const std = @import("std"); const testing = std.testing; const async = @import("async"); const stdx = @import("stdx"); const zml = @import("zml"); const Buffer = zml.Buffer; const Tensor = zml.Tensor; const ShapeOf = zml.ShapeOf; const Linear = zml.nn.Linear; const Shape = zml.Shape; const log = std.log.scoped(.qwen3_vl); pub const std_options: std.Options = .{ .log_level = .info, .logFn = async.logFn(std.log.defaultLog), }; test { std.testing.refAllDecls(@This()); } pub const Qwen3VL = struct { qwen: Qwen, pub fn init( allocator: std.mem.Allocator, config: Qwen.Config, options: Qwen.Options, store: zml.aio.BufferStore, ) !Qwen3VL { return .{ .qwen = try Qwen.init(allocator, config, options, store), }; } // Forward pass for the prefill phase // image_hwc: Tensor, the image 3 dim Tensor (height, width, channels) // input_ids: Tensor // image_dim: Tensor, the image dimension (single dim vector with the the resizedshape of the image) // token_index: Tensor, (0 for the prefill because we consider the whole sequence) // prompt_shape: Tensor, the prompt shape // kv_cache: KvCache, key-value cache // h_resized: u32, height of the resized image (given at compilation time but not used in execution) // w_resized: u32, width of the resized image (given at compilation time but not used in execution) pub fn forward( self: Qwen3VL, image_hwc: Tensor, input_ids: Tensor, image_dim: Tensor, token_index: Tensor, prompt_shape: Tensor, kv_cache: KvCache, h_resized: u32, w_resized: u32, rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor, Tensor.Rng } { const pixel_value, const image_grid_thw = self.processImage(image_hwc, image_dim, h_resized, w_resized); const next_token, const updated_cache, const mrope_position_deltas, const new_rng = zml.call(self.qwen, .forward, .{ input_ids, pixel_value, token_index, image_grid_thw, kv_cache, prompt_shape, rng }); return .{ next_token, updated_cache, mrope_position_deltas, new_rng }; } pub fn processImage( self: Qwen3VL, image_hwc: Tensor, image_size: Tensor, h_resized: u32, w_resized: u32, ) struct { Tensor, [3]u32 } { // Resize the image and transpose (channels, height, width) var image_chw = ResizeBicubic(image_hwc, .{ .h = h_resized, .w = w_resized }, .{ .original_len = image_size }).transpose(.{ .c, .h, .w }); image_chw = image_chw.convert(.f32); // Rescale and normalize the image const rescale_factor: f32 = 1.0 / 255.0; const image_mean: f32 = 0.5; const image_std: f32 = 0.5; image_chw = image_chw.scale(rescale_factor); // pixel / 255.0 var image_chw_rescaled_normalized = image_chw.sub(Tensor.scalar(image_mean, .f32)).div(Tensor.scalar(image_std, .f32)); // Introduce the temporal dimension (1) image_chw_rescaled_normalized = image_chw_rescaled_normalized.reshape(.{ .c = 3, .temporal_patch_size = 1, .h = h_resized, .w = w_resized }); const temporal_patch_size = self.qwen.config.vision_config.temporal_patch_size; // Repeat the image 2 times in the temporal dimension image_chw_rescaled_normalized = image_chw_rescaled_normalized.repeat1d(1, 2); const patch_size = self.qwen.config.vision_config.patch_size; //Hardcoded because we only have 1 temporal patch (image) const grid_t = 1; // Compute the number of grid cells based on the patch size (size of the patch embedding) const grid_h: u32 = @intCast(@as(u32, @divExact(h_resized, patch_size))); const grid_w: u32 = @intCast(@as(u32, @divExact(w_resized, patch_size))); const grid_thw = [3]u32{ grid_t, grid_h, grid_w }; const merge_size = self.qwen.config.vision_config.spatial_merge_size; // Split height axis: h -> h_div, m1, patch1 image_chw_rescaled_normalized = image_chw_rescaled_normalized.splitAxis(.h, .{ .h_div = @divExact(grid_h, merge_size), .m1 = merge_size, .patch1 = patch_size, }); // Split width axis: w -> w_div, m2, patch2 image_chw_rescaled_normalized = image_chw_rescaled_normalized.splitAxis(.w, .{ .w_div = @divExact(grid_w, merge_size), .m2 = merge_size, .patch2 = patch_size, }); // After splitting axes, the shape is : // .{ .temporal_patch_size = temporal_patch_size, .c = 3, .h_div = @divExact(grid_h, merge_size), .m1 = merge_size, .patch1 = patch_size, .w_div = @divExact(grid_w, merge_size), .m2 = merge_size, .patch2 = patch_size } image_chw_rescaled_normalized = image_chw_rescaled_normalized.transpose(.{ .h_div, .w_div, .m1, .m2, .c, .temporal_patch_size, .patch1, .patch2 }); const flatten_image = image_chw_rescaled_normalized.reshape(.{ .a = grid_h * grid_w, .b = 3 * temporal_patch_size * patch_size * patch_size }); // Return the flattened image and the grid dimensions return .{ flatten_image, grid_thw }; } pub fn forward_decode( self: Qwen3VL, input_ids: Tensor, cache_position: Tensor, kv_cache: KvCache, mrope_position_deltas: Tensor, rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor.Rng } { const next_token, const updated_cache, const new_rng = zml.call(self.qwen, .forward_decode, .{ input_ids, cache_position, kv_cache, mrope_position_deltas, rng }); const result = .{ next_token.convert(.u32), updated_cache, new_rng }; return result; } }; /// Qwen3-VL architecture, using huggingface transformers naming. /// Vision-Language model with vision transformer and text model. pub const Qwen = struct { pub const VisionConfig = struct { depth: u32 = 32, hidden_size: u32 = 1280, hidden_act: []const u8 = "silu", intermediate_size: u32 = 3420, num_heads: u32 = 16, in_channels: u32 = 3, patch_size: u32 = 14, spatial_merge_size: u32 = 2, temporal_patch_size: u32 = 2, out_hidden_size: u32 = 2048, initializer_range: f32 = 0.02, deepstack_visual_indexes: []const u32 = &[_]u32{ 5, 11, 17 }, num_position_embeddings: u32 = 48 * 48, }; pub const TextConfig = struct { hidden_size: u32 = 2560, bos_token_id: u32, eos_token_id: u32, head_dim: i64 = 128, num_hidden_layers: u32, num_attention_heads: u32, num_key_value_heads: u32, max_position_embeddings: u32, rms_norm_eps: f32, tie_word_embeddings: bool = true, rope_scaling: RopeScaling = .{ .mrope_section = .{ 24, 20, 20 } }, rope_theta: f32 = 5000000.0, }; pub const Config = struct { vision_config: VisionConfig, text_config: TextConfig, tie_word_embeddings: bool = true, }; pub const Options = struct { sampling_strategy: ?zml.nn.SamplingStrategy, max_seq_len: u32, }; pub const RopeScaling = struct { mrope_section: [3]u32 = .{ 24, 20, 20 }, }; vision_transformer: VisionTransformer, text_model: TextModel, // Options controlling generation gen_opts: zml.nn.SamplingStrategy = .{}, config: Config, // Initialize the Qwen model (Vision and Text models) pub fn init(allocator: std.mem.Allocator, config: Config, options: Options, store: zml.aio.BufferStore) !Qwen { return .{ .config = config, .gen_opts = options.sampling_strategy orelse .{}, .vision_transformer = try VisionTransformer.init(allocator, config, store), .text_model = try TextModel.init(allocator, config, store), }; } // Forward pass for the qwen model pub fn forward( self: Qwen, input_ids: Tensor, pixel_values: Tensor, cache_position: Tensor, image_grid_thw: [3]u32, kv_cache: KvCache, prompt_shape: Tensor, rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor, Tensor.Rng } { // Embed the input ids var embedded = zml.call(self.text_model.embed_tokens, .forward, .{input_ids}).withTags(.{ .bs, .seq, .d }); // Forward pass for the vision transformer const vision_embed, const deepstack_features = zml.call(self.vision_transformer, .forward, .{ pixel_values, image_grid_thw }); // Get the number of text tokens before the image, the number of image tokens and the number of text tokens after the image // The number of text tokens before the image (4 tokens according to the chat template) const text_before_image = prompt_shape.choose1d(0, 0).convert(.i32); // The number of image tokens (number of image tokens in the image grid, after the spatial merge -> number of patch on height dimension * number of patch on width dimension / spatial merge size^2) const num_image_tokens = prompt_shape.choose1d(0, 1).convert(.i32); // The number of text tokens after the image (prompt tokens + 6 according to the chat template) const text_after_image = prompt_shape.choose1d(0, 2).convert(.i32); // Update the embedding with the vision embedding const text_with_image = embedded.dynamicUpdateSlice(.{ .seq = text_before_image }, zml.torch.unsqueeze(vision_embed.convert(embedded.dtype()), 0)); const seq_len = text_with_image.dim(.seq); // Build the 3D positional ids const position_ids, const mrope_position_deltas = buildVisionPositionIds( self.config.vision_config.spatial_merge_size, input_ids, seq_len, prompt_shape, image_grid_thw, ); const real_seq_len = text_before_image.add(text_after_image).add(num_image_tokens); const hidden, const updated_cache = zml.call(self.text_model, .forward, .{ position_ids, text_with_image, cache_position, deepstack_features, kv_cache }); // Sample the next token using RNG const last_pos = real_seq_len.addConstant(-1).asScalar(); const last_hidden = hidden.dynamicSlice1d(hidden.axis(.seq), .{ .start = last_pos, .len = 1 }); const last_logits = projectToVocab(last_hidden, self.text_model.embed_tokens.weight); const next_token, const new_rng = self.sampleTokens(last_logits, rng); const next_token_with_shape = next_token.withTags(.{ .bs, .seq }); const result = .{ next_token_with_shape, updated_cache, mrope_position_deltas, new_rng }; return result; } // Forward decode pass for qwen model // Do not recompute the visual embeddings pub fn forward_decode( self: Qwen, input_ids: Tensor, cache_position: Tensor, kv_cache: KvCache, mrope_position_deltas: Tensor, rng: Tensor.Rng, ) struct { Tensor, KvCache, Tensor.Rng } { const embedded = zml.call(self.text_model.embed_tokens, .forward, .{input_ids}).withTags(.{ .bs, .seq, .d }); const position_ids = buildDecodePositionIds(cache_position, mrope_position_deltas); const hidden, const updated_cache = zml.call(self.text_model, .forward_decode, .{ position_ids, embedded, cache_position, kv_cache }); const logits = projectToVocab(hidden, self.text_model.embed_tokens.weight); // Sample the next token using RNG const last_logits = logits.slice1d(.seq, .{ .start = logits.dim(.seq) - 1, .end = logits.dim(.seq) }); const next_token, const new_rng = self.sampleTokens(last_logits, rng); const result = .{ next_token.reuseBuffer(input_ids), updated_cache, new_rng }; return result; } fn initKvCache(k: Tensor, v: Tensor, layer_index: Tensor) KvCache { return .{ .k = k.withTags(.{ .layer, .k, .h, .hd }), .v = v.withTags(.{ .layer, .k, .h, .hd }), .layer_index = layer_index, }; } fn projectToVocab(hidden: Tensor, embedding_weight: Tensor) Tensor { return hidden.convert(.f32).dotGeneral(embedding_weight.convert(.f32), &.{.{ -1, -1 }}, &.{}).withTags(.{ .bs, .seq, .voc }); } pub fn sampleTokens( self: Qwen, logits_: Tensor, rng: Tensor.Rng, ) struct { Tensor, Tensor.Rng } { const logits = logits_.withPartialTags(.{ .bs, .seq, .voc }); if (logits.shape().hasTag(.voc) == null) @panic("logits must have .voc tag"); const next_tokens, const new_rng = zml.nn.sampleTokens(logits, self.gen_opts, rng); return .{ next_tokens, new_rng }; } // Build the 3D positional ids for the vision transformer // Returns: (stacked_position_ids, mrope_position_deltas) pub fn buildVisionPositionIds( spatial_merge_size: u32, input_ids: Tensor, seq_len: i64, prompt_shape: Tensor, image_grid_thw: [3]u32, ) struct { Tensor, Tensor } { // Get the number of text tokens before the image, the number of image tokens and the number of text tokens after the image const text_before_image = prompt_shape.choose1d(0, 0).convert(.i32); const num_image_tokens = prompt_shape.choose1d(0, 1).convert(.i32); const text_after_image = prompt_shape.choose1d(0, 2).convert(.i32); // Build the 3D positional ids const before_image_positions = zml.Tensor.iota(Shape.init(.{ .bs = input_ids.dim(.bs), .seq = 4 }, .i32), .seq); const t = image_grid_thw[0]; const h = @divExact(image_grid_thw[1], spatial_merge_size); const w = @divExact(image_grid_thw[2], spatial_merge_size); var position_ids = zml.Tensor.iota( Shape.init(.{ .bs = input_ids.dim(.bs), .seq = seq_len }, .i32), .seq, ) .sub(num_image_tokens) .addConstant(@max(w, h, t)); position_ids = position_ids.dynamicUpdateSlice(.{ .seq = zml.Tensor.scalar(0, .i32) }, before_image_positions); // Define the shape of the iota tensor const iota_shape = Shape.init(.{ .bs = input_ids.dim(.bs), .t = t, .h = h, .w = w }, .i32); const reshape_shape = Shape.init(.{ .bs = input_ids.dim(.bs), .seq = t * h * w }, .i32); // Repeat the index along the 3 dimensions based on grid size (after the text (+4 tokens according to the chat template)) const t_index = zml.Tensor.iota(iota_shape, .t).reshape(reshape_shape).add(text_before_image); const h_index = zml.Tensor.iota(iota_shape, .h).reshape(reshape_shape).add(text_before_image); const w_index = zml.Tensor.iota(iota_shape, .w).reshape(reshape_shape).add(text_before_image); // Update the position ids with the 3D positional ids const position_ids_t = position_ids.dynamicUpdateSlice(.{ .seq = text_before_image }, t_index); const position_ids_h = position_ids.dynamicUpdateSlice(.{ .seq = text_before_image }, h_index); const position_ids_w = position_ids.dynamicUpdateSlice(.{ .seq = text_before_image }, w_index); // Stack the position ids const stacked_position_ids = zml.Tensor.stack(&.{ position_ids_t, position_ids_h, position_ids_w }, 0, .g); // Position max after 3d compression - real seq len const position_max_after_3d_compression = zml.Tensor.scalar(@max(w, h, t), .i32).add(text_after_image).add(text_before_image); const real_seq_len = text_before_image.add(text_after_image).add(num_image_tokens); const mrope_position_deltas = position_max_after_3d_compression.sub(real_seq_len).reshape(.{ .seq = 1 }); return .{ stacked_position_ids, mrope_position_deltas }; } test "buildVisionPositionIds" { std.debug.print("buildVisionPositionIds test started\n", .{}); const platform = zml.testing.env(); const allocator = std.testing.allocator; // Parameters const batch_size: u32 = 1; const seq_len: u32 = 79; // Create input buffers var input_ids_data = try allocator.alloc(i32, batch_size * seq_len); defer allocator.free(input_ids_data); for (0..batch_size * seq_len) |i| { input_ids_data[i] = @intCast(i % seq_len); } const input_ids_d = try zml.Buffer.fromSlice(platform, .{ .bs = batch_size, .seq = seq_len }, input_ids_data); defer input_ids_d.deinit(); const prompt_shape_d = try zml.Buffer.fromSlice(platform, .{ .seq = 3 }, &[_]i32{ 4, 64, 11 }); defer prompt_shape_d.deinit(); // Compile and execute buildVisionPositionIds const Local = struct { pub fn positionIds(input_ids: zml.Tensor, prompt_shape: zml.Tensor) zml.Tensor { return buildVisionPositionIds(2, input_ids, seq_len, prompt_shape, .{ 1, 16, 16 })[0]; } }; const result = try zml.testing.compileAndCall( platform, Local.positionIds, .{ input_ids_d, prompt_shape_d }, ); defer result.deinit(); const expected = [3][79]i32{ // temporal .{ 0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 }, // height .{ 0, 1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 }, // width .{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22 }, }; try std.testing.expectEqual(expected, try result.getValue([3][79]i32)); } // Build 3d positionnal based on mrope delta (delta between the position max after 3d compression and the real sequence length) fn buildDecodePositionIds(cache_position: Tensor, mrope_position_deltas: Tensor) Tensor { const cache_pos = cache_position.reshape(.{ .bs = 1, .seq = 1 }); const deltas = mrope_position_deltas.convert(.i64).reshape(.{ .bs = 1, .seq = 1 }); return zml.Tensor.stack(&.{ cache_pos.add(deltas), cache_pos.add(deltas), cache_pos.add(deltas) }, 0, .g).withTags(.{ .g, .bs, .seq }); } }; //========================Vision model======================== pub const VisionTransformer = struct { vision_patch_embed: VisionPatchEmbed, pos_embed: zml.nn.TokenEmbedding, blocks: []VisionBlock, patch_merger: PatchMerger, deepstack_patch_mergers: []PatchMerger, num_heads: u32, hidden_size: u32, spatial_merge_size: u32, num_position_embeddings: u32, rope_opts: zml.nn.RopeOpts, pub fn init(allocator: std.mem.Allocator, config: Qwen.Config, store: zml.aio.BufferStore) !VisionTransformer { const spatial_merge_size = config.vision_config.spatial_merge_size; const blocks = try allocator.alloc(VisionBlock, config.vision_config.depth); var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024); try prefix.push(stdx.noalloc, "model.visual.blocks"); for (0.., blocks) |i, *block| { try prefix.pushDigit(stdx.noalloc, i); defer prefix.pop(); var vision_attn = try zml.aio.populateModelWithPrefix(VisionAttention, allocator, store, prefix.concat("attn")); vision_attn.num_heads = config.vision_config.num_heads; var mlp = try zml.aio.populateModelWithPrefix(VisionMlp, allocator, store, prefix.concat("mlp")); mlp.hidden_act = zml.nn.Activation{ .gelu = {} }; var norm1 = try zml.aio.populateModelWithPrefix(zml.nn.LayerNorm, allocator, store, prefix.concat("norm1")); norm1.eps = 1e-6; var norm2 = try zml.aio.populateModelWithPrefix(zml.nn.LayerNorm, allocator, store, prefix.concat("norm2")); norm2.eps = 1e-6; block.* = .{ .attn = vision_attn, .mlp = mlp, .norm1 = norm1, .norm2 = norm2, .num_heads = config.vision_config.num_heads, }; } prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024); try prefix.push(stdx.noalloc, "model.visual.deepstack_merger_list"); const deepstack_patch_mergers = try allocator.alloc(PatchMerger, config.vision_config.deepstack_visual_indexes.len); for (0.., deepstack_patch_mergers) |i, *deepstack_patch_merger| { try prefix.pushDigit(stdx.noalloc, i); defer prefix.pop(); const norm = try zml.aio.populateModelWithPrefix(zml.nn.LayerNorm, allocator, store, prefix.concat("norm")); const linear_fc1 = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("linear_fc1")); const linear_fc2 = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, prefix.concat("linear_fc2")); deepstack_patch_merger.* = .{ .norm = norm, .linear_fc1 = linear_fc1, .linear_fc2 = linear_fc2, .out_hidden_size = config.vision_config.out_hidden_size, }; } return .{ .pos_embed = try zml.aio.populateModelWithPrefix(zml.nn.TokenEmbedding, allocator, store, "model.visual.pos_embed"), .blocks = blocks, .patch_merger = .{ .norm = try zml.aio.populateModelWithPrefix(zml.nn.LayerNorm, allocator, store, "model.visual.merger.norm"), .linear_fc1 = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, "model.visual.merger.linear_fc1"), .linear_fc2 = try zml.aio.populateModelWithPrefix(zml.nn.Linear, allocator, store, "model.visual.merger.linear_fc2"), .out_hidden_size = config.vision_config.out_hidden_size, }, .num_heads = config.vision_config.num_heads, .deepstack_patch_mergers = deepstack_patch_mergers, .vision_patch_embed = try VisionPatchEmbed.init(allocator, config.vision_config.patch_size, config.vision_config.temporal_patch_size, config.vision_config.in_channels, config.vision_config.out_hidden_size, store), .hidden_size = config.vision_config.hidden_size, .spatial_merge_size = spatial_merge_size, .num_position_embeddings = config.vision_config.num_position_embeddings, .rope_opts = zml.nn.RopeOpts{ .layout = .sequential, .freq_base = 10000.0, .scaling = .{ .default = {} }, }, }; } // Forward pass for the vision transformer // Outputs: // - hidden_states: the hidden states of the vision transformer (visual embedding) // - deepstack_features_list: the deepstack features (intermediate representation of the visual embedding) pub fn forward(self: VisionTransformer, x_input: Tensor, grid_thw: [3]u32) struct { Tensor, [3]Tensor } { const x = x_input; var pos_embeds = self.posEmbedInterpolate(&grid_thw); var rotary_pos_emb = rotaryPosEmbed(&grid_thw, self.spatial_merge_size, self.hidden_size, self.num_heads, self.rope_opts); rotary_pos_emb = zml.Tensor.concatenate(&.{ rotary_pos_emb, rotary_pos_emb }, 2); var hidden_states = zml.call(self.vision_patch_embed, .forward, .{x}); hidden_states = hidden_states.add(pos_embeds[0].convert(hidden_states.dtype())); const cos = rotary_pos_emb.cos(); const sin = rotary_pos_emb.sin(); const deepstack_visual_indexes = [3]u32{ 5, 11, 17 }; var count: usize = 0; var deepstack_features_list: [3]Tensor = undefined; for (0.., self.blocks) |layer, block| { hidden_states = zml.call(block, .forward, .{ hidden_states, cos, sin }); for (deepstack_visual_indexes) |index| { if (layer == index) { deepstack_features_list[count] = zml.call(self.deepstack_patch_mergers[count], .forward, .{ hidden_states, true }); count += 1; } } } hidden_states = zml.call(self.patch_merger, .forward, .{ hidden_states, false }); return .{ hidden_states, deepstack_features_list }; } // Positional embedding interpolation (representation of the image in a grid determined by the number of position embeddings 48 x 48) pub fn posEmbedInterpolate(self: VisionTransformer, grid: []const u32) [1]Tensor { // Calculate the number of grid points per side (sqrt of the number of position embeddings) const num_grid_per_side = std.math.pow(f32, @as(f32, @floatFromInt(self.num_position_embeddings)), 0.5); const m_size = self.spatial_merge_size; const embedding_dim = self.hidden_size; var outputs = [1]Tensor{undefined}; // Retrieve the dims in the image grid const t = grid[0]; const h = grid[1]; const w = grid[2]; const tensor_filled_1_h = zml.Tensor.constant(.{h}, zml.Data.init(.f32, 1)); const tensor_filled_1_w = zml.Tensor.constant(.{w}, zml.Data.init(.f32, 1)); // Build the indices for the height and width by linearly spacing the grid points const h_idxs = zml.Tensor.linspace(.{ .start = 0, .end = num_grid_per_side - 1, .steps = h }, .f32); const w_idxs = zml.Tensor.linspace(.{ .start = 0, .end = num_grid_per_side - 1, .steps = w }, .f32); const h_floor = h_idxs.floor(); const w_floor = w_idxs.floor(); // Build the ceil and floor for the height and width const h_ceil = h_floor.add(tensor_filled_1_h).clamp( zml.Tensor.scalar(0, .f32), zml.Tensor.scalar(num_grid_per_side - 1, .f32), ); const w_ceil = w_floor.add(tensor_filled_1_w).clamp( zml.Tensor.scalar(0, .f32), zml.Tensor.scalar(num_grid_per_side - 1, .f32), ); // Build the difference between the indices and the floor for the height and width -> delta with the grid points const dh = h_idxs.sub(h_floor); const dw = w_idxs.sub(w_floor); // Build the meshgrid for the height and width const tensor_filled_1_h_v = zml.Tensor.constant(.{ h, w }, zml.Data.init(.f32, 1)); const d_tensors_meshgrid = [2]Tensor{ dh, dw }; const floor_tensors_meshgrid = [2]Tensor{ h_floor, w_floor }; const ceil_tensors_meshgrid = [2]Tensor{ h_ceil, w_ceil }; const dhw_grid = zml.Tensor.cartesianProduct(2, d_tensors_meshgrid); const floorhw_grid = zml.Tensor.cartesianProduct(2, floor_tensors_meshgrid); const ceilhw_grid = zml.Tensor.cartesianProduct(2, ceil_tensors_meshgrid); // Compute the weights for the height and width const w11 = dhw_grid[0].mul(dhw_grid[1]); const w10 = dhw_grid[0].sub(w11); const w01 = dhw_grid[1].sub(w11); const w00 = tensor_filled_1_h_v.sub(dhw_grid[0]).sub(w01); const h_list = [4]Tensor{ floorhw_grid[0], floorhw_grid[0], ceilhw_grid[0], ceilhw_grid[0] }; const w_list = [4]Tensor{ floorhw_grid[1], ceilhw_grid[1], floorhw_grid[1], ceilhw_grid[1] }; // Stack the height and width lists const h_grid = zml.Tensor.stack(&h_list, 0, .layers); const w_grid = zml.Tensor.stack(&w_list, 0, .layers); const h_grid_idx = h_grid.scale(num_grid_per_side); const indices = h_grid_idx.add(w_grid).reshape(.{ 4, -1 }).convert(.i32); var weights = zml.Tensor.stack(&[4]Tensor{ w00, w01, w10, w11 }, 0, .layers).reshape(.{ 4, -1, 1 }); const embeds = zml.call(self.pos_embed, .forward, .{indices}); const weights_embed = embeds.convert(.f32).mul(weights.repeat1d(-1, embedding_dim)); const combined = weights_embed.sum(0).withTags(.{ .bs, .hw, .d }); // Split the combined tensor into the height and width dimensions const combined_reshape = combined.splitAxis(.hw, .{ .h = @divExact(h, m_size), .m1 = m_size, .w = @divExact(w, m_size), .m2 = m_size }); const combined_permuted = combined_reshape.transpose(.{ .bs, .h, .w, .m1, .m2, .d }); // Repeat the combined tensor t times along the temporal dimension const t_u63: u63 = @intCast(t); const repeated = combined_permuted.repeat1d(0, t_u63).reshape(.{ -1, embedding_dim }); outputs[0] = repeated; return outputs; } // Rotary position embedding for the vision transformer pub fn rotaryPosEmbed(grid_thw: []const u32, m_size: u32, hidden_size: u32, num_heads: u32, rope_opts: zml.nn.RopeOpts) Tensor { const t = grid_thw[0]; const h = grid_thw[1]; const w = grid_thw[2]; const pos_shape = zml.Shape.init(.{ .h = h, .w = w }, .f32); // Build the height position ids var hpos_ids = zml.Tensor.iota(pos_shape, 0); hpos_ids = hpos_ids.splitAxis(.h, .{ .h_div = @divExact(h, m_size), .m1 = m_size }).splitAxis(.w, .{ .w_div = @divExact(w, m_size), .m2 = m_size }); hpos_ids = hpos_ids.transpose(.{ .h_div, .w_div, .m1, .m2 }); hpos_ids = hpos_ids.reshape(.{ .seq = -1 }); // Build the width position ids var wpos_ids = zml.Tensor.iota(pos_shape, 1); wpos_ids = wpos_ids.splitAxis(.h, .{ .h_div = @divExact(h, m_size), .m1 = m_size }).splitAxis(.w, .{ .w_div = @divExact(w, m_size), .m2 = m_size }); wpos_ids = wpos_ids.transpose(.{ .h_div, .w_div, .m1, .m2 }); wpos_ids = wpos_ids.reshape(.{ .seq = -1 }); const pos_ids = zml.Tensor.stack(&[2]Tensor{ hpos_ids, wpos_ids }, 1, .layers).repeat1d(1, @as(u63, @intCast(t))).convert(.i32); // Compute the inverse frequency const inv_freq = zml.nn.invFreq(@intCast(32), rope_opts).withTags(.{.s}); const seq = zml.Tensor.arange(.{ .end = hidden_size / num_heads / 2 }, .f32).withTags(.{.d}); // Compute the outer product of the sequence and the inverse frequency const rotary_pos_emb_full = zml.Tensor.outer(seq, inv_freq); const output = rotary_pos_emb_full.gather(.{ .d = pos_ids }, .{}).merge(.{ .d = .{ .layers, .s } }); // Add a bs size for the output for the moment because the image processing is not done in batch but it is needed for the text processing const output_with_bs = output.reshape(.{ .bs = 1, .s = output.dim(.seq), .d = output.dim(.d) }); return output_with_bs; } }; // Get the padding for the convolution (same on all dimensions) fn getPaddingForSame(input_dims: [3]i64, kernel_dims: [3]i64, strides: [3]i64) [6]i64 { var res: [6]i64 = undefined; for (0..3) |i| { const output_size = @divFloor(input_dims[i], strides[i]); const total_padding = (output_size - 1) * strides[i] + kernel_dims[i] - input_dims[i]; const pad_start = @divFloor(total_padding, 2); const pad_end = total_padding - pad_start; res[2 * i] = pad_start; res[2 * i + 1] = pad_end; } return res; } pub const Conv3d = struct { weight: Tensor, bias: ?Tensor = null, temporal_stride: u32 = 2, spatial_stride: u32 = 16, pub fn forward(self: Conv3d, input: Tensor) Tensor { const x = input; var strides: [3]i64 = .{ self.temporal_stride, self.spatial_stride, self.spatial_stride }; const padding = getPaddingForSame(x.dims()[2..5].*, self.weight.dims()[2..5].*, strides); const loc = input.getContext().location(@src(), "Conv3d.forward", .{}); var y = x.convolution( self.weight.convert(x.dtype()), .{ .window_strides = &strides, .pad_value = &padding, .lhs_dilation = &.{ 1, 1, 1 }, .rhs_dilation = &.{ 1, 1, 1 }, .window_reversal = &.{ false, false, false }, .input_batch_dimension = 0, .input_feature_dimension = 1, .input_spatial_dimensions = &.{ 2, 3, 4 }, .kernel_input_feature_dimension = 1, .kernel_output_feature_dimension = 0, .kernel_spatial_dimensions = &.{ 2, 3, 4 }, .output_batch_dimension = 0, .output_feature_dimension = 1, .output_spatial_dimensions = &.{ 2, 3, 4 }, .feature_group_count = 1, .batch_group_count = 1, }, loc, ); if (self.bias) |b| y = y.add(b.convert(y.dtype()).broadcast(y._shape, &.{1})); return y; } }; pub const VisionBlock = struct { norm1: zml.nn.LayerNorm, norm2: zml.nn.LayerNorm, attn: VisionAttention, mlp: VisionMlp, num_heads: u32, pub fn forward(self: VisionBlock, hidden_states: Tensor, cos: Tensor, sin: Tensor) Tensor { const x = zml.call(self.norm1, .forward, .{hidden_states}); //Here we need to squeeze the output of the attention to remove the bs size because the image processing is not done in batch //To be discussed when the model will handle several images and several sequences const x1 = hidden_states.add(zml.call(self.attn, .forward, .{ x, cos, sin }).squeeze(0)); const x2 = zml.call(self.norm2, .forward, .{x1}); const x3 = x1.add(zml.call(self.mlp, .forward, .{x2})); return x3.reuseBuffer(hidden_states); } }; // Vision patch embedding // Project the image into a visual embedding by applying a 3D convolution along the temporal and spatial dims pub const VisionPatchEmbed = struct { proj: Conv3d, patch_size: u32 = 14, temporal_patch_size: u32 = 2, in_channels: u32 = 3, embed_dim: u32 = 1152, pub fn init( allocator: std.mem.Allocator, patch_size: u32, temporal_patch_size: u32, in_channels: u32, embed_dim: u32, store: zml.aio.BufferStore, ) !VisionPatchEmbed { var conv3d = try zml.aio.populateModelWithPrefix(Conv3d, allocator, store, "model.visual.patch_embed.proj"); conv3d.temporal_stride = temporal_patch_size; conv3d.spatial_stride = patch_size; return .{ .proj = conv3d, .patch_size = patch_size, .temporal_patch_size = temporal_patch_size, .in_channels = in_channels, .embed_dim = embed_dim, }; } pub fn forward(self: VisionPatchEmbed, hidden_states: Tensor) Tensor { const reshaped = hidden_states.reshape(.{ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size }); const conv_output = zml.call(self.proj, .forward, .{reshaped}); const reshaped_output = conv_output.reshape(.{ conv_output.dim(0), conv_output.dim(1) }); return reshaped_output; } }; pub fn rotate_half(x: Tensor) Tensor { const x1 = x.slice1d(-1, .{ .start = 0, .end = @divExact(x.dim(-1), 2) }); const x2 = x.slice1d(-1, .{ .start = @divExact(x.dim(-1), 2), .end = x.dim(-1) }); return Tensor.concatenate(&.{ x2.negate(), x1 }, -1); } pub fn applyRotaryPositionalEmbedding(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) struct { Tensor, Tensor } { // Broadcast cos and sin to the shape of q and k const cos_q = cos.broadcast(q.shape(), &.{ 0, 1, 3 }); const sin_q = sin.broadcast(q.shape(), &.{ 0, 1, 3 }); const cos_k = cos.broadcast(k.shape(), &.{ 0, 1, 3 }); const sin_k = sin.broadcast(k.shape(), &.{ 0, 1, 3 }); const q_dtype = q.convert(cos.dtype()); const k_dtype = k.convert(cos.dtype()); const q_embed = q_dtype.mul(cos_q).add(rotate_half(q_dtype).mul(sin_q)).withTags(.{ .bs, .q, .h, .hd }); const k_embed = k_dtype.mul(cos_k).add(rotate_half(k_dtype).mul(sin_k)).withTags(.{ .bs, .k, .h, .hd }); return .{ q_embed.convert(q.dtype()), k_embed.convert(k.dtype()) }; } // Vision-specific components pub const VisionAttention = struct { qkv: zml.nn.Linear, proj: zml.nn.Linear, num_heads: u32, //window_size: u32, //is_full_attention: bool, pub fn forward(self: VisionAttention, hidden_states: Tensor, cos: Tensor, sin: Tensor) Tensor { const qkv = zml.call(self.qkv, .forward, .{hidden_states}); const qkv_reshaped = qkv.reshape(.{ hidden_states.dim(0), 3, self.num_heads, -1, // head_dim }).withTags(.{ .s, .qkv, .h, .hd }); const qkv_permuted = qkv_reshaped.transpose(.{ .qkv, .s, .h, .hd }); const q, const k, var v = qkv_permuted.chunkExact(.qkv, 3); const q_embed, const k_embed = applyRotaryPositionalEmbedding(q, k, cos, sin); v = v.withTags(.{ .bs, .k, .h, .hd }); const attn_output = zml.nn.sdpa(q_embed, k_embed, v, .{ .allow_cudnn = true }); const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); const result = zml.call(self.proj, .forward, .{attn}); return result; } }; pub const PatchMerger = struct { norm: zml.nn.LayerNorm, linear_fc1: zml.nn.Linear, linear_fc2: zml.nn.Linear, out_hidden_size: u32, pub fn forward(self: PatchMerger, x: Tensor, use_post_shuffle_norm: bool) Tensor { const gelu: zml.nn.Activation = .gelu; // Apply the post shuffle norm if needed var x1 = if (use_post_shuffle_norm) zml.call(self.norm, .forward, .{x.reshape(.{ -1, 1024 * 4 })}) else zml.call(self.norm, .forward, .{x}); x1 = x1.reshape(.{ -1, 1024 * 4 }); const x2 = zml.call(self.linear_fc1, .forward, .{x1}); const x3 = gelu.forward(x2); const x4 = zml.call(self.linear_fc2, .forward, .{x3}).withTags(.{ .seq, .d }); return x4; } }; pub const VisionMlp = struct { // MLP classique linear_fc1: zml.nn.Linear, linear_fc2: zml.nn.Linear, hidden_act: zml.nn.Activation = .{ .gelu = {} }, pub fn forward(self: VisionMlp, x: Tensor) Tensor { const x1 = zml.call(self.linear_fc1, .forward, .{x}); const gelu_tanh_approximation = zml.nn.Activation{ .gelu = {} }; const x2 = gelu_tanh_approximation.forward(x1.convert(.f32)).convert(x.dtype()); const x3 = zml.call(self.linear_fc2, .forward, .{x2}); return x3; } }; //========================Text model======================== pub const TextModel = struct { embed_tokens: zml.nn.TokenEmbedding, layers: []TransformerLayer, norm: RmsNorm, rotary_embed: TextRotaryEmbedding, max_seq_len: u32 = 1000, num_heads: u32, num_kv_heads: u32, mrope_section: [3]u32, pub fn init(allocator: std.mem.Allocator, config: Qwen.Config, store: zml.aio.BufferStore) !TextModel { const layers = try allocator.alloc(TransformerLayer, config.text_config.num_hidden_layers); var prefix = try zml.aio.PrefixBuilder.initCapacity(allocator, 1024); const text_rotary_embed = try TextRotaryEmbedding.init(allocator, config.text_config.hidden_size, config.text_config.rope_theta, config.text_config.rope_scaling.mrope_section); try prefix.push(stdx.noalloc, "model.language_model.layers"); for (0.., layers) |i, *layer| { try prefix.pushDigit(stdx.noalloc, i); defer prefix.pop(); var self_attn = try zml.aio.populateModelWithPrefix(SelfAttn, allocator, store, prefix.concat("self_attn")); self_attn.num_heads = config.text_config.num_attention_heads; self_attn.num_kv_heads = config.text_config.num_key_value_heads; const mlp = try zml.aio.populateModelWithPrefix(Mlp, allocator, store, prefix.concat("mlp")); var input_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("input_layernorm")); input_layernorm.eps = 1e-6; var post_attention_layernorm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, prefix.concat("post_attention_layernorm")); post_attention_layernorm.eps = config.text_config.rms_norm_eps; layer.* = .{ .self_attn = self_attn, .mlp = mlp, .input_layernorm = input_layernorm, .post_attention_layernorm = post_attention_layernorm, .num_heads = config.text_config.num_attention_heads, }; } return .{ .embed_tokens = try zml.aio.populateModelWithPrefix(zml.nn.TokenEmbedding, allocator, store, "model.language_model.embed_tokens"), .layers = layers, .norm = try zml.aio.populateModelWithPrefix(RmsNorm, allocator, store, "model.language_model.norm"), .num_heads = config.text_config.num_attention_heads, .num_kv_heads = config.text_config.num_key_value_heads, .rotary_embed = text_rotary_embed, .mrope_section = config.text_config.rope_scaling.mrope_section, }; } // Forward prefill pass for the text model pub fn forward(self: TextModel, position_ids: Tensor, inputs_embeds: Tensor, cache_position: Tensor, deepstack_visual_embeds: [3]Tensor, kv_cache: KvCache) struct { Tensor, KvCache } { var hidden_states = inputs_embeds; const cos, const sin = self.rotary_embed.forward(position_ids); var count: u32 = 0; // Build the indices for the deepstack visual embeddings addition const indices = zml.Tensor.iota(Shape.init(.{ .seq = deepstack_visual_embeds[0].dim(.seq) }, .u32), .seq).addConstant(4); var updated_kv_cache = kv_cache; for (self.layers, 0..) |layer, i| { hidden_states, updated_kv_cache = zml.call(layer, .forward, .{ hidden_states, cache_position, cos, sin, updated_kv_cache.atLayer(i) }); hidden_states = hidden_states.withTags(.{ .bs, .seq, .d }); // Add the n visual embeddings at the n first layers outputs if (count < deepstack_visual_embeds.len) { const deepstack = deepstack_visual_embeds[count]; hidden_states = hidden_states.scatterSlices(.{ .seq = indices }, zml.torch.unsqueeze(deepstack, 0).convert(hidden_states.dtype()).withTags(.{ .bs, .seq, .d }), .{ .update_fn = zml.Tensor.ScatterOpts.increment }); count += 1; } } const output = zml.call(self.norm, .forward, .{hidden_states}); return .{ output, updated_kv_cache }; } // Forward decode pass for the text model // Similar to the prefill pass, but without the deepstack visual embeddings addition pub fn forward_decode(self: TextModel, position_ids: Tensor, inputs_embeds: Tensor, cache_position: Tensor, kv_cache: KvCache) struct { Tensor, KvCache } { var hidden_states = inputs_embeds; const cos, const sin = self.rotary_embed.forward(position_ids); var updated_kv_cache = kv_cache; for (self.layers, 0..) |layer, i| { hidden_states, updated_kv_cache = zml.call(layer, .forward, .{ hidden_states, cache_position, cos, sin, updated_kv_cache.atLayer(i) }); hidden_states = hidden_states.withTags(.{ .bs, .seq, .d }); } const output = zml.call(self.norm, .forward, .{hidden_states}); return .{ output, updated_kv_cache.reuseBuffer(kv_cache) }; } }; pub const TextRotaryEmbedding = struct { rope_opts: zml.nn.RopeOpts, dim: u32, mrope_section: [3]u32, pub fn init(allocator: std.mem.Allocator, dim: u32, theta: f32, mrope_section: [3]u32) !TextRotaryEmbedding { _ = allocator; return .{ .rope_opts = zml.nn.RopeOpts{ .layout = .sequential, .freq_base = theta, .scaling = .{ .default = {} }, }, .dim = dim, .mrope_section = mrope_section, }; } pub fn forward(self: TextRotaryEmbedding, position_ids: Tensor) struct { Tensor, Tensor } { const mrope_section = [3]u32{ self.mrope_section[0], self.mrope_section[1], self.mrope_section[2] }; // from config 24 +20 +20 = 64 i.e. hd / 2 const inv_freq = zml.nn.invFreq(@intCast(128), self.rope_opts).withTags(.{.dh}).convert(.f32); // perform the outer product between the position ids and the inverse frequencies, output shape is (3, bs, dim_head//2, seq len) var freqs = position_ids.convert(inv_freq.dtype()).outer(inv_freq); // Interleaved mrope // Slice the frequency tensor to get the frequency for the temporal, height and width dimensions var freqs_t, var freqs_h, var freqs_w = freqs.chunkExact(.g, 3); //Squeeze the grid dim because we process per dim independently freqs_t = freqs_t.squeeze(.g); freqs_h = freqs_h.squeeze(.g); freqs_w = freqs_w.squeeze(.g); const indices = zml.Tensor.iota(Shape.init(.{ .h = @as(u32, @intCast(mrope_section[1])) }, .i32), .h); // Build the indices for the height and width dimensions const h_indices = indices.scale(3).addConstant(1); const w_indices = indices.scale(3).addConstant(2); // Gather scatter the frequencies to build the tensor such as [t,h,w,t,h,w,...,t,h,w,t,t,t,t] const h_input = freqs_h.gather(.{ .dh = h_indices }, .{ .indices_are_sorted = true }); const w_input = freqs_w.gather(.{ .dh = w_indices }, .{ .indices_are_sorted = true }); freqs_t = freqs_t.transpose(.{ .dh, .bs, .seq }); freqs_t = freqs_t.scatterSlices(.{ .dh = h_indices }, h_input, .{ .update_fn = zml.Tensor.ScatterOpts.override }); freqs = freqs_t.scatterSlices(.{ .dh = w_indices }, w_input, .{ .update_fn = zml.Tensor.ScatterOpts.override }); freqs = freqs.transpose(.{ .bs, .seq, .dh }); const emb = zml.Tensor.concatenate(&.{ freqs, freqs }, -1); const cos = emb.cos(); const sin = emb.sin(); return .{ cos, sin }; } }; pub const TransformerLayer = struct { input_layernorm: RmsNorm, self_attn: SelfAttn, mlp: Mlp, post_attention_layernorm: RmsNorm, num_heads: u32, pub fn forward( self: TransformerLayer, x0: Tensor, token_index: Tensor, cos: Tensor, sin: Tensor, kv_cache: KvCache, ) struct { Tensor, KvCache } { const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0}); const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, cos, sin, kv_cache, }); const x1 = x0.add(delta0); const x1_normalized = zml.call(self.post_attention_layernorm, .forward, .{x1}); const x2 = zml.call(self.mlp, .forward, .{x1_normalized}).add(x1); const result = .{ x2.reuseBuffer(x0), updated_kv_cache }; return result; } }; pub const SelfAttn = struct { q_proj: zml.nn.Linear, k_proj: zml.nn.Linear, v_proj: zml.nn.Linear, q_norm: RmsNorm, k_norm: RmsNorm, o_proj: zml.nn.Linear, num_heads: i64 = undefined, num_kv_heads: i64 = 0, pub fn forward( self: SelfAttn, x: Tensor, token_position: Tensor, cos: Tensor, sin: Tensor, kv_cache: KvCache, ) struct { Tensor, KvCache } { // Compute key query and value projections (split the dimension into head and dimension) var q = zml.call(self.q_proj, .forward, .{x}).splitAxis(-1, .{ .h = 32, .hd = .auto }); var k = zml.call(self.k_proj, .forward, .{x}).splitAxis(-1, .{ .h = 8, .hd = .auto }); var v = zml.call(self.v_proj, .forward, .{x}).splitAxis(-1, .{ .h = 8, .hd = .auto }); const token_index = token_position.convert(kv_cache.layer_index.dtype()); const seq_len = kv_cache.k.dim(.k); // Generate the attention mask var attn_mask = zml.nn.causalAttnMask(.{ .q = seq_len, .k = seq_len }, x.dtype(), null); attn_mask = attn_mask.gatherSlices(zml.Shape.init(.{ .q = x.dim(.seq) }, attn_mask.dtype()), token_index.reshape(.{ .coord = 1 }), .{}); q = zml.call(self.q_norm, .forward, .{q.rename(.{ .hd = .d })}).rename(.{ .d = .hd }).withTags(.{ .bs, .q, .h, .hd }); k = zml.call(self.k_norm, .forward, .{k.rename(.{ .hd = .d })}).rename(.{ .d = .hd }).withTags(.{ .bs, .k, .h, .hd }); v = v.withTags(.{ .bs, .k, .h, .hd }); q, k = applyRotaryPositionalEmbedding(q, k, cos, sin); // Update the key-value cache const kv_cache_updated = kv_cache.update(k, v, token_index); // Retrieve the cached key and value const cached_k = kv_cache_updated.keys().convert(q.dtype()); const cached_v = kv_cache_updated.values().convert(q.dtype()); const orig_dtype = q.dtype(); // Attention const attn_output = zml.nn.sdpa(q, cached_k, cached_v, .{ .attn_mask = attn_mask, .allow_cudnn = true }).convert(orig_dtype); // Merge head and dimension back together const attn = attn_output.merge(.{ .d = .{ .h, .hd } }).rename(.{ .q = .s }); const result = .{ zml.call(self.o_proj, .forward, .{attn}), kv_cache_updated }; return result; } }; pub const Mlp = struct { up_proj: zml.nn.Linear, gate_proj: zml.nn.Linear, down_proj: zml.nn.Linear, pub fn forward(self: Mlp, x: Tensor) Tensor { const proj = zml.call(self.up_proj, .forward, .{x}); var output = zml.call(self.gate_proj, .forward, .{x}); output = output.silu().mul(proj); const result = zml.call(self.down_proj, .forward, .{output}); return result; } }; const RmsNorm = struct { weight: Tensor, eps: f32 = 1e-6, pub fn forward(self: RmsNorm, input: Tensor) Tensor { const x = if (input.shape().isFullyTagged()) input else input.withPartialTags(.{.d}); const normalized = zml.nn.rmsNorm(x, .d, self.eps); return normalized.mul(self.weight.withTags(.{.d}).broad(x.shape()).convert(x.dtype())); } }; pub const KvCache = struct { k: Tensor, v: Tensor, layer_index: Tensor, pub fn init(kv_shape: zml.Shape) KvCache { // The KV-cache is initialized with ones to detect reads of uninitialized memory. return .{ .k = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .v = Tensor.constant(kv_shape, kv_shape.dtype().one()).withSharding(.{.h}), .layer_index = Tensor.scalar(-1, .i64), }; } pub fn initShape(kv_shape: zml.Shape) ShapeOf(KvCache) { return .{ .k = kv_shape, .v = kv_shape, .layer_index = zml.Shape.init(.{}, .i64), }; } pub fn initBuffer(kv_shape: zml.Shape, platform: zml.Platform) !zml.Bufferized(KvCache) { return .{ .k = try zml.Buffer.uninitialized(platform, kv_shape, .{}), .v = try zml.Buffer.uninitialized(platform, kv_shape, .{}), .layer_index = try zml.Buffer.scalar(platform, 0, .i64), }; } pub fn keys(self: KvCache) Tensor { return self.k.dynamicSlice(.{ .layer = Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer); } pub fn values(self: KvCache) Tensor { return self.v.dynamicSlice(.{ .layer = Tensor.DynSlice{ .start = self.layer_index, .len = 1 } }).squeeze(.layer); } pub fn update(self: KvCache, new_k: Tensor, new_v: Tensor, token_index: ?Tensor) KvCache { const k_shape = self.k.shape().drop(.layer); var layer = self.layer_index; layer = if (token_index) |idx| layer.broad(idx.shape()) else layer; return if (token_index) |idx| .{ .k = self.k.scatterSlices( .{ .layer = layer, .k = idx }, new_k.convert(self.k.dtype()).transpose(k_shape), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.k), .v = self.v.scatterSlices( .{ .layer = layer, .k = idx }, new_v.convert(self.v.dtype()).transpose(k_shape), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.v), .layer_index = self.layer_index, } else .{ .k = self.k.scatterSlices( .{ .layer = layer }, new_k.convert(self.k.dtype()).transpose(k_shape), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.k), .v = self.v.scatterSlices( .{ .layer = layer }, new_v.convert(self.v.dtype()).transpose(k_shape), .{ .indices_are_sorted = true, .update_fn = zml.Tensor.ScatterOpts.override }, ).reuseBuffer(self.v), .layer_index = self.layer_index, }; } pub fn atLayer(self: KvCache, layer_index: usize) KvCache { return .{ .k = self.k, .v = self.v, .layer_index = Tensor.scalar(layer_index, .i64), }; } pub fn reuseBuffer(self: KvCache, other: KvCache) KvCache { return .{ .k = self.k.reuseBuffer(other.k), .v = self.v.reuseBuffer(other.v), .layer_index = self.layer_index.reuseBuffer(other.layer_index), }; } }; // Resize bicubic function // Resize the image using bicubic interpolation // image: Tensor, the image to be resized // resized_axes: anytype, the axes to be resized // opt: zml.nn.ResizeOpts, the options for the resize -> contain original length // returns: Tensor, the resized image pub fn ResizeBicubic(image: Tensor, resized_axes: anytype, opt: zml.nn.ResizeOpts) Tensor { const new_size, const tags_ = zml.Shape.parseStruct(u63, resized_axes); var out = image; for (new_size.constSlice(), tags_.constSlice()) |d, t| { const ax = image.shape().axis(t); const child_opt: zml.nn.ResizeOpts = .{ .original_len = opt.original_len, .precision = opt.precision, }; out = ResizeCubic1d(out, ax, d, child_opt); } return out; } /// Bicubic interpolation along a single axis fn ResizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: zml.nn.ResizeOpts) Tensor { const ax = image.axis(axis); const res_shape = image.shape().set(ax, new_len); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); // Extract the correct dimension from original_len if it's a vector const og_len = if (opt.original_len) |o| blk: { // If original_len is a vector ( here chw=3), extract the dimension for this axis if (o.rank() == 1) { // Get the index of the axis in the original length const idx_in_original = @as(i64, @intCast(ax)); break :blk o.choose1d(0, idx_in_original).convert(dtype); } else { // It's already a scalar break :blk o.convert(dtype); } } else Tensor.scalar(image.dim(ax), dtype); // Calculate scale const align_corners = false; // Compute the scale between the original length (not the padded one) and the new length const scale = if (align_corners and new_len > 1) og_len.addConstant(-1).scale(stdx.math.divFloat(f32, 1, new_len - 1)) else og_len.scale(stdx.math.divFloat(f32, 1, new_len)); // Generate output positions const dst_indices = Tensor.arange(.{ .end = new_len }, dtype); const src_f = if (align_corners) dst_indices.mul(scale) else dst_indices.addConstant(0.5).mul(scale).addConstant(-0.5); // Calculate floor and fractional part const input_index_floor = src_f.floor(); const t = src_f.sub(input_index_floor); // Start index for 4-pixel window (leftmost pixel is at floor - 1) const start_idx = input_index_floor.convert(.i32).addConstant(-1); // Calculate bicubic weights for all positions const A: f32 = -0.75; // Catmull-Rom coefficient const weights = computeBicubicWeights(t, A); // For each of the 4 neighbors, compute indices and gather values var accumulated = Tensor.constant(res_shape, dtype.zero()); inline for (0..4) |i| { // Compute neighbor indices const neighbor_idx_raw = start_idx.addConstant(@as(i32, @intCast(i))); const neighbor_idx_clamped = neighbor_idx_raw .maximum(Tensor.scalar(0, .i32)) .minimum(og_len.convert(.i32).addConstant(-1)); // Gather neighbor values using gather_ (like resizeLinear1d) const neighbor_values = image .gather_(&.{ax}, &.{neighbor_idx_clamped}, .{ .indices_are_sorted = true }) .convert(dtype); // Get weight for this neighbor const weight = weights[i]; // Broadcast weight to res_shape (matching the output shape) along axis ax const weight_broadcasted = weight.broadcast(res_shape, &.{ax}); // Multiply and accumulate const weighted = neighbor_values.mul(weight_broadcasted); accumulated = accumulated.add(weighted); } return accumulated.convert(image.dtype()).withTags(image.shape().tags()); } /// Compute bicubic weights for fractional distance t /// Returns array of 4 weight tensors (one for each neighbor at offsets -1, 0, 1, 2) fn computeBicubicWeights(t: Tensor, A: f32) [4]Tensor { const x = t; const x2 = x.mul(x); const one_minus_x = Tensor.scalar(1.0, x.dtype()).sub(x); const one_minus_x2 = one_minus_x.mul(one_minus_x); const one_minus_x_plus_1 = one_minus_x.addConstant(1.0); // Weight for neighbor -1 (distance 1+x) const w0 = ((x.addConstant(1.0).scale(A).addConstant(-5.0 * A)).mul(x.addConstant(1.0)).addConstant(8.0 * A)).mul(x.addConstant(1.0)).addConstant(-4.0 * A); // Weight for neighbor 0 (distance x) const w1 = ((x.scale(A + 2.0).addConstant(-(A + 3.0))).mul(x2)).addConstant(1.0); // Weight for neighbor 1 (distance 1-x) const w2 = ((one_minus_x.scale(A + 2.0).addConstant(-(A + 3.0))).mul(one_minus_x2)).addConstant(1.0); // Weight for neighbor 2 (distance 2-x) const w3 = ((one_minus_x_plus_1.scale(A).addConstant(-5.0 * A)).mul(one_minus_x_plus_1).addConstant(8.0 * A)).mul(one_minus_x_plus_1).addConstant(-4.0 * A); return .{ w0, w1, w2, w3 }; }