From 09c43b8759f51aed525d7759866bf3522381d32a Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 9 Jan 2025 15:01:33 +0000 Subject: [PATCH] Add customCall operation to zml/ops. --- zml/ops.zig | 204 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 171 insertions(+), 33 deletions(-) diff --git a/zml/ops.zig b/zml/ops.zig index da95536..11a76cd 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -783,44 +783,19 @@ pub fn addHostCallback( output_shapes: []const Shape, opts: HostCallbackOpt, ) []Tensor { - const ctx = CompilationContext.current(); - - const mlir_ctx = ctx.mlirCtx(); - const backend_config = mlir.Attribute.dict(mlir_ctx, &.{ - .{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) }, - .{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) }, - }); - - const values = stdx.stackSlice(8, mlir.Value, inputs.len); - for (inputs, values) |i, *v| { - v.* = ctx.getValue(i.toMemory(.host_pinned)); - } - const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len); - for (res_types, output_shapes) |*r, o| { - r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type); - } - - const loc = ctx.mlirCtx().location(@src()); - const op = dialect.stablehlo.custom_call( - ctx.mlirCtx(), - values, + return customCall( + "zmlHostBufferCallback", + inputs, + output_shapes, + .{ + .callback = @intFromPtr(callback), + .user_context = @intFromPtr(blkctx), + }, .{ - .call_target_name = "zmlHostBufferCallback", - .api_version = .typed_ffi, - .backend_config = backend_config, .has_side_effect = opts.has_side_effect, .output_operand_aliases = opts.output_operand_aliases, }, - res_types, - loc, ); - - const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM"); - for (res, output_shapes, 0..) |*r, o, i| { - r.* = Tensor._result(o, op.result(i)).toMemory(.device); - } - - return res; } pub const TritonOps = struct { @@ -1248,6 +1223,169 @@ fn scatterPrepareIndices( return Tensor.stack(indices.constSlice(), .last, .coord); } +fn TensorOrTensorArray(comptime T: type) type { + const type_info = @typeInfo(T); + return switch (type_info) { + .@"struct" => |struct_info| b: { + if (T == Tensor) break :b Tensor; + if (!struct_info.is_tuple) @compileError("Expected tuple"); + break :b if (struct_info.fields.len == 1) + Tensor + else + [struct_info.fields.len]Tensor; + }, + .array => |array_info| b: { + break :b if (array_info.len == 1) + Tensor + else + [array_info.len]Tensor; + }, + .pointer => |pointer_info| b: { + if (pointer_info.size != .slice) @compileError("Expected slice"); + break :b []Tensor; + }, + else => @compileError("Unsupported type: " ++ @typeName(T)), + }; +} + +pub const CustomCallOptions = struct { + has_side_effect: bool, + output_operand_aliases: ?[]const i64 = null, +}; + +pub fn customCall(target_name: [:0]const u8, inputs: anytype, outputs: anytype, metadata: anytype, opts: CustomCallOptions) TensorOrTensorArray(@TypeOf(outputs)) { + // Transform generic inputs to flat slice. + const inputs_: []const Tensor = switch (@typeInfo(@TypeOf(inputs))) { + .@"struct" => |struct_info| b: { + if (@TypeOf(inputs) == Tensor) { + break :b &[1]Tensor{inputs}; + } + if (!struct_info.is_tuple) @compileError("Expected tuple"); + var inputs_: [struct_info.fields.len]Tensor = undefined; + meta.collectBuf((struct { + pub fn func(t: Tensor) Tensor { + return t; + } + }).func, {}, &inputs, &inputs_); + break :b &inputs_; + }, + .array => &inputs, + .pointer => |pointer_info| b: { + if (pointer_info.size != .slice) @compileError("Expected slice"); + break :b inputs; + }, + else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(inputs))), + }; + + // Transform generic outputs to flat slice. + const output_shapes: []const Shape = switch (@typeInfo(@TypeOf(outputs))) { + .@"struct" => |struct_info| b: { + if (@TypeOf(outputs) == Shape) { + break :b &[1]Shape{outputs}; + } + if (!struct_info.is_tuple) @compileError("Expected tuple"); + var output_shapes: [struct_info.fields.len]Shape = undefined; + meta.collectBuf((struct { + pub fn func(t: Shape) Shape { + return t; + } + }).func, {}, &outputs, &output_shapes); + break :b &output_shapes; + }, + .array => &outputs, + .pointer => |pointer_info| b: { + if (pointer_info.size != .slice) @compileError("Expected slice"); + break :b outputs; + }, + else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))), + }; + + const outputs_flat = customCallInternal(target_name, inputs_, output_shapes, metadata, opts); + + // Transform flat slice to generic outputs. + return switch (@typeInfo(@TypeOf(outputs))) { + .@"struct" => |struct_info| b: { + if (@TypeOf(outputs) == Shape) break :b outputs_flat[0]; + if (!struct_info.is_tuple) @compileError("Expected tuple"); + if (struct_info.fields.len == 1) break :b outputs_flat[0]; + var outputs_: [struct_info.fields.len]Tensor = undefined; + @memcpy(&outputs_, outputs_flat); + break :b outputs_; + }, + .array => |array_info| b: { + if (array_info.len == 1) break :b outputs_flat[0]; + var outputs_: [array_info.fields.len]Tensor = undefined; + @memcpy(&outputs_, outputs_flat); + break :b outputs_; + }, + .pointer => |pointer_info| b: { + if (pointer_info.size != .slice) @compileError("Expected slice"); + break :b outputs_flat; + }, + else => @compileError("Unsupported type: " ++ @typeName(@TypeOf(outputs))), + }; +} + +fn customCallInternal(target_name: [:0]const u8, inputs: []const Tensor, outputs: []const Shape, metadata: anytype, opts: CustomCallOptions) []Tensor { + const ctx = module.CompilationContext.current(); + + const values = ctx.allocator().alloc(mlir.Value, inputs.len) catch unreachable; + ctx.extractValues(inputs, values); + + const res_types = ctx.allocator().alloc(mlir.Type, outputs.len) catch unreachable; + for (outputs, 0..) |output, i| { + res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); + } + + const metadata_type_info = @typeInfo(@TypeOf(metadata)); + var metadata_attributes_tuple: [metadata_type_info.@"struct".fields.len]mlir.AttrTuple = undefined; + inline for (metadata_type_info.@"struct".fields, 0..) |field, i| { + const attribute: mlir.Attribute = switch (@typeInfo(field.type)) { + .int, .comptime_int => mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@field(metadata, field.name))), + else => @compileError("Unsupported metadata type: " ++ @typeName(field.type)), + }; + metadata_attributes_tuple[i] = .{ field.name, attribute }; + } + + const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &(.{ + .{ "pjrt_api", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_api))) }, + .{ "pjrt_client", mlir.Attribute.int(ctx.mlirCtx(), .u64, @bitCast(@intFromPtr(ctx._platform.pjrt_client))) }, + } ++ metadata_attributes_tuple)); + + const operands_layouts = ctx.allocator().alloc([]const usize, inputs.len) catch unreachable; + for (inputs, 0..) |input, i| { + operands_layouts[i] = minorToMajor(input.rank()); + } + + const results_layouts = ctx.allocator().alloc([]const usize, outputs.len) catch unreachable; + for (outputs, 0..) |output, i| { + results_layouts[i] = minorToMajor(output.rank()); + } + + const op = dialect.stablehlo.custom_call( + ctx.mlirCtx(), + values, + .{ + .call_target_name = target_name, + .backend_config = backend_config, + .has_side_effect = true, + .api_version = .typed_ffi, + .operand_layouts = operands_layouts, + .result_layouts = results_layouts, + .output_operand_aliases = opts.output_operand_aliases orelse &.{}, + }, + res_types, + ctx.mlirCtx().location(@src()), + ); + + const outputs_ = ctx.allocator().alloc(Tensor, outputs.len) catch unreachable; + for (outputs, 0..) |output, i| { + outputs_[i] = Tensor._result(output, op.result(i)); + } + + return outputs_; +} + inline fn toI64(values: anytype) []i64 { var res: [Tensor.MAX_RANK]i64 = undefined; for (values, 0..) |val, i| res[i] = @intCast(val);