Fix testLayer by removing unnecessary compile_options argument and updating testing logic for new sharded output, ensuring proper usage by llama.zig.

This commit is contained in:
Tarry Singh 2023-03-31 14:23:45 +00:00
parent 05d23beb23
commit 66881899ca
3 changed files with 79 additions and 9 deletions

View File

@ -276,7 +276,7 @@ fn _populateStruct(
}; };
} }
switch (type_info) { return switch (type_info) {
.Pointer => |ptr_info| { .Pointer => |ptr_info| {
if (ptr_info.size == .Slice) { if (ptr_info.size == .Slice) {
obj.* = &.{}; obj.* = &.{};
@ -288,7 +288,6 @@ fn _populateStruct(
for (obj.*, 0..) |*value, i| { for (obj.*, 0..) |*value, i| {
try prefix_builder.pushDigit(allocator, i); try prefix_builder.pushDigit(allocator, i);
defer prefix_builder.pop(); defer prefix_builder.pop();
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
if (!found) { if (!found) {
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) }); std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) });
@ -307,12 +306,12 @@ fn _populateStruct(
.Struct => |struct_info| { .Struct => |struct_info| {
var partial_struct = false; var partial_struct = false;
inline for (struct_info.fields) |field| { inline for (struct_info.fields) |field| {
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
try prefix_builder.push(allocator, field.name); try prefix_builder.push(allocator, field.name);
defer prefix_builder.pop(); defer prefix_builder.pop();
var has_default = false; var has_default = false;
if (field.default_value) |_| has_default = true; if (field.default_value) |_| has_default = true;
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default); const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
partial_struct = partial_struct or field_found; partial_struct = partial_struct or field_found;
if (!field_found) { if (!field_found) {
@ -345,11 +344,64 @@ fn _populateStruct(
obj.* = undefined; obj.* = undefined;
return true; return true;
}, },
.Void => true,
else => if (required) { else => if (required) {
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) }); std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
return error.UnsupportedMetadataType; return error.UnsupportedMetadataType;
} else return false, } else return false,
};
}
test populateModel {
const Model = struct {
a: zml.Tensor,
b: struct { a: zml.Tensor, b: u32 },
c: []zml.Tensor,
d: []struct { a: zml.Tensor, b: u32 },
e: struct { zml.Tensor, u32, struct { a: u32, b: zml.Tensor, c: void } },
f: ?zml.Tensor,
g: ?zml.Tensor,
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
fn _newHostBuffer(n: u32) zml.HostBuffer {
return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined };
} }
};
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena_state.deinit();
var store: BufferStore = .{ .arena = arena_state };
try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16);
store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10));
store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20));
store.buffers.putAssumeCapacity("c.0", Model._newHostBuffer(30));
store.buffers.putAssumeCapacity("c.1", Model._newHostBuffer(31));
store.buffers.putAssumeCapacity("c.2", Model._newHostBuffer(32));
store.buffers.putAssumeCapacity("d.0.a", Model._newHostBuffer(40));
store.buffers.putAssumeCapacity("d.1.a", Model._newHostBuffer(41));
store.buffers.putAssumeCapacity("d.2.a", Model._newHostBuffer(42));
store.buffers.putAssumeCapacity("e.0", Model._newHostBuffer(50));
store.buffers.putAssumeCapacity("e.2.b", Model._newHostBuffer(51));
store.buffers.putAssumeCapacity("f", Model._newHostBuffer(60));
// no entry for g.
store.buffers.putAssumeCapacity("unused_entry", Model._newHostBuffer(1000));
const model = try populateModel(Model, arena_state.allocator(), store);
try std.testing.expectEqual(10, model.a.dim(0));
try std.testing.expectEqual(20, model.b.a.dim(0));
try std.testing.expectEqual(3, model.c.len);
try std.testing.expectEqual(30, model.c[0].dim(0));
try std.testing.expectEqual(31, model.c[1].dim(0));
try std.testing.expectEqual(32, model.c[2].dim(0));
try std.testing.expectEqual(3, model.d.len);
try std.testing.expectEqual(40, model.d[0].a.dim(0));
try std.testing.expectEqual(41, model.d[1].a.dim(0));
try std.testing.expectEqual(42, model.d[2].a.dim(0));
try std.testing.expectEqual(50, model.e[0].dim(0));
try std.testing.expectEqual(51, model.e[2].b.dim(0));
try std.testing.expectEqual(60, model.f.?.dim(0));
try std.testing.expectEqual(null, model.g);
} }
/// Creates a bufferized version of a Model from the given BufferStore. For details about /// Creates a bufferized version of a Model from the given BufferStore. For details about
@ -476,6 +528,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
}, },
.Struct => |struct_info| { .Struct => |struct_info| {
inline for (struct_info.fields) |field| { inline for (struct_info.fields) |field| {
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
try prefix_builder.push(allocator, field.name); try prefix_builder.push(allocator, field.name);
defer prefix_builder.pop(); defer prefix_builder.pop();

View File

@ -847,6 +847,16 @@ pub fn ExeWithWeights(comptime func: anytype) type {
return self.inner.platform; return self.inner.platform;
} }
pub fn getOutputBuffer(self: Self, i: usize) Buffer {
var shards: Buffer.Shards = .{};
for (self.output_per_device) |dev_out| {
shards.appendAssumeCapacity(dev_out[i]);
}
const out_shape = self.inner.result_buffer_shapes[i];
return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice());
}
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) { pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count); fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
var event: [1]*pjrt.Event = undefined; var event: [1]*pjrt.Event = undefined;

View File

@ -145,8 +145,15 @@ pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_ar
return mod.call(buffer_args); return mod.call(buffer_args);
} }
pub fn testLayer(platform: zml.Platform, buffer_store: zml.aio.BufferStore, comptime name: []const u8, layer: anytype, layer_weights: zml.Bufferized(@TypeOf(layer)), tolerance: f32) !void { pub fn testLayer(
try testLayerOut(platform, buffer_store, name, name ++ ".out", layer, layer_weights, tolerance); platform: zml.Platform,
activations: zml.aio.BufferStore,
comptime name: []const u8,
layer: anytype,
layer_weights: zml.Bufferized(@TypeOf(layer)),
tolerance: f32,
) !void {
try testLayerOut(platform, activations, name, name ++ ".out", layer, layer_weights, tolerance);
} }
pub fn testLayerOut( pub fn testLayerOut(
@ -214,18 +221,18 @@ pub fn testLayerOut(
try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32); try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32);
fetch_ctx.prefix.appendSliceAssumeCapacity(name ++ ".in."); fetch_ctx.prefix.appendSliceAssumeCapacity(name ++ ".in.");
try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers); try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers);
defer zml.aio.unloadBuffers(input_buffers); defer zml.aio.unloadBuffers(&input_buffers);
_ = mod.call(input_buffers); _ = mod.call(input_buffers);
} }
var buf: [1024]u8 = undefined; var buf: [1024]u8 = undefined;
for (mod.output_buffers, 0..) |out, i| { for (0..mod.inner.result_buffer_count) |i| {
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable; const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
const expected_out = activations.get(full_name) orelse { const expected_out = activations.get(full_name) orelse {
log.warn("Output buffer not found: {s}", .{full_name}); log.warn("Output buffer not found: {s}", .{full_name});
continue; continue;
}; };
zml.testing.expectClose(expected_out, zml.Buffer.fromPjrtBuffer(platform, out), tolerance) catch |err| { zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| {
log.err("{s}.{d} doesn't match !", .{ out_name, i }); log.err("{s}.{d} doesn't match !", .{ out_name, i });
return err; return err;
}; };