Fix crash in for_ by ensuring values are pushed to their block before opening a new block, adding asserts for block state, and guaranteeing first_step is used. Adjust padding syntax to improve usability.

This commit is contained in:
Tarry Singh 2023-07-25 14:25:47 +00:00
parent 0fa258cd88
commit b53462b515
6 changed files with 206 additions and 83 deletions

View File

@ -1,5 +1,6 @@
const builtin = @import("builtin"); const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const log = std.log.scoped(.mlir);
const c = @import("c"); const c = @import("c");
@ -852,7 +853,7 @@ pub const Operation = struct {
@panic("Failed to create MLIR operation"); @panic("Failed to create MLIR operation");
}; };
if (args.verify and new_op.verify() == false) { if (args.verify and new_op.verify() == false) {
std.log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })}); log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })});
@panic("Failed to verify MLIR operation"); @panic("Failed to verify MLIR operation");
} }
return new_op; return new_op;
@ -1062,7 +1063,7 @@ pub const Operation = struct {
pub const OpPrintingFlags = struct { pub const OpPrintingFlags = struct {
elide_large_elements_attrs: ?usize = null, elide_large_elements_attrs: ?usize = null,
debug_info: bool = false, debug_info: bool = false,
debug_info_pretty_form: bool = false, debug_info_pretty_form: bool = true,
print_generic_op_form: bool = false, print_generic_op_form: bool = false,
use_local_scope: bool = false, use_local_scope: bool = false,
assume_verified: bool = false, assume_verified: bool = false,
@ -1184,20 +1185,40 @@ pub const Value = struct {
return c.mlirValueIsAOpResult(val.inner()); return c.mlirValueIsAOpResult(val.inner());
} }
pub const Kind = enum { pub const Kind = union(enum) {
unknown, block_argument: BlockArgument,
block_argument, op_result: Operation,
op_result, null,
}; };
pub fn kind(val: Value) Kind { pub fn kind(val: Value) Kind {
if (val.isAOpResult()) { if (val.isAOpResult()) {
return .op_result; return .{ .op_result = val.owner() };
} }
if (val.isABlockArgument()) { if (val.isABlockArgument()) {
return .block_argument; return .{ .block_argument = .{ ._inner = val._inner } };
} }
return .unknown; // From MLIR docs:
// https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#details
// > An SSA value is either a BlockArgument or the result of an operation.
return .null;
}
};
pub const BlockArgument = struct {
_inner: c.MlirValue,
pub fn block(arg: BlockArgument) Block {
return Block.wrap(c.mlirBlockArgumentGetOwner(arg._inner));
}
pub fn number(arg: BlockArgument) usize {
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
}
pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
const value = Value{ ._inner = self._inner };
return value.format(fmt, options, writer);
} }
}; };
@ -1686,20 +1707,4 @@ pub const Block = struct {
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner()); c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
} }
} }
pub fn addOperationsRecursive(block: *Block, op_or_result: anytype) void {
const op: Operation = switch (@TypeOf(op_or_result)) {
Operation => op_or_result,
Value => if (op_or_result.kind() == .op_result) op_or_result.owner() else return,
else => |t| @compileError("can either be an operation or a value, not " ++ @typeName(t)),
};
if (op.block()) |prev_block| {
std.debug.assert(prev_block.equals(block.*));
return;
}
for (0..op.numOperands()) |i| {
block.addOperationsRecursive(op.operand(i));
}
block.appendOperation(op);
}
}; };

View File

@ -239,7 +239,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
const type_info_v = @typeInfo(T); const type_info_v = @typeInfo(T);
const K = switch (@typeInfo(FnParam(cb, 1))) { const K = switch (@typeInfo(FnParam(cb, 1))) {
.Pointer => |info| info.child, .Pointer => |info| info.child,
else => stdx.debug.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found {}", .{FnParam(cb, 1)}), else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
}; };
if (type_info_v != .Pointer) { if (type_info_v != .Pointer) {
@ -307,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
} }
} }
}, },
else => stdx.debug.compileError("Only single pointer and slice are supported. Received {}", .{T}), else => {},
} }
} }

