diff --git a/examples/modernbert/main.zig b/examples/modernbert/main.zig index 1206eb9..0df16b3 100644 --- a/examples/modernbert/main.zig +++ b/examples/modernbert/main.zig @@ -216,27 +216,26 @@ pub fn unmask( } pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 { - var tokens = std.array_list.Managed(u32).init(allocator); var encoder = try tokenizer.encoder(); defer encoder.deinit(); const bos = tokenizer.tokenToId("[CLS]") orelse return error.NoSuchToken; const eos = tokenizer.tokenToId("[SEP]") orelse return error.NoSuchToken; - try tokens.append(bos); - try tokens.appendSlice(try encoder.encode(prompt)); - try tokens.append(eos); + var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len); + try tokens.append(allocator, bos); + try tokens.appendSlice(allocator, try encoder.encode(prompt)); + try tokens.append(allocator, eos); - return tokens.toOwnedSlice(); + return tokens.toOwnedSlice(allocator); } fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize { - var mask_positions = std.array_list.Managed(usize).init(allocator); - defer mask_positions.deinit(); + var mask_positions: std.ArrayList(usize) = .empty; for (tokens, 0..) |token, i| { if (token == mask_token) { - try mask_positions.append(i); + try mask_positions.append(allocator, i); } } @@ -247,7 +246,7 @@ fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_tok if (mask_positions.items.len > 1) log.warn("Currently only supporting one [MASK] per input", .{}); - return mask_positions.toOwnedSlice(); + return mask_positions.toOwnedSlice(allocator); } fn prepareTensorInputs(