From f675a203c2e5c2501ff09152c08034da3bbd4d8d Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 21 Jul 2023 09:01:01 +0000 Subject: [PATCH] =?UTF-8?q?zml.ops.makeBlock=20now=20returns=20the=20inner?= =?UTF-8?q?=20tensor=20to=20propagate=20tags.=20The=20function=20returns?= =?UTF-8?q?=20both=20the=20created=20mlir.Block=20and=20tensors=20from=20t?= =?UTF-8?q?he=20supplied=20function,=20allowing=20shape=20and=20tag=20prop?= =?UTF-8?q?agation=20without=20exposing=20mlir.Values.=20Updated=20tests?= =?UTF-8?q?=20to=20run=20on=20non=E2=80=91CPU=20platforms.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zml/meta.zig | 22 ++++++++++++++++++++ zml/module.zig | 39 ++++++++++++++--------------------- zml/ops.zig | 55 ++++++++++++++++++++++++++++++++++--------------- zml/tensor.zig | 12 ++++++----- zml/testing.zig | 10 +++------ 5 files changed, 85 insertions(+), 53 deletions(-) diff --git a/zml/meta.zig b/zml/meta.zig index d527765..2666ebb 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -479,6 +479,28 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(s if (context.oom) return error.OutOfMemory; } +/// Given a func(X) -> Y or a func(Ctx, X) -> Y, +/// finds all X in the given object, and write the result of func(X) into an arraylist. +pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void { + stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)}); + const LocalContext = struct { + func_ctx: _CollectCtx(func), + out: @TypeOf(out), + idx: usize = 0, + }; + var context = LocalContext{ .func_ctx = func_ctx, .out = out }; + visit((struct { + fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void { + if (ctx.idx >= ctx.out.len) return; + + const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*); + ctx.out[ctx.idx] = res; + ctx.idx += 1; + } + }).cb, &context, obj); + std.debug.assert(context.idx == context.out.len); +} + fn _CollectCtx(func: anytype) type { const params = @typeInfo(@TypeOf(func)).Fn.params; if (params.len == 1) return void; diff --git a/zml/module.zig b/zml/module.zig index 2fcdb66..20da81d 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -1,7 +1,6 @@ const asynk = @import("async"); const builtin = @import("builtin"); const dialect = @import("mlir/dialects"); -const protobuf = @import("io/protobuf"); const runfiles = @import("runfiles"); const std = @import("std"); const stdx = @import("stdx"); @@ -142,13 +141,17 @@ pub const CompilationContext = struct { /// Transform a Tensor -> Tensor function into an Mlir block. /// `blkctx` represents values from outside the block that can be accessed inside the block. + /// Returns both the mlir.Block created and also the Tensors returned by `func`. + /// The returned tensors should not be returned to the user, + /// because their `mlir.Value` must not escape the block that created them. + /// But their shapes/tags can be safely propagated further. pub fn makeBlock( self: *CompilationContext, comptime S: ops.BlockSignature, func: *const S.Fn, blkctx: S.BlkCtx, args: S.Args, - ) mlir.Block { + ) struct { mlir.Block, S.Return } { const N = S.nIn; const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N; var input_types: [N]mlir.Type = undefined; @@ -172,7 +175,7 @@ pub const CompilationContext = struct { const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc); block.addOperationsRecursive(block_ret); - return block; + return .{ block, block_res }; } /// Generate an MLIR function from a ZML function. @@ -502,12 +505,12 @@ pub const CompilationContext = struct { const loc = self.mlirCtx().location(@src()); const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable; - extractValues(model, values[0..function.n_model]); - extractValues(args, values[function.n_model..]); + self.extractValues(&model, values[0..function.n_model]); + self.extractValues(&args, values[function.n_model..]); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); var res: stdx.meta.FnResult(func) = undefined; - assignResults(&res, function.res_shapes, op); + assignResults(op, &res, function.res_shapes); return res; } @@ -595,24 +598,12 @@ pub const CompilationContext = struct { }; } - /// Visit the given struct and copies the mlir.Value associated with each tensor found. - pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void { - const LocalContext = struct { - self: *const CompilationContext, - index: usize = 0, - values: []mlir.Value, - }; - var context = LocalContext{ .self = self, .values = values }; - meta.visit((struct { - fn cb(ctx: *LocalContext, tensor: *const Tensor) void { - const value, const donation = ctx.self.getValueAndDonation(tensor.*); - _ = donation; + fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value { + return self.getValueAndDonation(tensor)[0]; + } - ctx.values[ctx.index] = value; - ctx.index += 1; - } - }).cb, &context, v); - assert(context.index == values.len); + pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void { + meta.collectBuf(getValue, self, v, values); } }; @@ -721,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr } /// Visit the given struct and assign op results to each tensor found. -pub fn assignResults(v: anytype, shapes: ?[]Shape, op: mlir.Operation) void { +fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void { const LocalContext = struct { index: usize, op: mlir.Operation, diff --git a/zml/ops.zig b/zml/ops.zig index e171047..170b4ed 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -47,8 +47,9 @@ pub fn while_( @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); } const ctx = CompilationContext.current(); - const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); - const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs); + const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); + + const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs); var input_values: [BodyS.nIn]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -63,9 +64,7 @@ pub fn while_( .location = loc, }); - var res: BodyS.Args = inputs; - module.assignResults(&res, null, op); - return res; + return fromMlirOperationWithTags(op, body_res); } test "simple while" { @@ -139,7 +138,7 @@ pub fn reduce( var init_values: [N]mlir.Value = undefined; ctx.extractValues(&inits, &init_values); - const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); + const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const loc = ctx.mlirCtx().location(@src()); @@ -228,7 +227,7 @@ pub fn reduceWindow( if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); } const ctx = CompilationContext.current(); - const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); + const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const N = comptime @divExact(BodyS.nIn, 2); var input_values: [N]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -255,9 +254,7 @@ pub fn reduceWindow( .location = loc, }); - var res: BodyS.Return = inputs; - module.assignResults(&res, null, op); - return res; + return fromMlirOperationWithTags(op, inputs); } /// Runs a given function for several steps, and returns a stack of each step output. @@ -407,10 +404,11 @@ pub fn if_( @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); } const ctx = CompilationContext.current(); - const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); - const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); - const loc = ctx.mlirCtx().location(@src()); + const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); + const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); + stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res }); + const loc = ctx.mlirCtx().location(@src()); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ .operands = &.{pred.value()}, .result_type_inference = true, @@ -420,9 +418,7 @@ pub fn if_( .location = loc, }); - var res: TrueBlockSignature.Return = undefined; - module.assignResults(&res, null, op); - return res; + return fromMlirOperationWithTags(op, true_branch_res); } test "if" { @@ -470,7 +466,7 @@ pub fn sort( inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; } const ctx = CompilationContext.current(); - const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); + const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; ctx.extractValues(&inputs, &input_values); @@ -618,6 +614,31 @@ pub fn staticCountTensors(comptime T: type) ?usize { }; } +/// Create a Tensor struct similar to base, keeping base tags, +/// but using mlir value and dims from the mlir operation. +pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base) { + const LocalContext = struct { + index: usize, + op: mlir.Operation, + }; + var context = LocalContext{ .index = 0, .op = op }; + var res = base; + meta.visit((struct { + fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { + var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); + stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new }); + // copy tags and sharding info over + // some ops can change dims eg reduceWindow, so we trust mlir here. + new._shape._tags = tensor._shape._tags; + new._shape._sharding_info = tensor._shape._sharding_info; + tensor.* = new; + inner_ctx.index += 1; + } + }).cb, &context, &res); + assert(context.index == op.numResults()); + return res; +} + /// Produces a custom call to `name` that takes a tensor and returns it. /// /// For example, this can be used to extract tokens quickly if they run on a loop on the diff --git a/zml/tensor.zig b/zml/tensor.zig index 0af4559..e62965e 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1306,7 +1306,7 @@ pub const Tensor = struct { var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK; padding[a] = .{ self.dim(a) - 1, 0 }; - var res = ops.reduceWindow( + return ops.reduceWindow( Tensor.add, self, Tensor.scalar(0, self.dtype()), @@ -1318,8 +1318,6 @@ pub const Tensor = struct { .padding = padding[0..rk], }, ); - res._shape = self._shape; - return res; } test cumulativeSum { @@ -1328,7 +1326,11 @@ pub const Tensor = struct { const Local = struct { pub fn _cumsum(input: Tensor) Tensor { - return input.withPartialTags(.{.n}).cumulativeSum(.n); + const x = input.withPartialTags(.{.n}); + const y = x.cumulativeSum(.n); + // Check that tags are propagated + std.debug.assert(y.shape().eqlWithTags(x.shape())); + return y; } }; @@ -2457,7 +2459,7 @@ pub const Tensor = struct { const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; const UpdateS = ops.BlockSign(ScatterOpts.increment); - const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); + const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); const op = dialect.stablehlo.scatter( mlir_ctx, diff --git a/zml/testing.zig b/zml/testing.zig index 52b1b87..ba7710a 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -23,7 +23,8 @@ pub fn env() zml.Platform { _ctx = zml.Context.init() catch unreachable; } - return _ctx.?.platforms.get(.cpu).?.withCompilationOptions(_test_compile_opts); + + return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts); } var _test_compile_opts: zml.CompilationOptions = .{}; @@ -108,12 +109,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { }, inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| { const T = t.toZigType(); - const left_data = left.items(T); - const right_data = right.items(T); - if (!std.mem.eql(T, left_data, right_data)) { - log.err("left.data ({any}) != right.data ({any})", .{ left_data[0..10], right_data[0..10] }); - return error.TestUnexpectedResult; - } + return std.testing.expectEqualSlices(T, left.items(T), right.items(T)); }, .c64, .c128 => @panic("TODO: support comparison of complex"), }