diff --git a/zml/aio.zig b/zml/aio.zig index 3a5ca4d..61d928d 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -276,7 +276,7 @@ fn _populateStruct( }; } - switch (type_info) { + return switch (type_info) { .Pointer => |ptr_info| { if (ptr_info.size == .Slice) { obj.* = &.{}; @@ -288,7 +288,6 @@ fn _populateStruct( for (obj.*, 0..) |*value, i| { try prefix_builder.pushDigit(allocator, i); defer prefix_builder.pop(); - const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); if (!found) { std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) }); @@ -307,12 +306,12 @@ fn _populateStruct( .Struct => |struct_info| { var partial_struct = false; inline for (struct_info.fields) |field| { + if (field.is_comptime or @sizeOf(field.type) == 0) continue; try prefix_builder.push(allocator, field.name); defer prefix_builder.pop(); var has_default = false; if (field.default_value) |_| has_default = true; - const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default); partial_struct = partial_struct or field_found; if (!field_found) { @@ -345,11 +344,64 @@ fn _populateStruct( obj.* = undefined; return true; }, + .Void => true, else => if (required) { std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) }); return error.UnsupportedMetadataType; } else return false, - } + }; +} + +test populateModel { + const Model = struct { + a: zml.Tensor, + b: struct { a: zml.Tensor, b: u32 }, + c: []zml.Tensor, + d: []struct { a: zml.Tensor, b: u32 }, + e: struct { zml.Tensor, u32, struct { a: u32, b: zml.Tensor, c: void } }, + f: ?zml.Tensor, + g: ?zml.Tensor, + + // Create a fake HostBuffer, we use the given integer to identify the created buffer. + fn _newHostBuffer(n: u32) zml.HostBuffer { + return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined }; + } + }; + + var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena_state.deinit(); + var store: BufferStore = .{ .arena = arena_state }; + try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16); + store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10)); + store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20)); + store.buffers.putAssumeCapacity("c.0", Model._newHostBuffer(30)); + store.buffers.putAssumeCapacity("c.1", Model._newHostBuffer(31)); + store.buffers.putAssumeCapacity("c.2", Model._newHostBuffer(32)); + store.buffers.putAssumeCapacity("d.0.a", Model._newHostBuffer(40)); + store.buffers.putAssumeCapacity("d.1.a", Model._newHostBuffer(41)); + store.buffers.putAssumeCapacity("d.2.a", Model._newHostBuffer(42)); + store.buffers.putAssumeCapacity("e.0", Model._newHostBuffer(50)); + store.buffers.putAssumeCapacity("e.2.b", Model._newHostBuffer(51)); + store.buffers.putAssumeCapacity("f", Model._newHostBuffer(60)); + // no entry for g. + store.buffers.putAssumeCapacity("unused_entry", Model._newHostBuffer(1000)); + + const model = try populateModel(Model, arena_state.allocator(), store); + + try std.testing.expectEqual(10, model.a.dim(0)); + try std.testing.expectEqual(20, model.b.a.dim(0)); + try std.testing.expectEqual(3, model.c.len); + try std.testing.expectEqual(30, model.c[0].dim(0)); + try std.testing.expectEqual(31, model.c[1].dim(0)); + try std.testing.expectEqual(32, model.c[2].dim(0)); + try std.testing.expectEqual(3, model.d.len); + try std.testing.expectEqual(40, model.d[0].a.dim(0)); + try std.testing.expectEqual(41, model.d[1].a.dim(0)); + try std.testing.expectEqual(42, model.d[2].a.dim(0)); + try std.testing.expectEqual(50, model.e[0].dim(0)); + try std.testing.expectEqual(51, model.e[2].b.dim(0)); + try std.testing.expectEqual(60, model.f.?.dim(0)); + try std.testing.expectEqual(null, model.g); } /// Creates a bufferized version of a Model from the given BufferStore. For details about @@ -476,6 +528,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi }, .Struct => |struct_info| { inline for (struct_info.fields) |field| { + if (field.is_comptime or @sizeOf(field.type) == 0) continue; try prefix_builder.push(allocator, field.name); defer prefix_builder.pop(); diff --git a/zml/module.zig b/zml/module.zig index 69f2c95..75fd31f 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -847,6 +847,16 @@ pub fn ExeWithWeights(comptime func: anytype) type { return self.inner.platform; } + pub fn getOutputBuffer(self: Self, i: usize) Buffer { + var shards: Buffer.Shards = .{}; + for (self.output_per_device) |dev_out| { + shards.appendAssumeCapacity(dev_out[i]); + } + + const out_shape = self.inner.result_buffer_shapes[i]; + return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice()); + } + pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); var event: [1]*pjrt.Event = undefined; diff --git a/zml/testing.zig b/zml/testing.zig index df3c04e..5a3a3bc 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -145,8 +145,15 @@ pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_ar return mod.call(buffer_args); } -pub fn testLayer(platform: zml.Platform, buffer_store: zml.aio.BufferStore, comptime name: []const u8, layer: anytype, layer_weights: zml.Bufferized(@TypeOf(layer)), tolerance: f32) !void { - try testLayerOut(platform, buffer_store, name, name ++ ".out", layer, layer_weights, tolerance); +pub fn testLayer( + platform: zml.Platform, + activations: zml.aio.BufferStore, + comptime name: []const u8, + layer: anytype, + layer_weights: zml.Bufferized(@TypeOf(layer)), + tolerance: f32, +) !void { + try testLayerOut(platform, activations, name, name ++ ".out", layer, layer_weights, tolerance); } pub fn testLayerOut( @@ -214,18 +221,18 @@ pub fn testLayerOut( try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32); fetch_ctx.prefix.appendSliceAssumeCapacity(name ++ ".in."); try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers); - defer zml.aio.unloadBuffers(input_buffers); + defer zml.aio.unloadBuffers(&input_buffers); _ = mod.call(input_buffers); } var buf: [1024]u8 = undefined; - for (mod.output_buffers, 0..) |out, i| { + for (0..mod.inner.result_buffer_count) |i| { const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable; const expected_out = activations.get(full_name) orelse { log.warn("Output buffer not found: {s}", .{full_name}); continue; }; - zml.testing.expectClose(expected_out, zml.Buffer.fromPjrtBuffer(platform, out), tolerance) catch |err| { + zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| { log.err("{s}.{d} doesn't match !", .{ out_name, i }); return err; };