Update FnCache to copy and reuse non‑tensor fields in fixed‑size structs, preventing undefined memory in core modules.
This commit is contained in:
parent
dfe55b0d34
commit
05944b5cc9
@ -131,6 +131,12 @@ pub const Buffer = struct {
|
||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
|
||||
}
|
||||
|
||||
/// Copies the given Zig slice to the accelerator memory and
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
|
||||
return from(platform, HostBuffer.fromBytes(sh, data));
|
||||
}
|
||||
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
||||
|
||||
39
zml/exe.zig
39
zml/exe.zig
@ -74,8 +74,9 @@ pub fn compileFn(
|
||||
args: ShapeOf(stdx.meta.FnArgs(func)),
|
||||
platform: Platform,
|
||||
) !FnExe(func) {
|
||||
const name = @typeName(@TypeOf(func));
|
||||
var context = try CompilationContext.init(allocator, name, platform);
|
||||
var pretty_name = try prettyFnName(func, allocator);
|
||||
defer pretty_name.deinit(allocator);
|
||||
var context = try CompilationContext.init(allocator, pretty_name.items, platform);
|
||||
defer context.deinit();
|
||||
|
||||
return .{ .inner = try context.compileInternal(allocator, func, args) };
|
||||
@ -353,3 +354,37 @@ fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Bu
|
||||
}).cb, &local_ctx, v);
|
||||
stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
||||
}
|
||||
|
||||
fn prettyFnName(
|
||||
comptime func: anytype,
|
||||
allocator: std.mem.Allocator,
|
||||
) !std.ArrayListUnmanaged(u8) {
|
||||
const full_noisy_name = @typeName(@TypeOf(func));
|
||||
const og_len = full_noisy_name.len;
|
||||
const buffer = try allocator.alloc(u8, og_len);
|
||||
errdefer comptime unreachable; // No errors below this point.
|
||||
var out: []u8 = buffer;
|
||||
|
||||
{
|
||||
const verbose = "tensor.Tensor";
|
||||
const compact = "Tensor";
|
||||
const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer);
|
||||
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
||||
}
|
||||
|
||||
{
|
||||
const verbose = "tensor.Tensor.";
|
||||
const compact = "";
|
||||
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
|
||||
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
||||
}
|
||||
|
||||
{
|
||||
const verbose = "shape.Shape";
|
||||
const compact = "Shape";
|
||||
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
|
||||
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
|
||||
}
|
||||
|
||||
return .{ .items = out, .capacity = og_len };
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ test {
|
||||
}
|
||||
|
||||
const ShapeError = error{ DimMismatch, NotFound };
|
||||
const NOT_SET: i64 = 0;
|
||||
const NOT_SET: i64 = -2;
|
||||
const DIM_MISMATCH: i64 = -1;
|
||||
|
||||
/// Collect the given dimensions inside a struct containing tagged tensors.
|
||||
@ -26,10 +26,12 @@ pub fn collectDims(
|
||||
res: ShapeStruct(dims),
|
||||
mode: @TypeOf(mode),
|
||||
};
|
||||
|
||||
var context = LocalContext{
|
||||
.res = std.mem.zeroes(ShapeStruct(dims)),
|
||||
.res = undefined,
|
||||
.mode = mode,
|
||||
};
|
||||
@memset(std.mem.bytesAsSlice(i64, std.mem.asBytes(&context.res)), NOT_SET);
|
||||
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, shape: *const Shape) void {
|
||||
@ -96,7 +98,7 @@ test collectDims {
|
||||
collectDims(.{ .b, .d }, &model, .strict),
|
||||
error.DimMismatch,
|
||||
);
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = -1, .d = 5 });
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = DIM_MISMATCH, .d = 5 });
|
||||
}
|
||||
{
|
||||
var model: Model = .{
|
||||
@ -105,7 +107,7 @@ test collectDims {
|
||||
.bias = Shape.init(.{5}, .f32).withTags(.{.d}),
|
||||
};
|
||||
try std.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .strict), error.NotFound);
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = 0 });
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = NOT_SET });
|
||||
}
|
||||
{
|
||||
var model: Model = .{
|
||||
@ -114,7 +116,7 @@ test collectDims {
|
||||
.bias = Shape.init(.{7}, .f32).withTags(.{.d}),
|
||||
};
|
||||
try std.testing.expectEqual(collectDims(.{ .b, .d }, &model, .strict), error.DimMismatch);
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = -1 });
|
||||
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = DIM_MISMATCH });
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
164
zml/module.zig
164
zml/module.zig
@ -71,6 +71,7 @@ const Block = union(BlockKind) {
|
||||
pub const MlirFn = struct {
|
||||
name: []const u8,
|
||||
num_args: u32,
|
||||
res_tensors: *const anyopaque,
|
||||
res_types: []mlir.Type,
|
||||
res_shapes: []Shape,
|
||||
res_donations: []Tensor._Donation,
|
||||
@ -114,12 +115,15 @@ pub const CompilationContext = struct {
|
||||
var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable;
|
||||
mlir_ctx.loadAllAvailableDialects();
|
||||
|
||||
// Too long module names create too long file paths.
|
||||
const name = full_name[0..@min(128, full_name.len)];
|
||||
// Too long module names create too long file paths and files failed to create.
|
||||
// * leave half of the space for parent folder and XLA generated filename,
|
||||
// * leave 17 bytes for the module hash (16 + 1 for underscore).
|
||||
const max_name_len = @divFloor(std.fs.max_path_bytes, 2) - 17;
|
||||
const name = full_name[0..@min(max_name_len, full_name.len)];
|
||||
|
||||
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
|
||||
const module = mlir.Module.init(loc);
|
||||
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?);
|
||||
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?);
|
||||
|
||||
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
|
||||
{
|
||||
@ -130,12 +134,12 @@ pub const CompilationContext = struct {
|
||||
}
|
||||
|
||||
var arena = std.heap.ArenaAllocator.init(allocator_);
|
||||
_ = try arena.allocator().alloc(u8, std.mem.page_size);
|
||||
_ = try arena.allocator().alloc(u8, 4096);
|
||||
_ = arena.reset(.retain_capacity);
|
||||
|
||||
return .{
|
||||
._platform = platform,
|
||||
._name = name,
|
||||
._name = try arena.allocator().dupe(u8, name),
|
||||
._mlir_ctx = mlir_ctx,
|
||||
._mlir_registry = mlir_registry,
|
||||
._mlir_canonicalizer = canonicalizer,
|
||||
@ -394,7 +398,9 @@ pub const CompilationContext = struct {
|
||||
// But it forces user to have simpler function.
|
||||
const ReturnT = stdx.meta.FnResult(func);
|
||||
const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
|
||||
// Those are returned to caller so we don't put them in the arena.
|
||||
|
||||
// Those are returned to caller so we don't put them in the arena, but in the module allocator.
|
||||
const fn_res = try res_allocator.create(ReturnT);
|
||||
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
|
||||
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
|
||||
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
|
||||
@ -406,14 +412,14 @@ pub const CompilationContext = struct {
|
||||
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
|
||||
std.debug.assert(assigned_args_count == tensor_count);
|
||||
|
||||
const fn_res = forward: {
|
||||
fn_res.* = forward: {
|
||||
self.activate();
|
||||
defer self.deactivate();
|
||||
break :forward @call(.auto, func, args.*);
|
||||
};
|
||||
|
||||
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
||||
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||
|
||||
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||
fn_body.appendOperationRecursive(fn_ret);
|
||||
@ -457,6 +463,7 @@ pub const CompilationContext = struct {
|
||||
.mlir_fn = mlir_fn,
|
||||
.name = opts.name,
|
||||
.num_args = @intCast(tensor_count),
|
||||
.res_tensors = fn_res,
|
||||
.res_types = fn_res_types,
|
||||
.res_shapes = fn_res_shapes,
|
||||
.res_donations = fn_res_donations,
|
||||
@ -639,46 +646,51 @@ pub const CompilationContext = struct {
|
||||
func_name: [:0]const u8,
|
||||
comptime func: anytype,
|
||||
args: stdx.meta.FnArgs(func),
|
||||
) stdx.meta.FnResult(func) {
|
||||
) error{OutOfMemory}!stdx.meta.FnResult(func) {
|
||||
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
|
||||
defer arena_state.deinit();
|
||||
// This arena is used for allocations which won't outlive the function call,
|
||||
// but the function creation uses `self.allocator()` which we'll live for the duration of the compilation.
|
||||
const arena = arena_state.allocator();
|
||||
|
||||
// first, do the "compile" and check the bytecode
|
||||
// the result of this will also have the correct tags of the result shapes
|
||||
const args_hash = hashArgs(args);
|
||||
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash };
|
||||
const key: FnKey = .{ .fn_ptr = &func, .input_hash = args_hash };
|
||||
|
||||
const function = self._fn_cache.getEntry(key) orelse b: {
|
||||
const function = self._fn_cache.get(key) orelse b: {
|
||||
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
||||
arena.dupeZ(u8, func_name) catch unreachable
|
||||
try self.allocator().dupeZ(u8, func_name)
|
||||
else
|
||||
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable;
|
||||
try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash });
|
||||
|
||||
var arg_id: u16 = 0;
|
||||
var tensor_args: @TypeOf(args) = args;
|
||||
meta.mapAlloc(struct {
|
||||
try meta.mapAlloc(struct {
|
||||
fn cb(arg_id_: *u16, x: Tensor) Tensor {
|
||||
const a = arg_id_.*;
|
||||
arg_id_.* += 1;
|
||||
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
|
||||
}
|
||||
}.cb, arena, &arg_id, args, &tensor_args) catch @panic("OutOfMemory");
|
||||
}.cb, arena, &arg_id, args, &tensor_args);
|
||||
|
||||
const f = self.emitMlir(func, &tensor_args, .{
|
||||
.name = full_name,
|
||||
}) catch @panic("OOM");
|
||||
const f = try self.emitMlir(
|
||||
func,
|
||||
&tensor_args,
|
||||
.{ .name = full_name },
|
||||
);
|
||||
self._module.getBody().appendOperation(f.mlir_fn);
|
||||
|
||||
break :b self._fn_cache.addEntry(self.allocator(), key, f) catch unreachable;
|
||||
try self._fn_cache.putNoClobber(self.allocator(), key, f);
|
||||
break :b f;
|
||||
};
|
||||
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
|
||||
const values = arena.alloc(mlir.Value, function.num_args) catch unreachable;
|
||||
const values = try arena.alloc(mlir.Value, function.num_args);
|
||||
self.extractValues(&args, values);
|
||||
|
||||
const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable;
|
||||
const donations = try arena.alloc(Tensor._Donation, function.num_args);
|
||||
meta.collectBuf(struct {
|
||||
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
|
||||
return ctx.getValueAndDonation(x)[1];
|
||||
@ -689,9 +701,7 @@ pub const CompilationContext = struct {
|
||||
// Create the result tensor object by combining the operand results,
|
||||
// as well as the registered shapes and donations.
|
||||
// Note: this assume res can be stack-allocated.
|
||||
// Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us.
|
||||
// But this is blocked on https://github.com/zml/zml/issues/97
|
||||
var res: stdx.meta.FnResult(func) = undefined;
|
||||
var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*;
|
||||
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
|
||||
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
|
||||
meta.visit((struct {
|
||||
@ -911,6 +921,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
|
||||
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
|
||||
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
|
||||
// setFlags(&options, "xla_gpu_enable_address_computation_fusion", true);
|
||||
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
||||
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||
@ -1009,45 +1020,8 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer {
|
||||
return .{ .hasher = hasher };
|
||||
}
|
||||
|
||||
pub const FnCache = struct {
|
||||
pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 };
|
||||
|
||||
cache: std.AutoHashMapUnmanaged(Key, MlirFn) = .{},
|
||||
|
||||
pub fn deinit(self: FnCache, allocator: std.mem.Allocator) void {
|
||||
self.cache.deinit(allocator);
|
||||
}
|
||||
|
||||
pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn {
|
||||
return self.cache.get(key);
|
||||
}
|
||||
|
||||
pub fn addEntry(self: *FnCache, allocator: std.mem.Allocator, key: Key, value: MlirFn) !MlirFn {
|
||||
const res_types_copy = try allocator.dupe(mlir.Type, value.res_types);
|
||||
errdefer allocator.free(res_types_copy);
|
||||
|
||||
const res_shapes_copy = try allocator.dupe(Shape, value.res_shapes);
|
||||
errdefer allocator.free(res_shapes_copy);
|
||||
|
||||
const res_donations_copy = try allocator.dupe(Tensor._Donation, value.res_donations);
|
||||
errdefer allocator.free(res_donations_copy);
|
||||
|
||||
const name_copy = try allocator.dupeZ(u8, value.name);
|
||||
errdefer allocator.free(name_copy);
|
||||
|
||||
const owned_value: MlirFn = .{
|
||||
.name = name_copy,
|
||||
.mlir_fn = value.mlir_fn,
|
||||
.num_args = value.num_args,
|
||||
.res_types = res_types_copy,
|
||||
.res_shapes = res_shapes_copy,
|
||||
.res_donations = res_donations_copy,
|
||||
};
|
||||
|
||||
try self.cache.putNoClobber(allocator, key, owned_value);
|
||||
return owned_value;
|
||||
}
|
||||
};
|
||||
pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn);
|
||||
pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 };
|
||||
|
||||
test FnCache {
|
||||
const zml = @import("zml.zig");
|
||||
@ -1109,6 +1083,70 @@ test FnCache {
|
||||
try zml.testing.expectClose(expected, res, 1e-4);
|
||||
}
|
||||
|
||||
test "FnCache with mixed integer/tensor" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Layer = struct {
|
||||
const Layer_ = @This();
|
||||
var num_call: u32 = 0;
|
||||
|
||||
w: Tensor,
|
||||
|
||||
pub fn _fwd(self: Layer_, x: Tensor) struct { Tensor, usize } {
|
||||
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
|
||||
// Note: this is for testing only, it's a bad idea to mutate global state
|
||||
// from a forward function because it can mess with caching.
|
||||
num_call += 1;
|
||||
return .{ wx.addConstant(num_call), num_call };
|
||||
}
|
||||
};
|
||||
|
||||
const NN = struct {
|
||||
const NN_ = @This();
|
||||
layers: [3]Layer,
|
||||
|
||||
pub fn _fwd(self: NN_, x0: Tensor) Tensor {
|
||||
var x = x0;
|
||||
var y: usize = 0;
|
||||
x, y = ops.call(self.layers[0], ._fwd, .{x});
|
||||
std.debug.assert(Layer.num_call == 1);
|
||||
std.debug.assert(y == 1);
|
||||
// Here we call a second time but since first two layers have the same shape,
|
||||
// We hit the function cache, and "num_call" is not incremented.
|
||||
x, y = ops.call(self.layers[1], ._fwd, .{x});
|
||||
std.debug.assert(Layer.num_call == 1);
|
||||
std.debug.assert(y == 1);
|
||||
x, y = ops.call(self.layers[2], ._fwd, .{x});
|
||||
std.debug.assert(Layer.num_call == 2);
|
||||
std.debug.assert(y == 2);
|
||||
return x;
|
||||
}
|
||||
|
||||
pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor {
|
||||
var x = x0;
|
||||
for (self.layers, &[_]u32{ 1, 1, 2 }) |layer, bias| {
|
||||
const wx = layer.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
|
||||
x = wx.addConstant(bias);
|
||||
}
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
||||
const nn: zml.Bufferized(NN) = .{
|
||||
.layers = .{
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) },
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) },
|
||||
// third layer has different shape
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) },
|
||||
},
|
||||
};
|
||||
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
|
||||
const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x });
|
||||
try zml.testing.expectClose(expected, res, 1e-4);
|
||||
}
|
||||
|
||||
pub fn hashArgs(mod: anytype) u64 {
|
||||
var hasher = std.hash.Wyhash.init(0);
|
||||
hash(&hasher, mod, .DeepRecursive);
|
||||
|
||||
32
zml/ops.zig
32
zml/ops.zig
@ -34,7 +34,7 @@ pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), arg
|
||||
const ctx = CompilationContext.current();
|
||||
const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func);
|
||||
const actual_fn = @field(@TypeOf(self), @tagName(func));
|
||||
return ctx.callFunc(name, actual_fn, .{self} ++ args);
|
||||
return ctx.callFunc(name, actual_fn, .{self} ++ args) catch @panic("OOM");
|
||||
}
|
||||
|
||||
pub fn while_(
|
||||
@ -445,14 +445,36 @@ pub fn if_(
|
||||
if (TrueBlockSignature.Return != FalseBlockSignature.Return) {
|
||||
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
|
||||
}
|
||||
|
||||
stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred});
|
||||
|
||||
const ctx = CompilationContext.current();
|
||||
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||
stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
|
||||
|
||||
var true_shapes = std.ArrayList(Shape).init(ctx.allocator());
|
||||
defer true_shapes.deinit();
|
||||
var false_shapes = std.ArrayList(Shape).init(ctx.allocator());
|
||||
defer false_shapes.deinit();
|
||||
|
||||
var failed_to_collect = false;
|
||||
meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch {
|
||||
failed_to_collect = true;
|
||||
};
|
||||
meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch {
|
||||
failed_to_collect = true;
|
||||
};
|
||||
if (!failed_to_collect) {
|
||||
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
|
||||
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
|
||||
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
|
||||
}
|
||||
}
|
||||
|
||||
const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0);
|
||||
const loc = ctx.mlirCtx().location(@src());
|
||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
|
||||
.operands = &.{pred.value()},
|
||||
.operands = &.{scalar_pred.value()},
|
||||
.result_type_inference = true,
|
||||
.blocks = &.{ true_branch_block, false_branch_block },
|
||||
// We can't verify right away, cause the weights captured by the if haven't been added yet.
|
||||
@ -791,7 +813,7 @@ pub fn scatter(
|
||||
|
||||
// validate coord axes: all coord_axes should exist inside self
|
||||
for (indices_axes.constSlice()) |t| {
|
||||
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={any}", .{ self, indices_axes });
|
||||
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() });
|
||||
}
|
||||
|
||||
// Handle scalar indices by broadcasting them to the indices with the highest rank.
|
||||
@ -905,7 +927,7 @@ fn scatterConfig(
|
||||
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
|
||||
}
|
||||
for (indices.tags()) |t| {
|
||||
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got updates={} and indices={s}", .{ update, indices_axes.constSlice() });
|
||||
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices });
|
||||
updates_transpose.appendAssumeCapacity(update.axis(t));
|
||||
}
|
||||
|
||||
|
||||
@ -1743,7 +1743,7 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a Tensor containing evenly spaced values within a given interval.
|
||||
pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
|
||||
stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||
stdx.debug.assert(args.start <= args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||
|
||||
const ctx = CompilationContext.current();
|
||||
|
||||
@ -17,7 +17,6 @@ pub fn env() zml.Platform {
|
||||
_platform = ctx.autoPlatform(.{}).withCompilationOptions(.{
|
||||
.xla_dump_to = "/tmp/zml/tests/",
|
||||
.sharding_enabled = true,
|
||||
.xla_dump_hlo_pass_re = ".*",
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user