2023-10-13 16:08:08 +00:00
const std = @import ( " std " ) ;
2023-06-21 14:45:14 +00:00
const asynk = @import ( " async " ) ;
const dialect = @import ( " mlir/dialects " ) ;
2024-02-06 09:31:48 +00:00
const runfiles = @import ( " runfiles " ) ;
2023-06-21 14:45:14 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
const xla_pb = @import ( " //xla:xla_proto " ) ;
2023-06-21 14:45:14 +00:00
2023-01-02 14:28:25 +00:00
const meta = @import ( " meta.zig " ) ;
const mlir = @import ( " mlir.zig " ) ;
const ops = @import ( " ops.zig " ) ;
2023-05-26 15:54:15 +00:00
const pjrt = @import ( " pjrtx.zig " ) ;
2023-01-02 14:28:25 +00:00
2023-10-13 16:08:08 +00:00
const BaseExe = @import ( " exe.zig " ) . BaseExe ;
2023-06-21 14:45:14 +00:00
const Buffer = @import ( " buffer.zig " ) . Buffer ;
const Bufferized = @import ( " tensor.zig " ) . Bufferized ;
2023-01-02 14:28:25 +00:00
const Location = mlir . Location ;
const Platform = @import ( " platform.zig " ) . Platform ;
2023-06-21 14:45:14 +00:00
const Shape = @import ( " shape.zig " ) . Shape ;
const ShapeOf = @import ( " tensor.zig " ) . ShapeOf ;
2023-01-02 14:28:25 +00:00
const Target = @import ( " platform.zig " ) . Target ;
const Tensor = @import ( " tensor.zig " ) . Tensor ;
const Tracer = @import ( " tools/tracer.zig " ) . Tracer ;
2023-06-21 14:45:14 +00:00
const log = std . log . scoped ( . @ " zml/module " ) ;
2023-01-02 14:28:25 +00:00
2023-01-23 16:28:19 +00:00
test {
std . testing . refAllDecls ( @This ( ) ) ;
}
2023-07-25 14:25:47 +00:00
pub const BlockKind = enum { open , hermetic } ;
const Block = union ( BlockKind ) {
open : mlir . Block ,
hermetic : mlir . Block ,
pub fn block ( self : Block ) mlir . Block {
return switch ( self ) {
inline . open , . hermetic = > | t | t ,
} ;
}
fn appendTensorRecursive ( self : Block , x : * const Tensor ) void {
self . appendValueRecursive ( x . value ( ) ) ;
}
fn appendValueRecursive ( self : Block , value : mlir . Value ) void {
switch ( value . kind ( ) ) {
. op_result = > | parent_op | self . appendOperationRecursive ( parent_op ) ,
. block_argument = > | arg | {
// Hermetic blocks are not allowed to use arguments from other blocks.
2023-10-19 17:01:55 +00:00
stdx . debug . assert ( self = = . open or self . block ( ) . eql ( arg . block ( ) ) , " Can't add {} from {?x} block to {?x} block " , . { arg , arg . block ( ) . _inner . ptr , self . block ( ) . _inner . ptr } ) ;
2023-07-25 14:25:47 +00:00
} ,
. null = > @panic ( " InvalidMlir " ) ,
}
}
fn appendOperationRecursive ( self : Block , op : mlir . Operation ) void {
if ( op . block ( ) ) | prev_block | {
// Hermetic blocks are not allowed to reference values from other blocks.
std . debug . assert ( self = = . open or prev_block . equals ( self . block ( ) ) ) ;
return ;
}
for ( 0 . . op . numOperands ( ) ) | i | {
self . appendValueRecursive ( op . operand ( i ) ) ;
}
self . block ( ) . appendOperation ( op ) ;
}
} ;
2023-10-13 16:08:08 +00:00
pub const MlirFn = struct {
name : [ ] const u8 ,
num_args : u32 ,
2024-05-15 17:54:52 +00:00
res_tensors : * const anyopaque ,
2023-10-13 16:08:08 +00:00
res_types : [ ] mlir . Type ,
res_shapes : [ ] Shape ,
res_donations : [ ] Tensor . _Donation ,
mlir_fn : mlir . Operation ,
2023-10-19 17:01:55 +00:00
pub const Kind = enum {
main ,
private ,
} ;
2023-10-13 16:08:08 +00:00
} ;
2023-01-02 14:28:25 +00:00
pub const CompilationContext = struct {
_platform : Platform ,
2023-10-13 16:08:08 +00:00
_name : [ ] const u8 ,
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
_arena : std . heap . ArenaAllocator ,
2023-01-02 14:28:25 +00:00
_mlir_ctx : mlir . Context ,
_mlir_registry : mlir . Registry ,
_mlir_canonicalizer : mlir . PassManager ,
_module : mlir . Module ,
2023-11-16 15:11:23 +00:00
_blocks : std . BoundedArray ( Block , 64 ) = . { } ,
_fn_cache : FnCache = . { } ,
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
_block_args : TensorToBlockArg = . { } ,
2023-01-02 14:28:25 +00:00
_unique_id : u64 = 10000 ,
_tracer : Tracer ,
2023-05-25 16:02:11 +00:00
_previous : ? * CompilationContext = null ,
2023-01-02 14:28:25 +00:00
threadlocal var _current : ? * CompilationContext = null ;
2023-02-24 17:33:14 +00:00
const TensorToBlockArg = std . AutoHashMapUnmanaged ( Tensor . _Id , struct { mlir . Value , Tensor . _Donation } ) ;
const AttributeList = std . BoundedArray ( mlir . NamedAttribute , 3 ) ;
2023-12-25 13:01:17 +00:00
pub fn init ( allocator_ : std . mem . Allocator , full_name : [ ] const u8 , platform : Platform ) ! CompilationContext {
2023-01-02 14:28:25 +00:00
const mlir_registry = mlir . Registry . init ( ) catch unreachable ;
inline for ( . { " func " , " stablehlo " } ) | d | {
mlir . DialectHandle . fromString ( d ) . insertDialect ( mlir_registry ) ;
}
var mlir_ctx = mlir . Context . initWithRegistry ( mlir_registry , false ) catch unreachable ;
mlir_ctx . loadAllAvailableDialects ( ) ;
2024-05-15 17:54:52 +00:00
// Too long module names create too long file paths and files failed to create.
// * leave half of the space for parent folder and XLA generated filename,
// * leave 17 bytes for the module hash (16 + 1 for underscore).
const max_name_len = @divFloor ( std . fs . max_path_bytes , 2 ) - 17 ;
const name = full_name [ 0 . . @min ( max_name_len , full_name . len ) ] ;
2023-12-25 13:01:17 +00:00
2023-01-02 14:28:25 +00:00
const loc = mlir_ctx . location ( @src ( ) ) . named ( mlir_ctx , " main " ) ;
const module = mlir . Module . init ( loc ) ;
2024-05-15 17:54:52 +00:00
module . op ( ) . setAttributeByName ( " sym_name " , mlir . StringAttribute . init ( mlir_ctx , " zml " ) . as ( mlir . Attribute ) . ? ) ;
2023-01-02 14:28:25 +00:00
var canonicalizer = try mlir . PassManager . init ( mlir_ctx ) ;
{
var opm = canonicalizer . asOpPassManager ( ) ;
try opm . addPipeline ( " canonicalize " ) ;
try opm . addPipeline ( " cse " ) ;
try opm . addPipeline ( " canonicalize " ) ;
}
2023-11-16 15:11:23 +00:00
var arena = std . heap . ArenaAllocator . init ( allocator_ ) ;
2024-05-15 17:54:52 +00:00
_ = try arena . allocator ( ) . alloc ( u8 , 4096 ) ;
2023-11-16 15:11:23 +00:00
_ = arena . reset ( . retain_capacity ) ;
2023-01-02 14:28:25 +00:00
return . {
. _platform = platform ,
2024-05-15 17:54:52 +00:00
. _name = try arena . allocator ( ) . dupe ( u8 , name ) ,
2023-01-02 14:28:25 +00:00
. _mlir_ctx = mlir_ctx ,
. _mlir_registry = mlir_registry ,
. _mlir_canonicalizer = canonicalizer ,
. _module = module ,
. _blocks = . { } ,
2023-11-16 15:11:23 +00:00
. _fn_cache = . { } ,
. _arena = arena ,
2023-01-02 14:28:25 +00:00
. _tracer = Tracer . init ( " ai.zml.compilation " ) ,
} ;
}
pub fn deinit ( self : * CompilationContext ) void {
2023-11-16 15:11:23 +00:00
// No need to deinit self._fn_cache cause it uses our arena
2023-01-02 14:28:25 +00:00
self . _mlir_ctx . deinit ( ) ;
self . _mlir_registry . deinit ( ) ;
2023-11-16 15:11:23 +00:00
self . _arena . deinit ( ) ;
}
pub fn allocator ( self : * CompilationContext ) std . mem . Allocator {
return self . _arena . allocator ( ) ;
2023-01-02 14:28:25 +00:00
}
pub fn activate ( self : * CompilationContext ) void {
2023-05-25 16:02:11 +00:00
self . _previous = _current ;
2023-01-02 14:28:25 +00:00
_current = self ;
}
pub fn deactivate ( self : * CompilationContext ) void {
2023-10-13 16:08:08 +00:00
std . debug . assert ( _current ! = null and _current . ? = = self ) ;
2023-05-25 16:02:11 +00:00
_current = self . _previous ;
2023-01-02 14:28:25 +00:00
}
pub fn current ( ) * CompilationContext {
return _current . ? ;
}
pub fn target ( self : * const CompilationContext ) Target {
return self . _platform . target ;
}
pub fn mlirCtx ( self : * const CompilationContext ) mlir . Context {
return self . _mlir_ctx ;
}
2023-12-18 13:56:45 +00:00
pub fn location ( self : * const CompilationContext , src : std . builtin . SourceLocation , comptime name : [ : 0 ] const u8 , args : anytype ) mlir . Location {
return self . _mlir_ctx . location ( src ) . namedFmt ( self . _mlir_ctx , name , args ) ;
}
2023-10-13 16:08:08 +00:00
/// Compiles the given function with the given arguments.
/// This is the untyped API and is not meant to be use directly.
///
2023-11-16 15:11:23 +00:00
/// * allocator is used to allocate the result Exe
2023-10-13 16:08:08 +00:00
/// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors.
pub fn compileInternal (
self : * CompilationContext ,
2023-11-16 15:11:23 +00:00
allocator_ : std . mem . Allocator ,
2023-10-13 16:08:08 +00:00
comptime func : anytype ,
args : anytype ,
) ! BaseExe {
2023-11-16 15:11:23 +00:00
const arena = self . allocator ( ) ;
2023-10-13 16:08:08 +00:00
var timer = std . time . Timer . start ( ) catch null ;
2023-10-19 17:01:55 +00:00
const tensor_args = try self . tensorFromShapes ( stdx . meta . FnArgs ( func ) , arena , args ) ;
2023-10-13 16:08:08 +00:00
// Run in a dedicated thread because compilation relies on `threadlocal`.
2024-07-02 14:19:04 +00:00
const f = try asynk . callBlocking ( CompilationContext . emitMlir , . { self , func , & tensor_args , CompilationContext . EmitMlirOpts { . name = " main " , . kind = . main } } ) ;
2023-10-13 16:08:08 +00:00
const module = self . _module ;
module . getBody ( ) . appendOperation ( f . mlir_fn ) ;
const sharding = self . _platform . sharding ( ) ;
const mlir_ctx = self . _mlir_ctx ;
module . op ( ) . setAttributeByName ( " mhlo.num_replicas " , mlir . IntegerAttribute ( . i32 ) . init ( mlir_ctx , sharding . num_replicas ) . asAttr ( ) ) ;
module . op ( ) . setAttributeByName ( " mhlo.num_partitions " , mlir . IntegerAttribute ( . i32 ) . init ( mlir_ctx , sharding . num_partitions ) . asAttr ( ) ) ;
const module_hash = computeModuleHash ( self . _platform , module ) ;
2023-12-04 10:38:10 +00:00
var module_dir : ? [ ] const u8 = null ;
var pjrt_location : ? [ : 0 ] const u8 = null ;
2023-10-13 16:08:08 +00:00
if ( self . _platform . compilation_options . xla_dump_to ) | xla_dump_to | {
2023-12-04 10:38:10 +00:00
const sep = std . fs . path . sep_str ;
const module_dir_name = try std . fmt . allocPrint ( arena , " {s}{s}{s}{s}{s}_{x} " , . { xla_dump_to , sep , @tagName ( self . _platform . target ) , sep , self . _name , module_hash } ) ;
try std . fs . cwd ( ) . makePath ( module_dir_name ) ;
module_dir = try std . fs . cwd ( ) . realpathAlloc ( arena , module_dir_name ) ;
const cache_dir = try std . fs . cwd ( ) . openDir ( module_dir . ? , . { } ) ;
2023-10-13 16:08:08 +00:00
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
2023-12-04 10:38:10 +00:00
const mlir_name = " module.mlir " ;
if ( cache_dir . createFile ( mlir_name , . { . truncate = true } ) ) | file | {
module . op ( ) . print ( file . writer ( ) , . { . debug_info = true , . debug_info_pretty_form = false } ) ;
log . info ( " Wrote MLIR to {s}/{s} " , . { module_dir . ? , mlir_name } ) ;
2023-10-13 16:08:08 +00:00
} else | _ | {
2023-12-04 10:38:10 +00:00
log . warn ( " Failed to open {s} " , . { mlir_name } ) ;
2023-10-13 16:08:08 +00:00
}
2023-12-04 10:38:10 +00:00
pjrt_location = try std . fs . path . joinZ ( arena , & . { module_dir . ? , " module.pjrt " } ) ;
}
2023-10-13 16:08:08 +00:00
const loaded_executable : * pjrt . LoadedExecutable = blk : {
2023-12-04 10:38:10 +00:00
if ( pjrt_location ) | pjrt_loc | {
if ( loadPjrtExecutable ( arena , self . _platform , pjrt_loc ) ) | exe | {
log . info ( " Loaded pre-compiled module from {s} " , . { pjrt_loc } ) ;
2023-10-13 16:08:08 +00:00
break : blk exe ;
2023-12-04 10:38:10 +00:00
} else | err | {
if ( err ! = error . FileNotFound ) log . warn ( " Failed to load pre-compiled module: {} at {s} " , . { err , pjrt_loc } ) ;
}
2023-10-13 16:08:08 +00:00
}
2023-12-05 12:27:08 +00:00
const loaded_executable = compileModuleToPjrtExecutable ( arena , self . _platform , module , module_dir ) catch | err | {
2024-02-19 12:34:18 +00:00
log . err ( " pjrt-{s} failed to compile: {} " , . { @tagName ( self . _platform . target ) , err } ) ;
if ( module_dir ) | dir | log . err ( " mlir can be found at {s}/module.mlir " , . { dir } ) ;
2023-10-13 16:08:08 +00:00
return err ;
} ;
2023-12-04 10:38:10 +00:00
if ( pjrt_location ) | pjrt_loc | {
storePjrtExecutable ( self . _platform , loaded_executable , pjrt_loc ) catch | err | {
log . warn ( " Failed to store compiled module: {} at {s} " , . { err , pjrt_loc } ) ;
2023-10-13 16:08:08 +00:00
} ;
}
break : blk loaded_executable ;
} ;
log . debug ( " ******** ZML generated MLIR ******** " , . { } ) ;
log . debug ( " {} " , . { module . op ( ) . mlirFormatter ( . { } ) } ) ;
if ( timer ) | * t | {
const time_ms = @divFloor ( t . lap ( ) , std . time . ns_per_ms ) ;
if ( time_ms > 1000 ) log . info ( " Compilation took {d:.3}s " , . { stdx . math . divFloat ( f32 , time_ms , 1000 ) } ) ;
}
return BaseExe . init (
2023-11-16 15:11:23 +00:00
allocator_ ,
2023-10-13 16:08:08 +00:00
self . _platform ,
loaded_executable ,
. {
. n_in = f . num_args ,
. result_shapes = f . res_shapes ,
. n_devices = sharding . num_replicas * sharding . num_partitions ,
} ,
) ;
}
fn currentBlock ( self : * const CompilationContext ) ? Block {
2023-01-02 14:28:25 +00:00
return if ( self . _blocks . len > 0 ) self . _blocks . get ( self . _blocks . len - 1 ) else null ;
}
2023-07-25 14:25:47 +00:00
pub fn openBlock ( self : * CompilationContext , kind : BlockKind , args : [ ] const mlir . Type , locs : [ ] const mlir . Location ) ! Block {
const mlir_block = try mlir . Block . init ( args , locs ) ;
const block : Block = switch ( kind ) {
. open = > . { . open = mlir_block } ,
. hermetic = > . { . hermetic = mlir_block } ,
} ;
2023-01-02 14:28:25 +00:00
self . pushBlock ( block ) ;
return block ;
}
2023-07-25 14:25:47 +00:00
pub fn closeBlock ( self : * CompilationContext , block : Block ) void {
2023-01-02 14:28:25 +00:00
const popped = self . _blocks . pop ( ) ;
2024-07-02 14:19:04 +00:00
std . debug . assert ( block . block ( ) . eql ( popped . ? . block ( ) ) ) ;
2023-01-02 14:28:25 +00:00
}
2023-07-25 14:25:47 +00:00
fn pushBlock ( self : * CompilationContext , block : Block ) void {
2023-01-02 14:28:25 +00:00
self . _blocks . appendAssumeCapacity ( block ) ;
}
/// Transform a Tensor -> Tensor function into an Mlir block.
/// `blkctx` represents values from outside the block that can be accessed inside the block.
2023-07-21 09:01:01 +00:00
/// Returns both the mlir.Block created and also the Tensors returned by `func`.
/// The returned tensors should not be returned to the user,
/// because their `mlir.Value` must not escape the block that created them.
/// But their shapes/tags can be safely propagated further.
2023-01-02 14:28:25 +00:00
pub fn makeBlock (
self : * CompilationContext ,
2023-07-25 14:25:47 +00:00
kind : BlockKind ,
2023-01-02 14:28:25 +00:00
comptime S : ops . BlockSignature ,
2023-02-14 13:52:49 +00:00
func : * const S . Fn ,
2023-01-02 14:28:25 +00:00
blkctx : S . BlkCtx ,
args : S . Args ,
2023-07-21 09:01:01 +00:00
) struct { mlir . Block , S . Return } {
2023-01-02 14:28:25 +00:00
const N = S . nIn ;
2023-07-25 14:25:47 +00:00
const loc = self . mlirCtx ( ) . location ( @src ( ) ) ;
const locations = . { loc } * * N ;
2023-01-02 14:28:25 +00:00
var input_types : [ N ] mlir . Type = undefined ;
fillMlirTypes ( & args , self . mlirCtx ( ) , & input_types ) ;
2023-07-25 14:25:47 +00:00
// Before creating a new block, assign all received values to previous block,
// otherwise they will be assign to this block
if ( self . currentBlock ( ) ) | prev_block | {
meta . visit ( Block . appendTensorRecursive , prev_block , & blkctx ) ;
}
const block = self . openBlock ( kind , & input_types , & locations ) catch unreachable ;
defer self . closeBlock ( block ) ;
2023-01-02 14:28:25 +00:00
// Here we want to create the block with the correct mlir types.
// but we don't want to use the values themselves.
// So we create a copy of the arguments, and replace values
// by the block arguments.
var blk_args = args ;
2023-07-25 14:25:47 +00:00
std . debug . assert ( assignBlockArguments ( & blk_args , block . block ( ) , 0 ) = = N ) ;
2023-01-02 14:28:25 +00:00
const block_res = @call ( . auto , func , S . blkArgs ( blkctx , blk_args ) ) ;
var block_res_values : [ S . nOut ] mlir . Value = undefined ;
self . extractValues ( & block_res , & block_res_values ) ;
const block_ret = dialect . stablehlo . returns_ ( self . mlirCtx ( ) , & block_res_values , loc ) ;
2023-07-25 14:25:47 +00:00
block . appendOperationRecursive ( block_ret ) ;
2023-01-02 14:28:25 +00:00
2023-07-25 14:25:47 +00:00
return . { block . block ( ) , block_res } ;
2023-01-02 14:28:25 +00:00
}
2024-07-02 14:19:04 +00:00
pub const EmitMlirOpts = struct {
name : [ ] const u8 ,
kind : MlirFn . Kind = . private ,
} ;
2023-01-02 14:28:25 +00:00
/// Generate an MLIR function from a ZML function.
/// The caller is responsible to have properly created the input
/// tensors with unique tensor ids.
2023-10-19 17:01:55 +00:00
pub fn emitMlir (
2023-01-02 14:28:25 +00:00
self : * CompilationContext ,
comptime func : anytype ,
2023-10-13 16:08:08 +00:00
args : * const stdx . meta . FnArgs ( func ) ,
2024-07-02 14:19:04 +00:00
opts : EmitMlirOpts ,
2023-01-02 14:28:25 +00:00
) error { OutOfMemory } ! MlirFn {
2023-10-19 17:01:55 +00:00
const frame = self . _tracer . frameStart ( " emitMlir.emit " ) ;
errdefer self . _tracer . frameEnd ( frame , " emitMlir.emit " ) ;
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
const res_allocator = self . allocator ( ) ;
2023-01-02 14:28:25 +00:00
// Note: only temp allocations are done in the arena,
2023-11-16 15:11:23 +00:00
// the other allocations are in the context allocator.
var arena_state = std . heap . ArenaAllocator . init ( self . _arena . child_allocator ) ;
2023-01-02 14:28:25 +00:00
defer arena_state . deinit ( ) ;
const arena = arena_state . allocator ( ) ;
2024-01-08 17:55:20 +00:00
const tensor_count = meta . count ( Tensor , args ) ;
2023-01-02 14:28:25 +00:00
2023-02-24 17:33:14 +00:00
const mlir_ctx = self . mlirCtx ( ) ;
const loc = mlir_ctx . location ( @src ( ) ) ;
2023-01-02 14:28:25 +00:00
const locations = try arena . alloc ( mlir . Location , tensor_count ) ;
2023-02-24 17:33:14 +00:00
@memset ( locations , mlir . Location . unknown ( mlir_ctx ) ) ;
var input_shapes = try std . ArrayList ( Shape ) . initCapacity ( arena , tensor_count ) ;
meta . collect ( Tensor . shape , { } , & input_shapes , args ) catch unreachable ;
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( input_shapes . items . len = = tensor_count , " args have changed ? " , . { } ) ;
2023-02-24 17:33:14 +00:00
const input_types = try arena . alloc ( mlir . Type , tensor_count ) ;
for ( input_types , input_shapes . items ) | * t , sh | t . * = mlir . ext . mlirType ( mlir_ctx , sh ) ;
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
const og_block_args = self . _block_args ;
defer {
self . _block_args . deinit ( self . allocator ( ) ) ;
self . _block_args = og_block_args ;
}
// Reset the buffer -> assignement
self . _block_args = . { } ;
2023-01-02 14:28:25 +00:00
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
// But it forces user to have simpler function.
2023-10-13 16:08:08 +00:00
const ReturnT = stdx . meta . FnResult ( func ) ;
const out_tensor_count = comptime ops . staticCountTensors ( ReturnT ) orelse @compileError ( " Can't use " + + @typeName ( ReturnT ) + + " in an MLIR function, because it has a variable number of tensors " ) ;
2024-05-15 17:54:52 +00:00
// Those are returned to caller so we don't put them in the arena, but in the module allocator.
const fn_res = try res_allocator . create ( ReturnT ) ;
2023-11-16 15:11:23 +00:00
const fn_res_types = try res_allocator . alloc ( mlir . Type , out_tensor_count ) ;
const fn_res_shapes = try res_allocator . alloc ( Shape , out_tensor_count ) ;
const fn_res_donations = try res_allocator . alloc ( Tensor . _Donation , out_tensor_count ) ;
2023-07-25 14:25:47 +00:00
var fn_body = self . openBlock ( . hermetic , input_types , locations ) catch unreachable ;
2023-01-02 14:28:25 +00:00
{
2023-07-25 14:25:47 +00:00
defer self . closeBlock ( fn_body ) ;
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
try self . _block_args . ensureUnusedCapacity ( self . allocator ( ) , @intCast ( tensor_count ) ) ;
2023-10-13 16:08:08 +00:00
const assigned_args_count = self . mapBlockArguments ( args , fn_body . block ( ) , 0 ) ;
std . debug . assert ( assigned_args_count = = tensor_count ) ;
2023-01-02 14:28:25 +00:00
2024-05-15 17:54:52 +00:00
fn_res . * = forward : {
2023-01-02 14:28:25 +00:00
self . activate ( ) ;
defer self . deactivate ( ) ;
2023-10-13 16:08:08 +00:00
break : forward @call ( . auto , func , args . * ) ;
2023-01-02 14:28:25 +00:00
} ;
var fn_res_values : [ out_tensor_count ] mlir . Value = undefined ;
2024-05-15 17:54:52 +00:00
self . extractValuesAndTypes ( fn_res , & fn_res_values , fn_res_types , fn_res_shapes , fn_res_donations ) ;
2023-03-21 10:50:39 +00:00
2023-02-24 17:33:14 +00:00
const fn_ret = dialect . func . return_ ( mlir_ctx , & fn_res_values , loc ) ;
2023-07-25 14:25:47 +00:00
fn_body . appendOperationRecursive ( fn_ret ) ;
2023-01-02 14:28:25 +00:00
}
2023-02-24 17:33:14 +00:00
const arg_attrs = try arena . alloc ( AttributeList , tensor_count ) ;
@memset ( arg_attrs , . { } ) ;
2023-03-21 10:50:39 +00:00
const res_attrs = try arena . alloc ( AttributeList , out_tensor_count ) ;
@memset ( res_attrs , . { } ) ;
2023-10-19 17:01:55 +00:00
if ( opts . kind = = . main ) {
self . addDonationsAttributes ( arg_attrs , fn_res_donations ) ;
if ( self . _platform . sharding ( ) . num_partitions > 1 ) {
self . addShardingAttributes ( arg_attrs , res_attrs , input_shapes . items , fn_res_shapes ) ;
}
2023-03-21 10:50:39 +00:00
}
2023-10-19 17:01:55 +00:00
2023-01-02 14:28:25 +00:00
const mlir_fn = dialect . func . func ( self . mlirCtx ( ) , . {
2023-10-19 17:01:55 +00:00
. sym_name = opts . name ,
2023-02-24 17:33:14 +00:00
. args = input_types ,
. arg_attrs = try finalizeAttributeList ( arena , mlir_ctx , arg_attrs ) ,
2023-01-02 14:28:25 +00:00
. results = fn_res_types ,
2023-03-21 10:50:39 +00:00
. res_attrs = try finalizeAttributeList ( arena , mlir_ctx , res_attrs ) ,
2023-07-25 14:25:47 +00:00
. block = fn_body . block ( ) ,
2023-01-02 14:28:25 +00:00
. location = loc ,
} ) ;
2023-10-19 17:01:55 +00:00
self . _tracer . frameEnd ( frame , " emitMlir.emit " ) ;
const canonicalize_frame = self . _tracer . frameStart ( " emitMlir.canonicalize " ) ;
defer self . _tracer . frameEnd ( canonicalize_frame , " emitMlir.canonicalize " ) ;
2023-01-02 14:28:25 +00:00
self . _mlir_canonicalizer . runOnOp ( mlir_fn ) catch | err | switch ( err ) {
error . InvalidMlir = > {
log . err ( " Failed to canonicalize invalid mlir: {} " , . { mlir_fn . mlirFormatter ( . { } ) } ) ;
// user errors should have triggered a panic before we reach this.
@panic ( " ZML generated invalid mlir. Please open a bug report " ) ;
} ,
} ;
return . {
. mlir_fn = mlir_fn ,
2023-10-19 17:01:55 +00:00
. name = opts . name ,
2023-10-13 16:08:08 +00:00
. num_args = @intCast ( tensor_count ) ,
2024-05-15 17:54:52 +00:00
. res_tensors = fn_res ,
2023-01-02 14:28:25 +00:00
. res_types = fn_res_types ,
. res_shapes = fn_res_shapes ,
. res_donations = fn_res_donations ,
} ;
}
/// Given a list of donations mapping output buffers to input buffers,
/// generate donation attribute for each `n_args` input argument.
2023-02-24 17:33:14 +00:00
fn addDonationsAttributes ( self : CompilationContext , attributes : [ ] AttributeList , donations : [ ] const Tensor . _Donation ) void {
2023-01-02 14:28:25 +00:00
var n_donations : usize = 0 ;
for ( donations , 0 . . ) | donation , index | {
switch ( donation ) {
. no_buffer = > { } ,
// This is an input buffer that has been returned,
// but without explicitly calling `reuseBuffer`.
// So we assume the intent was to return a new buffer.
. input_buffer = > { } ,
. arg = > | a | {
n_donations + = 1 ;
2023-02-24 17:33:14 +00:00
// This will break the day we writer another attribute before donation.
// When the time come, do a more fancy lookup here to check if an argument
// is donated twice.
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( attributes [ a ] . len = = 0 , " Donation error ! Argument {} has been donated twice ! To {} and to {} " , . { a , index , attributes [ a ] . buffer [ 0 ] } ) ;
2023-02-24 17:33:14 +00:00
attributes [ a ] . appendAssumeCapacity (
2023-01-02 14:28:25 +00:00
mlir . NamedAttribute . init (
mlir . Identifier . get ( self . mlirCtx ( ) , " tf.aliasing_output " ) ,
mlir . IntegerAttribute ( . i32 ) . init ( self . mlirCtx ( ) , @intCast ( index ) ) . as ( mlir . Attribute ) . ? ,
) ,
2023-02-24 17:33:14 +00:00
) ;
// log.debug("attribute: {}", .{attributes[a].constSlice()});
2023-01-02 14:28:25 +00:00
} ,
}
}
}
2023-02-24 17:33:14 +00:00
test addDonationsAttributes {
2023-01-02 14:28:25 +00:00
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
var arena = std . heap . ArenaAllocator . init ( std . testing . allocator ) ;
defer arena . deinit ( ) ;
const s = Shape . init ( . { 8 } , . f16 ) ;
const Local = struct {
bias : Tensor ,
2024-01-08 17:55:20 +00:00
pub fn _fwd ( self : @This ( ) , x : Tensor , y : Tensor ) [ 2 ] Tensor {
const x1 = zml . ops . call ( self , . _inner , . { x } ) ;
const x2 = zml . ops . call ( self , . _inner , . { x1 } ) ;
2023-10-19 17:01:55 +00:00
return . { x1 . reuseBuffer ( y ) , x2 } ;
}
2024-01-08 17:55:20 +00:00
pub fn _inner ( self : @This ( ) , x : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
const y = x . add ( self . bias ) ;
return y . reuseBuffer ( x ) ;
}
} ;
const model : Local = . {
. bias = zml . Tensor { . _shape = s , . _id = . { . buffer_id = 0 } } ,
} ;
2023-11-16 15:11:23 +00:00
var comp = try zml . module . CompilationContext . init ( std . testing . allocator , " test " , platform ) ;
2023-01-02 14:28:25 +00:00
defer comp . deinit ( ) ;
2023-10-19 17:01:55 +00:00
var tensor_args = . { model , Tensor { . _shape = s , . _id = . { . buffer_id = 1234 } } , Tensor { . _shape = s , . _id = . { . buffer_id = 1235 } } } ;
2024-01-08 17:55:20 +00:00
const f = try comp . emitMlir ( Local . _fwd , & tensor_args , . { . name = " test.emitMlir.Local.forward " , . kind = . main } ) ;
2023-01-02 14:28:25 +00:00
2023-11-16 15:11:23 +00:00
var mlir_bytecode = std . ArrayList ( u8 ) . init ( std . testing . allocator ) ;
defer mlir_bytecode . deinit ( ) ;
try mlir_bytecode . writer ( ) . print ( " {} " , . { f . mlir_fn . mlirFormatter ( . { } ) } ) ;
2023-01-02 14:28:25 +00:00
// Check that the `x` input argument gives its buffer to the result tensor.
2023-10-19 17:01:55 +00:00
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
try std . testing . expectEqual ( 3 , f . num_args ) ;
// We should have two buffers being donated.
const template = " tf.aliasing_output = {d} : i32 " ;
var buf = template . * ;
for ( 0 . . 2 ) | i | {
const alias_attr = std . fmt . bufPrint ( & buf , template , . { i } ) catch unreachable ;
std . testing . expect ( std . mem . indexOf ( u8 , mlir_bytecode . items , alias_attr ) ! = null ) catch | err | {
log . warn ( " Didn't produced the expected IR: \n {s} " , . { mlir_bytecode . items } ) ;
return err ;
} ;
}
2023-01-02 14:28:25 +00:00
}
2023-03-21 10:50:39 +00:00
pub fn getShardingAttr ( self : CompilationContext , shape : Shape ) mlir . StringAttribute {
2023-02-24 17:33:14 +00:00
const mlir_ctx = self . mlirCtx ( ) ;
const num_partitions = self . _platform . sharding ( ) . num_partitions ;
var sharding_str : std . BoundedArray ( u8 , 128 ) = . { } ;
2023-03-21 10:50:39 +00:00
writeShardingRepresentation ( shape , num_partitions , sharding_str . writer ( ) ) catch unreachable ;
return mlir . StringAttribute . init ( mlir_ctx , sharding_str . constSlice ( ) ) ;
}
fn addShardingAttributes ( self : CompilationContext , arg_attrs : [ ] AttributeList , res_attrs : [ ] AttributeList , input_shapes : [ ] const Shape , output_shapes : [ ] const Shape ) void {
const mlir_ctx = self . mlirCtx ( ) ;
if ( ! self . _platform . compilation_options . sharding_enabled ) return ;
2023-02-24 17:33:14 +00:00
const mhlo_default_layout = mlir . NamedAttribute . init (
mlir . Identifier . get ( mlir_ctx , " mhlo.layout_mode " ) ,
mlir . StringAttribute . init ( mlir_ctx , " default " ) . asAttr ( ) ,
) ;
2023-03-21 10:50:39 +00:00
for ( arg_attrs , input_shapes ) | * attr , shape | {
attr . appendAssumeCapacity ( mhlo_default_layout ) ;
const sharding_attr = self . getShardingAttr ( shape ) ;
attr . appendAssumeCapacity ( mlir . NamedAttribute . init (
mlir . Identifier . get ( mlir_ctx , " mhlo.sharding " ) ,
sharding_attr . asAttr ( ) ,
) ) ;
}
for ( res_attrs , output_shapes ) | * attr , shape | {
2023-02-24 17:33:14 +00:00
attr . appendAssumeCapacity ( mhlo_default_layout ) ;
2023-03-21 10:50:39 +00:00
const sharding_attr = self . getShardingAttr ( shape ) ;
2023-02-24 17:33:14 +00:00
attr . appendAssumeCapacity ( mlir . NamedAttribute . init (
mlir . Identifier . get ( mlir_ctx , " mhlo.sharding " ) ,
2023-03-21 10:50:39 +00:00
sharding_attr . asAttr ( ) ,
2023-02-24 17:33:14 +00:00
) ) ;
}
}
fn writeShardingRepresentation ( shape : Shape , num_partitions : u8 , writer : anytype ) @TypeOf ( writer ) . Error ! void {
const n_sharded : u8 = @popCount ( @as ( u8 , @bitCast ( shape . _sharding_info ) ) ) ;
if ( n_sharded = = 0 or num_partitions = = 1 ) {
try writer . writeAll ( " {replicated} " ) ;
return ;
}
try writer . writeAll ( " {devices=[ " ) ;
for ( 0 . . shape . rank ( ) ) | i | {
try writer . print ( " {d} " , . { if ( shape . _sharding_info [ i ] ) num_partitions else 1 } ) ;
if ( i < shape . rank ( ) - 1 ) try writer . writeByte ( ',' ) ;
}
try writer . print ( " ]<=[{d}]}} " , . { num_partitions } ) ;
}
test writeShardingRepresentation {
var rule : [ 64 ] u8 = undefined ;
const x = Shape . init ( . { 16 , 8 } , . f32 ) ;
// By default tensors are replicated.
{
var fbs = std . io . fixedBufferStream ( & rule ) ;
try writeShardingRepresentation ( x , 4 , fbs . writer ( ) ) ;
try std . testing . expectEqualStrings ( " {replicated} " , fbs . getWritten ( ) ) ;
}
// Shard along first axis.
{
var fbs = std . io . fixedBufferStream ( & rule ) ;
try writeShardingRepresentation ( x . withSharding ( . { 0 } ) , 4 , fbs . writer ( ) ) ;
try std . testing . expectEqualStrings ( " {devices=[4,1]<=[4]} " , fbs . getWritten ( ) ) ;
}
// Also shard along second axis.
{
var fbs = std . io . fixedBufferStream ( & rule ) ;
try writeShardingRepresentation ( x . withSharding ( . { 0 , 1 } ) , 2 , fbs . writer ( ) ) ;
try std . testing . expectEqualStrings ( " {devices=[2,2]<=[2]} " , fbs . getWritten ( ) ) ;
}
}
2023-11-16 15:11:23 +00:00
fn finalizeAttributeList ( allocator_ : std . mem . Allocator , mlir_ctx : mlir . Context , attributes : [ ] AttributeList ) ! [ ] mlir . Attribute {
const res = try allocator_ . alloc ( mlir . Attribute , attributes . len ) ;
2023-02-24 17:33:14 +00:00
for ( res , attributes ) | * r , attr | {
r . * = mlir . DictionaryAttribute . init ( mlir_ctx , attr . constSlice ( ) ) . asAttr ( ) ;
}
return res ;
}
2023-01-02 14:28:25 +00:00
/// Generates an MLIR `func.call` of the given function.
/// If the function has not been seen yet, we generate MLIR for it,
/// in a independent function.
/// The main benefit of this is to generate MLIR that maps more closely
/// to the Zig code, but compilation speed stays similar.
pub fn callFunc (
self : * CompilationContext ,
func_name : [ : 0 ] const u8 ,
comptime func : anytype ,
2023-10-13 16:08:08 +00:00
args : stdx . meta . FnArgs ( func ) ,
2024-05-15 17:54:52 +00:00
) error { OutOfMemory } ! stdx . meta . FnResult ( func ) {
2023-11-16 15:11:23 +00:00
var arena_state = std . heap . ArenaAllocator . init ( self . _arena . child_allocator ) ;
2023-01-02 14:28:25 +00:00
defer arena_state . deinit ( ) ;
2024-05-15 17:54:52 +00:00
// This arena is used for allocations which won't outlive the function call,
// but the function creation uses `self.allocator()` which we'll live for the duration of the compilation.
2023-01-02 14:28:25 +00:00
const arena = arena_state . allocator ( ) ;
// first, do the "compile" and check the bytecode
// the result of this will also have the correct tags of the result shapes
2023-10-19 17:01:55 +00:00
const args_hash = hashArgs ( args ) ;
2024-05-15 17:54:52 +00:00
const key : FnKey = . { . fn_ptr = & func , . input_hash = args_hash } ;
2023-11-16 15:11:23 +00:00
2024-05-15 17:54:52 +00:00
const function = self . _fn_cache . get ( key ) orelse b : {
2023-01-02 14:28:25 +00:00
const full_name : [ : 0 ] const u8 = if ( std . mem . eql ( u8 , " main " , func_name ) )
2024-05-15 17:54:52 +00:00
try self . allocator ( ) . dupeZ ( u8 , func_name )
2023-01-02 14:28:25 +00:00
else
2024-05-15 17:54:52 +00:00
try std . fmt . allocPrintZ ( self . allocator ( ) , " {s}_{x} " , . { func_name , key . input_hash } ) ;
2023-01-02 14:28:25 +00:00
2023-10-19 17:01:55 +00:00
var arg_id : u16 = 0 ;
var tensor_args : @TypeOf ( args ) = args ;
2024-05-15 17:54:52 +00:00
try meta . mapAlloc ( struct {
2023-10-19 17:01:55 +00:00
fn cb ( arg_id_ : * u16 , x : Tensor ) Tensor {
const a = arg_id_ . * ;
arg_id_ . * + = 1 ;
return Tensor { . _shape = x . _shape , . _id = . { . arg_id = a } , . _donation = . { . arg = a } } ;
}
2024-05-15 17:54:52 +00:00
} . cb , arena , & arg_id , args , & tensor_args ) ;
2023-10-19 17:01:55 +00:00
2024-05-15 17:54:52 +00:00
const f = try self . emitMlir (
func ,
& tensor_args ,
. { . name = full_name } ,
) ;
2023-10-19 17:01:55 +00:00
self . _module . getBody ( ) . appendOperation ( f . mlir_fn ) ;
2023-01-02 14:28:25 +00:00
2024-05-15 17:54:52 +00:00
try self . _fn_cache . putNoClobber ( self . allocator ( ) , key , f ) ;
break : b f ;
2023-10-19 17:01:55 +00:00
} ;
2023-01-02 14:28:25 +00:00
const loc = self . mlirCtx ( ) . location ( @src ( ) ) ;
2024-05-15 17:54:52 +00:00
const values = try arena . alloc ( mlir . Value , function . num_args ) ;
2023-10-19 17:01:55 +00:00
self . extractValues ( & args , values ) ;
2023-01-02 14:28:25 +00:00
2024-05-15 17:54:52 +00:00
const donations = try arena . alloc ( Tensor . _Donation , function . num_args ) ;
2023-10-19 17:01:55 +00:00
meta . collectBuf ( struct {
pub fn cb ( ctx : * const CompilationContext , x : Tensor ) Tensor . _Donation {
return ctx . getValueAndDonation ( x ) [ 1 ] ;
}
} . cb , self , & args , donations ) ;
const op = dialect . func . call ( self . mlirCtx ( ) , @ptrCast ( function . name ) , values , function . res_types , loc ) ;
// Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated.
2024-05-15 17:54:52 +00:00
var res = @as ( * const stdx . meta . FnResult ( func ) , @alignCast ( @ptrCast ( function . res_tensors ) ) ) . * ;
2023-10-19 17:01:55 +00:00
const LocalContext = struct { index : usize = 0 , op : mlir . Operation , function : MlirFn , donations : [ ] Tensor . _Donation } ;
var context : LocalContext = . { . op = op , . function = function , . donations = donations } ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , tensor : * Tensor ) void {
const i = ctx . index ;
ctx . index + = 1 ;
var new = Tensor . fromMlirValue ( ctx . op . result ( i ) ) ;
new . _shape = ctx . function . res_shapes [ i ] ;
new . _donation = switch ( ctx . function . res_donations [ i ] ) {
. no_buffer = > . no_buffer ,
. arg = > | input_arg | ctx . donations [ input_arg ] ,
. input_buffer = > . no_buffer , // user escaped the sandbox
} ;
tensor . * = new ;
}
} ) . cb , & context , & res ) ;
std . debug . assert ( context . index = = op . numResults ( ) ) ;
2023-01-02 14:28:25 +00:00
return res ;
}
/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor.
///
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
/// stored in the Module.
2023-11-16 15:11:23 +00:00
/// Caller need to allocate required memory in self._block_args.
2023-01-02 14:28:25 +00:00
pub fn mapBlockArguments ( self : * CompilationContext , v : anytype , block : mlir . Block , start : usize ) usize {
const LocalContext = struct {
index : usize ,
block : mlir . Block ,
self : * CompilationContext ,
} ;
var context = LocalContext { . self = self , . block = block , . index = start } ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , tensor : * const Tensor ) void {
const arg_value = ctx . block . argument ( ctx . index ) ;
// log.debug("mapping {} to arg {}", .{ tensor._id, ctx.index });
2023-11-16 15:11:23 +00:00
const res = ctx . self . _block_args . getOrPutAssumeCapacity ( tensor . _id ) ;
2023-01-02 14:28:25 +00:00
if ( res . found_existing ) {
2023-10-19 17:01:55 +00:00
stdx . debug . panic ( " Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}). " , . { res . value_ptr . * [ 0 ] , tensor , ctx . index , tensor . _id } ) ;
2023-01-02 14:28:25 +00:00
} else {
res . value_ptr . * = . { arg_value , . { . arg = @intCast ( ctx . index ) } } ;
}
ctx . index + = 1 ;
}
} ) . cb , & context , v ) ;
return context . index ;
}
/// Create tensor from the given shapes.
/// Each created tensor will receive a unique id, local to this CompilationContext.
2023-11-16 15:11:23 +00:00
pub fn tensorFromShapes ( self : * CompilationContext , ArgsT : type , allocator_ : std . mem . Allocator , args_shapes : anytype ) ! ArgsT {
2023-01-02 14:28:25 +00:00
const Local = struct {
fn tensorFromShape ( arg_id : * u64 , shape : Shape ) Tensor {
defer arg_id . * + = 1 ;
return Tensor {
. _shape = shape ,
. _id = . { . arg_id = arg_id . * } ,
. _donation = . input_buffer ,
} ;
}
} ;
var tensor_args : ArgsT = undefined ;
2023-11-16 15:11:23 +00:00
try meta . mapAlloc ( Local . tensorFromShape , allocator_ , & self . _unique_id , args_shapes , & tensor_args ) ;
2023-01-02 14:28:25 +00:00
return tensor_args ;
}
/// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found.
pub fn extractValuesAndTypes ( self : * const CompilationContext , v : anytype , values : [ ] mlir . Value , types : [ ] mlir . Type , shapes : [ ] Shape , donations : [ ] Tensor . _Donation ) void {
2023-10-13 16:08:08 +00:00
std . debug . assert ( values . len = = types . len ) ;
2023-01-02 14:28:25 +00:00
const LocalContext = struct {
self : * const CompilationContext ,
index : usize = 0 ,
values : [ ] mlir . Value ,
types : [ ] mlir . Type ,
shapes : [ ] Shape ,
donations : [ ] Tensor . _Donation ,
} ;
var context = LocalContext { . self = self , . values = values , . types = types , . shapes = shapes , . donations = donations } ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , tensor : * const Tensor ) void {
const value , const donation = ctx . self . getValueAndDonation ( tensor . * ) ;
ctx . values [ ctx . index ] = value ;
ctx . types [ ctx . index ] = value . getType ( ) ;
ctx . shapes [ ctx . index ] = tensor . _shape ;
ctx . donations [ ctx . index ] = donation ;
ctx . index + = 1 ;
}
} ) . cb , & context , v ) ;
2023-10-13 16:08:08 +00:00
std . debug . assert ( context . index = = values . len ) ;
2023-01-02 14:28:25 +00:00
}
pub fn getValueAndDonation ( self : * const CompilationContext , tensor : Tensor ) struct { mlir . Value , Tensor . _Donation } {
return switch ( tensor . _id ) {
2023-11-16 15:11:23 +00:00
. buffer_id , . arg_id = > if ( self . _block_args . get ( tensor . _id ) ) | res |
2023-01-02 14:28:25 +00:00
. { res [ 0 ] , res [ 1 ] }
else {
log . err ( " Found unknown tensor id {}({}) " , . { tensor , tensor . _id } ) ;
@panic ( " Found unknown tensor id " ) ;
} ,
. mlir = > | v | . { v , tensor . _donation } ,
} ;
}
2024-01-08 17:55:20 +00:00
pub fn getValue ( self : * const CompilationContext , tensor : Tensor ) mlir . Value {
2023-07-21 09:01:01 +00:00
return self . getValueAndDonation ( tensor ) [ 0 ] ;
}
2023-01-02 14:28:25 +00:00
2023-07-21 09:01:01 +00:00
pub fn extractValues ( self : * const CompilationContext , v : anytype , values : [ ] mlir . Value ) void {
meta . collectBuf ( getValue , self , v , values ) ;
2023-01-02 14:28:25 +00:00
}
} ;
fn computeModuleHash ( platform : Platform , module : mlir . Module ) u64 {
var hasher = std . hash . XxHash64 . init ( 0 ) ;
var hasher_writer = xxHash64Writer ( & hasher ) ;
const writer = hasher_writer . writer ( ) ;
// Hash the canonicalized IR, without debug information that can change across builds.
2024-01-05 16:44:41 +00:00
module . op ( ) . print ( writer , . { . debug_info = false } ) ;
// Note: before we where using module.op().writeBytecode(writer),
// but it crashes on some inputs, notably for unused variables.
// So we use the text representation of the mlir.
// See https://github.com/zml/zml/issues/97.
2023-01-02 14:28:25 +00:00
// Writes can't fail because we are writing to a hasher.
writer . writeAll ( platform . pjrt_client . getPlatformName ( platform . pjrt_api ) ) catch unreachable ;
const api_version = platform . pjrt_api . version ( ) ;
writer . writeInt ( i64 , api_version . major , . little ) catch unreachable ;
writer . writeInt ( i64 , api_version . minor , . little ) catch unreachable ;
return hasher . final ( ) ;
}
2023-07-14 17:58:22 +00:00
const max_pjrt_executable_size = 400 * 1024 * 1024 ;
2023-01-02 14:28:25 +00:00
2023-10-13 16:08:08 +00:00
fn loadPjrtExecutable ( arena : std . mem . Allocator , platform : Platform , absolute_file : [ : 0 ] const u8 ) ! * pjrt . LoadedExecutable {
2023-12-04 10:38:10 +00:00
const tracer = Tracer . init ( " ai.zml.load_exe " ) ;
const compile_frame = tracer . frameStart ( " pjrt load executable " ) ;
defer tracer . frameEnd ( compile_frame , " pjrt load executable " ) ;
2023-10-13 16:08:08 +00:00
const loaded_executable_file = try std . fs . openFileAbsoluteZ ( absolute_file , . { } ) ;
2023-01-02 14:28:25 +00:00
defer loaded_executable_file . close ( ) ;
2023-10-13 16:08:08 +00:00
const exe_size = if ( loaded_executable_file . stat ( ) ) | stat | stat . size else | _ | max_pjrt_executable_size ;
const bytes = try arena . alloc ( u8 , exe_size ) ;
defer arena . free ( bytes ) ;
2023-01-02 14:28:25 +00:00
2023-10-13 16:08:08 +00:00
const size = try loaded_executable_file . readAll ( bytes ) ;
2023-12-04 10:38:10 +00:00
return try platform . pjrt_client . deserializeAndLoad ( platform . pjrt_api , bytes [ 0 . . size ] ) ;
2023-10-13 16:08:08 +00:00
}
2023-01-02 14:28:25 +00:00
2023-10-13 16:08:08 +00:00
fn storePjrtExecutable ( platform : Platform , loaded_executable : * pjrt . LoadedExecutable , absolute_file : [ : 0 ] const u8 ) ! void {
const loaded_executable_file = try std . fs . createFileAbsoluteZ ( absolute_file , . { } ) ;
2023-01-02 14:28:25 +00:00
defer loaded_executable_file . close ( ) ;
var executable = try loaded_executable . getExecutable ( platform . pjrt_api ) ;
defer executable . deinit ( platform . pjrt_api ) ;
var serialize_result = try executable . serialize ( platform . pjrt_api ) ;
defer serialize_result . deinit ( ) ;
try loaded_executable_file . writeAll ( serialize_result . bytes ) ;
}
2023-12-04 10:38:10 +00:00
fn compileModuleToPjrtExecutable ( arena : std . mem . Allocator , platform : Platform , module : mlir . Module , xla_dump_to_ : ? [ ] const u8 ) ! * pjrt . LoadedExecutable {
const tracer = Tracer . init ( " ai.zml.compilation " ) ;
const compile_frame = tracer . frameStart ( " pjrt compilation " ) ;
defer tracer . frameEnd ( compile_frame , " pjrt compilation " ) ;
2023-02-24 17:33:14 +00:00
const sharding = platform . sharding ( ) ;
2023-08-18 17:11:27 +00:00
// NOTE(Corendos): Hack needed because Protobuf struct are not public.
const DeviceAssignmentProto = @TypeOf ( xla_pb . CompileOptionsProto . init ( ) . executable_build_options . ? . device_assignment . ? ) ;
2023-01-02 14:28:25 +00:00
var options : xla_pb . CompileOptionsProto = . {
. executable_build_options = . {
2023-08-18 17:11:27 +00:00
. device_ordinal = - 1 ,
2023-02-24 17:33:14 +00:00
. num_replicas = sharding . num_replicas ,
. num_partitions = sharding . num_partitions ,
. use_spmd_partitioning = sharding . num_partitions > 1 or sharding . num_replicas > 1 ,
2023-08-18 17:11:27 +00:00
. device_assignment = . {
. replica_count = sharding . num_replicas ,
. computation_count = sharding . num_partitions ,
. computation_devices = blk : {
var computation_devices = try std . ArrayListUnmanaged ( DeviceAssignmentProto . ComputationDevice ) . initCapacity ( arena , sharding . num_partitions ) ;
for ( 0 . . sharding . num_partitions ) | i | {
var replica_device_ids = std . ArrayListUnmanaged ( i64 ) . initCapacity ( arena , 1 ) catch unreachable ;
replica_device_ids . appendAssumeCapacity ( @intCast ( i ) ) ;
computation_devices . appendAssumeCapacity ( . { . replica_device_ids = replica_device_ids } ) ;
}
break : blk computation_devices ;
} ,
} ,
2023-01-02 14:28:25 +00:00
} ,
} ;
2023-08-18 17:11:27 +00:00
2023-01-02 14:28:25 +00:00
// Let the arena deinit, zig-protobuf deinit is very slow.
2023-12-04 10:38:10 +00:00
try options . env_option_overrides . ensureUnusedCapacity ( arena , 16 ) ;
if ( xla_dump_to_ orelse platform . compilation_options . xla_dump_to ) | xla_dump_to | {
setFlag ( & options , " xla_dump_to " , xla_dump_to ) ;
2024-05-02 17:10:11 +00:00
setFlag ( & options , " xla_dump_hlo_as_proto " , true ) ;
2023-01-02 14:28:25 +00:00
if ( platform . compilation_options . xla_dump_fusion_visualization ) {
2023-12-04 10:38:10 +00:00
setFlag ( & options , " xla_dump_fusion_visualization " , true ) ;
2023-01-02 14:28:25 +00:00
}
2023-12-18 13:56:45 +00:00
if ( platform . compilation_options . xla_dump_hlo_pass_re ) | re | {
setFlag ( & options , " xla_dump_hlo_pass_re " , re ) ;
}
2023-01-02 14:28:25 +00:00
}
switch ( platform . target ) {
2024-03-05 17:04:42 +00:00
. cuda = > {
2024-01-15 09:41:42 +00:00
// NVIDIA recommends these settings
2023-01-02 14:28:25 +00:00
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
2023-12-04 10:38:10 +00:00
setFlag ( & options , " xla_gpu_enable_triton_gemm " , false ) ;
2024-01-15 09:41:42 +00:00
setFlag ( & options , " xla_gpu_enable_latency_hiding_scheduler " , true ) ;
2024-02-06 09:31:48 +00:00
setFlag ( & options , " xla_gpu_enable_llvm_module_compilation_parallelism " , true ) ;
setFlag ( & options , " xla_gpu_enable_libnvptxcompiler " , true ) ;
2023-12-04 10:38:10 +00:00
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
2024-05-15 17:54:52 +00:00
// setFlags(&options, "xla_gpu_enable_address_computation_fusion", true);
2023-12-04 10:38:10 +00:00
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
2024-01-08 17:55:20 +00:00
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
2023-12-04 10:38:10 +00:00
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
2023-01-02 14:28:25 +00:00
} ,
. rocm = > {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
// enabled on CDNA and it's used on RDNA. Disable it altogether.
2023-12-04 10:38:10 +00:00
setFlag ( & options , " xla_gpu_enable_triton_gemm " , false ) ;
2023-01-02 14:28:25 +00:00
} ,
else = > { } ,
}
2023-02-14 13:52:49 +00:00
const options_bytes = try options . encode ( arena ) ;
2023-03-06 17:05:56 +00:00
2023-05-26 15:54:15 +00:00
const loaded_executable = try platform . pjrt_client . compile ( platform . pjrt_api , arena , module , options_bytes ) ;
2023-03-06 17:05:56 +00:00
errdefer loaded_executable . deinit ( ) ;
2023-01-02 14:28:25 +00:00
return loaded_executable ;
}
2023-12-04 10:38:10 +00:00
fn setFlag ( options : * xla_pb . CompileOptionsProto , comptime flag : [ : 0 ] const u8 , value : anytype ) void {
const option : xla_pb . OptionOverrideProto = switch ( @typeInfo ( @TypeOf ( value ) ) ) {
2024-07-02 14:19:04 +00:00
. bool = > . { . value = . { . bool_field = value } } ,
. comptime_int , . int = > . { . value = . { . int_field = value } } ,
. comptime_float , . float = > . { . value = . { . double_field = value } } ,
2023-12-04 10:38:10 +00:00
else = > . { . value = . { . string_field = . { . Const = value } } } ,
} ;
options . env_option_overrides . appendAssumeCapacity ( . { . key = . { . Const = flag } , . value = option } ) ;
}
2023-10-13 16:08:08 +00:00
/// Visit the given struct and recursively counts the number of tensors found.
pub fn countTensors ( v : anytype ) usize {
const LocalContext = struct {
count : usize = 0 ,
} ;
var context = LocalContext { } ;
meta . visit ( ( struct {
fn cb ( inner_context : * LocalContext , _ : * const Tensor ) void {
inner_context . count + = 1 ;
}
} ) . cb , & context , v ) ;
return context . count ;
}
/// Visit the given struct and recursively fill the `types` slice with the mlir.Type associated with encountered Tensor.
pub fn fillMlirTypes ( v : anytype , mlir_ctx : mlir . Context , types : [ ] mlir . Type ) void {
const LocalContext = struct {
index : usize = 0 ,
mlir_ctx : mlir . Context ,
types : [ ] mlir . Type ,
} ;
var context = LocalContext { . mlir_ctx = mlir_ctx , . types = types } ;
meta . visit ( ( struct {
fn cb ( inner_context : * LocalContext , tensor : * const Tensor ) void {
inner_context . types [ inner_context . index ] = mlir . ext . mlirType ( inner_context . mlir_ctx , tensor . shape ( ) ) ;
inner_context . index + = 1 ;
}
} ) . cb , & context , v ) ;
std . debug . assert ( context . index = = types . len ) ;
}
/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor.
///
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
/// stored in the Module.
fn assignBlockArguments ( v : anytype , block : mlir . Block , start : usize ) usize {
const LocalContext = struct { index : usize , block : mlir . Block } ;
var context = LocalContext { . block = block , . index = start } ;
meta . visit ( ( struct {
fn cb ( ctx : * LocalContext , tensor : * Tensor ) void {
tensor . _id = . { . mlir = ctx . block . argument ( ctx . index ) } ;
tensor . _donation = . { . arg = @intCast ( ctx . index ) } ;
ctx . index + = 1 ;
}
} ) . cb , & context , v ) ;
return context . index ;
}
2023-01-02 14:28:25 +00:00
pub const XxHash64Writer = struct {
hasher : * std . hash . XxHash64 ,
pub const Error = error { } ;
pub const Writer = std . io . Writer ( * XxHash64Writer , Error , write ) ;
pub fn writer ( self : * XxHash64Writer ) Writer {
return . { . context = self } ;
}
pub fn write ( self : * XxHash64Writer , bytes : [ ] const u8 ) Error ! usize {
self . hasher . update ( bytes ) ;
return bytes . len ;
}
} ;
pub fn xxHash64Writer ( hasher : * std . hash . XxHash64 ) XxHash64Writer {
return . { . hasher = hasher } ;
}
2024-05-15 17:54:52 +00:00
pub const FnCache = std . AutoHashMapUnmanaged ( FnKey , MlirFn ) ;
pub const FnKey = struct { fn_ptr : * const anyopaque , input_hash : u64 } ;
2023-01-02 14:28:25 +00:00
test FnCache {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
2023-10-19 17:01:55 +00:00
const Layer = struct {
const Layer_ = @This ( ) ;
w : Tensor ,
b : Tensor ,
2024-01-08 17:55:20 +00:00
pub fn _fwd ( self : Layer_ , x : Tensor ) Tensor {
2023-10-19 17:01:55 +00:00
const wx = self . w . dotGeneral ( x , & . { . { - 1 , 0 } } , & . { } ) ;
return wx . add ( self . b . broad ( wx . shape ( ) ) ) . relu ( ) ;
}
} ;
2023-01-02 14:28:25 +00:00
const NN = struct {
const NN_ = @This ( ) ;
2023-10-19 17:01:55 +00:00
layers : [ 3 ] Layer ,
2023-01-02 14:28:25 +00:00
2024-01-08 17:55:20 +00:00
pub fn _fwd ( self : NN_ , x0 : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
var x = x0 ;
2023-10-19 17:01:55 +00:00
for ( self . layers ) | layer | {
2024-01-08 17:55:20 +00:00
x = ops . call ( layer , . _fwd , . { x } ) ;
2023-01-02 14:28:25 +00:00
}
return x ;
}
2024-01-08 17:55:20 +00:00
pub fn _forwardRefImpl ( self : NN_ , x0 : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
var x = x0 ;
2023-10-19 17:01:55 +00:00
for ( self . layers ) | layer | {
2024-01-08 17:55:20 +00:00
x = layer . _fwd ( x ) ;
2023-01-02 14:28:25 +00:00
}
return x ;
}
} ;
const x = try zml . Buffer . fromSlice ( platform , . { 2 } , & [ _ ] f16 { - 1 , 1 } ) ;
const nn : zml . Bufferized ( NN ) = . {
2023-10-19 17:01:55 +00:00
. layers = . {
. {
. w = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f16 { 1 , - 1 , 0 , 1 } ) ,
. b = try zml . Buffer . fromSlice ( platform , . { 2 } , & [ _ ] f16 { 0 , 0 } ) ,
} ,
. {
. w = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f16 { 1 , 2 , 1 , - 1 } ) ,
. b = try zml . Buffer . fromSlice ( platform , . { 2 } , & [ _ ] f16 { 10 , 10 } ) ,
} ,
2023-01-02 14:28:25 +00:00
// third layer is different
2023-10-19 17:01:55 +00:00
. {
. w = try zml . Buffer . fromSlice ( platform , . { 3 , 2 } , & [ _ ] f16 { 1 , 2 , 0 , 1 , - 1 , 0 } ) ,
. b = try zml . Buffer . fromSlice ( platform , . { 3 } , & [ _ ] f16 { - 10 , - 10 , - 10 } ) ,
} ,
2023-01-02 14:28:25 +00:00
} ,
} ;
2024-01-08 17:55:20 +00:00
const res = try zml . testing . compileAndCall ( platform , NN . _fwd , . { nn , x } ) ;
const expected = try zml . testing . compileAndCall ( platform , NN . _forwardRefImpl , . { nn , x } ) ;
2023-01-02 14:28:25 +00:00
try zml . testing . expectClose ( expected , res , 1e-4 ) ;
}
2024-05-15 17:54:52 +00:00
test " FnCache with mixed integer/tensor " {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const Layer = struct {
const Layer_ = @This ( ) ;
var num_call : u32 = 0 ;
w : Tensor ,
pub fn _fwd ( self : Layer_ , x : Tensor ) struct { Tensor , usize } {
const wx = self . w . dotGeneral ( x , & . { . { - 1 , 0 } } , & . { } ) ;
// Note: this is for testing only, it's a bad idea to mutate global state
// from a forward function because it can mess with caching.
num_call + = 1 ;
return . { wx . addConstant ( num_call ) , num_call } ;
}
} ;
const NN = struct {
const NN_ = @This ( ) ;
layers : [ 3 ] Layer ,
pub fn _fwd ( self : NN_ , x0 : Tensor ) Tensor {
var x = x0 ;
var y : usize = 0 ;
x , y = ops . call ( self . layers [ 0 ] , . _fwd , . { x } ) ;
std . debug . assert ( Layer . num_call = = 1 ) ;
std . debug . assert ( y = = 1 ) ;
// Here we call a second time but since first two layers have the same shape,
// We hit the function cache, and "num_call" is not incremented.
x , y = ops . call ( self . layers [ 1 ] , . _fwd , . { x } ) ;
std . debug . assert ( Layer . num_call = = 1 ) ;
std . debug . assert ( y = = 1 ) ;
x , y = ops . call ( self . layers [ 2 ] , . _fwd , . { x } ) ;
std . debug . assert ( Layer . num_call = = 2 ) ;
std . debug . assert ( y = = 2 ) ;
return x ;
}
pub fn _forwardRefImpl ( self : NN_ , x0 : Tensor ) Tensor {
var x = x0 ;
for ( self . layers , & [ _ ] u32 { 1 , 1 , 2 } ) | layer , bias | {
const wx = layer . w . dotGeneral ( x , & . { . { - 1 , 0 } } , & . { } ) ;
x = wx . addConstant ( bias ) ;
}
return x ;
}
} ;
const x = try zml . Buffer . fromSlice ( platform , . { 2 } , & [ _ ] f16 { - 1 , 1 } ) ;
const nn : zml . Bufferized ( NN ) = . {
. layers = . {
. { . w = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f16 { 1 , - 1 , 0 , 1 } ) } ,
. { . w = try zml . Buffer . fromSlice ( platform , . { 2 , 2 } , & [ _ ] f16 { 1 , 2 , 1 , - 1 } ) } ,
// third layer has different shape
. { . w = try zml . Buffer . fromSlice ( platform , . { 3 , 2 } , & [ _ ] f16 { 1 , 2 , 0 , 1 , - 1 , 0 } ) } ,
} ,
} ;
const res = try zml . testing . compileAndCall ( platform , NN . _fwd , . { nn , x } ) ;
const expected = try zml . testing . compileAndCall ( platform , NN . _forwardRefImpl , . { nn , x } ) ;
try zml . testing . expectClose ( expected , res , 1e-4 ) ;
}
2023-01-02 14:28:25 +00:00
pub fn hashArgs ( mod : anytype ) u64 {
var hasher = std . hash . Wyhash . init ( 0 ) ;
hash ( & hasher , mod , . DeepRecursive ) ;
return hasher . final ( ) ;
}
2023-10-13 16:08:08 +00:00
pub fn hashShape ( hasher : * std . hash . Wyhash , shape : Shape ) void {
2023-01-02 14:28:25 +00:00
// Note: if we enforced 0-init dims then we could hash dims instead.
2023-10-13 16:08:08 +00:00
hashArray ( hasher , shape . dims ( ) , . Shallow ) ;
hash ( hasher , shape . _dtype , . Shallow ) ;
hash ( hasher , shape . _sharding_info , . Shallow ) ;
for ( shape . tags ( ) ) | tag | {
hash ( hasher , @intFromPtr ( tag ) , . Shallow ) ;
}
2023-01-02 14:28:25 +00:00
}
const HashStrategy = std . hash . Strategy ;
const tensorAwareHash = hash ; // alias for when "hash" is ambiguous
/// Provides generic hashing for any eligible type.
/// Strategy is provided to determine if pointers should be followed or not.
pub fn hash ( hasher : * std . hash . Wyhash , key : anytype , comptime strat : HashStrategy ) void {
const Key = @TypeOf ( key ) ;
2023-10-13 16:08:08 +00:00
if ( Key = = Tensor ) return hashShape ( hasher , key . shape ( ) ) ;
if ( Key = = Shape ) return hashShape ( hasher , key ) ;
2023-01-02 14:28:25 +00:00
if ( strat = = . Shallow and std . meta . hasUniqueRepresentation ( Key ) ) {
hasher . update ( std . mem . asBytes ( & key ) ) ;
return ;
}
switch ( @typeInfo ( Key ) ) {
2024-07-02 14:19:04 +00:00
. noreturn , . @ " opaque " , . undefined , . null , . comptime_float , . comptime_int , . type , . enum_literal , . frame , . void = > return ,
2023-01-02 14:28:25 +00:00
// Help the optimizer see that hashing an int is easy by inlining!
// TODO Check if the situation is better after #561 is resolved.
2024-07-02 14:19:04 +00:00
. int = > | int | switch ( int . signedness ) {
. signed = > hash ( hasher , @as ( @Type ( . { . int = . {
2023-01-02 14:28:25 +00:00
. bits = int . bits ,
. signedness = . unsigned ,
} } ) , @bitCast ( key ) ) , strat ) ,
. unsigned = > {
if ( std . meta . hasUniqueRepresentation ( Key ) ) {
hasher . update ( std . mem . asBytes ( & key ) ) ;
} else {
// Take only the part containing the key value, the remaining
// bytes are undefined and must not be hashed!
const byte_size = comptime std . math . divCeil ( comptime_int , @bitSizeOf ( Key ) , 8 ) catch unreachable ;
hasher . update ( std . mem . asBytes ( & key ) [ 0 . . byte_size ] ) ;
}
} ,
} ,
// Note: contrary to Zig we accept hashing floats.
// Typically the float we are going to hash here are hyperparameters,
// and not the result of an operation, so bytes should be the same everytime.
2024-07-02 14:19:04 +00:00
. float = > hasher . update ( std . mem . asBytes ( & key ) ) ,
. bool = > hash ( hasher , @intFromBool ( key ) , strat ) ,
. @ " enum " = > hash ( hasher , @intFromEnum ( key ) , strat ) ,
. error_set = > hash ( hasher , @intFromError ( key ) , strat ) ,
. @ " anyframe " , . @ " fn " = > hash ( hasher , @intFromPtr ( key ) , strat ) ,
. pointer = > | info | switch ( info . size ) {
. one = > switch ( strat ) {
. shallow = > hash ( hasher , @intFromPtr ( key ) , . Shallow ) ,
. deep = > hash ( hasher , key . * , . Shallow ) ,
. deeprecursive = > switch ( @typeInfo ( info . child ) ) {
. @ " opaque " , . @ " fn " = > hash ( hasher , @intFromPtr ( key ) , . Shallow ) ,
2023-01-02 14:28:25 +00:00
else = > hash ( hasher , key . * , . DeepRecursive ) ,
} ,
} ,
2024-07-02 14:19:04 +00:00
. slice = > {
2023-01-02 14:28:25 +00:00
switch ( strat ) {
. Shallow = > hash ( hasher , @intFromPtr ( key . ptr ) , . Shallow ) ,
. Deep = > hashArray ( hasher , key , . Shallow ) ,
. DeepRecursive = > hashArray ( hasher , key , . DeepRecursive ) ,
}
hash ( hasher , key . len , . Shallow ) ;
} ,
2024-07-02 14:19:04 +00:00
. many ,
. c ,
2023-01-02 14:28:25 +00:00
= > switch ( strat ) {
2024-07-02 14:19:04 +00:00
. shallow = > hash ( hasher , @intFromPtr ( key ) , . Shallow ) ,
2023-01-02 14:28:25 +00:00
else = > @compileError (
\\ unknown-length pointers and C pointers cannot be hashed deeply.
\\ Consider providing your own hash function.
) ,
} ,
} ,
2024-07-02 14:19:04 +00:00
. optional = > if ( key ) | k | hash ( hasher , k , strat ) ,
2023-01-02 14:28:25 +00:00
2024-07-02 14:19:04 +00:00
. array = > hashArray ( hasher , key , strat ) ,
2023-01-02 14:28:25 +00:00
2024-07-02 14:19:04 +00:00
. vector = > | info | {
2023-01-02 14:28:25 +00:00
if ( std . meta . hasUniqueRepresentation ( Key ) ) {
hasher . update ( std . mem . asBytes ( & key ) ) ;
} else {
comptime var i = 0 ;
inline while ( i < info . len ) : ( i + = 1 ) {
hash ( hasher , key [ i ] , strat ) ;
}
}
} ,
2024-07-02 14:19:04 +00:00
. @ " struct " = > | info | {
2023-01-02 14:28:25 +00:00
inline for ( info . fields ) | field | {
// We reuse the hash of the previous field as the seed for the
// next one so that they're dependant.
hash ( hasher , @field ( key , field . name ) , strat ) ;
}
} ,
2024-07-02 14:19:04 +00:00
. @ " union " = > | info | {
2023-01-02 14:28:25 +00:00
if ( info . tag_type ) | tag_type | {
const tag = std . meta . activeTag ( key ) ;
hash ( hasher , tag , strat ) ;
inline for ( info . fields ) | field | {
if ( @field ( tag_type , field . name ) = = tag ) {
if ( field . type ! = void ) {
hash ( hasher , @field ( key , field . name ) , strat ) ;
}
// TODO use a labelled break when it does not crash the compiler. cf #2908
// break :blk;
return ;
}
}
unreachable ;
} else @compileError ( " cannot hash untagged union type: " + + @typeName ( Key ) + + " , provide your own hash function " ) ;
} ,
2024-07-02 14:19:04 +00:00
. error_union = > blk : {
2023-01-02 14:28:25 +00:00
const payload = key catch | err | {
hash ( hasher , err , strat ) ;
break : blk ;
} ;
hash ( hasher , payload , strat ) ;
} ,
}
}
fn hashArray ( hasher : anytype , key : anytype , comptime strat : HashStrategy ) void {
for ( key ) | element | {
hash ( hasher , element , strat ) ;
}
}