Fix testLayer by removing unnecessary compile_options argument and updating testing logic for new sharded output, ensuring proper usage by llama.zig.
This commit is contained in:
parent
05d23beb23
commit
66881899ca
61
zml/aio.zig
61
zml/aio.zig
@ -276,7 +276,7 @@ fn _populateStruct(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (type_info) {
|
return switch (type_info) {
|
||||||
.Pointer => |ptr_info| {
|
.Pointer => |ptr_info| {
|
||||||
if (ptr_info.size == .Slice) {
|
if (ptr_info.size == .Slice) {
|
||||||
obj.* = &.{};
|
obj.* = &.{};
|
||||||
@ -288,7 +288,6 @@ fn _populateStruct(
|
|||||||
for (obj.*, 0..) |*value, i| {
|
for (obj.*, 0..) |*value, i| {
|
||||||
try prefix_builder.pushDigit(allocator, i);
|
try prefix_builder.pushDigit(allocator, i);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
||||||
if (!found) {
|
if (!found) {
|
||||||
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) });
|
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) });
|
||||||
@ -307,12 +306,12 @@ fn _populateStruct(
|
|||||||
.Struct => |struct_info| {
|
.Struct => |struct_info| {
|
||||||
var partial_struct = false;
|
var partial_struct = false;
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
|
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
var has_default = false;
|
var has_default = false;
|
||||||
if (field.default_value) |_| has_default = true;
|
if (field.default_value) |_| has_default = true;
|
||||||
|
|
||||||
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
|
const field_found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, &@field(obj, field.name), required and !has_default);
|
||||||
partial_struct = partial_struct or field_found;
|
partial_struct = partial_struct or field_found;
|
||||||
if (!field_found) {
|
if (!field_found) {
|
||||||
@ -345,11 +344,64 @@ fn _populateStruct(
|
|||||||
obj.* = undefined;
|
obj.* = undefined;
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
|
.Void => true,
|
||||||
else => if (required) {
|
else => if (required) {
|
||||||
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
||||||
return error.UnsupportedMetadataType;
|
return error.UnsupportedMetadataType;
|
||||||
} else return false,
|
} else return false,
|
||||||
}
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test populateModel {
|
||||||
|
const Model = struct {
|
||||||
|
a: zml.Tensor,
|
||||||
|
b: struct { a: zml.Tensor, b: u32 },
|
||||||
|
c: []zml.Tensor,
|
||||||
|
d: []struct { a: zml.Tensor, b: u32 },
|
||||||
|
e: struct { zml.Tensor, u32, struct { a: u32, b: zml.Tensor, c: void } },
|
||||||
|
f: ?zml.Tensor,
|
||||||
|
g: ?zml.Tensor,
|
||||||
|
|
||||||
|
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
|
||||||
|
fn _newHostBuffer(n: u32) zml.HostBuffer {
|
||||||
|
return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined };
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
|
defer arena_state.deinit();
|
||||||
|
var store: BufferStore = .{ .arena = arena_state };
|
||||||
|
try store.buffers.ensureUnusedCapacity(arena_state.allocator(), 16);
|
||||||
|
store.buffers.putAssumeCapacity("a", Model._newHostBuffer(10));
|
||||||
|
store.buffers.putAssumeCapacity("b.a", Model._newHostBuffer(20));
|
||||||
|
store.buffers.putAssumeCapacity("c.0", Model._newHostBuffer(30));
|
||||||
|
store.buffers.putAssumeCapacity("c.1", Model._newHostBuffer(31));
|
||||||
|
store.buffers.putAssumeCapacity("c.2", Model._newHostBuffer(32));
|
||||||
|
store.buffers.putAssumeCapacity("d.0.a", Model._newHostBuffer(40));
|
||||||
|
store.buffers.putAssumeCapacity("d.1.a", Model._newHostBuffer(41));
|
||||||
|
store.buffers.putAssumeCapacity("d.2.a", Model._newHostBuffer(42));
|
||||||
|
store.buffers.putAssumeCapacity("e.0", Model._newHostBuffer(50));
|
||||||
|
store.buffers.putAssumeCapacity("e.2.b", Model._newHostBuffer(51));
|
||||||
|
store.buffers.putAssumeCapacity("f", Model._newHostBuffer(60));
|
||||||
|
// no entry for g.
|
||||||
|
store.buffers.putAssumeCapacity("unused_entry", Model._newHostBuffer(1000));
|
||||||
|
|
||||||
|
const model = try populateModel(Model, arena_state.allocator(), store);
|
||||||
|
|
||||||
|
try std.testing.expectEqual(10, model.a.dim(0));
|
||||||
|
try std.testing.expectEqual(20, model.b.a.dim(0));
|
||||||
|
try std.testing.expectEqual(3, model.c.len);
|
||||||
|
try std.testing.expectEqual(30, model.c[0].dim(0));
|
||||||
|
try std.testing.expectEqual(31, model.c[1].dim(0));
|
||||||
|
try std.testing.expectEqual(32, model.c[2].dim(0));
|
||||||
|
try std.testing.expectEqual(3, model.d.len);
|
||||||
|
try std.testing.expectEqual(40, model.d[0].a.dim(0));
|
||||||
|
try std.testing.expectEqual(41, model.d[1].a.dim(0));
|
||||||
|
try std.testing.expectEqual(42, model.d[2].a.dim(0));
|
||||||
|
try std.testing.expectEqual(50, model.e[0].dim(0));
|
||||||
|
try std.testing.expectEqual(51, model.e[2].b.dim(0));
|
||||||
|
try std.testing.expectEqual(60, model.f.?.dim(0));
|
||||||
|
try std.testing.expectEqual(null, model.g);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
||||||
@ -476,6 +528,7 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
|||||||
},
|
},
|
||||||
.Struct => |struct_info| {
|
.Struct => |struct_info| {
|
||||||
inline for (struct_info.fields) |field| {
|
inline for (struct_info.fields) |field| {
|
||||||
|
if (field.is_comptime or @sizeOf(field.type) == 0) continue;
|
||||||
try prefix_builder.push(allocator, field.name);
|
try prefix_builder.push(allocator, field.name);
|
||||||
defer prefix_builder.pop();
|
defer prefix_builder.pop();
|
||||||
|
|
||||||
|
|||||||
@ -847,6 +847,16 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
return self.inner.platform;
|
return self.inner.platform;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn getOutputBuffer(self: Self, i: usize) Buffer {
|
||||||
|
var shards: Buffer.Shards = .{};
|
||||||
|
for (self.output_per_device) |dev_out| {
|
||||||
|
shards.appendAssumeCapacity(dev_out[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
const out_shape = self.inner.result_buffer_shapes[i];
|
||||||
|
return Buffer.fromPjrtBuffers(self.platform(), out_shape, shards.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
||||||
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
||||||
var event: [1]*pjrt.Event = undefined;
|
var event: [1]*pjrt.Event = undefined;
|
||||||
|
|||||||
@ -145,8 +145,15 @@ pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_ar
|
|||||||
return mod.call(buffer_args);
|
return mod.call(buffer_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn testLayer(platform: zml.Platform, buffer_store: zml.aio.BufferStore, comptime name: []const u8, layer: anytype, layer_weights: zml.Bufferized(@TypeOf(layer)), tolerance: f32) !void {
|
pub fn testLayer(
|
||||||
try testLayerOut(platform, buffer_store, name, name ++ ".out", layer, layer_weights, tolerance);
|
platform: zml.Platform,
|
||||||
|
activations: zml.aio.BufferStore,
|
||||||
|
comptime name: []const u8,
|
||||||
|
layer: anytype,
|
||||||
|
layer_weights: zml.Bufferized(@TypeOf(layer)),
|
||||||
|
tolerance: f32,
|
||||||
|
) !void {
|
||||||
|
try testLayerOut(platform, activations, name, name ++ ".out", layer, layer_weights, tolerance);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn testLayerOut(
|
pub fn testLayerOut(
|
||||||
@ -214,18 +221,18 @@ pub fn testLayerOut(
|
|||||||
try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32);
|
try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32);
|
||||||
fetch_ctx.prefix.appendSliceAssumeCapacity(name ++ ".in.");
|
fetch_ctx.prefix.appendSliceAssumeCapacity(name ++ ".in.");
|
||||||
try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers);
|
try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers);
|
||||||
defer zml.aio.unloadBuffers(input_buffers);
|
defer zml.aio.unloadBuffers(&input_buffers);
|
||||||
_ = mod.call(input_buffers);
|
_ = mod.call(input_buffers);
|
||||||
}
|
}
|
||||||
|
|
||||||
var buf: [1024]u8 = undefined;
|
var buf: [1024]u8 = undefined;
|
||||||
for (mod.output_buffers, 0..) |out, i| {
|
for (0..mod.inner.result_buffer_count) |i| {
|
||||||
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
|
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
|
||||||
const expected_out = activations.get(full_name) orelse {
|
const expected_out = activations.get(full_name) orelse {
|
||||||
log.warn("Output buffer not found: {s}", .{full_name});
|
log.warn("Output buffer not found: {s}", .{full_name});
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
zml.testing.expectClose(expected_out, zml.Buffer.fromPjrtBuffer(platform, out), tolerance) catch |err| {
|
zml.testing.expectClose(expected_out, mod.getOutputBuffer(i), tolerance) catch |err| {
|
||||||
log.err("{s}.{d} doesn't match !", .{ out_name, i });
|
log.err("{s}.{d} doesn't match !", .{ out_name, i });
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user