From 4ef81b89ea06bc609104adbcb552f35316506ffd Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Fri, 18 Oct 2024 15:05:08 +0000 Subject: [PATCH] stdx.fmt: add slice formatting support, improving on previous prettyPrinter implementation by leveraging internal fmt mechanisms. --- stdx/BUILD.bazel | 1 + stdx/fmt.zig | 155 +++++++++++++++++++++++++++++++++++++++++++++ stdx/stdx.zig | 1 + zml/hostbuffer.zig | 59 +++++++++++------ 4 files changed, 196 insertions(+), 20 deletions(-) create mode 100644 stdx/fmt.zig diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index b170125..f2fb4ea 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -5,6 +5,7 @@ zig_library( srcs = [ "debug.zig", "flags.zig", + "fmt.zig", "io.zig", "json.zig", "math.zig", diff --git a/stdx/fmt.zig b/stdx/fmt.zig new file mode 100644 index 0000000..3934de1 --- /dev/null +++ b/stdx/fmt.zig @@ -0,0 +1,155 @@ +const std = @import("std"); + +pub const Fmt = union(enum) { + int: IntFmt, + float: FloatFmt, + generic: void, + + pub fn parse(T: type, comptime fmt_: []const u8) Fmt { + return switch (@typeInfo(T)) { + .float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) }, + .int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) }, + else => .{ .generic = {} }, + }; + } +}; + +pub const FullFormatOptions = struct { + fmt: Fmt, + options: std.fmt.FormatOptions, +}; + +pub const IntFmt = struct { + base: u8, + case: std.fmt.Case = .lower, + + pub fn parseComptime(comptime fmt_: []const u8) IntFmt { + return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_); + } + + pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt { + return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d")) + .{ .base = 10, .case = .lower } + else if (std.mem.eql(u8, fmt_, "x")) + .{ .base = 16, .case = .lower } + else if (std.mem.eql(u8, fmt_, "X")) + .{ .base = 16, .case = .upper } + else if (std.mem.eql(u8, fmt_, "o")) + .{ .base = 8, .case = .upper } + else + // TODO: unicode/ascii + error.InvalidArgument; + } +}; + +pub const FloatFmt = enum(u8) { + scientific = @intFromEnum(std.fmt.format_float.Format.scientific), + decimal = @intFromEnum(std.fmt.format_float.Format.decimal), + hex, + + pub fn parseComptime(comptime fmt_: []const u8) FloatFmt { + return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_); + } + + pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt { + return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e")) + .scientific + else if (std.mem.eql(u8, fmt_, "d")) + .decimal + else if (std.mem.eql(u8, fmt_, "x")) + .hex + else + error.InvalidArgument; + } +}; + +pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { + return switch (@typeInfo(@TypeOf(value))) { + .comptime_float, .float => try formatFloatValue(value, full, writer), + .comptime_int, .int => try formatIntValue(value, full, writer), + else => try formatAnyValue(value, full, writer), + }; +} + +pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { + const formatFloat = std.fmt.format_float.formatFloat; + var buf: [std.fmt.format_float.bufferSize(.decimal, f64)]u8 = undefined; + + const x = switch (@typeInfo(@TypeOf(value))) { + .@"struct" => value.toF32(), + .float => value, + else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))), + }; + const s_or_err = switch (full.fmt.float) { + .scientific => formatFloat(&buf, x, .{ .mode = .scientific, .precision = full.options.precision }), + .decimal => formatFloat(&buf, x, .{ .mode = .decimal, .precision = full.options.precision }), + .hex => hex: { + var buf_stream = std.io.fixedBufferStream(&buf); + std.fmt.formatFloatHexadecimal(x, full.options, buf_stream.writer()) catch unreachable; + break :hex buf_stream.getWritten(); + }, + }; + + const s = s_or_err catch "(float)"; + return std.fmt.formatBuf(s, full.options, writer); +} + +pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { + switch (@typeInfo(@TypeOf(value))) { + .int => {}, + else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))), + } + return std.fmt.formatInt(value, full.fmt.int.base, full.fmt.int.case, full.options, writer); +} + +pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: anytype) !void { + var buf: [48]u8 = undefined; + const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: { + buf[45..].* = "...".*; + break :blk buf[0..]; + }; + return std.fmt.formatBuf(s, full.options, writer); +} + +pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void { + + // Write first rows + const num_cols: usize = full.options.width orelse 12; + const n: usize = values.len; + _ = try writer.write("{"); + if (n <= num_cols) { + for (values, 0..) |v, i| { + // Force inlining so that the switch and the buffer can be done once. + try @call(.always_inline, fmt_func, .{ v, full, writer }); + if (i < n - 1) _ = try writer.write(","); + } + } else { + const half = @divFloor(num_cols, 2); + for (values[0..half]) |v| { + try @call(.always_inline, fmt_func, .{ v, full, writer }); + _ = try writer.write(","); + } + _ = try writer.write(" ..., "); + for (values[n - half ..], 0..) |v, i| { + try @call(.always_inline, fmt_func, .{ v, full, writer }); + if (i < half - 1) _ = try writer.write(","); + } + } + _ = try writer.write("}"); +} + +pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void { + return try formatSliceCustom(formatAnyValue, values, full, writer); +} + +pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { + return try formatSliceCustom(formatFloatValue, values, full, writer); +} + +pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { + return try formatSliceCustom(formatIntValue, values, full, writer); +} + +pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void { + return try formatSliceCustom(formatAnyValue, values, full, writer); +} diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 96707c8..226c122 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,5 +1,6 @@ pub const debug = @import("debug.zig"); pub const flags = @import("flags.zig"); +pub const fmt = @import("fmt.zig"); pub const io = @import("io.zig"); pub const json = @import("json.zig"); pub const math = @import("math.zig"); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 1885a00..39c6d04 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -5,6 +5,7 @@ const stdx = @import("stdx"); const Buffer = @import("buffer.zig").Buffer; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; +const floats = @import("floats.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; @@ -290,34 +291,52 @@ pub const HostBuffer = struct { x: HostBuffer, pub fn format(self: PrettyPrinter, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - _ = fmt; - _ = options; - try prettyPrint(self.x, writer); + const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) { + .integer => .parse(i32, fmt), + .float => .parse(f32, fmt), + else => .parse(void, fmt), + }; + try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options }); } }; - pub fn prettyPrint(self: HostBuffer, writer: anytype) !void { - return self.prettyPrintIndented(4, 0, writer); + pub fn prettyPrint(self: HostBuffer, writer: anytype, options: stdx.fmt.FullFormatOptions) !void { + return self.prettyPrintIndented(writer, 4, 0, options); } - fn prettyPrintIndented(self: HostBuffer, num_rows: u8, indent_level: u8, writer: anytype) !void { - if (self.rank() == 1) { - try writer.writeByteNTimes(' ', indent_level); + fn prettyPrintIndented(self: HostBuffer, writer: anytype, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void { + if (self.rank() == 0) { + // Special case input tensor is a scalar return switch (self.dtype()) { inline else => |dt| { - const values = self.items(dt.toZigType()); - // Write first rows - const num_cols: u32 = 12; - const n: u64 = @intCast(self.dim(0)); - if (n <= num_cols) { - try writer.print("{any},\n", .{values[0..n]}); - } else { - const half = @divExact(num_cols, 2); - try writer.print("{any}, ..., {any},\n", .{ values[0..half], values[n - half ..] }); - } + const val: dt.toZigType() = self.items(dt.toZigType())[0]; + return switch (comptime dt.class()) { + // Since we have custom floats, we need to explicitly convert to float32 ourselves. + .float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer), + .integer => stdx.fmt.formatIntValue(val, options, writer), + .bool, .complex => stdx.fmt.formatAnyValue(val, options, writer), + }; }, }; } + if (self.rank() == 1) { + // Print a contiguous slice of items from the buffer in one line. + // The number of items printed is controlled by the user through format syntax. + try writer.writeByteNTimes(' ', indent_level); + switch (self.dtype()) { + inline else => |dt| { + const values = self.items(dt.toZigType()); + switch (comptime dt.class()) { + .float => try stdx.fmt.formatFloatSlice(values, options, writer), + .integer => try stdx.fmt.formatIntSlice(values, options, writer), + .bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer), + } + }, + } + try writer.writeByte('\n'); + return; + } + // TODO: consider removing the \n if dim is 1 for this axis. try writer.writeByteNTimes(' ', indent_level); _ = try writer.write("{\n"); defer { @@ -330,7 +349,7 @@ pub const HostBuffer = struct { for (0..@min(num_rows, n)) |d| { const di: i64 = @intCast(d); const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); - try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); + try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options); } if (n < num_rows) return; @@ -343,7 +362,7 @@ pub const HostBuffer = struct { for (@max(n - num_rows, num_rows)..n) |d| { const di: i64 = @intCast(d); const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0); - try sliced_self.prettyPrintIndented(num_rows, indent_level + 2, writer); + try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options); } } };