746 lines
29 KiB
Zig
746 lines
29 KiB
Zig
|
|
const std = @import("std");
|
||
|
|
const async = @import("async");
|
||
|
|
const zml = @import("zml");
|
||
|
|
const qwen = @import("qwen3_vl.zig");
|
||
|
|
const clap = @import("clap");
|
||
|
|
const stdx = @import("stdx");
|
||
|
|
const zignal = @import("zignal");
|
||
|
|
|
||
|
|
const floats = zml.floats;
|
||
|
|
|
||
|
|
const log = std.log.scoped(.qwen);
|
||
|
|
|
||
|
|
test {
|
||
|
|
std.testing.refAllDecls(@This());
|
||
|
|
}
|
||
|
|
|
||
|
|
pub const std_options: std.Options = .{
|
||
|
|
.log_level = .info,
|
||
|
|
.logFn = async.logFn(std.log.defaultLog),
|
||
|
|
};
|
||
|
|
|
||
|
|
const params = clap.parseParamsComptime(
|
||
|
|
\\--help print this help
|
||
|
|
\\--prompt <STRING> the prompt
|
||
|
|
\\--image <STRING> path to the image file (BMP format)
|
||
|
|
\\--hf-model-path <STRING> path to the directory containing model weights, config and tokenizer
|
||
|
|
\\--seed <UINT> random seed (optional)
|
||
|
|
\\--seq-len <UINT> sequence length (default: 512)
|
||
|
|
\\--create-options <STRING> platform creation options in ZON format, defaults to {}
|
||
|
|
);
|
||
|
|
|
||
|
|
pub fn generateText(
|
||
|
|
config: qwen.Qwen.Config,
|
||
|
|
_: qwen.Qwen3VL,
|
||
|
|
mod_prefill: zml.ModuleExe(qwen.Qwen3VL.forward),
|
||
|
|
mod_decode: zml.ModuleExe(qwen.Qwen3VL.forward_decode),
|
||
|
|
kv_cache_: zml.Bufferized(qwen.KvCache),
|
||
|
|
tokenizer: zml.tokenizer.Tokenizer,
|
||
|
|
allocator: std.mem.Allocator,
|
||
|
|
seed: u128,
|
||
|
|
prompt: []const u8,
|
||
|
|
image_path: []const u8,
|
||
|
|
preprocessor_config: PreprocessorConfig,
|
||
|
|
max_seq_len: u32,
|
||
|
|
writer: *std.Io.Writer,
|
||
|
|
) !void {
|
||
|
|
// Preprocess image and prompt
|
||
|
|
const preprocessor_input = try preprocessor(
|
||
|
|
allocator,
|
||
|
|
tokenizer,
|
||
|
|
prompt,
|
||
|
|
config,
|
||
|
|
preprocessor_config,
|
||
|
|
image_path,
|
||
|
|
max_seq_len,
|
||
|
|
4096, // max_side of the image
|
||
|
|
);
|
||
|
|
defer {
|
||
|
|
preprocessor_input.image_buffer_chw.deinit(allocator);
|
||
|
|
preprocessor_input.prompt_tokens.deinit(allocator);
|
||
|
|
preprocessor_input.prompt_shape.deinit(allocator);
|
||
|
|
preprocessor_input.image_dim.deinit(allocator);
|
||
|
|
preprocessor_input.token_index.deinit(allocator);
|
||
|
|
}
|
||
|
|
|
||
|
|
const platform = mod_decode.platform();
|
||
|
|
var tokenizer_decoder = try tokenizer.decoder();
|
||
|
|
defer tokenizer_decoder.deinit();
|
||
|
|
|
||
|
|
// Extract prompt_shape values before converting to device buffer
|
||
|
|
const prompt_shape_values = preprocessor_input.prompt_shape.items(u32);
|
||
|
|
const total_seq_len = prompt_shape_values[0] + prompt_shape_values[1] + prompt_shape_values[2];
|
||
|
|
|
||
|
|
// Prepare device buffers for prefill
|
||
|
|
const image_buffer_chw = try preprocessor_input.image_buffer_chw.toDevice(platform);
|
||
|
|
defer image_buffer_chw.deinit();
|
||
|
|
|
||
|
|
const prompt_tokens = try preprocessor_input.prompt_tokens.toDevice(platform);
|
||
|
|
defer prompt_tokens.deinit();
|
||
|
|
|
||
|
|
const prompt_shape = try preprocessor_input.prompt_shape.toDevice(platform);
|
||
|
|
defer prompt_shape.deinit();
|
||
|
|
|
||
|
|
const image_dim = try preprocessor_input.image_dim.toDevice(platform);
|
||
|
|
defer image_dim.deinit();
|
||
|
|
|
||
|
|
const token_index = try preprocessor_input.token_index.toDevice(platform);
|
||
|
|
defer token_index.deinit();
|
||
|
|
|
||
|
|
// init RNG and buffers
|
||
|
|
var rng = try zml.Tensor.Rng.init(platform, seed);
|
||
|
|
var generated_token_buffer = [_]u32{undefined};
|
||
|
|
|
||
|
|
// Prefill: process the full prompt with image
|
||
|
|
var kv_cache, var mrope_position_deltas, rng = prefill: {
|
||
|
|
const next_token, const kv_cache, const mrope_deltas, const new_rng = mod_prefill.call(.{
|
||
|
|
image_buffer_chw,
|
||
|
|
prompt_tokens,
|
||
|
|
image_dim,
|
||
|
|
token_index,
|
||
|
|
prompt_shape,
|
||
|
|
kv_cache_,
|
||
|
|
rng,
|
||
|
|
});
|
||
|
|
|
||
|
|
// Extract the generated token
|
||
|
|
_ = try next_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||
|
|
|
||
|
|
break :prefill .{ kv_cache, mrope_deltas, new_rng };
|
||
|
|
};
|
||
|
|
defer zml.aio.unloadBuffers(&kv_cache);
|
||
|
|
defer mrope_position_deltas.deinit();
|
||
|
|
|
||
|
|
// Prepare for token-by-token generation,
|
||
|
|
// start with the token generated based on the full prompt.
|
||
|
|
var current_token = try zml.Buffer.fromSlice(platform, .{ .bs = 1, .seq = 1 }, &generated_token_buffer);
|
||
|
|
defer current_token.deinit();
|
||
|
|
|
||
|
|
const output_tokens_len = max_seq_len - total_seq_len - 1;
|
||
|
|
const start = std.time.microTimestamp();
|
||
|
|
|
||
|
|
// One token has already been generated by the prefill.
|
||
|
|
var num_tokens_generated: usize = 1;
|
||
|
|
|
||
|
|
// Store all generated tokens
|
||
|
|
var generated_tokens = try std.ArrayList(u32).initCapacity(allocator, output_tokens_len);
|
||
|
|
defer generated_tokens.deinit(allocator);
|
||
|
|
|
||
|
|
const token_gen = max_seq_len - total_seq_len;
|
||
|
|
generation: for (0..token_gen) |i| {
|
||
|
|
// Collect and print generated sequence
|
||
|
|
num_tokens_generated += 1;
|
||
|
|
const generated_token = generated_token_buffer[0];
|
||
|
|
try generated_tokens.append(allocator, generated_token);
|
||
|
|
if (try tokenizer_decoder.next(generated_token)) |chunk| {
|
||
|
|
try writer.writeAll(chunk);
|
||
|
|
}
|
||
|
|
|
||
|
|
// check for eos
|
||
|
|
if (i == output_tokens_len) break :generation;
|
||
|
|
if (generated_token == 151643 or generated_token == 151645) break :generation;
|
||
|
|
|
||
|
|
// Current token pos needs to go into a zml.Buffer
|
||
|
|
const cache_position_buffer = &[_]i64{@intCast(total_seq_len - 1 + i)};
|
||
|
|
const cache_position = try zml.Buffer.fromSlice(platform, .{}, cache_position_buffer);
|
||
|
|
defer cache_position.deinit();
|
||
|
|
|
||
|
|
// Call to generate the next token
|
||
|
|
const next_token, const updated_kv_cache, const new_rng = mod_decode.call(.{ current_token, cache_position, kv_cache, mrope_position_deltas, rng });
|
||
|
|
|
||
|
|
current_token = next_token;
|
||
|
|
kv_cache = updated_kv_cache;
|
||
|
|
rng = new_rng;
|
||
|
|
|
||
|
|
// Extract the generated token from the buffer
|
||
|
|
_ = try current_token.toHost(std.mem.sliceAsBytes(&generated_token_buffer));
|
||
|
|
}
|
||
|
|
|
||
|
|
const end = std.time.microTimestamp();
|
||
|
|
const duration = stdx.math.divFloat(f64, end - start, std.time.us_per_s);
|
||
|
|
const speed = @as(f64, @floatFromInt(num_tokens_generated)) / duration;
|
||
|
|
|
||
|
|
// Decode and print all generated tokens at the end
|
||
|
|
std.debug.print("\n", .{});
|
||
|
|
for (generated_tokens.items) |token| {
|
||
|
|
if (try tokenizer_decoder.next(token)) |chunk| {
|
||
|
|
try writer.writeAll(chunk);
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
std.debug.print("\n", .{});
|
||
|
|
log.info("Generated {d} tokens in {:.3}s: {d:.3}tok/s", .{ num_tokens_generated, duration, speed });
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn main() !void {
|
||
|
|
try async.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn asyncMain() !void {
|
||
|
|
log.info(" Qwen3-VL was compiled with {}", .{@import("builtin").mode});
|
||
|
|
|
||
|
|
const allocator = std.heap.c_allocator;
|
||
|
|
|
||
|
|
const parsers = comptime .{
|
||
|
|
.BOOL = bool_parser,
|
||
|
|
.UINT = clap.parsers.int(u32, 0),
|
||
|
|
.STRING = clap.parsers.string,
|
||
|
|
.PATH = clap.parsers.string,
|
||
|
|
};
|
||
|
|
var diag: clap.Diagnostic = .{};
|
||
|
|
var stderr_buffer: [1024]u8 = undefined;
|
||
|
|
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
||
|
|
defer stderr.interface.flush() catch {};
|
||
|
|
|
||
|
|
var cli = clap.parse(clap.Help, ¶ms, parsers, .{
|
||
|
|
.diagnostic = &diag,
|
||
|
|
.allocator = allocator,
|
||
|
|
}) catch |err| {
|
||
|
|
diag.report(&stderr.interface, err) catch {};
|
||
|
|
stderr.interface.writeAll("usage: ") catch {};
|
||
|
|
clap.usage(&stderr.interface, clap.Help, ¶ms) catch {};
|
||
|
|
stderr.interface.writeAll("\n") catch {};
|
||
|
|
return;
|
||
|
|
};
|
||
|
|
defer cli.deinit();
|
||
|
|
|
||
|
|
if (cli.args.help != 0) {
|
||
|
|
clap.help(&stderr.interface, clap.Help, ¶ms, .{}) catch {};
|
||
|
|
return;
|
||
|
|
}
|
||
|
|
|
||
|
|
const hf_model_path = cli.args.@"hf-model-path" orelse {
|
||
|
|
log.err("Missing --hf-model-path", .{});
|
||
|
|
return;
|
||
|
|
};
|
||
|
|
|
||
|
|
const image_path = cli.args.image orelse {
|
||
|
|
log.err("Missing --image", .{});
|
||
|
|
return;
|
||
|
|
};
|
||
|
|
|
||
|
|
const model_config_path = try std.fs.path.join(allocator, &.{ hf_model_path, "config.json" });
|
||
|
|
defer allocator.free(model_config_path);
|
||
|
|
|
||
|
|
const model_weights_path = b: {
|
||
|
|
const simple_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors" });
|
||
|
|
if (async.File.access(simple_path, .{})) {
|
||
|
|
break :b simple_path;
|
||
|
|
} else |_| {
|
||
|
|
allocator.free(simple_path);
|
||
|
|
}
|
||
|
|
|
||
|
|
const sharded_path = try std.fs.path.join(allocator, &.{ hf_model_path, "model.safetensors.index.json" });
|
||
|
|
break :b sharded_path;
|
||
|
|
};
|
||
|
|
defer allocator.free(model_weights_path);
|
||
|
|
|
||
|
|
const model_tokenizer_path = try std.fs.path.join(allocator, &.{ hf_model_path, "tokenizer.json" });
|
||
|
|
defer allocator.free(model_tokenizer_path);
|
||
|
|
|
||
|
|
const preprocessor_config_path = try std.fs.path.join(allocator, &.{ hf_model_path, "preprocessor_config.json" });
|
||
|
|
defer allocator.free(preprocessor_config_path);
|
||
|
|
|
||
|
|
// Load config
|
||
|
|
const config = blk: {
|
||
|
|
var config_json_file = try async.File.open(model_config_path, .{ .mode = .read_only });
|
||
|
|
defer config_json_file.close() catch unreachable;
|
||
|
|
var config_json_buffer: [256]u8 = undefined;
|
||
|
|
var config_reader = config_json_file.reader(&config_json_buffer);
|
||
|
|
var reader = std.json.Reader.init(allocator, &config_reader.interface);
|
||
|
|
defer reader.deinit();
|
||
|
|
const config_obj = try std.json.parseFromTokenSourceLeaky(qwen.Qwen.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||
|
|
break :blk config_obj;
|
||
|
|
};
|
||
|
|
|
||
|
|
// Load preprocessor config
|
||
|
|
const preprocessor_config = blk: {
|
||
|
|
var preprocessor_config_json_file = try async.File.open(preprocessor_config_path, .{ .mode = .read_only });
|
||
|
|
defer preprocessor_config_json_file.close() catch unreachable;
|
||
|
|
var preprocessor_config_json_buffer: [256]u8 = undefined;
|
||
|
|
var preprocessor_config_reader = preprocessor_config_json_file.reader(&preprocessor_config_json_buffer);
|
||
|
|
var reader = std.json.Reader.init(allocator, &preprocessor_config_reader.interface);
|
||
|
|
defer reader.deinit();
|
||
|
|
const preprocessor_config_obj = try std.json.parseFromTokenSourceLeaky(PreprocessorConfig, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||
|
|
break :blk preprocessor_config_obj;
|
||
|
|
};
|
||
|
|
|
||
|
|
var context = try zml.Context.init();
|
||
|
|
defer context.deinit();
|
||
|
|
|
||
|
|
const compilation_options = zml.CompilationOptions{
|
||
|
|
.xla_dump_to = "/tmp/zml/qwen3vl",
|
||
|
|
.sharding_enabled = true,
|
||
|
|
};
|
||
|
|
|
||
|
|
// Initialize ZML platform
|
||
|
|
const create_opts_zon = cli.args.@"create-options" orelse ".{}";
|
||
|
|
const create_opts = std.zon.parse.fromSlice(zml.Platform.CreateOptions, allocator, @ptrCast(create_opts_zon), null, .{ .free_on_error = false }) catch |err| {
|
||
|
|
log.err("Failed to parse --create-options as ZON ({}): {s}", .{ err, create_opts_zon });
|
||
|
|
return err;
|
||
|
|
};
|
||
|
|
|
||
|
|
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
|
||
|
|
context.printAvailablePlatforms(platform);
|
||
|
|
|
||
|
|
var store = try zml.aio.detectFormatAndOpen(allocator, model_weights_path);
|
||
|
|
defer store.deinit();
|
||
|
|
|
||
|
|
// Initialize model
|
||
|
|
const seq_len: u32 = cli.args.@"seq-len" orelse 512;
|
||
|
|
|
||
|
|
// Options for the generation
|
||
|
|
const qwen_options: qwen.Qwen.Options = .{
|
||
|
|
.max_seq_len = seq_len,
|
||
|
|
.sampling_strategy = .{
|
||
|
|
.topk = 3,
|
||
|
|
.temperature = 1.2,
|
||
|
|
},
|
||
|
|
};
|
||
|
|
|
||
|
|
var compiler_arena = std.heap.ArenaAllocator.init(allocator);
|
||
|
|
defer compiler_arena.deinit();
|
||
|
|
|
||
|
|
const qwen_tensors: qwen.Qwen3VL = try qwen.Qwen3VL.init(compiler_arena.allocator(), config, qwen_options, store);
|
||
|
|
|
||
|
|
// Load tokenizer early (needed for preprocessor)
|
||
|
|
var tokenizer = blk: {
|
||
|
|
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
|
||
|
|
var timer = try stdx.time.Timer.start();
|
||
|
|
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
|
||
|
|
|
||
|
|
break :blk try zml.tokenizer.Tokenizer.fromFile(allocator, model_tokenizer_path);
|
||
|
|
};
|
||
|
|
errdefer tokenizer.deinit();
|
||
|
|
|
||
|
|
const prompt = cli.args.prompt orelse "Describe this image.";
|
||
|
|
|
||
|
|
// Use preprocessor to calculate all needed values for compilation
|
||
|
|
const preprocessor_input = try preprocessor(
|
||
|
|
allocator,
|
||
|
|
tokenizer,
|
||
|
|
prompt,
|
||
|
|
config,
|
||
|
|
preprocessor_config,
|
||
|
|
image_path,
|
||
|
|
seq_len,
|
||
|
|
4096, // max_side
|
||
|
|
);
|
||
|
|
defer {
|
||
|
|
preprocessor_input.image_buffer_chw.deinit(allocator);
|
||
|
|
preprocessor_input.prompt_tokens.deinit(allocator);
|
||
|
|
preprocessor_input.prompt_shape.deinit(allocator);
|
||
|
|
preprocessor_input.image_dim.deinit(allocator);
|
||
|
|
preprocessor_input.token_index.deinit(allocator);
|
||
|
|
}
|
||
|
|
|
||
|
|
// Use shapes from preprocessor for compilation
|
||
|
|
const image_buffer_shape = preprocessor_input.image_buffer_chw.shape();
|
||
|
|
const prompt_tokens_shape = preprocessor_input.prompt_tokens.shape();
|
||
|
|
const prompt_shape_shape = preprocessor_input.prompt_shape.shape();
|
||
|
|
const image_dim_shape = preprocessor_input.image_dim.shape();
|
||
|
|
const token_index_shape = preprocessor_input.token_index.shape();
|
||
|
|
|
||
|
|
// Specify shapes for decode
|
||
|
|
const decode_input_ids_shape = zml.Shape.init(.{ .bs = 1, .seq = 1 }, .u32);
|
||
|
|
const decode_cache_position_shape = zml.Shape.init(.{}, .i64);
|
||
|
|
const decode_mrope_shape = zml.Shape.init(.{ .seq = 1 }, .i32);
|
||
|
|
|
||
|
|
const dtype = qwen_tensors.qwen.text_model.embed_tokens.weight.dtype();
|
||
|
|
const kv_shape = zml.Shape.init(.{
|
||
|
|
.bs = 1,
|
||
|
|
.layer = config.text_config.num_hidden_layers,
|
||
|
|
.k = seq_len,
|
||
|
|
.h = config.text_config.num_key_value_heads,
|
||
|
|
.hd = config.text_config.head_dim,
|
||
|
|
}, dtype).withSharding(.{.h});
|
||
|
|
|
||
|
|
const kv_cache_shape: zml.ShapeOf(qwen.KvCache) = qwen.KvCache.initShape(kv_shape);
|
||
|
|
const rng_shape = zml.Tensor.Rng.shape();
|
||
|
|
|
||
|
|
// Compile models asynchronously
|
||
|
|
var start = try std.time.Timer.start();
|
||
|
|
var fut_mod_prefill = try async.async(zml.compileModel, .{
|
||
|
|
allocator,
|
||
|
|
qwen.Qwen3VL.forward,
|
||
|
|
qwen_tensors,
|
||
|
|
.{
|
||
|
|
image_buffer_shape,
|
||
|
|
prompt_tokens_shape,
|
||
|
|
image_dim_shape,
|
||
|
|
token_index_shape,
|
||
|
|
prompt_shape_shape,
|
||
|
|
kv_cache_shape,
|
||
|
|
preprocessor_input.h_resized,
|
||
|
|
preprocessor_input.w_resized,
|
||
|
|
rng_shape,
|
||
|
|
},
|
||
|
|
platform,
|
||
|
|
});
|
||
|
|
|
||
|
|
var fut_mod_decode = try async.async(zml.compileModel, .{
|
||
|
|
allocator,
|
||
|
|
qwen.Qwen3VL.forward_decode,
|
||
|
|
qwen_tensors,
|
||
|
|
.{
|
||
|
|
decode_input_ids_shape,
|
||
|
|
decode_cache_position_shape,
|
||
|
|
kv_cache_shape,
|
||
|
|
decode_mrope_shape,
|
||
|
|
rng_shape,
|
||
|
|
},
|
||
|
|
platform,
|
||
|
|
});
|
||
|
|
|
||
|
|
// Load weights while compiling
|
||
|
|
log.info("\tLoading Qwen3-VL weights from {s}...", .{model_weights_path});
|
||
|
|
var qwen_buffers = try store.loadModelById(qwen.Qwen3VL, compiler_arena.allocator(), qwen_tensors, platform);
|
||
|
|
defer zml.aio.unloadBuffers(&qwen_buffers);
|
||
|
|
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||
|
|
|
||
|
|
var qwen_module_prefill = (try fut_mod_prefill.await()).prepare(qwen_buffers);
|
||
|
|
defer qwen_module_prefill.deinit();
|
||
|
|
var qwen_module_decode = (try fut_mod_decode.await()).prepare(qwen_buffers);
|
||
|
|
defer qwen_module_decode.deinit();
|
||
|
|
log.info("✅\tCompiled model in {D}", .{start.read()});
|
||
|
|
|
||
|
|
log.info("Creating KvCache", .{});
|
||
|
|
const kv_cache = try qwen.KvCache.initBuffer(kv_shape, platform);
|
||
|
|
|
||
|
|
log.info("✅\tPrompt: {s}", .{prompt});
|
||
|
|
log.info("✅\tImage: {s} \n", .{image_path});
|
||
|
|
|
||
|
|
var stdout = std.fs.File.stdout().writer(&.{});
|
||
|
|
|
||
|
|
const seed: u128 = cli.args.seed orelse @bitCast(std.time.nanoTimestamp());
|
||
|
|
|
||
|
|
try generateText(
|
||
|
|
config,
|
||
|
|
qwen_tensors,
|
||
|
|
qwen_module_prefill,
|
||
|
|
qwen_module_decode,
|
||
|
|
kv_cache,
|
||
|
|
tokenizer,
|
||
|
|
allocator,
|
||
|
|
seed,
|
||
|
|
prompt[0..],
|
||
|
|
image_path[0..],
|
||
|
|
preprocessor_config,
|
||
|
|
512,
|
||
|
|
&stdout.interface,
|
||
|
|
);
|
||
|
|
}
|
||
|
|
|
||
|
|
fn bool_parser(in: []const u8) error{}!bool {
|
||
|
|
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Keep all existing helper functions unchanged
|
||
|
|
pub const Size = struct {
|
||
|
|
longest_edge: u64,
|
||
|
|
shortest_edge: u64,
|
||
|
|
};
|
||
|
|
|
||
|
|
pub const PreprocessorConfig = struct {
|
||
|
|
size: Size,
|
||
|
|
patch_size: u32,
|
||
|
|
temporal_patch_size: u32,
|
||
|
|
image_mean: []const f32,
|
||
|
|
image_std: []const f32,
|
||
|
|
};
|
||
|
|
|
||
|
|
fn loadImageWithZignal(comptime T: type, allocator: std.mem.Allocator, path: []const u8) !zignal.Image(T) {
|
||
|
|
if (std.mem.endsWith(u8, path, ".png") or std.mem.endsWith(u8, path, ".PNG")) {
|
||
|
|
return zignal.png.load(T, allocator, path, .{});
|
||
|
|
} else if (std.mem.endsWith(u8, path, ".jpg") or std.mem.endsWith(u8, path, ".jpeg") or
|
||
|
|
std.mem.endsWith(u8, path, ".JPG") or std.mem.endsWith(u8, path, ".JPEG"))
|
||
|
|
{
|
||
|
|
return zignal.jpeg.load(T, allocator, path, .{});
|
||
|
|
} else {
|
||
|
|
return error.UnsupportedImageFormat;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
const Input = struct {
|
||
|
|
image_buffer_chw: zml.HostBuffer,
|
||
|
|
prompt_tokens: zml.HostBuffer,
|
||
|
|
prompt_shape: zml.HostBuffer,
|
||
|
|
image_dim: zml.HostBuffer,
|
||
|
|
token_index: zml.HostBuffer,
|
||
|
|
h_resized: u32,
|
||
|
|
w_resized: u32,
|
||
|
|
|
||
|
|
pub fn deinit(self: *Input, allocator: std.mem.Allocator) void {
|
||
|
|
self.image_buffer_chw.deinit(allocator);
|
||
|
|
self.prompt_tokens.deinit(allocator);
|
||
|
|
self.prompt_shape.deinit(allocator);
|
||
|
|
self.image_dim.deinit(allocator);
|
||
|
|
self.token_index.deinit(allocator);
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
pub fn preprocessor(
|
||
|
|
allocator: std.mem.Allocator,
|
||
|
|
tokenizer: zml.tokenizer.Tokenizer,
|
||
|
|
prompt: []const u8,
|
||
|
|
config: qwen.Qwen.Config,
|
||
|
|
preprocessor_config: PreprocessorConfig,
|
||
|
|
image_path: []const u8,
|
||
|
|
max_seq_len: u32,
|
||
|
|
max_side: u32,
|
||
|
|
) !Input {
|
||
|
|
|
||
|
|
// Detect the extension of the file (bmp, png, jpeg)
|
||
|
|
const ext = if (std.mem.lastIndexOf(u8, image_path, ".")) |idx|
|
||
|
|
std.mem.trim(u8, image_path[idx + 1 ..], " \t\n\r")
|
||
|
|
else
|
||
|
|
"";
|
||
|
|
|
||
|
|
const is_bmp = (ext.len == 3 and
|
||
|
|
std.ascii.toLower(ext[0]) == 'b' and
|
||
|
|
std.ascii.toLower(ext[1]) == 'm' and
|
||
|
|
std.ascii.toLower(ext[2]) == 'p');
|
||
|
|
|
||
|
|
var height: u32 = undefined;
|
||
|
|
var width: u32 = undefined;
|
||
|
|
var rgb_data: []u8 = undefined;
|
||
|
|
|
||
|
|
const image: RgbImage = if (is_bmp) img: {
|
||
|
|
const image_rgb = try loadBmpAsRgb(allocator, image_path);
|
||
|
|
break :img image_rgb;
|
||
|
|
} else img: {
|
||
|
|
var image_from_zignal = loadImageWithZignal(zignal.Rgb, allocator, image_path) catch |err| {
|
||
|
|
log.err("zignal failed to load {s}: {}. Please convert the image to BMP format (24-bit uncompressed) or use a supported format.", .{ image_path, err });
|
||
|
|
return err;
|
||
|
|
};
|
||
|
|
defer image_from_zignal.deinit(allocator);
|
||
|
|
height = @as(u32, @intCast(image_from_zignal.rows));
|
||
|
|
width = @as(u32, @intCast(image_from_zignal.cols));
|
||
|
|
const rgb_len = height * width * 3;
|
||
|
|
rgb_data = try allocator.alloc(u8, rgb_len);
|
||
|
|
errdefer allocator.free(rgb_data);
|
||
|
|
|
||
|
|
// Iterate over all pixels and extract R, G, B
|
||
|
|
var pix_dest: u32 = 0;
|
||
|
|
var y: u32 = 0;
|
||
|
|
while (y < height) : (y += 1) {
|
||
|
|
var x: u32 = 0;
|
||
|
|
while (x < width) : (x += 1) {
|
||
|
|
// at() takes (row, col) which is (y, x)
|
||
|
|
const pixel = image_from_zignal.at(y, x).*;
|
||
|
|
|
||
|
|
rgb_data[pix_dest + 0] = pixel.r;
|
||
|
|
rgb_data[pix_dest + 1] = pixel.g;
|
||
|
|
rgb_data[pix_dest + 2] = pixel.b;
|
||
|
|
pix_dest += 3;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
const image_rgb = RgbImage{
|
||
|
|
.width = width,
|
||
|
|
.height = height,
|
||
|
|
.data = rgb_data,
|
||
|
|
};
|
||
|
|
break :img image_rgb;
|
||
|
|
};
|
||
|
|
|
||
|
|
height = image.height;
|
||
|
|
width = image.width;
|
||
|
|
rgb_data = image.data;
|
||
|
|
|
||
|
|
// Create the HostBuffer for the actual image (small) and the padding image (large)
|
||
|
|
const image_hwc = rgb_data;
|
||
|
|
const image_buffer = try allocator.alloc(u8, max_side * max_side * 3);
|
||
|
|
@memset(image_buffer, 0);
|
||
|
|
const image_small_hwc = zml.HostBuffer.fromSlice(zml.Shape.init(.{ .h = height, .w = width, .c = 3 }, .u8), image_hwc);
|
||
|
|
const image_buffer_hwc = zml.HostBuffer.fromSlice(zml.Shape.init(.{ .h = max_side, .w = max_side, .c = 3 }, .u8), image_buffer);
|
||
|
|
|
||
|
|
// Insert the actual image into the padding image (top left corner of the padding image)
|
||
|
|
const small_height = @as(usize, @intCast(height));
|
||
|
|
const small_width = @as(usize, @intCast(width));
|
||
|
|
const channels = 3;
|
||
|
|
const row_size_small = small_width * channels;
|
||
|
|
const row_size_large = @as(usize, @intCast(max_side)) * channels;
|
||
|
|
|
||
|
|
// Copy line per line
|
||
|
|
const small_bytes = image_small_hwc.bytes();
|
||
|
|
var large_bytes = image_buffer_hwc.mutBytes();
|
||
|
|
for (0..small_height) |h| {
|
||
|
|
const src_offset = h * row_size_small;
|
||
|
|
const dst_offset = h * row_size_large;
|
||
|
|
@memcpy(large_bytes[dst_offset .. dst_offset + row_size_small], small_bytes[src_offset .. src_offset + row_size_small]);
|
||
|
|
}
|
||
|
|
|
||
|
|
const factor = preprocessor_config.patch_size * preprocessor_config.temporal_patch_size;
|
||
|
|
const min_pixels = preprocessor_config.size.shortest_edge;
|
||
|
|
const max_pixels = preprocessor_config.size.longest_edge;
|
||
|
|
stdx.debug.assert(@max(height, width) / @min(height, width) <= 200, "Invalid image ratio", .{});
|
||
|
|
|
||
|
|
// Calculate the resized height and width (rounded to nearest multiple of factor)
|
||
|
|
var h_resized: u32 = @as(u32, @intFromFloat(@round(stdx.math.divFloat(f64, height, factor)))) * factor;
|
||
|
|
var w_resized: u32 = @as(u32, @intFromFloat(@round(stdx.math.divFloat(f64, width, factor)))) * factor;
|
||
|
|
|
||
|
|
// Adjust if pixel count constraints are violated
|
||
|
|
if (@as(u64, h_resized) * @as(u64, w_resized) > max_pixels) {
|
||
|
|
const beta = std.math.sqrt(stdx.math.divFloat(f64, height * width, max_pixels));
|
||
|
|
const h_scaled = stdx.math.divFloat(f64, height, beta);
|
||
|
|
const w_scaled = stdx.math.divFloat(f64, width, beta);
|
||
|
|
h_resized = @max(factor, @as(u32, @intFromFloat(std.math.floor(stdx.math.divFloat(f64, h_scaled, factor)))) * factor);
|
||
|
|
w_resized = @max(factor, @as(u32, @intFromFloat(std.math.floor(stdx.math.divFloat(f64, w_scaled, factor)))) * factor);
|
||
|
|
} else if (@as(u64, h_resized) * @as(u64, w_resized) < min_pixels) {
|
||
|
|
const beta = std.math.sqrt(stdx.math.divFloat(f64, min_pixels, height * width));
|
||
|
|
const h_scaled = stdx.math.divFloat(f64, height, 1) * beta;
|
||
|
|
const w_scaled = stdx.math.divFloat(f64, width, 1) * beta;
|
||
|
|
h_resized = @max(factor, @as(u32, @intFromFloat(std.math.ceil(stdx.math.divFloat(f64, h_scaled, factor)))) * factor);
|
||
|
|
w_resized = @max(factor, @as(u32, @intFromFloat(std.math.ceil(stdx.math.divFloat(f64, w_scaled, factor)))) * factor);
|
||
|
|
}
|
||
|
|
const patch_size = config.vision_config.patch_size;
|
||
|
|
|
||
|
|
// Calculate the number of image pad tokens
|
||
|
|
const number_image_pad_tokens = 1 * (h_resized / patch_size) * (w_resized / patch_size) / std.math.pow(u32, config.vision_config.spatial_merge_size, 2);
|
||
|
|
|
||
|
|
// Apply the chat template to the prompt
|
||
|
|
const prompt_processed = try applyChatTemplate(allocator, tokenizer, prompt, number_image_pad_tokens);
|
||
|
|
const prompt_shape = prompt_processed.prompt_shape;
|
||
|
|
|
||
|
|
//Create the HostBuffer by reallocating the prompt tokens to the max_seq_len
|
||
|
|
const prompt_buffer = try allocator.realloc(prompt_processed.prompt_tokens, max_seq_len);
|
||
|
|
const prompt_tokens = zml.HostBuffer.fromSlice(.{ .bs = 1, .seq = max_seq_len }, prompt_buffer);
|
||
|
|
// Create the HostBuffer for the prompt shape
|
||
|
|
const prompt_shape_buffer = (try zml.HostBuffer.fromArray(allocator, prompt_shape)).withTags(.{.prompt_shape});
|
||
|
|
// Create the HostBuffer for the image size
|
||
|
|
const image_size_buffer = (try zml.HostBuffer.fromArray(allocator, [_]u32{ height, width, 3 })).withTags(.{.chw});
|
||
|
|
// Create the HostBuffer for the token index
|
||
|
|
const token_index_buffer = try zml.HostBuffer.empty(allocator, zml.Shape.init(.{}, .i64));
|
||
|
|
const index: i64 = 0;
|
||
|
|
@memcpy(token_index_buffer.mutItems(i64), &[_]i64{index});
|
||
|
|
|
||
|
|
return Input{
|
||
|
|
.image_buffer_chw = image_buffer_hwc,
|
||
|
|
.prompt_tokens = prompt_tokens,
|
||
|
|
.prompt_shape = prompt_shape_buffer,
|
||
|
|
.image_dim = image_size_buffer,
|
||
|
|
.token_index = token_index_buffer,
|
||
|
|
.h_resized = h_resized,
|
||
|
|
.w_resized = w_resized,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
// Apply the chat template to the prompt
|
||
|
|
// Returns the prompt tokens and the prompt shape
|
||
|
|
// The shape is 4 txt tokens, n vision pad tokens, m + 6 text tokens (4 and 6 are the hardcoded tokens by the chat template )
|
||
|
|
pub fn applyChatTemplate(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8, number_image_pad_tokens: u32) !struct { prompt_tokens: []u32, prompt_shape: [3]u32 } {
|
||
|
|
var encoder = try tokenizer.encoder();
|
||
|
|
defer encoder.deinit();
|
||
|
|
const im_start_id = tokenizer.tokenToId("<|im_start|>") orelse return error.NoSuchToken;
|
||
|
|
const im_end_id = tokenizer.tokenToId("<|im_end|>") orelse return error.NoSuchToken;
|
||
|
|
const user = tokenizer.tokenToId("user") orelse return error.NoSuchToken;
|
||
|
|
const assistant = tokenizer.tokenToId("assistant") orelse return error.NoSuchToken;
|
||
|
|
const vision_start_id = tokenizer.tokenToId("<|vision_start|>") orelse return error.NoSuchToken;
|
||
|
|
const vision_end_id = tokenizer.tokenToId("<|vision_end|>") orelse return error.NoSuchToken;
|
||
|
|
const image_pad_id = tokenizer.tokenToId("<|image_pad|>") orelse return error.NoSuchToken;
|
||
|
|
const newline = (try encoder.encode("\n"))[0];
|
||
|
|
|
||
|
|
var tokens: std.ArrayList(u32) = try .initCapacity(allocator, prompt.len);
|
||
|
|
try tokens.appendSlice(allocator, &.{ im_start_id, user, newline });
|
||
|
|
try tokens.appendSlice(allocator, &.{vision_start_id});
|
||
|
|
for (0..number_image_pad_tokens) |i| {
|
||
|
|
_ = i;
|
||
|
|
try tokens.appendSlice(allocator, &.{image_pad_id});
|
||
|
|
}
|
||
|
|
try tokens.appendSlice(allocator, &.{vision_end_id});
|
||
|
|
try tokens.appendSlice(allocator, try encoder.encode(prompt));
|
||
|
|
try tokens.appendSlice(allocator, &.{ im_end_id, newline });
|
||
|
|
try tokens.appendSlice(allocator, &.{ im_start_id, assistant, newline });
|
||
|
|
const prompt_tokens = try encoder.encode(prompt);
|
||
|
|
const prompt_shape: [3]u32 = .{ 4, number_image_pad_tokens, @as(u32, @intCast(prompt_tokens.len)) + 6 };
|
||
|
|
return .{ .prompt_tokens = try tokens.toOwnedSlice(allocator), .prompt_shape = prompt_shape };
|
||
|
|
}
|
||
|
|
|
||
|
|
pub const RgbImage = struct {
|
||
|
|
width: u32,
|
||
|
|
height: u32,
|
||
|
|
data: []u8,
|
||
|
|
|
||
|
|
pub fn deinit(self: *RgbImage, allocator: std.mem.Allocator) void {
|
||
|
|
allocator.free(self.data);
|
||
|
|
self.* = undefined;
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
pub fn loadBmpAsRgb(allocator: std.mem.Allocator, path: []const u8) !RgbImage {
|
||
|
|
var file = try std.fs.cwd().openFile(path, .{ .mode = .read_only });
|
||
|
|
defer file.close();
|
||
|
|
|
||
|
|
const max_len = 64 * 1024 * 1024; // 64 MiB safety cap
|
||
|
|
const file_bytes = try file.readToEndAlloc(allocator, max_len);
|
||
|
|
defer allocator.free(file_bytes);
|
||
|
|
|
||
|
|
if (file_bytes.len < 54) return error.InvalidBmpHeader;
|
||
|
|
if (!std.mem.eql(u8, file_bytes[0..2], "BM")) return error.InvalidBmpSignature;
|
||
|
|
|
||
|
|
const readU16 = struct {
|
||
|
|
fn f(bytes: []const u8) u16 {
|
||
|
|
return std.mem.readInt(u16, bytes[0..2], .little);
|
||
|
|
}
|
||
|
|
}.f;
|
||
|
|
const readI32 = struct {
|
||
|
|
fn f(bytes: []const u8) i32 {
|
||
|
|
return std.mem.readInt(i32, bytes[0..4], .little);
|
||
|
|
}
|
||
|
|
}.f;
|
||
|
|
const readU32 = struct {
|
||
|
|
fn f(bytes: []const u8) u32 {
|
||
|
|
return std.mem.readInt(u32, bytes[0..4], .little);
|
||
|
|
}
|
||
|
|
}.f;
|
||
|
|
|
||
|
|
const data_offset = readU32(file_bytes[10..14]);
|
||
|
|
const dib_header_size = readU32(file_bytes[14..18]);
|
||
|
|
if (dib_header_size < 40) return error.UnsupportedBmpFormat;
|
||
|
|
|
||
|
|
const width_i32 = readI32(file_bytes[18..22]);
|
||
|
|
const height_i32 = readI32(file_bytes[22..26]);
|
||
|
|
if (width_i32 <= 0 or height_i32 == 0) return error.InvalidBmpDimensions;
|
||
|
|
|
||
|
|
const planes = readU16(file_bytes[26..28]);
|
||
|
|
const bits_per_pixel = readU16(file_bytes[28..30]);
|
||
|
|
const compression = readU32(file_bytes[30..34]);
|
||
|
|
if (planes != 1 or compression != 0 or bits_per_pixel != 24) return error.UnsupportedBmpFormat;
|
||
|
|
|
||
|
|
const width: u32 = @intCast(width_i32);
|
||
|
|
const abs_height: u32 = @intCast(if (height_i32 < 0) -height_i32 else height_i32);
|
||
|
|
const is_top_down = height_i32 < 0;
|
||
|
|
|
||
|
|
const row_stride = ((width * 3 + 3) / 4) * 4;
|
||
|
|
const pixel_array_size = row_stride * abs_height;
|
||
|
|
if (data_offset + pixel_array_size > file_bytes.len) return error.TruncatedBmp;
|
||
|
|
|
||
|
|
const rgb_len = width * abs_height * 3;
|
||
|
|
var rgb_data = try allocator.alloc(u8, rgb_len);
|
||
|
|
errdefer allocator.free(rgb_data);
|
||
|
|
|
||
|
|
var row: u32 = 0;
|
||
|
|
while (row < abs_height) : (row += 1) {
|
||
|
|
const src_row_index = if (is_top_down) row else abs_height - 1 - row;
|
||
|
|
const src_start = data_offset + src_row_index * row_stride;
|
||
|
|
const src_slice = file_bytes[src_start .. src_start + row_stride];
|
||
|
|
|
||
|
|
const dst_start = row * width * 3;
|
||
|
|
var col: u32 = 0;
|
||
|
|
while (col < width) : (col += 1) {
|
||
|
|
const src_pixel = col * 3;
|
||
|
|
const dst_pixel = dst_start + col * 3;
|
||
|
|
// BMP pixels are stored in BGR order.
|
||
|
|
rgb_data[dst_pixel + 0] = src_slice[src_pixel + 2];
|
||
|
|
rgb_data[dst_pixel + 1] = src_slice[src_pixel + 1];
|
||
|
|
rgb_data[dst_pixel + 2] = src_slice[src_pixel + 0];
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return RgbImage{
|
||
|
|
.width = width,
|
||
|
|
.height = abs_height,
|
||
|
|
.data = rgb_data,
|
||
|
|
};
|
||
|
|
}
|