zml: HostBuffer.prettyPrint()
Add pretty printing of HostBuffer.
This will be leverage by the debug helper `x.print()`
It can also be used like this: `std.log.info("my buffer: {}",
.{host_buffer.pretty()})`
This commit is contained in:
parent
5ddd034d2c
commit
5bd7f8aae9
@ -235,6 +235,22 @@ pub const HostBuffer = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
||||
const ax = self._shape.axis(axis_);
|
||||
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
||||
|
||||
var _strides: ?[Shape.MAX_RANK]i64 = self._strides;
|
||||
if (self._strides) |strydes| {
|
||||
std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]);
|
||||
}
|
||||
return .{
|
||||
._shape = self.shape().drop(ax),
|
||||
.data = self.data,
|
||||
._strides = _strides,
|
||||
._memory = self._memory,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn format(
|
||||
self: HostBuffer,
|
||||
comptime fmt: []const u8,
|
||||
@ -245,6 +261,67 @@ pub const HostBuffer = struct {
|
||||
_ = options;
|
||||
try writer.print("HostBuffer(.{_})", .{self._shape});
|
||||
}
|
||||
|
||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
||||
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
||||
pub fn pretty(self: HostBuffer) PrettyPrinter {
|
||||
return .{ .x = self };
|
||||
}
|
||||
|
||||
pub const PrettyPrinter = struct {
|
||||
x: HostBuffer,
|
||||
|
||||
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
try prettyPrint(self.x, writer);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn prettyPrint(self: HostBuffer, writer: anytype) !void {
|
||||
return self.prettyPrintIndented(4, 0, writer);
|
||||
}
|
||||
|
||||
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
|
||||
if (self.rank() == 1) {
|
||||
try writer.writeByteNTimes(' ', indent_level);
|
||||
switch (self.dtype()) {
|
||||
inline else => |dt| {
|
||||
const values = self.items(dt.toZigType());
|
||||
const n = @min(values.len, 1024);
|
||||
try writer.print("{any},\n", .{values[0..n]});
|
||||
},
|
||||
}
|
||||
return;
|
||||
}
|
||||
try writer.writeByteNTimes(' ', indent_level);
|
||||
_ = try writer.write("{\n");
|
||||
defer {
|
||||
writer.writeByteNTimes(' ', indent_level) catch {};
|
||||
_ = writer.write("},\n") catch {};
|
||||
}
|
||||
|
||||
// Write first rows
|
||||
const n: u64 = @intCast(self.dim(0));
|
||||
for (0..@min(num_rows, n)) |d| {
|
||||
const di: i64 = @intCast(d);
|
||||
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
||||
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
|
||||
}
|
||||
|
||||
if (n < num_rows) return;
|
||||
// Skip middle rows
|
||||
if (n > 2 * num_rows) {
|
||||
try writer.writeByteNTimes(' ', indent_level + 2);
|
||||
_ = try writer.write("...\n");
|
||||
}
|
||||
// Write last rows
|
||||
for (@max(n - num_rows, num_rows)..n) |d| {
|
||||
const di: i64 = @intCast(d);
|
||||
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
||||
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
fn parseArrayInfo(T: type) Shape {
|
||||
|
||||
@ -106,7 +106,7 @@ pub const CompilationContext = struct {
|
||||
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
||||
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
||||
|
||||
pub fn init(allocator_: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
|
||||
pub fn init(allocator_: std.mem.Allocator, full_name: []const u8, platform: Platform) !CompilationContext {
|
||||
const mlir_registry = mlir.Registry.init() catch unreachable;
|
||||
inline for (.{ "func", "stablehlo" }) |d| {
|
||||
mlir.DialectHandle.fromString(d).insertDialect(mlir_registry);
|
||||
@ -114,6 +114,9 @@ pub const CompilationContext = struct {
|
||||
var mlir_ctx = mlir.Context.initWithRegistry(mlir_registry, false) catch unreachable;
|
||||
mlir_ctx.loadAllAvailableDialects();
|
||||
|
||||
// Too long module names create too long file paths.
|
||||
const name = full_name[0..@min(128, full_name.len)];
|
||||
|
||||
const loc = mlir_ctx.location(@src()).named(mlir_ctx, "main");
|
||||
const module = mlir.Module.init(loc);
|
||||
module.op().setAttributeByName("sym_name", mlir.StringAttribute.init(mlir_ctx, name).as(mlir.Attribute).?);
|
||||
|
||||
@ -1364,7 +1364,7 @@ pub const Tensor = struct {
|
||||
else
|
||||
toI64(axes__);
|
||||
|
||||
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {} and {}", .{ self.rank(), permutation.len });
|
||||
stdx.debug.assert(permutation.len == self.rank(), "transpose expects input tensor rank and 'axes_' length to be equal, got {_} and {d}", .{ self, permutation[0..@min(permutation.len, MAX_RANK + 2)] });
|
||||
|
||||
if (std.mem.eql(i64, permutation, no_op[0..self.rank()])) {
|
||||
return self;
|
||||
@ -1621,7 +1621,7 @@ pub const Tensor = struct {
|
||||
|
||||
/// Repeats a Tensor several times along the given axis.
|
||||
///
|
||||
/// * repeat1d(x, concat(&.{x, x, x, x}, axis);
|
||||
/// * repeat1d(x, axis, 4) = concat(&.{x, x, x, x}, axis);
|
||||
/// * repeat1d([0, 1, 2, 3], 0, 2) = [0, 1, 2, 3, 0, 1, 2, 3]
|
||||
pub fn repeat1d(self: Tensor, axis_: anytype, n_rep: u63) Tensor {
|
||||
if (n_rep == 1) {
|
||||
@ -3738,13 +3738,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
fn printCallback(host_buffer: HostBuffer) void {
|
||||
switch (host_buffer.dtype()) {
|
||||
inline else => |dt| {
|
||||
const items = host_buffer.items(dt.toZigType());
|
||||
const n = @min(items.len, 1024);
|
||||
std.debug.print("Device buffer: {}: {any}\n", .{ host_buffer.shape(), items[0..n] });
|
||||
},
|
||||
}
|
||||
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -13,19 +13,17 @@ var _platform: ?zml.Platform = null;
|
||||
pub fn env() zml.Platform {
|
||||
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
|
||||
if (_platform == null) {
|
||||
_test_compile_opts = .{
|
||||
.sharding_enabled = true,
|
||||
};
|
||||
|
||||
var ctx = zml.Context.init() catch unreachable;
|
||||
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
|
||||
_platform = ctx.autoPlatform(.{}).withCompilationOptions(.{
|
||||
.xla_dump_to = "/tmp/zml/tests/",
|
||||
.sharding_enabled = true,
|
||||
.xla_dump_hlo_pass_re = ".*",
|
||||
});
|
||||
}
|
||||
|
||||
return _platform.?;
|
||||
}
|
||||
|
||||
var _test_compile_opts: zml.CompilationOptions = .{};
|
||||
|
||||
/// In neural network we generally care about the relative precision,
|
||||
/// but on a given dimension, if the output is close to 0, then the precision
|
||||
/// don't matter as much.
|
||||
@ -53,6 +51,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
if (should_free_left) left.deinit(allocator);
|
||||
if (should_free_right) right.deinit(allocator);
|
||||
}
|
||||
errdefer log.err("\n--> Left: {}\n--> Right: {}", .{ left.pretty(), right.pretty() });
|
||||
|
||||
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
||||
log.err("left.shape() {} != right.shape() {}", .{ left.shape(), right.shape() });
|
||||
|
||||
Loading…
Reference in New Issue
Block a user