Add customCall operation to zml/ops.
This commit is contained in:
parent
9f1cc762cd
commit
09c43b8759
204
zml/ops.zig
204
zml/ops.zig
@ -783,44 +783,19 @@ pub fn addHostCallback(
|
|||||||
output_shapes: []const Shape,
|
output_shapes: []const Shape,
|
||||||
opts: HostCallbackOpt,
|
opts: HostCallbackOpt,
|
||||||
) []Tensor {
|
) []Tensor {
|
||||||
const ctx = CompilationContext.current();
|
return customCall(
|
||||||
|
"zmlHostBufferCallback",
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
inputs,
|
||||||
const backend_config = mlir.Attribute.dict(mlir_ctx, &.{
|
output_shapes,
|
||||||
.{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) },
|
.{
|
||||||
.{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) },
|
.callback = @intFromPtr(callback),
|
||||||
});
|
.user_context = @intFromPtr(blkctx),
|
||||||
|
},
|
||||||
const values = stdx.stackSlice(8, mlir.Value, inputs.len);
|
|
||||||
for (inputs, values) |i, *v| {
|
|
||||||
v.* = ctx.getValue(i.toMemory(.host_pinned));
|
|
||||||
}
|
|
||||||
const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len);
|
|
||||||
for (res_types, output_shapes) |*r, o| {
|
|
||||||
r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type);
|
|
||||||
}
|
|
||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src());
|
|
||||||
const op = dialect.stablehlo.custom_call(
|
|
||||||
ctx.mlirCtx(),
|
|
||||||
values,
|
|
||||||
.{
|
.{
|
||||||
.call_target_name = "zmlHostBufferCallback",
|
|
||||||
.api_version = .typed_ffi,
|
|
||||||
.backend_config = backend_config,
|
|
||||||
.has_side_effect = opts.has_side_effect,
|
.has_side_effect = opts.has_side_effect,
|
||||||
.output_operand_aliases = opts.output_operand_aliases,
|
.output_operand_aliases = opts.output_operand_aliases,
|
||||||
},
|
},
|
||||||
res_types,
|
|
||||||
loc,
|
|
||||||
);
|
);
|
||||||
|
|
||||||
const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM");
|
|
||||||
for (res, output_shapes, 0..) |*r, o, i| {
|
|
||||||
r.* = Tensor._result(o, op.result(i)).toMemory(.device);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const TritonOps = struct {
|
pub const TritonOps = struct {
|
||||||
@ -1248,6 +1223,169 @@ fn scatterPrepareIndices(
|
|||||||
return Tensor.stack(indices.constSlice(), .last, .coord);
|
return Tensor.stack(indices.constSlice(), .last, .coord);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn TensorOrTensorArray(comptime T: type) type {
|
||||||
|
const type_info = @typeInfo(T);
|
||||||
|
return switch (type_info) {
|
||||||
|
.@"struct" => |struct_info| b: {
|
||||||
|
if (T == Tensor) break :b Tensor;
|
||||||
|
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||||
|
break :b if (struct_info.fields.len == 1)
|
||||||
|
Tensor
|
||||||
|
else
|
||||||
|
[struct_info.fields.len]Tensor;
|
||||||
|
},
|
||||||
|
.array => |array_info| b: {
|
||||||
|
break :b if (array_info.len == 1)
|
||||||
|
Tensor
|
||||||
|
else
|
||||||
|
[array_info.len]Tensor;
|
||||||
|
},
|
||||||
|
.pointer => |pointer_info| b: {
|
||||||
|
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||||
|
break :b []Tensor;
|
||||||
|
},
|
||||||
|
else => @compileError("Unsupported type: " ++ @typeName(T)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const CustomCallOptions = struct {
|
||||||
|
has_side_effect: bool,
|
||||||
|
output_operand_aliases: ?[]const i64 = null,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype, metadata: anytype, opts: CustomCallOptions) TensorOrTensorArray(@TypeOf(outputs)) {
|
||||||
|
// Transform generic inputs to flat slice.
|
||||||
|
const inputs_: []const Tensor = switch (@typeInfo(@TypeOf(inputs))) {
|
||||||
|
.@"struct" => |struct_info| b: {
|
||||||
|
if (@TypeOf(inputs) == Tensor) {
|
||||||
|
break :b &[1]Tensor{inputs};
|
||||||
|
}
|
||||||
|
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||||
|
var inputs_: [struct_info.fields.len]Tensor = undefined;
|
||||||
|
meta.collectBuf((struct {
|
||||||
|
pub fn func(t: Tensor) Tensor {
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}).func, {}, &inputs, &inputs_);
|
||||||
|
break :b &inputs_;
|
||||||
|
},
|
||||||
|
.array => &inputs,
|
||||||
|
.pointer => |pointer_info| b: {
|
||||||
|
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||||
|
break :b inputs;
|
||||||
|
},
|
||||||
|
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(inputs))),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Transform generic outputs to flat slice.
|
||||||
|
const output_shapes: []const Shape = switch (@typeInfo(@TypeOf(outputs))) {
|
||||||
|
.@"struct" => |struct_info| b: {
|
||||||
|
if (@TypeOf(outputs) == Shape) {
|
||||||
|
break :b &[1]Shape{outputs};
|
||||||
|
}
|
||||||
|
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||||
|
var output_shapes: [struct_info.fields.len]Shape = undefined;
|
||||||
|
meta.collectBuf((struct {
|
||||||
|
pub fn func(t: Shape) Shape {
|
||||||
|
return t;
|
||||||
|
}
|
||||||
|
}).func, {}, &outputs, &output_shapes);
|
||||||
|
break :b &output_shapes;
|
||||||
|
},
|
||||||
|
.array => &outputs,
|
||||||
|
.pointer => |pointer_info| b: {
|
||||||
|
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||||
|
break :b outputs;
|
||||||
|
},
|
||||||
|
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
|
||||||
|
};
|
||||||
|
|
||||||
|
const outputs_flat = customCallInternal(target_name, inputs_, output_shapes, metadata, opts);
|
||||||
|
|
||||||
|
// Transform flat slice to generic outputs.
|
||||||
|
return switch (@typeInfo(@TypeOf(outputs))) {
|
||||||
|
.@"struct" => |struct_info| b: {
|
||||||
|
if (@TypeOf(outputs) == Shape) break :b outputs_flat[0];
|
||||||
|
if (!struct_info.is_tuple) @compileError("Expected tuple");
|
||||||
|
if (struct_info.fields.len == 1) break :b outputs_flat[0];
|
||||||
|
var outputs_: [struct_info.fields.len]Tensor = undefined;
|
||||||
|
@memcpy(&outputs_, outputs_flat);
|
||||||
|
break :b outputs_;
|
||||||
|
},
|
||||||
|
.array => |array_info| b: {
|
||||||
|
if (array_info.len == 1) break :b outputs_flat[0];
|
||||||
|
var outputs_: [array_info.fields.len]Tensor = undefined;
|
||||||
|
@memcpy(&outputs_, outputs_flat);
|
||||||
|
break :b outputs_;
|
||||||
|
},
|
||||||
|
.pointer => |pointer_info| b: {
|
||||||
|
if (pointer_info.size != .slice) @compileError("Expected slice");
|
||||||
|
break :b outputs_flat;
|
||||||
|
},
|
||||||
|
else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor {
|
||||||
|
const ctx = module.CompilationContext.current();
|
||||||
|
|
||||||
|
const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable;
|
||||||
|
ctx.extractValues(inputs, values);
|
||||||
|
|
||||||
|
const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable;
|
||||||
|
for (outputs, 0..) |output, i| {
|
||||||
|
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||||
|
}
|
||||||
|
|
||||||
|
const metadata_type_info = @typeInfo(@TypeOf(metadata));
|
||||||
|
var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined;
|
||||||
|
inline for (metadata_type_info.@"struct".fields, 0..) |field, i| {
|
||||||
|
const attribute: mlir.Attribute = switch (@typeInfo(field.type)) {
|
||||||
|
.int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))),
|
||||||
|
else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)),
|
||||||
|
};
|
||||||
|
metadata_attributes_tuple[i] = .{ field.name, attribute };
|
||||||
|
}
|
||||||
|
|
||||||
|
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{
|
||||||
|
.{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) },
|
||||||
|
.{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) },
|
||||||
|
} ++ metadata_attributes_tuple));
|
||||||
|
|
||||||
|
const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable;
|
||||||
|
for (inputs, 0..) |input, i| {
|
||||||
|
operands_layouts[i] = minorToMajor(input.rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
const results_layouts = ctx.allocator().alloc([]const usize, outputs.len) catch unreachable;
|
||||||
|
for (outputs, 0..) |output, i| {
|
||||||
|
results_layouts[i] = minorToMajor(output.rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
const op = dialect.stablehlo.custom_call(
|
||||||
|
ctx.mlirCtx(),
|
||||||
|
values,
|
||||||
|
.{
|
||||||
|
.call_target_name = target_name,
|
||||||
|
.backend_config = backend_config,
|
||||||
|
.has_side_effect = true,
|
||||||
|
.api_version = .typed_ffi,
|
||||||
|
.operand_layouts = operands_layouts,
|
||||||
|
.result_layouts = results_layouts,
|
||||||
|
.output_operand_aliases = opts.output_operand_aliases orelse &.{},
|
||||||
|
},
|
||||||
|
res_types,
|
||||||
|
ctx.mlirCtx().location(@src()),
|
||||||
|
);
|
||||||
|
|
||||||
|
const outputs_ = ctx.allocator().alloc(Tensor, outputs.len) catch unreachable;
|
||||||
|
for (outputs, 0..) |output, i| {
|
||||||
|
outputs_[i] = Tensor._result(output, op.result(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputs_;
|
||||||
|
}
|
||||||
|
|
||||||
inline fn toI64(values: anytype) []i64 {
|
inline fn toI64(values: anytype) []i64 {
|
||||||
var res: [Tensor.MAX_RANK]i64 = undefined;
|
var res: [Tensor.MAX_RANK]i64 = undefined;
|
||||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user