Fix crash in for_ by ensuring values are pushed to their block before opening a new block, adding asserts for block state, and guaranteeing first_step is used. Adjust padding syntax to improve usability.
This commit is contained in:
parent
0fa258cd88
commit
b53462b515
@ -1,5 +1,6 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const log = std.log.scoped(.mlir);
|
||||
|
||||
const c = @import("c");
|
||||
|
||||
@ -852,7 +853,7 @@ pub const Operation = struct {
|
||||
@panic("Failed to create MLIR operation");
|
||||
};
|
||||
if (args.verify and new_op.verify() == false) {
|
||||
std.log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
||||
log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })});
|
||||
@panic("Failed to verify MLIR operation");
|
||||
}
|
||||
return new_op;
|
||||
@ -1062,7 +1063,7 @@ pub const Operation = struct {
|
||||
pub const OpPrintingFlags = struct {
|
||||
elide_large_elements_attrs: ?usize = null,
|
||||
debug_info: bool = false,
|
||||
debug_info_pretty_form: bool = false,
|
||||
debug_info_pretty_form: bool = true,
|
||||
print_generic_op_form: bool = false,
|
||||
use_local_scope: bool = false,
|
||||
assume_verified: bool = false,
|
||||
@ -1184,20 +1185,40 @@ pub const Value = struct {
|
||||
return c.mlirValueIsAOpResult(val.inner());
|
||||
}
|
||||
|
||||
pub const Kind = enum {
|
||||
unknown,
|
||||
block_argument,
|
||||
op_result,
|
||||
pub const Kind = union(enum) {
|
||||
block_argument: BlockArgument,
|
||||
op_result: Operation,
|
||||
null,
|
||||
};
|
||||
|
||||
pub fn kind(val: Value) Kind {
|
||||
if (val.isAOpResult()) {
|
||||
return .op_result;
|
||||
return .{ .op_result = val.owner() };
|
||||
}
|
||||
if (val.isABlockArgument()) {
|
||||
return .block_argument;
|
||||
return .{ .block_argument = .{ ._inner = val._inner } };
|
||||
}
|
||||
return .unknown;
|
||||
// From MLIR docs:
|
||||
// https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#details
|
||||
// > An SSA value is either a BlockArgument or the result of an operation.
|
||||
return .null;
|
||||
}
|
||||
};
|
||||
|
||||
pub const BlockArgument = struct {
|
||||
_inner: c.MlirValue,
|
||||
|
||||
pub fn block(arg: BlockArgument) Block {
|
||||
return Block.wrap(c.mlirBlockArgumentGetOwner(arg._inner));
|
||||
}
|
||||
|
||||
pub fn number(arg: BlockArgument) usize {
|
||||
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
||||
}
|
||||
|
||||
pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
||||
const value = Value{ ._inner = self._inner };
|
||||
return value.format(fmt, options, writer);
|
||||
}
|
||||
};
|
||||
|
||||
@ -1686,20 +1707,4 @@ pub const Block = struct {
|
||||
c.mlirBlockAppendOwnedOperation(self.inner(), op.inner());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn addOperationsRecursive(block: *Block, op_or_result: anytype) void {
|
||||
const op: Operation = switch (@TypeOf(op_or_result)) {
|
||||
Operation => op_or_result,
|
||||
Value => if (op_or_result.kind() == .op_result) op_or_result.owner() else return,
|
||||
else => |t| @compileError("can either be an operation or a value, not " ++ @typeName(t)),
|
||||
};
|
||||
if (op.block()) |prev_block| {
|
||||
std.debug.assert(prev_block.equals(block.*));
|
||||
return;
|
||||
}
|
||||
for (0..op.numOperands()) |i| {
|
||||
block.addOperationsRecursive(op.operand(i));
|
||||
}
|
||||
block.appendOperation(op);
|
||||
}
|
||||
};
|
||||
|
||||
@ -239,7 +239,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
const type_info_v = @typeInfo(T);
|
||||
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
||||
.Pointer => |info| info.child,
|
||||
else => stdx.debug.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found {}", .{FnParam(cb, 1)}),
|
||||
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
||||
};
|
||||
|
||||
if (type_info_v != .Pointer) {
|
||||
@ -307,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
}
|
||||
}
|
||||
},
|
||||
else => stdx.debug.compileError("Only single pointer and slice are supported. Received {}", .{T}),
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
108
zml/module.zig
108
zml/module.zig
@ -30,6 +30,46 @@ test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
pub const BlockKind = enum { open, hermetic };
|
||||
|
||||
const Block = union(BlockKind) {
|
||||
open: mlir.Block,
|
||||
hermetic: mlir.Block,
|
||||
|
||||
pub fn block(self: Block) mlir.Block {
|
||||
return switch (self) {
|
||||
inline .open, .hermetic => |t| t,
|
||||
};
|
||||
}
|
||||
|
||||
fn appendTensorRecursive(self: Block, x: *const Tensor) void {
|
||||
self.appendValueRecursive(x.value());
|
||||
}
|
||||
|
||||
fn appendValueRecursive(self: Block, value: mlir.Value) void {
|
||||
switch (value.kind()) {
|
||||
.op_result => |parent_op| self.appendOperationRecursive(parent_op),
|
||||
.block_argument => |arg| {
|
||||
// Hermetic blocks are not allowed to use arguments from other blocks.
|
||||
std.debug.assert(self == .open or self.block().eql(arg.block()));
|
||||
},
|
||||
.null => @panic("InvalidMlir"),
|
||||
}
|
||||
}
|
||||
|
||||
fn appendOperationRecursive(self: Block, op: mlir.Operation) void {
|
||||
if (op.block()) |prev_block| {
|
||||
// Hermetic blocks are not allowed to reference values from other blocks.
|
||||
std.debug.assert(self == .open or prev_block.equals(self.block()));
|
||||
return;
|
||||
}
|
||||
for (0..op.numOperands()) |i| {
|
||||
self.appendValueRecursive(op.operand(i));
|
||||
}
|
||||
self.block().appendOperation(op);
|
||||
}
|
||||
};
|
||||
|
||||
pub const CompilationContext = struct {
|
||||
_platform: Platform,
|
||||
|
||||
@ -39,7 +79,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
_module: mlir.Module,
|
||||
|
||||
_blocks: std.BoundedArray(mlir.Block, 64),
|
||||
_blocks: std.BoundedArray(Block, 64),
|
||||
_fn_cache: FnCache,
|
||||
_allocator: std.mem.Allocator,
|
||||
|
||||
@ -120,22 +160,26 @@ pub const CompilationContext = struct {
|
||||
return self._mlir_ctx;
|
||||
}
|
||||
|
||||
pub fn currentBlock(self: *const CompilationContext) ?mlir.Block {
|
||||
pub fn currentBlock(self: *const CompilationContext) ?Block {
|
||||
return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null;
|
||||
}
|
||||
|
||||
pub fn openBlock(self: *CompilationContext, args: []const mlir.Type, locs: []const mlir.Location) !mlir.Block {
|
||||
const block = try mlir.Block.init(args, locs);
|
||||
pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block {
|
||||
const mlir_block = try mlir.Block.init(args, locs);
|
||||
const block: Block = switch (kind) {
|
||||
.open => .{ .open = mlir_block },
|
||||
.hermetic => .{ .hermetic = mlir_block },
|
||||
};
|
||||
self.pushBlock(block);
|
||||
return block;
|
||||
}
|
||||
|
||||
pub fn closeBlock(self: *CompilationContext, block: *mlir.Block) void {
|
||||
pub fn closeBlock(self: *CompilationContext, block: Block) void {
|
||||
const popped = self._blocks.pop();
|
||||
std.debug.assert(block.equals(popped));
|
||||
std.debug.assert(block.block().eql(popped.block()));
|
||||
}
|
||||
|
||||
fn pushBlock(self: *CompilationContext, block: mlir.Block) void {
|
||||
fn pushBlock(self: *CompilationContext, block: Block) void {
|
||||
self._blocks.appendAssumeCapacity(block);
|
||||
}
|
||||
|
||||
@ -147,35 +191,41 @@ pub const CompilationContext = struct {
|
||||
/// But their shapes/tags can be safely propagated further.
|
||||
pub fn makeBlock(
|
||||
self: *CompilationContext,
|
||||
kind: BlockKind,
|
||||
comptime S: ops.BlockSignature,
|
||||
func: *const S.Fn,
|
||||
blkctx: S.BlkCtx,
|
||||
args: S.Args,
|
||||
) struct { mlir.Block, S.Return } {
|
||||
const N = S.nIn;
|
||||
const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N;
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
const locations = .{loc} ** N;
|
||||
var input_types: [N]mlir.Type = undefined;
|
||||
fillMlirTypes(&args, self.mlirCtx(), &input_types);
|
||||
|
||||
var block = self.openBlock(&input_types, &locations) catch unreachable;
|
||||
defer self.closeBlock(&block);
|
||||
// Before creating a new block, assign all received values to previous block,
|
||||
// otherwise they will be assign to this block
|
||||
if (self.currentBlock()) |prev_block| {
|
||||
meta.visit(Block.appendTensorRecursive, prev_block, &blkctx);
|
||||
}
|
||||
|
||||
const block = self.openBlock(kind, &input_types, &locations) catch unreachable;
|
||||
defer self.closeBlock(block);
|
||||
|
||||
// Here we want to create the block with the correct mlir types.
|
||||
// but we don't want to use the values themselves.
|
||||
// So we create a copy of the arguments, and replace values
|
||||
// by the block arguments.
|
||||
var blk_args = args;
|
||||
std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N);
|
||||
std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N);
|
||||
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
|
||||
|
||||
var block_res_values: [S.nOut]mlir.Value = undefined;
|
||||
self.extractValues(&block_res, &block_res_values);
|
||||
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
|
||||
block.addOperationsRecursive(block_ret);
|
||||
block.appendOperationRecursive(block_ret);
|
||||
|
||||
return .{ block, block_res };
|
||||
return .{ block.block(), block_res };
|
||||
}
|
||||
|
||||
/// Generate an MLIR function from a ZML function.
|
||||
@ -225,9 +275,9 @@ pub const CompilationContext = struct {
|
||||
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
|
||||
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
|
||||
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
|
||||
var fn_body = self.openBlock(input_types, locations) catch unreachable;
|
||||
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
||||
{
|
||||
defer self.closeBlock(&fn_body);
|
||||
defer self.closeBlock(fn_body);
|
||||
// Note: we could shrink self._buffer_to_arg once we called `func`.
|
||||
// But for now we are only compiling one function per CompilationContext.
|
||||
// So we don't need to do this since we won't reuse self._buffer_to_arg anyway.
|
||||
@ -235,8 +285,8 @@ pub const CompilationContext = struct {
|
||||
// defer self._buffer_to_arg.shrinkRetainingCapacity(n);
|
||||
|
||||
try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count));
|
||||
const assigned_model_count = self.mapBlockArguments(model, fn_body, 0);
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body, assigned_model_count);
|
||||
const assigned_model_count = self.mapBlockArguments(model, fn_body.block(), 0);
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), assigned_model_count);
|
||||
assert(assigned_model_count == model_tensor_count);
|
||||
assert(assigned_args_count == tensor_count);
|
||||
|
||||
@ -250,7 +300,7 @@ pub const CompilationContext = struct {
|
||||
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||
|
||||
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||
fn_body.addOperationsRecursive(fn_ret);
|
||||
fn_body.appendOperationRecursive(fn_ret);
|
||||
}
|
||||
|
||||
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
|
||||
@ -271,7 +321,7 @@ pub const CompilationContext = struct {
|
||||
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
||||
.results = fn_res_types,
|
||||
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
|
||||
.block = fn_body,
|
||||
.block = fn_body.block(),
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
@ -921,6 +971,22 @@ fn compileInternal(
|
||||
context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr());
|
||||
context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
|
||||
|
||||
if (context._platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
|
||||
if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| {
|
||||
const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?;
|
||||
const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value();
|
||||
if (dir.createFile(file_name, .{ .truncate = true })) |file| {
|
||||
context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = true });
|
||||
log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name });
|
||||
} else |_| {
|
||||
log.warn("Failed to open {s}", .{file_name});
|
||||
}
|
||||
} else |_| {
|
||||
log.warn("Folder not found {s}", .{xla_dump_to});
|
||||
}
|
||||
}
|
||||
|
||||
const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| {
|
||||
log.err(
|
||||
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
||||
|
||||
74
zml/ops.zig
74
zml/ops.zig
@ -47,9 +47,8 @@ pub fn while_(
|
||||
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
|
||||
|
||||
const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
|
||||
const cond_block, _ = ctx.makeBlock(.open, CondS, &cond_fn, blkctx, inputs);
|
||||
const body_block, const body_res = ctx.makeBlock(.open, BodyS, &body_fn, blkctx, inputs);
|
||||
var input_values: [BodyS.nIn]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
@ -138,7 +137,7 @@ pub fn reduce(
|
||||
var init_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inits, &init_values);
|
||||
|
||||
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits });
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
|
||||
@ -227,7 +226,7 @@ pub fn reduceWindow(
|
||||
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const N = comptime @divExact(BodyS.nIn, 2);
|
||||
var input_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
@ -269,7 +268,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
|
||||
const ForBlk = struct {
|
||||
blk_ctx: S.BlkCtx,
|
||||
step_tag: @TypeOf(step_tag), // This is a Shape.Tag, but we rather keep it private
|
||||
step_tag: Shape.Tag,
|
||||
num_steps: u32,
|
||||
const Self = @This();
|
||||
|
||||
@ -295,11 +294,12 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
}
|
||||
|
||||
/// Prepare buffer to store all results steps.
|
||||
fn prep(self: Self, x: Tensor) Tensor {
|
||||
var shape = x.shape();
|
||||
shape._dims.insert(0, self.num_steps) catch unreachable;
|
||||
shape._tags.insert(0, self.step_tag) catch unreachable;
|
||||
return Tensor.constant(shape, x.dtype().zero());
|
||||
fn prep(self: Self, first_step: Tensor) Tensor {
|
||||
const shape = first_step.shape().insertTag(0, 1, self.step_tag);
|
||||
// Reuse the first step Tensor.
|
||||
// TODO: this is needed because of https://github.com/zml/zml/issues/97
|
||||
// Normally I'd rather NOT reuse first_step to streamline the stablehlo IR.
|
||||
return first_step.reshape(shape).pad(0, .{ ._0 = .{ .high = self.num_steps - 1 } });
|
||||
}
|
||||
|
||||
fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor {
|
||||
@ -310,8 +310,9 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
}
|
||||
};
|
||||
|
||||
// This first step won't appear in the generated MLIR,
|
||||
// it's only used to infer the output shapes.
|
||||
// Compute first step to infer the output shapes.
|
||||
// Normally this shouldn't be reused apart from the unrolled cases,
|
||||
// but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop.
|
||||
const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) });
|
||||
log.debug("for_ first_step: {}", .{first_step});
|
||||
const allocator = CompilationContext.current()._allocator;
|
||||
@ -343,7 +344,8 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
||||
for_blk,
|
||||
.{
|
||||
result_buffers,
|
||||
Tensor.scalar(0, .i32),
|
||||
// First step is already done
|
||||
Tensor.scalar(1, .i32),
|
||||
},
|
||||
)[0];
|
||||
}
|
||||
@ -388,6 +390,44 @@ test for_ {
|
||||
}
|
||||
}
|
||||
|
||||
test "nested for" {
|
||||
const OuterProd = struct {
|
||||
const OuterProd = @This();
|
||||
|
||||
x: Tensor,
|
||||
x_row: Tensor,
|
||||
|
||||
pub fn forward(x: Tensor) Tensor {
|
||||
return for_(OuterProd.scanRow, x, .{x.dim(0)});
|
||||
}
|
||||
|
||||
pub fn scanRow(x: Tensor, i: Tensor) Tensor {
|
||||
const row = x.dynamicSlice(.{.{ .start = i, .len = 1 }});
|
||||
return for_(OuterProd.scanCol, .{ .x = x, .x_row = row }, .{x.dim(0)});
|
||||
}
|
||||
|
||||
pub fn scanCol(self: OuterProd, j: Tensor) Tensor {
|
||||
const col = self.x.dynamicSlice(.{.{ .start = j, .len = 1 }});
|
||||
return self.x_row.mul(col);
|
||||
}
|
||||
};
|
||||
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// 5 to prevent inlining
|
||||
const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 });
|
||||
const outer_prod = try zml.testing.compileAndCall(platform, OuterProd.forward, .{x});
|
||||
const expected: [5][5]f32 = .{
|
||||
.{ 0, 0, 0, 0, 0 },
|
||||
.{ 0, 1.0, -1.0, 2.0, -2.0 },
|
||||
.{ 0, -1.0, 1.0, -2.0, 2.0 },
|
||||
.{ 0, 2.0, -2.0, 4.0, -4.0 },
|
||||
.{ 0, -2.0, 2.0, -4.0, 4.0 },
|
||||
};
|
||||
try std.testing.expectEqual(expected, outer_prod.getValue(@TypeOf(expected)));
|
||||
}
|
||||
|
||||
pub fn if_2(pred: Tensor, comptime Closure: type, blkctx: BlockSignNoArgs(@field(Closure, "then")).BlkCtx) BlockSignNoArgs(@field(Closure, "then")).Return {
|
||||
return if_(pred, @field(Closure, "then"), @field(Closure, "else_"), blkctx);
|
||||
}
|
||||
@ -404,8 +444,8 @@ pub fn if_(
|
||||
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
@ -466,7 +506,7 @@ pub fn sort(
|
||||
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
|
||||
const block, _ = ctx.makeBlock(.hermetic, BodyS, &comp_fn, blkctx, inits);
|
||||
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
|
||||
@ -292,33 +292,45 @@ pub const Shape = struct {
|
||||
stdx.debug.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T});
|
||||
}
|
||||
|
||||
fn axisFromInt(self: Shape, d: isize) u3 {
|
||||
fn axisFromInt(self: Shape, a: isize) u3 {
|
||||
const rk: i8 = self.rank();
|
||||
if (d < -rk or d > rk) {
|
||||
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d });
|
||||
if (a < -rk or a > rk) {
|
||||
stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a });
|
||||
}
|
||||
return if (d < 0)
|
||||
@intCast(d + rk)
|
||||
return if (a < 0)
|
||||
@intCast(a + rk)
|
||||
else
|
||||
@intCast(d);
|
||||
@intCast(a);
|
||||
}
|
||||
|
||||
fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 {
|
||||
if (d == TagUnknown) {
|
||||
return null;
|
||||
}
|
||||
fn axisFromTagMaybe(self: Shape, t: Tag) ?u3 {
|
||||
if (t == TagUnknown) return null;
|
||||
|
||||
if (axisFromLiteralInt(t)) |ax| return ax;
|
||||
|
||||
if (@inComptime()) {
|
||||
for (0.., self.tags()) |tagIndex, t| {
|
||||
const a: []const u8 = std.mem.span(t);
|
||||
const b: []const u8 = std.mem.span(d);
|
||||
if (std.mem.eql(u8, a, b)) {
|
||||
return @intCast(tagIndex);
|
||||
// At comptime two duplicated strings may have two different representations
|
||||
const t_bytes: []const u8 = std.mem.span(t);
|
||||
for (self.tags(), 0..) |self_tag, ax| {
|
||||
if (std.mem.eql(u8, t_bytes, std.mem.span(self_tag))) {
|
||||
return @truncate(ax);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (std.mem.indexOfScalar(Tag, self.tags(), d)) |d_| {
|
||||
return @intCast(d_);
|
||||
|
||||
// But at runtime the comptime strings have been deduplicated and ptr match is enough.
|
||||
if (std.mem.indexOfScalar(Tag, self.tags(), t)) |ax| {
|
||||
return @truncate(ax);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Handle .{ ._0 = x } syntax.
|
||||
fn axisFromLiteralInt(t: Tag) ?u3 {
|
||||
// match .{ '_', '0-9', null }
|
||||
if (t[0] == '_' and t[1] >= '0' and t[1] < '8' and t[2] == 0) {
|
||||
return @intCast(t[1] - '0');
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@ -2459,7 +2459,7 @@ pub const Tensor = struct {
|
||||
|
||||
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
|
||||
const UpdateS = ops.BlockSign(ScatterOpts.increment);
|
||||
const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||||
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||||
|
||||
const op = dialect.stablehlo.scatter(
|
||||
mlir_ctx,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user