diff --git a/zml/context.zig b/zml/context.zig index 0255d49..94308ce 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -6,7 +6,7 @@ const runtimes = @import("runtimes"); const std = @import("std"); const stdx = @import("stdx"); -const platform = @import("platform.zig"); +const zml_platform = @import("platform.zig"); const pjrt = @import("pjrtx.zig"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; @@ -18,6 +18,10 @@ const Target = @import("platform.zig").Target; const available_targets = @import("platform.zig").available_targets; const log = std.log.scoped(.@"zml/context"); +test { + std.testing.refAllDecls(Context); +} + /// Every program using ZML must start with a `zml.Context.init(.{});` /// The ZML context contains global state to interact with the different /// devices available on your system. @@ -149,7 +153,7 @@ pub const Context = struct { return platform_ orelse @panic("No platform found !"); } - pub fn printAvailablePlatforms(self: Context, selected: platform.Platform) void { + pub fn printAvailablePlatforms(self: Context, selected: Platform) void { // List available targets log.info("Available Platforms:", .{}); const selected_prefix = "✅"; @@ -157,7 +161,7 @@ pub const Context = struct { const selected_postfix = "(AUTO-SELECTED)"; const not_selected_postfix = ""; - for (platform.available_targets) |target| { + for (zml_platform.available_targets) |target| { log.info(" {s} {s} {s}", .{ if (target == selected.target) selected_prefix else not_selected_prefix, @tagName(target), diff --git a/zml/exe.zig b/zml/exe.zig index 12a3fec..b34d5e9 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -241,8 +241,7 @@ pub const BaseExe = struct { shards.appendAssumeCapacity(dev_out[i]); } - const out_shape = self.inner.result_buffer_shapes[i]; - return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice()); + return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice()); } }; diff --git a/zml/meta.zig b/zml/meta.zig index 8a08792..2e62605 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -120,9 +120,11 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam return; } + if (@sizeOf(ToStruct) == 0) return; + switch (type_info_to) { .Struct => |info| inline for (info.fields) |field| { - // if (field.is_comptime) continue; + if (field.is_comptime or @sizeOf(field.type) == 0) continue; const field_type_info = @typeInfo(field.type); // If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field. switch (field_type_info) { @@ -187,6 +189,8 @@ test mapAlloc { } }; + const Empty = struct {}; + const AA = struct { field: A, array: [2]A, @@ -195,6 +199,7 @@ test mapAlloc { // We want to allow conversion from comptime to runtime, because Zig type inference works like this. comptime static_val: u8 = 8, comptime static_slice: [2]A = .{ .{ .a = 11 }, .{ .a = 12 } }, + field_with_empty: struct { A, Empty }, }; const BB = struct { field: B, @@ -203,6 +208,7 @@ test mapAlloc { other: u8, static_val: u8, static_slice: []B, + field_with_empty: struct { B, Empty }, }; const aa: AA = .{ @@ -210,6 +216,7 @@ test mapAlloc { .array = .{ .{ .a = 5 }, .{ .a = 6 } }, .other = 7, .slice = &.{ .{ .a = 9 }, .{ .a = 10 } }, + .field_with_empty = .{ .{ .a = 9 }, .{} }, }; var bb: BB = undefined; diff --git a/zml/testing.zig b/zml/testing.zig index eebc2ce..e5626d2 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -201,8 +201,8 @@ pub fn testLayerOut( const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform); const n_out_exp = activations.countLayers(out_name); - if (exe.inner.result_buffer_count != n_out_exp) { - log.warn("Reference models produces {d} outputs, but implementation produces {d}", .{ n_out_exp, exe.inner.result_buffer_count }); + if (exe.inner.result_shapes.len != n_out_exp) { + log.warn("Reference models produces {d} outputs, but implementation produces {d}", .{ n_out_exp, exe.inner.result_shapes.len }); } const mod = exe.prepare(layer_weights); @@ -243,13 +243,13 @@ pub fn testLayerOut( var buf: [1024]u8 = undefined; var failed: bool = false; - for (0..mod.inner.result_buffer_count) |i| { + for (0..mod.inner.result_shapes.len) |i| { const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable; const expected_out = activations.get(full_name) orelse { log.warn("Output buffer not found: {s}", .{full_name}); continue; }; - zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| switch (err) { + zml.testing.expectClose(expected_out, mod.inner.getOutputBuffer(i), tolerance) catch |err| switch (err) { error.TestUnexpectedResult => { log.err("{s}.{d} doesn't match !", .{ out_name, i }); failed = true; @@ -263,6 +263,34 @@ pub fn testLayerOut( log.info("all good for {s} !", .{name}); } +test testLayer { + const platform = env(); + + // create a model + const layer: zml.nn.Linear = .{ + .weight = zml.Tensor{ ._shape = zml.Shape.init(.{ 5, 2 }, .f32), ._id = .{ .buffer_id = 42 } }, + }; + const layer_weights: zml.Bufferized(zml.nn.Linear) = .{ + .weight = try zml.Buffer.fromArray( + platform, + [5][2]f32{ .{ 0, 0 }, .{ 0, 1 }, .{ 1, 2 }, .{ -1, -1 }, .{ -1, 0 } }, + ), + }; + + // create a buffer store containing the activations: + var activations = try zml.aio.BufferStore.init(std.testing.allocator, &.{}); + defer activations.deinit(); + { + const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 }); + try activations.buffers.put(activations.arena.allocator(), "model.layer.in.0", input); + const output = zml.HostBuffer.fromArray(&[5]f32{ 0, -1, -1, 0, -1 }); + try activations.buffers.put(activations.arena.allocator(), "model.layer.out.0", output); + } + + // test the ZML layer reproduces the "captured" activations: + try zml.testing.testLayer(platform, activations, "model.layer", layer, layer_weights, 1e-5); +} + pub inline fn expectEqual(expected: anytype, actual: @TypeOf(expected)) !void { return std.testing.expectEqual(expected, actual); } diff --git a/zml/zml.zig b/zml/zml.zig index 2f5b799..292708e 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -36,6 +36,7 @@ pub const compileFn = exe.compileFn; pub const compileModel = exe.compileModel; pub const FnExe = exe.FnExe; pub const ModuleExe = exe.ModuleExe; +pub const ModuleSignature = exe.ModuleSignature; pub const ops = @import("ops.zig"); pub const tools = struct {