Radix/zml/module.zig

1262 lines
53 KiB
Zig
Raw Normal View History

const std = @import("std");
const asynk = @import("async");
const c = @import("c");
const dialect = @import("mlir/dialects");
2025-01-28 09:35:58 +00:00
const mlir = @import("mlir");
const stdx = @import("stdx");
const upb = @import("upb");
const BaseExe = @import("exe.zig").BaseExe;
const Buffer = @import("buffer.zig").Buffer;
const meta = @import("meta.zig");
2025-01-28 09:35:58 +00:00
const mlirx = @import("mlirx.zig");
const ops = @import("ops.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const Target = @import("platform.zig").Target;
const Tensor = @import("tensor.zig").Tensor;
const Tracer = @import("tools/tracer.zig").Tracer;
const log = std.log.scoped(.@"zml/module");
test {
std.testing.refAllDecls(@This());
}
pub const MlirFn = struct {
name: []const u8,
args_shapes: []Shape,
res_tensors: *const anyopaque,
res_types: []mlir.Type,
res_shapes: []Shape,
res_donations: []Tensor._Donation,
mlir_fn: mlir.Operation,
pub const Kind = enum {
main,
private,
};
};
pub const CompilationContext = struct {
_platform: Platform,
_name: []const u8,
_arena: std.heap.ArenaAllocator,
_mlir_ctx: mlir.Context,
_mlir_registry: mlir.Registry,
_mlir_canonicalizer: mlir.PassManager,
_module: mlir.Module,
_blocks: stdx.BoundedArray(TaggedBlock, 64) = .{},
_fn_cache: FnCache = .{},
_block_args: TensorToBlockArg = .{},
_unique_id: u64 = 10000,
_tracer: Tracer,
_previous: ?*CompilationContext = null,
threadlocal var _current: ?*CompilationContext = null;
const TaggedBlock = struct { mlir.Block, mlir.Block.RecursiveOpts };
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
const AttributeList = stdx.BoundedArray(mlir.NamedAttribute, 3);
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
const mlir_registry = mlir.Registry.init() catch unreachable;
inline for (.{ "func", "stablehlo" }) |d| {
mlir.DialectHandle.fromString(d).insertDialect(mlir_registry);
}
var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable;
mlir_ctx.loadAllAvailableDialects();
// Too long module names create too long file paths and files failed to create.
// * leave half of the space for parent folder and XLA generated filename,
// * leave 17 bytes for the module hash (16 + 1 for underscore).
const max_name_len = @divFloor(std.fs.max_path_bytes, 2) - 17;
const name = full_name[0..@min(max_name_len, full_name.len)];
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", .string(mlir_ctx, "zml"));
var canonicalizer = try mlir.PassManager.init(mlir_ctx);
{
var opm = canonicalizer.asOpPassManager();
try opm.addPipeline("canonicalize");
try opm.addPipeline("cse");
try opm.addPipeline("canonicalize");
}
var arena = std.heap.ArenaAllocator.init(allocator_);
_ = try arena.allocator().alloc(u8, 4096);
_ = arena.reset(.retain_capacity);
return .{
._platform = platform,
._name = try arena.allocator().dupe(u8, name),
._mlir_ctx = mlir_ctx,
._mlir_registry = mlir_registry,
._mlir_canonicalizer = canonicalizer,
._module = module,
._blocks = .{},
._fn_cache = .{},
._arena = arena,
._tracer = Tracer.init("ai.zml.compilation"),
};
}
pub fn deinit(self: *CompilationContext) void {
// No need to deinit self._fn_cache cause it uses our arena
self._mlir_ctx.deinit();
self._mlir_registry.deinit();
self._arena.deinit();
}
pub fn allocator(self: *CompilationContext) std.mem.Allocator {
return self._arena.allocator();
}
pub fn activate(self: *CompilationContext) void {
self._previous = _current;
_current = self;
}
pub fn deactivate(self: *CompilationContext) void {
std.debug.assert(_current != null and _current.? == self);
_current = self._previous;
}
pub fn current() *CompilationContext {
return _current.?;
}
pub fn target(self: *const CompilationContext) Target {
return self._platform.target;
}
pub fn mlirCtx(self: *const CompilationContext) mlir.Context {
return self._mlir_ctx;
}
pub fn location(self: *const CompilationContext, src: std.builtin.SourceLocation, comptime name: [:0]const u8, args: anytype) mlir.Location {
return self._mlir_ctx.location(src).namedFmt(self._mlir_ctx, name, args);
}
/// Compiles the given function with the given arguments.
/// This is the untyped API and is not meant to be use directly.
///
/// * allocator is used to allocate the result Exe
/// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors.
pub fn compileInternal(
self: *CompilationContext,
allocator_: std.mem.Allocator,
comptime func: anytype,
args: anytype,
) !BaseExe {
const arena = self.allocator();
var timer = std.time.Timer.start() catch null;
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
// Run in a dedicated thread because compilation relies on `threadlocal`.
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, func, &tensor_args, CompilationContext.EmitMlirOpts{ .name = "main", .kind = .main } });
const module = self._module;
module.getBody().appendOperation(f.mlir_fn);
const sharding = self._platform.sharding();
const mlir_ctx = self._mlir_ctx;
2025-01-28 09:35:58 +00:00
module.op().setAttributeByName("mhlo.num_replicas", .int(mlir_ctx, .i32, sharding.num_replicas));
module.op().setAttributeByName("mhlo.num_partitions", .int(mlir_ctx, .i32, sharding.num_partitions));
const module_hash = computeModuleHash(self._platform, module);
var module_dir: ?[]const u8 = null;
var pjrt_location: ?[:0]const u8 = null;
if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| {
const sep = std.fs.path.sep_str;
const module_dir_name = try std.fmt.allocPrint(arena, "{s}{s}{s}{s}{s}_{x}", .{ xla_dump_to, sep, @tagName(self._platform.target), sep, self._name, module_hash });
try std.fs.cwd().makePath(module_dir_name);
module_dir = try std.fs.cwd().realpathAlloc(arena, module_dir_name);
const cache_dir = try std.fs.cwd().openDir(module_dir.?, .{});
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
const mlir_name = "module.mlir";
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
var write_buf: [4096]u8 = undefined;
var writer = file.writer(&write_buf);
module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
} else |_| {
log.warn("Failed to open {s}", .{mlir_name});
}
pjrt_location = try std.fs.path.joinZ(arena, &.{ module_dir.?, "module.pjrt" });
}
const loaded_executable: *pjrt.LoadedExecutable = blk: {
if (pjrt_location) |pjrt_loc| {
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
log.info("Loaded pre-compiled module from {s} (generated from {s}/module.mlir)", .{ pjrt_loc, module_dir.? });
break :blk exe;
} else |err| {
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
}
}
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| {
log.err("pjrt-{s} failed to compile: {}", .{ @tagName(self._platform.target), err });
if (module_dir) |dir| log.err("mlir can be found at {s}/module.mlir", .{dir});
return err;
};
if (pjrt_location) |pjrt_loc| {
storePjrtExecutable(self._platform, loaded_executable, pjrt_loc) catch |err| {
log.warn("Failed to store compiled module: {} at {s}", .{ err, pjrt_loc });
};
}
break :blk loaded_executable;
};
log.debug("******** ZML generated MLIR ********", .{});
log.debug("{f}", .{module.op().mlirFormatter(.{})});
if (timer) |*t| {
const time_ms = @divFloor(t.lap(), std.time.ns_per_ms);
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{stdx.math.divFloat(f32, time_ms, 1000)});
}
return BaseExe.init(
allocator_,
self._platform,
loaded_executable,
.{
.input_shapes = f.args_shapes,
.result_shapes = f.res_shapes,
.n_devices = sharding.num_replicas * sharding.num_partitions,
},
);
}
fn currentBlock(self: *const CompilationContext) ?TaggedBlock {
return if (self._blocks.len > 0) self._blocks.get(self._blocks.len - 1) else null;
}
pub fn openBlock(self: *CompilationContext, kind: mlir.Block.RecursiveOpts, args: []const mlir.Type, locs: []const mlir.Location) !TaggedBlock {
const block: TaggedBlock = .{ try mlir.Block.init(args, locs), kind };
self.pushBlock(block);
return block;
}
pub fn closeBlock(self: *CompilationContext, block: TaggedBlock) void {
const popped = self._blocks.pop();
std.debug.assert(block[0].eql(popped.?[0]));
}
fn pushBlock(self: *CompilationContext, block: TaggedBlock) void {
self._blocks.appendAssumeCapacity(block);
}
/// Transform a Tensor -> Tensor function into an Mlir block.
/// `blkctx` represents values from outside the block that can be accessed inside the block.
/// Returns both the mlir.Block created and also the Tensors returned by `func`.
/// The returned tensors should not be returned to the user,
/// because their `mlir.Value` must not escape the block that created them.
/// But their shapes/tags can be safely propagated further.
pub fn makeBlock(
self: *CompilationContext,
kind: mlir.Block.RecursiveOpts,
comptime S: ops.BlockSignature,
2023-02-14 13:52:49 +00:00
func: *const S.Fn,
blkctx: S.BlkCtx,
args: S.Args,
) struct { mlir.Block, S.Return } {
const N = S.nIn;
const loc = self.mlirCtx().location(@src());
const locations = .{loc} ** N;
var input_types: [N]mlir.Type = undefined;
fillMlirTypes(&args, self.mlirCtx(), &input_types);
// Before creating a new block, assign all received values to previous block,
// otherwise they will be assign to this block
if (self.currentBlock()) |prev_block| {
meta.visit(_appendTensorRecursive, prev_block, &blkctx);
}
const block = self.openBlock(kind, &input_types, &locations) catch unreachable;
defer self.closeBlock(block);
// Here we want to create the block with the correct mlir types.
// but we don't want to use the values themselves.
// So we create a copy of the arguments, and replace values
// by the block arguments.
var blk_args = args;
std.debug.assert(assignBlockArguments(&blk_args, block[0], 0) == N);
const block_res = @call(.auto, func, S.blkArgs(blkctx, blk_args));
var block_res_values: [S.nOut]mlir.Value = undefined;
self.extractValues(&block_res, &block_res_values);
const block_ret = dialect.stablehlo.returns_(self.mlirCtx(), &block_res_values, loc);
block[0].appendOperationRecursive(block_ret, block[1]);
return .{ block[0], block_res };
}
fn _appendTensorRecursive(tagged_block: TaggedBlock, x: *const Tensor) void {
const block, const tag = tagged_block;
block.appendValueRecursive(x.value(), tag);
}
pub const EmitMlirOpts = struct {
name: []const u8,
kind: MlirFn.Kind = .private,
};
/// Generate an MLIR function from a ZML function.
/// The caller is responsible to have properly created the input
/// tensors with unique tensor ids.
pub fn emitMlir(
self: *CompilationContext,
comptime func: anytype,
args: *const stdx.meta.FnArgs(func),
opts: EmitMlirOpts,
) error{OutOfMemory}!MlirFn {
const frame = self._tracer.frameStart("emitMlir.emit");
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
const res_allocator = self.allocator();
// Note: only temp allocations are done in the arena,
// the other allocations are in the context allocator.
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
defer arena_state.deinit();
const arena = arena_state.allocator();
const tensor_count = meta.count(Tensor, args);
const mlir_ctx = self.mlirCtx();
const loc = mlir_ctx.location(@src());
const locations = try arena.alloc(mlir.Location, tensor_count);
@memset(locations, mlir.Location.unknown(mlir_ctx));
var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count);
2025-01-28 09:35:58 +00:00
for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
const og_block_args = self._block_args;
defer {
self._block_args.deinit(self.allocator());
self._block_args = og_block_args;
}
// Reset the buffer -> assignement
self._block_args = .{};
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
// But it forces user to have simpler function.
const ReturnT = stdx.meta.FnResult(func);
const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
// Those are returned to caller so we don't put them in the arena, but in the module allocator.
const fn_res = try res_allocator.create(ReturnT);
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
const fn_res_output_memory_kind = try res_allocator.alloc(Buffer.Memory, out_tensor_count);
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
{
defer self.closeBlock(fn_body);
try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count));
const assigned_args_count = self.mapBlockArguments(args, fn_body[0], 0);
std.debug.assert(assigned_args_count == tensor_count);
fn_res.* = forward: {
self.activate();
defer self.deactivate();
break :forward @call(.auto, func, args.*);
};
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations, fn_res_output_memory_kind);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
}
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
@memset(arg_attrs, .{});
const res_attrs = try arena.alloc(AttributeList, out_tensor_count);
@memset(res_attrs, .{});
if (opts.kind == .main) {
self.addDonationsAttributes(arg_attrs, fn_res_donations);
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
if (self._platform.sharding().num_partitions > 1) {
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
}
}
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
.sym_name = opts.name,
.args = input_types,
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types,
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
.block = fn_body[0],
.location = loc,
});
self._tracer.frameEnd(frame, "emitMlir.emit");
const canonicalize_frame = self._tracer.frameStart("emitMlir.canonicalize");
defer self._tracer.frameEnd(canonicalize_frame, "emitMlir.canonicalize");
self._mlir_canonicalizer.runOnOp(mlir_fn) catch |err| switch (err) {
error.InvalidMlir => {
log.err("Failed to canonicalize invalid mlir: {f}", .{mlir_fn.mlirFormatter(.{})});
// user errors should have triggered a panic before we reach this.
@panic("ZML generated invalid mlir. Please open a bug report");
},
};
return .{
.mlir_fn = mlir_fn,
.name = opts.name,
.args_shapes = input_shapes.items,
.res_tensors = fn_res,
.res_types = fn_res_types,
.res_shapes = fn_res_shapes,
.res_donations = fn_res_donations,
};
}
fn addOutputMemoryKindAttributes(self: CompilationContext, attributes: []AttributeList, output_memory_kind: []const Buffer.Memory) void {
const mlir_ctx = self.mlirCtx();
for (attributes, output_memory_kind) |*attr, memory_kind| {
// .device is the default output, don't explicitly emit the attribute
if (memory_kind == .device) continue;
attr.appendAssumeCapacity(.named(
mlir_ctx,
"mhlo.memory_kind",
.string(mlir_ctx, memory_kind.pjrtName()),
));
}
}
/// Given a list of donations mapping output buffers to input buffers,
/// generate donation attribute for each `n_args` input argument.
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
const ctx = self.mlirCtx();
var n_donations: usize = 0;
for (donations, 0..) |donation, index| {
switch (donation) {
.no_buffer => {},
// This is an input buffer that has been returned,
// but without explicitly calling `reuseBuffer`.
// So we assume the intent was to return a new buffer.
.input_buffer => {},
.arg => |a| {
n_donations += 1;
// This will break the day we writer another attribute before donation.
// When the time come, do a more fancy lookup here to check if an argument
// is donated twice.
stdx.debug.assert(attributes[a].len == 0, "Donation error ! Argument {d} has been donated twice ! To {d} and to {any}", .{ a, index, attributes[a].buffer[0] });
attributes[a].appendAssumeCapacity(.named(ctx, "tf.aliasing_output", .int(ctx, .i32, @intCast(index))));
// log.debug("attribute: {}", .{attributes[a].constSlice()});
},
}
}
}
test addDonationsAttributes {
const zml = @import("zml.zig");
const platform = zml.testing.env();
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const s = Shape.init(.{8}, .f16);
const Local = struct {
bias: Tensor,
pub fn _fwd(self: @This(), x: Tensor, y: Tensor) [2]Tensor {
const x1 = zml.ops.call(self, ._inner, .{x});
const x2 = zml.ops.call(self, ._inner, .{x1});
return .{ x1.reuseBuffer(y), x2 };
}
pub fn _inner(self: @This(), x: Tensor) Tensor {
const y = x.add(self.bias);
return y.reuseBuffer(x);
}
};
const model: Local = .{
.bias = zml.Tensor{ ._shape = s, ._id = .{ .buffer_id = 0 } },
};
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
defer comp.deinit();
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator);
defer mlir_bytecode.deinit();
try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})});
// Check that the `x` input argument gives its buffer to the result tensor.
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
try std.testing.expectEqual(3, f.args_shapes.len);
// We should have two buffers being donated.
const template = "tf.aliasing_output = {d} : i32";
var buf = template.*;
for (0..2) |i| {
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| {
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
return err;
};
}
}
fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void {
const ctx = self.mlirCtx();
if (!self._platform.compilation_options.sharding_enabled) return;
const default_layout = mlir.NamedAttribute.named(ctx, "mhlo.layout_mode", .string(ctx, "default"));
for (arg_attrs, input_shapes) |*attr, shape| {
attr.appendAssumeCapacity(default_layout);
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
}
for (res_attrs, output_shapes) |*attr, shape| {
attr.appendAssumeCapacity(default_layout);
attr.appendAssumeCapacity(.named(ctx, "mhlo.sharding", self.getShardingAttr(shape)));
}
}
pub fn numPartitions(self: CompilationContext) u8 {
return self._platform.sharding().num_partitions;
}
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
const ctx = self.mlirCtx();
const num_partitions = self.numPartitions();
var sharding_str: stdx.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.Attribute.string(ctx, sharding_str.constSlice());
}
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
if (n_sharded == 0 or num_partitions == 1) {
try writer.writeAll("{replicated}");
return;
}
try writer.writeAll("{devices=[");
for (0..shape.rank()) |i| {
try writer.print("{d}", .{if (shape._sharding_info[i]) num_partitions else 1});
if (i < shape.rank() - 1) try writer.writeByte(',');
}
try writer.print("]<=[{d}]}}", .{num_partitions});
}
test writeShardingRepresentation {
var rule: [64]u8 = undefined;
const x = Shape.init(.{ 16, 8 }, .f32);
// By default tensors are replicated.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x, 4, fbs.writer());
try std.testing.expectEqualStrings("{replicated}", fbs.getWritten());
}
// Shard along first axis.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer());
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten());
}
// Also shard along second axis.
{
var fbs = std.io.fixedBufferStream(&rule);
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer());
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten());
}
}
fn finalizeAttributeList(allocator_: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute {
const res = try allocator_.alloc(mlir.Attribute, attributes.len);
for (res, attributes) |*r, attr| {
r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr();
}
return res;
}
/// Generates an MLIR `func.call` of the given function.
/// If the function has not been seen yet, we generate MLIR for it,
/// in a independent function.
/// The main benefit of this is to generate MLIR that maps more closely
/// to the Zig code, but compilation speed stays similar.
pub fn callFunc(
self: *CompilationContext,
func_name: [:0]const u8,
comptime func: anytype,
args: stdx.meta.FnArgs(func),
) error{OutOfMemory}!stdx.meta.FnResult(func) {
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
defer arena_state.deinit();
// This arena is used for allocations which won't outlive the function call,
// but the function creation uses `self.allocator()` which we'll live for the duration of the compilation.
const arena = arena_state.allocator();
// first, do the "compile" and check the bytecode
// the result of this will also have the correct tags of the result shapes
const args_hash = hashArgs(args);
const key: FnKey = .{ .fn_ptr = &func, .input_hash = args_hash };
const function = self._fn_cache.get(key) orelse b: {
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
try self.allocator().dupeZ(u8, func_name)
else
try std.fmt.allocPrintSentinel(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }, 0);
var arg_id: u16 = 0;
var tensor_args: @TypeOf(args) = args;
try meta.mapAlloc(struct {
fn cb(arg_id_: *u16, x: Tensor) Tensor {
const a = arg_id_.*;
arg_id_.* += 1;
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
}
}.cb, arena, &arg_id, args, &tensor_args);
const f = try self.emitMlir(
func,
&tensor_args,
.{ .name = full_name },
);
self._module.getBody().appendOperation(f.mlir_fn);
try self._fn_cache.putNoClobber(self.allocator(), key, f);
break :b f;
};
const loc = self.mlirCtx().location(@src());
const num_args = function.args_shapes.len;
const values = try arena.alloc(mlir.Value, num_args);
self.extractValues(&args, values);
const donations = try arena.alloc(Tensor._Donation, num_args);
meta.collectBuf(struct {
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
return ctx.getValueAndDonation(x)[1];
}
}.cb, self, &args, donations);
const op = dialect.func.call(self.mlirCtx(), @ptrCast(function.name), values, function.res_types, loc);
// Create the result tensor object by combining the operand results,
// as well as the registered shapes and donations.
// Note: this assume res can be stack-allocated.
var res = @as(*const stdx.meta.FnResult(func), @ptrCast(@alignCast(function.res_tensors))).*;
const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation };
var context: LocalContext = .{ .op = op, .function = function, .donations = donations };
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *Tensor) void {
const i = ctx.index;
ctx.index += 1;
var new = Tensor.fromMlirValue(ctx.op.result(i));
new._shape = ctx.function.res_shapes[i];
new._donation = switch (ctx.function.res_donations[i]) {
.no_buffer => .no_buffer,
.arg => |input_arg| ctx.donations[input_arg],
.input_buffer => .no_buffer, // user escaped the sandbox
};
tensor.* = new;
}
}).cb, &context, &res);
std.debug.assert(context.index == op.numResults());
return res;
}
/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor.
///
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
/// stored in the Module.
/// Caller need to allocate required memory in self._block_args.
pub fn mapBlockArguments(self: *CompilationContext, v: anytype, block: mlir.Block, start: usize) usize {
const LocalContext = struct {
index: usize,
block: mlir.Block,
self: *CompilationContext,
};
var context = LocalContext{ .self = self, .block = block, .index = start };
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
const arg_value = ctx.block.argument(ctx.index);
// log.debug("mapping {} to arg {}", .{ tensor._id, ctx.index });
const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id);
if (res.found_existing) {
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {f} and {f} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
} else {
res.value_ptr.* = .{ arg_value, .{ .arg = @intCast(ctx.index) } };
}
ctx.index += 1;
}
}).cb, &context, v);
return context.index;
}
/// Create tensor from the given shapes.
/// Each created tensor will receive a unique id, local to this CompilationContext.
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator_: std.mem.Allocator, args_shapes: anytype) !ArgsT {
const Local = struct {
fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor {
defer arg_id.* += 1;
return Tensor{
._shape = shape,
._id = .{ .arg_id = arg_id.* },
._donation = .input_buffer,
};
}
};
var tensor_args: ArgsT = undefined;
try meta.mapAlloc(Local.tensorFromShape, allocator_, &self._unique_id, args_shapes, &tensor_args);
return tensor_args;
}
/// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found.
pub fn extractValuesAndTypes(
self: *const CompilationContext,
v: anytype,
values: []mlir.Value,
types: []mlir.Type,
shapes: []Shape,
donations: []Tensor._Donation,
output_memory_kind: []Buffer.Memory,
) void {
std.debug.assert(values.len == types.len);
const LocalContext = struct {
self: *const CompilationContext,
index: usize = 0,
values: []mlir.Value,
types: []mlir.Type,
shapes: []Shape,
donations: []Tensor._Donation,
output_memory_kind: []Buffer.Memory,
};
var context = LocalContext{
.self = self,
.values = values,
.types = types,
.shapes = shapes,
.donations = donations,
.output_memory_kind = output_memory_kind,
};
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
const value, const donation = ctx.self.getValueAndDonation(tensor.*);
ctx.values[ctx.index] = value;
ctx.types[ctx.index] = value.getType();
ctx.shapes[ctx.index] = tensor._shape;
ctx.donations[ctx.index] = donation;
ctx.output_memory_kind[ctx.index] = tensor._output_memory_kind;
ctx.index += 1;
}
}).cb, &context, v);
std.debug.assert(context.index == values.len);
}
pub fn getValueAndDonation(self: *const CompilationContext, tensor: Tensor) struct { mlir.Value, Tensor._Donation } {
return switch (tensor._id) {
.buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res|
.{ res[0], res[1] }
else {
log.err("Found unknown tensor id {f}({})", .{ tensor, tensor._id });
@panic("Found unknown tensor id");
},
.mlir => |v| .{ v, tensor._donation },
};
}
pub fn getValue(self: *const CompilationContext, tensor: Tensor) mlir.Value {
return self.getValueAndDonation(tensor)[0];
}
pub fn extractValues(self: *const CompilationContext, v: anytype, values: []mlir.Value) void {
meta.collectBuf(getValue, self, v, values);
}
};
fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
var hasher = std.hash.XxHash64.init(0);
module.hash(&hasher);
hasher.update(platform.pjrt_client.getPlatformName(platform.pjrt_api));
const api_version = platform.pjrt_api.version();
hasher.update(std.mem.sliceAsBytes(&[_]i64{ api_version.major, api_version.minor }));
return hasher.final();
}
const max_pjrt_executable_size = 400 * 1024 * 1024;
fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable {
const tracer = Tracer.init("ai.zml.load_exe");
const compile_frame = tracer.frameStart("pjrt load executable");
defer tracer.frameEnd(compile_frame, "pjrt load executable");
const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{});
defer loaded_executable_file.close();
const exe_size = if (loaded_executable_file.stat()) |stat| stat.size else |_| max_pjrt_executable_size;
const bytes = try arena.alloc(u8, exe_size);
defer arena.free(bytes);
const size = try loaded_executable_file.readAll(bytes);
return try platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]);
}
fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void {
const loaded_executable_file = try std.fs.createFileAbsoluteZ(absolute_file, .{});
defer loaded_executable_file.close();
var executable = try loaded_executable.getExecutable(platform.pjrt_api);
defer executable.deinit(platform.pjrt_api);
var serialize_result = try executable.serialize(platform.pjrt_api);
defer serialize_result.deinit();
try loaded_executable_file.writeAll(serialize_result.bytes);
}
fn setXlaOverrideFlag(map: *c.upb_Map, flag: []const u8, value: anytype, upb_arena: *c.upb_Arena) !void {
const result = c.upb_Map_Set(
map,
.{ .str_val = upb.stringView(flag) },
.{ .msg_val = blk: {
const field = try upb.new(c.xla_OptionOverrideProto, upb_arena);
switch (@typeInfo(@TypeOf(value))) {
.bool => c.xla_OptionOverrideProto_set_bool_field(field, value),
.comptime_int, .int => c.xla_OptionOverrideProto_set_int_field(field, @intCast(value)),
.comptime_float, .float => c.xla_OptionOverrideProto_set_double_field(field, @floatCast(value)),
else => c.xla_OptionOverrideProto_set_string_field(field, upb.stringView(value)),
}
break :blk @ptrCast(field);
} },
upb_arena,
);
if (result == false) {
return std.mem.Allocator.Error.OutOfMemory;
}
}
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable {
const tracer = Tracer.init("ai.zml.compilation");
const compile_frame = tracer.frameStart("pjrt compilation");
defer tracer.frameEnd(compile_frame, "pjrt compilation");
const sharding = platform.sharding();
var upb_alloc: upb.Allocator = .init(arena);
const upb_arena = c.upb_Arena_Init(null, 0, upb_alloc.inner());
defer c.upb_Arena_Free(upb_arena);
const options = blk: {
const options = try upb.new(c.xla_CompileOptionsProto, upb_arena);
c.xla_CompileOptionsProto_set_executable_build_options(options, executable_build_options_blk: {
const exec_build_options = try upb.new(c.xla_ExecutableBuildOptionsProto, upb_arena);
c.xla_ExecutableBuildOptionsProto_set_device_ordinal(exec_build_options, -1);
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
c.xla_DeviceAssignmentProto_set_replica_count(device_assignment, sharding.num_replicas);
c.xla_DeviceAssignmentProto_set_computation_count(device_assignment, sharding.num_partitions);
const computation_devices = c.xla_DeviceAssignmentProto_resize_computation_devices(device_assignment, sharding.num_partitions, upb_arena);
for (computation_devices[0..sharding.num_partitions], 0..) |*computation_device, i| {
computation_device.* = try upb.new(c.xla_DeviceAssignmentProto_ComputationDevice, upb_arena);
_ = c.xla_DeviceAssignmentProto_ComputationDevice_add_replica_device_ids(computation_device.*, @intCast(i), upb_arena);
}
break :device_assignment_blk device_assignment;
});
break :executable_build_options_blk exec_build_options;
});
const overrides_map = c._xla_CompileOptionsProto_env_option_overrides_mutable_upb_map(options, upb_arena);
switch (platform.target) {
.cuda => {
// NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena);
},
.rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
// enabled on CDNA and it's used on RDNA. Disable it altogether.
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
// Use lld from libllvm instead of invoking the ld.lld binary.
// This saves us from having to sandbox it.
try setXlaOverrideFlag(overrides_map, "xla_gpu_use_inprocess_lld", true, upb_arena);
},
else => {},
}
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
try setXlaOverrideFlag(overrides_map, "xla_dump_to", xla_dump_to, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_as_proto", true, upb_arena);
if (platform.compilation_options.xla_dump_fusion_visualization) {
try setXlaOverrideFlag(overrides_map, "xla_dump_fusion_visualization", true, upb_arena);
}
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_pass_re", re, upb_arena);
}
}
break :blk options;
};
const loaded_executable = try platform.pjrt_client.compile(
platform.pjrt_api,
arena,
module,
try upb.serialize(options, upb_arena),
);
errdefer loaded_executable.deinit();
return loaded_executable;
}
/// Visit the given struct and recursively counts the number of tensors found.
pub fn countTensors(v: anytype) usize {
const LocalContext = struct {
count: usize = 0,
};
var context = LocalContext{};
meta.visit((struct {
fn cb(inner_context: *LocalContext, _: *const Tensor) void {
inner_context.count += 1;
}
}).cb, &context, v);
return context.count;
}
/// Visit the given struct and recursively fill the `types` slice with the mlir.Type associated with encountered Tensor.
pub fn fillMlirTypes(v: anytype, mlir_ctx: mlir.Context, types: []mlir.Type) void {
const LocalContext = struct {
index: usize = 0,
mlir_ctx: mlir.Context,
types: []mlir.Type,
};
var context = LocalContext{ .mlir_ctx = mlir_ctx, .types = types };
meta.visit((struct {
fn cb(inner_context: *LocalContext, tensor: *const Tensor) void {
2025-01-28 09:35:58 +00:00
inner_context.types[inner_context.index] = mlirx.tensorType(inner_context.mlir_ctx, tensor.shape());
inner_context.index += 1;
}
}).cb, &context, v);
std.debug.assert(context.index == types.len);
}
/// Visit the given struct and recursively associate the `block` arguments with the `value` field of each encountered Tensor.
///
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
/// stored in the Module.
fn assignBlockArguments(v: anytype, block: mlir.Block, start: usize) usize {
const LocalContext = struct { index: usize, block: mlir.Block };
var context = LocalContext{ .block = block, .index = start };
meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *Tensor) void {
tensor._id = .{ .mlir = ctx.block.argument(ctx.index) };
tensor._donation = .{ .arg = @intCast(ctx.index) };
ctx.index += 1;
}
}).cb, &context, v);
return context.index;
}
pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn);
pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 };
test FnCache {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Layer = struct {
const Layer_ = @This();
w: Tensor,
b: Tensor,
pub fn _fwd(self: Layer_, x: Tensor) Tensor {
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
return wx.add(self.b.broad(wx.shape())).relu();
}
};
const NN = struct {
const NN_ = @This();
layers: [3]Layer,
pub fn _fwd(self: NN_, x0: Tensor) Tensor {
var x = x0;
for (self.layers) |layer| {
x = ops.call(layer, ._fwd, .{x});
}
return x;
}
pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor {
var x = x0;
for (self.layers) |layer| {
x = layer._fwd(x);
}
return x;
}
};
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
const nn: zml.testing.BufferizedWithArgs(NN) = .{
.layers = .{
.{
.w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }),
.b = try .fromArray(platform, [2]f16{ 0, 0 }),
},
.{
.w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }),
.b = try .fromArray(platform, [2]f16{ 10, 10 }),
},
// third layer is different
.{
.w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }),
.b = try .fromArray(platform, [3]f16{ -10, -10, -10 }),
},
},
};
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x });
try zml.testing.expectClose(expected, res, 1e-4);
}
test "FnCache with mixed integer/tensor" {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Layer = struct {
const Layer_ = @This();
var num_call: u32 = 0;
w: Tensor,
pub fn _fwd(self: Layer_, x: Tensor) struct { Tensor, usize } {
const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
// Note: this is for testing only, it's a bad idea to mutate global state
// from a forward function because it can mess with caching.
num_call += 1;
return .{ wx.addConstant(num_call), num_call };
}
};
const NN = struct {
const NN_ = @This();
layers: [3]Layer,
pub fn _fwd(self: NN_, x0: Tensor) Tensor {
var x = x0;
var y: usize = 0;
x, y = ops.call(self.layers[0], ._fwd, .{x});
std.debug.assert(Layer.num_call == 1);
std.debug.assert(y == 1);
// Here we call a second time but since first two layers have the same shape,
// We hit the function cache, and "num_call" is not incremented.
x, y = ops.call(self.layers[1], ._fwd, .{x});
std.debug.assert(Layer.num_call == 1);
std.debug.assert(y == 1);
x, y = ops.call(self.layers[2], ._fwd, .{x});
std.debug.assert(Layer.num_call == 2);
std.debug.assert(y == 2);
return x;
}
pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor {
var x = x0;
for (self.layers, &[_]u32{ 1, 1, 2 }) |layer, bias| {
const wx = layer.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{});
x = wx.addConstant(bias);
}
return x;
}
};
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
const nn: zml.testing.BufferizedWithArgs(NN) = .{
.layers = .{
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }) },
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }) },
// third layer has different shape
.{ .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }) },
},
};
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x });
try zml.testing.expectClose(expected, res, 1e-4);
}
pub fn hashArgs(mod: anytype) u64 {
var hasher = std.hash.Wyhash.init(0);
hash(&hasher, mod, .DeepRecursive);
return hasher.final();
}
pub fn hashShape(hasher: *std.hash.Wyhash, shape: Shape) void {
// Note: if we enforced 0-init dims then we could hash dims instead.
hashArray(hasher, shape.dims(), .Shallow);
hash(hasher, shape._dtype, .Shallow);
hash(hasher, shape._sharding_info, .Shallow);
for (shape.tags()) |tag| {
hash(hasher, @intFromPtr(tag), .Shallow);
}
}
const tensorAwareHash = hash; // alias for when "hash" is ambiguous
/// Provides generic hashing for any eligible type.
/// Strategy is provided to determine if pointers should be followed or not.
pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Strategy) void {
const Key = @TypeOf(key);
if (Key == Tensor) return hashShape(hasher, key.shape());
if (Key == Shape) return hashShape(hasher, key);
if (strat == .Shallow and std.meta.hasUniqueRepresentation(Key)) {
hasher.update(std.mem.asBytes(&key));
return;
}
switch (@typeInfo(Key)) {
.noreturn, .@"opaque", .undefined, .null, .comptime_float, .comptime_int, .type, .enum_literal, .frame, .void => return,
// Help the optimizer see that hashing an int is easy by inlining!
// TODO Check if the situation is better after #561 is resolved.
.int => |int| switch (int.signedness) {
.signed => hash(hasher, @as(@Type(.{ .int = .{
.bits = int.bits,
.signedness = .unsigned,
} }), @bitCast(key)), strat),
.unsigned => {
if (std.meta.hasUniqueRepresentation(Key)) {
hasher.update(std.mem.asBytes(&key));
} else {
// Take only the part containing the key value, the remaining
// bytes are undefined and must not be hashed!
const byte_size = comptime std.math.divCeil(comptime_int, @bitSizeOf(Key), 8) catch unreachable;
hasher.update(std.mem.asBytes(&key)[0..byte_size]);
}
},
},
// Note: contrary to Zig we accept hashing floats.
// Typically the float we are going to hash here are hyperparameters,
// and not the result of an operation, so bytes should be the same everytime.
.float => hasher.update(std.mem.asBytes(&key)),
.bool => hash(hasher, @intFromBool(key), strat),
.@"enum" => hash(hasher, @intFromEnum(key), strat),
.error_set => hash(hasher, @intFromError(key), strat),
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
.pointer => |info| switch (info.size) {
.one => switch (strat) {
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
.deep => hash(hasher, key.*, .Shallow),
.deeprecursive => switch (@typeInfo(info.child)) {
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
else => hash(hasher, key.*, .DeepRecursive),
},
},
.slice => {
switch (strat) {
.Shallow => hash(hasher, @intFromPtr(key.ptr), .Shallow),
.Deep => hashArray(hasher, key, .Shallow),
.DeepRecursive => hashArray(hasher, key, .DeepRecursive),
}
hash(hasher, key.len, .Shallow);
},
.many,
.c,
=> switch (strat) {
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
else => @compileError(
\\ unknown-length pointers and C pointers cannot be hashed deeply.
\\ Consider providing your own hash function.
),
},
},
.optional => if (key) |k| hash(hasher, k, strat),
.array => hashArray(hasher, key, strat),
.vector => |info| {
if (std.meta.hasUniqueRepresentation(Key)) {
hasher.update(std.mem.asBytes(&key));
} else {
comptime var i = 0;
inline while (i < info.len) : (i += 1) {
hash(hasher, key[i], strat);
}
}
},
.@"struct" => |info| {
inline for (info.fields) |field| {
// We reuse the hash of the previous field as the seed for the
// next one so that they're dependant.
hash(hasher, @field(key, field.name), strat);
}
},
.@"union" => |info| {
if (info.tag_type) |tag_type| {
const tag = std.meta.activeTag(key);
hash(hasher, tag, strat);
inline for (info.fields) |field| {
if (@field(tag_type, field.name) == tag) {
if (field.type != void) {
hash(hasher, @field(key, field.name), strat);
}
// TODO use a labelled break when it does not crash the compiler. cf #2908
// break :blk;
return;
}
}
unreachable;
} else @compileError("cannot hash untagged union type: " ++ @typeName(Key) ++ ", provide your own hash function");
},
.error_union => blk: {
const payload = key catch |err| {
hash(hasher, err, strat);
break :blk;
};
hash(hasher, payload, strat);
},
}
}
fn hashArray(hasher: anytype, key: anytype, comptime strat: std.hash.Strategy) void {
for (key) |element| {
hash(hasher, element, strat);
}
}