From af8844c1f1adfc89818836238673e436eb8b17f3 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 12 Feb 2025 13:18:27 +0000 Subject: [PATCH] Add model prefix support when loading a model from safetensors, enabling use of a specific model prefix (e.g., ModernBertModel) instead of the full model. Tested with the text embeddings server project. --- zml/aio.zig | 25 +++++++++++++++++++++++-- zml/exe.zig | 20 +++++++++++++++++++- zml/zml.zig | 1 + 3 files changed, 43 insertions(+), 3 deletions(-) diff --git a/zml/aio.zig b/zml/aio.zig index 4dc205c..82d1858 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -504,18 +504,39 @@ pub fn loadBuffers( buffer_store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, +) !zml.Bufferized(Model) { + return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, ""); +} + +/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix. +/// For details about bufferization, see the documentation of Bufferized(T). +/// +/// This will represent the weights of the model, loaded on a specific platform. +/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a +/// `module.ExeWithWeights` ready to be called. +/// +/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function. +pub fn loadBuffersWithPrefix( + comptime Model: type, + init_args: anytype, + buffer_store: BufferStore, + allocator: std.mem.Allocator, + platform: zml.Platform, + prefix: []const u8, ) !zml.Bufferized(Model) { var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); - var model: Model = try zml.aio.populateModel(Model, arena, buffer_store); + + // Get model structure with tensor shapes from the buffer store with prefix + var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix); // If the Model has a "init" function, call it with the given parameters. if (@hasDecl(Model, "init")) { @call(.auto, Model.init, .{&model} ++ init_args); } - return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); + return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix); } /// Creates a bufferized version of a Model from the given BufferStore. For details about diff --git a/zml/exe.zig b/zml/exe.zig index 2ab8bcb..0923f17 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -30,13 +30,31 @@ pub fn compile( args_shapes: ShapeOf(ModuleSignature(func).ArgsT), buffer_store: aio.BufferStore, platform: Platform, +) !FnExe(func) { + return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, ""); +} + +/// Compiles a Model struct with the given configuration and shapes, for the given platform. +/// Uses a prefix for looking up model weights in the buffer store. +/// The steps are: +/// * lookup at tensors available in the store and create a `model: Model` struct with them +/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config +/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments +pub fn compileWithPrefix( + allocator: std.mem.Allocator, + comptime func: anytype, + init_args: anytype, + args_shapes: ShapeOf(ModuleSignature(func).ArgsT), + buffer_store: aio.BufferStore, + platform: Platform, + prefix: []const u8, ) !FnExe(func) { const ModelT = ModuleSignature(func).ModelT; var arena_state = std.heap.ArenaAllocator.init(allocator); defer arena_state.deinit(); const arena = arena_state.allocator(); - var model = try aio.populateModel(ModelT, arena, buffer_store); + var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix); // If the Model has a "init" function, call it with the given parameters. if (@hasDecl(ModelT, "init")) { diff --git a/zml/zml.zig b/zml/zml.zig index bcb86db..45abfcc 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -35,6 +35,7 @@ pub const tokenizer = @import("zml/tokenizer"); pub const call = ops.call; pub const compile = exe.compile; +pub const compileWithPrefix = exe.compileWithPrefix; pub const compileFn = exe.compileFn; pub const compileModel = exe.compileModel; pub const FnExe = exe.FnExe;