mlir: rework stablehlo custom call implementation and add a Triton example
This commit is contained in:
parent
aec1d96e6d
commit
42dee5d0e0
@ -31,6 +31,7 @@ zig_library(
|
||||
deps = [
|
||||
"//mlir",
|
||||
"//mlir:c",
|
||||
"//stdx",
|
||||
"@stablehlo//:stablehlo_dialect_capi",
|
||||
],
|
||||
)
|
||||
|
||||
@ -2,6 +2,7 @@ const std = @import("std");
|
||||
|
||||
const c = @import("c");
|
||||
const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
pub const abs = functors.unary_fn("stablehlo.abs").call;
|
||||
pub const cosine = functors.unary_fn("stablehlo.cosine").call;
|
||||
@ -733,53 +734,113 @@ pub fn convolution(
|
||||
}
|
||||
|
||||
pub const CustomCallOpts = struct {
|
||||
pub const ApiVersion = enum(i32) {
|
||||
original = 1,
|
||||
status_returning = 2,
|
||||
status_returning_unified = 3,
|
||||
typed_ffi = 4,
|
||||
};
|
||||
|
||||
pub const BackendConfig = union(enum) {
|
||||
string: [:0]const u8,
|
||||
dict: mlir.DictionaryAttribute,
|
||||
};
|
||||
|
||||
call_target_name: [:0]const u8,
|
||||
has_side_effect: bool,
|
||||
backend_config: [:0]const u8 = &.{},
|
||||
api_version: i32,
|
||||
output_operand_aliases: []const i64,
|
||||
backend_config: BackendConfig = .{ .string = &.{} },
|
||||
operand_layouts: []const []const usize = &.{},
|
||||
result_layouts: []const []const usize = &.{},
|
||||
output_operand_aliases: []const i64 = &.{},
|
||||
addional_attributes: []const mlir.AttrTuple = &.{},
|
||||
api_version: ApiVersion,
|
||||
};
|
||||
|
||||
pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
var buffer: [1024]u8 = undefined;
|
||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
||||
const allocator = fba.allocator();
|
||||
const MAX_OPERANDS = 64;
|
||||
const MAX_RESULTS = 16;
|
||||
|
||||
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
|
||||
for (opts.output_operand_aliases, 0..) |alias, i| {
|
||||
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute);
|
||||
}
|
||||
const operand_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||
for (opts.operand_layouts) |ol| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(ol.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const result_layouts = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.result_layouts) |rl| {
|
||||
const tensor_type = mlir.RankedTensorType.init(
|
||||
&.{@intCast(rl.len)},
|
||||
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||
).as(mlir.Type);
|
||||
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const output_operand_aliases = blk: {
|
||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||
for (opts.output_operand_aliases) |alias| {
|
||||
ret.appendAssumeCapacity(
|
||||
OutputOperandAliasAttribute.init(
|
||||
ctx,
|
||||
&.{},
|
||||
alias,
|
||||
&.{},
|
||||
).as(mlir.Attribute),
|
||||
);
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const backend_config = switch (opts.backend_config) {
|
||||
.string => blk: {
|
||||
stdx.debug.assert(
|
||||
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
|
||||
"Only API version of less than 4 is supported for backend_config as string",
|
||||
.{},
|
||||
);
|
||||
break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute);
|
||||
},
|
||||
.dict => blk: {
|
||||
stdx.debug.assert(
|
||||
opts.api_version == .typed_ffi,
|
||||
"Only API version 4 is supported for backend_config as dictionary",
|
||||
.{},
|
||||
);
|
||||
break :blk opts.backend_config.dict.as(mlir.Attribute);
|
||||
},
|
||||
};
|
||||
|
||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||
.{ "backend_config", backend_config },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
||||
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
|
||||
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
|
||||
});
|
||||
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
||||
|
||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||
.operands = inputs,
|
||||
.results = res_types,
|
||||
.attributes = &.{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||
.operands = inputs,
|
||||
.results = res_types,
|
||||
.attributes = &.{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() },
|
||||
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "mhlo.sharding", sharding_spec.asAttr() },
|
||||
},
|
||||
.attributes = attrs.constSlice(),
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
// todo: move out of stablehlo.zig when we start to implement the frontend
|
||||
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||
const frontend_attributes = mlir.DictionaryAttribute.init(
|
||||
ctx,
|
||||
@ -787,19 +848,14 @@ pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value,
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
|
||||
},
|
||||
).asAttr();
|
||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||
.operands = inputs,
|
||||
.results = res_types,
|
||||
.attributes = &.{
|
||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() },
|
||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() },
|
||||
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
||||
.{ "mhlo.frontend_attributes", frontend_attributes },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
|
||||
return custom_call(ctx, inputs, .{
|
||||
.call_target_name = "annotate_device_placement",
|
||||
.has_side_effect = true,
|
||||
.backend_config = .{ .string = &.{} },
|
||||
.addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
|
||||
.api_version = .original,
|
||||
}, res_types, location);
|
||||
}
|
||||
|
||||
pub const DotDimensionNumbersAttribute = struct {
|
||||
|
||||
@ -125,10 +125,9 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
&.{ q.value(), k.value(), v.value(), bias.value() },
|
||||
.{
|
||||
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
||||
.backend_config = backend_config,
|
||||
.api_version = 2,
|
||||
.backend_config = .{ .string = backend_config },
|
||||
.has_side_effect = false,
|
||||
.output_operand_aliases = &.{},
|
||||
.api_version = .original,
|
||||
},
|
||||
&.{
|
||||
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||
|
||||
161
zml/ops.zig
161
zml/ops.zig
@ -773,34 +773,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Produces a custom call to `name` that takes a tensor and returns it.
|
||||
///
|
||||
/// For example, this can be used to extract tokens quickly if they run on a loop on the
|
||||
/// GPU.
|
||||
pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque) Tensor {
|
||||
const address: [8]u8 = @bitCast(@intFromPtr(context));
|
||||
var backend_config: [8:0]u8 = undefined;
|
||||
@memcpy(backend_config[0..8], address[0..8]);
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
|
||||
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
&.{input.value()},
|
||||
.{
|
||||
.api_version = 1,
|
||||
.has_side_effect = false,
|
||||
.call_target_name = name,
|
||||
.backend_config = backend_config[0..],
|
||||
.output_operand_aliases = &.{0},
|
||||
},
|
||||
&.{input.value().getType()},
|
||||
loc,
|
||||
);
|
||||
return Tensor._result(input.shape(), op.result(0));
|
||||
}
|
||||
|
||||
/// At runtime the given tensor will be materialized and copied to host,
|
||||
/// and the callback will be called on it.
|
||||
pub fn addHostCallback(
|
||||
@ -835,11 +807,11 @@ pub fn addHostCallback(
|
||||
ctx.mlirCtx(),
|
||||
&.{input.value()},
|
||||
.{
|
||||
.api_version = 1,
|
||||
.has_side_effect = false,
|
||||
.call_target_name = "zmlHostBufferCallback",
|
||||
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)),
|
||||
.backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) },
|
||||
.output_operand_aliases = &.{0},
|
||||
.api_version = .original,
|
||||
},
|
||||
&.{input.value().getType()},
|
||||
loc,
|
||||
@ -847,6 +819,135 @@ pub fn addHostCallback(
|
||||
return Tensor._result(input.shape(), op.result(0));
|
||||
}
|
||||
|
||||
pub const TritonOps = struct {
|
||||
debug: bool = false,
|
||||
name: [:0]const u8,
|
||||
ir: [:0]const u8,
|
||||
grid: [3]i32,
|
||||
num_stages: i32,
|
||||
num_warps: i32,
|
||||
};
|
||||
|
||||
/// Generate an MLIR call to the given member function with the given tensors.
|
||||
pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]Tensor {
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
var values: [inputs.len]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &values);
|
||||
|
||||
var res_types: [outputs.len]mlir.Type = undefined;
|
||||
inline for (outputs, 0..) |output, i| {
|
||||
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||
}
|
||||
|
||||
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)),
|
||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)),
|
||||
});
|
||||
|
||||
const MINOR_TO_MAJOR = blk: {
|
||||
var ret: [Shape.MAX_RANK]usize = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const operands_layouts = blk: {
|
||||
var ret: [inputs.len][]const usize = undefined;
|
||||
inline for (inputs, 0..) |input, i| {
|
||||
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..];
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const results_layouts = blk: {
|
||||
var ret: [outputs.len][]const usize = undefined;
|
||||
inline for (outputs, 0..) |output, i| {
|
||||
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..];
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
&values,
|
||||
.{
|
||||
.call_target_name = "__gpu$xla.gpu.triton",
|
||||
.backend_config = .{ .dict = attrs },
|
||||
.has_side_effect = false,
|
||||
.api_version = .typed_ffi,
|
||||
.operand_layouts = &operands_layouts,
|
||||
.result_layouts = &results_layouts,
|
||||
},
|
||||
&res_types,
|
||||
ctx.mlirCtx().location(@src()),
|
||||
);
|
||||
|
||||
var outputs_: [outputs.len]Tensor = undefined;
|
||||
inline for (outputs, 0..) |output, i| {
|
||||
outputs_[i] = Tensor._result(output, op.result(i));
|
||||
}
|
||||
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
test "triton" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
if (platform.target != .cuda and platform.target != .rocm) return error.SkipZigTest;
|
||||
|
||||
const ir =
|
||||
\\ module {
|
||||
\\ tt.func public @add_one(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}) {
|
||||
\\ %0 = tt.get_program_id x : i32
|
||||
\\ %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
|
||||
\\ %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
|
||||
\\ %cst = arith.constant 1.000000e+00 : f32
|
||||
\\ %3 = arith.addf %1, %cst : f32
|
||||
\\ tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
|
||||
\\ tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
|
||||
\\ tt.return
|
||||
\\ }
|
||||
\\ }
|
||||
;
|
||||
|
||||
const TritonMod = struct {
|
||||
pub fn forward(a: Tensor, b: Tensor) [2]Tensor {
|
||||
return triton(.{ a, b }, .{ a.shape(), b.shape() }, .{
|
||||
.debug = false,
|
||||
.name = "add_one",
|
||||
.ir = ir,
|
||||
.grid = .{ 1, 1, 1 },
|
||||
.num_stages = 1,
|
||||
.num_warps = 1,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const a = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{1});
|
||||
const b = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{3});
|
||||
|
||||
const results = try zml.testing.compileAndCall(platform, TritonMod.forward, .{ a, b });
|
||||
|
||||
var cpu_result_0 = try results[0].toHostAlloc(std.testing.allocator);
|
||||
defer cpu_result_0.deinit(std.testing.allocator);
|
||||
var cpu_result_1 = try results[1].toHostAlloc(std.testing.allocator);
|
||||
defer cpu_result_1.deinit(std.testing.allocator);
|
||||
|
||||
const expected_result_a: f32 = 2.0;
|
||||
const expected_result_b: f32 = 3.0;
|
||||
|
||||
try std.testing.expectEqual(expected_result_a, cpu_result_0.items(f32)[0]);
|
||||
try std.testing.expectEqual(expected_result_b, cpu_result_1.items(f32)[0]);
|
||||
}
|
||||
|
||||
/// Generalized version of scatter to many inputs.
|
||||
/// See `zml.Tensor.scatterSlices` for documentation on scatter.
|
||||
///
|
||||
|
||||
@ -175,10 +175,15 @@ pub const Tensor = struct {
|
||||
|
||||
const sharding = ctx.getShardingAttr(res._shape);
|
||||
|
||||
const op = dialect.stablehlo.sharding(
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
&.{self.value()},
|
||||
sharding,
|
||||
.{
|
||||
.call_target_name = "Sharding",
|
||||
.has_side_effect = false,
|
||||
.addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }},
|
||||
.api_version = .original,
|
||||
},
|
||||
&.{self.value().getType()},
|
||||
ctx.mlirCtx().location(@src()),
|
||||
);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user