2023-01-02 14:28:25 +00:00
const std = @import ( " std " ) ;
2023-06-21 14:45:14 +00:00
const stdx = @import ( " stdx " ) ;
2023-01-02 14:28:25 +00:00
2023-06-21 14:45:14 +00:00
const buffer = @import ( " buffer.zig " ) ;
2023-01-02 14:28:25 +00:00
const helpers = @import ( " helpers.zig " ) ;
const meta = @import ( " meta.zig " ) ;
2023-06-21 14:45:14 +00:00
const mlir = @import ( " mlir.zig " ) ;
const module = @import ( " module.zig " ) ;
2023-01-02 14:28:25 +00:00
2023-06-21 14:45:14 +00:00
const Buffer = buffer . Buffer ;
2023-01-02 14:28:25 +00:00
const CompilationContext = module . CompilationContext ;
2023-06-05 13:42:45 +00:00
const Context = @import ( " context.zig " ) . Context ;
2023-01-02 14:28:25 +00:00
const Data = @import ( " dtype.zig " ) . Data ;
const DataType = @import ( " dtype.zig " ) . DataType ;
const EnumLiteral = @TypeOf ( . enum_literal ) ;
2023-06-05 13:42:45 +00:00
const HostBuffer = @import ( " hostbuffer.zig " ) . HostBuffer ;
const Shape = @import ( " shape.zig " ) . Shape ;
const Tensor = @import ( " tensor.zig " ) . Tensor ;
2023-01-02 14:28:25 +00:00
const dialect = struct {
const stablehlo = @import ( " mlir/dialects " ) . stablehlo ;
} ;
const assert = std . debug . assert ;
2023-06-21 14:45:14 +00:00
const log = std . log . scoped ( . @ " zml/tensor " ) ;
2023-01-02 14:28:25 +00:00
2023-01-23 16:28:19 +00:00
test {
std . testing . refAllDecls ( @This ( ) ) ;
}
2023-01-02 14:28:25 +00:00
/// Generate an MLIR call to the given member function with the given tensors.
2023-06-21 14:45:14 +00:00
pub fn call ( self : anytype , comptime func : stdx . meta . DeclEnum ( @TypeOf ( self ) ) , args : anytype ) @TypeOf ( @call ( . auto , @field ( stdx . meta . UnwrapPtr ( @TypeOf ( self ) ) , @tagName ( func ) ) , . { self } + + args ) ) {
2023-01-02 14:28:25 +00:00
// TODO: this should use `self.getContext().callFunc(self, args)`
return @call ( . auto , @field ( @TypeOf ( self ) , @tagName ( func ) ) , . { self } + + args ) ;
}
pub fn while_ (
comptime cond_fn : anytype ,
comptime body_fn : anytype ,
blkctx : BlockSign ( body_fn ) . BlkCtx ,
inputs : BlockSign ( body_fn ) . Args ,
) BlockSign ( body_fn ) . Return {
const CondS = comptime BlockSign ( cond_fn ) ;
const BodyS = comptime BlockSign ( body_fn ) ;
if ( CondS . Args ! = BodyS . Args ) {
@compileError ( " cond_fn and body_fn signatures don't match ! " + + @typeName ( @TypeOf ( cond_fn ) ) + + " and " + + @typeName ( @TypeOf ( body_fn ) ) ) ;
}
const ctx = CompilationContext . current ( ) ;
2023-07-21 09:01:01 +00:00
const cond_block , _ = ctx . makeBlock ( CondS , & cond_fn , blkctx , inputs ) ;
const body_block , const body_res = ctx . makeBlock ( BodyS , & body_fn , blkctx , inputs ) ;
2023-01-02 14:28:25 +00:00
var input_values : [ BodyS . nIn ] mlir . Value = undefined ;
ctx . extractValues ( & inputs , & input_values ) ;
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) ;
const op = mlir . Operation . make ( ctx . mlirCtx ( ) , " stablehlo.while " , . {
. variadic_operands = & . { & input_values } ,
. result_type_inference = true ,
. blocks = & . { cond_block , body_block } ,
// We can't verify right away, cause the weights captured by the while haven't been added yet.
. verify = false ,
. location = loc ,
} ) ;
2023-07-21 09:01:01 +00:00
return fromMlirOperationWithTags ( op , body_res ) ;
2023-01-02 14:28:25 +00:00
}
test " simple while " {
const CountInts = struct {
step : Tensor ,
end : Tensor ,
const CountInts = @This ( ) ;
pub fn hasNext ( self : CountInts , i : Tensor , sum : Tensor ) Tensor {
_ = sum ;
return i . cmp ( . LT , self . end ) ;
}
pub fn next ( self : CountInts , i : Tensor , sum : Tensor ) [ 2 ] Tensor {
const r1 = i . add ( self . step ) ;
const r2 = sum . add ( i ) ;
return . { r1 , r2 } ;
}
pub fn forward ( self : CountInts , init_i : Tensor , init_sum : Tensor ) [ 2 ] Tensor {
const x = init_i . scale ( 2 ) ;
return while_ ( CountInts . hasNext , CountInts . next , self , . { x , init_sum } ) ;
}
pub fn zigForward ( step : i64 , end : i64 , init_i : i64 , init_sum : i64 ) [ 2 ] i64 {
const x = init_i * 2 ;
var i = x ;
var sum = init_sum ;
while ( i < end ) {
const r1 = i + step ;
const r2 = sum + i ;
i , sum = . { r1 , r2 } ;
}
return . { i , sum } ;
}
} ;
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const init_i = try zml . Buffer . fromSlice ( platform , . { } , & [ _ ] i64 { 0 } ) ;
const init_sum = try zml . Buffer . fromSlice ( platform , . { } , & [ _ ] i64 { 0 } ) ;
const counter = . {
. step = try zml . Buffer . fromSlice ( platform , . { } , & [ _ ] i64 { 1 } ) ,
. end = try zml . Buffer . fromSlice ( platform , . { } , & [ _ ] i64 { 10 } ) ,
} ;
const res0 , const res1 = try zml . testing . compileAndCall ( platform , CountInts . forward , . { counter , init_i , init_sum } ) ;
const last_i = try res0 . getValue ( i64 ) ;
const sum = try res1 . getValue ( i64 ) ;
try std . testing . expectEqual ( 10 , last_i ) ;
try std . testing . expectEqual ( 45 , sum ) ;
try std . testing . expectEqual ( . { 10 , 45 } , CountInts . zigForward ( 1 , 10 , 0 , 0 ) ) ;
}
pub fn reduce (
comptime body_fn : anytype ,
2023-06-21 14:45:14 +00:00
inputs : stdx . meta . FnParam ( body_fn , 0 ) ,
inits : stdx . meta . FnParam ( body_fn , 0 ) ,
2023-01-02 14:28:25 +00:00
axes : [ ] const i64 ,
) BlockSignNoCtx ( body_fn ) . Return {
// TODO: actualAxes
const BodyS = comptime BlockSignNoCtx ( body_fn ) ;
comptime {
if ( BodyS . Return ! = @TypeOf ( inputs ) ) @compileError ( " reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " + + @typeName ( body_fn ) ) ;
}
const ctx = CompilationContext . current ( ) ;
const N = comptime @divExact ( BodyS . nIn , 2 ) ;
var input_values : [ N ] mlir . Value = undefined ;
ctx . extractValues ( & inputs , & input_values ) ;
var init_values : [ N ] mlir . Value = undefined ;
ctx . extractValues ( & inits , & init_values ) ;
2023-07-21 09:01:01 +00:00
const body_block , _ = ctx . makeBlock ( BodyS , & body_fn , { } , . { inits , inits } ) ;
2023-01-02 14:28:25 +00:00
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) ;
const op = mlir . Operation . make ( ctx . mlirCtx ( ) , " stablehlo.reduce " , . {
. variadic_operands = & . { & input_values , & init_values } ,
. result_type_inference = true ,
. blocks = & . { body_block } ,
. attributes = & . {
. { " dimensions " , mlir . DenseArrayAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , axes ) . as ( mlir . Attribute ) . ? } ,
} ,
// We can't verify right away, cause the weights captured by the reduce haven't been added yet.
. verify = false ,
. location = loc ,
} ) ;
// `stablehlo.reduce` drops axes. We want to avoid that to propagate tags.
// So we need to broadcast the output of `stablehlo.reduce` to the input shapes.
2023-06-21 14:45:14 +00:00
// To that order, we initialize `result` to `inputs`, then we use stdx.meta.visit,
2023-01-02 14:28:25 +00:00
// to find the correct mlir.Value, but we first broadcast before creating the final
// Tensor struct.
var broadcasting_axes : std . BoundedArray ( i64 , Tensor . MAX_RANK ) = . { } ;
for ( 0 . . Tensor . MAX_RANK ) | i | {
if ( std . mem . indexOfScalar ( i64 , axes , @intCast ( i ) ) = = null ) {
broadcasting_axes . append ( @intCast ( i ) ) catch unreachable ;
}
}
var res : BodyS . Return = inputs ;
const LocalContext = struct {
axes : [ ] const i64 ,
broadcasting_axes : [ ] const i64 ,
n_reduced : u8 ,
op : mlir . Operation ,
loc : mlir . Location ,
index : usize = 0 ,
} ;
var local_context = LocalContext {
. axes = axes ,
. broadcasting_axes = broadcasting_axes . constSlice ( ) ,
. n_reduced = @intCast ( axes . len ) ,
. op = op ,
. loc = loc ,
} ;
meta . visit ( ( struct {
fn cb ( inner_ctx : * LocalContext , tensor : * Tensor ) void {
const val = inner_ctx . op . result ( inner_ctx . index ) ;
// compute the target reduced shape
var reduced_shape = tensor . shape ( ) ;
for ( inner_ctx . axes ) | a | {
reduced_shape = reduced_shape . setDim ( a , 1 ) ;
}
const mlir_ctx = CompilationContext . current ( ) . mlirCtx ( ) ;
const broad_val = dialect . stablehlo . broadcast_in_dim (
mlir_ctx ,
val ,
inner_ctx . broadcasting_axes [ 0 . . tensor . rank ( ) - inner_ctx . n_reduced ] ,
mlir . ext . RankedTensorType . fromShape ( mlir_ctx , reduced_shape ) . as ( mlir . Type ) . ? ,
inner_ctx . loc ,
) ;
tensor . * = Tensor . _result ( reduced_shape , broad_val . result ( 0 ) ) ;
inner_ctx . index + = 1 ;
}
} ) . cb , & local_context , & res ) ;
assert ( local_context . index = = op . numResults ( ) ) ;
return res ;
}
pub const ReduceWindowOpts = struct {
// TODO replace with Shape
window_dimensions : [ ] const i64 ,
window_strides : [ ] const i64 ,
base_dilations : [ ] const i64 ,
window_dilations : [ ] const i64 ,
2023-05-17 09:01:27 +00:00
padding : [ ] const [ 2 ] i64 ,
2023-01-02 14:28:25 +00:00
} ;
pub fn reduceWindow (
comptime body_fn : anytype ,
2023-06-21 14:45:14 +00:00
inputs : stdx . meta . FnParam ( body_fn , 0 ) ,
inits : stdx . meta . FnParam ( body_fn , 0 ) ,
2023-01-02 14:28:25 +00:00
opts : ReduceWindowOpts ,
2023-06-21 14:45:14 +00:00
) stdx . meta . FnResult ( body_fn ) {
2023-01-02 14:28:25 +00:00
const BodyS = comptime BlockSignNoCtx ( body_fn ) ;
comptime {
if ( BodyS . Return ! = @TypeOf ( inputs ) ) @compileError ( " reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " + + @typeName ( body_fn ) ) ;
}
const ctx = CompilationContext . current ( ) ;
2023-07-21 09:01:01 +00:00
const body_block , _ = ctx . makeBlock ( BodyS , & body_fn , { } , . { inits , inits } ) ;
2023-01-02 14:28:25 +00:00
const N = comptime @divExact ( BodyS . nIn , 2 ) ;
var input_values : [ N ] mlir . Value = undefined ;
ctx . extractValues ( & inputs , & input_values ) ;
var init_values : [ N ] mlir . Value = undefined ;
ctx . extractValues ( & inits , & init_values ) ;
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) ;
2023-05-17 09:01:27 +00:00
const pad_shape = mlir . RankedTensorType . init (
& . { @intCast ( opts . padding . len ) , 2 } ,
mlir . ext . Type . fromDType ( ctx . mlirCtx ( ) , . i64 ) ,
) . as ( mlir . Type ) . ? ;
2023-01-02 14:28:25 +00:00
const op = mlir . Operation . make ( ctx . mlirCtx ( ) , " stablehlo.reduce_window " , . {
. variadic_operands = & . { input_values [ 0 . . ] , init_values [ 0 . . ] } ,
. result_type_inference = true ,
. blocks = & . { body_block } ,
. attributes = & . {
. { " window_dimensions " , mlir . DenseArrayAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , opts . window_dimensions ) . as ( mlir . Attribute ) . ? } ,
. { " window_strides " , mlir . DenseArrayAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , opts . window_strides ) . as ( mlir . Attribute ) . ? } ,
. { " base_dilations " , mlir . DenseArrayAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , opts . base_dilations ) . as ( mlir . Attribute ) . ? } ,
. { " window_dilations " , mlir . DenseArrayAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , opts . window_dilations ) . as ( mlir . Attribute ) . ? } ,
2023-05-17 09:01:27 +00:00
. { " padding " , mlir . DenseIntOrFPElementsAttribute ( . i64 ) . init ( pad_shape , std . mem . sliceAsBytes ( opts . padding ) ) . as ( mlir . Attribute ) . ? } ,
2023-01-02 14:28:25 +00:00
} ,
. location = loc ,
} ) ;
2023-07-21 09:01:01 +00:00
return fromMlirOperationWithTags ( op , inputs ) ;
2023-01-02 14:28:25 +00:00
}
/// Runs a given function for several steps, and returns a stack of each step output.
/// The step outputs will be stacked along the first axis.
pub fn for_ ( comptime func : anytype , blk_ctx : BlockSign ( func ) . BlkCtx , num_steps_ : anytype ) BlockSign ( func ) . Return {
const num_steps : u32 , const step_tag = blk : {
const dims , const tags = Shape . parseDimensions ( num_steps_ ) ;
2023-06-21 14:45:14 +00:00
stdx . debug . assert ( dims . len = = 1 , " zml.for_ only supports one num_step, Received: {any} " , . { num_steps_ } ) ;
2023-01-02 14:28:25 +00:00
break : blk . { @intCast ( dims . get ( 0 ) ) , tags . get ( 0 ) } ;
} ;
const S = comptime BlockSign ( func ) ;
const ForBlk = struct {
blk_ctx : S . BlkCtx ,
step_tag : @TypeOf ( step_tag ) , // This is a Shape.Tag, but we rather keep it private
num_steps : u32 ,
const Self = @This ( ) ;
fn next ( self : Self , res : S . Return , idx : Tensor ) struct { S . Return , Tensor } {
const step_res = @call ( . auto , func , . { self . blk_ctx , idx } ) ;
var buf : [ @sizeOf ( S . Return ) * 2 ] u8 = undefined ;
var fba = std . heap . FixedBufferAllocator . init ( & buf ) ;
return . {
meta . zip ( updateResBuffer , fba . allocator ( ) , & [ _ ] S . Return { res , step_res } , . { idx } ) catch unreachable ,
idx . addConstant ( 1 ) ,
} ;
}
fn done ( self : Self , res : S . Return , idx : Tensor ) Tensor {
_ = res ;
return idx . cmp ( . LT , Tensor . scalar ( self . num_steps , idx . dtype ( ) ) ) ;
}
fn updateResBuffer ( inputs : [ ] const Tensor , idx : Tensor ) Tensor {
2023-06-21 14:45:14 +00:00
stdx . debug . internalAssert ( inputs . len = = 2 , " too many tensors " , . { } ) ;
2023-01-02 14:28:25 +00:00
const res , const step_res = inputs [ 0 . . 2 ] . * ;
return res . dynamicUpdateSlice1d ( step_res . insertAxes ( 0 , . { . _ } ) , 0 , idx ) ;
}
/// Prepare buffer to store all results steps.
fn prep ( self : Self , x : Tensor ) Tensor {
var shape = x . shape ( ) ;
shape . _dims . insert ( 0 , self . num_steps ) catch unreachable ;
shape . _tags . insert ( 0 , self . step_tag ) catch unreachable ;
return Tensor . constant ( shape , x . dtype ( ) . zero ( ) ) ;
}
2023-06-19 15:29:29 +00:00
fn wrapFirstStep ( tag_ : @TypeOf ( step_tag ) , x : Tensor ) Tensor {
2023-01-02 14:28:25 +00:00
var shape = x . shape ( ) ;
shape . _dims . insert ( 0 , 1 ) catch unreachable ;
shape . _tags . insert ( 0 , tag_ ) catch unreachable ;
return x . reshape ( shape ) ;
}
} ;
// This first step won't appear in the generated MLIR,
// it's only used to infer the output shapes.
const first_step = @call ( . auto , func , . { blk_ctx , Tensor . scalar ( 0 , . i32 ) } ) ;
log . debug ( " for_ first_step: {} " , . { first_step } ) ;
2023-06-19 15:29:29 +00:00
const allocator = CompilationContext . current ( ) . _allocator ;
2023-01-02 14:28:25 +00:00
// Optimize for small num reps
if ( num_steps = = 1 ) {
2023-06-19 15:29:29 +00:00
var res = first_step ;
meta . mapAlloc ( ForBlk . wrapFirstStep , allocator , step_tag , first_step , & res ) catch unreachable ;
return res ;
2023-01-02 14:28:25 +00:00
}
if ( num_steps < = 4 ) {
var steps : [ 4 ] S . Return = undefined ;
steps [ 0 ] = first_step ;
for ( 1 . . num_steps ) | i | {
steps [ i ] = @call ( . auto , func , . { blk_ctx , Tensor . scalar ( i , . i32 ) } ) ;
}
const res = meta . zip ( Tensor . stack , allocator , steps [ 0 . . num_steps ] , . { 0 , step_tag } ) catch unreachable ;
return res ;
}
const for_blk : ForBlk = . { . blk_ctx = blk_ctx , . step_tag = step_tag , . num_steps = num_steps } ;
var result_buffers : @TypeOf ( first_step ) = undefined ;
try meta . mapAlloc ( ForBlk . prep , allocator , for_blk , first_step , & result_buffers ) ;
return while_ (
ForBlk . done ,
ForBlk . next ,
for_blk ,
. {
result_buffers ,
Tensor . scalar ( 0 , . i32 ) ,
} ,
) [ 0 ] ;
}
test for_ {
const Squares = struct {
const Squares = @This ( ) ;
pub fn sq ( self : Squares , i : Tensor ) Tensor {
_ = self ;
const f = i . convert ( . f32 ) ;
return f . mul ( f ) ;
}
pub fn forward ( num_steps : u63 ) Tensor {
return for_ ( Squares . sq , . { } , . { num_steps } ) ;
}
} ;
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
// Just one baby step
{
const squares = try zml . testing . compileAndCall ( platform , Squares . forward , . { 1 } ) ;
2023-06-19 15:29:29 +00:00
try zml . testing . expectEqualShapes ( Shape . init ( . { 1 } , . f32 ) , squares . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqual ( 0 , squares . getValue ( f32 ) ) ;
}
// Wow 4 in rows !
{
const squares = try zml . testing . compileAndCall ( platform , Squares . forward , . { 4 } ) ;
2023-06-19 15:29:29 +00:00
try zml . testing . expectEqualShapes ( Shape . init ( . { 4 } , . f32 ) , squares . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqual ( [ _ ] f32 { 0 , 1 , 4 , 9 } , try squares . getValue ( [ 4 ] f32 ) ) ;
}
// AGI is coming, computing 10 squares as it's nothing.
{
const squares = try zml . testing . compileAndCall ( platform , Squares . forward , . { 10 } ) ;
2023-06-19 15:29:29 +00:00
try zml . testing . expectEqualShapes ( Shape . init ( . { 10 } , . f32 ) , squares . shape ( ) ) ;
2023-01-02 14:28:25 +00:00
try std . testing . expectEqual (
[ _ ] f32 { 0 , 1 , 4 , 9 , 16 , 25 , 36 , 49 , 64 , 81 } ,
try squares . getValue ( [ 10 ] f32 ) ,
) ;
}
}
pub fn if_2 ( pred : Tensor , comptime Closure : type , blkctx : BlockSignNoArgs ( @field ( Closure , " then " ) ) . BlkCtx ) BlockSignNoArgs ( @field ( Closure , " then " ) ) . Return {
return if_ ( pred , @field ( Closure , " then " ) , @field ( Closure , " else_ " ) , blkctx ) ;
}
pub fn if_ (
pred : Tensor ,
comptime true_branch_fn : anytype ,
comptime false_branch_fn : anytype ,
blkctx : BlockSignNoArgs ( true_branch_fn ) . BlkCtx ,
) BlockSignNoArgs ( true_branch_fn ) . Return {
const TrueBlockSignature = comptime BlockSignNoArgs ( true_branch_fn ) ;
const FalseBlockSignature = comptime BlockSignNoArgs ( false_branch_fn ) ;
if ( TrueBlockSignature . Return ! = FalseBlockSignature . Return ) {
@compileError ( " true_branch_fn and false_branch_fn return types don't match ! " + + @typeName ( TrueBlockSignature . Return ) + + " and " + + @typeName ( FalseBlockSignature . Return ) ) ;
}
const ctx = CompilationContext . current ( ) ;
2023-07-21 09:01:01 +00:00
const true_branch_block , const true_branch_res = ctx . makeBlock ( TrueBlockSignature , & true_branch_fn , blkctx , { } ) ;
const false_branch_block , const false_branch_res = ctx . makeBlock ( TrueBlockSignature , & false_branch_fn , blkctx , { } ) ;
stdx . debug . assert ( false_branch_res . shape ( ) . eqlWithTags ( true_branch_res . shape ( ) ) , " zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={} " , . { true_branch_res , false_branch_res } ) ;
2023-01-02 14:28:25 +00:00
2023-07-21 09:01:01 +00:00
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) ;
2023-01-02 14:28:25 +00:00
const op = mlir . Operation . make ( ctx . mlirCtx ( ) , " stablehlo.if " , . {
. operands = & . { pred . value ( ) } ,
. result_type_inference = true ,
. blocks = & . { true_branch_block , false_branch_block } ,
// We can't verify right away, cause the weights captured by the if haven't been added yet.
. verify = false ,
. location = loc ,
} ) ;
2023-07-21 09:01:01 +00:00
return fromMlirOperationWithTags ( op , true_branch_res ) ;
2023-01-02 14:28:25 +00:00
}
test " if " {
const zml = @import ( " zml.zig " ) ;
const platform = zml . testing . env ( ) ;
const allocator = std . testing . allocator ;
const IfMod = struct {
pub fn forward ( pred : Tensor , a : Tensor , b : Tensor ) Tensor {
const result = if_ ( pred . convert ( . bool ) , condTrue , condFalse , . { a , b } ) ;
return result ;
}
pub fn condTrue ( a : Tensor , b : Tensor ) Tensor {
return a . matmul ( b ) ;
}
pub fn condFalse ( a : Tensor , b : Tensor ) Tensor {
return b . matmul ( a ) ;
}
} ;
{
const pred = Shape . init ( . { } , . i32 ) ;
const a = Shape . init ( . { 4 , 4 } , . f32 ) ;
const b = Shape . init ( . { 4 , 4 } , . f32 ) ;
const mod = try zml . compileFn ( allocator , IfMod . forward , . { pred , a , b } , platform ) ;
defer mod . deinit ( ) ;
}
}
pub fn sort (
comptime comp_fn : anytype ,
blkctx : BlockSign ( comp_fn ) . BlkCtx ,
inputs : [ @divExact ( BlockSign ( comp_fn ) . nIn , 2 ) ] Tensor ,
dimension : i64 ,
is_stable : bool ,
) [ @divExact ( BlockSign ( comp_fn ) . nIn , 2 ) ] Tensor {
const BodyS = comptime BlockSign ( comp_fn ) ;
var inits : BlockSign ( comp_fn ) . Args = undefined ;
inline for ( 0 . . @divExact ( BlockSign ( comp_fn ) . nIn , 2 ) ) | i | {
const arg_shape = Shape . init ( . { } , inputs [ i ] . dtype ( ) ) ;
// Note: the id doesn't matter cause makeBlock will correctly fill it.
inits [ i * 2 ] = Tensor { . _shape = arg_shape , . _id = undefined , . _donation = . no_buffer } ;
inits [ i * 2 + 1 ] = Tensor { . _shape = arg_shape , . _id = undefined , . _donation = . no_buffer } ;
}
const ctx = CompilationContext . current ( ) ;
2023-07-21 09:01:01 +00:00
const block , _ = ctx . makeBlock ( BodyS , & comp_fn , blkctx , inits ) ;
2023-01-02 14:28:25 +00:00
var input_values : [ @divExact ( BodyS . nIn , 2 ) ] mlir . Value = undefined ;
ctx . extractValues ( & inputs , & input_values ) ;
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) . namedFmt ( ctx . mlirCtx ( ) , " sort(dimension={d}, is_stable={}) " , . { dimension , is_stable } ) ;
const op = mlir . Operation . make ( ctx . mlirCtx ( ) , " stablehlo.sort " , . {
. variadic_operands = & . { & input_values } ,
. result_type_inference = true ,
. blocks = & . { block } ,
. attributes = & . {
. { " dimension " , mlir . IntegerAttribute ( . i64 ) . init ( ctx . mlirCtx ( ) , dimension ) . as ( mlir . Attribute ) . ? } ,
. { " is_stable " , mlir . BoolAttribute . init ( ctx . mlirCtx ( ) , is_stable ) . as ( mlir . Attribute ) . ? } ,
} ,
. location = loc ,
} ) ;
var res : [ @divExact ( BlockSign ( comp_fn ) . nIn , 2 ) ] Tensor = undefined ;
inline for ( 0 . . @divExact ( BlockSign ( comp_fn ) . nIn , 2 ) ) | i | {
res [ i ] = Tensor . _result ( inputs [ i ] . shape ( ) , op . result ( i ) ) ;
}
return res ;
}
pub const BlockSignature = struct {
2023-02-14 13:52:49 +00:00
Fn : type ,
2023-01-02 14:28:25 +00:00
BlkCtx : type ,
Args : type ,
FullArgs : type ,
Return : type ,
nIn : usize ,
nOut : usize ,
pub inline fn blkArgs ( self : BlockSignature , blk_ctx : self . BlkCtx , args : self . Args ) self . FullArgs {
if ( self . BlkCtx = = void ) return args ;
if ( self . Args = = void ) return blk_ctx ;
return . { blk_ctx } + + args ;
}
} ;
const BlockType = enum { default , no_ctx , no_args } ;
pub inline fn BlockSign ( comptime func : anytype ) BlockSignature {
return _BlockSign ( func , . default ) ;
}
pub inline fn BlockSignNoCtx ( comptime func : anytype ) BlockSignature {
return _BlockSign ( func , . no_ctx ) ;
}
pub inline fn BlockSignNoArgs ( comptime func : anytype ) BlockSignature {
return _BlockSign ( func , . no_args ) ;
}
pub fn fnInfo ( comptime func : anytype ) std . builtin . Type . Fn {
if ( @TypeOf ( func ) = = type ) {
if ( @typeInfo ( func ) = = . Struct and @hasDecl ( func , " forward " ) ) {
return fnInfo ( func . forward ) ;
}
@compileError ( " Given type doesn't have a forward function: " + + @typeName ( func ) ) ;
}
const type_info = @typeInfo ( @TypeOf ( func ) ) ;
const err_msg = " `func` must be a function and return one or more `Tensor`. Got: " ;
if ( type_info ! = . Fn or type_info . Fn . return_type = = null ) {
@compileError ( err_msg + + @typeName ( @TypeOf ( func ) ) ) ;
}
return type_info . Fn ;
}
fn _BlockSign ( comptime func : anytype , blk_type : BlockType ) BlockSignature {
const fn_info = fnInfo ( func ) ;
const err_msg = " `func` must be a function and return one or more `Tensor`. Got: " ;
var full_args : [ fn_info . params . len ] type = undefined ;
const arg_start = switch ( blk_type ) {
. default = > 1 ,
. no_ctx = > 0 ,
. no_args = > fn_info . params . len ,
} ;
var n_tensors : usize = 0 ;
// var n_inner_tensors: usize = 0;
inline for ( fn_info . params , 0 . . ) | arg , i | {
const ArgType = if ( arg . type ) | T | T else @compileError ( err_msg + + @typeName ( @TypeOf ( func ) ) ) ;
full_args [ i ] = ArgType ;
if ( i > = arg_start ) {
n_tensors + = staticCountTensors ( ArgType ) orelse @compileError ( " Can't use " + + @typeName ( ArgType ) + + " in an MLIR function, because it has a variable number of tensors " ) ;
}
}
const FullArgs = std . meta . Tuple ( & full_args ) ;
const BlkCtx = switch ( blk_type ) {
. default = > full_args [ 0 ] ,
. no_ctx = > void ,
. no_args = > FullArgs ,
} ;
const Args = switch ( blk_type ) {
. default = > std . meta . Tuple ( full_args [ 1 . . ] ) ,
. no_ctx = > FullArgs ,
. no_args = > void ,
} ;
2023-02-14 13:52:49 +00:00
return . {
. Fn = @TypeOf ( func ) ,
2023-01-02 14:28:25 +00:00
. BlkCtx = BlkCtx ,
. Args = Args ,
. FullArgs = FullArgs ,
. Return = fn_info . return_type . ? ,
. nIn = n_tensors ,
. nOut = staticCountTensors ( fn_info . return_type . ? ) orelse @compileError ( " Can't use " + + @typeName ( fn_info . return_type . ? ) + + " in an MLIR function, because it has a variable number of tensors " ) ,
} ;
}
pub fn staticIsOnlyTensors ( comptime T : type ) bool {
if ( T = = Tensor ) return true ;
return switch ( @typeInfo ( T ) ) {
. Array = > | array_info | staticIsOnlyTensors ( array_info . child ) ,
. Pointer = > | ptr_info | ptr_info . size = = . One and staticIsOnlyTensors ( ptr_info . child ) ,
. Struct = > | struct_info | {
inline for ( struct_info . fields ) | field | {
if ( ! staticIsOnlyTensors ( field . type ) ) return false ;
}
return true ;
} ,
else = > false ,
} ;
}
pub fn staticCountTensors ( comptime T : type ) ? usize {
if ( T = = Tensor ) return 1 ;
return switch ( @typeInfo ( T ) ) {
. Array = > | array_info | array_info . len * ( staticCountTensors ( array_info . child ) orelse return null ) ,
. Pointer = > | ptr_info | {
const n = staticCountTensors ( ptr_info . child ) orelse return null ;
if ( ptr_info . size ! = . One and n > 0 ) return null ;
return n ;
} ,
. Struct = > | struct_info | {
var count : usize = 0 ;
inline for ( struct_info . fields ) | field | {
count + = staticCountTensors ( field . type ) orelse return null ;
}
return count ;
} ,
else = > 0 ,
} ;
}
2023-07-21 09:01:01 +00:00
/// Create a Tensor struct similar to base, keeping base tags,
/// but using mlir value and dims from the mlir operation.
pub fn fromMlirOperationWithTags ( op : mlir . Operation , base : anytype ) @TypeOf ( base ) {
const LocalContext = struct {
index : usize ,
op : mlir . Operation ,
} ;
var context = LocalContext { . index = 0 , . op = op } ;
var res = base ;
meta . visit ( ( struct {
fn cb ( inner_ctx : * LocalContext , tensor : * Tensor ) void {
var new = Tensor . fromMlirValue ( inner_ctx . op . result ( inner_ctx . index ) ) ;
stdx . debug . internalAssert ( new . rank ( ) = = tensor . rank ( ) , " expected operand result to have rank {} but got {} " , . { tensor . rank ( ) , new } ) ;
// copy tags and sharding info over
// some ops can change dims eg reduceWindow, so we trust mlir here.
new . _shape . _tags = tensor . _shape . _tags ;
new . _shape . _sharding_info = tensor . _shape . _sharding_info ;
tensor . * = new ;
inner_ctx . index + = 1 ;
}
} ) . cb , & context , & res ) ;
assert ( context . index = = op . numResults ( ) ) ;
return res ;
}
2023-01-02 14:28:25 +00:00
/// Produces a custom call to `name` that takes a tensor and returns it.
///
/// For example, this can be used to extract tokens quickly if they run on a loop on the
/// GPU.
pub fn identityCustomCall ( name : [ : 0 ] const u8 , input : Tensor , context : * anyopaque ) Tensor {
const address : [ 8 ] u8 = @bitCast ( @intFromPtr ( context ) ) ;
var backend_config : [ 8 : 0 ] u8 = undefined ;
@memcpy ( backend_config [ 0 . . 8 ] , address [ 0 . . 8 ] ) ;
const ctx = CompilationContext . current ( ) ;
2023-06-05 13:42:45 +00:00
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) . namedFmt ( ctx . mlirCtx ( ) , " custom_call({s}) " , . { name } ) ;
2023-01-02 14:28:25 +00:00
const op = dialect . stablehlo . custom_call (
ctx . mlirCtx ( ) ,
& . { input . value ( ) } ,
. {
. api_version = 1 ,
. has_side_effect = false ,
. call_target_name = name ,
. backend_config = backend_config [ 0 . . ] ,
. output_operand_aliases = & . { 0 } ,
} ,
& . { input . value ( ) . getType ( ) } ,
loc ,
) ;
return Tensor . _result ( input . shape ( ) , op . result ( 0 ) ) ;
}
2023-06-05 13:42:45 +00:00
/// At runtime the given tensor will be materialized and copied to host,
/// and the callback will be called on it.
pub fn addHostCallback (
callback : * const fn ( HostBuffer ) void ,
input : Tensor ,
) Tensor {
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy.
if ( input . getContext ( ) . target ( ) ! = . cuda ) return input ;
const len = input . byteSize ( ) ;
// Reserve memory to be able to log the runtime Buffer later during the computation.
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
const HostCallbackCtx = Context . HostCallbackCtx ;
const full_data = std . heap . page_allocator . alignedAlloc ( u8 , 32 , len + 2 * @sizeOf ( HostCallbackCtx ) ) catch {
log . err ( " Failed to pre-allocate buffer to print {}. " , . { input } ) ;
return input ;
} ;
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
// Use an fba to have the stable buffer at an aligned offset.
var fba = std . heap . FixedBufferAllocator . init ( full_data [ len . . ] ) ;
const stable_ctx_ptr = fba . allocator ( ) . create ( HostCallbackCtx ) catch unreachable ;
stable_ctx_ptr . * = . {
. host = HostBuffer . fromBytes ( input . shape ( ) , full_data [ 0 . . len ] ) ,
} ;
const backend_config : [ 2 : null ] ? * const anyopaque = . { callback , stable_ctx_ptr } ;
const ctx = CompilationContext . current ( ) ;
const loc = ctx . mlirCtx ( ) . location ( @src ( ) ) ;
const op = dialect . stablehlo . custom_call (
ctx . mlirCtx ( ) ,
& . { input . value ( ) } ,
. {
. api_version = 1 ,
. has_side_effect = false ,
. call_target_name = " zmlHostBufferCallback " ,
. backend_config = @ptrCast ( std . mem . sliceAsBytes ( & backend_config ) ) ,
. output_operand_aliases = & . { 0 } ,
} ,
& . { input . value ( ) . getType ( ) } ,
loc ,
) ;
return Tensor . _result ( input . shape ( ) , op . result ( 0 ) ) ;
}