zml.ops.makeBlock now returns the inner tensor to propagate tags. The function returns both the created mlir.Block and tensors from the supplied function, allowing shape and tag propagation without exposing mlir.Values. Updated tests to run on non‑CPU platforms.

This commit is contained in:
Tarry Singh 2023-07-21 09:01:01 +00:00
parent be8aa4fa8e
commit f675a203c2
5 changed files with 85 additions and 53 deletions

View File

@ -479,6 +479,28 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(s
if (context.oom) return error.OutOfMemory; if (context.oom) return error.OutOfMemory;
} }
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist.
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
const LocalContext = struct {
func_ctx: _CollectCtx(func),
out: @TypeOf(out),
idx: usize = 0,
};
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
visit((struct {
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
if (ctx.idx >= ctx.out.len) return;
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
ctx.out[ctx.idx] = res;
ctx.idx += 1;
}
}).cb, &context, obj);
std.debug.assert(context.idx == context.out.len);
}
fn _CollectCtx(func: anytype) type { fn _CollectCtx(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).Fn.params; const params = @typeInfo(@TypeOf(func)).Fn.params;
if (params.len == 1) return void; if (params.len == 1) return void;

View File

@ -1,7 +1,6 @@
const asynk = @import("async"); const asynk = @import("async");
const builtin = @import("builtin"); const builtin = @import("builtin");
const dialect = @import("mlir/dialects"); const dialect = @import("mlir/dialects");
const protobuf = @import("io/protobuf");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
@ -142,13 +141,17 @@ pub const CompilationContext = struct {
/// Transform a Tensor -> Tensor function into an Mlir block. /// Transform a Tensor -> Tensor function into an Mlir block.
/// `blkctx` represents values from outside the block that can be accessed inside the block. /// `blkctx` represents values from outside the block that can be accessed inside the block.
/// Returns both the mlir.Block created and also the Tensors returned by `func`.
/// The returned tensors should not be returned to the user,
/// because their `mlir.Value` must not escape the block that created them.
/// But their shapes/tags can be safely propagated further.
pub fn makeBlock( pub fn makeBlock(
self: *CompilationContext, self: *CompilationContext,
comptime S: ops.BlockSignature, comptime S: ops.BlockSignature,
func: *const S.Fn, func: *const S.Fn,
blkctx: S.BlkCtx, blkctx: S.BlkCtx,
args: S.Args, args: S.Args,
) mlir.Block { ) struct { mlir.Block, S.Return } {
const N = S.nIn; const N = S.nIn;
const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N; const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N;
var input_types: [N]mlir.Type = undefined; var input_types: [N]mlir.Type = undefined;
@ -172,7 +175,7 @@ pub const CompilationContext = struct {
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc); const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
block.addOperationsRecursive(block_ret); block.addOperationsRecursive(block_ret);
return block; return .{ block, block_res };
} }
/// Generate an MLIR function from a ZML function. /// Generate an MLIR function from a ZML function.
@ -502,12 +505,12 @@ pub const CompilationContext = struct {
const loc = self.mlirCtx().location(@src()); const loc = self.mlirCtx().location(@src());
const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable; const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable;
extractValues(model, values[0..function.n_model]); self.extractValues(&model, values[0..function.n_model]);
extractValues(args, values[function.n_model..]); self.extractValues(&args, values[function.n_model..]);
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc); const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
var res: stdx.meta.FnResult(func) = undefined; var res: stdx.meta.FnResult(func) = undefined;
assignResults(&res, function.res_shapes, op); assignResults(op, &res, function.res_shapes);
return res; return res;
} }
@ -595,24 +598,12 @@ pub const CompilationContext = struct {
}; };
} }
/// Visit the given struct and copies the mlir.Value associated with each tensor found. fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value {
pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void { return self.getValueAndDonation(tensor)[0];
const LocalContext = struct { }
self: *const CompilationContext,
index: usize = 0,
values: []mlir.Value,
};
var context = LocalContext{ .self = self, .values = values };
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
const value, const donation = ctx.self.getValueAndDonation(tensor.*);
_ = donation;
ctx.values[ctx.index] = value; pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void {
ctx.index += 1; meta.collectBuf(getValue, self, v, values);
}
}).cb, &context, v);
assert(context.index == values.len);
} }
}; };
@ -721,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
} }
/// Visit the given struct and assign op results to each tensor found. /// Visit the given struct and assign op results to each tensor found.
pub fn assignResults(v: anytype, shapes: ?[]Shape, op: mlir.Operation) void { fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void {
const LocalContext = struct { const LocalContext = struct {
index: usize, index: usize,
op: mlir.Operation, op: mlir.Operation,

View File

@ -47,8 +47,9 @@ pub fn while_(
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn))); @compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs); const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
var input_values: [BodyS.nIn]mlir.Value = undefined; var input_values: [BodyS.nIn]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -63,9 +64,7 @@ pub fn while_(
.location = loc, .location = loc,
}); });
var res: BodyS.Args = inputs; return fromMlirOperationWithTags(op, body_res);
module.assignResults(&res, null, op);
return res;
} }
test "simple while" { test "simple while" {
@ -139,7 +138,7 @@ pub fn reduce(
var init_values: [N]mlir.Value = undefined; var init_values: [N]mlir.Value = undefined;
ctx.extractValues(&inits, &init_values); ctx.extractValues(&inits, &init_values);
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
@ -228,7 +227,7 @@ pub fn reduceWindow(
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn)); if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits }); const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
const N = comptime @divExact(BodyS.nIn, 2); const N = comptime @divExact(BodyS.nIn, 2);
var input_values: [N]mlir.Value = undefined; var input_values: [N]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -255,9 +254,7 @@ pub fn reduceWindow(
.location = loc, .location = loc,
}); });
var res: BodyS.Return = inputs; return fromMlirOperationWithTags(op, inputs);
module.assignResults(&res, null, op);
return res;
} }
/// Runs a given function for several steps, and returns a stack of each step output. /// Runs a given function for several steps, and returns a stack of each step output.
@ -407,10 +404,11 @@ pub fn if_(
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {}); const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
const loc = ctx.mlirCtx().location(@src()); stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
const loc = ctx.mlirCtx().location(@src());
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
.operands = &.{pred.value()}, .operands = &.{pred.value()},
.result_type_inference = true, .result_type_inference = true,
@ -420,9 +418,7 @@ pub fn if_(
.location = loc, .location = loc,
}); });
var res: TrueBlockSignature.Return = undefined; return fromMlirOperationWithTags(op, true_branch_res);
module.assignResults(&res, null, op);
return res;
} }
test "if" { test "if" {
@ -470,7 +466,7 @@ pub fn sort(
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer }; inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
} }
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits); const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined; var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
ctx.extractValues(&inputs, &input_values); ctx.extractValues(&inputs, &input_values);
@ -618,6 +614,31 @@ pub fn staticCountTensors(comptime T: type) ?usize {
}; };
} }
/// Create a Tensor struct similar to base, keeping base tags,
/// but using mlir value and dims from the mlir operation.
pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base) {
const LocalContext = struct {
index: usize,
op: mlir.Operation,
};
var context = LocalContext{ .index = 0, .op = op };
var res = base;
meta.visit((struct {
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new });
// copy tags and sharding info over
// some ops can change dims eg reduceWindow, so we trust mlir here.
new._shape._tags = tensor._shape._tags;
new._shape._sharding_info = tensor._shape._sharding_info;
tensor.* = new;
inner_ctx.index += 1;
}
}).cb, &context, &res);
assert(context.index == op.numResults());
return res;
}
/// Produces a custom call to `name` that takes a tensor and returns it. /// Produces a custom call to `name` that takes a tensor and returns it.
/// ///
/// For example, this can be used to extract tokens quickly if they run on a loop on the /// For example, this can be used to extract tokens quickly if they run on a loop on the

