2023-10-13 16:08:08 +00:00
const std = @import ( " std " ) ;
2024-10-28 11:21:46 +00:00
2023-10-13 16:08:08 +00:00
const stdx = @import ( " stdx " ) ;
const aio = @import ( " aio.zig " ) ;
const Buffer = @import ( " buffer.zig " ) . Buffer ;
const Bufferized = @import ( " tensor.zig " ) . Bufferized ;
2025-08-20 10:27:54 +00:00
const callback = @import ( " callback.zig " ) ;
2023-10-13 16:08:08 +00:00
const CompilationContext = @import ( " module.zig " ) . CompilationContext ;
2024-10-28 11:21:46 +00:00
const meta = @import ( " meta.zig " ) ;
const pjrt = @import ( " pjrtx.zig " ) ;
2023-10-13 16:08:08 +00:00
const Platform = @import ( " platform.zig " ) . Platform ;
const Shape = @import ( " shape.zig " ) . Shape ;
const ShapeOf = @import ( " tensor.zig " ) . ShapeOf ;
2024-12-10 09:36:37 +00:00
const log = std . log . scoped ( . @ " zml/exe " ) ;
2023-10-13 16:08:08 +00:00
test {
std . testing . refAllDecls ( @This ( ) ) ;
}
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compile (
allocator : std . mem . Allocator ,
comptime func : anytype ,
init_args : anytype ,
args_shapes : ShapeOf ( ModuleSignature ( func ) . ArgsT ) ,
buffer_store : aio . BufferStore ,
platform : Platform ,
2025-02-12 13:18:27 +00:00
) ! FnExe ( func ) {
return compileWithPrefix ( allocator , func , init_args , args_shapes , buffer_store , platform , " " ) ;
}
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Uses a prefix for looking up model weights in the buffer store.
/// The steps are:
/// * lookup at tensors available in the store and create a `model: Model` struct with them
/// * call `model.init(init_args)` to fields of the model that aren't Tensor, ie hyperparemeters/config
/// * generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileWithPrefix (
allocator : std . mem . Allocator ,
comptime func : anytype ,
init_args : anytype ,
args_shapes : ShapeOf ( ModuleSignature ( func ) . ArgsT ) ,
buffer_store : aio . BufferStore ,
platform : Platform ,
prefix : [ ] const u8 ,
2023-10-13 16:08:08 +00:00
) ! FnExe ( func ) {
const ModelT = ModuleSignature ( func ) . ModelT ;
var arena_state = std . heap . ArenaAllocator . init ( allocator ) ;
defer arena_state . deinit ( ) ;
const arena = arena_state . allocator ( ) ;
2025-02-12 13:18:27 +00:00
var model = try aio . populateModelWithPrefix ( ModelT , arena , buffer_store , prefix ) ;
2023-10-13 16:08:08 +00:00
// If the Model has a "init" function, call it with the given parameters.
if ( @hasDecl ( ModelT , " init " ) ) {
// TODO(Corentin,@Improvement): Add a warning/error if there is no init function but init_args is non-void.
@call ( . auto , ModelT . init , . { @as ( * ModelT , & model ) } + + init_args ) ;
}
return compileModel ( allocator , func , model , args_shapes , platform ) ;
}
/// Compiles a Model struct with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling `model.forward` with tensor of the given shapes and other arguments
pub fn compileModel (
allocator : std . mem . Allocator ,
comptime func : anytype ,
model : ModuleSignature ( func ) . ModelT ,
args_shapes : ShapeOf ( ModuleSignature ( func ) . ArgsT ) ,
platform : Platform ,
) ! FnExe ( func ) {
const ModelT = ModuleSignature ( func ) . ModelT ;
const name = @typeName ( ModelT ) + + " .forward " ;
2025-08-28 14:39:21 +00:00
log . info ( " Compiling {s} with {f} " , . { name , stdx . fmt . any ( args_shapes ) } ) ;
2023-10-13 16:08:08 +00:00
var context = try CompilationContext . init ( allocator , name , platform ) ;
defer context . deinit ( ) ;
return . { . inner = try context . compileInternal ( allocator , func , . { model } + + args_shapes ) } ;
}
/// Compiles a function with the given configuration and shapes, for the given platform.
/// Generate MLIR by calling the given function with tensor of the given shapes.
pub fn compileFn (
allocator : std . mem . Allocator ,
comptime func : anytype ,
args : ShapeOf ( stdx . meta . FnArgs ( func ) ) ,
platform : Platform ,
) ! FnExe ( func ) {
2024-05-15 17:54:52 +00:00
var pretty_name = try prettyFnName ( func , allocator ) ;
defer pretty_name . deinit ( allocator ) ;
var context = try CompilationContext . init ( allocator , pretty_name . items , platform ) ;
2023-10-13 16:08:08 +00:00
defer context . deinit ( ) ;
return . { . inner = try context . compileInternal ( allocator , func , args ) } ;
}
pub fn FnExe ( comptime func : anytype ) type {
return Exe ( stdx . meta . FnArgs ( func ) , stdx . meta . FnResult ( func ) ) ;
}
/// Represents a ZML model, compiled into a PJRT executable, and ready to call.
/// The buffers for the model weights are saved inside the struct and will be used in `call`.
/// You only need to pass the remaining arguments.
/// Creating a `ModuleExe` is a two steps proccess:
///
/// ```
/// const exe: zml.FnExe(MyModel.forward) = try zml.compile(allocator, MyModel.forward, init_args, model_shapes, buffer_store, platform);`
/// const module: zml.ModuleExe(MyModel.forward) = exe.prepare(model_buffers);
/// ```
pub fn ModuleExe ( comptime func : anytype ) type {
const AllArgs = stdx . meta . FnArgs ( func ) ;
2024-07-02 14:19:04 +00:00
const len = @typeInfo ( AllArgs ) . @ " struct " . fields . len ;
2023-10-13 16:08:08 +00:00
stdx . debug . assertComptime ( len > 0 , " ModuleExe expects a function with at least one argument where the first one is treated as the module, got {} " , . { func } ) ;
return Exe ( stdx . meta . Tail ( AllArgs ) , stdx . meta . FnResult ( func ) ) ;
}
// making this a struct force all fields to be evaluted on creation,
// which gives a better error stacktrace
// than delaying the error to when the object fields are read.
const Sign = struct {
ModelT : type ,
ArgsT : type ,
ReturnT : type ,
} ;
pub fn ModuleSignature ( comptime func : anytype ) Sign {
const AllArgsT = stdx . meta . FnArgs ( func ) ;
2024-07-02 14:19:04 +00:00
const len = @typeInfo ( AllArgsT ) . @ " struct " . fields . len ;
2023-10-13 16:08:08 +00:00
stdx . debug . assertComptime ( len > 0 , " ModuleExe expects a function with at least one argument where the first one is treated as the module, got {} " , . { func } ) ;
return . {
. ModelT = stdx . meta . Head ( AllArgsT ) ,
. ArgsT = stdx . meta . Tail ( AllArgsT ) ,
. ReturnT = stdx . meta . FnResult ( func ) ,
} ;
}
/// Represents an MLIR module compiled into a PJRT executable.
/// The BaseExe is a plain old struct and doesn't have information about Zig types.
///
/// It also contains pre-allocated buffers so that we can pass them to PJRT_LoadedExecutable_Execute
/// without allocations.
pub const BaseExe = struct {
/// The platform for which this module was compiled.
platform : Platform ,
/// The PJRT executable representing the compiled module.
exe : * pjrt . LoadedExecutable ,
2024-12-10 09:36:37 +00:00
/// The execution context for this executable.
2025-08-20 10:27:54 +00:00
execute_context : ? * pjrt . ExecuteContext ,
2024-12-10 09:36:37 +00:00
2023-10-13 16:08:08 +00:00
/// Pre-allocated slice of buffers to use as inputs when the module is called.
input_per_device : [ ] const [ * ] * pjrt . Buffer ,
/// Pre-allocated slice of buffers to use as outputs when the module is called.
output_per_device : [ ] const [ * ] * pjrt . Buffer ,
/// Number of buffers already fed to the executable.
ready_buffer_count : u32 ,
/// Total number of buffers needed by this executable.
input_buffer_count : u32 ,
2024-10-28 11:21:46 +00:00
input_shapes : [ ] Shape ,
2023-10-13 16:08:08 +00:00
result_shapes : [ ] Shape ,
/// Num devices used (>1 for sharded executable)
num_devices : u8 ,
/// Allocator backing memory
_arena : std . heap . ArenaAllocator ,
2024-10-28 11:21:46 +00:00
pub fn init (
parent_allocator : std . mem . Allocator ,
platform : Platform ,
exe : * pjrt . LoadedExecutable ,
args : struct { input_shapes : [ ] const Shape , result_shapes : [ ] const Shape , n_devices : u8 } ,
) ! BaseExe {
2023-10-13 16:08:08 +00:00
var arena = std . heap . ArenaAllocator . init ( parent_allocator ) ;
errdefer arena . deinit ( ) ;
const allocator = arena . allocator ( ) ;
2024-10-28 11:21:46 +00:00
const n_in = args . input_shapes . len ;
2023-10-13 16:08:08 +00:00
const n_out = args . result_shapes . len ;
const n_devices = args . n_devices ;
// Allocate once for all the *pjrt.Buffer we need to store ...
2024-10-28 11:21:46 +00:00
const all_buffers = try allocator . alloc ( * pjrt . Buffer , ( n_in + n_out ) * n_devices ) ;
const all_input_buffers , const all_output_buffers = splitBuffer ( * pjrt . Buffer , all_buffers , . { n_in * n_devices , n_out * n_devices } ) ;
2023-10-13 16:08:08 +00:00
// ... and once for all the [*]*pjrt.Buffer.
const all_per_device = try allocator . alloc ( [ * ] * pjrt . Buffer , 2 * n_devices ) ;
const input_per_device , const output_per_device = splitBuffer ( [ * ] * pjrt . Buffer , all_per_device , . { n_devices , n_devices } ) ;
for ( 0 . . n_devices ) | i | {
2024-10-28 11:21:46 +00:00
input_per_device [ i ] = all_input_buffers [ i * n_in . . ] . ptr ;
2023-10-13 16:08:08 +00:00
output_per_device [ i ] = all_output_buffers [ i * n_out . . ] . ptr ;
}
2024-10-28 11:21:46 +00:00
const all_shapes = try allocator . alloc ( Shape , n_in + n_out ) ;
@memcpy ( all_shapes [ 0 . . n_in ] , args . input_shapes ) ;
@memcpy ( all_shapes [ n_in . . ] , args . result_shapes ) ;
2025-08-20 10:27:54 +00:00
var execute_context : ? * pjrt . ExecuteContext = null ;
if ( platform . pjrt_api . ffi ( ) ) | ffi | {
execute_context = try platform . pjrt_api . createExecuteContext ( ) ;
try callback . bindInternalCallbacks ( allocator , platform , ffi , execute_context . ? ) ;
2025-08-28 14:39:21 +00:00
// log.info("Created context execution {*} for {*}", .{ execute_context, exe });
2025-08-20 10:27:54 +00:00
}
2023-10-13 16:08:08 +00:00
return . {
. platform = platform ,
. exe = exe ,
2025-08-20 10:27:54 +00:00
. execute_context = execute_context ,
2023-10-13 16:08:08 +00:00
. ready_buffer_count = 0 ,
2024-10-28 11:21:46 +00:00
. input_buffer_count = @intCast ( n_in ) ,
2023-10-13 16:08:08 +00:00
. num_devices = args . n_devices ,
. input_per_device = input_per_device ,
. output_per_device = output_per_device ,
2024-10-28 11:21:46 +00:00
. input_shapes = all_shapes [ 0 . . n_in ] ,
. result_shapes = all_shapes [ n_in . . ] ,
2023-10-13 16:08:08 +00:00
. _arena = arena ,
} ;
}
pub fn deinit ( self : BaseExe ) void {
2025-08-20 10:27:54 +00:00
if ( self . execute_context ) | ctx | {
2024-12-10 09:36:37 +00:00
ctx . deinit ( self . platform . pjrt_api ) ;
}
2023-10-13 16:08:08 +00:00
self . _arena . deinit ( ) ;
}
pub fn call ( self : BaseExe ) void {
stdx . debug . assert ( self . input_buffer_count = = self . ready_buffer_count , " BaseExe isn't ready to be called, expected {} buffer inputs got {} " , . { self . input_buffer_count , self . ready_buffer_count } ) ;
return self . _unsafeCall ( ) ;
}
pub fn _unsafeCall ( self : BaseExe ) void {
var events = [ _ ] ? * pjrt . Event { null } * * Platform . MAX_NUM_DEVICES ;
const sharding = self . platform . sharding ( ) ;
self . exe . execute ( self . platform . pjrt_api , . {
. arguments = self . input_per_device ,
. num_args = self . input_buffer_count ,
. results = self . output_per_device ,
. events = events [ 0 . . sharding . num_partitions ] ,
// this allows to tell a specific buffer shouldn't be donated,
// even if it has been marked as "can be donated" during compilation.
// TODO: expose it ?
. non_donatable_input_indices = & . { } ,
2025-08-20 10:27:54 +00:00
. context = self . execute_context ,
2024-10-28 11:21:46 +00:00
} ) catch | err | {
std . debug . panic ( " PJRT_LoadedExecutable_Execute failed with: {} " , . { err } ) ;
} ;
2023-10-13 16:08:08 +00:00
2025-08-20 10:27:54 +00:00
// for (events[0..sharding.num_partitions]) |e| {
// if (e) |ev| {
// ev.await_(self.platform.pjrt_api) catch unreachable;
// }
// }
2023-10-13 16:08:08 +00:00
}
2024-11-28 12:24:39 +00:00
pub fn _unsafeAssignResults ( self : BaseExe , T : type , result : * T ) void {
const LocalContext = struct {
index : u32 ,
platform : Platform ,
outputs : [ ] const [ * ] * pjrt . Buffer ,
output_shapes : [ ] Shape ,
} ;
var local_ctx : LocalContext = . {
. index = 0 ,
. platform = self . platform ,
. outputs = self . output_per_device ,
. output_shapes = self . result_shapes ,
} ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , buffer : * Buffer ) void {
const i = ctx . index ;
ctx . index + = 1 ;
if ( i > = ctx . output_shapes . len ) return ;
var shards : Buffer . Shards = . { } ;
for ( ctx . outputs ) | buff | {
shards . appendAssumeCapacity ( buff [ i ] ) ;
}
buffer . * = Buffer . fromPjrtBuffers ( ctx . platform , ctx . output_shapes [ i ] , shards . constSlice ( ) ) ;
}
} ) . cb , & local_ctx , result ) ;
stdx . debug . internalAssert ( local_ctx . index = = self . result_shapes . len , " Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors. " , . { self . output_per_device . len , @typeName ( T ) , local_ctx . index } ) ;
}
2025-08-20 10:27:54 +00:00
pub fn bind ( exe : BaseExe , Callback : type , op : * Callback ) ! void {
stdx . debug . assert ( exe . execute_context ! = null , " Exe doesn't have an execution context " , . { } ) ;
const pjrt_api = exe . platform . pjrt_api ;
if ( pjrt_api . ffi ( ) ) | ffi | {
try callback . addUserData ( Callback , pjrt_api , ffi , exe . execute_context . ? , op ) ;
} else {
stdx . debug . panic ( " Callbacks are not supported for target {s} " , . { @tagName ( exe . platform . target ) } ) ;
}
}
2023-10-13 16:08:08 +00:00
pub fn serialize ( self : BaseExe , writer : anytype ) ! void {
var executable = try self . exe . getExecutable ( self . platform . pjrt_api ) ;
var serialize_result = try executable . serialize ( self . platform . pjrt_api ) ;
defer serialize_result . deinit ( ) ;
try writer . writeAll ( serialize_result . bytes ) ;
}
// pub fn deserialize(allocator: std.mem.Allocator, platform: Platform, reader: anytype) !Self {
// const bytes = try reader.readToEndAlloc(allocator, max_pjrt_executable_size);
// defer allocator.free(bytes);
// return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes);
// }
pub fn prepare ( self : * BaseExe , x : anytype ) void {
2024-10-28 11:21:46 +00:00
const n = fillBuffers ( & x , self . input_shapes , self . input_per_device , self . ready_buffer_count ) ;
2023-10-13 16:08:08 +00:00
self . ready_buffer_count + = n ;
}
pub fn getOutputBuffer ( self : BaseExe , i : usize ) Buffer {
var shards : Buffer . Shards = . { } ;
for ( self . output_per_device ) | dev_out | {
shards . appendAssumeCapacity ( dev_out [ i ] ) ;
}
2023-11-06 11:25:57 +00:00
return Buffer . fromPjrtBuffers ( self . platform , self . result_shapes [ i ] , shards . constSlice ( ) ) ;
2023-10-13 16:08:08 +00:00
}
2024-10-28 11:21:46 +00:00
pub fn clone ( self : BaseExe , parent_allocator : std . mem . Allocator ) ! BaseExe {
2024-12-10 09:36:37 +00:00
var exe : BaseExe = try . init ( parent_allocator , self . platform , self . exe , . {
2025-08-20 10:27:54 +00:00
. input_shapes = self . input_shapes ,
2024-10-28 11:21:46 +00:00
. result_shapes = self . result_shapes ,
. n_devices = self . num_devices ,
} ) ;
2025-08-20 10:27:54 +00:00
exe . execute_context = self . execute_context ;
2024-12-10 09:36:37 +00:00
return exe ;
2024-10-28 11:21:46 +00:00
}
2023-10-13 16:08:08 +00:00
} ;
/// Represents a ZML function, compiled into a PJRT executable.
/// The signature of the Exe reflects the arguments that are needed for `call`.
pub fn Exe ( ArgsT : type , ReturnT : type ) type {
return struct {
const Self = @This ( ) ;
/// The raw untyped compiled module.
inner : BaseExe ,
pub fn deinit ( self : Self ) void {
self . inner . deinit ( ) ;
}
/// Hardcode the first argument of the function to the given buffers.
/// Returns an Exe with one less argument in `call`.
/// In functional languages this is known as partial application.
///
/// **Warning:** the new Exe reuses the underlying memory of the previous one.
/// The caller is responsible to come up with a strategy to call `deinit` exactly once.
pub fn prepare ( self : Self , first_arg : Bufferized ( stdx . meta . Head ( ArgsT ) ) ) Exe ( stdx . meta . Tail ( ArgsT ) , ReturnT ) {
var new : Exe ( stdx . meta . Tail ( ArgsT ) , ReturnT ) = . { . inner = self . inner } ;
new . inner . prepare ( first_arg ) ;
return new ;
}
2025-08-20 10:27:54 +00:00
/// For a given customCall inside this executable,
/// provide a pointer to runtime data.
/// The caller keeps memory ownership and need to ensure that the value
/// stays alive as long as the executable.
pub fn bind ( self : Self , comptime T : type , value : * T ) ! void {
try self . inner . bind ( T , value ) ;
}
2023-10-13 16:08:08 +00:00
pub fn serialize ( self : Self , writer : anytype ) ! void {
return try self . inner . serialize ( writer ) ;
}
pub fn platform ( self : Self ) Platform {
return self . inner . platform ;
}
pub fn call ( self : Self , args : Bufferized ( ArgsT ) ) Bufferized ( ReturnT ) {
2024-10-28 11:21:46 +00:00
const total_ready = fillBuffers ( & args , self . inner . input_shapes , self . inner . input_per_device , self . inner . ready_buffer_count ) ;
2023-10-13 16:08:08 +00:00
std . debug . assert ( total_ready = = self . inner . input_buffer_count ) ;
self . inner . _unsafeCall ( ) ;
var result : Bufferized ( ReturnT ) = undefined ;
2024-11-28 12:24:39 +00:00
self . inner . _unsafeAssignResults ( Bufferized ( ReturnT ) , & result ) ;
2023-10-13 16:08:08 +00:00
return result ;
}
} ;
}
fn splitBuffer ( T : type , buffer : [ ] T , lengths : anytype ) [ lengths . len ] [ ] T {
var res : [ lengths . len ] [ ] T = undefined ;
var i : usize = 0 ;
inline for ( & res , lengths ) | * r , len | {
r . * = buffer [ i . . i + len ] ;
i + = len ;
}
std . debug . assert ( i = = buffer . len ) ;
return res ;
}
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
2024-10-28 11:21:46 +00:00
fn fillBuffers ( v : anytype , shapes : [ ] const Shape , buffers : [ ] const [ * ] * pjrt . Buffer , start : u32 ) u32 {
2023-10-13 16:08:08 +00:00
const LocalContext = struct {
index : u32 ,
buffers : [ ] const [ * ] * pjrt . Buffer ,
2024-10-28 11:21:46 +00:00
shapes : [ ] const Shape ,
2023-10-13 16:08:08 +00:00
} ;
var context : LocalContext = . {
. index = start ,
. buffers = buffers ,
2024-10-28 11:21:46 +00:00
. shapes = shapes ,
2023-10-13 16:08:08 +00:00
} ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , buffer : * const Buffer ) void {
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
const model_sharding = ctx . buffers . len ;
2025-07-28 13:54:28 +00:00
stdx . debug . assert ( buffer . _shards . len = = model_sharding , " Can't feed a {d}-sharded tensor into a {d}-sharded model " , . { buffer . _shards . len , ctx . buffers . len } ) ;
stdx . debug . assert ( ctx . shapes [ ctx . index ] . eql ( buffer . shape ( ) ) , " Executable expected argument {} to have shape {f}, got {f} " , . { ctx . index , ctx . shapes [ ctx . index ] , buffer . shape ( ) } ) ;
2023-10-13 16:08:08 +00:00
for ( buffer . _shards . constSlice ( ) , 0 . . ) | shard , d | {
ctx . buffers [ d ] [ ctx . index ] = shard ;
}
ctx . index + = 1 ;
}
} ) . cb , & context , v ) ;
return context . index ;
}
2024-05-15 17:54:52 +00:00
fn prettyFnName (
comptime func : anytype ,
allocator : std . mem . Allocator ,
) ! std . ArrayListUnmanaged ( u8 ) {
const full_noisy_name = @typeName ( @TypeOf ( func ) ) ;
const og_len = full_noisy_name . len ;
const buffer = try allocator . alloc ( u8 , og_len ) ;
errdefer comptime unreachable ; // No errors below this point.
var out : [ ] u8 = buffer ;
{
const verbose = " tensor.Tensor " ;
const compact = " Tensor " ;
const num_replacements = std . mem . replace ( u8 , full_noisy_name , verbose , compact , buffer ) ;
out . len = out . len + num_replacements * compact . len - num_replacements * verbose . len ;
}
{
const verbose = " tensor.Tensor. " ;
const compact = " " ;
const num_replacements = std . mem . replace ( u8 , out , verbose , compact , buffer ) ;
out . len = out . len + num_replacements * compact . len - num_replacements * verbose . len ;
}
{
const verbose = " shape.Shape " ;
const compact = " Shape " ;
const num_replacements = std . mem . replace ( u8 , out , verbose , compact , buffer ) ;
out . len = out . len + num_replacements * compact . len - num_replacements * verbose . len ;
}
return . { . items = out , . capacity = og_len } ;
}