Add customCall operation to zml/ops.

This commit is contained in:
Tarry Singh 2025-01-09 15:01:33 +00:00
parent 9f1cc762cd
commit 09c43b8759

View File

@ -783,44 +783,19 @@ pub fn addHostCallback(
output_shapes: []const Shape, output_shapes: []const Shape,
opts: HostCallbackOpt, opts: HostCallbackOpt,
) []Tensor { ) []Tensor {
const ctx = CompilationContext.current(); return customCall(
"zmlHostBufferCallback",
const mlir_ctx = ctx.mlirCtx(); inputs,
const backend_config = mlir.Attribute.dict(mlir_ctx, &.{ output_shapes,
.{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) }, .{
.{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) }, .callback = @intFromPtr(callback),
}); .user_context = @intFromPtr(blkctx),
},
const values = stdx.stackSlice(8, mlir.Value, inputs.len);
for (inputs, values) |i, *v| {
v.* = ctx.getValue(i.toMemory(.host_pinned));
}
const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len);
for (res_types, output_shapes) |*r, o| {
r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type);
}
const loc = ctx.mlirCtx().location(@src());
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
values,
.{ .{
.call_target_name = "zmlHostBufferCallback",
.api_version = .typed_ffi,
.backend_config = backend_config,
.has_side_effect = opts.has_side_effect, .has_side_effect = opts.has_side_effect,
.output_operand_aliases = opts.output_operand_aliases, .output_operand_aliases = opts.output_operand_aliases,
}, },
res_types,
loc,
); );
const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM");
for (res, output_shapes, 0..) |*r, o, i| {
r.* = Tensor._result(o, op.result(i)).toMemory(.device);
}
return res;
} }
pub const TritonOps = struct { pub const TritonOps = struct {
@ -1248,6 +1223,169 @@ fn scatterPrepareIndices(
return Tensor.stack(indices.constSlice(), .last, .coord); return Tensor.stack(indices.constSlice(), .last, .coord);
} }
fn TensorOrTensorArray(comptime T: type) type {
const type_info = @typeInfo(T);
return switch (type_info) {
.@"struct" => |struct_info| b: {
if (T == Tensor) break :b Tensor;
if (!struct_info.is_tuple) @compileError("Expected tuple");
break :b if (struct_info.fields.len == 1)
Tensor
else
[struct_info.fields.len]Tensor;
},
.array => |array_info| b: {
break :b if (array_info.len == 1)
Tensor
else
[array_info.len]Tensor;
},
.pointer => |pointer_info| b: {
if (pointer_info.size != .slice) @compileError("Expected slice");
break :b []Tensor;
},
else => @compileError("Unsupported type: " ++ @typeName(T)),
};
}
pub const CustomCallOptions = struct {
has_side_effect: bool,
output_operand_aliases: ?[]const i64 = null,
};
pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype, metadata: anytype, opts: CustomCallOptions) TensorOrTensorArray(@TypeOf(outputs)) {
// Transform generic inputs to flat slice.
const inputs_: []const Tensor = switch (@typeInfo(@TypeOf(inputs))) {
.@"struct" => |struct_info| b: {
if (@TypeOf(inputs) == Tensor) {
break :b &[1]Tensor{inputs};
}
if (!struct_info.is_tuple) @compileError("Expected tuple");
var inputs_: [struct_info.fields.len]Tensor = undefined;
meta.collectBuf((struct {
pub fn func(t: Tensor) Tensor {
return t;
}
}).func, {}, &inputs, &inputs_);
break :b &inputs_;
},
.array => &inputs,
.pointer => |pointer_info| b: {
if (pointer_info.size != .slice) @compileError("Expected slice");
break :b inputs;
},
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(inputs))),
};
// Transform generic outputs to flat slice.
const output_shapes: []const Shape = switch (@typeInfo(@TypeOf(outputs))) {
.@"struct" => |struct_info| b: {
if (@TypeOf(outputs) == Shape) {
break :b &[1]Shape{outputs};
}
if (!struct_info.is_tuple) @compileError("Expected tuple");
var output_shapes: [struct_info.fields.len]Shape = undefined;
meta.collectBuf((struct {
pub fn func(t: Shape) Shape {
return t;
}
}).func, {}, &outputs, &output_shapes);
break :b &output_shapes;
},
.array => &outputs,
.pointer => |pointer_info| b: {
if (pointer_info.size != .slice) @compileError("Expected slice");
break :b outputs;
},
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
};
const outputs_flat = customCallInternal(target_name, inputs_, output_shapes, metadata, opts);
// Transform flat slice to generic outputs.
return switch (@typeInfo(@TypeOf(outputs))) {
.@"struct" => |struct_info| b: {
if (@TypeOf(outputs) == Shape) break :b outputs_flat[0];
if (!struct_info.is_tuple) @compileError("Expected tuple");
if (struct_info.fields.len == 1) break :b outputs_flat[0];
var outputs_: [struct_info.fields.len]Tensor = undefined;
@memcpy(&outputs_, outputs_flat);
break :b outputs_;
},
.array => |array_info| b: {
if (array_info.len == 1) break :b outputs_flat[0];
var outputs_: [array_info.fields.len]Tensor = undefined;
@memcpy(&outputs_, outputs_flat);
break :b outputs_;
},
.pointer => |pointer_info| b: {
if (pointer_info.size != .slice) @compileError("Expected slice");
break :b outputs_flat;
},
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
};
}
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
const ctx = module.CompilationContext.current();
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
ctx.extractValues(inputs, values);
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
for (outputs, 0..) |output, i| {
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
}
const metadata_type_info = @typeInfo(@TypeOf(metadata));
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
};
metadata_attributes_tuple[i] = .{ field.name, attribute };
}
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
} ++ metadata_attributes_tuple));
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
for (inputs, 0..) |input, i| {
operands_layouts[i] = minorToMajor(input.rank());
}
const results_layouts = ctx.allocator().alloc([]const usize, outputs.len) catch unreachable;
for (outputs, 0..) |output, i| {
results_layouts[i] = minorToMajor(output.rank());
}
const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(),
values,
.{
.call_target_name = target_name,
.backend_config = backend_config,
.has_side_effect = true,
.api_version = .typed_ffi,
.operand_layouts = operands_layouts,
.result_layouts = results_layouts,
.output_operand_aliases = opts.output_operand_aliases orelse &.{},
},
res_types,
ctx.mlirCtx().location(@src()),
);
const outputs_ = ctx.allocator().alloc(Tensor, outputs.len) catch unreachable;
for (outputs, 0..) |output, i| {
outputs_[i] = Tensor._result(output, op.result(i));
}
return outputs_;
}
inline fn toI64(values: anytype) []i64 { inline fn toI64(values: anytype) []i64 {
var res: [Tensor.MAX_RANK]i64 = undefined; var res: [Tensor.MAX_RANK]i64 = undefined;
for (values, 0..) |val, i| res[i] = @intCast(val); for (values, 0..) |val, i| res[i] = @intCast(val);