zml.ops.makeBlock now returns the inner tensor to propagate tags. The function returns both the created mlir.Block and tensors from the supplied function, allowing shape and tag propagation without exposing mlir.Values. Updated tests to run on non‑CPU platforms.
This commit is contained in:
parent
be8aa4fa8e
commit
f675a203c2
22
zml/meta.zig
22
zml/meta.zig
@ -479,6 +479,28 @@ pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.ArrayList(s
|
||||
if (context.oom) return error.OutOfMemory;
|
||||
}
|
||||
|
||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
|
||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
|
||||
const LocalContext = struct {
|
||||
func_ctx: _CollectCtx(func),
|
||||
out: @TypeOf(out),
|
||||
idx: usize = 0,
|
||||
};
|
||||
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
||||
visit((struct {
|
||||
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
|
||||
if (ctx.idx >= ctx.out.len) return;
|
||||
|
||||
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
|
||||
ctx.out[ctx.idx] = res;
|
||||
ctx.idx += 1;
|
||||
}
|
||||
}).cb, &context, obj);
|
||||
std.debug.assert(context.idx == context.out.len);
|
||||
}
|
||||
|
||||
fn _CollectCtx(func: anytype) type {
|
||||
const params = @typeInfo(@TypeOf(func)).Fn.params;
|
||||
if (params.len == 1) return void;
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
const asynk = @import("async");
|
||||
const builtin = @import("builtin");
|
||||
const dialect = @import("mlir/dialects");
|
||||
const protobuf = @import("io/protobuf");
|
||||
const runfiles = @import("runfiles");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
@ -142,13 +141,17 @@ pub const CompilationContext = struct {
|
||||
|
||||
/// Transform a Tensor -> Tensor function into an Mlir block.
|
||||
/// `blkctx` represents values from outside the block that can be accessed inside the block.
|
||||
/// Returns both the mlir.Block created and also the Tensors returned by `func`.
|
||||
/// The returned tensors should not be returned to the user,
|
||||
/// because their `mlir.Value` must not escape the block that created them.
|
||||
/// But their shapes/tags can be safely propagated further.
|
||||
pub fn makeBlock(
|
||||
self: *CompilationContext,
|
||||
comptime S: ops.BlockSignature,
|
||||
func: *const S.Fn,
|
||||
blkctx: S.BlkCtx,
|
||||
args: S.Args,
|
||||
) mlir.Block {
|
||||
) struct { mlir.Block, S.Return } {
|
||||
const N = S.nIn;
|
||||
const locations = .{mlir.Location.unknown(self.mlirCtx())} ** N;
|
||||
var input_types: [N]mlir.Type = undefined;
|
||||
@ -172,7 +175,7 @@ pub const CompilationContext = struct {
|
||||
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
|
||||
block.addOperationsRecursive(block_ret);
|
||||
|
||||
return block;
|
||||
return .{ block, block_res };
|
||||
}
|
||||
|
||||
/// Generate an MLIR function from a ZML function.
|
||||
@ -502,12 +505,12 @@ pub const CompilationContext = struct {
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
|
||||
const values = arena.alloc(mlir.Value, function.n_model + function.n_args) catch unreachable;
|
||||
extractValues(model, values[0..function.n_model]);
|
||||
extractValues(args, values[function.n_model..]);
|
||||
self.extractValues(&model, values[0..function.n_model]);
|
||||
self.extractValues(&args, values[function.n_model..]);
|
||||
|
||||
const op = dialect.func.call(self.mlirCtx(), function.name, values, function.res_types, loc);
|
||||
var res: stdx.meta.FnResult(func) = undefined;
|
||||
assignResults(&res, function.res_shapes, op);
|
||||
assignResults(op, &res, function.res_shapes);
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -595,24 +598,12 @@ pub const CompilationContext = struct {
|
||||
};
|
||||
}
|
||||
|
||||
/// Visit the given struct and copies the mlir.Value associated with each tensor found.
|
||||
pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void {
|
||||
const LocalContext = struct {
|
||||
self: *const CompilationContext,
|
||||
index: usize = 0,
|
||||
values: []mlir.Value,
|
||||
};
|
||||
var context = LocalContext{ .self = self, .values = values };
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
|
||||
const value, const donation = ctx.self.getValueAndDonation(tensor.*);
|
||||
_ = donation;
|
||||
|
||||
ctx.values[ctx.index] = value;
|
||||
ctx.index += 1;
|
||||
fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value {
|
||||
return self.getValueAndDonation(tensor)[0];
|
||||
}
|
||||
}).cb, &context, v);
|
||||
assert(context.index == values.len);
|
||||
|
||||
pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void {
|
||||
meta.collectBuf(getValue, self, v, values);
|
||||
}
|
||||
};
|
||||
|
||||
@ -721,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
|
||||
}
|
||||
|
||||
/// Visit the given struct and assign op results to each tensor found.
|
||||
pub fn assignResults(v: anytype, shapes: ?[]Shape, op: mlir.Operation) void {
|
||||
fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void {
|
||||
const LocalContext = struct {
|
||||
index: usize,
|
||||
op: mlir.Operation,
|
||||
|
||||
55
zml/ops.zig
55
zml/ops.zig
@ -47,8 +47,9 @@ pub fn while_(
|
||||
@compileError("cond_fn and body_fn signatures don't match ! " ++ @typeName(@TypeOf(cond_fn)) ++ " and " ++ @typeName(@TypeOf(body_fn)));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const cond_block = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
|
||||
const cond_block, _ = ctx.makeBlock(CondS, &cond_fn, blkctx, inputs);
|
||||
|
||||
const body_block, const body_res = ctx.makeBlock(BodyS, &body_fn, blkctx, inputs);
|
||||
var input_values: [BodyS.nIn]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
@ -63,9 +64,7 @@ pub fn while_(
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
var res: BodyS.Args = inputs;
|
||||
module.assignResults(&res, null, op);
|
||||
return res;
|
||||
return fromMlirOperationWithTags(op, body_res);
|
||||
}
|
||||
|
||||
test "simple while" {
|
||||
@ -139,7 +138,7 @@ pub fn reduce(
|
||||
var init_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inits, &init_values);
|
||||
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
|
||||
@ -228,7 +227,7 @@ pub fn reduceWindow(
|
||||
if (BodyS.Return != @TypeOf(inputs)) @compileError("reduce body function need to have the following signature `fn (left: T, right: T) T`, got: " ++ @typeName(body_fn));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const body_block = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const body_block, _ = ctx.makeBlock(BodyS, &body_fn, {}, .{ inits, inits });
|
||||
const N = comptime @divExact(BodyS.nIn, 2);
|
||||
var input_values: [N]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
@ -255,9 +254,7 @@ pub fn reduceWindow(
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
var res: BodyS.Return = inputs;
|
||||
module.assignResults(&res, null, op);
|
||||
return res;
|
||||
return fromMlirOperationWithTags(op, inputs);
|
||||
}
|
||||
|
||||
/// Runs a given function for several steps, and returns a stack of each step output.
|
||||
@ -407,10 +404,11 @@ pub fn if_(
|
||||
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const true_branch_block = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
const true_branch_block, const true_branch_res = ctx.makeBlock(TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block, const false_branch_res = ctx.makeBlock(TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
|
||||
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
|
||||
.operands = &.{pred.value()},
|
||||
.result_type_inference = true,
|
||||
@ -420,9 +418,7 @@ pub fn if_(
|
||||
.location = loc,
|
||||
});
|
||||
|
||||
var res: TrueBlockSignature.Return = undefined;
|
||||
module.assignResults(&res, null, op);
|
||||
return res;
|
||||
return fromMlirOperationWithTags(op, true_branch_res);
|
||||
}
|
||||
|
||||
test "if" {
|
||||
@ -470,7 +466,7 @@ pub fn sort(
|
||||
inits[i * 2 + 1] = Tensor{ ._shape = arg_shape, ._id = undefined, ._donation = .no_buffer };
|
||||
}
|
||||
const ctx = CompilationContext.current();
|
||||
const block = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
|
||||
const block, _ = ctx.makeBlock(BodyS, &comp_fn, blkctx, inits);
|
||||
var input_values: [@divExact(BodyS.nIn, 2)]mlir.Value = undefined;
|
||||
ctx.extractValues(&inputs, &input_values);
|
||||
|
||||
@ -618,6 +614,31 @@ pub fn staticCountTensors(comptime T: type) ?usize {
|
||||
};
|
||||
}
|
||||
|
||||
/// Create a Tensor struct similar to base, keeping base tags,
|
||||
/// but using mlir value and dims from the mlir operation.
|
||||
pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base) {
|
||||
const LocalContext = struct {
|
||||
index: usize,
|
||||
op: mlir.Operation,
|
||||
};
|
||||
var context = LocalContext{ .index = 0, .op = op };
|
||||
var res = base;
|
||||
meta.visit((struct {
|
||||
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
|
||||
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
|
||||
stdx.debug.internalAssert(new.rank() == tensor.rank(), "expected operand result to have rank {} but got {}", .{ tensor.rank(), new });
|
||||
// copy tags and sharding info over
|
||||
// some ops can change dims eg reduceWindow, so we trust mlir here.
|
||||
new._shape._tags = tensor._shape._tags;
|
||||
new._shape._sharding_info = tensor._shape._sharding_info;
|
||||
tensor.* = new;
|
||||
inner_ctx.index += 1;
|
||||
}
|
||||
}).cb, &context, &res);
|
||||
assert(context.index == op.numResults());
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Produces a custom call to `name` that takes a tensor and returns it.
|
||||
///
|
||||
/// For example, this can be used to extract tokens quickly if they run on a loop on the
|
||||
|
||||
@ -1306,7 +1306,7 @@ pub const Tensor = struct {
|
||||
var padding = [_][2]i64{.{ 0, 0 }} ** MAX_RANK;
|
||||
padding[a] = .{ self.dim(a) - 1, 0 };
|
||||
|
||||
var res = ops.reduceWindow(
|
||||
return ops.reduceWindow(
|
||||
Tensor.add,
|
||||
self,
|
||||
Tensor.scalar(0, self.dtype()),
|
||||
@ -1318,8 +1318,6 @@ pub const Tensor = struct {
|
||||
.padding = padding[0..rk],
|
||||
},
|
||||
);
|
||||
res._shape = self._shape;
|
||||
return res;
|
||||
}
|
||||
|
||||
test cumulativeSum {
|
||||
@ -1328,7 +1326,11 @@ pub const Tensor = struct {
|
||||
|
||||
const Local = struct {
|
||||
pub fn _cumsum(input: Tensor) Tensor {
|
||||
return input.withPartialTags(.{.n}).cumulativeSum(.n);
|
||||
const x = input.withPartialTags(.{.n});
|
||||
const y = x.cumulativeSum(.n);
|
||||
// Check that tags are propagated
|
||||
std.debug.assert(y.shape().eqlWithTags(x.shape()));
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
@ -2457,7 +2459,7 @@ pub const Tensor = struct {
|
||||
|
||||
const _scalar: Tensor = .{ ._shape = Shape.init(.{}, self.dtype()), ._id = undefined };
|
||||
const UpdateS = ops.BlockSign(ScatterOpts.increment);
|
||||
const update_block = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||||
const update_block, _ = ctx.makeBlock(UpdateS, opts.update_fn, opts.update_fn_ctx, .{ _scalar, _scalar });
|
||||
|
||||
const op = dialect.stablehlo.scatter(
|
||||
mlir_ctx,
|
||||
|
||||
@ -23,7 +23,8 @@ pub fn env() zml.Platform {
|
||||
|
||||
_ctx = zml.Context.init() catch unreachable;
|
||||
}
|
||||
return _ctx.?.platforms.get(.cpu).?.withCompilationOptions(_test_compile_opts);
|
||||
|
||||
return _ctx.?.autoPlatform().withCompilationOptions(_test_compile_opts);
|
||||
}
|
||||
|
||||
var _test_compile_opts: zml.CompilationOptions = .{};
|
||||
@ -108,12 +109,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
},
|
||||
inline .bool, .u4, .u8, .u16, .u32, .u64, .i4, .i8, .i16, .i32, .i64 => |t| {
|
||||
const T = t.toZigType();
|
||||
const left_data = left.items(T);
|
||||
const right_data = right.items(T);
|
||||
if (!std.mem.eql(T, left_data, right_data)) {
|
||||
log.err("left.data ({any}) != right.data ({any})", .{ left_data[0..10], right_data[0..10] });
|
||||
return error.TestUnexpectedResult;
|
||||
}
|
||||
return std.testing.expectEqualSlices(T, left.items(T), right.items(T));
|
||||
},
|
||||
.c64, .c128 => @panic("TODO: support comparison of complex"),
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user