Update FnCache to copy and reuse non‑tensor fields in fixed‑size structs, preventing undefined memory in core modules.

This commit is contained in:
Tarry Singh 2024-05-15 17:54:52 +00:00
parent dfe55b0d34
commit 05944b5cc9
7 changed files with 179 additions and 77 deletions

View File

@ -131,6 +131,12 @@ pub const Buffer = struct {
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
} }
/// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions.
pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer {
return from(platform, HostBuffer.fromBytes(sh, data));
}
/// Copies the given Zig array to the accelerator memory and /// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape. /// return a Buffer using the array shape.
pub fn fromArray(platform: Platform, arr: anytype) !Buffer { pub fn fromArray(platform: Platform, arr: anytype) !Buffer {

View File

@ -74,8 +74,9 @@ pub fn compileFn(
args: ShapeOf(stdx.meta.FnArgs(func)), args: ShapeOf(stdx.meta.FnArgs(func)),
platform: Platform, platform: Platform,
) !FnExe(func) { ) !FnExe(func) {
const name = @typeName(@TypeOf(func)); var pretty_name = try prettyFnName(func, allocator);
var context = try CompilationContext.init(allocator, name, platform); defer pretty_name.deinit(allocator);
var context = try CompilationContext.init(allocator, pretty_name.items, platform);
defer context.deinit(); defer context.deinit();
return .{ .inner = try context.compileInternal(allocator, func, args) }; return .{ .inner = try context.compileInternal(allocator, func, args) };
@ -353,3 +354,37 @@ fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Bu
}).cb, &local_ctx, v); }).cb, &local_ctx, v);
stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
} }
fn prettyFnName(
comptime func: anytype,
allocator: std.mem.Allocator,
) !std.ArrayListUnmanaged(u8) {
const full_noisy_name = @typeName(@TypeOf(func));
const og_len = full_noisy_name.len;
const buffer = try allocator.alloc(u8, og_len);
errdefer comptime unreachable; // No errors below this point.
var out: []u8 = buffer;
{
const verbose = "tensor.Tensor";
const compact = "Tensor";
const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer);
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
}
{
const verbose = "tensor.Tensor.";
const compact = "";
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
}
{
const verbose = "shape.Shape";
const compact = "Shape";
const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer);
out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len;
}
return .{ .items = out, .capacity = og_len };
}

View File

@ -13,7 +13,7 @@ test {
} }
const ShapeError = error{ DimMismatch, NotFound }; const ShapeError = error{ DimMismatch, NotFound };
const NOT_SET: i64 = 0; const NOT_SET: i64 = -2;
const DIM_MISMATCH: i64 = -1; const DIM_MISMATCH: i64 = -1;
/// Collect the given dimensions inside a struct containing tagged tensors. /// Collect the given dimensions inside a struct containing tagged tensors.
@ -26,10 +26,12 @@ pub fn collectDims(
res: ShapeStruct(dims), res: ShapeStruct(dims),
mode: @TypeOf(mode), mode: @TypeOf(mode),
}; };
var context = LocalContext{ var context = LocalContext{
.res = std.mem.zeroes(ShapeStruct(dims)), .res = undefined,
.mode = mode, .mode = mode,
}; };
@memset(std.mem.bytesAsSlice(i64, std.mem.asBytes(&context.res)), NOT_SET);
meta.visit((struct { meta.visit((struct {
fn cb(ctx: *LocalContext, shape: *const Shape) void { fn cb(ctx: *LocalContext, shape: *const Shape) void {
@ -96,7 +98,7 @@ test collectDims {
collectDims(.{ .b, .d }, &model, .strict), collectDims(.{ .b, .d }, &model, .strict),
error.DimMismatch, error.DimMismatch,
); );
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = -1, .d = 5 }); try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = DIM_MISMATCH, .d = 5 });
} }
{ {
var model: Model = .{ var model: Model = .{
@ -105,7 +107,7 @@ test collectDims {
.bias = Shape.init(.{5}, .f32).withTags(.{.d}), .bias = Shape.init(.{5}, .f32).withTags(.{.d}),
}; };
try std.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .strict), error.NotFound); try std.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .strict), error.NotFound);
try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = 0 }); try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = NOT_SET });
} }
{ {
var model: Model = .{ var model: Model = .{
@ -114,7 +116,7 @@ test collectDims {
.bias = Shape.init(.{7}, .f32).withTags(.{.d}), .bias = Shape.init(.{7}, .f32).withTags(.{.d}),
}; };
try std.testing.expectEqual(collectDims(.{ .b, .d }, &model, .strict), error.DimMismatch); try std.testing.expectEqual(collectDims(.{ .b, .d }, &model, .strict), error.DimMismatch);
try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = -1 }); try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = DIM_MISMATCH });
} }
} }

