Add customCall operation to zml/ops.
This commit is contained in:
parent
9f1cc762cd
commit
09c43b8759
204
zml/ops.zig
204
zml/ops.zig
@ -783,44 +783,19 @@ pub fn addHostCallback(
|
||||
output_shapes: []const Shape,
|
||||
opts: HostCallbackOpt,
|
||||
) []Tensor {
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
const backend_config = mlir.Attribute.dict(mlir_ctx, &.{
|
||||
.{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) },
|
||||
.{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) },
|
||||
});
|
||||
|
||||
const values = stdx.stackSlice(8, mlir.Value, inputs.len);
|
||||
for (inputs, values) |i, *v| {
|
||||
v.* = ctx.getValue(i.toMemory(.host_pinned));
|
||||
}
|
||||
const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len);
|
||||
for (res_types, output_shapes) |*r, o| {
|
||||
r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type);
|
||||
}
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
values,
|
||||
return customCall(
|
||||
"zmlHostBufferCallback",
|
||||
inputs,
|
||||
output_shapes,
|
||||
.{
|
||||
.callback = @intFromPtr(callback),
|
||||
.user_context = @intFromPtr(blkctx),
|
||||
},
|
||||
.{
|
||||
.call_target_name = "zmlHostBufferCallback",
|
||||
.api_version = .typed_ffi,
|
||||
.backend_config = backend_config,
|
||||
.has_side_effect = opts.has_side_effect,
|
||||
.output_operand_aliases = opts.output_operand_aliases,
|
||||
},
|
||||
res_types,
|
||||
loc,
|
||||
);
|
||||
|
||||
const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM");
|
||||
for (res, output_shapes, 0..) |*r, o, i| {
|
||||
r.* = Tensor._result(o, op.result(i)).toMemory(.device);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
pub const TritonOps = struct {
|
||||
@ -1248,6 +1223,169 @@ fn scatterPrepareIndices(
|
||||
return Tensor.stack(indices.constSlice(), .last, .coord);
|
||||
}
|
||||
|
||||
fn TensorOrTensorArray(comptime T: type) type {
|
||||
const type_info = @typeInfo(T);
|
||||
return switch (type_info) {
|
||||
.@"struct" => |struct_info| b: {
|
||||
if (T == Tensor) break :b Tensor;
|
||||
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||
break :b if (struct_info.fields.len == 1)
|
||||
Tensor
|
||||
else
|
||||
[struct_info.fields.len]Tensor;
|
||||
},
|
||||
.array => |array_info| b: {
|
||||
break :b if (array_info.len == 1)
|
||||
Tensor
|
||||
else
|
||||
[array_info.len]Tensor;
|
||||
},
|
||||
.pointer => |pointer_info| b: {
|
||||
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||
break :b []Tensor;
|
||||
},
|
||||
else => @compileError("Unsupported type: " ++ @typeName(T)),
|
||||
};
|
||||
}
|
||||
|
||||
pub const CustomCallOptions = struct {
|
||||
has_side_effect: bool,
|
||||
output_operand_aliases: ?[]const i64 = null,
|
||||
};
|
||||
|
||||
pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype, metadata: anytype, opts: CustomCallOptions) TensorOrTensorArray(@TypeOf(outputs)) {
|
||||
// Transform generic inputs to flat slice.
|
||||
const inputs_: []const Tensor = switch (@typeInfo(@TypeOf(inputs))) {
|
||||
.@"struct" => |struct_info| b: {
|
||||
if (@TypeOf(inputs) == Tensor) {
|
||||
break :b &[1]Tensor{inputs};
|
||||
}
|
||||
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||
var inputs_: [struct_info.fields.len]Tensor = undefined;
|
||||
meta.collectBuf((struct {
|
||||
pub fn func(t: Tensor) Tensor {
|
||||
return t;
|
||||
}
|
||||
}).func, {}, &inputs, &inputs_);
|
||||
break :b &inputs_;
|
||||
},
|
||||
.array => &inputs,
|
||||
.pointer => |pointer_info| b: {
|
||||
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||
break :b inputs;
|
||||
},
|
||||
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(inputs))),
|
||||
};
|
||||
|
||||
// Transform generic outputs to flat slice.
|
||||
const output_shapes: []const Shape = switch (@typeInfo(@TypeOf(outputs))) {
|
||||
.@"struct" => |struct_info| b: {
|
||||
if (@TypeOf(outputs) == Shape) {
|
||||
break :b &[1]Shape{outputs};
|
||||
}
|
||||
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||
var output_shapes: [struct_info.fields.len]Shape = undefined;
|
||||
meta.collectBuf((struct {
|
||||
pub fn func(t: Shape) Shape {
|
||||
return t;
|
||||
}
|
||||
}).func, {}, &outputs, &output_shapes);
|
||||
break :b &output_shapes;
|
||||
},
|
||||
.array => &outputs,
|
||||
.pointer => |pointer_info| b: {
|
||||
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||
break :b outputs;
|
||||
},
|
||||
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
|
||||
};
|
||||
|
||||
const outputs_flat = customCallInternal(target_name, inputs_, output_shapes, metadata, opts);
|
||||
|
||||
// Transform flat slice to generic outputs.
|
||||
return switch (@typeInfo(@TypeOf(outputs))) {
|
||||
.@"struct" => |struct_info| b: {
|
||||
if (@TypeOf(outputs) == Shape) break :b outputs_flat[0];
|
||||
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||
if (struct_info.fields.len == 1) break :b outputs_flat[0];
|
||||
var outputs_: [struct_info.fields.len]Tensor = undefined;
|
||||
@memcpy(&outputs_, outputs_flat);
|
||||
break :b outputs_;
|
||||
},
|
||||
.array => |array_info| b: {
|
||||
if (array_info.len == 1) break :b outputs_flat[0];
|
||||
var outputs_: [array_info.fields.len]Tensor = undefined;
|
||||
@memcpy(&outputs_, outputs_flat);
|
||||
break :b outputs_;
|
||||
},
|
||||
.pointer => |pointer_info| b: {
|
||||
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||
break :b outputs_flat;
|
||||
},
|
||||
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
|
||||
};
|
||||
}
|
||||
|
||||
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
|
||||
const ctx = module.CompilationContext.current();
|
||||
|
||||
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
|
||||
ctx.extractValues(inputs, values);
|
||||
|
||||
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
|
||||
for (outputs, 0..) |output, i| {
|
||||
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||
}
|
||||
|
||||
const metadata_type_info = @typeInfo(@TypeOf(metadata));
|
||||
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
|
||||
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
|
||||
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
|
||||
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
|
||||
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
|
||||
};
|
||||
metadata_attributes_tuple[i] = .{ field.name, attribute };
|
||||
}
|
||||
|
||||
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{
|
||||
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
|
||||
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
|
||||
} ++ metadata_attributes_tuple));
|
||||
|
||||
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
|
||||
for (inputs, 0..) |input, i| {
|
||||
operands_layouts[i] = minorToMajor(input.rank());
|
||||
}
|
||||
|
||||
const results_layouts = ctx.allocator().alloc([]const usize, outputs.len) catch unreachable;
|
||||
for (outputs, 0..) |output, i| {
|
||||
results_layouts[i] = minorToMajor(output.rank());
|
||||
}
|
||||
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
ctx.mlirCtx(),
|
||||
values,
|
||||
.{
|
||||
.call_target_name = target_name,
|
||||
.backend_config = backend_config,
|
||||
.has_side_effect = true,
|
||||
.api_version = .typed_ffi,
|
||||
.operand_layouts = operands_layouts,
|
||||
.result_layouts = results_layouts,
|
||||
.output_operand_aliases = opts.output_operand_aliases orelse &.{},
|
||||
},
|
||||
res_types,
|
||||
ctx.mlirCtx().location(@src()),
|
||||
);
|
||||
|
||||
const outputs_ = ctx.allocator().alloc(Tensor, outputs.len) catch unreachable;
|
||||
for (outputs, 0..) |output, i| {
|
||||
outputs_[i] = Tensor._result(output, op.result(i));
|
||||
}
|
||||
|
||||
return outputs_;
|
||||
}
|
||||
|
||||
inline fn toI64(values: anytype) []i64 {
|
||||
var res: [Tensor.MAX_RANK]i64 = undefined;
|
||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||
|
||||
Loading…
Reference in New Issue
Block a user