Add model prefix support when loading a model from safetensors, enabling use of a specific model prefix (e.g., ModernBertModel) instead of the full model. Tested with the text embeddings server project.

This commit is contained in:
Tarry Singh 2025-02-12 13:18:27 +00:00
parent 1cafcc3c60
commit af8844c1f1
3 changed files with 43 additions and 3 deletions

View File

@ -504,18 +504,39 @@ pub fn loadBuffers(
buffer_store: BufferStore, buffer_store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
) !zml.Bufferized(Model) {
return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, "");
}
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///
/// This will represent the weights of the model, loaded on a specific platform.
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
/// `module.ExeWithWeights` ready to be called.
///
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
pub fn loadBuffersWithPrefix(
comptime Model: type,
init_args: anytype,
buffer_store: BufferStore,
allocator: std.mem.Allocator,
platform: zml.Platform,
prefix: []const u8,
) !zml.Bufferized(Model) { ) !zml.Bufferized(Model) {
var arena_state = std.heap.ArenaAllocator.init(allocator); var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit(); defer arena_state.deinit();
const arena = arena_state.allocator(); const arena = arena_state.allocator();
var model: Model = try zml.aio.populateModel(Model, arena, buffer_store);
// Get model structure with tensor shapes from the buffer store with prefix
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix);
// If the Model has a "init" function, call it with the given parameters. // If the Model has a "init" function, call it with the given parameters.
if (@hasDecl(Model, "init")) { if (@hasDecl(Model, "init")) {
@call(.auto, Model.init, .{&model} ++ init_args); @call(.auto, Model.init, .{&model} ++ init_args);
} }
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix);
} }
/// Creates a bufferized version of a Model from the given BufferStore. For details about /// Creates a bufferized version of a Model from the given BufferStore. For details about

View File

@ -30,13 +30,31 @@ pub fn compile(
args_shapes: ShapeOf(ModuleSignature(func).ArgsT), args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
buffer_store: aio.BufferStore, buffer_store: aio.BufferStore,
platform: Platform, platform: Platform,
) !FnExe(func) {
return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, "");
}
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Uses a prefix for looking up model weights in the buffer store.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileWithPrefix(
allocator: std.mem.Allocator,
comptime func: anytype,
init_args: anytype,
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
buffer_store: aio.BufferStore,
platform: Platform,
prefix: []const u8,
) !FnExe(func) { ) !FnExe(func) {
const ModelT = ModuleSignature(func).ModelT; const ModelT = ModuleSignature(func).ModelT;
var arena_state = std.heap.ArenaAllocator.init(allocator); var arena_state = std.heap.ArenaAllocator.init(allocator);
defer arena_state.deinit(); defer arena_state.deinit();
const arena = arena_state.allocator(); const arena = arena_state.allocator();
var model = try aio.populateModel(ModelT, arena, buffer_store); var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix);
// If the Model has a "init" function, call it with the given parameters. // If the Model has a "init" function, call it with the given parameters.
if (@hasDecl(ModelT, "init")) { if (@hasDecl(ModelT, "init")) {

View File

@ -35,6 +35,7 @@ pub const tokenizer = @import("zml/tokenizer");
pub const call = ops.call; pub const call = ops.call;
pub const compile = exe.compile; pub const compile = exe.compile;
pub const compileWithPrefix = exe.compileWithPrefix;
pub const compileFn = exe.compileFn; pub const compileFn = exe.compileFn;
pub const compileModel = exe.compileModel; pub const compileModel = exe.compileModel;
pub const FnExe = exe.FnExe; pub const FnExe = exe.FnExe;