stdx.fmt: add slice formatting support, improving on previous prettyPrinter implementation by leveraging internal fmt mechanisms.

This commit is contained in:
Tarry Singh 2024-10-18 15:05:08 +00:00
parent fe55c600d4
commit 4ef81b89ea
4 changed files with 196 additions and 20 deletions

View File

@ -5,6 +5,7 @@ zig_library(
srcs = [ srcs = [
"debug.zig", "debug.zig",
"flags.zig", "flags.zig",
"fmt.zig",
"io.zig", "io.zig",
"json.zig", "json.zig",
"math.zig", "math.zig",

155
stdx/fmt.zig Normal file
View File

@ -0,0 +1,155 @@
const std = @import("std");
pub const Fmt = union(enum) {
int: IntFmt,
float: FloatFmt,
generic: void,
pub fn parse(T: type, comptime fmt_: []const u8) Fmt {
return switch (@typeInfo(T)) {
.float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) },
.int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) },
else => .{ .generic = {} },
};
}
};
pub const FullFormatOptions = struct {
fmt: Fmt,
options: std.fmt.FormatOptions,
};
pub const IntFmt = struct {
base: u8,
case: std.fmt.Case = .lower,
pub fn parseComptime(comptime fmt_: []const u8) IntFmt {
return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_);
}
pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt {
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d"))
.{ .base = 10, .case = .lower }
else if (std.mem.eql(u8, fmt_, "x"))
.{ .base = 16, .case = .lower }
else if (std.mem.eql(u8, fmt_, "X"))
.{ .base = 16, .case = .upper }
else if (std.mem.eql(u8, fmt_, "o"))
.{ .base = 8, .case = .upper }
else
// TODO: unicode/ascii
error.InvalidArgument;
}
};
pub const FloatFmt = enum(u8) {
scientific = @intFromEnum(std.fmt.format_float.Format.scientific),
decimal = @intFromEnum(std.fmt.format_float.Format.decimal),
hex,
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_);
}
pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt {
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e"))
.scientific
else if (std.mem.eql(u8, fmt_, "d"))
.decimal
else if (std.mem.eql(u8, fmt_, "x"))
.hex
else
error.InvalidArgument;
}
};
pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
return switch (@typeInfo(@TypeOf(value))) {
.comptime_float, .float => try formatFloatValue(value, full, writer),
.comptime_int, .int => try formatIntValue(value, full, writer),
else => try formatAnyValue(value, full, writer),
};
}
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
const formatFloat = std.fmt.format_float.formatFloat;
var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined;
const x = switch (@typeInfo(@TypeOf(value))) {
.@"struct" => value.toF32(),
.float => value,
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
};
const s_or_err = switch (full.fmt.float) {
.scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }),
.decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }),
.hex => hex: {
var buf_stream = std.io.fixedBufferStream(&buf);
std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable;
break :hex buf_stream.getWritten();
},
};
const s = s_or_err catch "(float)";
return std.fmt.formatBuf(s, full.options, writer);
}
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
switch (@typeInfo(@TypeOf(value))) {
.int => {},
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
}
return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer);
}
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
var buf: [48]u8 = undefined;
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
buf[45..].* = "...".*;
break :blk buf[0..];
};
return std.fmt.formatBuf(s, full.options, writer);
}
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
// Write first rows
const num_cols: usize = full.options.width orelse 12;
const n: usize = values.len;
_ = try writer.write("{");
if (n <= num_cols) {
for (values, 0..) |v, i| {
// Force inlining so that the switch and the buffer can be done once.
try @call(.always_inline, fmt_func, .{ v, full, writer });
if (i < n - 1) _ = try writer.write(",");
}
} else {
const half = @divFloor(num_cols, 2);
for (values[0..half]) |v| {
try @call(.always_inline, fmt_func, .{ v, full, writer });
_ = try writer.write(",");
}
_ = try writer.write(" ..., ");
for (values[n - half ..], 0..) |v, i| {
try @call(.always_inline, fmt_func, .{ v, full, writer });
if (i < half - 1) _ = try writer.write(",");
}
}
_ = try writer.write("}");
}
pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatAnyValue, values, full, writer);
}
pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatFloatValue, values, full, writer);
}
pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatIntValue, values, full, writer);
}
pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatAnyValue, values, full, writer);
}

View File

@ -1,5 +1,6 @@
pub const debug = @import("debug.zig"); pub const debug = @import("debug.zig");
pub const flags = @import("flags.zig"); pub const flags = @import("flags.zig");
pub const fmt = @import("fmt.zig");
pub const io = @import("io.zig"); pub const io = @import("io.zig");
pub const json = @import("json.zig"); pub const json = @import("json.zig");
pub const math = @import("math.zig"); pub const math = @import("math.zig");

View File

@ -5,6 +5,7 @@ const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const floats = @import("floats.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
@ -290,34 +291,52 @@ pub const HostBuffer = struct {
x: HostBuffer, x: HostBuffer,
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
_ = fmt; const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
_ = options; .integer => .parse(i32, fmt),
try prettyPrint(self.x, writer); .float => .parse(f32, fmt),
else => .parse(void, fmt),
};
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
} }
}; };
pub fn prettyPrint(self: HostBuffer, writer: anytype) !void { pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void {
return self.prettyPrintIndented(4, 0, writer); return self.prettyPrintIndented(writer, 4, 0, options);
} }
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void { fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
if (self.rank() == 1) { if (self.rank() == 0) {
try writer.writeByteNTimes(' ', indent_level); // Special case input tensor is a scalar
return switch (self.dtype()) { return switch (self.dtype()) {
inline else => |dt| { inline else => |dt| {
const values = self.items(dt.toZigType()); const val: dt.toZigType() = self.items(dt.toZigType())[0];
// Write first rows return switch (comptime dt.class()) {
const num_cols: u32 = 12; // Since we have custom floats, we need to explicitly convert to float32 ourselves.
const n: u64 = @intCast(self.dim(0)); .float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer),
if (n <= num_cols) { .integer => stdx.fmt.formatIntValue(val, options, writer),
try writer.print("{any},\n", .{values[0..n]}); .bool, .complex => stdx.fmt.formatAnyValue(val, options, writer),
} else { };
const half = @divExact(num_cols, 2);
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
}
}, },
}; };
} }
if (self.rank() == 1) {
// Print a contiguous slice of items from the buffer in one line.
// The number of items printed is controlled by the user through format syntax.
try writer.writeByteNTimes(' ', indent_level);
switch (self.dtype()) {
inline else => |dt| {
const values = self.items(dt.toZigType());
switch (comptime dt.class()) {
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
.bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer),
}
},
}
try writer.writeByte('\n');
return;
}
// TODO: consider removing the \n if dim is 1 for this axis.
try writer.writeByteNTimes(' ', indent_level); try writer.writeByteNTimes(' ', indent_level);
_ = try writer.write("{\n"); _ = try writer.write("{\n");
defer { defer {
@ -330,7 +349,7 @@ pub const HostBuffer = struct {
for (0..@min(num_rows, n)) |d| { for (0..@min(num_rows, n)) |d| {
const di: i64 = @intCast(d); const di: i64 = @intCast(d);
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
} }
if (n < num_rows) return; if (n < num_rows) return;
@ -343,7 +362,7 @@ pub const HostBuffer = struct {
for (@max(n - num_rows, num_rows)..n) |d| { for (@max(n - num_rows, num_rows)..n) |d| {
const di: i64 = @intCast(d); const di: i64 = @intCast(d);
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
} }
} }
}; };