Add model prefix support when loading a model from safetensors, enabling use of a specific model prefix (e.g., ModernBertModel) instead of the full model. Tested with the text embeddings server project.
This commit is contained in:
parent
1cafcc3c60
commit
af8844c1f1
25
zml/aio.zig
25
zml/aio.zig
@ -504,18 +504,39 @@ pub fn loadBuffers(
|
|||||||
buffer_store: BufferStore,
|
buffer_store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
|
) !zml.Bufferized(Model) {
|
||||||
|
return loadBuffersWithPrefix(Model, init_args, buffer_store, allocator, platform, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a bufferized version of a Model from the given BufferStore with a specified prefix.
|
||||||
|
/// For details about bufferization, see the documentation of Bufferized(T).
|
||||||
|
///
|
||||||
|
/// This will represent the weights of the model, loaded on a specific platform.
|
||||||
|
/// It can be used with a `module.Exe` (a compiled version of the same Model), to make a
|
||||||
|
/// `module.ExeWithWeights` ready to be called.
|
||||||
|
///
|
||||||
|
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
|
||||||
|
pub fn loadBuffersWithPrefix(
|
||||||
|
comptime Model: type,
|
||||||
|
init_args: anytype,
|
||||||
|
buffer_store: BufferStore,
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
platform: zml.Platform,
|
||||||
|
prefix: []const u8,
|
||||||
) !zml.Bufferized(Model) {
|
) !zml.Bufferized(Model) {
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
const arena = arena_state.allocator();
|
const arena = arena_state.allocator();
|
||||||
var model: Model = try zml.aio.populateModel(Model, arena, buffer_store);
|
|
||||||
|
// Get model structure with tensor shapes from the buffer store with prefix
|
||||||
|
var model: Model = try zml.aio.populateModelWithPrefix(Model, arena, buffer_store, prefix);
|
||||||
|
|
||||||
// If the Model has a "init" function, call it with the given parameters.
|
// If the Model has a "init" function, call it with the given parameters.
|
||||||
if (@hasDecl(Model, "init")) {
|
if (@hasDecl(Model, "init")) {
|
||||||
@call(.auto, Model.init, .{&model} ++ init_args);
|
@call(.auto, Model.init, .{&model} ++ init_args);
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, prefix);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
||||||
|
|||||||
20
zml/exe.zig
20
zml/exe.zig
@ -30,13 +30,31 @@ pub fn compile(
|
|||||||
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||||
buffer_store: aio.BufferStore,
|
buffer_store: aio.BufferStore,
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
|
) !FnExe(func) {
|
||||||
|
return compileWithPrefix(allocator, func, init_args, args_shapes, buffer_store, platform, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
|
||||||
|
/// Uses a prefix for looking up model weights in the buffer store.
|
||||||
|
/// The steps are:
|
||||||
|
/// * lookup at tensors available in the store and create a `model: Model` struct with them
|
||||||
|
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
|
||||||
|
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
|
||||||
|
pub fn compileWithPrefix(
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
comptime func: anytype,
|
||||||
|
init_args: anytype,
|
||||||
|
args_shapes: ShapeOf(ModuleSignature(func).ArgsT),
|
||||||
|
buffer_store: aio.BufferStore,
|
||||||
|
platform: Platform,
|
||||||
|
prefix: []const u8,
|
||||||
) !FnExe(func) {
|
) !FnExe(func) {
|
||||||
const ModelT = ModuleSignature(func).ModelT;
|
const ModelT = ModuleSignature(func).ModelT;
|
||||||
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
const arena = arena_state.allocator();
|
const arena = arena_state.allocator();
|
||||||
var model = try aio.populateModel(ModelT, arena, buffer_store);
|
var model = try aio.populateModelWithPrefix(ModelT, arena, buffer_store, prefix);
|
||||||
|
|
||||||
// If the Model has a "init" function, call it with the given parameters.
|
// If the Model has a "init" function, call it with the given parameters.
|
||||||
if (@hasDecl(ModelT, "init")) {
|
if (@hasDecl(ModelT, "init")) {
|
||||||
|
|||||||
@ -35,6 +35,7 @@ pub const tokenizer = @import("zml/tokenizer");
|
|||||||
|
|
||||||
pub const call = ops.call;
|
pub const call = ops.call;
|
||||||
pub const compile = exe.compile;
|
pub const compile = exe.compile;
|
||||||
|
pub const compileWithPrefix = exe.compileWithPrefix;
|
||||||
pub const compileFn = exe.compileFn;
|
pub const compileFn = exe.compileFn;
|
||||||
pub const compileModel = exe.compileModel;
|
pub const compileModel = exe.compileModel;
|
||||||
pub const FnExe = exe.FnExe;
|
pub const FnExe = exe.FnExe;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user