mlir: rework stablehlo custom call implementation and add a Triton example

This commit is contained in:
Tarry Singh 2024-07-16 13:23:07 +00:00
parent aec1d96e6d
commit 42dee5d0e0
5 changed files with 243 additions and 81 deletions

View File

@ -31,6 +31,7 @@ zig_library(
deps = [ deps = [
"//mlir", "//mlir",
"//mlir:c", "//mlir:c",
"//stdx",
"@stablehlo//:stablehlo_dialect_capi", "@stablehlo//:stablehlo_dialect_capi",
], ],
) )

View File

@ -2,6 +2,7 @@ const std = @import("std");
const c = @import("c"); const c = @import("c");
const mlir = @import("mlir"); const mlir = @import("mlir");
const stdx = @import("stdx");
pub const abs = functors.unary_fn("stablehlo.abs").call; pub const abs = functors.unary_fn("stablehlo.abs").call;
pub const cosine = functors.unary_fn("stablehlo.cosine").call; pub const cosine = functors.unary_fn("stablehlo.cosine").call;
@ -733,53 +734,113 @@ pub fn convolution(
} }
pub const CustomCallOpts = struct { pub const CustomCallOpts = struct {
pub const ApiVersion = enum(i32) {
original = 1,
status_returning = 2,
status_returning_unified = 3,
typed_ffi = 4,
};
pub const BackendConfig = union(enum) {
string: [:0]const u8,
dict: mlir.DictionaryAttribute,
};
call_target_name: [:0]const u8, call_target_name: [:0]const u8,
has_side_effect: bool, has_side_effect: bool,
backend_config: [:0]const u8 = &.{}, backend_config: BackendConfig = .{ .string = &.{} },
api_version: i32, operand_layouts: []const []const usize = &.{},
output_operand_aliases: []const i64, result_layouts: []const []const usize = &.{},
output_operand_aliases: []const i64 = &.{},
addional_attributes: []const mlir.AttrTuple = &.{},
api_version: ApiVersion,
}; };
pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
var buffer: [1024]u8 = undefined; const MAX_OPERANDS = 64;
var fba = std.heap.FixedBufferAllocator.init(&buffer); const MAX_RESULTS = 16;
const allocator = fba.allocator();
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable; const operand_layouts = blk: {
for (opts.output_operand_aliases, 0..) |alias, i| { var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute); for (opts.operand_layouts) |ol| {
} const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(ol.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const result_layouts = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.result_layouts) |rl| {
const tensor_type = mlir.RankedTensorType.init(
&.{@intCast(rl.len)},
mlir.IndexType.init(ctx).as(mlir.Type),
).as(mlir.Type);
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
}
break :blk ret;
};
const output_operand_aliases = blk: {
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (opts.output_operand_aliases) |alias| {
ret.appendAssumeCapacity(
OutputOperandAliasAttribute.init(
ctx,
&.{},
alias,
&.{},
).as(mlir.Attribute),
);
}
break :blk ret;
};
const backend_config = switch (opts.backend_config) {
.string => blk: {
stdx.debug.assert(
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
"Only API version of less than 4 is supported for backend_config as string",
.{},
);
break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute);
},
.dict => blk: {
stdx.debug.assert(
opts.api_version == .typed_ffi,
"Only API version 4 is supported for backend_config as dictionary",
.{},
);
break :blk opts.backend_config.dict.as(mlir.Attribute);
},
};
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "backend_config", backend_config },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
});
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs, .operands = inputs,
.results = res_types, .results = res_types,
.attributes = &.{ .attributes = attrs.constSlice(),
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
},
.location = location,
});
}
pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs,
.results = res_types,
.attributes = &.{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
.{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() },
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
.{ "mhlo.sharding", sharding_spec.asAttr() },
},
.location = location, .location = location,
}); });
} }
// todo: move out of stablehlo.zig when we start to implement the frontend
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
const frontend_attributes = mlir.DictionaryAttribute.init( const frontend_attributes = mlir.DictionaryAttribute.init(
ctx, ctx,
@ -787,19 +848,14 @@ pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value,
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()), mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
}, },
).asAttr(); ).asAttr();
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs, return custom_call(ctx, inputs, .{
.results = res_types, .call_target_name = "annotate_device_placement",
.attributes = &.{ .has_side_effect = true,
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() }, .backend_config = .{ .string = &.{} },
.{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() }, .addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() }, .api_version = .original,
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() }, }, res_types, location);
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
.{ "mhlo.frontend_attributes", frontend_attributes },
},
.location = location,
});
} }
pub const DotDimensionNumbersAttribute = struct { pub const DotDimensionNumbersAttribute = struct {

View File

@ -125,10 +125,9 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
&.{ q.value(), k.value(), v.value(), bias.value() }, &.{ q.value(), k.value(), v.value(), bias.value() },
.{ .{
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax", .call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
.backend_config = backend_config, .backend_config = .{ .string = backend_config },
.api_version = 2,
.has_side_effect = false, .has_side_effect = false,
.output_operand_aliases = &.{}, .api_version = .original,
}, },
&.{ &.{
mlir.ext.mlirType(mlir_ctx, q.shape()), mlir.ext.mlirType(mlir_ctx, q.shape()),

View File

@ -773,34 +773,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
return res; return res;
} }
/// Produces a custom call to `name` that takes a tensor and returns it.
///
/// For example, this can be used to extract tokens quickly if they run on a loop on the
/// GPU.
pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque) Tensor {
const address: [8]u8 = @bitCast(@intFromPtr(context));
var backend_config: [8:0]u8 = undefined;
@memcpy(backend_config[0..8], address[0..8]);
const ctx = CompilationContext.current();
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
&.{input.value()},
.{
.api_version = 1,
.has_side_effect = false,
.call_target_name = name,
.backend_config = backend_config[0..],
.output_operand_aliases = &.{0},
},
&.{input.value().getType()},
loc,
);
return Tensor._result(input.shape(), op.result(0));
}
/// At runtime the given tensor will be materialized and copied to host, /// At runtime the given tensor will be materialized and copied to host,
/// and the callback will be called on it. /// and the callback will be called on it.
pub fn addHostCallback( pub fn addHostCallback(
@ -835,11 +807,11 @@ pub fn addHostCallback(
ctx.mlirCtx(), ctx.mlirCtx(),
&.{input.value()}, &.{input.value()},
.{ .{
.api_version = 1,
.has_side_effect = false, .has_side_effect = false,
.call_target_name = "zmlHostBufferCallback", .call_target_name = "zmlHostBufferCallback",
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)), .backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) },
.output_operand_aliases = &.{0}, .output_operand_aliases = &.{0},
.api_version = .original,
}, },
&.{input.value().getType()}, &.{input.value().getType()},
loc, loc,
@ -847,6 +819,135 @@ pub fn addHostCallback(
return Tensor._result(input.shape(), op.result(0)); return Tensor._result(input.shape(), op.result(0));
} }
pub const TritonOps = struct {
debug: bool = false,
name: [:0]const u8,
ir: [:0]const u8,
grid: [3]i32,
num_stages: i32,
num_warps: i32,
};
/// Generate an MLIR call to the given member function with the given tensors.
pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]Tensor {
const ctx = CompilationContext.current();
var values: [inputs.len]mlir.Value = undefined;
ctx.extractValues(&inputs, &values);
var res_types: [outputs.len]mlir.Type = undefined;
inline for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
}
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)),
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)),
});
const MINOR_TO_MAJOR = blk: {
var ret: [Shape.MAX_RANK]usize = undefined;
for (0..Shape.MAX_RANK) |i| {
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
}
break :blk ret;
};
const operands_layouts = blk: {
var ret: [inputs.len][]const usize = undefined;
inline for (inputs, 0..) |input, i| {
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..];
}
break :blk ret;
};
const results_layouts = blk: {
var ret: [outputs.len][]const usize = undefined;
inline for (outputs, 0..) |output, i| {
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..];
}
break :blk ret;
};
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
&values,
.{
.call_target_name = "__gpu$xla.gpu.triton",
.backend_config = .{ .dict = attrs },
.has_side_effect = false,
.api_version = .typed_ffi,
.operand_layouts = &operands_layouts,
.result_layouts = &results_layouts,
},
&res_types,
ctx.mlirCtx().location(@src()),
);
var outputs_: [outputs.len]Tensor = undefined;
inline for (outputs, 0..) |output, i| {
outputs_[i] = Tensor._result(output, op.result(i));
}
return outputs_;
}
test "triton" {
const zml = @import("zml.zig");
const platform = zml.testing.env();
if (platform.target != .cuda and platform.target != .rocm) return error.SkipZigTest;
const ir =
\\ module {
\\ tt.func public @add_one(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}) {
\\ %0 = tt.get_program_id x : i32
\\ %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
\\ %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
\\ %cst = arith.constant 1.000000e+00 : f32
\\ %3 = arith.addf %1, %cst : f32
\\ tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
\\ tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
\\ tt.return
\\ }
\\ }
;
const TritonMod = struct {
pub fn forward(a: Tensor, b: Tensor) [2]Tensor {
return triton(.{ a, b }, .{ a.shape(), b.shape() }, .{
.debug = false,
.name = "add_one",
.ir = ir,
.grid = .{ 1, 1, 1 },
.num_stages = 1,
.num_warps = 1,
});
}
};
const a = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{1});
const b = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{3});
const results = try zml.testing.compileAndCall(platform, TritonMod.forward, .{ a, b });
var cpu_result_0 = try results[0].toHostAlloc(std.testing.allocator);
defer cpu_result_0.deinit(std.testing.allocator);
var cpu_result_1 = try results[1].toHostAlloc(std.testing.allocator);
defer cpu_result_1.deinit(std.testing.allocator);
const expected_result_a: f32 = 2.0;
const expected_result_b: f32 = 3.0;
try std.testing.expectEqual(expected_result_a, cpu_result_0.items(f32)[0]);
try std.testing.expectEqual(expected_result_b, cpu_result_1.items(f32)[0]);
}
/// Generalized version of scatter to many inputs. /// Generalized version of scatter to many inputs.
/// See `zml.Tensor.scatterSlices` for documentation on scatter. /// See `zml.Tensor.scatterSlices` for documentation on scatter.
/// ///

View File

@ -175,10 +175,15 @@ pub const Tensor = struct {
const sharding = ctx.getShardingAttr(res._shape); const sharding = ctx.getShardingAttr(res._shape);
const op = dialect.stablehlo.sharding( const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(), ctx.mlirCtx(),
&.{self.value()}, &.{self.value()},
sharding, .{
.call_target_name = "Sharding",
.has_side_effect = false,
.addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }},
.api_version = .original,
},
&.{self.value().getType()}, &.{self.value().getType()},
ctx.mlirCtx().location(@src()), ctx.mlirCtx().location(@src()),
); );