From b53462b51510377cc4c41d07d4672a0fbd53987c Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 25 Jul 2023 14:25:47 +0000 Subject: [PATCH] Fix crash in for_ by ensuring values are pushed to their block before opening a new block, adding asserts for block state, and guaranteeing first_step is used. Adjust padding syntax to improve usability. --- mlir/mlir.zig | 55 +++++++++++++------------ zml/meta.zig | 4 +- zml/module.zig | 108 +++++++++++++++++++++++++++++++++++++++---------- zml/ops.zig | 74 +++++++++++++++++++++++++-------- zml/shape.zig | 46 +++++++++++++-------- zml/tensor.zig | 2 +- 6 files changed, 206 insertions(+), 83 deletions(-) diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 1bb1727..05aef14 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const log = std.log.scoped(.mlir); const c = @import("c"); @@ -852,7 +853,7 @@ pub const Operation = struct { @panic("Failed to create MLIR operation"); }; if (args.verify and new_op.verify() == false) { - std.log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })}); + log.err("Failed to verify MLIR operation:\n{}", .{new_op.mlirFormatter(.{ .debug_info = true })}); @panic("Failed to verify MLIR operation"); } return new_op; @@ -1062,7 +1063,7 @@ pub const Operation = struct { pub const OpPrintingFlags = struct { elide_large_elements_attrs: ?usize = null, debug_info: bool = false, - debug_info_pretty_form: bool = false, + debug_info_pretty_form: bool = true, print_generic_op_form: bool = false, use_local_scope: bool = false, assume_verified: bool = false, @@ -1184,20 +1185,40 @@ pub const Value = struct { return c.mlirValueIsAOpResult(val.inner()); } - pub const Kind = enum { - unknown, - block_argument, - op_result, + pub const Kind = union(enum) { + block_argument: BlockArgument, + op_result: Operation, + null, }; pub fn kind(val: Value) Kind { if (val.isAOpResult()) { - return .op_result; + return .{ .op_result = val.owner() }; } if (val.isABlockArgument()) { - return .block_argument; + return .{ .block_argument = .{ ._inner = val._inner } }; } - return .unknown; + // From MLIR docs: + // https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#details + // > An SSA value is either a BlockArgument or the result of an operation. + return .null; + } +}; + +pub const BlockArgument = struct { + _inner: c.MlirValue, + + pub fn block(arg: BlockArgument) Block { + return Block.wrap(c.mlirBlockArgumentGetOwner(arg._inner)); + } + + pub fn number(arg: BlockArgument) usize { + return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner)); + } + + pub fn format(self: BlockArgument, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + const value = Value{ ._inner = self._inner }; + return value.format(fmt, options, writer); } }; @@ -1686,20 +1707,4 @@ pub const Block = struct { c.mlirBlockAppendOwnedOperation(self.inner(), op.inner()); } } - - pub fn addOperationsRecursive(block: *Block, op_or_result: anytype) void { - const op: Operation = switch (@TypeOf(op_or_result)) { - Operation => op_or_result, - Value => if (op_or_result.kind() == .op_result) op_or_result.owner() else return, - else => |t| @compileError("can either be an operation or a value, not " ++ @typeName(t)), - }; - if (op.block()) |prev_block| { - std.debug.assert(prev_block.equals(block.*)); - return; - } - for (0..op.numOperands()) |i| { - block.addOperationsRecursive(op.operand(i)); - } - block.appendOperation(op); - } }; diff --git a/zml/meta.zig b/zml/meta.zig index 2666ebb..171bbb2 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -239,7 +239,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { const type_info_v = @typeInfo(T); const K = switch (@typeInfo(FnParam(cb, 1))) { .Pointer => |info| info.child, - else => stdx.debug.compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found {}", .{FnParam(cb, 1)}), + else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}), }; if (type_info_v != .Pointer) { @@ -307,7 +307,7 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { } } }, - else => stdx.debug.compileError("Only single pointer and slice are supported. Received {}", .{T}), + else => {}, } } diff --git a/zml/module.zig b/zml/module.zig index 20da81d..6d204b3 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -30,6 +30,46 @@ test { std.testing.refAllDecls(@This()); } +pub const BlockKind = enum { open, hermetic }; + +const Block = union(BlockKind) { + open: mlir.Block, + hermetic: mlir.Block, + + pub fn block(self: Block) mlir.Block { + return switch (self) { + inline .open, .hermetic => |t| t, + }; + } + + fn appendTensorRecursive(self: Block, x: *const Tensor) void { + self.appendValueRecursive(x.value()); + } + + fn appendValueRecursive(self: Block, value: mlir.Value) void { + switch (value.kind()) { + .op_result => |parent_op| self.appendOperationRecursive(parent_op), + .block_argument => |arg| { + // Hermetic blocks are not allowed to use arguments from other blocks. + std.debug.assert(self == .open or self.block().eql(arg.block())); + }, + .null => @panic("InvalidMlir"), + } + } + + fn appendOperationRecursive(self: Block, op: mlir.Operation) void { + if (op.block()) |prev_block| { + // Hermetic blocks are not allowed to reference values from other blocks. + std.debug.assert(self == .open or prev_block.equals(self.block())); + return; + } + for (0..op.numOperands()) |i| { + self.appendValueRecursive(op.operand(i)); + } + self.block().appendOperation(op); + } +}; + pub const CompilationContext = struct { _platform: Platform, @@ -39,7 +79,7 @@ pub const CompilationContext = struct { _module: mlir.Module, - _blocks: std.BoundedArray(mlir.Block, 64), + _blocks: std.BoundedArray(Block, 64), _fn_cache: FnCache, _allocator: std.mem.Allocator, @@ -120,22 +160,26 @@ pub const CompilationContext = struct { return self._mlir_ctx; } - pub fn currentBlock(self: *const CompilationContext) ?mlir.Block { + pub fn currentBlock(self: *const CompilationContext) ?Block { return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null; } - pub fn openBlock(self: *CompilationContext, args: []const mlir.Type, locs: []const mlir.Location) !mlir.Block { - const block = try mlir.Block.init(args, locs); + pub fn openBlock(self: *CompilationContext, kind: BlockKind, args: []const mlir.Type, locs: []const mlir.Location) !Block { + const mlir_block = try mlir.Block.init(args, locs); + const block: Block = switch (kind) { + .open => .{ .open = mlir_block }, + .hermetic => .{ .hermetic = mlir_block }, + }; self.pushBlock(block); return block; } - pub fn closeBlock(self: *CompilationContext, block: *mlir.Block) void { + pub fn closeBlock(self: *CompilationContext, block: Block) void { const popped = self._blocks.pop(); - std.debug.assert(block.equals(popped)); + std.debug.assert(block.block().eql(popped.block())); } - fn pushBlock(self: *CompilationContext, block: mlir.Block) void { + fn pushBlock(self: *CompilationContext, block: Block) void { self._blocks.appendAssumeCapacity(block); } @@ -147,35 +191,41 @@ pub const CompilationContext = struct { /// But their shapes/tags can be safely propagated further. pub fn makeBlock( self: *CompilationContext, + kind: BlockKind, comptime S: ops.BlockSignature, func: *const S.Fn, blkctx: S.BlkCtx, args: S.Args, ) struct { mlir.Block, S.Return } { const N = S.nIn; - const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N; + const loc = self.mlirCtx().location(@src()); + const locations = .{loc} ** N; var input_types: [N]mlir.Type = undefined; fillMlirTypes(&args, self.mlirCtx(), &input_types); - var block = self.openBlock(&input_types, &locations) catch unreachable; - defer self.closeBlock(&block); + // Before creating a new block, assign all received values to previous block, + // otherwise they will be assign to this block + if (self.currentBlock()) |prev_block| { + meta.visit(Block.appendTensorRecursive, prev_block, &blkctx); + } + + const block = self.openBlock(kind, &input_types, &locations) catch unreachable; + defer self.closeBlock(block); // Here we want to create the block with the correct mlir types. // but we don't want to use the values themselves. // So we create a copy of the arguments, and replace values // by the block arguments. var blk_args = args; - std.debug.assert(assignBlockArguments(&blk_args, block, 0) == N); + std.debug.assert(assignBlockArguments(&blk_args, block.block(), 0) == N); - const loc = self.mlirCtx().location(@src()); const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args)); - var block_res_values: [S.nOut]mlir.Value = undefined; self.extractValues(&block_res, &block_res_values); const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc); - block.addOperationsRecursive(block_ret); + block.appendOperationRecursive(block_ret); - return .{ block, block_res }; + return .{ block.block(), block_res }; } /// Generate an MLIR function from a ZML function. @@ -225,9 +275,9 @@ pub const CompilationContext = struct { const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count); const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count); const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count); - var fn_body = self.openBlock(input_types, locations) catch unreachable; + var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; { - defer self.closeBlock(&fn_body); + defer self.closeBlock(fn_body); // Note: we could shrink self._buffer_to_arg once we called `func`. // But for now we are only compiling one function per CompilationContext. // So we don't need to do this since we won't reuse self._buffer_to_arg anyway. @@ -235,8 +285,8 @@ pub const CompilationContext = struct { // defer self._buffer_to_arg.shrinkRetainingCapacity(n); try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count)); - const assigned_model_count = self.mapBlockArguments(model, fn_body, 0); - const assigned_args_count = self.mapBlockArguments(args, fn_body, assigned_model_count); + const assigned_model_count = self.mapBlockArguments(model, fn_body.block(), 0); + const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), assigned_model_count); assert(assigned_model_count == model_tensor_count); assert(assigned_args_count == tensor_count); @@ -250,7 +300,7 @@ pub const CompilationContext = struct { self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); - fn_body.addOperationsRecursive(fn_ret); + fn_body.appendOperationRecursive(fn_ret); } const arg_attrs = try arena.alloc(AttributeList, tensor_count); @@ -271,7 +321,7 @@ pub const CompilationContext = struct { .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .results = fn_res_types, .res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs), - .block = fn_body, + .block = fn_body.block(), .location = loc, }); @@ -921,6 +971,22 @@ fn compileInternal( context._module.op().setAttributeByName("mhlo.num_replicas", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_replicas).asAttr()); context._module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); + if (context._platform.compilation_options.xla_dump_to) |xla_dump_to| { + // Write the mlir to a file. All errors are discarded, since this is for debugging only. + if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| { + const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?; + const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value(); + if (dir.createFile(file_name, .{ .truncate = true })) |file| { + context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = true }); + log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name }); + } else |_| { + log.warn("Failed to open {s}", .{file_name}); + } + } else |_| { + log.warn("Folder not found {s}", .{xla_dump_to}); + } + } + const loaded_executable = loadOrCompilePjrtExecutable(arena, context._platform, context._module) catch |err| { log.err( "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", diff --git a/zml/ops.zig b/zml/ops.zig index 170b4ed..aa17038 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -47,9 +47,8 @@ pub fn while_( @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); } const ctx = CompilationContext.current(); - const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); - - const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs); + const cond_block, _ = ctx.makeBlock(.open, CondS, &cond_fn, blkctx, inputs); + const body_block, const body_res = ctx.makeBlock(.open, BodyS, &body_fn, blkctx, inputs); var input_values: [BodyS.nIn]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -138,7 +137,7 @@ pub fn reduce( var init_values: [N]mlir.Value = undefined; ctx.extractValues(&inits, &init_values); - const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); + const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits }); const loc = ctx.mlirCtx().location(@src()); @@ -227,7 +226,7 @@ pub fn reduceWindow( if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); } const ctx = CompilationContext.current(); - const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); + const body_block, _ = ctx.makeBlock(.hermetic, BodyS, &body_fn, {}, .{ inits, inits }); const N = comptime @divExact(BodyS.nIn, 2); var input_values: [N]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -269,7 +268,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: const ForBlk = struct { blk_ctx: S.BlkCtx, - step_tag: @TypeOf(step_tag), // This is a Shape.Tag, but we rather keep it private + step_tag: Shape.Tag, num_steps: u32, const Self = @This(); @@ -295,11 +294,12 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: } /// Prepare buffer to store all results steps. - fn prep(self: Self, x: Tensor) Tensor { - var shape = x.shape(); - shape._dims.insert(0, self.num_steps) catch unreachable; - shape._tags.insert(0, self.step_tag) catch unreachable; - return Tensor.constant(shape, x.dtype().zero()); + fn prep(self: Self, first_step: Tensor) Tensor { + const shape = first_step.shape().insertTag(0, 1, self.step_tag); + // Reuse the first step Tensor. + // TODO: this is needed because of https://github.com/zml/zml/issues/97 + // Normally I'd rather NOT reuse first_step to streamline the stablehlo IR. + return first_step.reshape(shape).pad(0, .{ ._0 = .{ .high = self.num_steps - 1 } }); } fn wrapFirstStep(tag_: @TypeOf(step_tag), x: Tensor) Tensor { @@ -310,8 +310,9 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: } }; - // This first step won't appear in the generated MLIR, - // it's only used to infer the output shapes. + // Compute first step to infer the output shapes. + // Normally this shouldn't be reused apart from the unrolled cases, + // but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop. const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) }); log.debug("for_ first_step: {}", .{first_step}); const allocator = CompilationContext.current()._allocator; @@ -343,7 +344,8 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_: for_blk, .{ result_buffers, - Tensor.scalar(0, .i32), + // First step is already done + Tensor.scalar(1, .i32), }, )[0]; } @@ -388,6 +390,44 @@ test for_ { } } +test "nested for" { + const OuterProd = struct { + const OuterProd = @This(); + + x: Tensor, + x_row: Tensor, + + pub fn forward(x: Tensor) Tensor { + return for_(OuterProd.scanRow, x, .{x.dim(0)}); + } + + pub fn scanRow(x: Tensor, i: Tensor) Tensor { + const row = x.dynamicSlice(.{.{ .start = i, .len = 1 }}); + return for_(OuterProd.scanCol, .{ .x = x, .x_row = row }, .{x.dim(0)}); + } + + pub fn scanCol(self: OuterProd, j: Tensor) Tensor { + const col = self.x.dynamicSlice(.{.{ .start = j, .len = 1 }}); + return self.x_row.mul(col); + } + }; + + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + // 5 to prevent inlining + const x = try zml.Buffer.fromArray(platform, [5]f32{ 0, 1.0, -1.0, 2.0, -2.0 }); + const outer_prod = try zml.testing.compileAndCall(platform, OuterProd.forward, .{x}); + const expected: [5][5]f32 = .{ + .{ 0, 0, 0, 0, 0 }, + .{ 0, 1.0, -1.0, 2.0, -2.0 }, + .{ 0, -1.0, 1.0, -2.0, 2.0 }, + .{ 0, 2.0, -2.0, 4.0, -4.0 }, + .{ 0, -2.0, 2.0, -4.0, 4.0 }, + }; + try std.testing.expectEqual(expected, outer_prod.getValue(@TypeOf(expected))); +} + pub fn if_2(pred: Tensor, comptime Closure: type, blkctx: BlockSignNoArgs(@field(Closure, "then")).BlkCtx) BlockSignNoArgs(@field(Closure, "then")).Return { return if_(pred, @field(Closure, "then"), @field(Closure, "else_"), blkctx); } @@ -404,8 +444,8 @@ pub fn if_( @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); } const ctx = CompilationContext.current(); - const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); - const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); + const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); + const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res }); const loc = ctx.mlirCtx().location(@src()); @@ -466,7 +506,7 @@ pub fn sort( inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; } const ctx = CompilationContext.current(); - const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); + const block, _ = ctx.makeBlock(.hermetic, BodyS, &comp_fn, blkctx, inits); var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); diff --git a/zml/shape.zig b/zml/shape.zig index 2528364..f9643e9 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -292,33 +292,45 @@ pub const Shape = struct { stdx.debug.compileError("axes expects an int-tuple or a tuple of enum literal, got {}", .{T}); } - fn axisFromInt(self: Shape, d: isize) u3 { + fn axisFromInt(self: Shape, a: isize) u3 { const rk: i8 = self.rank(); - if (d < -rk or d > rk) { - stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, d }); + if (a < -rk or a > rk) { + stdx.debug.panic("Tensor {} doesn't have dimension: {d}", .{ self, a }); } - return if (d < 0) - @intCast(d + rk) + return if (a < 0) + @intCast(a + rk) else - @intCast(d); + @intCast(a); } - fn axisFromTagMaybe(self: Shape, d: Tag) ?u3 { - if (d == TagUnknown) { - return null; - } + fn axisFromTagMaybe(self: Shape, t: Tag) ?u3 { + if (t == TagUnknown) return null; + + if (axisFromLiteralInt(t)) |ax| return ax; + if (@inComptime()) { - for (0.., self.tags()) |tagIndex, t| { - const a: []const u8 = std.mem.span(t); - const b: []const u8 = std.mem.span(d); - if (std.mem.eql(u8, a, b)) { - return @intCast(tagIndex); + // At comptime two duplicated strings may have two different representations + const t_bytes: []const u8 = std.mem.span(t); + for (self.tags(), 0..) |self_tag, ax| { + if (std.mem.eql(u8, t_bytes, std.mem.span(self_tag))) { + return @truncate(ax); } } return null; } - if (std.mem.indexOfScalar(Tag, self.tags(), d)) |d_| { - return @intCast(d_); + + // But at runtime the comptime strings have been deduplicated and ptr match is enough. + if (std.mem.indexOfScalar(Tag, self.tags(), t)) |ax| { + return @truncate(ax); + } + return null; + } + + /// Handle .{ ._0 = x } syntax. + fn axisFromLiteralInt(t: Tag) ?u3 { + // match .{ '_', '0-9', null } + if (t[0] == '_' and t[1] >= '0' and t[1] < '8' and t[2] == 0) { + return @intCast(t[1] - '0'); } return null; } diff --git a/zml/tensor.zig b/zml/tensor.zig index e62965e..5eedd72 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -2459,7 +2459,7 @@ pub const Tensor = struct { const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; const UpdateS = ops.BlockSign(ScatterOpts.increment); - const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); + const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); const op = dialect.stablehlo.scatter( mlir_ctx,