Add Qwen3VL bf16 example implementation and integrate zignal image format support; update Bazel build files and core ZML modules.
This commit is contained in:
parent
b8b4d33379
commit
e659dc8fa3
@ -102,7 +102,7 @@ use_repo(zls, "zls_toolchains")
|
||||
register_toolchains("@zls_toolchains//:all")
|
||||
|
||||
non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps")
|
||||
use_repo(non_module_deps, "com_github_hejsil_clap", "com_google_sentencepiece", "mnist", "org_swig_swig", "xla")
|
||||
use_repo(non_module_deps, "com_github_hejsil_clap", "com_google_sentencepiece", "mnist", "org_swig_swig", "xla", "com_github_bfactory_ai_zignal")
|
||||
|
||||
xla = use_extension("//third_party/xla:xla.bzl", "xla")
|
||||
use_repo(
|
||||
|
||||
1
third_party/com_github_bfactory_ai_zignal/BUILD.bazel
vendored
Normal file
1
third_party/com_github_bfactory_ai_zignal/BUILD.bazel
vendored
Normal file
@ -0,0 +1 @@
|
||||
# Empty BUILD.bazel to make this a Bazel package
|
||||
9
third_party/com_github_bfactory_ai_zignal/repo.bzl
vendored
Normal file
9
third_party/com_github_bfactory_ai_zignal/repo.bzl
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository")
|
||||
|
||||
def repo():
|
||||
new_git_repository(
|
||||
name = "com_github_bfactory_ai_zignal",
|
||||
remote = "https://github.com/loupicaaa/zignal.git",
|
||||
commit = "21553a48014add0e7f069f8c72b9277786185127",
|
||||
build_file = "//:third_party/com_github_bfactory_ai_zignal/zignal.bazel",
|
||||
)
|
||||
9
third_party/com_github_bfactory_ai_zignal/zignal.bazel
vendored
Normal file
9
third_party/com_github_bfactory_ai_zignal/zignal.bazel
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
zig_library(
|
||||
name = "zignal",
|
||||
import_name = "zignal",
|
||||
srcs = glob(["**/*.zig"], exclude = ["build.zig", "build.zig.zon"]),
|
||||
main = "src/root.zig", # Le fichier principal devrait être à la racine
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
2
third_party/non_module_deps.bzl
vendored
2
third_party/non_module_deps.bzl
vendored
@ -1,3 +1,4 @@
|
||||
load("//third_party/com_github_bfactory_ai_zignal:repo.bzl", com_github_bfactory_ai_zignal = "repo")
|
||||
load("//third_party/com_github_hejsil_clap:repo.bzl", com_github_hejsil_clap = "repo")
|
||||
load("//third_party/com_google_sentencepiece:repo.bzl", com_google_sentencepiece = "repo")
|
||||
load("//third_party/mnist:repo.bzl", mnist = "repo")
|
||||
@ -10,6 +11,7 @@ def _non_module_deps_impl(mctx):
|
||||
com_github_hejsil_clap()
|
||||
mnist()
|
||||
xla()
|
||||
com_github_bfactory_ai_zignal()
|
||||
|
||||
return mctx.extension_metadata(
|
||||
reproducible = True,
|
||||
|
||||
@ -140,7 +140,7 @@ pub const Buffer = struct {
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
||||
const host_buffer = HostBuffer.fromArray(&arr);
|
||||
const host_buffer = HostBuffer.fromArrayPtr(&arr);
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
}
|
||||
|
||||
@ -160,7 +160,7 @@ pub const Buffer = struct {
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer {
|
||||
const host_buffer = HostBuffer.fromArray(&arr);
|
||||
const host_buffer = HostBuffer.fromArrayPtr(&arr);
|
||||
return try from(platform, host_buffer, opts);
|
||||
}
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ pub const HostBuffer = struct {
|
||||
/// Creates a tensor from a **pointer** to a "multi dimension" array.
|
||||
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
|
||||
/// Typically this is use with constant arrays.
|
||||
pub fn fromArray(arr_ptr: anytype) HostBuffer {
|
||||
pub fn fromArrayPtr(arr_ptr: anytype) HostBuffer {
|
||||
const T = @TypeOf(arr_ptr.*);
|
||||
const sh = parseArrayInfo(T);
|
||||
std.debug.assert(sh.byteSize() == @sizeOf(T));
|
||||
@ -105,6 +105,17 @@ pub const HostBuffer = struct {
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a tensor from an array by allocating and copying the content.
|
||||
pub fn fromArray(allocator: std.mem.Allocator, arr: anytype) !HostBuffer {
|
||||
const T = @TypeOf(arr);
|
||||
const sh = parseArrayInfo(T);
|
||||
std.debug.assert(sh.byteSize() == @sizeOf(T));
|
||||
|
||||
const buffer = try empty(allocator, sh);
|
||||
@memcpy(std.mem.sliceAsBytes(buffer.mutItems(@TypeOf(arr[0]))), std.mem.sliceAsBytes(&arr));
|
||||
return buffer;
|
||||
}
|
||||
|
||||
/// Returns a HostBuffer tagged with the tags in 'tagz'.
|
||||
pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer {
|
||||
var res = self;
|
||||
@ -328,7 +339,7 @@ pub const HostBuffer = struct {
|
||||
return self.prettyPrintIndented(writer, 4, 0, options);
|
||||
}
|
||||
|
||||
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void {
|
||||
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u32, indent_level: u8, options: std.fmt.Number) !void {
|
||||
if (self.rank() == 0) {
|
||||
// Special case input tensor is a scalar
|
||||
return switch (self.dtype()) {
|
||||
|
||||
@ -302,6 +302,8 @@ test mapAlloc {
|
||||
pub fn MapRestrict(From: type, To: type) type {
|
||||
return struct {
|
||||
pub fn map(T: type) type {
|
||||
@setEvalBranchQuota(10_000);
|
||||
|
||||
switch (T) {
|
||||
From => return To,
|
||||
?From => return ?To,
|
||||
|
||||
25
zml/nn.zig
25
zml/nn.zig
@ -32,7 +32,7 @@ pub const Linear = struct {
|
||||
}
|
||||
|
||||
// log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() });
|
||||
return if (self.bias) |bias| y.add(bias.broadcast(y.shape(), &.{y.axis(-1)})) else y;
|
||||
return if (self.bias) |bias| y.add(bias.convert(y.dtype()).broadcast(y.shape(), &.{y.axis(-1)})) else y;
|
||||
}
|
||||
};
|
||||
|
||||
@ -100,10 +100,11 @@ pub const LayerNorm = struct {
|
||||
pub fn forward(self: LayerNorm, x: Tensor) Tensor {
|
||||
const normed = normalizeVariance(x, self.eps);
|
||||
const ax = x.axis(-1);
|
||||
var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax}));
|
||||
if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax}));
|
||||
var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax}).convert(.f32));
|
||||
|
||||
return out;
|
||||
if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax}).convert(.f32));
|
||||
|
||||
return out.convert(x.dtype());
|
||||
}
|
||||
};
|
||||
|
||||
@ -112,6 +113,7 @@ pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
|
||||
// upcast to improve precision
|
||||
const variance = x.convert(.f32).powByConst(2).mean(ax);
|
||||
const rsqrt = Tensor.rsqrt(variance.addConstant(eps)).convert(x.dtype());
|
||||
|
||||
return x.mul(rsqrt.broad(x.shape()));
|
||||
}
|
||||
|
||||
@ -190,7 +192,7 @@ pub const RopeOpts = struct {
|
||||
if (content != .object) return error.InvalidEnumTag;
|
||||
|
||||
const obj = content.object;
|
||||
const impl = obj.get("rope_type") orelse return error.MissingField;
|
||||
const impl = obj.get("rope_type") orelse obj.get("type") orelse return error.MissingField;
|
||||
if (impl != .string) return error.InvalidEnumTag;
|
||||
if (std.mem.eql(u8, impl.string, "llama3")) {
|
||||
// Note: leaky is fine here cause Llama3 struct don't need to allocate memory.
|
||||
@ -583,7 +585,7 @@ test nearest {
|
||||
const result = try zml.testing.compileAndCall(platform, upsample, .{ input_3d_basic, .{ .scale_factor = &.{3}, .mode = .nearest } });
|
||||
try std.testing.expectEqualSlices(i64, &.{ 1, 1, 6 }, result.dims());
|
||||
const expected: [1][1][6]i32 = .{.{.{ 1, 1, 1, 2, 2, 2 }}};
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0);
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0);
|
||||
}
|
||||
// 3D Tensor (advanced)
|
||||
{
|
||||
@ -605,7 +607,7 @@ test nearest {
|
||||
.{ 21, 21, 22, 22, 23, 23, 24, 24 },
|
||||
},
|
||||
};
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0);
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0);
|
||||
}
|
||||
// 4D Tensor (basic)
|
||||
{
|
||||
@ -663,7 +665,7 @@ test nearest {
|
||||
},
|
||||
},
|
||||
};
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0);
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0);
|
||||
}
|
||||
// 5D Tensor (basic)
|
||||
{
|
||||
@ -688,7 +690,7 @@ test nearest {
|
||||
},
|
||||
},
|
||||
};
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0);
|
||||
try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -835,7 +837,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
|
||||
.{ 1, -2.5, 2, -0.5 },
|
||||
.{ -0.5, 1.5, -1.5, 0.5 },
|
||||
};
|
||||
const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors });
|
||||
const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArrayPtr(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors });
|
||||
|
||||
// actually do the interpolation.
|
||||
// Note: ideally this matmul should be inlined with the gather, but that's currently not the case.
|
||||
@ -940,7 +942,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
k = k.mul(head_scaling.convert(k.dtype()));
|
||||
|
||||
var attn_weights = q.dot(k, .{.hd});
|
||||
// log.debug("attn_weights : {f}, attn_mask : {?f}", .{ attn_weights, attn_mask });
|
||||
|
||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||
attn_weights = attn_weights.convert(.f32);
|
||||
attn_weights = if (opts.softmax_bias) |softmax_bias| attn: {
|
||||
@ -949,7 +951,6 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
const bias = softmax_bias.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto });
|
||||
break :attn attn_weights.convert(.f32).softmaxBiased(.k, bias).convert(q.dtype());
|
||||
} else attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||
|
||||
var attn = attn_weights.dot(v, .{.k});
|
||||
return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } });
|
||||
}
|
||||
|
||||
@ -715,7 +715,7 @@ pub const Tensor = struct {
|
||||
for (&powers, 0..) |*p, i| p.* = std.math.pow(u64, 2, i * 16);
|
||||
break :blk powers;
|
||||
};
|
||||
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
|
||||
const values = Tensor.constantTensor(HostBuffer.fromArrayPtr(&powers)).withTags(.{.d});
|
||||
const counts = values.gather(.{ .d = samples }, .{}).sum(.n).bitCast(.u16);
|
||||
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
|
||||
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
|
||||
@ -764,7 +764,7 @@ pub const Tensor = struct {
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
|
||||
pub inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor {
|
||||
stdx.debug.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() });
|
||||
const N = self.rank();
|
||||
stdx.debug.guard(opts.window_strides.len == N - 2, @src());
|
||||
@ -859,6 +859,49 @@ pub const Tensor = struct {
|
||||
return _result(new_shape, op.result(0));
|
||||
}
|
||||
|
||||
pub fn conv3d(
|
||||
input: Tensor,
|
||||
kernel: Tensor,
|
||||
opts: struct {
|
||||
window_strides: []const i64 = &.{ 1, 1, 1 }, //[time, height, width]
|
||||
padding: []const i64 = &.{ 0, 0, 0, 0, 0, 0 }, //[front, back, top, bottom, left, right]
|
||||
lhs_dilation: []const i64 = &.{ 1, 1, 1 }, //[time, height, width]
|
||||
rhs_dilation: []const i64 = &.{ 1, 1, 1 }, //[time, height, width]
|
||||
window_reversal: []const bool = &.{ false, false, false }, //[time, height, width]
|
||||
input_batch_dimension: i64 = 0,
|
||||
input_feature_dimension: i64 = 1,
|
||||
input_spatial_dimensions: []const i64 = &.{ 2, 3, 4 },
|
||||
kernel_input_feature_dimension: i64 = 1,
|
||||
kernel_output_feature_dimension: i64 = 0,
|
||||
kernel_spatial_dimensions: []const i64 = &.{ 2, 3, 4 },
|
||||
output_batch_dimension: i64 = 0,
|
||||
output_feature_dimension: i64 = 1,
|
||||
output_spatial_dimensions: []const i64 = &.{ 2, 3, 4 },
|
||||
feature_group_count: i64 = 1,
|
||||
batch_group_count: i64 = 1,
|
||||
},
|
||||
) Tensor {
|
||||
const loc = input.getContext().location(@src(), "opts={}", .{opts});
|
||||
return input.convolution(kernel, .{
|
||||
.window_strides = opts.window_strides,
|
||||
.pad_value = opts.padding,
|
||||
.lhs_dilation = opts.lhs_dilation,
|
||||
.rhs_dilation = opts.rhs_dilation,
|
||||
.window_reversal = opts.window_reversal,
|
||||
.input_batch_dimension = opts.input_batch_dimension,
|
||||
.input_feature_dimension = opts.input_feature_dimension,
|
||||
.input_spatial_dimensions = opts.input_spatial_dimensions,
|
||||
.kernel_input_feature_dimension = opts.kernel_input_feature_dimension,
|
||||
.kernel_output_feature_dimension = opts.kernel_output_feature_dimension,
|
||||
.kernel_spatial_dimensions = opts.kernel_spatial_dimensions,
|
||||
.output_batch_dimension = opts.output_batch_dimension,
|
||||
.output_feature_dimension = opts.output_feature_dimension,
|
||||
.output_spatial_dimensions = opts.output_spatial_dimensions,
|
||||
.feature_group_count = opts.feature_group_count,
|
||||
.batch_group_count = opts.batch_group_count,
|
||||
}, loc);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the result of the 1D convolution of 'input' by 'kernel'.
|
||||
pub fn conv1d(
|
||||
input: Tensor,
|
||||
@ -1283,7 +1326,7 @@ pub const Tensor = struct {
|
||||
const input = try zml.Buffer.fromSlice(platform, .{2}, &[_]f32{ -0.6884, 1.6795 });
|
||||
const res = try zml.testing.compileAndCall(platform, leakyReLU, .{ input, 0.1 });
|
||||
|
||||
const expectation = zml.HostBuffer.fromArray(&[2]f32{ -0.0688, 1.6795 });
|
||||
const expectation = zml.HostBuffer.fromArrayPtr(&[2]f32{ -0.0688, 1.6795 });
|
||||
try zml.testing.expectClose(expectation, res, 1e-4);
|
||||
}
|
||||
|
||||
@ -1979,9 +2022,10 @@ pub const Tensor = struct {
|
||||
const sh = Shape.init(.{args.steps}, dt);
|
||||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc);
|
||||
var res = _result(sh, iota_op.result(0));
|
||||
const range = args.end - args.start;
|
||||
|
||||
if (args.steps != 1) {
|
||||
res = res.scale(args.steps);
|
||||
res = res.scale(range / @as(f64, @floatFromInt(args.steps - 1)));
|
||||
}
|
||||
|
||||
if (args.start != 0) {
|
||||
@ -2600,7 +2644,7 @@ pub const Tensor = struct {
|
||||
|
||||
const result = try zml.testing.compileAndCall(platform, Local._gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
||||
|
||||
const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{
|
||||
const expected = zml.HostBuffer.fromArrayPtr(&[2][2][2][3]u16{
|
||||
.{
|
||||
.{ .{ 13, 14, 15 }, .{ 19, 20, 21 } },
|
||||
.{ .{ 37, 38, 39 }, .{ 43, 44, 45 } },
|
||||
|
||||
120
zml/testing.zig
120
zml/testing.zig
@ -213,14 +213,28 @@ pub fn testLayerOut(
|
||||
const fwd = @TypeOf(layer).forward;
|
||||
const FwdSign = zml.ModuleSignature(fwd);
|
||||
|
||||
const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in");
|
||||
const input_shapes = try shapesOf(input_tensors, alloc);
|
||||
const ArgsT = FwdSign.ArgsT;
|
||||
|
||||
const n_in = zml.module.countTensors(&input_tensors);
|
||||
const n_in_exp = activations.countLayers(name ++ ".in");
|
||||
if (n_in != n_in_exp) {
|
||||
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
|
||||
}
|
||||
// Check if layer has inputs
|
||||
const has_inputs = switch (@typeInfo(ArgsT)) {
|
||||
.@"struct" => |info| info.fields.len > 0,
|
||||
else => false,
|
||||
};
|
||||
|
||||
// Get input shapes (empty for layers without input)
|
||||
const input_shapes = if (has_inputs) blk: {
|
||||
const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in");
|
||||
const n_in = zml.module.countTensors(&input_tensors);
|
||||
const n_in_exp = activations.countLayers(name ++ ".in");
|
||||
if (n_in != n_in_exp) {
|
||||
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
|
||||
}
|
||||
break :blk try zml.shapesOf(input_tensors, alloc);
|
||||
} else blk: {
|
||||
// For layers without input, ArgsT should be void or empty tuple
|
||||
const empty_shapes: zml.ShapeOf(ArgsT) = undefined;
|
||||
break :blk empty_shapes;
|
||||
};
|
||||
|
||||
const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform);
|
||||
|
||||
@ -230,32 +244,29 @@ pub fn testLayerOut(
|
||||
}
|
||||
const mod = exe.prepare(layer_weights);
|
||||
|
||||
const FetchCtx = struct {
|
||||
store: zml.aio.BufferStore,
|
||||
index: u32,
|
||||
prefix: std.ArrayList(u8),
|
||||
platform: zml.Platform,
|
||||
// Call the module with input buffers (empty for layers without input)
|
||||
if (has_inputs) {
|
||||
const FetchCtx = struct {
|
||||
store: zml.aio.BufferStore,
|
||||
index: u32,
|
||||
prefix: std.ArrayList(u8),
|
||||
platform: zml.Platform,
|
||||
|
||||
fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer {
|
||||
_ = x;
|
||||
defer ctx.index += 1;
|
||||
var full_prefix = ctx.*.prefix;
|
||||
_ = full_prefix.writer(undefined).print("{d}", .{ctx.index}) catch unreachable;
|
||||
log.info("prefix: {s}", .{full_prefix.items});
|
||||
const host = ctx.store.get(full_prefix.items) orelse {
|
||||
log.err("Didn't find test input: {s}", .{full_prefix.items});
|
||||
@panic("Missing test input");
|
||||
};
|
||||
return host.toDevice(ctx.platform) catch unreachable;
|
||||
}
|
||||
};
|
||||
fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer {
|
||||
_ = x;
|
||||
defer ctx.index += 1;
|
||||
var full_prefix = ctx.*.prefix;
|
||||
_ = full_prefix.writer(undefined).print("{d}", .{ctx.index}) catch unreachable;
|
||||
log.info("prefix: {s}", .{full_prefix.items});
|
||||
const host = ctx.store.get(full_prefix.items) orelse {
|
||||
log.err("Didn't find test input: {s}", .{full_prefix.items});
|
||||
@panic("Missing test input");
|
||||
};
|
||||
return host.toDevice(ctx.platform) catch unreachable;
|
||||
}
|
||||
};
|
||||
|
||||
// Note: zml.populateModelWithPrefix isn't enough,
|
||||
// because it assumes we have the same structure in the activation file
|
||||
// than in the function signature.
|
||||
// But for sake of decoupling the reference implementation
|
||||
// and ZML code that's not always the case.
|
||||
{
|
||||
const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in");
|
||||
var input_buffers: zml.Bufferized(FwdSign.ArgsT) = undefined;
|
||||
var fetch_ctx: FetchCtx = .{ .store = activations, .index = 0, .prefix = .{}, .platform = platform };
|
||||
try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32);
|
||||
@ -263,10 +274,16 @@ pub fn testLayerOut(
|
||||
try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers);
|
||||
defer zml.aio.unloadBuffers(&input_buffers);
|
||||
_ = mod.call(input_buffers);
|
||||
} else {
|
||||
// For layers without input, ArgsT should be void
|
||||
// Bufferized(void) is void, so we can't call mod.call normally
|
||||
// Use _unsafeCall directly and then get the results manually
|
||||
mod.inner._unsafeCall();
|
||||
}
|
||||
|
||||
var buf: [1024]u8 = undefined;
|
||||
var failed: bool = false;
|
||||
log.info("COMPARAISON DES SORTIES", .{});
|
||||
for (0..mod.inner.result_shapes.len) |i| {
|
||||
const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable;
|
||||
const expected_out = activations.get(full_name) orelse {
|
||||
@ -305,14 +322,49 @@ test testLayer {
|
||||
var activations = zml.aio.BufferStore.init(std.testing.allocator);
|
||||
defer activations.deinit();
|
||||
{
|
||||
const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 });
|
||||
const input = zml.HostBuffer.fromArrayPtr(&[2]f32{ 1, -1 });
|
||||
try activations.buffers.put(activations.arena.allocator(), "model.layer.in.0", input);
|
||||
const output = zml.HostBuffer.fromArray(&[5]f32{ 0, -1, -1, 0, -1 });
|
||||
const output = zml.HostBuffer.fromArrayPtr(&[5]f32{ 0, -1, -1, 0, -1 });
|
||||
try activations.buffers.put(activations.arena.allocator(), "model.layer.out.0", output);
|
||||
}
|
||||
|
||||
// test the ZML layer reproduces the "captured" activations:
|
||||
// Test the ZML layer reproduces the activations:
|
||||
try zml.testing.testLayer(platform, activations, "model.layer", layer, layer_weights, 1e-5);
|
||||
|
||||
const LayerWithoutInput = struct {
|
||||
weight: zml.Tensor,
|
||||
|
||||
pub fn forward(self: @This()) zml.Tensor {
|
||||
// Return the weights
|
||||
return self.weight;
|
||||
}
|
||||
};
|
||||
|
||||
const layer_no_input: LayerWithoutInput = .{
|
||||
.weight = zml.Tensor{ ._shape = zml.Shape.init(.{ 3, 4 }, .f32), ._id = .{ .buffer_id = 43 } },
|
||||
};
|
||||
|
||||
const layer_no_input_weights: zml.Bufferized(LayerWithoutInput) = .{
|
||||
.weight = try zml.Buffer.fromArray(
|
||||
platform,
|
||||
[3][4]f32{
|
||||
.{ 1.0, 2.0, 3.0, 4.0 },
|
||||
.{ 5.0, 6.0, 7.0, 8.0 },
|
||||
.{ 9.0, 10.0, 11.0, 12.0 },
|
||||
},
|
||||
),
|
||||
};
|
||||
|
||||
// Expected output
|
||||
const expected_output = zml.HostBuffer.fromArrayPtr(&[3][4]f32{
|
||||
.{ 1.0, 2.0, 3.0, 4.0 },
|
||||
.{ 5.0, 6.0, 7.0, 8.0 },
|
||||
.{ 9.0, 10.0, 11.0, 12.0 },
|
||||
});
|
||||
try activations.buffers.put(activations.arena.allocator(), "model.layer_no_input.out.0", expected_output);
|
||||
|
||||
// Test the ZML layer without input reproduces the "captured" activations:
|
||||
try zml.testing.testLayer(platform, activations, "model.layer_no_input", layer_no_input, layer_no_input_weights, 1e-5);
|
||||
}
|
||||
|
||||
pub inline fn expectEqual(expected: anytype, actual: @TypeOf(expected)) !void {
|
||||
|
||||
@ -148,7 +148,7 @@ test pixelShuffle {
|
||||
|
||||
const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor });
|
||||
|
||||
const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{
|
||||
const exp = zml.HostBuffer.fromArrayPtr(&[1][1][12][12]i32{.{.{
|
||||
.{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },
|
||||
.{ 48, 64, 80, 49, 65, 81, 50, 66, 82, 51, 67, 83 },
|
||||
.{ 96, 112, 128, 97, 113, 129, 98, 114, 130, 99, 115, 131 },
|
||||
|
||||
@ -42,6 +42,7 @@ pub const Shape = @import("shape.zig").Shape;
|
||||
pub const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
pub const Target = @import("platform.zig").Target;
|
||||
pub const Tensor = @import("tensor.zig").Tensor;
|
||||
pub const shapesOf = @import("tensor.zig").shapesOf;
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user