stdx.fmt: add slice formatting support, improving on previous prettyPrinter implementation by leveraging internal fmt mechanisms.
This commit is contained in:
parent
fe55c600d4
commit
4ef81b89ea
@ -5,6 +5,7 @@ zig_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"debug.zig",
|
"debug.zig",
|
||||||
"flags.zig",
|
"flags.zig",
|
||||||
|
"fmt.zig",
|
||||||
"io.zig",
|
"io.zig",
|
||||||
"json.zig",
|
"json.zig",
|
||||||
"math.zig",
|
"math.zig",
|
||||||
|
|||||||
155
stdx/fmt.zig
Normal file
155
stdx/fmt.zig
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub const Fmt = union(enum) {
|
||||||
|
int: IntFmt,
|
||||||
|
float: FloatFmt,
|
||||||
|
generic: void,
|
||||||
|
|
||||||
|
pub fn parse(T: type, comptime fmt_: []const u8) Fmt {
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) },
|
||||||
|
.int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) },
|
||||||
|
else => .{ .generic = {} },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const FullFormatOptions = struct {
|
||||||
|
fmt: Fmt,
|
||||||
|
options: std.fmt.FormatOptions,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const IntFmt = struct {
|
||||||
|
base: u8,
|
||||||
|
case: std.fmt.Case = .lower,
|
||||||
|
|
||||||
|
pub fn parseComptime(comptime fmt_: []const u8) IntFmt {
|
||||||
|
return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt {
|
||||||
|
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d"))
|
||||||
|
.{ .base = 10, .case = .lower }
|
||||||
|
else if (std.mem.eql(u8, fmt_, "x"))
|
||||||
|
.{ .base = 16, .case = .lower }
|
||||||
|
else if (std.mem.eql(u8, fmt_, "X"))
|
||||||
|
.{ .base = 16, .case = .upper }
|
||||||
|
else if (std.mem.eql(u8, fmt_, "o"))
|
||||||
|
.{ .base = 8, .case = .upper }
|
||||||
|
else
|
||||||
|
// TODO: unicode/ascii
|
||||||
|
error.InvalidArgument;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const FloatFmt = enum(u8) {
|
||||||
|
scientific = @intFromEnum(std.fmt.format_float.Format.scientific),
|
||||||
|
decimal = @intFromEnum(std.fmt.format_float.Format.decimal),
|
||||||
|
hex,
|
||||||
|
|
||||||
|
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
|
||||||
|
return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt {
|
||||||
|
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e"))
|
||||||
|
.scientific
|
||||||
|
else if (std.mem.eql(u8, fmt_, "d"))
|
||||||
|
.decimal
|
||||||
|
else if (std.mem.eql(u8, fmt_, "x"))
|
||||||
|
.hex
|
||||||
|
else
|
||||||
|
error.InvalidArgument;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
return switch (@typeInfo(@TypeOf(value))) {
|
||||||
|
.comptime_float, .float => try formatFloatValue(value, full, writer),
|
||||||
|
.comptime_int, .int => try formatIntValue(value, full, writer),
|
||||||
|
else => try formatAnyValue(value, full, writer),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
const formatFloat = std.fmt.format_float.formatFloat;
|
||||||
|
var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined;
|
||||||
|
|
||||||
|
const x = switch (@typeInfo(@TypeOf(value))) {
|
||||||
|
.@"struct" => value.toF32(),
|
||||||
|
.float => value,
|
||||||
|
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
|
||||||
|
};
|
||||||
|
const s_or_err = switch (full.fmt.float) {
|
||||||
|
.scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }),
|
||||||
|
.decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }),
|
||||||
|
.hex => hex: {
|
||||||
|
var buf_stream = std.io.fixedBufferStream(&buf);
|
||||||
|
std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable;
|
||||||
|
break :hex buf_stream.getWritten();
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const s = s_or_err catch "(float)";
|
||||||
|
return std.fmt.formatBuf(s, full.options, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
switch (@typeInfo(@TypeOf(value))) {
|
||||||
|
.int => {},
|
||||||
|
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
|
||||||
|
}
|
||||||
|
return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
var buf: [48]u8 = undefined;
|
||||||
|
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
|
||||||
|
buf[45..].* = "...".*;
|
||||||
|
break :blk buf[0..];
|
||||||
|
};
|
||||||
|
return std.fmt.formatBuf(s, full.options, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
|
||||||
|
// Write first rows
|
||||||
|
const num_cols: usize = full.options.width orelse 12;
|
||||||
|
const n: usize = values.len;
|
||||||
|
_ = try writer.write("{");
|
||||||
|
if (n <= num_cols) {
|
||||||
|
for (values, 0..) |v, i| {
|
||||||
|
// Force inlining so that the switch and the buffer can be done once.
|
||||||
|
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||||
|
if (i < n - 1) _ = try writer.write(",");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const half = @divFloor(num_cols, 2);
|
||||||
|
for (values[0..half]) |v| {
|
||||||
|
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||||
|
_ = try writer.write(",");
|
||||||
|
}
|
||||||
|
_ = try writer.write(" ..., ");
|
||||||
|
for (values[n - half ..], 0..) |v, i| {
|
||||||
|
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||||
|
if (i < half - 1) _ = try writer.write(",");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = try writer.write("}");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
return try formatSliceCustom(formatFloatValue, values, full, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
return try formatSliceCustom(formatIntValue, values, full, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||||
|
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
||||||
|
}
|
||||||
@ -1,5 +1,6 @@
|
|||||||
pub const debug = @import("debug.zig");
|
pub const debug = @import("debug.zig");
|
||||||
pub const flags = @import("flags.zig");
|
pub const flags = @import("flags.zig");
|
||||||
|
pub const fmt = @import("fmt.zig");
|
||||||
pub const io = @import("io.zig");
|
pub const io = @import("io.zig");
|
||||||
pub const json = @import("json.zig");
|
pub const json = @import("json.zig");
|
||||||
pub const math = @import("math.zig");
|
pub const math = @import("math.zig");
|
||||||
|
|||||||
@ -5,6 +5,7 @@ const stdx = @import("stdx");
|
|||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
|
const floats = @import("floats.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
|
||||||
@ -290,34 +291,52 @@ pub const HostBuffer = struct {
|
|||||||
x: HostBuffer,
|
x: HostBuffer,
|
||||||
|
|
||||||
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
||||||
_ = fmt;
|
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
|
||||||
_ = options;
|
.integer => .parse(i32, fmt),
|
||||||
try prettyPrint(self.x, writer);
|
.float => .parse(f32, fmt),
|
||||||
|
else => .parse(void, fmt),
|
||||||
|
};
|
||||||
|
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn prettyPrint(self: HostBuffer, writer: anytype) !void {
|
pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void {
|
||||||
return self.prettyPrintIndented(4, 0, writer);
|
return self.prettyPrintIndented(writer, 4, 0, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void {
|
fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
|
||||||
if (self.rank() == 1) {
|
if (self.rank() == 0) {
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
// Special case input tensor is a scalar
|
||||||
return switch (self.dtype()) {
|
return switch (self.dtype()) {
|
||||||
inline else => |dt| {
|
inline else => |dt| {
|
||||||
const values = self.items(dt.toZigType());
|
const val: dt.toZigType() = self.items(dt.toZigType())[0];
|
||||||
// Write first rows
|
return switch (comptime dt.class()) {
|
||||||
const num_cols: u32 = 12;
|
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
|
||||||
const n: u64 = @intCast(self.dim(0));
|
.float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer),
|
||||||
if (n <= num_cols) {
|
.integer => stdx.fmt.formatIntValue(val, options, writer),
|
||||||
try writer.print("{any},\n", .{values[0..n]});
|
.bool, .complex => stdx.fmt.formatAnyValue(val, options, writer),
|
||||||
} else {
|
};
|
||||||
const half = @divExact(num_cols, 2);
|
|
||||||
try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] });
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if (self.rank() == 1) {
|
||||||
|
// Print a contiguous slice of items from the buffer in one line.
|
||||||
|
// The number of items printed is controlled by the user through format syntax.
|
||||||
|
try writer.writeByteNTimes(' ', indent_level);
|
||||||
|
switch (self.dtype()) {
|
||||||
|
inline else => |dt| {
|
||||||
|
const values = self.items(dt.toZigType());
|
||||||
|
switch (comptime dt.class()) {
|
||||||
|
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
||||||
|
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
||||||
|
.bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try writer.writeByte('\n');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// TODO: consider removing the \n if dim is 1 for this axis.
|
||||||
try writer.writeByteNTimes(' ', indent_level);
|
try writer.writeByteNTimes(' ', indent_level);
|
||||||
_ = try writer.write("{\n");
|
_ = try writer.write("{\n");
|
||||||
defer {
|
defer {
|
||||||
@ -330,7 +349,7 @@ pub const HostBuffer = struct {
|
|||||||
for (0..@min(num_rows, n)) |d| {
|
for (0..@min(num_rows, n)) |d| {
|
||||||
const di: i64 = @intCast(d);
|
const di: i64 = @intCast(d);
|
||||||
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
||||||
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
|
try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n < num_rows) return;
|
if (n < num_rows) return;
|
||||||
@ -343,7 +362,7 @@ pub const HostBuffer = struct {
|
|||||||
for (@max(n - num_rows, num_rows)..n) |d| {
|
for (@max(n - num_rows, num_rows)..n) |d| {
|
||||||
const di: i64 = @intCast(d);
|
const di: i64 = @intCast(d);
|
||||||
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
||||||
try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer);
|
try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user