zml: HostBuffer.prettyPrint()

Add pretty printing of HostBuffer.

This will be leverage by the debug helper `x.print()`
It can also be used like this: `std.log.info("my buffer: {}",
.{host_buffer.pretty()})`
This commit is contained in:
Tarry Singh 2023-12-25 13:01:17 +00:00
parent 5ddd034d2c
commit 5bd7f8aae9
4 changed files with 90 additions and 17 deletions

View File

@ -235,6 +235,22 @@ pub const HostBuffer = struct {
};
}
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
const ax = self._shape.axis(axis_);
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
var _strides: ?[Shape.MAX_RANK]i64 = self._strides;
if (self._strides) |strydes| {
std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]);
}
return .{
._shape = self.shape().drop(ax),
.data = self.data,
._strides = _strides,
._memory = self._memory,
};
}
pub fn format(
self: HostBuffer,
comptime fmt: []const u8,
@ -245,6 +261,67 @@ pub const HostBuffer = struct {
_ = options;
try writer.print("HostBuffer(.{_})", .{self._shape});
}
/// Formatter for a HostBuffer that also print the values not just the shape.
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
pub fn pretty(self: HostBuffer) PrettyPrinter {
return .{ .x = self };
}
pub const PrettyPrinter = struct {
x: HostBuffer,
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
_ = fmt;
_ = options;
try prettyPrint(self.x, writer);
}
};
pub fn prettyPrint(self: HostBuffer, writer: anytype) !void {
return self.prettyPrintIndented(4, 0, writer);
}
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
if (self.rank() == 1) {
try writer.writeByteNTimes(' ', indent_level);
switch (self.dtype()) {
inline else => |dt| {
const values = self.items(dt.toZigType());
const n = @min(values.len, 1024);
try writer.print("{any},\n", .{values[0..n]});
},
}
return;
}
try writer.writeByteNTimes(' ', indent_level);
_ = try writer.write("{\n");
defer {
writer.writeByteNTimes(' ', indent_level) catch {};
_ = writer.write("},\n") catch {};
}
// Write first rows
const n: u64 = @intCast(self.dim(0));
for (0..@min(num_rows, n)) |d| {
const di: i64 = @intCast(d);
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
}
if (n < num_rows) return;
// Skip middle rows
if (n > 2 * num_rows) {
try writer.writeByteNTimes(' ', indent_level + 2);
_ = try writer.write("...\n");
}
// Write last rows
for (@max(n - num_rows, num_rows)..n) |d| {
const di: i64 = @intCast(d);
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
}
}
};
fn parseArrayInfo(T: type) Shape {

View File

@ -106,7 +106,7 @@ pub const CompilationContext = struct {
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
pub fn init(allocator_: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
const mlir_registry = mlir.Registry.init() catch unreachable;
inline for (.{ "func", "stablehlo" }) |d| {
mlir.DialectHandle.fromString(d).insertDialect(mlir_registry);
@ -114,6 +114,9 @@ pub const CompilationContext = struct {
var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable;
mlir_ctx.loadAllAvailableDialects();
// Too long module names create too long file paths.
const name = full_name[0..@min(128, full_name.len)];
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
const module = mlir.Module.init(loc);
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?);

View File

@ -1364,7 +1364,7 @@ pub const Tensor = struct {
else
toI64(axes__);
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] });
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
return self;
@ -1621,7 +1621,7 @@ pub const Tensor = struct {
/// Repeats a Tensor several times along the given axis.
///
/// * repeat1d(x, concat(&.{x, x, x, x}, axis);
/// * repeat1d(x, axis, 4) = concat(&.{x, x, x, x}, axis);
/// * repeat1d([0, 1, 2, 3], 0, 2) = [0, 1, 2, 3, 0, 1, 2, 3]
pub fn repeat1d(self: Tensor, axis_: anytype, n_rep: u63) Tensor {
if (n_rep == 1) {
@ -3738,13 +3738,7 @@ pub const Tensor = struct {
}
fn printCallback(host_buffer: HostBuffer) void {
switch (host_buffer.dtype()) {
inline else => |dt| {
const items = host_buffer.items(dt.toZigType());
const n = @min(items.len, 1024);
std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] });
},
}
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
}
};

View File

@ -13,19 +13,17 @@ var _platform: ?zml.Platform = null;
pub fn env() zml.Platform {
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
if (_platform == null) {
_test_compile_opts = .{
.sharding_enabled = true,
};
var ctx = zml.Context.init() catch unreachable;
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
_platform = ctx.autoPlatform(.{}).withCompilationOptions(.{
.xla_dump_to = "/tmp/zml/tests/",
.sharding_enabled = true,
.xla_dump_hlo_pass_re = ".*",
});
}
return _platform.?;
}
var _test_compile_opts: zml.CompilationOptions = .{};
/// In neural network we generally care about the relative precision,
/// but on a given dimension, if the output is close to 0, then the precision
/// don't matter as much.
@ -53,6 +51,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
if (should_free_left) left.deinit(allocator);
if (should_free_right) right.deinit(allocator);
}
errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() });
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() });