diff --git a/zml/exe.zig b/zml/exe.zig index 2292b70..3f4b6aa 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -231,6 +231,35 @@ pub const BaseExe = struct { } } + pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void { + const LocalContext = struct { + index: u32, + platform: Platform, + outputs: []const [*]*pjrt.Buffer, + output_shapes: []Shape, + }; + var local_ctx: LocalContext = .{ + .index = 0, + .platform = self.platform, + .outputs = self.output_per_device, + .output_shapes = self.result_shapes, + }; + meta.visit((struct { + fn cb(ctx: *LocalContext, buffer: *Buffer) void { + const i = ctx.index; + ctx.index += 1; + if (i >= ctx.output_shapes.len) return; + + var shards: Buffer.Shards = .{}; + for (ctx.outputs) |buff| { + shards.appendAssumeCapacity(buff[i]); + } + buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice()); + } + }).cb, &local_ctx, result); + stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index }); + } + pub fn serialize(self: BaseExe, writer: anytype) !void { var executable = try self.exe.getExecutable(self.platform.pjrt_api); var serialize_result = try executable.serialize(self.platform.pjrt_api); @@ -305,7 +334,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type { std.debug.assert(total_ready == self.inner.input_buffer_count); self.inner._unsafeCall(); var result: Bufferized(ReturnT) = undefined; - assignRawBuffers(&result, self.inner.platform, self.inner.output_per_device, self.inner.result_shapes); + self.inner._unsafeAssignResults(Bufferized(ReturnT), &result); return result; } }; @@ -349,36 +378,6 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff return context.index; } -/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. -fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape) void { - const LocalContext = struct { - index: u32, - platform: Platform, - buffers: []const [*]*pjrt.Buffer, - buffer_shapes: []Shape, - }; - var local_ctx: LocalContext = .{ - .index = 0, - .platform = platform, - .buffers = buffers, - .buffer_shapes = buffer_shapes, - }; - meta.visit((struct { - fn cb(ctx: *LocalContext, buffer: *Buffer) void { - const i = ctx.index; - ctx.index += 1; - if (i >= ctx.buffer_shapes.len) return; - - var shards: Buffer.Shards = .{}; - for (ctx.buffers) |buff| { - shards.appendAssumeCapacity(buff[i]); - } - buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); - } - }).cb, &local_ctx, v); - stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); -} - fn prettyFnName( comptime func: anytype, allocator: std.mem.Allocator, diff --git a/zml/meta.zig b/zml/meta.zig index b3dcf68..483b03a 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -1,5 +1,6 @@ const std = @import("std"); const testing = std.testing; +const builtin = @import("builtin"); const stdx = @import("stdx"); const FnParam = stdx.meta.FnParam; @@ -9,6 +10,7 @@ test { std.testing.refAllDecls(@This()); } +/// Visit a given type `T` and replace all fields containing `From` by fields containing `To`. pub fn MapType(From: type, To: type) type { return struct { pub fn map(T: type) type { @@ -78,6 +80,32 @@ pub fn MapType(From: type, To: type) type { }; } +test MapType { + const A = struct { a: u32 }; + const B = struct { b: u32 }; + + const A2B = MapType(A, B); + + const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 }; + const struct_b = A2B.map(StructA){ + .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } }, + .maybe = null, + .one = .{ .b = 2 }, + .other = 43, + }; + _ = struct_b; + + // TODO(corendos) fixme, union_b should contains Bs not As. + const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 }; + const union_b = [_]A2B.map(UnionA){ + .{ .some = &[2]A{ .{ .a = 0 }, .{ .a = 1 } } }, + .{ .one = .{ .a = 2 } }, + .{ .maybe = null }, + .{ .other = 43 }, + }; + _ = union_b; +} + /// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct /// and calls the callback when it finds a `From` element, and writes it to the `to` struct. /// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it. @@ -136,10 +164,10 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam @field(from, field.name), &@field(to, field.name), ); - } else if (field.default_value) |_| { + } else if (field.default_value_ptr) |_| { @field(to, field.name) = null; } else { - stdx.debug.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name }); + stdx.debug.compileError("Mapping {} -> {} inside {} failed. Missing field {s} in {}", .{ From, To, FromStruct, field.name, ToStruct }); }, else => @field(to, field.name) = @field(from, field.name), } @@ -234,6 +262,136 @@ test mapAlloc { try testing.expectEqual(12, bb.static_slice[1].b); } +/// Visit a given type `T` and: +/// * replace all fields containing `From` by fields containing `To` +/// * drop all fields not containing any `From`. +/// The returned type will contains only `To` making it easy for the compiler to produce compact layout. +/// Used by `zml.Bufferized` to strip compile time arguments from a model struct. +pub fn MapRestrict(From: type, To: type) type { + return struct { + pub fn map(T: type) type { + switch (T) { + From => return To, + ?From => return ?To, + *From => return *To, + *const From => return *const To, + []From => return []To, + []const From => return []const To, + else => {}, + } + + if (!Contains(T, From)) return void; + + return switch (@typeInfo(T)) { + .@"struct" => |struct_infos| { + // We know that at least one of the struct field contains a From. + // We map each field individually. Fields without From and comptime fields are removed. + const fields = struct_infos.fields; + var num_fields: usize = 0; + + var struct_fields: [fields.len]std.builtin.Type.StructField = undefined; + for (fields) |field| { + if (!field.is_comptime and Contains(field.type, From)) { + const R = map(field.type); + if (R == field.type) { + struct_fields[num_fields] = field; + } else { + const name = if (struct_infos.is_tuple) struct_infos.fields[num_fields].name else field.name; + struct_fields[num_fields] = .{ + .name = name, + .type = R, + .default_value_ptr = null, + .is_comptime = false, + .alignment = @alignOf(R), + }; + // Handle the case `field: ?Tensor = null` + // Generic handling of default value is not possible. + if (R == ?To) { + struct_fields[num_fields].default_value_ptr = &@as(R, null); + } + } + num_fields += 1; + } + } + if (num_fields == 0) return void; + return @Type(.{ .@"struct" = .{ + .layout = .auto, + .fields = struct_fields[0..num_fields], + .decls = &.{}, + .is_tuple = struct_infos.is_tuple, + } }); + }, + .@"union" => |union_info| { + // We know that at least one of the union field contains a From. + // We map each field individually. Fields without From, are replaced by "void". + const fields = union_info.fields; + var union_fields: [fields.len]std.builtin.Type.UnionField = undefined; + for (0.., fields) |i, field| { + union_fields[i] = .{ + .name = field.name, + .type = map(field.type), + .alignment = 0, + }; + } + return @Type(.{ .@"union" = .{ + .layout = .auto, + .tag_type = union_info.tag_type, + .fields = union_fields[0..], + .decls = &.{}, + } }); + }, + .array => |arr_info| [arr_info.len]map(arr_info.child), + .pointer => |ptr_info| switch (ptr_info.size) { + .slice => if (ptr_info.is_const) + []const map(ptr_info.child) + else + []map(ptr_info.child), + .one => if (ptr_info.is_const) + *const map(ptr_info.child) + else + *map(ptr_info.child), + .many => if (ptr_info.is_const) + [*]const map(ptr_info.child) + else + [*]map(ptr_info.child), + .c => if (ptr_info.is_const) + [*c]map(ptr_info.child) + else + [*c]map(ptr_info.child), + }, + .optional => |opt_info| ?map(opt_info.child), + else => T, + }; + } + }; +} + +test MapRestrict { + const A = struct { a: u32 }; + const B = struct { b: u32 }; + + const A2B = MapRestrict(A, B); + + const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 }; + const struct_b = A2B.map(StructA){ + .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } }, + .maybe = null, + .one = .{ .b = 2 }, + // Note how struct_b doesn't even have a .other field now. + }; + _ = struct_b; + + const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 }; + const union_b = [_]A2B.map(UnionA){ + .{ .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } } }, + .{ .one = .{ .b = 2 } }, + .{ .maybe = null }, + // Note how union_b.other is void now. + .{ .other = {} }, + }; + _ = union_b; +} + /// Recursively visit the given struct and calls the callback for each K found. /// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it. pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { @@ -554,7 +712,7 @@ pub fn Contains(Haystack: type, T: type) bool { return switch (@typeInfo(Haystack)) { .@"struct" => |info| { inline for (info.fields) |field| { - if (Contains(field.type, T)) + if (!field.is_comptime and Contains(field.type, T)) return true; } return false; diff --git a/zml/module.zig b/zml/module.zig index 798b81c..9650010 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1011,21 +1011,21 @@ test FnCache { } }; - const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 }); - const nn: zml.Bufferized(NN) = .{ + const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 }); + const nn: zml.testing.BufferizedWithArgs(NN) = .{ .layers = .{ .{ - .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }), - .b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }), + .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }), + .b = try .fromArray(platform, [2]f16{ 0, 0 }), }, .{ - .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }), - .b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }), + .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }), + .b = try .fromArray(platform, [2]f16{ 10, 10 }), }, // third layer is different .{ - .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }), - .b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }), + .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }), + .b = try .fromArray(platform, [3]f16{ -10, -10, -10 }), }, }, }; @@ -1084,13 +1084,13 @@ test "FnCache with mixed integer/tensor" { } }; - const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 }); - const nn: zml.Bufferized(NN) = .{ + const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 }); + const nn: zml.testing.BufferizedWithArgs(NN) = .{ .layers = .{ - .{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) }, - .{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) }, + .{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }) }, + .{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }) }, // third layer has different shape - .{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) }, + .{ .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }) }, }, }; const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x }); diff --git a/zml/nn.zig b/zml/nn.zig index c6a8560..ffdb7a9 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -401,7 +401,7 @@ test "real/img" { { const mod = try zml.compileFn(std.testing.allocator, Fns.testSplitSeq, .{}, platform); defer mod.deinit(); - const ret = mod.call(.{}); + const ret = mod.call({}); try testing.expectEqual(20, ret.getValue(i32)); } const d_split_interleaved = try zml.testing.compileAndCall(platform, Fns.testSplitInterleaved, .{}); @@ -1106,11 +1106,11 @@ test sdpaMemEfficient { const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); defer rng_mask.deinit(); - // Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable. - const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd }); - const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd }); - const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd }); - const mask = rng_mask.call(undefined).withTags(.{ .q, .k }); + // Note: we pass void here, cause Rng.normal doesn't take any runtime inputs. + const q = rng.call({}).withTags(.{ .b, .h, .q, .hd }); + const k = rng.call({}).withTags(.{ .b, .h, .k, .hd }); + const v = rng.call({}).withTags(.{ .b, .h, .k, .hd }); + const mask = rng_mask.call({}).withTags(.{ .q, .k }); const ref_res = try zml.testing.compileAndCall( platform, @@ -1164,11 +1164,11 @@ test "sdpaMemEfficient transposed" { const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); defer rng_mask.deinit(); - // Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable. - const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd }); - const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd }); - const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd }); - const mask = rng_mask.call(undefined).withTags(.{ .q, .k }); + // Note: we pass void here, cause Rng.normal doesn't take any runtime inputs. + const q = rng.call({}).withTags(.{ .b, .q, .h, .hd }); + const k = rng.call({}).withTags(.{ .b, .k, .h, .hd }); + const v = rng.call({}).withTags(.{ .b, .k, .h, .hd }); + const mask = rng_mask.call({}).withTags(.{ .q, .k }); const ref_res = try zml.testing.compileAndCall( platform, @@ -1266,7 +1266,7 @@ test sampleTokens { const logits, const expected: i32 = logits_expected; var logits_buff = try zml.Buffer.fromArray(platform, logits); defer logits_buff.deinit(); - var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff }); + var sampled, rng_buff = mod.call(.{ logits_buff, rng_buff }); defer sampled.deinit(); try zml.testing.expectEqual(expected, try sampled.getValue(i32)); } @@ -1304,7 +1304,6 @@ pub const DynamicSamplingStrategy = struct { opts: Opts, ) !zml.Bufferized(DynamicSamplingStrategy) { return .{ - .max_top_k = 0, .top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32), .temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype), .top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype), diff --git a/zml/ops.zig b/zml/ops.zig index 1765888..da95536 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -108,13 +108,15 @@ test "simple while" { const zml = @import("zml.zig"); const platform = zml.testing.env(); - const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0}); - const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0}); - const counter: zml.Bufferized(CountInts) = .{ - .step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}), - .end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}), - }; - const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum }); + const res0, const res1 = try zml.testing.compileAndCall( + platform, + CountInts._fwd, + .{ + .{ .step = try .scalar(platform, 1, .i64), .end = try .scalar(platform, 10, .i64) }, + try .scalar(platform, 0, .i64), + try .scalar(platform, 0, .i64), + }, + ); const last_i = try res0.getValue(i64); const sum = try res1.getValue(i64); diff --git a/zml/tensor.zig b/zml/tensor.zig index ae06a77..51517ef 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -532,7 +532,7 @@ pub const Tensor = struct { } pub const Rng = struct { - _state: Tensor, + _state: Tensor = .{ ._shape = .init(.{2}, .u64), ._id = .{ .buffer_id = 0 } }, algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT, pub fn shape() ShapeOf(Rng) { @@ -542,10 +542,8 @@ pub const Tensor = struct { } pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) { - const bits: [2]u64 = @bitCast(seed); return .{ - ._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits), - .algorithm = undefined, + ._state = try Buffer.fromBytes(platform, Rng.shape()._state, std.mem.asBytes(&seed)), }; } @@ -643,17 +641,15 @@ pub const Tensor = struct { const platform = zml.testing.env(); // Compute stats over a uniform distribution on [-2, 10]. - const rand, const stats = try zml.testing.compileAndCall( + const rand, const stats = try zml.testing.compileAndCallWithTensors( platform, Stats.uniformStats, - .{ - try Rng.init(platform, 1234), - Shape.init(.{1024}, .f32), - .{ .min = -2, .max = 10 }, - }, + .{ Rng.shape(), zml.Shape.init(.{1024}, .f32), .{ .min = -2, .max = 10 } }, + .{try Rng.init(platform, 1234)}, ); + // Check the Rng state has been modified. - try std.testing.expect(try rand._state.getValue(i128) != 1234); + try std.testing.expect(try rand._state.getValue(u128) != 1234); // Check the mean and variance are close to theoritical values. const mean_ = try stats.mean.getValue(f32); @@ -746,9 +742,12 @@ pub const Tensor = struct { const platform = zml.testing.env(); const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 }; - const rand, const stats = try zml.testing.compileAndCall(platform, Stats.gumbelStats, .{ - try Rng.init(platform, 1234), try HostBuffer.fromArray(&tgt_dist).toDevice(platform), - }); + const rand, const stats = try zml.testing.compileAndCallWithTensors( + platform, + Stats.gumbelStats, + .{ Rng.shape(), zml.Shape.init(.{tgt_dist.len}, .f32) }, + .{ try Rng.init(platform, 1234), try .fromArray(platform, tgt_dist) }, + ); // Check the Rng state has been modified. try std.testing.expect(try rand._state.getValue(i128) != 1234); @@ -1620,15 +1619,15 @@ pub const Tensor = struct { }; { - const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } }); + const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } }); try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32)); } { - const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } }); + const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } }); try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32)); } { - const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } }); + const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } }); try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32)); } } @@ -3965,13 +3964,12 @@ test "Tensor.maxPool2d" { ); } -/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. pub fn Bufferized(comptime T: type) type { // TODO: we should strip out the non-buffer fields. // Currently it's confusing cause the Bufferized struct contains field that are never read. // Also it will simplify the layout of the Bufferized struct. // accelerating the calls to execute. - return meta.MapType(Tensor, Buffer).map(T); + return meta.MapRestrict(Tensor, Buffer).map(T); } /// Return a clone of a type with Tensors replaced by Shapes. diff --git a/zml/testing.zig b/zml/testing.zig index 8946f78..0dcf394 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -112,13 +112,30 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec return error.TestExpectedEqual; } +/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. +/// This is similar to zml.Bufferized, +/// but also keep every other fields that could be used during compilation. +/// +/// see `compileAndCall`. +pub fn BufferizedWithArgs(comptime T: type) type { + // TODO: we should strip out the non-buffer fields. + // Currently it's confusing cause the Bufferized struct contains field that are never read. + // Also it will simplify the layout of the Bufferized struct. + // accelerating the calls to execute. + return meta.MapType(zml.Tensor, zml.Buffer).map(T); +} + /// Compile a function and immediatly call it with the given buffers. /// The compiled module is discarded after the call. /// Useful during testing when a module is typically called only once. /// /// Note: `func` needs explicit types on all parameters. /// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types. -pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) { +pub fn compileAndCall( + platform: zml.Platform, + func: anytype, + buffer_and_args: BufferizedWithArgs(stdx.meta.FnArgs(func)), +) !zml.Bufferized(stdx.meta.FnResult(func)) { // This simplify test API and also ensure this fn isn't used outside of tests. const allocator = std.testing.allocator; var arena = std.heap.ArenaAllocator.init(allocator); @@ -130,18 +147,30 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu } }; var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined; - try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args); + try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_and_args, &shape_args); - const mod = try zml.compileFn(allocator, func, shape_args, platform); + var mod = try zml.compileFn(allocator, func, shape_args, platform); defer mod.deinit(); - return mod.call(buffer_args); + // Note: we don't use the type safe API of mod, + // cause mod.call expects a `zml.Bufferized` while we have `BufferizedWithArgs`. + mod.inner.prepare(buffer_and_args); + mod.inner._unsafeCall(); + + var result: zml.Bufferized(stdx.meta.FnResult(func)) = undefined; + mod.inner._unsafeAssignResults(@TypeOf(result), &result); + return result; } /// Compile a function and immediatly call it with the given buffers. /// The compiled module is discarded after the call. /// Useful during testing when a module is typically called only once. -pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) { +pub fn compileAndCallWithTensors( + platform: zml.Platform, + func: anytype, + shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), + buffer_args: zml.Bufferized(stdx.meta.FnArgs(func)), +) !zml.Bufferized(stdx.meta.FnResult(func)) { // This simplify test API and also ensure this fn isn't used outside of tests. const allocator = std.testing.allocator; var arena = std.heap.ArenaAllocator.init(allocator); diff --git a/zml/torch.zig b/zml/torch.zig index 94c9693..60f3c19 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -1,8 +1,8 @@ const std = @import("std"); + const stdx = @import("stdx"); const zml = @import("zml.zig"); - const Tensor = zml.Tensor; const log = std.log.scoped(.zml_torch); @@ -141,16 +141,14 @@ test pixelShuffle { const platform = zml.testing.env(); const upscale_factor = 3; - var digits: [9 * 4 * 4]i32 = undefined; - for (&digits, 0..) |*d, i| d.* = @intCast(i); - // TODO should we have tags in buffers ? - const input = try zml.Buffer.fromSlice(platform, .{ 1, 9, 4, 4 }, &digits); - const output = try zml.testing.compileAndCallWithTensors( - platform, - pixelShuffle, - .{ zml.Shape.init(.{ .batch_size = 1, .c = 9, .h = 4, .w = 4 }, .i32), upscale_factor }, - .{ input, upscale_factor }, - ); + const shape = zml.Shape.init(.{ .b = 1, .c = 9, .h = 4, .w = 4 }, .i32); + const input = input: { + var digits: [9 * 4 * 4]i32 = undefined; + for (&digits, 0..) |*d, i| d.* = @intCast(i); + break :input try zml.Buffer.fromSlice(platform, shape, &digits); + }; + + const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor }); const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{ .{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },