From e659dc8fa3aa739dc8a518a74ef6d389d8f65d4f Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 29 Dec 2025 16:17:11 +0000 Subject: [PATCH] Add Qwen3VL bf16 example implementation and integrate zignal image format support; update Bazel build files and core ZML modules. --- MODULE.bazel | 2 +- .../com_github_bfactory_ai_zignal/BUILD.bazel | 1 + .../com_github_bfactory_ai_zignal/repo.bzl | 9 ++ .../zignal.bazel | 9 ++ third_party/non_module_deps.bzl | 2 + zml/buffer.zig | 4 +- zml/hostbuffer.zig | 15 ++- zml/meta.zig | 2 + zml/nn.zig | 25 ++-- zml/tensor.zig | 54 +++++++- zml/testing.zig | 120 +++++++++++++----- zml/torch.zig | 2 +- zml/zml.zig | 1 + 13 files changed, 189 insertions(+), 57 deletions(-) create mode 100644 third_party/com_github_bfactory_ai_zignal/BUILD.bazel create mode 100644 third_party/com_github_bfactory_ai_zignal/repo.bzl create mode 100644 third_party/com_github_bfactory_ai_zignal/zignal.bazel diff --git a/MODULE.bazel b/MODULE.bazel index 22e95e7..0ef6245 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -102,7 +102,7 @@ use_repo(zls, "zls_toolchains") register_toolchains("@zls_toolchains//:all") non_module_deps = use_extension("//:third_party/non_module_deps.bzl", "non_module_deps") -use_repo(non_module_deps, "com_github_hejsil_clap", "com_google_sentencepiece", "mnist", "org_swig_swig", "xla") +use_repo(non_module_deps, "com_github_hejsil_clap", "com_google_sentencepiece", "mnist", "org_swig_swig", "xla", "com_github_bfactory_ai_zignal") xla = use_extension("//third_party/xla:xla.bzl", "xla") use_repo( diff --git a/third_party/com_github_bfactory_ai_zignal/BUILD.bazel b/third_party/com_github_bfactory_ai_zignal/BUILD.bazel new file mode 100644 index 0000000..245f310 --- /dev/null +++ b/third_party/com_github_bfactory_ai_zignal/BUILD.bazel @@ -0,0 +1 @@ +# Empty BUILD.bazel to make this a Bazel package diff --git a/third_party/com_github_bfactory_ai_zignal/repo.bzl b/third_party/com_github_bfactory_ai_zignal/repo.bzl new file mode 100644 index 0000000..d644e30 --- /dev/null +++ b/third_party/com_github_bfactory_ai_zignal/repo.bzl @@ -0,0 +1,9 @@ +load("@bazel_tools//tools/build_defs/repo:git.bzl", "new_git_repository") + +def repo(): + new_git_repository( + name = "com_github_bfactory_ai_zignal", + remote = "https://github.com/loupicaaa/zignal.git", + commit = "21553a48014add0e7f069f8c72b9277786185127", + build_file = "//:third_party/com_github_bfactory_ai_zignal/zignal.bazel", + ) diff --git a/third_party/com_github_bfactory_ai_zignal/zignal.bazel b/third_party/com_github_bfactory_ai_zignal/zignal.bazel new file mode 100644 index 0000000..0f59688 --- /dev/null +++ b/third_party/com_github_bfactory_ai_zignal/zignal.bazel @@ -0,0 +1,9 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + +zig_library( + name = "zignal", + import_name = "zignal", + srcs = glob(["**/*.zig"], exclude = ["build.zig", "build.zig.zon"]), + main = "src/root.zig", # Le fichier principal devrait être à la racine + visibility = ["//visibility:public"], +) diff --git a/third_party/non_module_deps.bzl b/third_party/non_module_deps.bzl index 5c76946..07353fb 100644 --- a/third_party/non_module_deps.bzl +++ b/third_party/non_module_deps.bzl @@ -1,3 +1,4 @@ +load("//third_party/com_github_bfactory_ai_zignal:repo.bzl", com_github_bfactory_ai_zignal = "repo") load("//third_party/com_github_hejsil_clap:repo.bzl", com_github_hejsil_clap = "repo") load("//third_party/com_google_sentencepiece:repo.bzl", com_google_sentencepiece = "repo") load("//third_party/mnist:repo.bzl", mnist = "repo") @@ -10,6 +11,7 @@ def _non_module_deps_impl(mctx): com_github_hejsil_clap() mnist() xla() + com_github_bfactory_ai_zignal() return mctx.extension_metadata( reproducible = True, diff --git a/zml/buffer.zig b/zml/buffer.zig index e76039d..57c46b2 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -140,7 +140,7 @@ pub const Buffer = struct { /// Copies the given Zig array to the accelerator memory and /// return a Buffer using the array shape. pub fn fromArray(platform: Platform, arr: anytype) !Buffer { - const host_buffer = HostBuffer.fromArray(&arr); + const host_buffer = HostBuffer.fromArrayPtr(&arr); return try from(platform, host_buffer, .{ .wait = true }); } @@ -160,7 +160,7 @@ pub const Buffer = struct { /// Copies the given Zig array to the accelerator memory and /// return a Buffer using the array shape. pub fn fromArrayOpts(platform: Platform, arr: anytype, opts: FromOptions) !Buffer { - const host_buffer = HostBuffer.fromArray(&arr); + const host_buffer = HostBuffer.fromArrayPtr(&arr); return try from(platform, host_buffer, opts); } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 0469f1e..59b97ef 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -93,7 +93,7 @@ pub const HostBuffer = struct { /// Creates a tensor from a **pointer** to a "multi dimension" array. /// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object. /// Typically this is use with constant arrays. - pub fn fromArray(arr_ptr: anytype) HostBuffer { + pub fn fromArrayPtr(arr_ptr: anytype) HostBuffer { const T = @TypeOf(arr_ptr.*); const sh = parseArrayInfo(T); std.debug.assert(sh.byteSize() == @sizeOf(T)); @@ -105,6 +105,17 @@ pub const HostBuffer = struct { }; } + /// Creates a tensor from an array by allocating and copying the content. + pub fn fromArray(allocator: std.mem.Allocator, arr: anytype) !HostBuffer { + const T = @TypeOf(arr); + const sh = parseArrayInfo(T); + std.debug.assert(sh.byteSize() == @sizeOf(T)); + + const buffer = try empty(allocator, sh); + @memcpy(std.mem.sliceAsBytes(buffer.mutItems(@TypeOf(arr[0]))), std.mem.sliceAsBytes(&arr)); + return buffer; + } + /// Returns a HostBuffer tagged with the tags in 'tagz'. pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer { var res = self; @@ -328,7 +339,7 @@ pub const HostBuffer = struct { return self.prettyPrintIndented(writer, 4, 0, options); } - fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void { + fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u32, indent_level: u8, options: std.fmt.Number) !void { if (self.rank() == 0) { // Special case input tensor is a scalar return switch (self.dtype()) { diff --git a/zml/meta.zig b/zml/meta.zig index 891a27e..2733bd5 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -302,6 +302,8 @@ test mapAlloc { pub fn MapRestrict(From: type, To: type) type { return struct { pub fn map(T: type) type { + @setEvalBranchQuota(10_000); + switch (T) { From => return To, ?From => return ?To, diff --git a/zml/nn.zig b/zml/nn.zig index 8bfd2f2..30a7fc1 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -32,7 +32,7 @@ pub const Linear = struct { } // log.debug("Linear({*}): {d} -> {d} -> {d}", .{ self, x.dims(), y.dims(), if (self.bias) |bias| y.add(bias).dims() else y.dims() }); - return if (self.bias) |bias| y.add(bias.broadcast(y.shape(), &.{y.axis(-1)})) else y; + return if (self.bias) |bias| y.add(bias.convert(y.dtype()).broadcast(y.shape(), &.{y.axis(-1)})) else y; } }; @@ -100,10 +100,11 @@ pub const LayerNorm = struct { pub fn forward(self: LayerNorm, x: Tensor) Tensor { const normed = normalizeVariance(x, self.eps); const ax = x.axis(-1); - var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax})); - if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax})); + var out = normed.mul(self.weight.broadcast(x.shape(), &.{ax}).convert(.f32)); - return out; + if (self.bias) |bias| out = out.add(bias.broadcast(x.shape(), &.{ax}).convert(.f32)); + + return out.convert(x.dtype()); } }; @@ -112,6 +113,7 @@ pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor { // upcast to improve precision const variance = x.convert(.f32).powByConst(2).mean(ax); const rsqrt = Tensor.rsqrt(variance.addConstant(eps)).convert(x.dtype()); + return x.mul(rsqrt.broad(x.shape())); } @@ -190,7 +192,7 @@ pub const RopeOpts = struct { if (content != .object) return error.InvalidEnumTag; const obj = content.object; - const impl = obj.get("rope_type") orelse return error.MissingField; + const impl = obj.get("rope_type") orelse obj.get("type") orelse return error.MissingField; if (impl != .string) return error.InvalidEnumTag; if (std.mem.eql(u8, impl.string, "llama3")) { // Note: leaky is fine here cause Llama3 struct don't need to allocate memory. @@ -583,7 +585,7 @@ test nearest { const result = try zml.testing.compileAndCall(platform, upsample, .{ input_3d_basic, .{ .scale_factor = &.{3}, .mode = .nearest } }); try std.testing.expectEqualSlices(i64, &.{ 1, 1, 6 }, result.dims()); const expected: [1][1][6]i32 = .{.{.{ 1, 1, 1, 2, 2, 2 }}}; - try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0); + try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0); } // 3D Tensor (advanced) { @@ -605,7 +607,7 @@ test nearest { .{ 21, 21, 22, 22, 23, 23, 24, 24 }, }, }; - try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0); + try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0); } // 4D Tensor (basic) { @@ -663,7 +665,7 @@ test nearest { }, }, }; - try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0); + try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0); } // 5D Tensor (basic) { @@ -688,7 +690,7 @@ test nearest { }, }, }; - try zml.testing.expectClose(zml.HostBuffer.fromArray(&expected), result, 0); + try zml.testing.expectClose(zml.HostBuffer.fromArrayPtr(&expected), result, 0); } } @@ -835,7 +837,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten .{ 1, -2.5, 2, -0.5 }, .{ -0.5, 1.5, -1.5, 0.5 }, }; - const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors }); + const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArrayPtr(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors }); // actually do the interpolation. // Note: ideally this matmul should be inlined with the gather, but that's currently not the case. @@ -940,7 +942,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { k = k.mul(head_scaling.convert(k.dtype())); var attn_weights = q.dot(k, .{.hd}); - // log.debug("attn_weights : {f}, attn_mask : {?f}", .{ attn_weights, attn_mask }); + if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); attn_weights = attn_weights.convert(.f32); attn_weights = if (opts.softmax_bias) |softmax_bias| attn: { @@ -949,7 +951,6 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { const bias = softmax_bias.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto }); break :attn attn_weights.convert(.f32).softmaxBiased(.k, bias).convert(q.dtype()); } else attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); - var attn = attn_weights.dot(v, .{.k}); return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } }); } diff --git a/zml/tensor.zig b/zml/tensor.zig index d5acc54..778c0c4 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -715,7 +715,7 @@ pub const Tensor = struct { for (&powers, 0..) |*p, i| p.* = std.math.pow(u64, 2, i * 16); break :blk powers; }; - const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d}); + const values = Tensor.constantTensor(HostBuffer.fromArrayPtr(&powers)).withTags(.{.d}); const counts = values.gather(.{ .d = samples }, .{}).sum(.n).bitCast(.u16); const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n)); return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } }; @@ -764,7 +764,7 @@ pub const Tensor = struct { return _result(self._shape, op.result(0)); } - inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor { + pub inline fn convolution(self: Tensor, other: Tensor, opts: dialect.stablehlo.ConvolutionOpts, loc: mlir.Location) Tensor { stdx.debug.assert(self.rank() == other.rank(), "convolution expects tensor ranks to match, got {} and {}", .{ self.rank(), other.rank() }); const N = self.rank(); stdx.debug.guard(opts.window_strides.len == N - 2, @src()); @@ -859,6 +859,49 @@ pub const Tensor = struct { return _result(new_shape, op.result(0)); } + pub fn conv3d( + input: Tensor, + kernel: Tensor, + opts: struct { + window_strides: []const i64 = &.{ 1, 1, 1 }, //[time, height, width] + padding: []const i64 = &.{ 0, 0, 0, 0, 0, 0 }, //[front, back, top, bottom, left, right] + lhs_dilation: []const i64 = &.{ 1, 1, 1 }, //[time, height, width] + rhs_dilation: []const i64 = &.{ 1, 1, 1 }, //[time, height, width] + window_reversal: []const bool = &.{ false, false, false }, //[time, height, width] + input_batch_dimension: i64 = 0, + input_feature_dimension: i64 = 1, + input_spatial_dimensions: []const i64 = &.{ 2, 3, 4 }, + kernel_input_feature_dimension: i64 = 1, + kernel_output_feature_dimension: i64 = 0, + kernel_spatial_dimensions: []const i64 = &.{ 2, 3, 4 }, + output_batch_dimension: i64 = 0, + output_feature_dimension: i64 = 1, + output_spatial_dimensions: []const i64 = &.{ 2, 3, 4 }, + feature_group_count: i64 = 1, + batch_group_count: i64 = 1, + }, + ) Tensor { + const loc = input.getContext().location(@src(), "opts={}", .{opts}); + return input.convolution(kernel, .{ + .window_strides = opts.window_strides, + .pad_value = opts.padding, + .lhs_dilation = opts.lhs_dilation, + .rhs_dilation = opts.rhs_dilation, + .window_reversal = opts.window_reversal, + .input_batch_dimension = opts.input_batch_dimension, + .input_feature_dimension = opts.input_feature_dimension, + .input_spatial_dimensions = opts.input_spatial_dimensions, + .kernel_input_feature_dimension = opts.kernel_input_feature_dimension, + .kernel_output_feature_dimension = opts.kernel_output_feature_dimension, + .kernel_spatial_dimensions = opts.kernel_spatial_dimensions, + .output_batch_dimension = opts.output_batch_dimension, + .output_feature_dimension = opts.output_feature_dimension, + .output_spatial_dimensions = opts.output_spatial_dimensions, + .feature_group_count = opts.feature_group_count, + .batch_group_count = opts.batch_group_count, + }, loc); + } + /// Returns a Tensor containing the result of the 1D convolution of 'input' by 'kernel'. pub fn conv1d( input: Tensor, @@ -1283,7 +1326,7 @@ pub const Tensor = struct { const input = try zml.Buffer.fromSlice(platform, .{2}, &[_]f32{ -0.6884, 1.6795 }); const res = try zml.testing.compileAndCall(platform, leakyReLU, .{ input, 0.1 }); - const expectation = zml.HostBuffer.fromArray(&[2]f32{ -0.0688, 1.6795 }); + const expectation = zml.HostBuffer.fromArrayPtr(&[2]f32{ -0.0688, 1.6795 }); try zml.testing.expectClose(expectation, res, 1e-4); } @@ -1979,9 +2022,10 @@ pub const Tensor = struct { const sh = Shape.init(.{args.steps}, dt); var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlirx.tensorType(ctx.mlirCtx(), sh), loc); var res = _result(sh, iota_op.result(0)); + const range = args.end - args.start; if (args.steps != 1) { - res = res.scale(args.steps); + res = res.scale(range / @as(f64, @floatFromInt(args.steps - 1))); } if (args.start != 0) { @@ -2600,7 +2644,7 @@ pub const Tensor = struct { const result = try zml.testing.compileAndCall(platform, Local._gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} }); - const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{ + const expected = zml.HostBuffer.fromArrayPtr(&[2][2][2][3]u16{ .{ .{ .{ 13, 14, 15 }, .{ 19, 20, 21 } }, .{ .{ 37, 38, 39 }, .{ 43, 44, 45 } }, diff --git a/zml/testing.zig b/zml/testing.zig index 38952dd..036efaa 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -213,14 +213,28 @@ pub fn testLayerOut( const fwd = @TypeOf(layer).forward; const FwdSign = zml.ModuleSignature(fwd); - const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in"); - const input_shapes = try shapesOf(input_tensors, alloc); + const ArgsT = FwdSign.ArgsT; - const n_in = zml.module.countTensors(&input_tensors); - const n_in_exp = activations.countLayers(name ++ ".in"); - if (n_in != n_in_exp) { - log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in }); - } + // Check if layer has inputs + const has_inputs = switch (@typeInfo(ArgsT)) { + .@"struct" => |info| info.fields.len > 0, + else => false, + }; + + // Get input shapes (empty for layers without input) + const input_shapes = if (has_inputs) blk: { + const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in"); + const n_in = zml.module.countTensors(&input_tensors); + const n_in_exp = activations.countLayers(name ++ ".in"); + if (n_in != n_in_exp) { + log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in }); + } + break :blk try zml.shapesOf(input_tensors, alloc); + } else blk: { + // For layers without input, ArgsT should be void or empty tuple + const empty_shapes: zml.ShapeOf(ArgsT) = undefined; + break :blk empty_shapes; + }; const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform); @@ -230,32 +244,29 @@ pub fn testLayerOut( } const mod = exe.prepare(layer_weights); - const FetchCtx = struct { - store: zml.aio.BufferStore, - index: u32, - prefix: std.ArrayList(u8), - platform: zml.Platform, + // Call the module with input buffers (empty for layers without input) + if (has_inputs) { + const FetchCtx = struct { + store: zml.aio.BufferStore, + index: u32, + prefix: std.ArrayList(u8), + platform: zml.Platform, - fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer { - _ = x; - defer ctx.index += 1; - var full_prefix = ctx.*.prefix; - _ = full_prefix.writer(undefined).print("{d}", .{ctx.index}) catch unreachable; - log.info("prefix: {s}", .{full_prefix.items}); - const host = ctx.store.get(full_prefix.items) orelse { - log.err("Didn't find test input: {s}", .{full_prefix.items}); - @panic("Missing test input"); - }; - return host.toDevice(ctx.platform) catch unreachable; - } - }; + fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer { + _ = x; + defer ctx.index += 1; + var full_prefix = ctx.*.prefix; + _ = full_prefix.writer(undefined).print("{d}", .{ctx.index}) catch unreachable; + log.info("prefix: {s}", .{full_prefix.items}); + const host = ctx.store.get(full_prefix.items) orelse { + log.err("Didn't find test input: {s}", .{full_prefix.items}); + @panic("Missing test input"); + }; + return host.toDevice(ctx.platform) catch unreachable; + } + }; - // Note: zml.populateModelWithPrefix isn't enough, - // because it assumes we have the same structure in the activation file - // than in the function signature. - // But for sake of decoupling the reference implementation - // and ZML code that's not always the case. - { + const input_tensors = try zml.aio.populateModelWithPrefix(FwdSign.ArgsT, alloc, activations, name ++ ".in"); var input_buffers: zml.Bufferized(FwdSign.ArgsT) = undefined; var fetch_ctx: FetchCtx = .{ .store = activations, .index = 0, .prefix = .{}, .platform = platform }; try fetch_ctx.prefix.ensureTotalCapacity(alloc, name.len + 32); @@ -263,10 +274,16 @@ pub fn testLayerOut( try zml.meta.mapAlloc(FetchCtx.fetch, alloc, &fetch_ctx, input_tensors, &input_buffers); defer zml.aio.unloadBuffers(&input_buffers); _ = mod.call(input_buffers); + } else { + // For layers without input, ArgsT should be void + // Bufferized(void) is void, so we can't call mod.call normally + // Use _unsafeCall directly and then get the results manually + mod.inner._unsafeCall(); } var buf: [1024]u8 = undefined; var failed: bool = false; + log.info("COMPARAISON DES SORTIES", .{}); for (0..mod.inner.result_shapes.len) |i| { const full_name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ out_name, i }) catch unreachable; const expected_out = activations.get(full_name) orelse { @@ -305,14 +322,49 @@ test testLayer { var activations = zml.aio.BufferStore.init(std.testing.allocator); defer activations.deinit(); { - const input = zml.HostBuffer.fromArray(&[2]f32{ 1, -1 }); + const input = zml.HostBuffer.fromArrayPtr(&[2]f32{ 1, -1 }); try activations.buffers.put(activations.arena.allocator(), "model.layer.in.0", input); - const output = zml.HostBuffer.fromArray(&[5]f32{ 0, -1, -1, 0, -1 }); + const output = zml.HostBuffer.fromArrayPtr(&[5]f32{ 0, -1, -1, 0, -1 }); try activations.buffers.put(activations.arena.allocator(), "model.layer.out.0", output); } - // test the ZML layer reproduces the "captured" activations: + // Test the ZML layer reproduces the activations: try zml.testing.testLayer(platform, activations, "model.layer", layer, layer_weights, 1e-5); + + const LayerWithoutInput = struct { + weight: zml.Tensor, + + pub fn forward(self: @This()) zml.Tensor { + // Return the weights + return self.weight; + } + }; + + const layer_no_input: LayerWithoutInput = .{ + .weight = zml.Tensor{ ._shape = zml.Shape.init(.{ 3, 4 }, .f32), ._id = .{ .buffer_id = 43 } }, + }; + + const layer_no_input_weights: zml.Bufferized(LayerWithoutInput) = .{ + .weight = try zml.Buffer.fromArray( + platform, + [3][4]f32{ + .{ 1.0, 2.0, 3.0, 4.0 }, + .{ 5.0, 6.0, 7.0, 8.0 }, + .{ 9.0, 10.0, 11.0, 12.0 }, + }, + ), + }; + + // Expected output + const expected_output = zml.HostBuffer.fromArrayPtr(&[3][4]f32{ + .{ 1.0, 2.0, 3.0, 4.0 }, + .{ 5.0, 6.0, 7.0, 8.0 }, + .{ 9.0, 10.0, 11.0, 12.0 }, + }); + try activations.buffers.put(activations.arena.allocator(), "model.layer_no_input.out.0", expected_output); + + // Test the ZML layer without input reproduces the "captured" activations: + try zml.testing.testLayer(platform, activations, "model.layer_no_input", layer_no_input, layer_no_input_weights, 1e-5); } pub inline fn expectEqual(expected: anytype, actual: @TypeOf(expected)) !void { diff --git a/zml/torch.zig b/zml/torch.zig index 09888ae..8ddf3c4 100644 --- a/zml/torch.zig +++ b/zml/torch.zig @@ -148,7 +148,7 @@ test pixelShuffle { const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor }); - const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{ + const exp = zml.HostBuffer.fromArrayPtr(&[1][1][12][12]i32{.{.{ .{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 }, .{ 48, 64, 80, 49, 65, 81, 50, 66, 82, 51, 67, 83 }, .{ 96, 112, 128, 97, 113, 129, 98, 114, 130, 99, 115, 131 }, diff --git a/zml/zml.zig b/zml/zml.zig index 55559f5..db9aecf 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -42,6 +42,7 @@ pub const Shape = @import("shape.zig").Shape; pub const ShapeOf = @import("tensor.zig").ShapeOf; pub const Target = @import("platform.zig").Target; pub const Tensor = @import("tensor.zig").Tensor; +pub const shapesOf = @import("tensor.zig").shapesOf; pub const testing = @import("testing.zig"); pub const torch = @import("torch.zig");