View File

@ -30,6 +30,46 @@ test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
pub const BlockKind = enum { open, hermetic };
const Block = union(BlockKind) {
open: mlir.Block,
hermetic: mlir.Block,
pub fn block(self: Block) mlir.Block {
return switch (self) {
inline .open, .hermetic => |t| t,
};
}
fn appendTensorRecursive(self: Block, x: *const Tensor) void {
self.appendValueRecursive(x.value());
}
fn appendValueRecursive(self: Block, value: mlir.Value) void {
switch (value.kind()) {
.op_result => |parent_op| self.appendOperationRecursive(parent_op),
.block_argument => |arg| {
// Hermetic blocks are not allowed to use arguments from other blocks.
std.debug.assert(self == .open or self.block().eql(arg.block()));
},
.null => @panic("InvalidMlir"),
}
}
fn appendOperationRecursive(self: Block, op: mlir.Operation) void {
if (op.block()) |prev_block| {
// Hermetic blocks are not allowed to reference values from other blocks.
std.debug.assert(self == .open or prev_block.equals(self.block()));
return;
}
for (0..op.numOperands()) |i| {
self.appendValueRecursive(op.operand(i));
}
self.block().appendOperation(op);
}
};
pub const CompilationContext = struct { pub const CompilationContext = struct {
_platform: Platform, _platform: Platform,
@ -39,7 +79,7 @@ pub const CompilationContext = struct {
_module: mlir.Module, _module: mlir.Module,
_blocks: std.BoundedArray(mlir.Block, 64), _blocks: std.BoundedArray(Block, 64),
_fn_cache: FnCache, _fn_cache: FnCache,
_allocator: std.mem.Allocator, _allocator: std.mem.Allocator,
@ -120,22 +160,26 @@ pub const CompilationContext = struct {
return self._mlir_ctx; return self._mlir_ctx;
} }
pub fn currentBlock(self: *const CompilationContext) ?mlir.Block { pub fn currentBlock(self: *const CompilationContext) ?Block {
return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null; return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null;
} }
pub fn openBlock(self: *CompilationContext, args: []const mlir.Type, locs: []const mlir.Location) !mlir.Block { pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block {
const block = try mlir.Block.init(args, locs); const mlir_block = try mlir.Block.init(args, locs);
const block: Block = switch (kind) {
.open => .{ .open = mlir_block },
.hermetic => .{ .hermetic = mlir_block },
};
self.pushBlock(block); self.pushBlock(block);
return block; return block;
} }
pub fn closeBlock(self: *CompilationContext, block: *mlir.Block) void { pub fn closeBlock(self: *CompilationContext, block: Block) void {
const popped = self._blocks.pop(); const popped = self._blocks.pop();
std.debug.assert(block.equals(popped)); std.debug.assert(block.block().eql(popped.block()));
} }
fn pushBlock(self: *CompilationContext, block: mlir.Block) void { fn pushBlock(self: *CompilationContext, block: Block) void {
self._blocks.appendAssumeCapacity(block); self._blocks.appendAssumeCapacity(block);
} }
@ -147,35 +191,41 @@ pub const CompilationContext = struct {
/// But their shapes/tags can be safely propagated further. /// But their shapes/tags can be safely propagated further.
pub fn makeBlock( pub fn makeBlock(
self: *CompilationContext, self: *CompilationContext,
kind: BlockKind,
comptime S: ops.BlockSignature, comptime S: ops.BlockSignature,
func: *const S.Fn, func: *const S.Fn,
blkctx: S.BlkCtx, blkctx: S.BlkCtx,
args: S.Args, args: S.Args,
) struct { mlir.Block, S.Return } { ) struct { mlir.Block, S.Return } {
const N = S.nIn; const N = S.nIn;
const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N; const loc = self.mlirCtx().location(@src());
const locations = .{loc} ** N;
var input_types: [N]mlir.Type = undefined; var input_types: [N]mlir.Type = undefined;
fillMlirTypes(&args, self.mlirCtx(), &input_types); fillMlirTypes(&args, self.mlirCtx(), &input_types);
var block = self.openBlock(&input_types, &locations) catch unreachable; // Before creating a new block, assign all received values to previous block,
defer self.closeBlock(&block); // otherwise they will be assign to this block
if (self.currentBlock()) |prev_block| {
meta.visit(Block.appendTensorRecursive, prev_block, &blkctx);
}
const block = self.openBlock(kind, &input_types, &locations) catch unreachable;
defer self.closeBlock(block);
// Here we want to create the block with the correct mlir types. // Here we want to create the block with the correct mlir types.
// but we don't want to use the values themselves. // but we don't want to use the values themselves.
// So we create a copy of the arguments, and replace values // So we create a copy of the arguments, and replace values
// by the block arguments. // by the block arguments.
var blk_args = args; var blk_args = args;
std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N); std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N);
const loc = self.mlirCtx().location(@src());
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args)); const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
var block_res_values: [S.nOut]mlir.Value = undefined; var block_res_values: [S.nOut]mlir.Value = undefined;
self.extractValues(&block_res, &block_res_values); self.extractValues(&block_res, &block_res_values);
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc); const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
block.addOperationsRecursive(block_ret); block.appendOperationRecursive(block_ret);
return .{ block, block_res }; return .{ block.block(), block_res };
} }
/// Generate an MLIR function from a ZML function. /// Generate an MLIR function from a ZML function.
@ -225,9 +275,9 @@ pub const CompilationContext = struct {
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count); const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count); const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count); const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
var fn_body = self.openBlock(input_types, locations) catch unreachable; var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
{ {
defer self.closeBlock(&fn_body); defer self.closeBlock(fn_body);
// Note: we could shrink self._buffer_to_arg once we called `func`. // Note: we could shrink self._buffer_to_arg once we called `func`.
// But for now we are only compiling one function per CompilationContext. // But for now we are only compiling one function per CompilationContext.
// So we don't need to do this since we won't reuse self._buffer_to_arg anyway. // So we don't need to do this since we won't reuse self._buffer_to_arg anyway.
@ -235,8 +285,8 @@ pub const CompilationContext = struct {
// defer self._buffer_to_arg.shrinkRetainingCapacity(n); // defer self._buffer_to_arg.shrinkRetainingCapacity(n);
try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count));
const assigned_model_count = self.mapBlockArguments(model, fn_body, 0); const assigned_model_count = self.mapBlockArguments(model, fn_body.block(), 0);
const assigned_args_count = self.mapBlockArguments(args, fn_body, assigned_model_count); const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), assigned_model_count);
assert(assigned_model_count == model_tensor_count); assert(assigned_model_count == model_tensor_count);
assert(assigned_args_count == tensor_count); assert(assigned_args_count == tensor_count);
@ -250,7 +300,7 @@ pub const CompilationContext = struct {
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body.addOperationsRecursive(fn_ret); fn_body.appendOperationRecursive(fn_ret);
} }
const arg_attrs = try arena.alloc(AttributeList, tensor_count); const arg_attrs = try arena.alloc(AttributeList, tensor_count);
@ -271,7 +321,7 @@ pub const CompilationContext = struct {
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types, .results = fn_res_types,
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs), .res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
.block = fn_body, .block = fn_body.block(),
.location = loc, .location = loc,
}); });
@ -921,6 +971,22 @@ fn compileInternal(
context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
if (context._platform.compilation_options.xla_dump_to) |xla_dump_to| {
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| {
const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?;
const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value();
if (dir.createFile(file_name, .{ .truncate = true })) |file| {
context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = true });
log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name });
} else |_| {
log.warn("Failed to open {s}", .{file_name});
}
} else |_| {
log.warn("Folder not found {s}", .{xla_dump_to});
}
}
const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| { const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| {
log.err( log.err(
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",

View File

@ -47,9 +47,8 @@ pub fn while_(
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); const cond_block, _ = ctx.makeBlock(.open, CondS, &cond_fn, blkctx, inputs);
const body_block, const body_res = ctx.makeBlock(.open, BodyS, &body_fn, blkctx, inputs);
const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
var input_values: [BodyS.nIn]mlir.Value = undefined; var input_values: [BodyS.nIn]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -138,7 +137,7 @@ pub fn reduce(
var init_values: [N]mlir.Value = undefined; var init_values: [N]mlir.Value = undefined;
ctx.extractValues(&inits, &init_values); ctx.extractValues(&inits, &init_values);
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits });
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
@ -227,7 +226,7 @@ pub fn reduceWindow(
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits });
const N = comptime @divExact(BodyS.nIn, 2); const N = comptime @divExact(BodyS.nIn, 2);
var input_values: [N]mlir.Value = undefined; var input_values: [N]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -269,7 +268,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
const ForBlk = struct { const ForBlk = struct {
blk_ctx: S.BlkCtx, blk_ctx: S.BlkCtx,
step_tag: @TypeOf(step_tag), // This is a Shape.Tag, but we rather keep it private step_tag: Shape.Tag,
num_steps: u32, num_steps: u32,
const Self = @This(); const Self = @This();
@ -295,11 +294,12 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
} }
/// Prepare buffer to store all results steps. /// Prepare buffer to store all results steps.
fn prep(self: Self, x: Tensor) Tensor { fn prep(self: Self, first_step: Tensor) Tensor {
var shape = x.shape(); const shape = first_step.shape().insertTag(0, 1, self.step_tag);
shape._dims.insert(0, self.num_steps) catch unreachable; // Reuse the first step Tensor.
shape._tags.insert(0, self.step_tag) catch unreachable; // TODO: this is needed because of https://github.com/zml/zml/issues/97
return Tensor.constant(shape, x.dtype().zero()); // Normally I'd rather NOT reuse first_step to streamline the stablehlo IR.
return first_step.reshape(shape).pad(0, .{ ._0 = .{ .high = self.num_steps - 1 } });
} }
fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor { fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor {
@ -310,8 +310,9 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
} }
}; };
// This first step won't appear in the generated MLIR, // Compute first step to infer the output shapes.
// it's only used to infer the output shapes. // Normally this shouldn't be reused apart from the unrolled cases,
// but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop.
const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) }); const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) });
log.debug("for_ first_step: {}", .{first_step}); log.debug("for_ first_step: {}", .{first_step});
const allocator = CompilationContext.current()._allocator; const allocator = CompilationContext.current()._allocator;
@ -343,7 +344,8 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
for_blk, for_blk,
.{ .{
result_buffers, result_buffers,
Tensor.scalar(0, .i32), // First step is already done
Tensor.scalar(1, .i32),
}, },
)[0]; )[0];
} }
@ -388,6 +390,44 @@ test for_ {
} }
} }
test "nested for" {
const OuterProd = struct {
const OuterProd = @This();
x: Tensor,
x_row: Tensor,
pub fn forward(x: Tensor) Tensor {
return for_(OuterProd.scanRow, x, .{x.dim(0)});
}
pub fn scanRow(x: Tensor, i: Tensor) Tensor {
const row = x.dynamicSlice(.{.{ .start = i, .len = 1 }});
return for_(OuterProd.scanCol, .{ .x = x, .x_row = row }, .{x.dim(0)});
}
pub fn scanCol(self: OuterProd, j: Tensor) Tensor {
const col = self.x.dynamicSlice(.{.{ .start = j, .len = 1 }});
return self.x_row.mul(col);
}
};
const zml = @import("zml.zig");
const platform = zml.testing.env();
// 5 to prevent inlining
const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 });
const outer_prod = try zml.testing.compileAndCall(platform, OuterProd.forward, .{x});
const expected: [5][5]f32 = .{
.{ 0, 0, 0, 0, 0 },
.{ 0, 1.0, -1.0, 2.0, -2.0 },
.{ 0, -1.0, 1.0, -2.0, 2.0 },
.{ 0, 2.0, -2.0, 4.0, -4.0 },
.{ 0, -2.0, 2.0, -4.0, 4.0 },
};
try std.testing.expectEqual(expected, outer_prod.getValue(@TypeOf(expected)));
}
pub fn if_2(pred: Tensor, comptime Closure: type, blkctx: BlockSignNoArgs(@field(Closure, "then")).BlkCtx) BlockSignNoArgs(@field(Closure, "then")).Return { pub fn if_2(pred: Tensor, comptime Closure: type, blkctx: BlockSignNoArgs(@field(Closure, "then")).BlkCtx) BlockSignNoArgs(@field(Closure, "then")).Return {
return if_(pred, @field(Closure, "then"), @field(Closure, "else_"), blkctx); return if_(pred, @field(Closure, "then"), @field(Closure, "else_"), blkctx);
} }
@ -404,8 +444,8 @@ pub fn if_(
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res }); stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
@ -466,7 +506,7 @@ pub fn sort(
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); const block, _ = ctx.makeBlock(.hermetic, BodyS, &comp_fn, blkctx, inits);
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);

