From 42dee5d0e00118d9ac1dc0c65f1ed087ffcb16aa Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 16 Jul 2024 13:23:07 +0000 Subject: [PATCH] mlir: rework stablehlo custom call implementation and add a Triton example --- mlir/dialects/BUILD.bazel | 1 + mlir/dialects/stablehlo.zig | 148 ++++++++++++++++++++++----------- zml/nn/cuda.zig | 5 +- zml/ops.zig | 161 +++++++++++++++++++++++++++++------- zml/tensor.zig | 9 +- 5 files changed, 243 insertions(+), 81 deletions(-) diff --git a/mlir/dialects/BUILD.bazel b/mlir/dialects/BUILD.bazel index 0a275d8..9570cb8 100644 --- a/mlir/dialects/BUILD.bazel +++ b/mlir/dialects/BUILD.bazel @@ -31,6 +31,7 @@ zig_library( deps = [ "//mlir", "//mlir:c", + "//stdx", "@stablehlo//:stablehlo_dialect_capi", ], ) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 4311de3..58ba5eb 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -2,6 +2,7 @@ const std = @import("std"); const c = @import("c"); const mlir = @import("mlir"); +const stdx = @import("stdx"); pub const abs = functors.unary_fn("stablehlo.abs").call; pub const cosine = functors.unary_fn("stablehlo.cosine").call; @@ -733,53 +734,113 @@ pub fn convolution( } pub const CustomCallOpts = struct { + pub const ApiVersion = enum(i32) { + original = 1, + status_returning = 2, + status_returning_unified = 3, + typed_ffi = 4, + }; + + pub const BackendConfig = union(enum) { + string: [:0]const u8, + dict: mlir.DictionaryAttribute, + }; + call_target_name: [:0]const u8, has_side_effect: bool, - backend_config: [:0]const u8 = &.{}, - api_version: i32, - output_operand_aliases: []const i64, + backend_config: BackendConfig = .{ .string = &.{} }, + operand_layouts: []const []const usize = &.{}, + result_layouts: []const []const usize = &.{}, + output_operand_aliases: []const i64 = &.{}, + addional_attributes: []const mlir.AttrTuple = &.{}, + api_version: ApiVersion, }; pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { - var buffer: [1024]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&buffer); - const allocator = fba.allocator(); + const MAX_OPERANDS = 64; + const MAX_RESULTS = 16; - const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable; - for (opts.output_operand_aliases, 0..) |alias, i| { - output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute); - } + const operand_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (opts.operand_layouts) |ol| { + const tensor_type = mlir.RankedTensorType.init( + &.{@intCast(ol.len)}, + mlir.IndexType.init(ctx).as(mlir.Type), + ).as(mlir.Type); + const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol); + ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + } + break :blk ret; + }; + + const result_layouts = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (opts.result_layouts) |rl| { + const tensor_type = mlir.RankedTensorType.init( + &.{@intCast(rl.len)}, + mlir.IndexType.init(ctx).as(mlir.Type), + ).as(mlir.Type); + const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl); + ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute)); + } + break :blk ret; + }; + + const output_operand_aliases = blk: { + var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (opts.output_operand_aliases) |alias| { + ret.appendAssumeCapacity( + OutputOperandAliasAttribute.init( + ctx, + &.{}, + alias, + &.{}, + ).as(mlir.Attribute), + ); + } + break :blk ret; + }; + + const backend_config = switch (opts.backend_config) { + .string => blk: { + stdx.debug.assert( + @intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi), + "Only API version of less than 4 is supported for backend_config as string", + .{}, + ); + break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute); + }, + .dict => blk: { + stdx.debug.assert( + opts.api_version == .typed_ffi, + "Only API version 4 is supported for backend_config as dictionary", + .{}, + ); + break :blk opts.backend_config.dict.as(mlir.Attribute); + }, + }; + + var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; + attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ + .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) }, + .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) }, + .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) }, + .{ "backend_config", backend_config }, + .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) }, + .{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) }, + .{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) }, + }); + attrs.appendSliceAssumeCapacity(opts.addional_attributes); return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ .operands = inputs, .results = res_types, - .attributes = &.{ - .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) }, - .{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) }, - .{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) }, - .{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) }, - .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) }, - }, - .location = location, - }); -} - -pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { - return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ - .operands = inputs, - .results = res_types, - .attributes = &.{ - .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() }, - .{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() }, - .{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() }, - .{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() }, - .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() }, - .{ "mhlo.sharding", sharding_spec.asAttr() }, - }, + .attributes = attrs.constSlice(), .location = location, }); } +// todo: move out of stablehlo.zig when we start to implement the frontend pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { const frontend_attributes = mlir.DictionaryAttribute.init( ctx, @@ -787,19 +848,14 @@ pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()), }, ).asAttr(); - return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ - .operands = inputs, - .results = res_types, - .attributes = &.{ - .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() }, - .{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() }, - .{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() }, - .{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() }, - .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() }, - .{ "mhlo.frontend_attributes", frontend_attributes }, - }, - .location = location, - }); + + return custom_call(ctx, inputs, .{ + .call_target_name = "annotate_device_placement", + .has_side_effect = true, + .backend_config = .{ .string = &.{} }, + .addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }}, + .api_version = .original, + }, res_types, location); } pub const DotDimensionNumbersAttribute = struct { diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 94e8e2e..81ff600 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -125,10 +125,9 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { &.{ q.value(), k.value(), v.value(), bias.value() }, .{ .call_target_name = "__cudnn$fmhaScaleBiasSoftmax", - .backend_config = backend_config, - .api_version = 2, + .backend_config = .{ .string = backend_config }, .has_side_effect = false, - .output_operand_aliases = &.{}, + .api_version = .original, }, &.{ mlir.ext.mlirType(mlir_ctx, q.shape()), diff --git a/zml/ops.zig b/zml/ops.zig index 852ec34..c049985 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -773,34 +773,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base return res; } -/// Produces a custom call to `name` that takes a tensor and returns it. -/// -/// For example, this can be used to extract tokens quickly if they run on a loop on the -/// GPU. -pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque) Tensor { - const address: [8]u8 = @bitCast(@intFromPtr(context)); - var backend_config: [8:0]u8 = undefined; - @memcpy(backend_config[0..8], address[0..8]); - const ctx = CompilationContext.current(); - - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name}); - - const op = dialect.stablehlo.custom_call( - ctx.mlirCtx(), - &.{input.value()}, - .{ - .api_version = 1, - .has_side_effect = false, - .call_target_name = name, - .backend_config = backend_config[0..], - .output_operand_aliases = &.{0}, - }, - &.{input.value().getType()}, - loc, - ); - return Tensor._result(input.shape(), op.result(0)); -} - /// At runtime the given tensor will be materialized and copied to host, /// and the callback will be called on it. pub fn addHostCallback( @@ -835,11 +807,11 @@ pub fn addHostCallback( ctx.mlirCtx(), &.{input.value()}, .{ - .api_version = 1, .has_side_effect = false, .call_target_name = "zmlHostBufferCallback", - .backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)), + .backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) }, .output_operand_aliases = &.{0}, + .api_version = .original, }, &.{input.value().getType()}, loc, @@ -847,6 +819,135 @@ pub fn addHostCallback( return Tensor._result(input.shape(), op.result(0)); } +pub const TritonOps = struct { + debug: bool = false, + name: [:0]const u8, + ir: [:0]const u8, + grid: [3]i32, + num_stages: i32, + num_warps: i32, +}; + +/// Generate an MLIR call to the given member function with the given tensors. +pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]Tensor { + const ctx = CompilationContext.current(); + + var values: [inputs.len]mlir.Value = undefined; + ctx.extractValues(&inputs, &values); + + var res_types: [outputs.len]mlir.Type = undefined; + inline for (outputs, 0..) |output, i| { + res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); + } + + const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{ + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)), + mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)), + }); + + const MINOR_TO_MAJOR = blk: { + var ret: [Shape.MAX_RANK]usize = undefined; + for (0..Shape.MAX_RANK) |i| { + ret[i] = @intCast(Shape.MAX_RANK - i - 1); + } + break :blk ret; + }; + + const operands_layouts = blk: { + var ret: [inputs.len][]const usize = undefined; + inline for (inputs, 0..) |input, i| { + ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..]; + } + break :blk ret; + }; + + const results_layouts = blk: { + var ret: [outputs.len][]const usize = undefined; + inline for (outputs, 0..) |output, i| { + ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..]; + } + break :blk ret; + }; + + const op = dialect.stablehlo.custom_call( + ctx.mlirCtx(), + &values, + .{ + .call_target_name = "__gpu$xla.gpu.triton", + .backend_config = .{ .dict = attrs }, + .has_side_effect = false, + .api_version = .typed_ffi, + .operand_layouts = &operands_layouts, + .result_layouts = &results_layouts, + }, + &res_types, + ctx.mlirCtx().location(@src()), + ); + + var outputs_: [outputs.len]Tensor = undefined; + inline for (outputs, 0..) |output, i| { + outputs_[i] = Tensor._result(output, op.result(i)); + } + + return outputs_; +} + +test "triton" { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + if (platform.target != .cuda and platform.target != .rocm) return error.SkipZigTest; + + const ir = + \\ module { + \\ tt.func public @add_one(%arg0: !tt.ptr {tt.divisibility = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 32 : i32}, %arg3: !tt.ptr {tt.divisibility = 32 : i32}) { + \\ %0 = tt.get_program_id x : i32 + \\ %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr + \\ %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr + \\ %cst = arith.constant 1.000000e+00 : f32 + \\ %3 = arith.addf %1, %cst : f32 + \\ tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr + \\ tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr + \\ tt.return + \\ } + \\ } + ; + + const TritonMod = struct { + pub fn forward(a: Tensor, b: Tensor) [2]Tensor { + return triton(.{ a, b }, .{ a.shape(), b.shape() }, .{ + .debug = false, + .name = "add_one", + .ir = ir, + .grid = .{ 1, 1, 1 }, + .num_stages = 1, + .num_warps = 1, + }); + } + }; + + const a = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{1}); + const b = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{3}); + + const results = try zml.testing.compileAndCall(platform, TritonMod.forward, .{ a, b }); + + var cpu_result_0 = try results[0].toHostAlloc(std.testing.allocator); + defer cpu_result_0.deinit(std.testing.allocator); + var cpu_result_1 = try results[1].toHostAlloc(std.testing.allocator); + defer cpu_result_1.deinit(std.testing.allocator); + + const expected_result_a: f32 = 2.0; + const expected_result_b: f32 = 3.0; + + try std.testing.expectEqual(expected_result_a, cpu_result_0.items(f32)[0]); + try std.testing.expectEqual(expected_result_b, cpu_result_1.items(f32)[0]); +} + /// Generalized version of scatter to many inputs. /// See `zml.Tensor.scatterSlices` for documentation on scatter. /// diff --git a/zml/tensor.zig b/zml/tensor.zig index 81578a2..83e6c28 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -175,10 +175,15 @@ pub const Tensor = struct { const sharding = ctx.getShardingAttr(res._shape); - const op = dialect.stablehlo.sharding( + const op = dialect.stablehlo.custom_call( ctx.mlirCtx(), &.{self.value()}, - sharding, + .{ + .call_target_name = "Sharding", + .has_side_effect = false, + .addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }}, + .api_version = .original, + }, &.{self.value().getType()}, ctx.mlirCtx().location(@src()), );