Add model prefix support when loading a model from safetensors, enabling use of a specific model prefix (e.g., ModernBertModel) instead of the full model. Tested with the text embeddings server project.
This commit is contained in:
parent
1cafcc3c60
commit
af8844c1f1
25
zml/aio.zig
25
zml/aio.zig
@ -504,18 +504,39 @@ pub fn loadBuffers(
|
||||
buffer_store: BufferStore,
|
||||
allocator: std.mem.Allocator,
|
||||
platform: zml.Platform,
|
||||
) !zml.Bufferized(Model) {
|
||||
return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, "");
|
||||
}
|
||||
|
||||
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
|
||||
/// For details about bufferization, see the documentation of Bufferized(T).
|
||||
///
|
||||
/// This will represent the weights of the model, loaded on a specific platform.
|
||||
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
|
||||
/// `module.ExeWithWeights` ready to be called.
|
||||
///
|
||||
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
|
||||
pub fn loadBuffersWithPrefix(
|
||||
comptime Model: type,
|
||||
init_args: anytype,
|
||||
buffer_store: BufferStore,
|
||||
allocator: std.mem.Allocator,
|
||||
platform: zml.Platform,
|
||||
prefix: []const u8,
|
||||
) !zml.Bufferized(Model) {
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const arena = arena_state.allocator();
|
||||
var model: Model = try zml.aio.populateModel(Model, arena, buffer_store);
|
||||
|
||||
// Get model structure with tensor shapes from the buffer store with prefix
|
||||
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix);
|
||||
|
||||
// If the Model has a "init" function, call it with the given parameters.
|
||||
if (@hasDecl(Model, "init")) {
|
||||
@call(.auto, Model.init, .{&model} ++ init_args);
|
||||
}
|
||||
|
||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix);
|
||||
}
|
||||
|
||||
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
||||
|
||||
20
zml/exe.zig
20
zml/exe.zig
@ -30,13 +30,31 @@ pub fn compile(
|
||||
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||
buffer_store: aio.BufferStore,
|
||||
platform: Platform,
|
||||
) !FnExe(func) {
|
||||
return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, "");
|
||||
}
|
||||
|
||||
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
||||
/// Uses a prefix for looking up model weights in the buffer store.
|
||||
/// The steps are:
|
||||
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
||||
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
||||
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
||||
pub fn compileWithPrefix(
|
||||
allocator: std.mem.Allocator,
|
||||
comptime func: anytype,
|
||||
init_args: anytype,
|
||||
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||
buffer_store: aio.BufferStore,
|
||||
platform: Platform,
|
||||
prefix: []const u8,
|
||||
) !FnExe(func) {
|
||||
const ModelT = ModuleSignature(func).ModelT;
|
||||
|
||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||
defer arena_state.deinit();
|
||||
const arena = arena_state.allocator();
|
||||
var model = try aio.populateModel(ModelT, arena, buffer_store);
|
||||
var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix);
|
||||
|
||||
// If the Model has a "init" function, call it with the given parameters.
|
||||
if (@hasDecl(ModelT, "init")) {
|
||||
|
||||
@ -35,6 +35,7 @@ pub const tokenizer = @import("zml/tokenizer");
|
||||
|
||||
pub const call = ops.call;
|
||||
pub const compile = exe.compile;
|
||||
pub const compileWithPrefix = exe.compileWithPrefix;
|
||||
pub const compileFn = exe.compileFn;
|
||||
pub const compileModel = exe.compileModel;
|
||||
pub const FnExe = exe.FnExe;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user