From 5bd7f8aae915db41f9772d49ee383c382fcf4d92 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 25 Dec 2023 13:01:17 +0000 Subject: [PATCH] zml: HostBuffer.prettyPrint() Add pretty printing of HostBuffer. This will be leverage by the debug helper `x.print()` It can also be used like this: `std.log.info("my buffer: {}", .{host_buffer.pretty()})` --- zml/hostbuffer.zig | 77 ++++++++++++++++++++++++++++++++++++++++++++++ zml/module.zig | 5 ++- zml/tensor.zig | 12 ++------ zml/testing.zig | 13 ++++---- 4 files changed, 90 insertions(+), 17 deletions(-) diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index aed629e..a477de4 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -235,6 +235,22 @@ pub const HostBuffer = struct { }; } + pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { + const ax = self._shape.axis(axis_); + stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); + + var _strides: ?[Shape.MAX_RANK]i64 = self._strides; + if (self._strides) |strydes| { + std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]); + } + return .{ + ._shape = self.shape().drop(ax), + .data = self.data, + ._strides = _strides, + ._memory = self._memory, + }; + } + pub fn format( self: HostBuffer, comptime fmt: []const u8, @@ -245,6 +261,67 @@ pub const HostBuffer = struct { _ = options; try writer.print("HostBuffer(.{_})", .{self._shape}); } + + /// Formatter for a HostBuffer that also print the values not just the shape. + /// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});` + pub fn pretty(self: HostBuffer) PrettyPrinter { + return .{ .x = self }; + } + + pub const PrettyPrinter = struct { + x: HostBuffer, + + pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { + _ = fmt; + _ = options; + try prettyPrint(self.x, writer); + } + }; + + pub fn prettyPrint(self: HostBuffer, writer: anytype) !void { + return self.prettyPrintIndented(4, 0, writer); + } + + fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void { + if (self.rank() == 1) { + try writer.writeByteNTimes(' ', indent_level); + switch (self.dtype()) { + inline else => |dt| { + const values = self.items(dt.toZigType()); + const n = @min(values.len, 1024); + try writer.print("{any},\n", .{values[0..n]}); + }, + } + return; + } + try writer.writeByteNTimes(' ', indent_level); + _ = try writer.write("{\n"); + defer { + writer.writeByteNTimes(' ', indent_level) catch {}; + _ = writer.write("},\n") catch {}; + } + + // Write first rows + const n: u64 = @intCast(self.dim(0)); + for (0..@min(num_rows, n)) |d| { + const di: i64 = @intCast(d); + const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); + try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); + } + + if (n < num_rows) return; + // Skip middle rows + if (n > 2 * num_rows) { + try writer.writeByteNTimes(' ', indent_level + 2); + _ = try writer.write("...\n"); + } + // Write last rows + for (@max(n - num_rows, num_rows)..n) |d| { + const di: i64 = @intCast(d); + const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); + try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); + } + } }; fn parseArrayInfo(T: type) Shape { diff --git a/zml/module.zig b/zml/module.zig index f7e6930..0587f88 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -106,7 +106,7 @@ pub const CompilationContext = struct { const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3); - pub fn init(allocator_: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext { + pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext { const mlir_registry = mlir.Registry.init() catch unreachable; inline for (.{ "func", "stablehlo" }) |d| { mlir.DialectHandle.fromString(d).insertDialect(mlir_registry); @@ -114,6 +114,9 @@ pub const CompilationContext = struct { var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable; mlir_ctx.loadAllAvailableDialects(); + // Too long module names create too long file paths. + const name = full_name[0..@min(128, full_name.len)]; + const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main"); const module = mlir.Module.init(loc); module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?); diff --git a/zml/tensor.zig b/zml/tensor.zig index bc66a8c..c375e24 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1364,7 +1364,7 @@ pub const Tensor = struct { else toI64(axes__); - stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len }); + stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] }); if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) { return self; @@ -1621,7 +1621,7 @@ pub const Tensor = struct { /// Repeats a Tensor several times along the given axis. /// - /// * repeat1d(x, concat(&.{x, x, x, x}, axis); + /// * repeat1d(x, axis, 4) = concat(&.{x, x, x, x}, axis); /// * repeat1d([0, 1, 2, 3], 0, 2) = [0, 1, 2, 3, 0, 1, 2, 3] pub fn repeat1d(self: Tensor, axis_: anytype, n_rep: u63) Tensor { if (n_rep == 1) { @@ -3738,13 +3738,7 @@ pub const Tensor = struct { } fn printCallback(host_buffer: HostBuffer) void { - switch (host_buffer.dtype()) { - inline else => |dt| { - const items = host_buffer.items(dt.toZigType()); - const n = @min(items.len, 1024); - std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] }); - }, - } + std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); } }; diff --git a/zml/testing.zig b/zml/testing.zig index 9d7a47e..e686b60 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -13,19 +13,17 @@ var _platform: ?zml.Platform = null; pub fn env() zml.Platform { if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block"); if (_platform == null) { - _test_compile_opts = .{ - .sharding_enabled = true, - }; - var ctx = zml.Context.init() catch unreachable; - _platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts); + _platform = ctx.autoPlatform(.{}).withCompilationOptions(.{ + .xla_dump_to = "/tmp/zml/tests/", + .sharding_enabled = true, + .xla_dump_hlo_pass_re = ".*", + }); } return _platform.?; } -var _test_compile_opts: zml.CompilationOptions = .{}; - /// In neural network we generally care about the relative precision, /// but on a given dimension, if the output is close to 0, then the precision /// don't matter as much. @@ -53,6 +51,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { if (should_free_left) left.deinit(allocator); if (should_free_right) right.deinit(allocator); } + errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() }); if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) { log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() });