Add buffer and hostbuffer utilities with precise f32→bf16 conversion, type inference for loadBuffers, store expected input shapes, enhance meta.visit and JSON TaggedUnion support, and improve logging.
This commit is contained in:
parent
1540c6e85e
commit
3849eb10b7
@ -350,7 +350,7 @@ pub const Client = opaque {
|
||||
}
|
||||
|
||||
pub const BufferFromHostBufferArgs = struct {
|
||||
data: []const u8,
|
||||
data: [*]const u8,
|
||||
buffer_type: BufferType,
|
||||
dims: []const i64,
|
||||
byte_strides: ?[]const i64,
|
||||
@ -362,7 +362,7 @@ pub const Client = opaque {
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
|
||||
.client = self.inner(),
|
||||
.data = @ptrCast(@constCast(args.data.ptr)),
|
||||
.data = @constCast(args.data),
|
||||
.type = @intFromEnum(args.buffer_type),
|
||||
.dims = @ptrCast(@constCast(args.dims.ptr)),
|
||||
.num_dims = args.dims.len,
|
||||
|
||||
@ -1,25 +1,46 @@
|
||||
pub const std = @import("std");
|
||||
const ParseFromValueError = std.json.ParseFromValueError;
|
||||
|
||||
/// Handle json fields that can have different Zig types depending on the message.
|
||||
/// Each union field should have a unique Zig type.
|
||||
///
|
||||
/// Example json:
|
||||
///
|
||||
/// ```json
|
||||
/// [
|
||||
/// { "question": "How old are you ?", "answer": 5 },
|
||||
/// { "question": "Count to three.", "answer": [1, 2, 3] },
|
||||
/// ]
|
||||
/// ```
|
||||
///
|
||||
/// Corresponding Zig code:
|
||||
///
|
||||
/// ```zig
|
||||
/// const Answer = union {
|
||||
/// number: i32,
|
||||
/// numbers: []const i32,
|
||||
/// };
|
||||
///
|
||||
/// const Message = struct {
|
||||
/// question: []const u8;
|
||||
/// answer: stdx.json.Union(Answer);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn Union(comptime T: type) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
|
||||
value: T,
|
||||
|
||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
|
||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
|
||||
return jsonParseFromValue(
|
||||
allocator,
|
||||
try std.json.innerParse(
|
||||
std.json.Value,
|
||||
allocator,
|
||||
source,
|
||||
options,
|
||||
),
|
||||
try std.json.innerParse(std.json.Value, allocator, source, options),
|
||||
options,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
|
||||
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
|
||||
inline for (std.meta.fields(T)) |field| {
|
||||
switch (field.type) {
|
||||
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
|
||||
@ -39,7 +60,67 @@ pub fn Union(comptime T: type) type {
|
||||
},
|
||||
}
|
||||
}
|
||||
return error.UnexpectedToken;
|
||||
return error.UnknownField;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Handle json fields that can have different Zig types depending on another field in the same message.
|
||||
/// This is translated to a Zig tagged union.
|
||||
///
|
||||
/// Example json:
|
||||
///
|
||||
/// ```json
|
||||
/// [
|
||||
/// { "type": "faq", "question": "How old are you ?", "answer": 5 },
|
||||
/// { "type": "address", "city": "NYC", "zipcode": "49130"},
|
||||
/// ]
|
||||
/// ```
|
||||
///
|
||||
/// Corresponding Zig struct:
|
||||
///
|
||||
/// ```zig
|
||||
/// const Entry = union {
|
||||
/// faq: struct { question: []const u8, answer: u32 },
|
||||
/// address: struct { city: []const u8, zipcode: []const u8 },
|
||||
/// };
|
||||
///
|
||||
/// const Message = []const stdx.json.TaggedUnion(Entry, "type");
|
||||
/// ```
|
||||
pub fn TaggedUnion(comptime T: type, comptime tag_name: [:0]const u8) type {
|
||||
return struct {
|
||||
const Self = @This();
|
||||
|
||||
value: T,
|
||||
|
||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
|
||||
return jsonParseFromValue(
|
||||
allocator,
|
||||
try std.json.innerParse(std.json.Value, allocator, source, options),
|
||||
options,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
|
||||
errdefer std.log.warn("failed to parse: {} as {s}", .{ source, @typeName(T) });
|
||||
if (source != .object) return error.UnexpectedToken;
|
||||
const o = source.object;
|
||||
const tag = o.get(tag_name) orelse return error.MissingField;
|
||||
for (o.keys(), o.values()) |k, v| {
|
||||
std.log.warn("object['{s}'] = {}", .{ k, v });
|
||||
}
|
||||
if (tag != .string) return error.LengthMismatch;
|
||||
inline for (std.meta.fields(T)) |field| {
|
||||
if (std.mem.eql(u8, field.name, tag.string)) {
|
||||
const inner_source = o.get(field.name) orelse return error.MissingField;
|
||||
const inner: field.type = std.json.innerParseFromValue(field.type, allocator, inner_source, options) catch |err| {
|
||||
std.log.warn("failed to interpret {s} as a {s}: {}", .{ tag.string, @typeName(field.type), err });
|
||||
return err;
|
||||
};
|
||||
return .{ .value = @unionInit(T, field.name, inner) };
|
||||
}
|
||||
}
|
||||
return error.InvalidEnumTag;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
25
zml/aio.zig
25
zml/aio.zig
@ -1,11 +1,9 @@
|
||||
const asynk = @import("async");
|
||||
const builtin = @import("builtin");
|
||||
const c = @import("c");
|
||||
const std = @import("std");
|
||||
const stdx = @import("stdx");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const zml = @import("zml.zig");
|
||||
const posix = @import("posix.zig");
|
||||
const asynk = @import("async");
|
||||
const c = @import("c");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
pub const gguf = @import("aio/gguf.zig");
|
||||
pub const nemo = @import("aio/nemo.zig");
|
||||
@ -13,10 +11,11 @@ pub const safetensors = @import("aio/safetensors.zig");
|
||||
pub const tinyllama = @import("aio/tinyllama.zig");
|
||||
pub const torch = @import("aio/torch.zig");
|
||||
pub const yaml = @import("aio/yaml.zig");
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const posix = @import("posix.zig");
|
||||
const zml = @import("zml.zig");
|
||||
|
||||
pub const log = std.log.scoped(.@"zml/aio");
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(gguf);
|
||||
@ -26,6 +25,8 @@ test {
|
||||
std.testing.refAllDecls(yaml);
|
||||
}
|
||||
|
||||
// TODO error set for weight loading
|
||||
|
||||
/// Detects the format of the model file (base on filename) and open it.
|
||||
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
|
||||
return if (std.mem.endsWith(u8, model_path, ".safetensors"))
|
||||
@ -422,7 +423,7 @@ fn _populateStruct(
|
||||
return true;
|
||||
},
|
||||
.float => {
|
||||
obj.* = undefined;
|
||||
obj.* = std.math.nan(@TypeOf(obj.*));
|
||||
return true;
|
||||
},
|
||||
.void => true,
|
||||
@ -450,7 +451,7 @@ test populateModel {
|
||||
|
||||
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
|
||||
fn _newHostBuffer(n: u32) zml.HostBuffer {
|
||||
return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined };
|
||||
return .{ ._shape = zml.Shape.init(.{n}, .f16), ._strides = undefined, ._data = undefined };
|
||||
}
|
||||
};
|
||||
|
||||
@ -500,7 +501,7 @@ test populateModel {
|
||||
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
|
||||
pub fn loadBuffers(
|
||||
comptime Model: type,
|
||||
init_args: anytype,
|
||||
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
||||
buffer_store: BufferStore,
|
||||
allocator: std.mem.Allocator,
|
||||
platform: zml.Platform,
|
||||
@ -513,8 +514,6 @@ pub fn loadBuffers(
|
||||
// If the Model has a "init" function, call it with the given parameters.
|
||||
if (@hasDecl(Model, "init")) {
|
||||
@call(.auto, Model.init, .{&model} ++ init_args);
|
||||
} else {
|
||||
stdx.debug.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
|
||||
}
|
||||
|
||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||
|
||||
@ -44,23 +44,6 @@ pub const Buffer = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const Shard = struct {
|
||||
api: *const pjrt.Api,
|
||||
buffer: *pjrt.Buffer,
|
||||
ready_event: ?*pjrt.Event = null,
|
||||
ready: bool = false,
|
||||
|
||||
pub fn awaitt(self: *Shard) !void {
|
||||
if (self.ready) {
|
||||
return;
|
||||
}
|
||||
if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| {
|
||||
try ev.awaitt(self.api);
|
||||
}
|
||||
self.ready = true;
|
||||
}
|
||||
};
|
||||
|
||||
_shape: Shape,
|
||||
_api: *const pjrt.Api,
|
||||
_shards: Shards,
|
||||
@ -88,7 +71,7 @@ pub const Buffer = struct {
|
||||
} else 0;
|
||||
|
||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
||||
const byte_strides = host_buffer.strides();
|
||||
|
||||
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
||||
const devices = platform.getDevices();
|
||||
@ -103,7 +86,7 @@ pub const Buffer = struct {
|
||||
platform.pjrt_client,
|
||||
platform.pjrt_api,
|
||||
pjrt.Client.BufferFromHostBufferArgs{
|
||||
.data = buf.data,
|
||||
.data = buf._data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
@ -155,6 +138,14 @@ pub const Buffer = struct {
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
|
||||
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
||||
// TODO restore assert
|
||||
// const memory = self.getMemory().kind(self._api);
|
||||
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
|
||||
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||
}
|
||||
|
||||
/// Creates a Buffer with a single element.
|
||||
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||
const x = dtype_.constant(val);
|
||||
@ -182,8 +173,8 @@ pub const Buffer = struct {
|
||||
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
|
||||
const host_buffer: HostBuffer = .{
|
||||
._shape = shape_,
|
||||
._strides = [1]i64{0} ** Shape.MAX_RANK,
|
||||
.data = x.constSlice(),
|
||||
._strides = @splat(0),
|
||||
._data = x.constSlice().ptr,
|
||||
};
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
@ -207,7 +198,7 @@ pub const Buffer = struct {
|
||||
},
|
||||
else => unreachable,
|
||||
}
|
||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes };
|
||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
|
||||
@ -228,12 +219,12 @@ pub const Buffer = struct {
|
||||
/// could lead to crashes and operations on the buffer will be slower.
|
||||
/// Tested on Cuda 12.4.
|
||||
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
|
||||
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
|
||||
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(buf._data));
|
||||
}
|
||||
|
||||
/// Creates a Buffer from a pointer into device memory.
|
||||
/// This allows to interface with other libraries producing buffers.
|
||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer {
|
||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer {
|
||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
@ -255,7 +246,7 @@ pub const Buffer = struct {
|
||||
.tile_dims_sizes = &.{},
|
||||
},
|
||||
},
|
||||
.stream = @bitCast(@as(usize, @intFromPtr(stream))),
|
||||
.stream = stream,
|
||||
}) catch @panic("failed to createViewOfDeviceBuffer");
|
||||
|
||||
var shards: Shards = .{};
|
||||
@ -296,7 +287,7 @@ pub const Buffer = struct {
|
||||
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
||||
const output = try HostBuffer.empty(allocator, self.shape());
|
||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.bytes()));
|
||||
if (maybe_event) |event| {
|
||||
try event.await_(self._api);
|
||||
}
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
const std = @import("std");
|
||||
|
||||
const floats = @import("floats.zig");
|
||||
|
||||
const C64 = std.math.Complex(f32);
|
||||
@ -111,9 +112,7 @@ pub const DataType = enum(u8) {
|
||||
}
|
||||
|
||||
pub fn toZigType(comptime dtype: DataType) type {
|
||||
return switch (dtype) {
|
||||
inline else => |tag| std.meta.TagPayload(Data, tag),
|
||||
};
|
||||
return @FieldType(Data, @tagName(dtype));
|
||||
}
|
||||
|
||||
pub fn isSignedInt(dtype: DataType) bool {
|
||||
@ -125,19 +124,19 @@ pub const DataType = enum(u8) {
|
||||
|
||||
pub fn sizeOf(self: DataType) u16 {
|
||||
return switch (self) {
|
||||
inline else => |tag| @sizeOf(std.meta.TagPayload(Data, tag)),
|
||||
inline else => |tag| @sizeOf(tag.toZigType()),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn bitSizeOf(self: DataType) u16 {
|
||||
return switch (self) {
|
||||
inline else => |tag| @bitSizeOf(std.meta.TagPayload(Data, tag)),
|
||||
inline else => |tag| @bitSizeOf(tag.toZigType()),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn alignOf(self: DataType) u29 {
|
||||
return switch (self) {
|
||||
inline else => |tag| @alignOf(std.meta.TagPayload(Data, tag)),
|
||||
inline else => |tag| @alignOf(tag.toZigType()),
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
50
zml/exe.zig
50
zml/exe.zig
@ -1,13 +1,13 @@
|
||||
const std = @import("std");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const aio = @import("aio.zig");
|
||||
const meta = @import("meta.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const meta = @import("meta.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
@ -147,6 +147,7 @@ pub const BaseExe = struct {
|
||||
/// Total number of buffers needed by this executable.
|
||||
input_buffer_count: u32,
|
||||
|
||||
input_shapes: []Shape,
|
||||
result_shapes: []Shape,
|
||||
|
||||
/// Num devices used (>1 for sharded executable)
|
||||
@ -155,34 +156,44 @@ pub const BaseExe = struct {
|
||||
/// Allocator backing memory
|
||||
_arena: std.heap.ArenaAllocator,
|
||||
|
||||
pub fn init(parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { n_in: u32, result_shapes: []const Shape, n_devices: u8 }) !BaseExe {
|
||||
pub fn init(
|
||||
parent_allocator: std.mem.Allocator,
|
||||
platform: Platform,
|
||||
exe: *pjrt.LoadedExecutable,
|
||||
args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 },
|
||||
) !BaseExe {
|
||||
var arena = std.heap.ArenaAllocator.init(parent_allocator);
|
||||
errdefer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
const n_in = args.input_shapes.len;
|
||||
const n_out = args.result_shapes.len;
|
||||
const n_devices = args.n_devices;
|
||||
// Allocate once for all the *pjrt.Buffer we need to store ...
|
||||
const all_buffers = try allocator.alloc(*pjrt.Buffer, (args.n_in + n_out) * n_devices);
|
||||
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ args.n_in * n_devices, n_out * n_devices });
|
||||
const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices);
|
||||
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices });
|
||||
|
||||
// ... and once for all the [*]*pjrt.Buffer.
|
||||
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
|
||||
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
|
||||
|
||||
for (0..n_devices) |i| {
|
||||
input_per_device[i] = all_input_buffers[i * args.n_in ..].ptr;
|
||||
input_per_device[i] = all_input_buffers[i * n_in ..].ptr;
|
||||
output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
|
||||
}
|
||||
|
||||
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
||||
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
||||
@memcpy(all_shapes[n_in..], args.result_shapes);
|
||||
return .{
|
||||
.platform = platform,
|
||||
.exe = exe,
|
||||
.ready_buffer_count = 0,
|
||||
.input_buffer_count = args.n_in,
|
||||
.input_buffer_count = @intCast(n_in),
|
||||
.num_devices = args.n_devices,
|
||||
.input_per_device = input_per_device,
|
||||
.output_per_device = output_per_device,
|
||||
.result_shapes = try allocator.dupe(Shape, args.result_shapes),
|
||||
.input_shapes = all_shapes[0..n_in],
|
||||
.result_shapes = all_shapes[n_in..],
|
||||
._arena = arena,
|
||||
};
|
||||
}
|
||||
@ -209,7 +220,9 @@ pub const BaseExe = struct {
|
||||
// even if it has been marked as "can be donated" during compilation.
|
||||
// TODO: expose it ?
|
||||
.non_donatable_input_indices = &.{},
|
||||
}) catch unreachable;
|
||||
}) catch |err| {
|
||||
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
||||
};
|
||||
|
||||
for (events[0..sharding.num_partitions]) |e| {
|
||||
if (e) |ev| {
|
||||
@ -232,7 +245,7 @@ pub const BaseExe = struct {
|
||||
// }
|
||||
|
||||
pub fn prepare(self: *BaseExe, x: anytype) void {
|
||||
const n = fillBuffers(&x, self.input_per_device, self.ready_buffer_count);
|
||||
const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count);
|
||||
self.ready_buffer_count += n;
|
||||
}
|
||||
|
||||
@ -244,6 +257,14 @@ pub const BaseExe = struct {
|
||||
|
||||
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
|
||||
}
|
||||
|
||||
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
||||
return .init(parent_allocator, self.platform, self.exe, .{
|
||||
.input_shapes = self.input_shapes,
|
||||
.result_shapes = self.result_shapes,
|
||||
.n_devices = self.num_devices,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Represents a ZML function, compiled into a PJRT executable.
|
||||
@ -280,7 +301,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
||||
}
|
||||
|
||||
pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
|
||||
const total_ready = fillBuffers(&args, self.inner.input_per_device, self.inner.ready_buffer_count);
|
||||
const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count);
|
||||
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
||||
self.inner._unsafeCall();
|
||||
var result: Bufferized(ReturnT) = undefined;
|
||||
@ -302,20 +323,23 @@ fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
|
||||
}
|
||||
|
||||
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
||||
fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
|
||||
fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
|
||||
const LocalContext = struct {
|
||||
index: u32,
|
||||
buffers: []const [*]*pjrt.Buffer,
|
||||
shapes: []const Shape,
|
||||
};
|
||||
var context: LocalContext = .{
|
||||
.index = start,
|
||||
.buffers = buffers,
|
||||
.shapes = shapes,
|
||||
};
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
||||
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
||||
const model_sharding = ctx.buffers.len;
|
||||
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
||||
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
|
||||
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
||||
ctx.buffers[d][ctx.index] = shard;
|
||||
}
|
||||
|
||||
@ -305,11 +305,23 @@ pub const BFloat16 = packed struct(u16) {
|
||||
pub fn isInf(self: BFloat16) bool {
|
||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||
}
|
||||
|
||||
pub fn toF32(self: BFloat16) f32 {
|
||||
// Pad the BF16 with zeros 0
|
||||
return @bitCast([2]u16{ 0, @bitCast(self) });
|
||||
}
|
||||
|
||||
pub fn fromF32(float32: f32) BFloat16 {
|
||||
var int: u32 = @bitCast(float32);
|
||||
// Round up if needed.
|
||||
int += 0x8000;
|
||||
const parts: [2]u16 = @bitCast(int);
|
||||
return @bitCast(parts[1]);
|
||||
}
|
||||
|
||||
const Helpers = FloatHelpers(@This());
|
||||
pub const zero = Helpers.zero;
|
||||
pub const neg = Helpers.neg;
|
||||
pub const fromF32 = Helpers.fromF32;
|
||||
pub const toF32 = Helpers.toF32;
|
||||
pub const format = Helpers.format;
|
||||
};
|
||||
|
||||
@ -317,7 +329,7 @@ test BFloat16 {
|
||||
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
|
||||
try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 66 });
|
||||
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
|
||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
|
||||
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));
|
||||
|
||||
@ -18,8 +18,8 @@ test {
|
||||
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
||||
pub const HostBuffer = struct {
|
||||
_shape: Shape,
|
||||
_strides: ?[Shape.MAX_RANK]i64 = null,
|
||||
data: []const u8,
|
||||
_strides: [Shape.MAX_RANK]i64,
|
||||
_data: [*]const u8,
|
||||
_memory: union(enum) {
|
||||
managed: std.mem.Alignment,
|
||||
unmanaged,
|
||||
@ -28,10 +28,11 @@ pub const HostBuffer = struct {
|
||||
/// Allocates a HostBuffer with the given shape.
|
||||
/// The memory is left undefined.
|
||||
/// The caller owns the memory, and need to call `deinit()`.
|
||||
pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer {
|
||||
pub fn empty(allocator: std.mem.Allocator, sh: Shape) error{OutOfMemory}!HostBuffer {
|
||||
return .{
|
||||
._shape = sh,
|
||||
.data = try allocator.alignedAlloc(u8, 64, sh.byteSize()),
|
||||
._strides = sh.computeStrides().buffer,
|
||||
._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr,
|
||||
._memory = .{ .managed = .@"64" },
|
||||
};
|
||||
}
|
||||
@ -43,7 +44,8 @@ pub const HostBuffer = struct {
|
||||
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
|
||||
return .{
|
||||
._shape = shape_,
|
||||
.data = data_,
|
||||
._strides = shape_.computeStrides().buffer,
|
||||
._data = data_.ptr,
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
}
|
||||
@ -53,7 +55,7 @@ pub const HostBuffer = struct {
|
||||
// This means we don't own the data.
|
||||
if (self._memory == .unmanaged) return;
|
||||
const log2_align = self._memory.managed;
|
||||
allocator.rawFree(@constCast(self.data), log2_align, @returnAddress());
|
||||
allocator.rawFree(self.mutBytes(), log2_align, @returnAddress());
|
||||
}
|
||||
|
||||
/// Wraps an exisiting slice into a HostBuffer.
|
||||
@ -62,10 +64,12 @@ pub const HostBuffer = struct {
|
||||
/// that will still need to be deallocated.
|
||||
pub fn fromSlice(sh: anytype, s: anytype) HostBuffer {
|
||||
const shape_ = Shape.init(sh, DataType.fromSliceElementType(s));
|
||||
std.debug.assert(shape_.count() == s.len);
|
||||
const raw_bytes = std.mem.sliceAsBytes(s);
|
||||
std.debug.assert(shape_.byteSize() == raw_bytes.len);
|
||||
return .{
|
||||
._shape = shape_,
|
||||
.data = @alignCast(std.mem.sliceAsBytes(s)),
|
||||
._strides = shape_.computeStrides().buffer,
|
||||
._data = raw_bytes.ptr,
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
}
|
||||
@ -81,7 +85,7 @@ pub const HostBuffer = struct {
|
||||
@memcpy(tmp[0..strides_.len], strides_);
|
||||
return .{
|
||||
._shape = sh,
|
||||
.data = @alignCast(std.mem.sliceAsBytes(s)),
|
||||
._data = @alignCast(std.mem.sliceAsBytes(s).ptr),
|
||||
._strides = tmp,
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
@ -89,13 +93,15 @@ pub const HostBuffer = struct {
|
||||
|
||||
/// Creates a tensor from a **pointer** to a "multi dimension" array.
|
||||
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
|
||||
/// Typically this is use with constant arrays.
|
||||
pub fn fromArray(arr_ptr: anytype) HostBuffer {
|
||||
const T = @TypeOf(arr_ptr.*);
|
||||
const sh = parseArrayInfo(T);
|
||||
std.debug.assert(sh.byteSize() == @sizeOf(T));
|
||||
return .{
|
||||
._shape = sh,
|
||||
.data = @alignCast(std.mem.sliceAsBytes(arr_ptr)),
|
||||
// Array are typically stack allocated and don't need to be freed.
|
||||
._strides = sh.computeStrides().buffer,
|
||||
._data = @ptrCast(arr_ptr),
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
}
|
||||
@ -121,16 +127,15 @@ pub const HostBuffer = struct {
|
||||
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||
|
||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||
const b = dt.sizeOf();
|
||||
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
||||
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
||||
var data_ = @constCast(res.data);
|
||||
switch (dt) {
|
||||
inline else => {
|
||||
inline else => |d| if (comptime d.class() != .integer) {
|
||||
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
||||
} else {
|
||||
const Zt = d.toZigType();
|
||||
var j: i64 = args.start;
|
||||
for (0..@intCast(n_steps)) |i| {
|
||||
var v = Data.init(dt, j);
|
||||
@memcpy(data_[i * b .. (i + 1) * b], v.constSlice());
|
||||
for (res.mutItems(Zt)) |*val| {
|
||||
val.* = @intCast(j);
|
||||
j +%= args.step;
|
||||
}
|
||||
},
|
||||
@ -160,16 +165,26 @@ pub const HostBuffer = struct {
|
||||
/// WARNING: It's only valid if the buffer is contiguous.
|
||||
/// Strided buffers can't use this method.
|
||||
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
||||
if (DataType.fromZigType(T) != self.dtype()) {
|
||||
std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
||||
}
|
||||
if (!self.isContiguous()) {
|
||||
std.debug.panic("{} isn't contiguous", .{self});
|
||||
}
|
||||
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
|
||||
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
||||
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
||||
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
||||
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
|
||||
return ptr[0..self._shape.count()];
|
||||
}
|
||||
|
||||
pub fn mutItems(self: HostBuffer, comptime T: type) []T {
|
||||
return @constCast(self.items(T));
|
||||
}
|
||||
|
||||
pub fn bytes(self: HostBuffer) []const u8 {
|
||||
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
||||
return self._data[0..self._shape.byteSize()];
|
||||
}
|
||||
|
||||
pub fn mutBytes(self: HostBuffer) []u8 {
|
||||
return @constCast(self.bytes());
|
||||
}
|
||||
|
||||
pub fn shape(self: HostBuffer) Shape {
|
||||
return self._shape;
|
||||
}
|
||||
@ -178,9 +193,9 @@ pub const HostBuffer = struct {
|
||||
return self._shape.dtype();
|
||||
}
|
||||
|
||||
pub fn strides(self: *const HostBuffer) ?[]const i64 {
|
||||
pub fn strides(self: *const HostBuffer) []const i64 {
|
||||
// Pass strides per pointer otherwise we return a pointer to this stack frame.
|
||||
return if (self._strides) |*strd| strd[0..self.rank()] else null;
|
||||
return self._strides[0..self._shape.rank()];
|
||||
}
|
||||
|
||||
// TODO: rename .data into ._data and make it a [*]u8
|
||||
@ -205,7 +220,7 @@ pub const HostBuffer = struct {
|
||||
}
|
||||
|
||||
pub fn isContiguous(self: HostBuffer) bool {
|
||||
const _strides = self._strides orelse return true;
|
||||
const _strides = self._strides;
|
||||
const cont_strides = self._shape.computeStrides();
|
||||
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
|
||||
if (d != 1 and stride != cont_stride) return false;
|
||||
@ -217,6 +232,7 @@ pub const HostBuffer = struct {
|
||||
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
||||
var res = self;
|
||||
res._shape = self._shape.reshape(shape_);
|
||||
res._strides = res._shape.computeStrides().buffer;
|
||||
return res;
|
||||
}
|
||||
|
||||
@ -236,15 +252,12 @@ pub const HostBuffer = struct {
|
||||
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
||||
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
||||
|
||||
// If strides weren't set it means original buffer is contiguous.
|
||||
// But it won't be anymore after slicing. The strides don't change though.
|
||||
const _strides = self._strides orelse self._shape.computeStrides().buffer;
|
||||
const offset: usize = @intCast(start * _strides[ax]);
|
||||
const offset: usize = @intCast(start * self._strides[ax]);
|
||||
const new_shape = self.shape().set(ax, end - start);
|
||||
return .{
|
||||
._shape = self.shape().set(ax, end - start),
|
||||
.data = self.data[offset..],
|
||||
// When axis is 0, we stay contiguous.
|
||||
._strides = if (ax == 0) self._strides else _strides,
|
||||
._shape = new_shape,
|
||||
._data = self._data[offset..],
|
||||
._strides = self._strides,
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
}
|
||||
@ -254,18 +267,52 @@ pub const HostBuffer = struct {
|
||||
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
|
||||
}
|
||||
|
||||
pub fn choose(self: HostBuffer, offsets: anytype) HostBuffer {
|
||||
const off, const tags = Shape.parseDimensions(offsets);
|
||||
var sh = self._shape;
|
||||
var offset: i64 = 0;
|
||||
for (off.constSlice(), tags.constSlice()) |o, t| {
|
||||
const ax = sh.axis(t);
|
||||
offset += o * self._strides[ax];
|
||||
sh._dims.buffer[ax] = 0;
|
||||
}
|
||||
|
||||
var new_strides: [Shape.MAX_RANK]i64 = @splat(self.dtype().sizeOf());
|
||||
|
||||
// TODO rewrite with simd. This is a pshuf, but it's not supported by @shuffle.
|
||||
var res_ax: u32 = 0;
|
||||
for (0..self._shape.rank()) |ax| {
|
||||
if (sh._dims.buffer[ax] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
sh._dims.buffer[res_ax] = self._shape._dims.buffer[ax];
|
||||
sh._tags.buffer[res_ax] = self._shape._tags.buffer[ax];
|
||||
new_strides[res_ax] = self._strides[ax];
|
||||
res_ax += 1;
|
||||
}
|
||||
sh._dims.len -= off.len;
|
||||
sh._tags.len -= off.len;
|
||||
|
||||
return HostBuffer{
|
||||
._shape = sh,
|
||||
._strides = new_strides,
|
||||
._data = self._data[@intCast(offset)..],
|
||||
._memory = .unmanaged,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
||||
const ax = self._shape.axis(axis_);
|
||||
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
||||
|
||||
var _strides: ?[Shape.MAX_RANK]i64 = self._strides;
|
||||
if (self._strides) |strydes| {
|
||||
std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]);
|
||||
}
|
||||
var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
|
||||
_ = strd.orderedRemove(ax);
|
||||
|
||||
return .{
|
||||
._shape = self.shape().drop(ax),
|
||||
.data = self.data,
|
||||
._strides = _strides,
|
||||
._data = self._data,
|
||||
._strides = strd.buffer,
|
||||
._memory = self._memory,
|
||||
};
|
||||
}
|
||||
@ -276,10 +323,13 @@ pub const HostBuffer = struct {
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
if (std.mem.eql(u8, fmt, "v")) {
|
||||
try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||
} else {
|
||||
try writer.print("HostBuffer(.{_})", .{self._shape});
|
||||
}
|
||||
}
|
||||
|
||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
||||
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
||||
|
||||
128
zml/meta.zig
128
zml/meta.zig
@ -237,42 +237,48 @@ test mapAlloc {
|
||||
/// Recursively visit the given struct and calls the callback for each K found.
|
||||
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
||||
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
const T = @TypeOf(v);
|
||||
const type_info_v = @typeInfo(T);
|
||||
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
||||
.pointer => |info| info.child,
|
||||
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
||||
};
|
||||
|
||||
if (type_info_v != .pointer) {
|
||||
const Callback = @TypeOf(cb);
|
||||
stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T });
|
||||
const Ptr = @TypeOf(v);
|
||||
const type_info_v = @typeInfo(Ptr);
|
||||
if (type_info_v != .pointer) {
|
||||
stdx.debug.compileError("zml.meta.visit({}) is expecting a pointer/slice input, but received: {}", .{ Callback, Ptr });
|
||||
}
|
||||
const ptr_info = type_info_v.pointer;
|
||||
if (@typeInfo(ptr_info.child) == .@"fn") return;
|
||||
if (ptr_info.child == anyopaque) return;
|
||||
// This is important, because with trivial types like void,
|
||||
// Zig sometimes decide to call `visit` at comptime, but can't do
|
||||
// the pointer wrangling logic at comptime.
|
||||
// So we detect early this case and return.
|
||||
if (@sizeOf(ptr_info.child) == 0) return;
|
||||
const Child = ptr_info.child;
|
||||
|
||||
const K, const mutating_cb = switch (@typeInfo(FnParam(cb, 1))) {
|
||||
.pointer => |info| .{ info.child, !info.is_const },
|
||||
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
||||
};
|
||||
// Abort if v doesnt' contain any K.
|
||||
if (comptime !Contains(Ptr, K)) return;
|
||||
|
||||
// Handle simple cases.
|
||||
switch (Ptr) {
|
||||
*const K, *K => return cb(ctx, v),
|
||||
*const ?K, *?K => return if (v.*) |*val| cb(ctx, val) else {},
|
||||
[]const K, []K => {
|
||||
for (v) |*v_elem| cb(ctx, v_elem);
|
||||
return;
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
|
||||
// Handle std.BoundedArray that contains uninitalized data.
|
||||
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
|
||||
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
|
||||
}
|
||||
|
||||
// Recursively visit fields of v.
|
||||
switch (ptr_info.size) {
|
||||
// If we have a single pointer, two cases:
|
||||
// * It's a pointer to K, in which case we call the callback.
|
||||
// * It's a pointer to something else, in which case, we explore and recurse if needed.
|
||||
.one => if (ptr_info.child == K) {
|
||||
cb(ctx, v);
|
||||
} else if (ptr_info.child == ?K) {
|
||||
if (v.*) |*val| cb(ctx, val);
|
||||
} else switch (@typeInfo(ptr_info.child)) {
|
||||
.@"struct" => |s| inline for (s.fields) |field_info| {
|
||||
if (field_info.is_comptime) continue;
|
||||
const field_type_info = @typeInfo(field_info.type);
|
||||
.one => switch (@typeInfo(Child)) {
|
||||
.@"struct" => |s| inline for (s.fields) |field| {
|
||||
if (field.is_comptime or comptime !Contains(field.type, K)) continue;
|
||||
const field_type_info = @typeInfo(field.type);
|
||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||
switch (field_type_info) {
|
||||
.pointer => visit(cb, ctx, @field(v, field_info.name)),
|
||||
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field_info.name)),
|
||||
.pointer => visit(cb, ctx, @field(v, field.name)),
|
||||
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field.name)),
|
||||
else => {},
|
||||
}
|
||||
},
|
||||
@ -281,23 +287,19 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
.@"union" => switch (v.*) {
|
||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||
},
|
||||
else => {},
|
||||
else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
|
||||
},
|
||||
// If we have a slice, two cases also:
|
||||
// * It's a slice of K, in which case we call the callback for each element of the slice.
|
||||
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
|
||||
.slice => {
|
||||
for (v) |*v_elem| {
|
||||
if (ptr_info.child == K) {
|
||||
cb(ctx, v_elem);
|
||||
} else switch (@typeInfo(ptr_info.child)) {
|
||||
.@"struct" => |s| inline for (s.fields) |field_info| {
|
||||
const field_type_info = @typeInfo(field_info.type);
|
||||
switch (@typeInfo(Child)) {
|
||||
.@"struct" => |s| inline for (s.fields) |field| {
|
||||
if (field.is_comptime or comptime !Contains(field.type, K)) continue;
|
||||
const field_type_info = @typeInfo(field.type);
|
||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||
if (field_type_info == .pointer) {
|
||||
visit(cb, ctx, @field(v_elem, field_info.name));
|
||||
visit(cb, ctx, @field(v_elem, field.name));
|
||||
} else {
|
||||
visit(cb, ctx, &@field(v_elem, field_info.name));
|
||||
visit(cb, ctx, &@field(v_elem, field.name));
|
||||
}
|
||||
},
|
||||
.array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
||||
@ -305,11 +307,11 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
.@"union" => switch (v_elem.*) {
|
||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||
},
|
||||
else => {},
|
||||
else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
|
||||
}
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
.many, .c => stdx.debug.compileError("zml.meta.visit({}) doesn't support [*] style pointers, got: {}", .{ Callback, Ptr }),
|
||||
}
|
||||
}
|
||||
|
||||
@ -320,7 +322,7 @@ test visit {
|
||||
const NestedAttrOptional = struct { nested: ?Attr };
|
||||
const SimpleStruct = struct { prop: Attr };
|
||||
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
|
||||
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional };
|
||||
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) };
|
||||
|
||||
const LocalContext = struct { result: usize };
|
||||
|
||||
@ -374,11 +376,16 @@ test visit {
|
||||
}
|
||||
{
|
||||
var context: LocalContext = .{ .result = 0 };
|
||||
const prop5: std.BoundedArray(Attr, 8) = .{
|
||||
.buffer = @splat(.{ .data = 4 }),
|
||||
.len = 2,
|
||||
};
|
||||
const container: NestedTypesStruct = .{
|
||||
.prop1 = .{ .data = 1 },
|
||||
.prop2 = .{ .other = "hello" },
|
||||
.prop3 = .{ .nested = .{ .data = 2 } },
|
||||
.prop4 = .{ .nested = .{ .data = 3 } },
|
||||
.prop5 = prop5, // 4 will be counted twice.
|
||||
};
|
||||
|
||||
visit((struct {
|
||||
@ -387,7 +394,7 @@ test visit {
|
||||
}
|
||||
}).cb, &context, &container);
|
||||
|
||||
try std.testing.expectEqual(6, context.result);
|
||||
try std.testing.expectEqual(14, context.result);
|
||||
}
|
||||
}
|
||||
|
||||
@ -533,3 +540,36 @@ fn _CollectArg(func: anytype) type {
|
||||
const params = @typeInfo(@TypeOf(func)).@"fn".params;
|
||||
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
||||
}
|
||||
|
||||
pub fn Contains(Haystack: type, T: type) bool {
|
||||
switch (Haystack) {
|
||||
T, ?T => return true,
|
||||
*T, ?*T => return true,
|
||||
*const T, ?*const T => return true,
|
||||
[]const T, ?[]const T => return true,
|
||||
anyopaque => return false,
|
||||
else => {},
|
||||
}
|
||||
|
||||
return switch (@typeInfo(Haystack)) {
|
||||
.@"struct" => |info| {
|
||||
inline for (info.fields) |field| {
|
||||
if (Contains(field.type, T))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
.@"union" => |info| {
|
||||
inline for (info.fields) |field| {
|
||||
if (Contains(field.type, T))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
.array => |info| Contains(info.child, T),
|
||||
.pointer => |info| Contains(info.child, T),
|
||||
.optional => |info| Contains(info.child, T),
|
||||
.vector => |info| Contains(info.child, T),
|
||||
else => false,
|
||||
};
|
||||
}
|
||||
|
||||
@ -29,7 +29,7 @@ test {
|
||||
|
||||
pub const MlirFn = struct {
|
||||
name: []const u8,
|
||||
num_args: u32,
|
||||
args_shapes: []Shape,
|
||||
res_tensors: *const anyopaque,
|
||||
res_types: []mlir.Type,
|
||||
res_shapes: []Shape,
|
||||
@ -199,7 +199,7 @@ pub const CompilationContext = struct {
|
||||
const loaded_executable: *pjrt.LoadedExecutable = blk: {
|
||||
if (pjrt_location) |pjrt_loc| {
|
||||
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
|
||||
log.info("Loaded pre-compiled module from {s}", .{pjrt_loc});
|
||||
log.info("Loaded pre-compiled module from {s} (generated from {s}/module.mlir)", .{ pjrt_loc, module_dir.? });
|
||||
break :blk exe;
|
||||
} else |err| {
|
||||
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
|
||||
@ -233,7 +233,7 @@ pub const CompilationContext = struct {
|
||||
self._platform,
|
||||
loaded_executable,
|
||||
.{
|
||||
.n_in = f.num_args,
|
||||
.input_shapes = f.args_shapes,
|
||||
.result_shapes = f.res_shapes,
|
||||
.n_devices = sharding.num_replicas * sharding.num_partitions,
|
||||
},
|
||||
@ -341,7 +341,7 @@ pub const CompilationContext = struct {
|
||||
const locations = try arena.alloc(mlir.Location, tensor_count);
|
||||
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
||||
|
||||
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
|
||||
var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count);
|
||||
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
||||
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||
|
||||
@ -427,7 +427,7 @@ pub const CompilationContext = struct {
|
||||
return .{
|
||||
.mlir_fn = mlir_fn,
|
||||
.name = opts.name,
|
||||
.num_args = @intCast(tensor_count),
|
||||
.args_shapes = input_shapes.items,
|
||||
.res_tensors = fn_res,
|
||||
.res_types = fn_res_types,
|
||||
.res_shapes = fn_res_shapes,
|
||||
@ -512,7 +512,7 @@ pub const CompilationContext = struct {
|
||||
|
||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||
try std.testing.expectEqual(3, f.num_args);
|
||||
try std.testing.expectEqual(3, f.args_shapes.len);
|
||||
// We should have two buffers being donated.
|
||||
const template = "tf.aliasing_output = {d} : i32";
|
||||
var buf = template.*;
|
||||
@ -540,9 +540,13 @@ pub const CompilationContext = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn numPartitions(self: CompilationContext) u8 {
|
||||
return self._platform.sharding().num_partitions;
|
||||
}
|
||||
|
||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
||||
const ctx = self.mlirCtx();
|
||||
const num_partitions = self._platform.sharding().num_partitions;
|
||||
const num_partitions = self.numPartitions();
|
||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
||||
@ -645,10 +649,11 @@ pub const CompilationContext = struct {
|
||||
|
||||
const loc = self.mlirCtx().location(@src());
|
||||
|
||||
const values = try arena.alloc(mlir.Value, function.num_args);
|
||||
const num_args = function.args_shapes.len;
|
||||
const values = try arena.alloc(mlir.Value, num_args);
|
||||
self.extractValues(&args, values);
|
||||
|
||||
const donations = try arena.alloc(Tensor._Donation, function.num_args);
|
||||
const donations = try arena.alloc(Tensor._Donation, num_args);
|
||||
meta.collectBuf(struct {
|
||||
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
|
||||
return ctx.getValueAndDonation(x)[1];
|
||||
|
||||
@ -176,6 +176,8 @@ pub const RopeOpts = struct {
|
||||
/// Read a Rope scaling config from HF config.json format.
|
||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
|
||||
const content = try std.json.Value.jsonParse(allocator, source, options);
|
||||
if (content == .null) return .default;
|
||||
|
||||
if (content != .object) return error.InvalidEnumTag;
|
||||
|
||||
const obj = content.object;
|
||||
|
||||
@ -58,10 +58,10 @@ pub const Shape = struct {
|
||||
const fv = @field(v, field.name);
|
||||
if (comptime stdx.meta.isInteger(field.type)) {
|
||||
dims_.appendAssumeCapacity(@intCast(fv));
|
||||
} else if (comptime isAutoDim(fv)) {
|
||||
} else if (@TypeOf(fv) == EnumLiteral and comptime isAutoDim(fv)) {
|
||||
dims_.appendAssumeCapacity(-1);
|
||||
} else {
|
||||
stdx.debug.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
|
||||
stdx.debug.compileError("Field {s} should be an integer or an auto dimension, got {}", .{ field.name, field.type });
|
||||
}
|
||||
if (comptime stdx.meta.isTuple(T)) {
|
||||
tags_.appendAssumeCapacity(TagUnknown);
|
||||
@ -186,7 +186,7 @@ pub const Shape = struct {
|
||||
EnumLiteral => @tagName(v).ptr,
|
||||
std.builtin.Type.StructField => v.name.ptr,
|
||||
Tag => v,
|
||||
else => stdx.debug.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
||||
else => stdx.debug.compileError("Shape tag should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
||||
};
|
||||
}
|
||||
|
||||
@ -581,6 +581,41 @@ pub const Shape = struct {
|
||||
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
|
||||
}
|
||||
|
||||
pub fn removeMany(self: Shape, axes_: anytype) Shape {
|
||||
var to_remove = self.axes(axes_);
|
||||
if (to_remove.len == 0) return self;
|
||||
std.mem.sort(u3, to_remove.slice(), {}, std.sort.asc(u3));
|
||||
|
||||
var sh: Shape = self;
|
||||
const rk = self.rank();
|
||||
var res_ax: u32 = 0;
|
||||
for (0..rk) |ax| {
|
||||
if (std.mem.indexOfScalar(u3, to_remove.constSlice(), @intCast(ax))) |_| {
|
||||
continue;
|
||||
}
|
||||
|
||||
sh._dims.buffer[res_ax] = self._dims.buffer[ax];
|
||||
sh._tags.buffer[res_ax] = self._tags.buffer[ax];
|
||||
res_ax += 1;
|
||||
}
|
||||
sh._dims.len = rk - to_remove.len;
|
||||
sh._tags.len = rk - to_remove.len;
|
||||
return sh;
|
||||
}
|
||||
|
||||
test removeMany {
|
||||
try std.testing.expectEqualSlices(
|
||||
i64,
|
||||
&.{12},
|
||||
Shape.init(.{ 10, 11, 12 }, .f32).removeMany(.{ 0, 1 }).dims(),
|
||||
);
|
||||
try std.testing.expectEqualSlices(
|
||||
i64,
|
||||
&.{ 10, 11 },
|
||||
Shape.init(.{ 10, 11, 12, 13 }, .f32).removeMany(.{ -1, -2 }).dims(),
|
||||
);
|
||||
}
|
||||
|
||||
pub fn transpose(self: Shape, permutations: anytype) Shape {
|
||||
std.debug.assert(self.rank() == permutations.len);
|
||||
const permutations_ = self.axes(permutations);
|
||||
@ -729,7 +764,9 @@ pub const Shape = struct {
|
||||
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
|
||||
var res = self;
|
||||
inline for (std.meta.fields(T)) |field| {
|
||||
res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
|
||||
const new_field = @field(renames, field.name);
|
||||
stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
|
||||
res._tags.set(self.axis(field), toTag(new_field));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@ -749,15 +786,20 @@ pub const Shape = struct {
|
||||
}
|
||||
|
||||
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
|
||||
const base_stride = self.dtype().sizeOf();
|
||||
const rk = self.rank();
|
||||
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
|
||||
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
|
||||
if (rk == 0) return strides;
|
||||
strides.buffer[rk - 1] = base_stride;
|
||||
for (1..rk) |i| {
|
||||
const j = @as(usize, rk) - 1 - i;
|
||||
strides.buffer[j] = self._dims.get(j + 1) * strides.buffer[j + 1];
|
||||
}
|
||||
|
||||
const V = @Vector(MAX_RANK, i64);
|
||||
const rank_mask = std.simd.iota(u8, MAX_RANK) < @as(@Vector(MAX_RANK, u8), @splat(rk));
|
||||
// For each axis compute the product of all following dimensions
|
||||
// and the element size in bytes.
|
||||
var d: V = @bitCast(self._dims.buffer);
|
||||
d = @select(i64, rank_mask, d, @as(V, @splat(1)));
|
||||
d = std.simd.shiftElementsLeft(d, 1, self.dtype().sizeOf());
|
||||
d = std.simd.prefixScan(.Mul, -1, d);
|
||||
|
||||
strides.buffer = @bitCast(d);
|
||||
return strides;
|
||||
}
|
||||
|
||||
|
||||
@ -176,6 +176,7 @@ pub const Tensor = struct {
|
||||
var res = self;
|
||||
res._shape = self._shape.withSharding(axes_);
|
||||
|
||||
if (ctx.numPartitions() <= 1) return self;
|
||||
const op = dialect.stablehlo.custom_call(
|
||||
mlir_ctx,
|
||||
&.{self.value()},
|
||||
@ -1279,9 +1280,9 @@ pub const Tensor = struct {
|
||||
/// see: https://paperswithcode.com/method/gelu
|
||||
pub fn gelu(x: Tensor) Tensor {
|
||||
const scaled_x_cube = x.mul(x).mul(x).scale(0.044715);
|
||||
const one = Tensor.constant(x._shape, x.dtype().one());
|
||||
const one_plus_tanh = Tensor.add(x, scaled_x_cube).scale(std.math.sqrt(2.0 / std.math.pi)).tanh().add(one);
|
||||
return one_plus_tanh.mul(x).scale(0.5);
|
||||
const beta = std.math.sqrt(2.0 / std.math.pi);
|
||||
const tanh_ = x.add(scaled_x_cube).scale(beta).tanh();
|
||||
return tanh_.addConstant(1).mul(x).scale(0.5);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
|
||||
@ -1526,8 +1527,34 @@ pub const Tensor = struct {
|
||||
|
||||
pub const Slice = struct {
|
||||
start: i64 = 0,
|
||||
end: ?i64 = null,
|
||||
step: i64 = 1,
|
||||
end: i64 = to_the_end,
|
||||
step: i32 = 1,
|
||||
singleton: bool = false,
|
||||
|
||||
pub fn single(offset: i64) Slice {
|
||||
return .{ .start = offset, .end = offset + 1, .singleton = true };
|
||||
}
|
||||
|
||||
const to_the_end = std.math.maxInt(i64);
|
||||
|
||||
pub fn format(
|
||||
self: Slice,
|
||||
comptime fmt: []const u8,
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
if (self.singleton) {
|
||||
try writer.print("[{}]", .{self.start});
|
||||
} else if (self.end == to_the_end and self.step == 1) {
|
||||
try writer.print("[{}..]", .{self.start});
|
||||
} else if (self.step == 1) {
|
||||
try writer.print("[{}..{}]", .{ self.start, self.end });
|
||||
} else {
|
||||
try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step });
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// Slices the input Tensor over the given axis using the given parameters.
|
||||
@ -1549,13 +1576,13 @@ pub const Tensor = struct {
|
||||
|
||||
const args: Slice = .{
|
||||
.start = self.wrapIndex(a, s.start),
|
||||
.end = if (s.end) |end| self.wrapIndex(a, end) else self.dim(a),
|
||||
.end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end),
|
||||
.step = s.step,
|
||||
};
|
||||
start_indices[a] = args.start;
|
||||
limit_indices[a] = args.end.?;
|
||||
limit_indices[a] = args.end;
|
||||
strides[a] = args.step;
|
||||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
||||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable);
|
||||
}
|
||||
|
||||
const mlir_ctx = self.getContext().mlirCtx();
|
||||
@ -1571,7 +1598,12 @@ pub const Tensor = struct {
|
||||
loc,
|
||||
);
|
||||
|
||||
return _result(res_shape, slice_op.result(0));
|
||||
var res = _result(res_shape, slice_op.result(0));
|
||||
var to_remove: Shape.AxesArray = .{};
|
||||
for (slices, 0..) |s, a| {
|
||||
if (s.singleton) to_remove.appendAssumeCapacity(@intCast(a));
|
||||
}
|
||||
return res.reshape(res_shape.removeMany(to_remove.constSlice()));
|
||||
}
|
||||
|
||||
test slice {
|
||||
@ -1606,8 +1638,17 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor {
|
||||
// TODO: this use case could be handled directly by slice if we added a .single field
|
||||
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_);
|
||||
return self.slice1d(axis_, .single(i));
|
||||
}
|
||||
|
||||
pub fn choose(self: Tensor, offsets: anytype) Tensor {
|
||||
const off, const tags = Shape.parseDimensions(offsets);
|
||||
var slices = [_]Slice{.{}} ** MAX_RANK;
|
||||
for (off.constSlice(), tags.constSlice()) |o, t| {
|
||||
const ax = self.axis(t);
|
||||
slices[ax] = .single(o);
|
||||
}
|
||||
return self.slice(slices[0..self.rank()]);
|
||||
}
|
||||
|
||||
/// Concatenates the input Tensors along the given axis.
|
||||
@ -1866,7 +1907,12 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a 0-rank Tensor with the given value.
|
||||
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
||||
return Tensor.constant(.{}, Data.init(dt, val));
|
||||
const data = Data.init(dt, val);
|
||||
switch (dt.class()) {
|
||||
.float => stdx.debug.assert(!std.math.isNan(val), "scalar(NaN) is probably due to compiling a model with an uninitialized field", .{}),
|
||||
else => {},
|
||||
}
|
||||
return Tensor.constant(.{}, data);
|
||||
}
|
||||
|
||||
test scalar {
|
||||
@ -1913,7 +1959,7 @@ pub const Tensor = struct {
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
||||
const loc = ctx.location(@src());
|
||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
||||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc);
|
||||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc);
|
||||
return _result(val.shape(), constant_op.result(0));
|
||||
}
|
||||
|
||||
@ -3786,6 +3832,7 @@ pub const Tensor = struct {
|
||||
/// Only for debug purpose, it inserts device to host synchronization
|
||||
/// so it will slow down the program execution.
|
||||
pub fn print(input: Tensor) Tensor {
|
||||
// TODO: find a way of doing print that doesn't involve a H2D copy.
|
||||
return ops.addHostCallback(
|
||||
&printCallback,
|
||||
null,
|
||||
@ -3797,8 +3844,10 @@ pub const Tensor = struct {
|
||||
|
||||
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
||||
const host_buffer = inputs[0];
|
||||
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||
std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr);
|
||||
std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||
// This is true because of the operand aliases.
|
||||
// Since the result is already pointing to the input we don't need to modify the buffer.
|
||||
std.debug.assert(host_buffer._data == outputs[0]._data);
|
||||
}
|
||||
};
|
||||
|
||||
@ -3918,6 +3967,10 @@ test "Tensor.maxPool2d" {
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
// TODO: we should strip out the non-buffer fields.
|
||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||
// Also it will simplify the layout of the Bufferized struct.
|
||||
// accelerating the calls to execute.
|
||||
return meta.MapType(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
const builtin = @import("builtin");
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const zml = @import("zml.zig");
|
||||
const meta = @import("meta.zig");
|
||||
const shapesOf = @import("tensor.zig").shapesOf;
|
||||
const zml = @import("zml.zig");
|
||||
|
||||
const log = std.log.scoped(.@"zml/testing");
|
||||
|
||||
@ -35,7 +36,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool
|
||||
/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the
|
||||
/// host for comparison !
|
||||
pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.page_allocator;
|
||||
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator;
|
||||
var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer)
|
||||
.{ try left_.toHostAlloc(allocator), true }
|
||||
else
|
||||
|
||||
Loading…
Reference in New Issue
Block a user