diff --git a/zml/buffer.zig b/zml/buffer.zig index 2369138..20a6fcb 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -131,6 +131,12 @@ pub const Buffer = struct { return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); } + /// Copies the given Zig slice to the accelerator memory and + /// return a Buffer with the given dimensions. + pub fn fromBytes(platform: Platform, sh: Shape, data: []const u8) !Buffer { + return from(platform, HostBuffer.fromBytes(sh, data)); + } + /// Copies the given Zig array to the accelerator memory and /// return a Buffer using the array shape. pub fn fromArray(platform: Platform, arr: anytype) !Buffer { diff --git a/zml/exe.zig b/zml/exe.zig index b34d5e9..48329b9 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -74,8 +74,9 @@ pub fn compileFn( args: ShapeOf(stdx.meta.FnArgs(func)), platform: Platform, ) !FnExe(func) { - const name = @typeName(@TypeOf(func)); - var context = try CompilationContext.init(allocator, name, platform); + var pretty_name = try prettyFnName(func, allocator); + defer pretty_name.deinit(allocator); + var context = try CompilationContext.init(allocator, pretty_name.items, platform); defer context.deinit(); return .{ .inner = try context.compileInternal(allocator, func, args) }; @@ -353,3 +354,37 @@ fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Bu }).cb, &local_ctx, v); stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); } + +fn prettyFnName( + comptime func: anytype, + allocator: std.mem.Allocator, +) !std.ArrayListUnmanaged(u8) { + const full_noisy_name = @typeName(@TypeOf(func)); + const og_len = full_noisy_name.len; + const buffer = try allocator.alloc(u8, og_len); + errdefer comptime unreachable; // No errors below this point. + var out: []u8 = buffer; + + { + const verbose = "tensor.Tensor"; + const compact = "Tensor"; + const num_replacements = std.mem.replace(u8, full_noisy_name, verbose, compact, buffer); + out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; + } + + { + const verbose = "tensor.Tensor."; + const compact = ""; + const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer); + out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; + } + + { + const verbose = "shape.Shape"; + const compact = "Shape"; + const num_replacements = std.mem.replace(u8, out, verbose, compact, buffer); + out.len = out.len + num_replacements * compact.len - num_replacements * verbose.len; + } + + return .{ .items = out, .capacity = og_len }; +} diff --git a/zml/helpers.zig b/zml/helpers.zig index c756825..6a8c31f 100644 --- a/zml/helpers.zig +++ b/zml/helpers.zig @@ -13,7 +13,7 @@ test { } const ShapeError = error{ DimMismatch, NotFound }; -const NOT_SET: i64 = 0; +const NOT_SET: i64 = -2; const DIM_MISMATCH: i64 = -1; /// Collect the given dimensions inside a struct containing tagged tensors. @@ -26,10 +26,12 @@ pub fn collectDims( res: ShapeStruct(dims), mode: @TypeOf(mode), }; + var context = LocalContext{ - .res = std.mem.zeroes(ShapeStruct(dims)), + .res = undefined, .mode = mode, }; + @memset(std.mem.bytesAsSlice(i64, std.mem.asBytes(&context.res)), NOT_SET); meta.visit((struct { fn cb(ctx: *LocalContext, shape: *const Shape) void { @@ -96,7 +98,7 @@ test collectDims { collectDims(.{ .b, .d }, &model, .strict), error.DimMismatch, ); - try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = -1, .d = 5 }); + try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = DIM_MISMATCH, .d = 5 }); } { var model: Model = .{ @@ -105,7 +107,7 @@ test collectDims { .bias = Shape.init(.{5}, .f32).withTags(.{.d}), }; try std.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .strict), error.NotFound); - try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = 0 }); + try zml.testing.expectEqual(collectDims(.{ .b, .d, .c }, &model, .ignore_errors), .{ .b = 2, .d = 5, .c = NOT_SET }); } { var model: Model = .{ @@ -114,7 +116,7 @@ test collectDims { .bias = Shape.init(.{7}, .f32).withTags(.{.d}), }; try std.testing.expectEqual(collectDims(.{ .b, .d }, &model, .strict), error.DimMismatch); - try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = -1 }); + try zml.testing.expectEqual(collectDims(.{ .b, .d }, &model, .ignore_errors), .{ .b = 2, .d = DIM_MISMATCH }); } } diff --git a/zml/module.zig b/zml/module.zig index 253125c..453f0cb 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -71,6 +71,7 @@ const Block = union(BlockKind) { pub const MlirFn = struct { name: []const u8, num_args: u32, + res_tensors: *const anyopaque, res_types: []mlir.Type, res_shapes: []Shape, res_donations: []Tensor._Donation, @@ -114,12 +115,15 @@ pub const CompilationContext = struct { var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable; mlir_ctx.loadAllAvailableDialects(); - // Too long module names create too long file paths. - const name = full_name[0..@min(128, full_name.len)]; + // Too long module names create too long file paths and files failed to create. + // * leave half of the space for parent folder and XLA generated filename, + // * leave 17 bytes for the module hash (16 + 1 for underscore). + const max_name_len = @divFloor(std.fs.max_path_bytes, 2) - 17; + const name = full_name[0..@min(max_name_len, full_name.len)]; const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const module = mlir.Module.init(loc); - module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?); + module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, "zml").as(mlir.Attribute).?); var canonicalizer = try mlir.PassManager.init(mlir_ctx); { @@ -130,12 +134,12 @@ pub const CompilationContext = struct { } var arena = std.heap.ArenaAllocator.init(allocator_); - _ = try arena.allocator().alloc(u8, std.mem.page_size); + _ = try arena.allocator().alloc(u8, 4096); _ = arena.reset(.retain_capacity); return .{ ._platform = platform, - ._name = name, + ._name = try arena.allocator().dupe(u8, name), ._mlir_ctx = mlir_ctx, ._mlir_registry = mlir_registry, ._mlir_canonicalizer = canonicalizer, @@ -394,7 +398,9 @@ pub const CompilationContext = struct { // But it forces user to have simpler function. const ReturnT = stdx.meta.FnResult(func); const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors"); - // Those are returned to caller so we don't put them in the arena. + + // Those are returned to caller so we don't put them in the arena, but in the module allocator. + const fn_res = try res_allocator.create(ReturnT); const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count); const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count); const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count); @@ -406,14 +412,14 @@ pub const CompilationContext = struct { const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0); std.debug.assert(assigned_args_count == tensor_count); - const fn_res = forward: { + fn_res.* = forward: { self.activate(); defer self.deactivate(); break :forward @call(.auto, func, args.*); }; var fn_res_values: [out_tensor_count]mlir.Value = undefined; - self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); + self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); fn_body.appendOperationRecursive(fn_ret); @@ -457,6 +463,7 @@ pub const CompilationContext = struct { .mlir_fn = mlir_fn, .name = opts.name, .num_args = @intCast(tensor_count), + .res_tensors = fn_res, .res_types = fn_res_types, .res_shapes = fn_res_shapes, .res_donations = fn_res_donations, @@ -639,46 +646,51 @@ pub const CompilationContext = struct { func_name: [:0]const u8, comptime func: anytype, args: stdx.meta.FnArgs(func), - ) stdx.meta.FnResult(func) { + ) error{OutOfMemory}!stdx.meta.FnResult(func) { var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator); defer arena_state.deinit(); + // This arena is used for allocations which won't outlive the function call, + // but the function creation uses `self.allocator()` which we'll live for the duration of the compilation. const arena = arena_state.allocator(); // first, do the "compile" and check the bytecode // the result of this will also have the correct tags of the result shapes const args_hash = hashArgs(args); - const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash }; + const key: FnKey = .{ .fn_ptr = &func, .input_hash = args_hash }; - const function = self._fn_cache.getEntry(key) orelse b: { + const function = self._fn_cache.get(key) orelse b: { const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name)) - arena.dupeZ(u8, func_name) catch unreachable + try self.allocator().dupeZ(u8, func_name) else - std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable; + try std.fmt.allocPrintZ(self.allocator(), "{s}_{x}", .{ func_name, key.input_hash }); var arg_id: u16 = 0; var tensor_args: @TypeOf(args) = args; - meta.mapAlloc(struct { + try meta.mapAlloc(struct { fn cb(arg_id_: *u16, x: Tensor) Tensor { const a = arg_id_.*; arg_id_.* += 1; return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } }; } - }.cb, arena, &arg_id, args, &tensor_args) catch @panic("OutOfMemory"); + }.cb, arena, &arg_id, args, &tensor_args); - const f = self.emitMlir(func, &tensor_args, .{ - .name = full_name, - }) catch @panic("OOM"); + const f = try self.emitMlir( + func, + &tensor_args, + .{ .name = full_name }, + ); self._module.getBody().appendOperation(f.mlir_fn); - break :b self._fn_cache.addEntry(self.allocator(), key, f) catch unreachable; + try self._fn_cache.putNoClobber(self.allocator(), key, f); + break :b f; }; const loc = self.mlirCtx().location(@src()); - const values = arena.alloc(mlir.Value, function.num_args) catch unreachable; + const values = try arena.alloc(mlir.Value, function.num_args); self.extractValues(&args, values); - const donations = arena.alloc(Tensor._Donation, function.num_args) catch unreachable; + const donations = try arena.alloc(Tensor._Donation, function.num_args); meta.collectBuf(struct { pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation { return ctx.getValueAndDonation(x)[1]; @@ -689,9 +701,7 @@ pub const CompilationContext = struct { // Create the result tensor object by combining the operand results, // as well as the registered shapes and donations. // Note: this assume res can be stack-allocated. - // Maybe it'd be simpler to just call the Zig function twice to do the shape/donation propagation for us. - // But this is blocked on https://github.com/zml/zml/issues/97 - var res: stdx.meta.FnResult(func) = undefined; + var res = @as(*const stdx.meta.FnResult(func), @alignCast(@ptrCast(function.res_tensors))).*; const LocalContext = struct { index: usize = 0, op: mlir.Operation, function: MlirFn, donations: []Tensor._Donation }; var context: LocalContext = .{ .op = op, .function = function, .donations = donations }; meta.visit((struct { @@ -911,6 +921,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m // setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true); // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); // setFlag(&options, "xla_gpu_enable_custom_fusions", true); + // setFlags(&options, "xla_gpu_enable_address_computation_fusion", true); // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); // setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true); // setFlag(&options, "xla_gpu_use_runtime_fusion", true); @@ -1009,45 +1020,8 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer { return .{ .hasher = hasher }; } -pub const FnCache = struct { - pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 }; - - cache: std.AutoHashMapUnmanaged(Key, MlirFn) = .{}, - - pub fn deinit(self: FnCache, allocator: std.mem.Allocator) void { - self.cache.deinit(allocator); - } - - pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn { - return self.cache.get(key); - } - - pub fn addEntry(self: *FnCache, allocator: std.mem.Allocator, key: Key, value: MlirFn) !MlirFn { - const res_types_copy = try allocator.dupe(mlir.Type, value.res_types); - errdefer allocator.free(res_types_copy); - - const res_shapes_copy = try allocator.dupe(Shape, value.res_shapes); - errdefer allocator.free(res_shapes_copy); - - const res_donations_copy = try allocator.dupe(Tensor._Donation, value.res_donations); - errdefer allocator.free(res_donations_copy); - - const name_copy = try allocator.dupeZ(u8, value.name); - errdefer allocator.free(name_copy); - - const owned_value: MlirFn = .{ - .name = name_copy, - .mlir_fn = value.mlir_fn, - .num_args = value.num_args, - .res_types = res_types_copy, - .res_shapes = res_shapes_copy, - .res_donations = res_donations_copy, - }; - - try self.cache.putNoClobber(allocator, key, owned_value); - return owned_value; - } -}; +pub const FnCache = std.AutoHashMapUnmanaged(FnKey, MlirFn); +pub const FnKey = struct { fn_ptr: *const anyopaque, input_hash: u64 }; test FnCache { const zml = @import("zml.zig"); @@ -1109,6 +1083,70 @@ test FnCache { try zml.testing.expectClose(expected, res, 1e-4); } +test "FnCache with mixed integer/tensor" { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Layer = struct { + const Layer_ = @This(); + var num_call: u32 = 0; + + w: Tensor, + + pub fn _fwd(self: Layer_, x: Tensor) struct { Tensor, usize } { + const wx = self.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); + // Note: this is for testing only, it's a bad idea to mutate global state + // from a forward function because it can mess with caching. + num_call += 1; + return .{ wx.addConstant(num_call), num_call }; + } + }; + + const NN = struct { + const NN_ = @This(); + layers: [3]Layer, + + pub fn _fwd(self: NN_, x0: Tensor) Tensor { + var x = x0; + var y: usize = 0; + x, y = ops.call(self.layers[0], ._fwd, .{x}); + std.debug.assert(Layer.num_call == 1); + std.debug.assert(y == 1); + // Here we call a second time but since first two layers have the same shape, + // We hit the function cache, and "num_call" is not incremented. + x, y = ops.call(self.layers[1], ._fwd, .{x}); + std.debug.assert(Layer.num_call == 1); + std.debug.assert(y == 1); + x, y = ops.call(self.layers[2], ._fwd, .{x}); + std.debug.assert(Layer.num_call == 2); + std.debug.assert(y == 2); + return x; + } + + pub fn _forwardRefImpl(self: NN_, x0: Tensor) Tensor { + var x = x0; + for (self.layers, &[_]u32{ 1, 1, 2 }) |layer, bias| { + const wx = layer.w.dotGeneral(x, &.{.{ -1, 0 }}, &.{}); + x = wx.addConstant(bias); + } + return x; + } + }; + + const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 }); + const nn: zml.Bufferized(NN) = .{ + .layers = .{ + .{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) }, + .{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) }, + // third layer has different shape + .{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) }, + }, + }; + const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x }); + const expected = try zml.testing.compileAndCall(platform, NN._forwardRefImpl, .{ nn, x }); + try zml.testing.expectClose(expected, res, 1e-4); +} + pub fn hashArgs(mod: anytype) u64 { var hasher = std.hash.Wyhash.init(0); hash(&hasher, mod, .DeepRecursive); diff --git a/zml/ops.zig b/zml/ops.zig index 21bcb9f..f314e01 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -34,7 +34,7 @@ pub fn call(self: anytype, comptime func: stdx.meta.DeclEnum(@TypeOf(self)), arg const ctx = CompilationContext.current(); const name = @typeName(@TypeOf(self)) ++ "." ++ @tagName(func); const actual_fn = @field(@TypeOf(self), @tagName(func)); - return ctx.callFunc(name, actual_fn, .{self} ++ args); + return ctx.callFunc(name, actual_fn, .{self} ++ args) catch @panic("OOM"); } pub fn while_( @@ -445,14 +445,36 @@ pub fn if_( if (TrueBlockSignature.Return != FalseBlockSignature.Return) { @compileError("true_branch_fn and false_branch_fn return types don't match ! " ++ @typeName(TrueBlockSignature.Return) ++ " and " ++ @typeName(FalseBlockSignature.Return)); } + + stdx.debug.assert(pred.dtype() == .bool and pred.count() == 1, "zml.ops.if_ expects the condition to have exactly one element of dtype .bool, got {}", .{pred}); + const ctx = CompilationContext.current(); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); - stdx.debug.assert(false_branch_res.shape().eqlWithTags(true_branch_res.shape()), "zml.ops.if_ expects true and false branch to produce outputs of the same shape, but it produced true={} and false={}", .{ true_branch_res, false_branch_res }); + var true_shapes = std.ArrayList(Shape).init(ctx.allocator()); + defer true_shapes.deinit(); + var false_shapes = std.ArrayList(Shape).init(ctx.allocator()); + defer false_shapes.deinit(); + + var failed_to_collect = false; + meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch { + failed_to_collect = true; + }; + meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch { + failed_to_collect = true; + }; + if (!failed_to_collect) { + stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); + for (true_shapes.items, false_shapes.items) |true_shape, false_shape| { + stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {_}\n -false branch: {_}", .{ true_shapes.items, false_shapes.items }); + } + } + + const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0); const loc = ctx.mlirCtx().location(@src()); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ - .operands = &.{pred.value()}, + .operands = &.{scalar_pred.value()}, .result_type_inference = true, .blocks = &.{ true_branch_block, false_branch_block }, // We can't verify right away, cause the weights captured by the if haven't been added yet. @@ -791,7 +813,7 @@ pub fn scatter( // validate coord axes: all coord_axes should exist inside self for (indices_axes.constSlice()) |t| { - stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={any}", .{ self, indices_axes }); + stdx.debug.assert(self._shape.hasTag(t) != null, "zml.ops.scatter expects axes of indices to be axes of inputs, got input={_} and indices={s}", .{ self, indices_axes.constSlice() }); } // Handle scalar indices by broadcasting them to the indices with the highest rank. @@ -905,7 +927,7 @@ fn scatterConfig( scatter_to_operand_axes.appendAssumeCapacity(op.axis(t)); } for (indices.tags()) |t| { - stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got updates={} and indices={s}", .{ update, indices_axes.constSlice() }); + stdx.debug.assert(update.hasTag(t) != null, "scatter expects 'updates' to have all axes of 'indices', got self={_}, updates={_} and indices={_}", .{ op, update, indices }); updates_transpose.appendAssumeCapacity(update.axis(t)); } diff --git a/zml/tensor.zig b/zml/tensor.zig index 5529b6b..15543d7 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1743,7 +1743,7 @@ pub const Tensor = struct { /// Returns a Tensor containing evenly spaced values within a given interval. pub fn arange(args: ArangeArgs, dt: DataType) Tensor { - stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); + stdx.debug.assert(args.start <= args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); const ctx = CompilationContext.current(); diff --git a/zml/testing.zig b/zml/testing.zig index e686b60..b0d6827 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -17,7 +17,6 @@ pub fn env() zml.Platform { _platform = ctx.autoPlatform(.{}).withCompilationOptions(.{ .xla_dump_to = "/tmp/zml/tests/", .sharding_enabled = true, - .xla_dump_hlo_pass_re = ".*", }); }