View File

@ -292,33 +292,45 @@ pub const Shape = struct {
stdx.debug.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); stdx.debug.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
} }
fn axisFromInt(self: Shape, d: isize) u3 { fn axisFromInt(self: Shape, a: isize) u3 {
const rk: i8 = self.rank(); const rk: i8 = self.rank();
if (d < -rk or d > rk) { if (a < -rk or a > rk) {
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a });
} }
return if (d < 0) return if (a < 0)
@intCast(d + rk) @intCast(a + rk)
else else
@intCast(d); @intCast(a);
} }
fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 { fn axisFromTagMaybe(self: Shape, t: Tag) ?u3 {
if (d == TagUnknown) { if (t == TagUnknown) return null;
return null;
} if (axisFromLiteralInt(t)) |ax| return ax;
if (@inComptime()) { if (@inComptime()) {
for (0.., self.tags()) |tagIndex, t| { // At comptime two duplicated strings may have two different representations
const a: []const u8 = std.mem.span(t); const t_bytes: []const u8 = std.mem.span(t);
const b: []const u8 = std.mem.span(d); for (self.tags(), 0..) |self_tag, ax| {
if (std.mem.eql(u8, a, b)) { if (std.mem.eql(u8, t_bytes, std.mem.span(self_tag))) {
return @intCast(tagIndex); return @truncate(ax);
} }
} }
return null; return null;
} }
if (std.mem.indexOfScalar(Tag, self.tags(), d)) |d_| {
return @intCast(d_); // But at runtime the comptime strings have been deduplicated and ptr match is enough.
if (std.mem.indexOfScalar(Tag, self.tags(), t)) |ax| {
return @truncate(ax);
}
return null;
}
/// Handle .{ ._0 = x } syntax.
fn axisFromLiteralInt(t: Tag) ?u3 {
// match .{ '_', '0-9', null }
if (t[0] == '_' and t[1] >= '0' and t[1] < '8' and t[2] == 0) {
return @intCast(t[1] - '0');
} }
return null; return null;
} }

View File

@ -2459,7 +2459,7 @@ pub const Tensor = struct {
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
const UpdateS = ops.BlockSign(ScatterOpts.increment); const UpdateS = ops.BlockSign(ScatterOpts.increment);
const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
const op = dialect.stablehlo.scatter( const op = dialect.stablehlo.scatter(
mlir_ctx, mlir_ctx,