Radix/examples/qwen3_vl/main.zig

746 lines
29 KiB
Zig

const std = @import("std");
const async = @import("async");
const zml = @import("zml");
const qwen = @import("qwen3_vl.zig");
const clap = @import("clap");
const stdx = @import("stdx");
const zignal = @import("zignal");
const floats = zml.floats;
const log = std.log.scoped(.qwen);
test {
std.testing.refAllDecls(@This());
}
pub const std_options: std.Options = .{
.log_level = .info,
.logFn = async.logFn(std.log.defaultLog),
};
const params = clap.parseParamsComptime(
\\--help print this help
\\--prompt <STRING> the prompt
\\--image <STRING> path to the image file (BMP format)
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
\\--seed <UINT> random seed (optional)
\\--seq-len <UINT> sequence length (default: 512)
\\--create-options <STRING> platform creation options in ZON format, defaults to {}
);
pub fn generateText(
config: qwen.Qwen.Config,
_: qwen.Qwen3VL,
mod_prefill: zml.ModuleExe(qwen.Qwen3VL.forward),
mod_decode: zml.ModuleExe(qwen.Qwen3VL.forward_decode),
kv_cache_: zml.Bufferized(qwen.KvCache),
tokenizer: zml.tokenizer.Tokenizer,
allocator: std.mem.Allocator,
seed: u128,
prompt: []const u8,
image_path: []const u8,
preprocessor_config: PreprocessorConfig,
max_seq_len: u32,
writer: *std.Io.Writer,
) !void {
// Preprocess image and prompt
const preprocessor_input = try preprocessor(
allocator,
tokenizer,
prompt,
config,
preprocessor_config,
image_path,
max_seq_len,
4096, // max_side of the image
);
defer {
preprocessor_input.image_buffer_chw.deinit(allocator);
preprocessor_input.prompt_tokens.deinit(allocator);
preprocessor_input.prompt_shape.deinit(allocator);
preprocessor_input.image_dim.deinit(allocator);
preprocessor_input.token_index.deinit(allocator);
}
const platform = mod_decode.platform();
var tokenizer_decoder = try tokenizer.decoder();
defer tokenizer_decoder.deinit();
// Extract prompt_shape values before converting to device buffer
const prompt_shape_values = preprocessor_input.prompt_shape.items(u32);
const total_seq_len = prompt_shape_values[0] + prompt_shape_values[1] + prompt_shape_values[2];
// Prepare device buffers for prefill
const image_buffer_chw = try preprocessor_input.image_buffer_chw.toDevice(platform);
defer image_buffer_chw.deinit();
const prompt_tokens = try preprocessor_input.prompt_tokens.toDevice(platform);
defer prompt_tokens.deinit();
const prompt_shape = try preprocessor_input.prompt_shape.toDevice(platform);
defer prompt_shape.deinit();
const image_dim = try preprocessor_input.image_dim.toDevice(platform);
defer image_dim.deinit();
const token_index = try preprocessor_input.token_index.toDevice(platform);
defer token_index.deinit();
// init RNG and buffers
var rng = try zml.Tensor.Rng.init(platform, seed);
var generated_token_buffer = [_]u32{undefined};
// Prefill: process the full prompt with image
var kv_cache, var mrope_position_deltas, rng = prefill: {
const next_token, const kv_cache, const mrope_deltas, const new_rng = mod_prefill.call(.{
image_buffer_chw,
prompt_tokens,
image_dim,
token_index,
prompt_shape,
kv_cache_,
rng,
});
// Extract the generated token
_ = try next_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
break :prefill .{ kv_cache, mrope_deltas, new_rng };
};
defer zml.aio.unloadBuffers(&kv_cache);
defer mrope_position_deltas.deinit();
// Prepare for token-by-token generation,
// start with the token generated based on the full prompt.
var current_token = try zml.Buffer.fromSlice(platform, .{ .bs = 1, .seq = 1 }, &generated_token_buffer);
defer current_token.deinit();
const output_tokens_len = max_seq_len - total_seq_len - 1;
const start = std.time.microTimestamp();
// One token has already been generated by the prefill.
var num_tokens_generated: usize = 1;
// Store all generated tokens
var generated_tokens = try std.ArrayList(u32).initCapacity(allocator, output_tokens_len);
defer generated_tokens.deinit(allocator);
const token_gen = max_seq_len - total_seq_len;
generation: for (0..token_gen) |i| {
// Collect and print generated sequence
num_tokens_generated += 1;
const generated_token = generated_token_buffer[0];
try generated_tokens.append(allocator, generated_token);
if (try tokenizer_decoder.next(generated_token)) |chunk| {
try writer.writeAll(chunk);
}
// check for eos
if (i == output_tokens_len) break :generation;
if (generated_token == 151643 or generated_token == 151645) break :generation;
// Current token pos needs to go into a zml.Buffer
const cache_position_buffer = &[_]i64{@intCast(total_seq_len - 1 + i)};
const cache_position = try zml.Buffer.fromSlice(platform, .{}, cache_position_buffer);
defer cache_position.deinit();
// Call to generate the next token
const next_token, const updated_kv_cache, const new_rng = mod_decode.call(.{ current_token, cache_position, kv_cache, mrope_position_deltas, rng });
current_token = next_token;
kv_cache = updated_kv_cache;
rng = new_rng;
// Extract the generated token from the buffer
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
}
const end = std.time.microTimestamp();
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
// Decode and print all generated tokens at the end
std.debug.print("\n", .{});
for (generated_tokens.items) |token| {
if (try tokenizer_decoder.next(token)) |chunk| {
try writer.writeAll(chunk);
}
}
std.debug.print("\n", .{});
log.info("Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
}
pub fn main() !void {
try async.AsyncThread.main(std.heap.c_allocator, asyncMain);
}
pub fn asyncMain() !void {
log.info(" Qwen3-VL was compiled with {}", .{@import("builtin").mode});
const allocator = std.heap.c_allocator;
const parsers = comptime .{
.BOOL = bool_parser,
.UINT = clap.parsers.int(u32, 0),
.STRING = clap.parsers.string,
.PATH = clap.parsers.string,
};
var diag: clap.Diagnostic = .{};
var stderr_buffer: [1024]u8 = undefined;
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
defer stderr.interface.flush() catch {};
var cli = clap.parse(clap.Help, &params, parsers, .{
.diagnostic = &diag,
.allocator = allocator,
}) catch |err| {
diag.report(&stderr.interface, err) catch {};
stderr.interface.writeAll("usage: ") catch {};
clap.usage(&stderr.interface, clap.Help, &params) catch {};
stderr.interface.writeAll("\n") catch {};
return;
};
defer cli.deinit();
if (cli.args.help != 0) {
clap.help(&stderr.interface, clap.Help, &params, .{}) catch {};
return;
}
const hf_model_path = cli.args.@"hf-model-path" orelse {
log.err("Missing --hf-model-path", .{});
return;
};
const image_path = cli.args.image orelse {
log.err("Missing --image", .{});
return;
};
const model_config_path = try std.fs.path.join(allocator, &.{ hf_model_path, "config.json" });
defer allocator.free(model_config_path);
const model_weights_path = b: {
const simple_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors" });
if (async.File.access(simple_path, .{})) {
break :b simple_path;
} else |_| {
allocator.free(simple_path);
}
const sharded_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" });
break :b sharded_path;
};
defer allocator.free(model_weights_path);
const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" });
defer allocator.free(model_tokenizer_path);
const preprocessor_config_path = try std.fs.path.join(allocator, &.{ hf_model_path, "preprocessor_config.json" });
defer allocator.free(preprocessor_config_path);
// Load config
const config = blk: {
var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only });
defer config_json_file.close() catch unreachable;
var config_json_buffer: [256]u8 = undefined;
var config_reader = config_json_file.reader(&config_json_buffer);
var reader = std.json.Reader.init(allocator, &config_reader.interface);
defer reader.deinit();
const config_obj = try std.json.parseFromTokenSourceLeaky(qwen.Qwen.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
break :blk config_obj;
};
// Load preprocessor config
const preprocessor_config = blk: {
var preprocessor_config_json_file = try async.File.open(preprocessor_config_path, .{ .mode = .read_only });
defer preprocessor_config_json_file.close() catch unreachable;
var preprocessor_config_json_buffer: [256]u8 = undefined;
var preprocessor_config_reader = preprocessor_config_json_file.reader(&preprocessor_config_json_buffer);
var reader = std.json.Reader.init(allocator, &preprocessor_config_reader.interface);
defer reader.deinit();
const preprocessor_config_obj = try std.json.parseFromTokenSourceLeaky(PreprocessorConfig, allocator, &reader, .{ .ignore_unknown_fields = true });
break :blk preprocessor_config_obj;
};
var context = try zml.Context.init();
defer context.deinit();
const compilation_options = zml.CompilationOptions{
.xla_dump_to = "/tmp/zml/qwen3vl",
.sharding_enabled = true,
};
// Initialize ZML platform
const create_opts_zon = cli.args.@"create-options" orelse ".{}";
const create_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(create_opts_zon), null, .{ .free_on_error = false }) catch |err| {
log.err("Failed to parse --create-options as ZON ({}): {s}", .{ err, create_opts_zon });
return err;
};
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
defer store.deinit();
// Initialize model
const seq_len: u32 = cli.args.@"seq-len" orelse 512;
// Options for the generation
const qwen_options: qwen.Qwen.Options = .{
.max_seq_len = seq_len,
.sampling_strategy = .{
.topk = 3,
.temperature = 1.2,
},
};
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
defer compiler_arena.deinit();
const qwen_tensors: qwen.Qwen3VL = try qwen.Qwen3VL.init(compiler_arena.allocator(), config, qwen_options, store);
// Load tokenizer early (needed for preprocessor)
var tokenizer = blk: {
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
var timer = try stdx.time.Timer.start();
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
};
errdefer tokenizer.deinit();
const prompt = cli.args.prompt orelse "Describe this image.";
// Use preprocessor to calculate all needed values for compilation
const preprocessor_input = try preprocessor(
allocator,
tokenizer,
prompt,
config,
preprocessor_config,
image_path,
seq_len,
4096, // max_side
);
defer {
preprocessor_input.image_buffer_chw.deinit(allocator);
preprocessor_input.prompt_tokens.deinit(allocator);
preprocessor_input.prompt_shape.deinit(allocator);
preprocessor_input.image_dim.deinit(allocator);
preprocessor_input.token_index.deinit(allocator);
}
// Use shapes from preprocessor for compilation
const image_buffer_shape = preprocessor_input.image_buffer_chw.shape();
const prompt_tokens_shape = preprocessor_input.prompt_tokens.shape();
const prompt_shape_shape = preprocessor_input.prompt_shape.shape();
const image_dim_shape = preprocessor_input.image_dim.shape();
const token_index_shape = preprocessor_input.token_index.shape();
// Specify shapes for decode
const decode_input_ids_shape = zml.Shape.init(.{ .bs = 1, .seq = 1 }, .u32);
const decode_cache_position_shape = zml.Shape.init(.{}, .i64);
const decode_mrope_shape = zml.Shape.init(.{ .seq = 1 }, .i32);
const dtype = qwen_tensors.qwen.text_model.embed_tokens.weight.dtype();
const kv_shape = zml.Shape.init(.{
.bs = 1,
.layer = config.text_config.num_hidden_layers,
.k = seq_len,
.h = config.text_config.num_key_value_heads,
.hd = config.text_config.head_dim,
}, dtype).withSharding(.{.h});
const kv_cache_shape: zml.ShapeOf(qwen.KvCache) = qwen.KvCache.initShape(kv_shape);
const rng_shape = zml.Tensor.Rng.shape();
// Compile models asynchronously
var start = try std.time.Timer.start();
var fut_mod_prefill = try async.async(zml.compileModel, .{
allocator,
qwen.Qwen3VL.forward,
qwen_tensors,
.{
image_buffer_shape,
prompt_tokens_shape,
image_dim_shape,
token_index_shape,
prompt_shape_shape,
kv_cache_shape,
preprocessor_input.h_resized,
preprocessor_input.w_resized,
rng_shape,
},
platform,
});
var fut_mod_decode = try async.async(zml.compileModel, .{
allocator,
qwen.Qwen3VL.forward_decode,
qwen_tensors,
.{
decode_input_ids_shape,
decode_cache_position_shape,
kv_cache_shape,
decode_mrope_shape,
rng_shape,
},
platform,
});
// Load weights while compiling
log.info("\tLoading Qwen3-VL weights from {s}...", .{model_weights_path});
var qwen_buffers = try store.loadModelById(qwen.Qwen3VL, compiler_arena.allocator(), qwen_tensors, platform);
defer zml.aio.unloadBuffers(&qwen_buffers);
log.info("\tLoaded weights in {D}", .{start.read()});
var qwen_module_prefill = (try fut_mod_prefill.await()).prepare(qwen_buffers);
defer qwen_module_prefill.deinit();
var qwen_module_decode = (try fut_mod_decode.await()).prepare(qwen_buffers);
defer qwen_module_decode.deinit();
log.info("\tCompiled model in {D}", .{start.read()});
log.info("Creating KvCache", .{});
const kv_cache = try qwen.KvCache.initBuffer(kv_shape, platform);
log.info("\tPrompt: {s}", .{prompt});
log.info("\tImage: {s} \n", .{image_path});
var stdout = std.fs.File.stdout().writer(&.{});
const seed: u128 = cli.args.seed orelse @bitCast(std.time.nanoTimestamp());
try generateText(
config,
qwen_tensors,
qwen_module_prefill,
qwen_module_decode,
kv_cache,
tokenizer,
allocator,
seed,
prompt[0..],
image_path[0..],
preprocessor_config,
512,
&stdout.interface,
);
}
fn bool_parser(in: []const u8) error{}!bool {
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
}
// Keep all existing helper functions unchanged
pub const Size = struct {
longest_edge: u64,
shortest_edge: u64,
};
pub const PreprocessorConfig = struct {
size: Size,
patch_size: u32,
temporal_patch_size: u32,
image_mean: []const f32,
image_std: []const f32,
};
fn loadImageWithZignal(comptime T: type, allocator: std.mem.Allocator, path: []const u8) !zignal.Image(T) {
if (std.mem.endsWith(u8, path, ".png") or std.mem.endsWith(u8, path, ".PNG")) {
return zignal.png.load(T, allocator, path, .{});
} else if (std.mem.endsWith(u8, path, ".jpg") or std.mem.endsWith(u8, path, ".jpeg") or
std.mem.endsWith(u8, path, ".JPG") or std.mem.endsWith(u8, path, ".JPEG"))
{
return zignal.jpeg.load(T, allocator, path, .{});
} else {
return error.UnsupportedImageFormat;
}
}
const Input = struct {
image_buffer_chw: zml.HostBuffer,
prompt_tokens: zml.HostBuffer,
prompt_shape: zml.HostBuffer,
image_dim: zml.HostBuffer,
token_index: zml.HostBuffer,
h_resized: u32,
w_resized: u32,
pub fn deinit(self: *Input, allocator: std.mem.Allocator) void {
self.image_buffer_chw.deinit(allocator);
self.prompt_tokens.deinit(allocator);
self.prompt_shape.deinit(allocator);
self.image_dim.deinit(allocator);
self.token_index.deinit(allocator);
}
};
pub fn preprocessor(
allocator: std.mem.Allocator,
tokenizer: zml.tokenizer.Tokenizer,
prompt: []const u8,
config: qwen.Qwen.Config,
preprocessor_config: PreprocessorConfig,
image_path: []const u8,
max_seq_len: u32,
max_side: u32,
) !Input {
// Detect the extension of the file (bmp, png, jpeg)
const ext = if (std.mem.lastIndexOf(u8, image_path, ".")) |idx|
std.mem.trim(u8, image_path[idx + 1 ..], " \t\n\r")
else
"";
const is_bmp = (ext.len == 3 and
std.ascii.toLower(ext[0]) == 'b' and
std.ascii.toLower(ext[1]) == 'm' and
std.ascii.toLower(ext[2]) == 'p');
var height: u32 = undefined;
var width: u32 = undefined;
var rgb_data: []u8 = undefined;
const image: RgbImage = if (is_bmp) img: {
const image_rgb = try loadBmpAsRgb(allocator, image_path);
break :img image_rgb;
} else img: {
var image_from_zignal = loadImageWithZignal(zignal.Rgb, allocator, image_path) catch |err| {
log.err("zignal failed to load {s}: {}. Please convert the image to BMP format (24-bit uncompressed) or use a supported format.", .{ image_path, err });
return err;
};
defer image_from_zignal.deinit(allocator);
height = @as(u32, @intCast(image_from_zignal.rows));
width = @as(u32, @intCast(image_from_zignal.cols));
const rgb_len = height * width * 3;
rgb_data = try allocator.alloc(u8, rgb_len);
errdefer allocator.free(rgb_data);
// Iterate over all pixels and extract R, G, B
var pix_dest: u32 = 0;
var y: u32 = 0;
while (y < height) : (y += 1) {
var x: u32 = 0;
while (x < width) : (x += 1) {
// at() takes (row, col) which is (y, x)
const pixel = image_from_zignal.at(y, x).*;
rgb_data[pix_dest + 0] = pixel.r;
rgb_data[pix_dest + 1] = pixel.g;
rgb_data[pix_dest + 2] = pixel.b;
pix_dest += 3;
}
}
const image_rgb = RgbImage{
.width = width,
.height = height,
.data = rgb_data,
};
break :img image_rgb;
};
height = image.height;
width = image.width;
rgb_data = image.data;
// Create the HostBuffer for the actual image (small) and the padding image (large)
const image_hwc = rgb_data;
const image_buffer = try allocator.alloc(u8, max_side * max_side * 3);
@memset(image_buffer, 0);
const image_small_hwc = zml.HostBuffer.fromSlice(zml.Shape.init(.{ .h = height, .w = width, .c = 3 }, .u8), image_hwc);
const image_buffer_hwc = zml.HostBuffer.fromSlice(zml.Shape.init(.{ .h = max_side, .w = max_side, .c = 3 }, .u8), image_buffer);
// Insert the actual image into the padding image (top left corner of the padding image)
const small_height = @as(usize, @intCast(height));
const small_width = @as(usize, @intCast(width));
const channels = 3;
const row_size_small = small_width * channels;
const row_size_large = @as(usize, @intCast(max_side)) * channels;
// Copy line per line
const small_bytes = image_small_hwc.bytes();
var large_bytes = image_buffer_hwc.mutBytes();
for (0..small_height) |h| {
const src_offset = h * row_size_small;
const dst_offset = h * row_size_large;
@memcpy(large_bytes[dst_offset .. dst_offset + row_size_small], small_bytes[src_offset .. src_offset + row_size_small]);
}
const factor = preprocessor_config.patch_size * preprocessor_config.temporal_patch_size;
const min_pixels = preprocessor_config.size.shortest_edge;
const max_pixels = preprocessor_config.size.longest_edge;
stdx.debug.assert(@max(height, width) / @min(height, width) <= 200, "Invalid image ratio", .{});
// Calculate the resized height and width (rounded to nearest multiple of factor)
var h_resized: u32 = @as(u32, @intFromFloat(@round(stdx.math.divFloat(f64, height, factor)))) * factor;
var w_resized: u32 = @as(u32, @intFromFloat(@round(stdx.math.divFloat(f64, width, factor)))) * factor;
// Adjust if pixel count constraints are violated
if (@as(u64, h_resized) * @as(u64, w_resized) > max_pixels) {
const beta = std.math.sqrt(stdx.math.divFloat(f64, height * width, max_pixels));
const h_scaled = stdx.math.divFloat(f64, height, beta);
const w_scaled = stdx.math.divFloat(f64, width, beta);
h_resized = @max(factor, @as(u32, @intFromFloat(std.math.floor(stdx.math.divFloat(f64, h_scaled, factor)))) * factor);
w_resized = @max(factor, @as(u32, @intFromFloat(std.math.floor(stdx.math.divFloat(f64, w_scaled, factor)))) * factor);
} else if (@as(u64, h_resized) * @as(u64, w_resized) < min_pixels) {
const beta = std.math.sqrt(stdx.math.divFloat(f64, min_pixels, height * width));
const h_scaled = stdx.math.divFloat(f64, height, 1) * beta;
const w_scaled = stdx.math.divFloat(f64, width, 1) * beta;
h_resized = @max(factor, @as(u32, @intFromFloat(std.math.ceil(stdx.math.divFloat(f64, h_scaled, factor)))) * factor);
w_resized = @max(factor, @as(u32, @intFromFloat(std.math.ceil(stdx.math.divFloat(f64, w_scaled, factor)))) * factor);
}
const patch_size = config.vision_config.patch_size;
// Calculate the number of image pad tokens
const number_image_pad_tokens = 1 * (h_resized / patch_size) * (w_resized / patch_size) / std.math.pow(u32, config.vision_config.spatial_merge_size, 2);
// Apply the chat template to the prompt
const prompt_processed = try applyChatTemplate(allocator, tokenizer, prompt, number_image_pad_tokens);
const prompt_shape = prompt_processed.prompt_shape;
//Create the HostBuffer by reallocating the prompt tokens to the max_seq_len
const prompt_buffer = try allocator.realloc(prompt_processed.prompt_tokens, max_seq_len);
const prompt_tokens = zml.HostBuffer.fromSlice(.{ .bs = 1, .seq = max_seq_len }, prompt_buffer);
// Create the HostBuffer for the prompt shape
const prompt_shape_buffer = (try zml.HostBuffer.fromArray(allocator, prompt_shape)).withTags(.{.prompt_shape});
// Create the HostBuffer for the image size
const image_size_buffer = (try zml.HostBuffer.fromArray(allocator, [_]u32{ height, width, 3 })).withTags(.{.chw});
// Create the HostBuffer for the token index
const token_index_buffer = try zml.HostBuffer.empty(allocator, zml.Shape.init(.{}, .i64));
const index: i64 = 0;
@memcpy(token_index_buffer.mutItems(i64), &[_]i64{index});
return Input{
.image_buffer_chw = image_buffer_hwc,
.prompt_tokens = prompt_tokens,
.prompt_shape = prompt_shape_buffer,
.image_dim = image_size_buffer,
.token_index = token_index_buffer,
.h_resized = h_resized,
.w_resized = w_resized,
};
}
// Apply the chat template to the prompt
// Returns the prompt tokens and the prompt shape
// The shape is 4 txt tokens, n vision pad tokens, m + 6 text tokens (4 and 6 are the hardcoded tokens by the chat template )
pub fn applyChatTemplate(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, number_image_pad_tokens: u32) !struct { prompt_tokens: []u32, prompt_shape: [3]u32 } {
var encoder = try tokenizer.encoder();
defer encoder.deinit();
const im_start_id = tokenizer.tokenToId("<|im_start|>") orelse return error.NoSuchToken;
const im_end_id = tokenizer.tokenToId("<|im_end|>") orelse return error.NoSuchToken;
const user = tokenizer.tokenToId("user") orelse return error.NoSuchToken;
const assistant = tokenizer.tokenToId("assistant") orelse return error.NoSuchToken;
const vision_start_id = tokenizer.tokenToId("<|vision_start|>") orelse return error.NoSuchToken;
const vision_end_id = tokenizer.tokenToId("<|vision_end|>") orelse return error.NoSuchToken;
const image_pad_id = tokenizer.tokenToId("<|image_pad|>") orelse return error.NoSuchToken;
const newline = (try encoder.encode("\n"))[0];
var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len);
try tokens.appendSlice(allocator, &.{ im_start_id, user, newline });
try tokens.appendSlice(allocator, &.{vision_start_id});
for (0..number_image_pad_tokens) |i| {
_ = i;
try tokens.appendSlice(allocator, &.{image_pad_id});
}
try tokens.appendSlice(allocator, &.{vision_end_id});
try tokens.appendSlice(allocator, try encoder.encode(prompt));
try tokens.appendSlice(allocator, &.{ im_end_id, newline });
try tokens.appendSlice(allocator, &.{ im_start_id, assistant, newline });
const prompt_tokens = try encoder.encode(prompt);
const prompt_shape: [3]u32 = .{ 4, number_image_pad_tokens, @as(u32, @intCast(prompt_tokens.len)) + 6 };
return .{ .prompt_tokens = try tokens.toOwnedSlice(allocator), .prompt_shape = prompt_shape };
}
pub const RgbImage = struct {
width: u32,
height: u32,
data: []u8,
pub fn deinit(self: *RgbImage, allocator: std.mem.Allocator) void {
allocator.free(self.data);
self.* = undefined;
}
};
pub fn loadBmpAsRgb(allocator: std.mem.Allocator, path: []const u8) !RgbImage {
var file = try std.fs.cwd().openFile(path, .{ .mode = .read_only });
defer file.close();
const max_len = 64 * 1024 * 1024; // 64 MiB safety cap
const file_bytes = try file.readToEndAlloc(allocator, max_len);
defer allocator.free(file_bytes);
if (file_bytes.len < 54) return error.InvalidBmpHeader;
if (!std.mem.eql(u8, file_bytes[0..2], "BM")) return error.InvalidBmpSignature;
const readU16 = struct {
fn f(bytes: []const u8) u16 {
return std.mem.readInt(u16, bytes[0..2], .little);
}
}.f;
const readI32 = struct {
fn f(bytes: []const u8) i32 {
return std.mem.readInt(i32, bytes[0..4], .little);
}
}.f;
const readU32 = struct {
fn f(bytes: []const u8) u32 {
return std.mem.readInt(u32, bytes[0..4], .little);
}
}.f;
const data_offset = readU32(file_bytes[10..14]);
const dib_header_size = readU32(file_bytes[14..18]);
if (dib_header_size < 40) return error.UnsupportedBmpFormat;
const width_i32 = readI32(file_bytes[18..22]);
const height_i32 = readI32(file_bytes[22..26]);
if (width_i32 <= 0 or height_i32 == 0) return error.InvalidBmpDimensions;
const planes = readU16(file_bytes[26..28]);
const bits_per_pixel = readU16(file_bytes[28..30]);
const compression = readU32(file_bytes[30..34]);
if (planes != 1 or compression != 0 or bits_per_pixel != 24) return error.UnsupportedBmpFormat;
const width: u32 = @intCast(width_i32);
const abs_height: u32 = @intCast(if (height_i32 < 0) -height_i32 else height_i32);
const is_top_down = height_i32 < 0;
const row_stride = ((width * 3 + 3) / 4) * 4;
const pixel_array_size = row_stride * abs_height;
if (data_offset + pixel_array_size > file_bytes.len) return error.TruncatedBmp;
const rgb_len = width * abs_height * 3;
var rgb_data = try allocator.alloc(u8, rgb_len);
errdefer allocator.free(rgb_data);
var row: u32 = 0;
while (row < abs_height) : (row += 1) {
const src_row_index = if (is_top_down) row else abs_height - 1 - row;
const src_start = data_offset + src_row_index * row_stride;
const src_slice = file_bytes[src_start .. src_start + row_stride];
const dst_start = row * width * 3;
var col: u32 = 0;
while (col < width) : (col += 1) {
const src_pixel = col * 3;
const dst_pixel = dst_start + col * 3;
// BMP pixels are stored in BGR order.
rgb_data[dst_pixel + 0] = src_slice[src_pixel + 2];
rgb_data[dst_pixel + 1] = src_slice[src_pixel + 1];
rgb_data[dst_pixel + 2] = src_slice[src_pixel + 0];
}
}
return RgbImage{
.width = width,
.height = abs_height,
.data = rgb_data,
};
}