2023-01-02 14:28:25 +00:00
|
|
|
const std = @import("std");
|
2024-09-10 09:14:28 +00:00
|
|
|
|
2023-06-21 14:45:14 +00:00
|
|
|
const stdx = @import("stdx");
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
const Buffer = @import("buffer.zig").Buffer;
|
|
|
|
|
const DataType = @import("dtype.zig").DataType;
|
2024-10-18 15:05:08 +00:00
|
|
|
const floats = @import("floats.zig");
|
2023-01-02 14:28:25 +00:00
|
|
|
const Platform = @import("platform.zig").Platform;
|
2023-01-27 14:35:11 +00:00
|
|
|
const Shape = @import("shape.zig").Shape;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-01-23 16:28:19 +00:00
|
|
|
test {
|
2023-01-27 14:35:11 +00:00
|
|
|
std.testing.refAllDecls(HostBuffer);
|
2023-01-23 16:28:19 +00:00
|
|
|
}
|
|
|
|
|
|
2023-01-02 14:28:25 +00:00
|
|
|
/// Represents a tensor with associated data allocated by user code.
|
|
|
|
|
/// If the memory is `.managed` it needs to be freed with `x.deinit(allocator)`
|
|
|
|
|
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
|
|
|
|
pub const HostBuffer = struct {
|
|
|
|
|
_shape: Shape,
|
2024-10-28 11:21:46 +00:00
|
|
|
_strides: [Shape.MAX_RANK]i64,
|
|
|
|
|
_data: [*]const u8,
|
2023-01-02 14:28:25 +00:00
|
|
|
_memory: union(enum) {
|
2024-07-02 14:19:04 +00:00
|
|
|
managed: std.mem.Alignment,
|
2023-01-02 14:28:25 +00:00
|
|
|
unmanaged,
|
|
|
|
|
} = .unmanaged,
|
|
|
|
|
|
|
|
|
|
/// Allocates a HostBuffer with the given shape.
|
|
|
|
|
/// The memory is left undefined.
|
|
|
|
|
/// The caller owns the memory, and need to call `deinit()`.
|
2024-10-28 11:21:46 +00:00
|
|
|
pub fn empty(allocator: std.mem.Allocator, sh: Shape) error{OutOfMemory}!HostBuffer {
|
2023-01-02 14:28:25 +00:00
|
|
|
return .{
|
|
|
|
|
._shape = sh,
|
2024-10-28 11:21:46 +00:00
|
|
|
._strides = sh.computeStrides().buffer,
|
2025-07-28 13:54:28 +00:00
|
|
|
._data = (try allocator.alignedAlloc(u8, .@"64", sh.byteSize())).ptr,
|
2024-07-02 14:19:04 +00:00
|
|
|
._memory = .{ .managed = .@"64" },
|
2023-01-02 14:28:25 +00:00
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Wraps an exisiting slice of bytes into a HostBuffer.
|
|
|
|
|
/// The returned HostBuffer doesn't take ownership of the slice
|
|
|
|
|
/// that will still need to be deallocated.
|
|
|
|
|
pub fn fromBytes(shape_: Shape, data_: []const u8) HostBuffer {
|
2024-02-28 15:47:37 +00:00
|
|
|
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
|
2023-01-02 14:28:25 +00:00
|
|
|
return .{
|
|
|
|
|
._shape = shape_,
|
2024-10-28 11:21:46 +00:00
|
|
|
._strides = shape_.computeStrides().buffer,
|
|
|
|
|
._data = data_.ptr,
|
2023-01-02 14:28:25 +00:00
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Frees the underlying memory if we owned it, ie if we've been created with `HostBuffer.empty`.
|
|
|
|
|
pub fn deinit(self: *const HostBuffer, allocator: std.mem.Allocator) void {
|
|
|
|
|
// This means we don't own the data.
|
|
|
|
|
if (self._memory == .unmanaged) return;
|
|
|
|
|
const log2_align = self._memory.managed;
|
2024-10-28 11:21:46 +00:00
|
|
|
allocator.rawFree(self.mutBytes(), log2_align, @returnAddress());
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Wraps an exisiting slice into a HostBuffer.
|
|
|
|
|
/// The element type is inferred from the slice type.
|
|
|
|
|
/// The returned HostBuffer doesn't take ownership of the slice
|
|
|
|
|
/// that will still need to be deallocated.
|
|
|
|
|
pub fn fromSlice(sh: anytype, s: anytype) HostBuffer {
|
|
|
|
|
const shape_ = Shape.init(sh, DataType.fromSliceElementType(s));
|
2024-10-28 11:21:46 +00:00
|
|
|
const raw_bytes = std.mem.sliceAsBytes(s);
|
|
|
|
|
std.debug.assert(shape_.byteSize() == raw_bytes.len);
|
2023-01-02 14:28:25 +00:00
|
|
|
return .{
|
|
|
|
|
._shape = shape_,
|
2024-10-28 11:21:46 +00:00
|
|
|
._strides = shape_.computeStrides().buffer,
|
|
|
|
|
._data = raw_bytes.ptr,
|
2023-01-02 14:28:25 +00:00
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Wraps an exisiting slice into a HostBuffer.
|
|
|
|
|
/// The element type is inferred from the slice type.
|
|
|
|
|
/// The values in the slice doesn't need to be contiguous,
|
|
|
|
|
/// strides can be specified.
|
|
|
|
|
/// The returned HostBuffer doesn't take ownership of the slice.
|
|
|
|
|
pub fn fromStridedSlice(sh: Shape, s: anytype, strides_: []const i64) HostBuffer {
|
|
|
|
|
// std.debug.assert(sh.count() == s.len);
|
|
|
|
|
var tmp: [Shape.MAX_RANK]i64 = undefined;
|
|
|
|
|
@memcpy(tmp[0..strides_.len], strides_);
|
|
|
|
|
return .{
|
|
|
|
|
._shape = sh,
|
2024-10-28 11:21:46 +00:00
|
|
|
._data = @alignCast(std.mem.sliceAsBytes(s).ptr),
|
2023-01-02 14:28:25 +00:00
|
|
|
._strides = tmp,
|
|
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Creates a tensor from a **pointer** to a "multi dimension" array.
|
|
|
|
|
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
|
2024-10-28 11:21:46 +00:00
|
|
|
/// Typically this is use with constant arrays.
|
2023-01-02 14:28:25 +00:00
|
|
|
pub fn fromArray(arr_ptr: anytype) HostBuffer {
|
|
|
|
|
const T = @TypeOf(arr_ptr.*);
|
|
|
|
|
const sh = parseArrayInfo(T);
|
2024-10-28 11:21:46 +00:00
|
|
|
std.debug.assert(sh.byteSize() == @sizeOf(T));
|
2023-01-02 14:28:25 +00:00
|
|
|
return .{
|
|
|
|
|
._shape = sh,
|
2024-10-28 11:21:46 +00:00
|
|
|
._strides = sh.computeStrides().buffer,
|
|
|
|
|
._data = @ptrCast(arr_ptr),
|
2023-01-02 14:28:25 +00:00
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-10 09:14:28 +00:00
|
|
|
/// Returns a HostBuffer tagged with the tags in 'tagz'.
|
|
|
|
|
pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer {
|
|
|
|
|
var res = self;
|
|
|
|
|
res._shape = self._shape.withTags(tagz);
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-27 14:35:11 +00:00
|
|
|
pub const ArangeArgs = struct {
|
|
|
|
|
start: i64 = 0,
|
|
|
|
|
end: i64,
|
|
|
|
|
step: i64 = 1,
|
|
|
|
|
};
|
|
|
|
|
|
2023-01-02 14:28:25 +00:00
|
|
|
/// Allocates a HostBuffer with the given shape.
|
|
|
|
|
/// The memory is initialized with increasing numbers.
|
|
|
|
|
/// The caller owns the memory, and need to call `deinit()`.
|
2023-01-27 14:35:11 +00:00
|
|
|
pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
|
2023-06-21 14:45:14 +00:00
|
|
|
stdx.debug.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
|
|
|
|
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
|
|
|
|
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
|
|
|
|
switch (dt) {
|
2024-10-28 11:21:46 +00:00
|
|
|
inline else => |d| if (comptime d.class() != .integer) {
|
|
|
|
|
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
|
|
|
|
} else {
|
|
|
|
|
const Zt = d.toZigType();
|
2023-01-02 14:28:25 +00:00
|
|
|
var j: i64 = args.start;
|
2024-10-28 11:21:46 +00:00
|
|
|
for (res.mutItems(Zt)) |*val| {
|
|
|
|
|
val.* = @intCast(j);
|
2023-01-02 14:28:25 +00:00
|
|
|
j +%= args.step;
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
test arange {
|
|
|
|
|
{
|
|
|
|
|
var x = try arange(std.testing.allocator, .{ .end = 8 }, .i32);
|
|
|
|
|
defer x.deinit(std.testing.allocator);
|
|
|
|
|
try std.testing.expectEqualSlices(i32, &.{ 0, 1, 2, 3, 4, 5, 6, 7 }, x.items(i32));
|
|
|
|
|
}
|
|
|
|
|
{
|
|
|
|
|
var x = try arange(std.testing.allocator, .{ .start = -3, .end = 12, .step = 2 }, .i32);
|
|
|
|
|
defer x.deinit(std.testing.allocator);
|
|
|
|
|
try std.testing.expectEqualSlices(i32, &.{ -3, -1, 1, 3, 5, 7, 9, 11 }, x.items(i32));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Copies this HostBuffer to the given accelerator.
|
|
|
|
|
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
|
2024-12-25 17:14:44 +00:00
|
|
|
return try self.toDeviceOpts(platform_, .{});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Copies this HostBuffer to the given accelerator (with options).
|
|
|
|
|
pub fn toDeviceOpts(self: HostBuffer, platform_: Platform, opts: Buffer.FromOptions) !Buffer {
|
|
|
|
|
return try Buffer.from(platform_, self, opts);
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|
|
|
|
|
|
2023-02-24 17:33:14 +00:00
|
|
|
/// Interpret the underlying data as a contiguous slice.
|
|
|
|
|
/// WARNING: It's only valid if the buffer is contiguous.
|
|
|
|
|
/// Strided buffers can't use this method.
|
2023-01-02 14:28:25 +00:00
|
|
|
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
2024-10-28 11:21:46 +00:00
|
|
|
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
2025-07-28 13:54:28 +00:00
|
|
|
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {f} as {s}", .{ self, @typeName(T) });
|
|
|
|
|
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
2025-08-07 15:09:27 +00:00
|
|
|
const ptr: [*]const T = @ptrCast(@alignCast(self._data));
|
2023-01-02 14:28:25 +00:00
|
|
|
return ptr[0..self._shape.count()];
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-28 11:21:46 +00:00
|
|
|
pub fn mutItems(self: HostBuffer, comptime T: type) []T {
|
|
|
|
|
return @constCast(self.items(T));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn bytes(self: HostBuffer) []const u8 {
|
2025-07-28 13:54:28 +00:00
|
|
|
stdx.debug.assert(self.isContiguous(), "{f} isn't contiguous, can't interpret as []const u8", .{self});
|
2024-10-28 11:21:46 +00:00
|
|
|
return self._data[0..self._shape.byteSize()];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn mutBytes(self: HostBuffer) []u8 {
|
|
|
|
|
return @constCast(self.bytes());
|
|
|
|
|
}
|
|
|
|
|
|
2023-01-02 14:28:25 +00:00
|
|
|
pub fn shape(self: HostBuffer) Shape {
|
|
|
|
|
return self._shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn dtype(self: HostBuffer) DataType {
|
|
|
|
|
return self._shape.dtype();
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-28 11:21:46 +00:00
|
|
|
pub fn strides(self: *const HostBuffer) []const i64 {
|
2023-01-27 14:35:11 +00:00
|
|
|
// Pass strides per pointer otherwise we return a pointer to this stack frame.
|
2024-10-28 11:21:46 +00:00
|
|
|
return self._strides[0..self._shape.rank()];
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|
|
|
|
|
|
2024-07-02 14:19:04 +00:00
|
|
|
// TODO: rename .data into ._data and make it a [*]u8
|
|
|
|
|
// pub fn data(self: HostBuffer) []const u8 {
|
|
|
|
|
// return self.data;
|
|
|
|
|
// }
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
pub inline fn rank(self: HostBuffer) u4 {
|
|
|
|
|
return self._shape.rank();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub inline fn count(self: HostBuffer) usize {
|
|
|
|
|
return self._shape.count();
|
|
|
|
|
}
|
|
|
|
|
|
2023-02-24 17:33:14 +00:00
|
|
|
pub fn dim(self: HostBuffer, axis_: anytype) i64 {
|
|
|
|
|
return self._shape.dim(axis_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn axis(self: HostBuffer, axis_: anytype) u3 {
|
|
|
|
|
return self._shape.axis(axis_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn isContiguous(self: HostBuffer) bool {
|
2024-10-28 11:21:46 +00:00
|
|
|
const _strides = self._strides;
|
2023-02-24 17:33:14 +00:00
|
|
|
const cont_strides = self._shape.computeStrides();
|
2024-02-19 12:34:18 +00:00
|
|
|
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
|
|
|
|
|
if (d != 1 and stride != cont_stride) return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
2025-07-28 13:54:28 +00:00
|
|
|
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {f}", .{self});
|
2023-01-02 14:28:25 +00:00
|
|
|
var res = self;
|
|
|
|
|
res._shape = self._shape.reshape(shape_);
|
2024-10-28 11:21:46 +00:00
|
|
|
res._strides = res._shape.computeStrides().buffer;
|
2023-01-02 14:28:25 +00:00
|
|
|
return res;
|
|
|
|
|
}
|
2023-02-24 17:33:14 +00:00
|
|
|
|
|
|
|
|
pub const Slice = struct {
|
|
|
|
|
start: i64 = 0,
|
|
|
|
|
end: ?i64 = null,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// Slices the input Tensor over the given axis_ using the given parameters.
|
|
|
|
|
pub fn slice1d(self: HostBuffer, axis_: anytype, s: Slice) HostBuffer {
|
|
|
|
|
const ax = self._shape.axis(axis_);
|
|
|
|
|
const d = self.dim(ax);
|
|
|
|
|
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
|
|
|
|
var end = s.end orelse d;
|
|
|
|
|
if (end < 0) end += d;
|
2025-07-28 13:54:28 +00:00
|
|
|
stdx.debug.assert(start >= 0 and start < d, "slice1d({f}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
|
|
|
|
|
stdx.debug.assert(end >= 1 and end <= d, "slice1d({f}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
|
|
|
|
stdx.debug.assert(start < end, "slice1d({f}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
2023-02-24 17:33:14 +00:00
|
|
|
|
2024-10-28 11:21:46 +00:00
|
|
|
const offset: usize = @intCast(start * self._strides[ax]);
|
|
|
|
|
const new_shape = self.shape().set(ax, end - start);
|
2023-02-24 17:33:14 +00:00
|
|
|
return .{
|
2024-10-28 11:21:46 +00:00
|
|
|
._shape = new_shape,
|
|
|
|
|
._data = self._data[offset..],
|
|
|
|
|
._strides = self._strides,
|
2023-02-24 17:33:14 +00:00
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
2024-09-10 09:14:28 +00:00
|
|
|
pub fn choose1d(self: HostBuffer, axis_: anytype, start: i64) HostBuffer {
|
|
|
|
|
const ax = self.axis(axis_);
|
|
|
|
|
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
|
|
|
|
|
}
|
|
|
|
|
|
2024-10-28 11:21:46 +00:00
|
|
|
pub fn choose(self: HostBuffer, offsets: anytype) HostBuffer {
|
|
|
|
|
const off, const tags = Shape.parseDimensions(offsets);
|
|
|
|
|
var sh = self._shape;
|
|
|
|
|
var offset: i64 = 0;
|
|
|
|
|
for (off.constSlice(), tags.constSlice()) |o, t| {
|
|
|
|
|
const ax = sh.axis(t);
|
|
|
|
|
offset += o * self._strides[ax];
|
|
|
|
|
sh._dims.buffer[ax] = 0;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var new_strides: [Shape.MAX_RANK]i64 = @splat(self.dtype().sizeOf());
|
|
|
|
|
|
|
|
|
|
// TODO rewrite with simd. This is a pshuf, but it's not supported by @shuffle.
|
|
|
|
|
var res_ax: u32 = 0;
|
|
|
|
|
for (0..self._shape.rank()) |ax| {
|
|
|
|
|
if (sh._dims.buffer[ax] == 0) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sh._dims.buffer[res_ax] = self._shape._dims.buffer[ax];
|
|
|
|
|
sh._tags.buffer[res_ax] = self._shape._tags.buffer[ax];
|
|
|
|
|
new_strides[res_ax] = self._strides[ax];
|
|
|
|
|
res_ax += 1;
|
|
|
|
|
}
|
|
|
|
|
sh._dims.len -= off.len;
|
|
|
|
|
sh._tags.len -= off.len;
|
|
|
|
|
|
|
|
|
|
return HostBuffer{
|
|
|
|
|
._shape = sh,
|
|
|
|
|
._strides = new_strides,
|
|
|
|
|
._data = self._data[@intCast(offset)..],
|
|
|
|
|
._memory = .unmanaged,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
2023-12-25 13:01:17 +00:00
|
|
|
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
|
|
|
|
const ax = self._shape.axis(axis_);
|
2025-07-28 13:54:28 +00:00
|
|
|
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {f}", .{ ax, self });
|
2023-12-25 13:01:17 +00:00
|
|
|
|
2025-07-28 13:54:28 +00:00
|
|
|
var strd: stdx.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
|
2024-10-28 11:21:46 +00:00
|
|
|
_ = strd.orderedRemove(ax);
|
|
|
|
|
|
2023-12-25 13:01:17 +00:00
|
|
|
return .{
|
|
|
|
|
._shape = self.shape().drop(ax),
|
2024-10-28 11:21:46 +00:00
|
|
|
._data = self._data,
|
|
|
|
|
._strides = strd.buffer,
|
2023-12-25 13:01:17 +00:00
|
|
|
._memory = self._memory,
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
Remove deprecated writer interface APIs from core ZML modules (async, MLIR, PJRT, runtime, fmt, aio, buffer, exe, hostbuffer, meta, mlirx).
2025-09-04 14:03:09 +00:00
|
|
|
pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
|
2025-07-28 13:54:28 +00:00
|
|
|
try writer.print("HostBuffer(.{f})", .{self._shape});
|
2023-02-24 17:33:14 +00:00
|
|
|
}
|
2023-12-25 13:01:17 +00:00
|
|
|
|
2025-08-20 10:27:54 +00:00
|
|
|
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
|
|
|
|
return self.prettyPrintIndented(writer, 4, 0, n);
|
2023-12-25 13:01:17 +00:00
|
|
|
}
|
|
|
|
|
|
2025-08-20 10:27:54 +00:00
|
|
|
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: std.fmt.Number) !void {
|
2024-10-18 15:05:08 +00:00
|
|
|
return self.prettyPrintIndented(writer, 4, 0, options);
|
2023-12-25 13:01:17 +00:00
|
|
|
}
|
|
|
|
|
|
2025-08-20 10:27:54 +00:00
|
|
|
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void {
|
2024-10-18 15:05:08 +00:00
|
|
|
if (self.rank() == 0) {
|
|
|
|
|
// Special case input tensor is a scalar
|
|
|
|
|
return switch (self.dtype()) {
|
|
|
|
|
inline else => |dt| {
|
|
|
|
|
const val: dt.toZigType() = self.items(dt.toZigType())[0];
|
|
|
|
|
return switch (comptime dt.class()) {
|
|
|
|
|
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
|
2025-08-20 10:27:54 +00:00
|
|
|
.float => stdx.fmt.formatFloat(floats.floatCast(f32, val), options, writer),
|
|
|
|
|
.integer => stdx.fmt.formatInt(val, options, writer),
|
|
|
|
|
.bool => stdx.fmt.formatBool(val, options, writer),
|
|
|
|
|
.complex => stdx.fmt.formatComplex(val, options, writer),
|
2024-10-18 15:05:08 +00:00
|
|
|
};
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
}
|
2023-12-25 13:01:17 +00:00
|
|
|
if (self.rank() == 1) {
|
2024-10-18 15:05:08 +00:00
|
|
|
// Print a contiguous slice of items from the buffer in one line.
|
|
|
|
|
// The number of items printed is controlled by the user through format syntax.
|
2025-07-28 13:54:28 +00:00
|
|
|
try writer.splatByteAll(' ', indent_level);
|
2024-10-18 15:05:08 +00:00
|
|
|
switch (self.dtype()) {
|
2023-12-25 13:01:17 +00:00
|
|
|
inline else => |dt| {
|
|
|
|
|
const values = self.items(dt.toZigType());
|
2024-10-18 15:05:08 +00:00
|
|
|
switch (comptime dt.class()) {
|
|
|
|
|
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
|
|
|
|
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
2025-08-20 10:27:54 +00:00
|
|
|
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
|
|
|
|
|
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
|
2024-02-05 15:22:44 +00:00
|
|
|
}
|
2023-12-25 13:01:17 +00:00
|
|
|
},
|
2024-10-18 15:05:08 +00:00
|
|
|
}
|
|
|
|
|
try writer.writeByte('\n');
|
|
|
|
|
return;
|
2023-12-25 13:01:17 +00:00
|
|
|
}
|
2024-10-18 15:05:08 +00:00
|
|
|
// TODO: consider removing the \n if dim is 1 for this axis.
|
2025-07-28 13:54:28 +00:00
|
|
|
try writer.splatByteAll(' ', indent_level);
|
2023-12-25 13:01:17 +00:00
|
|
|
_ = try writer.write("{\n");
|
|
|
|
|
defer {
|
2025-07-28 13:54:28 +00:00
|
|
|
writer.splatByteAll(' ', indent_level) catch {};
|
2023-12-25 13:01:17 +00:00
|
|
|
_ = writer.write("},\n") catch {};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Write first rows
|
|
|
|
|
const n: u64 = @intCast(self.dim(0));
|
|
|
|
|
for (0..@min(num_rows, n)) |d| {
|
|
|
|
|
const di: i64 = @intCast(d);
|
|
|
|
|
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
2024-10-18 15:05:08 +00:00
|
|
|
try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
|
2023-12-25 13:01:17 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (n < num_rows) return;
|
|
|
|
|
// Skip middle rows
|
|
|
|
|
if (n > 2 * num_rows) {
|
2025-07-28 13:54:28 +00:00
|
|
|
try writer.splatByteAll(' ', indent_level + 2);
|
2023-12-25 13:01:17 +00:00
|
|
|
_ = try writer.write("...\n");
|
|
|
|
|
}
|
|
|
|
|
// Write last rows
|
|
|
|
|
for (@max(n - num_rows, num_rows)..n) |d| {
|
|
|
|
|
const di: i64 = @intCast(d);
|
|
|
|
|
const sliced_self = self.slice1d(0, .{ .start = di, .end = di + 1 }).squeeze(0);
|
2024-10-18 15:05:08 +00:00
|
|
|
try sliced_self.prettyPrintIndented(writer, num_rows, indent_level + 2, options);
|
2023-12-25 13:01:17 +00:00
|
|
|
}
|
|
|
|
|
}
|
2023-01-02 14:28:25 +00:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
fn parseArrayInfo(T: type) Shape {
|
|
|
|
|
return switch (@typeInfo(T)) {
|
2024-07-02 14:19:04 +00:00
|
|
|
.array => |arr| {
|
2023-01-02 14:28:25 +00:00
|
|
|
const s = parseArrayInfo(arr.child);
|
|
|
|
|
return s.insert(0, .{arr.len});
|
|
|
|
|
},
|
|
|
|
|
else => .{ ._dtype = DataType.fromZigType(T) },
|
|
|
|
|
};
|
|
|
|
|
}
|