View File

@ -1306,7 +1306,7 @@ pub const Tensor = struct {
var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK; var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK;
padding[a] = .{ self.dim(a) - 1, 0 }; padding[a] = .{ self.dim(a) - 1, 0 };
var res = ops.reduceWindow( return ops.reduceWindow(
Tensor.add, Tensor.add,
self, self,
Tensor.scalar(0, self.dtype()), Tensor.scalar(0, self.dtype()),
@ -1318,8 +1318,6 @@ pub const Tensor = struct {
.padding = padding[0..rk], .padding = padding[0..rk],
}, },
); );
res._shape = self._shape;
return res;
} }
test cumulativeSum { test cumulativeSum {
@ -1328,7 +1326,11 @@ pub const Tensor = struct {
const Local = struct { const Local = struct {
pub fn _cumsum(input: Tensor) Tensor { pub fn _cumsum(input: Tensor) Tensor {
return input.withPartialTags(.{.n}).cumulativeSum(.n); const x = input.withPartialTags(.{.n});
const y = x.cumulativeSum(.n);
// Check that tags are propagated
std.debug.assert(y.shape().eqlWithTags(x.shape()));
return y;
} }
}; };
@ -2457,7 +2459,7 @@ pub const Tensor = struct {
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined }; const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
const UpdateS = ops.BlockSign(ScatterOpts.increment); const UpdateS = ops.BlockSign(ScatterOpts.increment);
const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar }); const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
const op = dialect.stablehlo.scatter( const op = dialect.stablehlo.scatter(
mlir_ctx, mlir_ctx,

View File

@ -23,7 +23,8 @@ pub fn env() zml.Platform {
_ctx = zml.Context.init() catch unreachable; _ctx = zml.Context.init() catch unreachable;
} }
return _ctx.?.platforms.get(.cpu).?.withCompilationOptions(_test_compile_opts);
return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts);
} }
var _test_compile_opts: zml.CompilationOptions = .{}; var _test_compile_opts: zml.CompilationOptions = .{};
@ -108,12 +109,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
}, },
inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| { inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| {
const T = t.toZigType(); const T = t.toZigType();
const left_data = left.items(T); return std.testing.expectEqualSlices(T, left.items(T), right.items(T));
const right_data = right.items(T);
if (!std.mem.eql(T, left_data, right_data)) {
log.err("left.data ({any}) != right.data ({any})", .{ left_data[0..10], right_data[0..10] });
return error.TestUnexpectedResult;
}
}, },
.c64, .c128 => @panic("TODO: support comparison of complex"), .c64, .c128 => @panic("TODO: support comparison of complex"),
} }