View File

@ -71,6 +71,7 @@ const Block = union(BlockKind) {
pub const MlirFn = struct { pub const MlirFn = struct {
name: []const u8, name: []const u8,
num_args: u32, num_args: u32,
res_tensors: *const anyopaque,
res_types: []mlir.Type, res_types: []mlir.Type,
res_shapes: []Shape, res_shapes: []Shape,
res_donations: []Tensor._Donation, res_donations: []Tensor._Donation,
@ -114,12 +115,15 @@ pub const CompilationContext = struct {
var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable; var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable;
mlir_ctx.loadAllAvailableDialects(); mlir_ctx.loadAllAvailableDialects();
// Too long module names create too long file paths. // Too long module names create too long file paths and files failed to create.
const name = full_name[0..@min(128, full_name.len)]; // * leave half of the space for parent folder and XLA generated filename,
// * leave 17 bytes for the module hash (16 + 1 for underscore).
const max_name_len = @divFloor(std.fs.max_path_bytes, 2) - 17;
const name = full_name[0..@min(max_name_len, full_name.len)];
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc); const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?); module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?);
var canonicalizer = try mlir.PassManager.init(mlir_ctx); var canonicalizer = try mlir.PassManager.init(mlir_ctx);
{ {
@ -130,12 +134,12 @@ pub const CompilationContext = struct {
} }
var arena = std.heap.ArenaAllocator.init(allocator_); var arena = std.heap.ArenaAllocator.init(allocator_);
_ = try arena.allocator().alloc(u8, std.mem.page_size); _ = try arena.allocator().alloc(u8, 4096);
_ = arena.reset(.retain_capacity); _ = arena.reset(.retain_capacity);
return .{ return .{
._platform = platform, ._platform = platform,
._name = name, ._name = try arena.allocator().dupe(u8, name),
._mlir_ctx = mlir_ctx, ._mlir_ctx = mlir_ctx,
._mlir_registry = mlir_registry, ._mlir_registry = mlir_registry,
._mlir_canonicalizer = canonicalizer, ._mlir_canonicalizer = canonicalizer,
@ -394,7 +398,9 @@ pub const CompilationContext = struct {
// But it forces user to have simpler function. // But it forces user to have simpler function.
const ReturnT = stdx.meta.FnResult(func); const ReturnT = stdx.meta.FnResult(func);
const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
// Those are returned to caller so we don't put them in the arena.
// Those are returned to caller so we don't put them in the arena, but in the module allocator.
const fn_res = try res_allocator.create(ReturnT);
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count); const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count); const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count); const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
@ -406,14 +412,14 @@ pub const CompilationContext = struct {
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
std.debug.assert(assigned_args_count == tensor_count); std.debug.assert(assigned_args_count == tensor_count);
const fn_res = forward: { fn_res.* = forward: {
self.activate(); self.activate();
defer self.deactivate(); defer self.deactivate();
break :forward @call(.auto, func, args.*); break :forward @call(.auto, func, args.*);
}; };
var fn_res_values: [out_tensor_count]mlir.Value = undefined; var fn_res_values: [out_tensor_count]mlir.Value = undefined;
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body.appendOperationRecursive(fn_ret); fn_body.appendOperationRecursive(fn_ret);
@ -457,6 +463,7 @@ pub const CompilationContext = struct {
.mlir_fn = mlir_fn, .mlir_fn = mlir_fn,
.name = opts.name, .name = opts.name,
.num_args = @intCast(tensor_count), .num_args = @intCast(tensor_count),
.res_tensors = fn_res,
.res_types = fn_res_types, .res_types = fn_res_types,
.res_shapes = fn_res_shapes, .res_shapes = fn_res_shapes,
.res_donations = fn_res_donations, .res_donations = fn_res_donations,
@ -639,46 +646,51 @@ pub const CompilationContext = struct {
func_name: [:0]const u8, func_name: [:0]const u8,
comptime func: anytype, comptime func: anytype,
args: stdx.meta.FnArgs(func), args: stdx.meta.FnArgs(func),
) stdx.meta.FnResult(func) { ) error{OutOfMemory}!stdx.meta.FnResult(func) {
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator); var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
defer arena_state.deinit(); defer arena_state.deinit();
// This arena is used for allocations which won't outlive the function call,
// but the function creation uses `self.allocator()` which we'll live for the duration of the compilation.
const arena = arena_state.allocator(); const arena = arena_state.allocator();
// first, do the "compile" and check the bytecode // first, do the "compile" and check the bytecode
// the result of this will also have the correct tags of the result shapes // the result of this will also have the correct tags of the result shapes
const args_hash = hashArgs(args); const args_hash = hashArgs(args);
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash }; const key: FnKey = .{ .fn_ptr = &func, .input_hash = args_hash };
const function = self._fn_cache.getEntry(key) orelse b: { const function = self._fn_cache.get(key) orelse b: {
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
arena.dupeZ(u8, func_name) catch unreachable try self.allocator().dupeZ(u8, func_name)
else else
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash });
var arg_id: u16 = 0; var arg_id: u16 = 0;
var tensor_args: @TypeOf(args) = args; var tensor_args: @TypeOf(args) = args;
meta.mapAlloc(struct { try meta.mapAlloc(struct {
fn cb(arg_id_: *u16, x: Tensor) Tensor { fn cb(arg_id_: *u16, x: Tensor) Tensor {
const a = arg_id_.*; const a = arg_id_.*;
arg_id_.* += 1; arg_id_.* += 1;
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } }; return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
} }
}.cb, arena, &arg_id, args, &tensor_args) catch @panic("OutOfMemory"); }.cb, arena, &arg_id, args, &tensor_args);
const f = self.emitMlir(func, &tensor_args, .{ const f = try self.emitMlir(
.name = full_name, func,
}) catch @panic("OOM"); &tensor_args,
.{ .name = full_name },
);
self._module.getBody().appendOperation(f.mlir_fn); self._module.getBody().appendOperation(f.mlir_fn);
break :b self._fn_cache.addEntry(self.allocator(), key, f) catch unreachable; try self._fn_cache.putNoClobber(self.allocator(), key, f);
break :b f;
}; };
const loc = self.mlirCtx().location(@src()); const loc = self.mlirCtx().location(@src());
const values = arena.alloc(mlir.Value, function.num_args) catch unreachable; const values = try arena.alloc(mlir.Value, function.num_args);
self.extractValues(&args, values); self.extractValues(&args, values);
const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable; const donations = try arena.alloc(Tensor._Donation, function.num_args);
meta.collectBuf(struct { meta.collectBuf(struct {
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation { pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
return ctx.getValueAndDonation(x)[1]; return ctx.getValueAndDonation(x)[1];
@ -689,9 +701,7 @@ pub const CompilationContext = struct {
// Create the result tensor object by combining the operand results, // Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations. // as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated. // Note: this assume res can be stack-allocated.
// Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us. var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*;
// But this is blocked on https://github.com/zml/zml/issues/97
var res: stdx.meta.FnResult(func) = undefined;
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation }; const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
var context: LocalContext = .{ .op = op, .function = function, .donations = donations }; var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
meta.visit((struct { meta.visit((struct {
@ -911,6 +921,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true); // setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true); // setFlag(&options, "xla_gpu_enable_custom_fusions", true);
// setFlags(&options, "xla_gpu_enable_address_computation_fusion", true);
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true);
@ -1009,45 +1020,8 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer {
return .{ .hasher = hasher }; return .{ .hasher = hasher };
} }
pub const FnCache = struct { pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn);
pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 }; pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 };
cache: std.AutoHashMapUnmanaged(Key, MlirFn) = .{},
pub fn deinit(self: FnCache, allocator: std.mem.Allocator) void {
self.cache.deinit(allocator);
}
pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn {
return self.cache.get(key);
}
pub fn addEntry(self: *FnCache, allocator: std.mem.Allocator, key: Key, value: MlirFn) !MlirFn {
const res_types_copy = try allocator.dupe(mlir.Type, value.res_types);
errdefer allocator.free(res_types_copy);
const res_shapes_copy = try allocator.dupe(Shape, value.res_shapes);
errdefer allocator.free(res_shapes_copy);
const res_donations_copy = try allocator.dupe(Tensor._Donation, value.res_donations);
errdefer allocator.free(res_donations_copy);
const name_copy = try allocator.dupeZ(u8, value.name);
errdefer allocator.free(name_copy);
const owned_value: MlirFn = .{
.name = name_copy,
.mlir_fn = value.mlir_fn,
.num_args = value.num_args,
.res_types = res_types_copy,
.res_shapes = res_shapes_copy,
.res_donations = res_donations_copy,
};
try self.cache.putNoClobber(allocator, key, owned_value);
return owned_value;
}
};
test FnCache { test FnCache {
const zml = @import("zml.zig"); const zml = @import("zml.zig");
@ -1109,6 +1083,70 @@ test FnCache {
try zml.testing.expectClose(expected, res, 1e-4); try zml.testing.expectClose(expected, res, 1e-4);
} }
test "FnCache with mixed integer/tensor" {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Layer = struct {
const Layer_ = @This();
var num_call: u32 = 0;
w: Tensor,
pub fn _fwd(self: Layer_, x: Tensor) struct { Tensor, usize } {
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
// Note: this is for testing only, it's a bad idea to mutate global state
// from a forward function because it can mess with caching.
num_call += 1;
return .{ wx.addConstant(num_call), num_call };
}
};
const NN = struct {
const NN_ = @This();
layers: [3]Layer,
pub fn _fwd(self: NN_, x0: Tensor) Tensor {
var x = x0;
var y: usize = 0;
x, y = ops.call(self.layers[0], ._fwd, .{x});
std.debug.assert(Layer.num_call == 1);
std.debug.assert(y == 1);
// Here we call a second time but since first two layers have the same shape,
// We hit the function cache, and "num_call" is not incremented.
x, y = ops.call(self.layers[1], ._fwd, .{x});
std.debug.assert(Layer.num_call == 1);
std.debug.assert(y == 1);
x, y = ops.call(self.layers[2], ._fwd, .{x});
std.debug.assert(Layer.num_call == 2);
std.debug.assert(y == 2);
return x;
}
pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor {
var x = x0;
for (self.layers, &[_]u32{ 1, 1, 2 }) |layer, bias| {
const wx = layer.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
x = wx.addConstant(bias);
}
return x;
}
};
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
const nn: zml.Bufferized(NN) = .{
.layers = .{
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) },
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) },
// third layer has different shape
.{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) },
},
};
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x });
try zml.testing.expectClose(expected, res, 1e-4);
}
pub fn hashArgs(mod: anytype) u64 { pub fn hashArgs(mod: anytype) u64 {
var hasher = std.hash.Wyhash.init(0); var hasher = std.hash.Wyhash.init(0);
hash(&hasher, mod, .DeepRecursive); hash(&hasher, mod, .DeepRecursive);

View File

@ -34,7 +34,7 @@ pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), arg
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func); const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func);
const actual_fn = @field(@TypeOf(self), @tagName(func)); const actual_fn = @field(@TypeOf(self), @tagName(func));
return ctx.callFunc(name, actual_fn, .{self} ++ args); return ctx.callFunc(name, actual_fn, .{self} ++ args) catch @panic("OOM");
} }
pub fn while_( pub fn while_(
@ -445,14 +445,36 @@ pub fn if_(
if (TrueBlockSignature.Return != FalseBlockSignature.Return) { if (TrueBlockSignature.Return != FalseBlockSignature.Return) {
@compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return));
} }
stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res });
var true_shapes = std.ArrayList(Shape).init(ctx.allocator());
defer true_shapes.deinit();
var false_shapes = std.ArrayList(Shape).init(ctx.allocator());
defer false_shapes.deinit();
var failed_to_collect = false;
meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch {
failed_to_collect = true;
};
meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch {
failed_to_collect = true;
};
if (!failed_to_collect) {
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items });
}
}
const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0);
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
.operands = &.{pred.value()}, .operands = &.{scalar_pred.value()},
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{ true_branch_block, false_branch_block }, .blocks = &.{ true_branch_block, false_branch_block },
// We can't verify right away, cause the weights captured by the if haven't been added yet. // We can't verify right away, cause the weights captured by the if haven't been added yet.
@ -791,7 +813,7 @@ pub fn scatter(
// validate coord axes: all coord_axes should exist inside self // validate coord axes: all coord_axes should exist inside self
for (indices_axes.constSlice()) |t| { for (indices_axes.constSlice()) |t| {
stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={any}", .{ self, indices_axes }); stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() });
} }
// Handle scalar indices by broadcasting them to the indices with the highest rank. // Handle scalar indices by broadcasting them to the indices with the highest rank.
@ -905,7 +927,7 @@ fn scatterConfig(
scatter_to_operand_axes.appendAssumeCapacity(op.axis(t)); scatter_to_operand_axes.appendAssumeCapacity(op.axis(t));
} }
for (indices.tags()) |t| { for (indices.tags()) |t| {
stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got updates={} and indices={s}", .{ update, indices_axes.constSlice() }); stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices });
updates_transpose.appendAssumeCapacity(update.axis(t)); updates_transpose.appendAssumeCapacity(update.axis(t));
} }

View File

@ -1743,7 +1743,7 @@ pub const Tensor = struct {
/// Returns a Tensor containing evenly spaced values within a given interval. /// Returns a Tensor containing evenly spaced values within a given interval.
pub fn arange(args: ArangeArgs, dt: DataType) Tensor { pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); stdx.debug.assert(args.start <= args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();

View File

@ -17,7 +17,6 @@ pub fn env() zml.Platform {
_platform = ctx.autoPlatform(.{}).withCompilationOptions(.{ _platform = ctx.autoPlatform(.{}).withCompilationOptions(.{
.xla_dump_to = "/tmp/zml/tests/", .xla_dump_to = "/tmp/zml/tests/",
.sharding_enabled = true, .sharding_enabled = true,
.xla_dump_hlo_pass_re = ".*",
}); });
} }