mlir: rework stablehlo custom call implementation and add a Triton example
This commit is contained in:
parent
aec1d96e6d
commit
42dee5d0e0
@ -31,6 +31,7 @@ zig_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//mlir",
|
"//mlir",
|
||||||
"//mlir:c",
|
"//mlir:c",
|
||||||
|
"//stdx",
|
||||||
"@stablehlo//:stablehlo_dialect_capi",
|
"@stablehlo//:stablehlo_dialect_capi",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ const std = @import("std");
|
|||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub const abs = functors.unary_fn("stablehlo.abs").call;
|
pub const abs = functors.unary_fn("stablehlo.abs").call;
|
||||||
pub const cosine = functors.unary_fn("stablehlo.cosine").call;
|
pub const cosine = functors.unary_fn("stablehlo.cosine").call;
|
||||||
@ -733,53 +734,113 @@ pub fn convolution(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const CustomCallOpts = struct {
|
pub const CustomCallOpts = struct {
|
||||||
|
pub const ApiVersion = enum(i32) {
|
||||||
|
original = 1,
|
||||||
|
status_returning = 2,
|
||||||
|
status_returning_unified = 3,
|
||||||
|
typed_ffi = 4,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const BackendConfig = union(enum) {
|
||||||
|
string: [:0]const u8,
|
||||||
|
dict: mlir.DictionaryAttribute,
|
||||||
|
};
|
||||||
|
|
||||||
call_target_name: [:0]const u8,
|
call_target_name: [:0]const u8,
|
||||||
has_side_effect: bool,
|
has_side_effect: bool,
|
||||||
backend_config: [:0]const u8 = &.{},
|
backend_config: BackendConfig = .{ .string = &.{} },
|
||||||
api_version: i32,
|
operand_layouts: []const []const usize = &.{},
|
||||||
output_operand_aliases: []const i64,
|
result_layouts: []const []const usize = &.{},
|
||||||
|
output_operand_aliases: []const i64 = &.{},
|
||||||
|
addional_attributes: []const mlir.AttrTuple = &.{},
|
||||||
|
api_version: ApiVersion,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCallOpts, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||||
var buffer: [1024]u8 = undefined;
|
const MAX_OPERANDS = 64;
|
||||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
const MAX_RESULTS = 16;
|
||||||
const allocator = fba.allocator();
|
|
||||||
|
|
||||||
const output_operand_aliases = allocator.alloc(mlir.Attribute, opts.output_operand_aliases.len) catch unreachable;
|
const operand_layouts = blk: {
|
||||||
for (opts.output_operand_aliases, 0..) |alias, i| {
|
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
output_operand_aliases[i] = OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute);
|
for (opts.operand_layouts) |ol| {
|
||||||
}
|
const tensor_type = mlir.RankedTensorType.init(
|
||||||
|
&.{@intCast(ol.len)},
|
||||||
|
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||||
|
).as(mlir.Type);
|
||||||
|
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, ol);
|
||||||
|
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const result_layouts = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
|
for (opts.result_layouts) |rl| {
|
||||||
|
const tensor_type = mlir.RankedTensorType.init(
|
||||||
|
&.{@intCast(rl.len)},
|
||||||
|
mlir.IndexType.init(ctx).as(mlir.Type),
|
||||||
|
).as(mlir.Type);
|
||||||
|
const layout_attr = mlir.DenseElementsAttribute(.index).init(tensor_type, rl);
|
||||||
|
ret.appendAssumeCapacity(layout_attr.as(mlir.Attribute));
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const output_operand_aliases = blk: {
|
||||||
|
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
|
for (opts.output_operand_aliases) |alias| {
|
||||||
|
ret.appendAssumeCapacity(
|
||||||
|
OutputOperandAliasAttribute.init(
|
||||||
|
ctx,
|
||||||
|
&.{},
|
||||||
|
alias,
|
||||||
|
&.{},
|
||||||
|
).as(mlir.Attribute),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const backend_config = switch (opts.backend_config) {
|
||||||
|
.string => blk: {
|
||||||
|
stdx.debug.assert(
|
||||||
|
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
|
||||||
|
"Only API version of less than 4 is supported for backend_config as string",
|
||||||
|
.{},
|
||||||
|
);
|
||||||
|
break :blk mlir.StringAttribute.init(ctx, opts.backend_config.string).as(mlir.Attribute);
|
||||||
|
},
|
||||||
|
.dict => blk: {
|
||||||
|
stdx.debug.assert(
|
||||||
|
opts.api_version == .typed_ffi,
|
||||||
|
"Only API version 4 is supported for backend_config as dictionary",
|
||||||
|
.{},
|
||||||
|
);
|
||||||
|
break :blk opts.backend_config.dict.as(mlir.Attribute);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||||
|
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||||
|
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, @intFromEnum(opts.api_version)).as(mlir.Attribute) },
|
||||||
|
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
||||||
|
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
||||||
|
.{ "backend_config", backend_config },
|
||||||
|
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases.constSlice()).as(mlir.Attribute) },
|
||||||
|
.{ "operand_layouts", mlir.ArrayAttribute.init(ctx, operand_layouts.constSlice()).as(mlir.Attribute) },
|
||||||
|
.{ "result_layouts", mlir.ArrayAttribute.init(ctx, result_layouts.constSlice()).as(mlir.Attribute) },
|
||||||
|
});
|
||||||
|
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||||
.operands = inputs,
|
.operands = inputs,
|
||||||
.results = res_types,
|
.results = res_types,
|
||||||
.attributes = &.{
|
.attributes = attrs.constSlice(),
|
||||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, opts.api_version).as(mlir.Attribute) },
|
|
||||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, opts.call_target_name).as(mlir.Attribute) },
|
|
||||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, opts.has_side_effect).as(mlir.Attribute) },
|
|
||||||
.{ "backend_config", mlir.StringAttribute.init(ctx, opts.backend_config).as(mlir.Attribute) },
|
|
||||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, output_operand_aliases).as(mlir.Attribute) },
|
|
||||||
},
|
|
||||||
.location = location,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
|
||||||
.operands = inputs,
|
|
||||||
.results = res_types,
|
|
||||||
.attributes = &.{
|
|
||||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
|
||||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() },
|
|
||||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() },
|
|
||||||
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
|
||||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
|
||||||
.{ "mhlo.sharding", sharding_spec.asAttr() },
|
|
||||||
},
|
|
||||||
.location = location,
|
.location = location,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: move out of stablehlo.zig when we start to implement the frontend
|
||||||
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||||
const frontend_attributes = mlir.DictionaryAttribute.init(
|
const frontend_attributes = mlir.DictionaryAttribute.init(
|
||||||
ctx,
|
ctx,
|
||||||
@ -787,19 +848,14 @@ pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value,
|
|||||||
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx, "_xla_buffer_placement"), memory_kind.asAttr()),
|
||||||
},
|
},
|
||||||
).asAttr();
|
).asAttr();
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
|
||||||
.operands = inputs,
|
return custom_call(ctx, inputs, .{
|
||||||
.results = res_types,
|
.call_target_name = "annotate_device_placement",
|
||||||
.attributes = &.{
|
.has_side_effect = true,
|
||||||
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
.backend_config = .{ .string = &.{} },
|
||||||
.{ "call_target_name", mlir.StringAttribute.init(ctx, "annotate_device_placement").asAttr() },
|
.addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
|
||||||
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, true).asAttr() },
|
.api_version = .original,
|
||||||
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
}, res_types, location);
|
||||||
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
|
||||||
.{ "mhlo.frontend_attributes", frontend_attributes },
|
|
||||||
},
|
|
||||||
.location = location,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const DotDimensionNumbersAttribute = struct {
|
pub const DotDimensionNumbersAttribute = struct {
|
||||||
|
|||||||
@ -125,10 +125,9 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
&.{ q.value(), k.value(), v.value(), bias.value() },
|
&.{ q.value(), k.value(), v.value(), bias.value() },
|
||||||
.{
|
.{
|
||||||
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
||||||
.backend_config = backend_config,
|
.backend_config = .{ .string = backend_config },
|
||||||
.api_version = 2,
|
|
||||||
.has_side_effect = false,
|
.has_side_effect = false,
|
||||||
.output_operand_aliases = &.{},
|
.api_version = .original,
|
||||||
},
|
},
|
||||||
&.{
|
&.{
|
||||||
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||||
|
|||||||
161
zml/ops.zig
161
zml/ops.zig
@ -773,34 +773,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Produces a custom call to `name` that takes a tensor and returns it.
|
|
||||||
///
|
|
||||||
/// For example, this can be used to extract tokens quickly if they run on a loop on the
|
|
||||||
/// GPU.
|
|
||||||
pub fn identityCustomCall(name: [:0]const u8, input: Tensor, context: *anyopaque) Tensor {
|
|
||||||
const address: [8]u8 = @bitCast(@intFromPtr(context));
|
|
||||||
var backend_config: [8:0]u8 = undefined;
|
|
||||||
@memcpy(backend_config[0..8], address[0..8]);
|
|
||||||
const ctx = CompilationContext.current();
|
|
||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "custom_call({s})", .{name});
|
|
||||||
|
|
||||||
const op = dialect.stablehlo.custom_call(
|
|
||||||
ctx.mlirCtx(),
|
|
||||||
&.{input.value()},
|
|
||||||
.{
|
|
||||||
.api_version = 1,
|
|
||||||
.has_side_effect = false,
|
|
||||||
.call_target_name = name,
|
|
||||||
.backend_config = backend_config[0..],
|
|
||||||
.output_operand_aliases = &.{0},
|
|
||||||
},
|
|
||||||
&.{input.value().getType()},
|
|
||||||
loc,
|
|
||||||
);
|
|
||||||
return Tensor._result(input.shape(), op.result(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// At runtime the given tensor will be materialized and copied to host,
|
/// At runtime the given tensor will be materialized and copied to host,
|
||||||
/// and the callback will be called on it.
|
/// and the callback will be called on it.
|
||||||
pub fn addHostCallback(
|
pub fn addHostCallback(
|
||||||
@ -835,11 +807,11 @@ pub fn addHostCallback(
|
|||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
&.{input.value()},
|
&.{input.value()},
|
||||||
.{
|
.{
|
||||||
.api_version = 1,
|
|
||||||
.has_side_effect = false,
|
.has_side_effect = false,
|
||||||
.call_target_name = "zmlHostBufferCallback",
|
.call_target_name = "zmlHostBufferCallback",
|
||||||
.backend_config = @ptrCast(std.mem.sliceAsBytes(&backend_config)),
|
.backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) },
|
||||||
.output_operand_aliases = &.{0},
|
.output_operand_aliases = &.{0},
|
||||||
|
.api_version = .original,
|
||||||
},
|
},
|
||||||
&.{input.value().getType()},
|
&.{input.value().getType()},
|
||||||
loc,
|
loc,
|
||||||
@ -847,6 +819,135 @@ pub fn addHostCallback(
|
|||||||
return Tensor._result(input.shape(), op.result(0));
|
return Tensor._result(input.shape(), op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const TritonOps = struct {
|
||||||
|
debug: bool = false,
|
||||||
|
name: [:0]const u8,
|
||||||
|
ir: [:0]const u8,
|
||||||
|
grid: [3]i32,
|
||||||
|
num_stages: i32,
|
||||||
|
num_warps: i32,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Generate an MLIR call to the given member function with the given tensors.
|
||||||
|
pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]Tensor {
|
||||||
|
const ctx = CompilationContext.current();
|
||||||
|
|
||||||
|
var values: [inputs.len]mlir.Value = undefined;
|
||||||
|
ctx.extractValues(&inputs, &values);
|
||||||
|
|
||||||
|
var res_types: [outputs.len]mlir.Type = undefined;
|
||||||
|
inline for (outputs, 0..) |output, i| {
|
||||||
|
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||||
|
}
|
||||||
|
|
||||||
|
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "name"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.name).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "ir"), mlir.StringAttribute.init(ctx.mlirCtx(), opts.ir).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_x"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[0])).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_y"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[1])).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "grid_z"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.grid[2])).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_stages"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_stages)).as(mlir.Attribute)),
|
||||||
|
mlir.NamedAttribute.init(mlir.Identifier.get(ctx.mlirCtx(), "num_warps"), mlir.IntegerAttribute(.i32).init(ctx.mlirCtx(), @intCast(opts.num_warps)).as(mlir.Attribute)),
|
||||||
|
});
|
||||||
|
|
||||||
|
const MINOR_TO_MAJOR = blk: {
|
||||||
|
var ret: [Shape.MAX_RANK]usize = undefined;
|
||||||
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
|
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const operands_layouts = blk: {
|
||||||
|
var ret: [inputs.len][]const usize = undefined;
|
||||||
|
inline for (inputs, 0..) |input, i| {
|
||||||
|
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..];
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const results_layouts = blk: {
|
||||||
|
var ret: [outputs.len][]const usize = undefined;
|
||||||
|
inline for (outputs, 0..) |output, i| {
|
||||||
|
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..];
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
const op = dialect.stablehlo.custom_call(
|
||||||
|
ctx.mlirCtx(),
|
||||||
|
&values,
|
||||||
|
.{
|
||||||
|
.call_target_name = "__gpu$xla.gpu.triton",
|
||||||
|
.backend_config = .{ .dict = attrs },
|
||||||
|
.has_side_effect = false,
|
||||||
|
.api_version = .typed_ffi,
|
||||||
|
.operand_layouts = &operands_layouts,
|
||||||
|
.result_layouts = &results_layouts,
|
||||||
|
},
|
||||||
|
&res_types,
|
||||||
|
ctx.mlirCtx().location(@src()),
|
||||||
|
);
|
||||||
|
|
||||||
|
var outputs_: [outputs.len]Tensor = undefined;
|
||||||
|
inline for (outputs, 0..) |output, i| {
|
||||||
|
outputs_[i] = Tensor._result(output, op.result(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
|
test "triton" {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
if (platform.target != .cuda and platform.target != .rocm) return error.SkipZigTest;
|
||||||
|
|
||||||
|
const ir =
|
||||||
|
\\ module {
|
||||||
|
\\ tt.func public @add_one(%arg0: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg1: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg2: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}, %arg3: !tt.ptr<f32, 1> {tt.divisibility = 32 : i32}) {
|
||||||
|
\\ %0 = tt.get_program_id x : i32
|
||||||
|
\\ %1 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
|
||||||
|
\\ %2 = tt.load %arg1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr<f32>
|
||||||
|
\\ %cst = arith.constant 1.000000e+00 : f32
|
||||||
|
\\ %3 = arith.addf %1, %cst : f32
|
||||||
|
\\ tt.store %arg2, %3 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
|
||||||
|
\\ tt.store %arg3, %2 {cache = 1 : i32, evict = 1 : i32} : !tt.ptr<f32>
|
||||||
|
\\ tt.return
|
||||||
|
\\ }
|
||||||
|
\\ }
|
||||||
|
;
|
||||||
|
|
||||||
|
const TritonMod = struct {
|
||||||
|
pub fn forward(a: Tensor, b: Tensor) [2]Tensor {
|
||||||
|
return triton(.{ a, b }, .{ a.shape(), b.shape() }, .{
|
||||||
|
.debug = false,
|
||||||
|
.name = "add_one",
|
||||||
|
.ir = ir,
|
||||||
|
.grid = .{ 1, 1, 1 },
|
||||||
|
.num_stages = 1,
|
||||||
|
.num_warps = 1,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const a = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{1});
|
||||||
|
const b = try zml.Buffer.fromSlice(platform, .{}, &[1]f32{3});
|
||||||
|
|
||||||
|
const results = try zml.testing.compileAndCall(platform, TritonMod.forward, .{ a, b });
|
||||||
|
|
||||||
|
var cpu_result_0 = try results[0].toHostAlloc(std.testing.allocator);
|
||||||
|
defer cpu_result_0.deinit(std.testing.allocator);
|
||||||
|
var cpu_result_1 = try results[1].toHostAlloc(std.testing.allocator);
|
||||||
|
defer cpu_result_1.deinit(std.testing.allocator);
|
||||||
|
|
||||||
|
const expected_result_a: f32 = 2.0;
|
||||||
|
const expected_result_b: f32 = 3.0;
|
||||||
|
|
||||||
|
try std.testing.expectEqual(expected_result_a, cpu_result_0.items(f32)[0]);
|
||||||
|
try std.testing.expectEqual(expected_result_b, cpu_result_1.items(f32)[0]);
|
||||||
|
}
|
||||||
|
|
||||||
/// Generalized version of scatter to many inputs.
|
/// Generalized version of scatter to many inputs.
|
||||||
/// See `zml.Tensor.scatterSlices` for documentation on scatter.
|
/// See `zml.Tensor.scatterSlices` for documentation on scatter.
|
||||||
///
|
///
|
||||||
|
|||||||
@ -175,10 +175,15 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const sharding = ctx.getShardingAttr(res._shape);
|
const sharding = ctx.getShardingAttr(res._shape);
|
||||||
|
|
||||||
const op = dialect.stablehlo.sharding(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
&.{self.value()},
|
&.{self.value()},
|
||||||
sharding,
|
.{
|
||||||
|
.call_target_name = "Sharding",
|
||||||
|
.has_side_effect = false,
|
||||||
|
.addional_attributes = &.{.{ "mhlo.sharding", sharding.asAttr() }},
|
||||||
|
.api_version = .original,
|
||||||
|
},
|
||||||
&.{self.value().getType()},
|
&.{self.value().getType()},
|
||||||
ctx.mlirCtx().location(@src()),
|
ctx.mlirCtx().location(@src()),
|
||||||
);
|
);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user