Add buffer and hostbuffer utilities with precise f32→bf16 conversion, type inference for loadBuffers, store expected input shapes, enhance meta.visit and JSON TaggedUnion support, and improve logging.

This commit is contained in:
Tarry Singh 2024-10-28 11:21:46 +00:00
parent 1540c6e85e
commit 3849eb10b7
14 changed files with 497 additions and 198 deletions

View File

@ -350,7 +350,7 @@ pub const Client = opaque {
} }
pub const BufferFromHostBufferArgs = struct { pub const BufferFromHostBufferArgs = struct {
data: []const u8, data: [*]const u8,
buffer_type: BufferType, buffer_type: BufferType,
dims: []const i64, dims: []const i64,
byte_strides: ?[]const i64, byte_strides: ?[]const i64,
@ -362,7 +362,7 @@ pub const Client = opaque {
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{ const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
.client = self.inner(), .client = self.inner(),
.data = @ptrCast(@constCast(args.data.ptr)), .data = @constCast(args.data),
.type = @intFromEnum(args.buffer_type), .type = @intFromEnum(args.buffer_type),
.dims = @ptrCast(@constCast(args.dims.ptr)), .dims = @ptrCast(@constCast(args.dims.ptr)),
.num_dims = args.dims.len, .num_dims = args.dims.len,

View File

@ -1,25 +1,46 @@
pub const std = @import("std"); pub const std = @import("std");
const ParseFromValueError = std.json.ParseFromValueError;
/// Handle json fields that can have different Zig types depending on the message.
/// Each union field should have a unique Zig type.
///
/// Example json:
///
/// ```json
/// [
/// { "question": "How old are you ?", "answer": 5 },
/// { "question": "Count to three.", "answer": [1, 2, 3] },
/// ]
/// ```
///
/// Corresponding Zig code:
///
/// ```zig
/// const Answer = union {
/// number: i32,
/// numbers: []const i32,
/// };
///
/// const Message = struct {
/// question: []const u8;
/// answer: stdx.json.Union(Answer);
/// }
/// ```
pub fn Union(comptime T: type) type { pub fn Union(comptime T: type) type {
return struct { return struct {
const Self = @This(); const Self = @This();
value: T, value: T,
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self { pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
return jsonParseFromValue( return jsonParseFromValue(
allocator, allocator,
try std.json.innerParse( try std.json.innerParse(std.json.Value, allocator, source, options),
std.json.Value,
allocator,
source,
options,
),
options, options,
); );
} }
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self { pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
switch (field.type) { switch (field.type) {
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) }, bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
@ -39,7 +60,67 @@ pub fn Union(comptime T: type) type {
}, },
} }
} }
return error.UnexpectedToken; return error.UnknownField;
}
};
}
/// Handle json fields that can have different Zig types depending on another field in the same message.
/// This is translated to a Zig tagged union.
///
/// Example json:
///
/// ```json
/// [
/// { "type": "faq", "question": "How old are you ?", "answer": 5 },
/// { "type": "address", "city": "NYC", "zipcode": "49130"},
/// ]
/// ```
///
/// Corresponding Zig struct:
///
/// ```zig
/// const Entry = union {
/// faq: struct { question: []const u8, answer: u32 },
/// address: struct { city: []const u8, zipcode: []const u8 },
/// };
///
/// const Message = []const stdx.json.TaggedUnion(Entry, "type");
/// ```
pub fn TaggedUnion(comptime T: type, comptime tag_name: [:0]const u8) type {
return struct {
const Self = @This();
value: T,
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
return jsonParseFromValue(
allocator,
try std.json.innerParse(std.json.Value, allocator, source, options),
options,
);
}
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
errdefer std.log.warn("failed to parse: {} as {s}", .{ source, @typeName(T) });
if (source != .object) return error.UnexpectedToken;
const o = source.object;
const tag = o.get(tag_name) orelse return error.MissingField;
for (o.keys(), o.values()) |k, v| {
std.log.warn("object['{s}'] = {}", .{ k, v });
}
if (tag != .string) return error.LengthMismatch;
inline for (std.meta.fields(T)) |field| {
if (std.mem.eql(u8, field.name, tag.string)) {
const inner_source = o.get(field.name) orelse return error.MissingField;
const inner: field.type = std.json.innerParseFromValue(field.type, allocator, inner_source, options) catch |err| {
std.log.warn("failed to interpret {s} as a {s}: {}", .{ tag.string, @typeName(field.type), err });
return err;
};
return .{ .value = @unionInit(T, field.name, inner) };
}
}
return error.InvalidEnumTag;
} }
}; };
} }

View File

@ -1,11 +1,9 @@
const asynk = @import("async");
const builtin = @import("builtin");
const c = @import("c");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const builtin = @import("builtin");
const zml = @import("zml.zig"); const asynk = @import("async");
const posix = @import("posix.zig"); const c = @import("c");
const stdx = @import("stdx");
pub const gguf = @import("aio/gguf.zig"); pub const gguf = @import("aio/gguf.zig");
pub const nemo = @import("aio/nemo.zig"); pub const nemo = @import("aio/nemo.zig");
@ -13,10 +11,11 @@ pub const safetensors = @import("aio/safetensors.zig");
pub const tinyllama = @import("aio/tinyllama.zig"); pub const tinyllama = @import("aio/tinyllama.zig");
pub const torch = @import("aio/torch.zig"); pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig"); pub const yaml = @import("aio/yaml.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const posix = @import("posix.zig");
const zml = @import("zml.zig");
pub const log = std.log.scoped(.@"zml/aio"); pub const log = std.log.scoped(.@"zml/aio");
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
std.testing.refAllDecls(gguf); std.testing.refAllDecls(gguf);
@ -26,6 +25,8 @@ test {
std.testing.refAllDecls(yaml); std.testing.refAllDecls(yaml);
} }
// TODO error set for weight loading
/// Detects the format of the model file (base on filename) and open it. /// Detects the format of the model file (base on filename) and open it.
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore { pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
return if (std.mem.endsWith(u8, model_path, ".safetensors")) return if (std.mem.endsWith(u8, model_path, ".safetensors"))
@ -422,7 +423,7 @@ fn _populateStruct(
return true; return true;
}, },
.float => { .float => {
obj.* = undefined; obj.* = std.math.nan(@TypeOf(obj.*));
return true; return true;
}, },
.void => true, .void => true,
@ -450,7 +451,7 @@ test populateModel {
// Create a fake HostBuffer, we use the given integer to identify the created buffer. // Create a fake HostBuffer, we use the given integer to identify the created buffer.
fn _newHostBuffer(n: u32) zml.HostBuffer { fn _newHostBuffer(n: u32) zml.HostBuffer {
return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined }; return .{ ._shape = zml.Shape.init(.{n}, .f16), ._strides = undefined, ._data = undefined };
} }
}; };
@ -500,7 +501,7 @@ test populateModel {
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function. /// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
pub fn loadBuffers( pub fn loadBuffers(
comptime Model: type, comptime Model: type,
init_args: anytype, init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
buffer_store: BufferStore, buffer_store: BufferStore,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
platform: zml.Platform, platform: zml.Platform,
@ -513,8 +514,6 @@ pub fn loadBuffers(
// If the Model has a "init" function, call it with the given parameters. // If the Model has a "init" function, call it with the given parameters.
if (@hasDecl(Model, "init")) { if (@hasDecl(Model, "init")) {
@call(.auto, Model.init, .{&model} ++ init_args); @call(.auto, Model.init, .{&model} ++ init_args);
} else {
stdx.debug.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
} }
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");

View File

@ -44,23 +44,6 @@ pub const Buffer = struct {
} }
}; };
pub const Shard = struct {
api: *const pjrt.Api,
buffer: *pjrt.Buffer,
ready_event: ?*pjrt.Event = null,
ready: bool = false,
pub fn awaitt(self: *Shard) !void {
if (self.ready) {
return;
}
if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| {
try ev.awaitt(self.api);
}
self.ready = true;
}
};
_shape: Shape, _shape: Shape,
_api: *const pjrt.Api, _api: *const pjrt.Api,
_shards: Shards, _shards: Shards,
@ -88,7 +71,7 @@ pub const Buffer = struct {
} else 0; } else 0;
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); const byte_strides = host_buffer.strides();
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
const devices = platform.getDevices(); const devices = platform.getDevices();
@ -103,7 +86,7 @@ pub const Buffer = struct {
platform.pjrt_client, platform.pjrt_client,
platform.pjrt_api, platform.pjrt_api,
pjrt.Client.BufferFromHostBufferArgs{ pjrt.Client.BufferFromHostBufferArgs{
.data = buf.data, .data = buf._data,
.buffer_type = buffer_type, .buffer_type = buffer_type,
.dims = buf.shape().dims(), .dims = buf.shape().dims(),
.byte_strides = byte_strides, .byte_strides = byte_strides,
@ -155,6 +138,14 @@ pub const Buffer = struct {
return try from(platform, host_buffer); return try from(platform, host_buffer);
} }
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
// TODO restore assert
// const memory = self.getMemory().kind(self._api);
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
}
/// Creates a Buffer with a single element. /// Creates a Buffer with a single element.
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer { pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
const x = dtype_.constant(val); const x = dtype_.constant(val);
@ -182,8 +173,8 @@ pub const Buffer = struct {
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) { if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
const host_buffer: HostBuffer = .{ const host_buffer: HostBuffer = .{
._shape = shape_, ._shape = shape_,
._strides = [1]i64{0} ** Shape.MAX_RANK, ._strides = @splat(0),
.data = x.constSlice(), ._data = x.constSlice().ptr,
}; };
return try from(platform, host_buffer); return try from(platform, host_buffer);
} }
@ -207,7 +198,7 @@ pub const Buffer = struct {
}, },
else => unreachable, else => unreachable,
} }
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes }; const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
return try from(platform, host_buffer); return try from(platform, host_buffer);
} }
@ -228,12 +219,12 @@ pub const Buffer = struct {
/// could lead to crashes and operations on the buffer will be slower. /// could lead to crashes and operations on the buffer will be slower.
/// Tested on Cuda 12.4. /// Tested on Cuda 12.4.
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer { pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr))); return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(buf._data));
} }
/// Creates a Buffer from a pointer into device memory. /// Creates a Buffer from a pointer into device memory.
/// This allows to interface with other libraries producing buffers. /// This allows to interface with other libraries producing buffers.
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer { pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer {
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
var res: [Shape.MAX_RANK]i64 = undefined; var res: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| { for (0..Shape.MAX_RANK) |i| {
@ -255,7 +246,7 @@ pub const Buffer = struct {
.tile_dims_sizes = &.{}, .tile_dims_sizes = &.{},
}, },
}, },
.stream = @bitCast(@as(usize, @intFromPtr(stream))), .stream = stream,
}) catch @panic("failed to createViewOfDeviceBuffer"); }) catch @panic("failed to createViewOfDeviceBuffer");
var shards: Shards = .{}; var shards: Shards = .{};
@ -296,7 +287,7 @@ pub const Buffer = struct {
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
const output = try HostBuffer.empty(allocator, self.shape()); const output = try HostBuffer.empty(allocator, self.shape());
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.bytes()));
if (maybe_event) |event| { if (maybe_event) |event| {
try event.await_(self._api); try event.await_(self._api);
} }

View File

@ -1,4 +1,5 @@
const std = @import("std"); const std = @import("std");
const floats = @import("floats.zig"); const floats = @import("floats.zig");
const C64 = std.math.Complex(f32); const C64 = std.math.Complex(f32);
@ -111,9 +112,7 @@ pub const DataType = enum(u8) {
} }
pub fn toZigType(comptime dtype: DataType) type { pub fn toZigType(comptime dtype: DataType) type {
return switch (dtype) { return @FieldType(Data, @tagName(dtype));
inline else => |tag| std.meta.TagPayload(Data, tag),
};
} }
pub fn isSignedInt(dtype: DataType) bool { pub fn isSignedInt(dtype: DataType) bool {
@ -125,19 +124,19 @@ pub const DataType = enum(u8) {
pub fn sizeOf(self: DataType) u16 { pub fn sizeOf(self: DataType) u16 {
return switch (self) { return switch (self) {
inline else => |tag| @sizeOf(std.meta.TagPayload(Data, tag)), inline else => |tag| @sizeOf(tag.toZigType()),
}; };
} }
pub fn bitSizeOf(self: DataType) u16 { pub fn bitSizeOf(self: DataType) u16 {
return switch (self) { return switch (self) {
inline else => |tag| @bitSizeOf(std.meta.TagPayload(Data, tag)), inline else => |tag| @bitSizeOf(tag.toZigType()),
}; };
} }
pub fn alignOf(self: DataType) u29 { pub fn alignOf(self: DataType) u29 {
return switch (self) { return switch (self) {
inline else => |tag| @alignOf(std.meta.TagPayload(Data, tag)), inline else => |tag| @alignOf(tag.toZigType()),
}; };
} }

View File

@ -1,13 +1,13 @@
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const aio = @import("aio.zig"); const aio = @import("aio.zig");
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized; const Bufferized = @import("tensor.zig").Bufferized;
const CompilationContext = @import("module.zig").CompilationContext; const CompilationContext = @import("module.zig").CompilationContext;
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf; const ShapeOf = @import("tensor.zig").ShapeOf;
@ -147,6 +147,7 @@ pub const BaseExe = struct {
/// Total number of buffers needed by this executable. /// Total number of buffers needed by this executable.
input_buffer_count: u32, input_buffer_count: u32,
input_shapes: []Shape,
result_shapes: []Shape, result_shapes: []Shape,
/// Num devices used (>1 for sharded executable) /// Num devices used (>1 for sharded executable)
@ -155,34 +156,44 @@ pub const BaseExe = struct {
/// Allocator backing memory /// Allocator backing memory
_arena: std.heap.ArenaAllocator, _arena: std.heap.ArenaAllocator,
pub fn init(parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { n_in: u32, result_shapes: []const Shape, n_devices: u8 }) !BaseExe { pub fn init(
parent_allocator: std.mem.Allocator,
platform: Platform,
exe: *pjrt.LoadedExecutable,
args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 },
) !BaseExe {
var arena = std.heap.ArenaAllocator.init(parent_allocator); var arena = std.heap.ArenaAllocator.init(parent_allocator);
errdefer arena.deinit(); errdefer arena.deinit();
const allocator = arena.allocator(); const allocator = arena.allocator();
const n_in = args.input_shapes.len;
const n_out = args.result_shapes.len; const n_out = args.result_shapes.len;
const n_devices = args.n_devices; const n_devices = args.n_devices;
// Allocate once for all the *pjrt.Buffer we need to store ... // Allocate once for all the *pjrt.Buffer we need to store ...
const all_buffers = try allocator.alloc(*pjrt.Buffer, (args.n_in + n_out) * n_devices); const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices);
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ args.n_in * n_devices, n_out * n_devices }); const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices });
// ... and once for all the [*]*pjrt.Buffer. // ... and once for all the [*]*pjrt.Buffer.
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
for (0..n_devices) |i| { for (0..n_devices) |i| {
input_per_device[i] = all_input_buffers[i * args.n_in ..].ptr; input_per_device[i] = all_input_buffers[i * n_in ..].ptr;
output_per_device[i] = all_output_buffers[i * n_out ..].ptr; output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
} }
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
@memcpy(all_shapes[0..n_in], args.input_shapes);
@memcpy(all_shapes[n_in..], args.result_shapes);
return .{ return .{
.platform = platform, .platform = platform,
.exe = exe, .exe = exe,
.ready_buffer_count = 0, .ready_buffer_count = 0,
.input_buffer_count = args.n_in, .input_buffer_count = @intCast(n_in),
.num_devices = args.n_devices, .num_devices = args.n_devices,
.input_per_device = input_per_device, .input_per_device = input_per_device,
.output_per_device = output_per_device, .output_per_device = output_per_device,
.result_shapes = try allocator.dupe(Shape, args.result_shapes), .input_shapes = all_shapes[0..n_in],
.result_shapes = all_shapes[n_in..],
._arena = arena, ._arena = arena,
}; };
} }
@ -209,7 +220,9 @@ pub const BaseExe = struct {
// even if it has been marked as "can be donated" during compilation. // even if it has been marked as "can be donated" during compilation.
// TODO: expose it ? // TODO: expose it ?
.non_donatable_input_indices = &.{}, .non_donatable_input_indices = &.{},
}) catch unreachable; }) catch |err| {
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
};
for (events[0..sharding.num_partitions]) |e| { for (events[0..sharding.num_partitions]) |e| {
if (e) |ev| { if (e) |ev| {
@ -232,7 +245,7 @@ pub const BaseExe = struct {
// } // }
pub fn prepare(self: *BaseExe, x: anytype) void { pub fn prepare(self: *BaseExe, x: anytype) void {
const n = fillBuffers(&x, self.input_per_device, self.ready_buffer_count); const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count);
self.ready_buffer_count += n; self.ready_buffer_count += n;
} }
@ -244,6 +257,14 @@ pub const BaseExe = struct {
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice()); return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
} }
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
return .init(parent_allocator, self.platform, self.exe, .{
.input_shapes = self.input_shapes,
.result_shapes = self.result_shapes,
.n_devices = self.num_devices,
});
}
}; };
/// Represents a ZML function, compiled into a PJRT executable. /// Represents a ZML function, compiled into a PJRT executable.
@ -280,7 +301,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
} }
pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) { pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
const total_ready = fillBuffers(&args, self.inner.input_per_device, self.inner.ready_buffer_count); const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count);
std.debug.assert(total_ready == self.inner.input_buffer_count); std.debug.assert(total_ready == self.inner.input_buffer_count);
self.inner._unsafeCall(); self.inner._unsafeCall();
var result: Bufferized(ReturnT) = undefined; var result: Bufferized(ReturnT) = undefined;
@ -302,20 +323,23 @@ fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
} }
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. /// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32) u32 { fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
const LocalContext = struct { const LocalContext = struct {
index: u32, index: u32,
buffers: []const [*]*pjrt.Buffer, buffers: []const [*]*pjrt.Buffer,
shapes: []const Shape,
}; };
var context: LocalContext = .{ var context: LocalContext = .{
.index = start, .index = start,
.buffers = buffers, .buffers = buffers,
.shapes = shapes,
}; };
meta.visit((struct { meta.visit((struct {
fn cb(ctx: *LocalContext, buffer: *const Buffer) void { fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
const model_sharding = ctx.buffers.len; const model_sharding = ctx.buffers.len;
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
for (buffer._shards.constSlice(), 0..) |shard, d| { for (buffer._shards.constSlice(), 0..) |shard, d| {
ctx.buffers[d][ctx.index] = shard; ctx.buffers[d][ctx.index] = shard;
} }

View File

@ -305,11 +305,23 @@ pub const BFloat16 = packed struct(u16) {
pub fn isInf(self: BFloat16) bool { pub fn isInf(self: BFloat16) bool {
return allBitsOne(self.exponent) and self.mantissa == 0; return allBitsOne(self.exponent) and self.mantissa == 0;
} }
pub fn toF32(self: BFloat16) f32 {
// Pad the BF16 with zeros 0
return @bitCast([2]u16{ 0, @bitCast(self) });
}
pub fn fromF32(float32: f32) BFloat16 {
var int: u32 = @bitCast(float32);
// Round up if needed.
int += 0x8000;
const parts: [2]u16 = @bitCast(int);
return @bitCast(parts[1]);
}
const Helpers = FloatHelpers(@This()); const Helpers = FloatHelpers(@This());
pub const zero = Helpers.zero; pub const zero = Helpers.zero;
pub const neg = Helpers.neg; pub const neg = Helpers.neg;
pub const fromF32 = Helpers.fromF32;
pub const toF32 = Helpers.toF32;
pub const format = Helpers.format; pub const format = Helpers.format;
}; };
@ -317,7 +329,7 @@ test BFloat16 {
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples // From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 }); try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 });
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 }); try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 }); try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 66 });
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 }); try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff }); try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32))); try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));

View File

@ -18,8 +18,8 @@ test {
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere). /// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
pub const HostBuffer = struct { pub const HostBuffer = struct {
_shape: Shape, _shape: Shape,
_strides: ?[Shape.MAX_RANK]i64 = null, _strides: [Shape.MAX_RANK]i64,
data: []const u8, _data: [*]const u8,
_memory: union(enum) { _memory: union(enum) {
managed: std.mem.Alignment, managed: std.mem.Alignment,
unmanaged, unmanaged,
@ -28,10 +28,11 @@ pub const HostBuffer = struct {
/// Allocates a HostBuffer with the given shape. /// Allocates a HostBuffer with the given shape.
/// The memory is left undefined. /// The memory is left undefined.
/// The caller owns the memory, and need to call `deinit()`. /// The caller owns the memory, and need to call `deinit()`.
pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer { pub fn empty(allocator: std.mem.Allocator, sh: Shape) error{OutOfMemory}!HostBuffer {
return .{ return .{
._shape = sh, ._shape = sh,
.data = try allocator.alignedAlloc(u8, 64, sh.byteSize()), ._strides = sh.computeStrides().buffer,
._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr,
._memory = .{ .managed = .@"64" }, ._memory = .{ .managed = .@"64" },
}; };
} }
@ -43,7 +44,8 @@ pub const HostBuffer = struct {
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len }); stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
return .{ return .{
._shape = shape_, ._shape = shape_,
.data = data_, ._strides = shape_.computeStrides().buffer,
._data = data_.ptr,
._memory = .unmanaged, ._memory = .unmanaged,
}; };
} }
@ -53,7 +55,7 @@ pub const HostBuffer = struct {
// This means we don't own the data. // This means we don't own the data.
if (self._memory == .unmanaged) return; if (self._memory == .unmanaged) return;
const log2_align = self._memory.managed; const log2_align = self._memory.managed;
allocator.rawFree(@constCast(self.data), log2_align, @returnAddress()); allocator.rawFree(self.mutBytes(), log2_align, @returnAddress());
} }
/// Wraps an exisiting slice into a HostBuffer. /// Wraps an exisiting slice into a HostBuffer.
@ -62,10 +64,12 @@ pub const HostBuffer = struct {
/// that will still need to be deallocated. /// that will still need to be deallocated.
pub fn fromSlice(sh: anytype, s: anytype) HostBuffer { pub fn fromSlice(sh: anytype, s: anytype) HostBuffer {
const shape_ = Shape.init(sh, DataType.fromSliceElementType(s)); const shape_ = Shape.init(sh, DataType.fromSliceElementType(s));
std.debug.assert(shape_.count() == s.len); const raw_bytes = std.mem.sliceAsBytes(s);
std.debug.assert(shape_.byteSize() == raw_bytes.len);
return .{ return .{
._shape = shape_, ._shape = shape_,
.data = @alignCast(std.mem.sliceAsBytes(s)), ._strides = shape_.computeStrides().buffer,
._data = raw_bytes.ptr,
._memory = .unmanaged, ._memory = .unmanaged,
}; };
} }
@ -81,7 +85,7 @@ pub const HostBuffer = struct {
@memcpy(tmp[0..strides_.len], strides_); @memcpy(tmp[0..strides_.len], strides_);
return .{ return .{
._shape = sh, ._shape = sh,
.data = @alignCast(std.mem.sliceAsBytes(s)), ._data = @alignCast(std.mem.sliceAsBytes(s).ptr),
._strides = tmp, ._strides = tmp,
._memory = .unmanaged, ._memory = .unmanaged,
}; };
@ -89,13 +93,15 @@ pub const HostBuffer = struct {
/// Creates a tensor from a **pointer** to a "multi dimension" array. /// Creates a tensor from a **pointer** to a "multi dimension" array.
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object. /// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
/// Typically this is use with constant arrays.
pub fn fromArray(arr_ptr: anytype) HostBuffer { pub fn fromArray(arr_ptr: anytype) HostBuffer {
const T = @TypeOf(arr_ptr.*); const T = @TypeOf(arr_ptr.*);
const sh = parseArrayInfo(T); const sh = parseArrayInfo(T);
std.debug.assert(sh.byteSize() == @sizeOf(T));
return .{ return .{
._shape = sh, ._shape = sh,
.data = @alignCast(std.mem.sliceAsBytes(arr_ptr)), ._strides = sh.computeStrides().buffer,
// Array are typically stack allocated and don't need to be freed. ._data = @ptrCast(arr_ptr),
._memory = .unmanaged, ._memory = .unmanaged,
}; };
} }
@ -121,16 +127,15 @@ pub const HostBuffer = struct {
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
const b = dt.sizeOf();
const res = try empty(allocator, Shape.init(.{n_steps}, dt)); const res = try empty(allocator, Shape.init(.{n_steps}, dt));
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
var data_ = @constCast(res.data);
switch (dt) { switch (dt) {
inline else => { inline else => |d| if (comptime d.class() != .integer) {
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
} else {
const Zt = d.toZigType();
var j: i64 = args.start; var j: i64 = args.start;
for (0..@intCast(n_steps)) |i| { for (res.mutItems(Zt)) |*val| {
var v = Data.init(dt, j); val.* = @intCast(j);
@memcpy(data_[i * b .. (i + 1) * b], v.constSlice());
j +%= args.step; j +%= args.step;
} }
}, },
@ -160,16 +165,26 @@ pub const HostBuffer = struct {
/// WARNING: It's only valid if the buffer is contiguous. /// WARNING: It's only valid if the buffer is contiguous.
/// Strided buffers can't use this method. /// Strided buffers can't use this method.
pub fn items(self: HostBuffer, comptime T: type) []const T { pub fn items(self: HostBuffer, comptime T: type) []const T {
if (DataType.fromZigType(T) != self.dtype()) { // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) }); stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) });
} stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
if (!self.isContiguous()) { const ptr: [*]const T = @alignCast(@ptrCast(self._data));
std.debug.panic("{} isn't contiguous", .{self});
}
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
return ptr[0..self._shape.count()]; return ptr[0..self._shape.count()];
} }
pub fn mutItems(self: HostBuffer, comptime T: type) []T {
return @constCast(self.items(T));
}
pub fn bytes(self: HostBuffer) []const u8 {
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
return self._data[0..self._shape.byteSize()];
}
pub fn mutBytes(self: HostBuffer) []u8 {
return @constCast(self.bytes());
}
pub fn shape(self: HostBuffer) Shape { pub fn shape(self: HostBuffer) Shape {
return self._shape; return self._shape;
} }
@ -178,9 +193,9 @@ pub const HostBuffer = struct {
return self._shape.dtype(); return self._shape.dtype();
} }
pub fn strides(self: *const HostBuffer) ?[]const i64 { pub fn strides(self: *const HostBuffer) []const i64 {
// Pass strides per pointer otherwise we return a pointer to this stack frame. // Pass strides per pointer otherwise we return a pointer to this stack frame.
return if (self._strides) |*strd| strd[0..self.rank()] else null; return self._strides[0..self._shape.rank()];
} }
// TODO: rename .data into ._data and make it a [*]u8 // TODO: rename .data into ._data and make it a [*]u8
@ -205,7 +220,7 @@ pub const HostBuffer = struct {
} }
pub fn isContiguous(self: HostBuffer) bool { pub fn isContiguous(self: HostBuffer) bool {
const _strides = self._strides orelse return true; const _strides = self._strides;
const cont_strides = self._shape.computeStrides(); const cont_strides = self._shape.computeStrides();
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| { for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
if (d != 1 and stride != cont_stride) return false; if (d != 1 and stride != cont_stride) return false;
@ -217,6 +232,7 @@ pub const HostBuffer = struct {
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
var res = self; var res = self;
res._shape = self._shape.reshape(shape_); res._shape = self._shape.reshape(shape_);
res._strides = res._shape.computeStrides().buffer;
return res; return res;
} }
@ -236,15 +252,12 @@ pub const HostBuffer = struct {
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
// If strides weren't set it means original buffer is contiguous. const offset: usize = @intCast(start * self._strides[ax]);
// But it won't be anymore after slicing. The strides don't change though. const new_shape = self.shape().set(ax, end - start);
const _strides = self._strides orelse self._shape.computeStrides().buffer;
const offset: usize = @intCast(start * _strides[ax]);
return .{ return .{
._shape = self.shape().set(ax, end - start), ._shape = new_shape,
.data = self.data[offset..], ._data = self._data[offset..],
// When axis is 0, we stay contiguous. ._strides = self._strides,
._strides = if (ax == 0) self._strides else _strides,
._memory = .unmanaged, ._memory = .unmanaged,
}; };
} }
@ -254,18 +267,52 @@ pub const HostBuffer = struct {
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax); return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
} }
pub fn choose(self: HostBuffer, offsets: anytype) HostBuffer {
const off, const tags = Shape.parseDimensions(offsets);
var sh = self._shape;
var offset: i64 = 0;
for (off.constSlice(), tags.constSlice()) |o, t| {
const ax = sh.axis(t);
offset += o * self._strides[ax];
sh._dims.buffer[ax] = 0;
}
var new_strides: [Shape.MAX_RANK]i64 = @splat(self.dtype().sizeOf());
// TODO rewrite with simd. This is a pshuf, but it's not supported by @shuffle.
var res_ax: u32 = 0;
for (0..self._shape.rank()) |ax| {
if (sh._dims.buffer[ax] == 0) {
continue;
}
sh._dims.buffer[res_ax] = self._shape._dims.buffer[ax];
sh._tags.buffer[res_ax] = self._shape._tags.buffer[ax];
new_strides[res_ax] = self._strides[ax];
res_ax += 1;
}
sh._dims.len -= off.len;
sh._tags.len -= off.len;
return HostBuffer{
._shape = sh,
._strides = new_strides,
._data = self._data[@intCast(offset)..],
._memory = .unmanaged,
};
}
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
const ax = self._shape.axis(axis_); const ax = self._shape.axis(axis_);
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
var _strides: ?[Shape.MAX_RANK]i64 = self._strides; var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
if (self._strides) |strydes| { _ = strd.orderedRemove(ax);
std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]);
}
return .{ return .{
._shape = self.shape().drop(ax), ._shape = self.shape().drop(ax),
.data = self.data, ._data = self._data,
._strides = _strides, ._strides = strd.buffer,
._memory = self._memory, ._memory = self._memory,
}; };
} }
@ -276,10 +323,13 @@ pub const HostBuffer = struct {
options: std.fmt.FormatOptions, options: std.fmt.FormatOptions,
writer: anytype, writer: anytype,
) !void { ) !void {
_ = fmt;
_ = options; _ = options;
if (std.mem.eql(u8, fmt, "v")) {
try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
} else {
try writer.print("HostBuffer(.{_})", .{self._shape}); try writer.print("HostBuffer(.{_})", .{self._shape});
} }
}
/// Formatter for a HostBuffer that also print the values not just the shape. /// Formatter for a HostBuffer that also print the values not just the shape.
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});` /// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`

View File

@ -237,42 +237,48 @@ test mapAlloc {
/// Recursively visit the given struct and calls the callback for each K found. /// Recursively visit the given struct and calls the callback for each K found.
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it. /// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
const T = @TypeOf(v);
const type_info_v = @typeInfo(T);
const K = switch (@typeInfo(FnParam(cb, 1))) {
.pointer => |info| info.child,
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
};
if (type_info_v != .pointer) {
const Callback = @TypeOf(cb); const Callback = @TypeOf(cb);
stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T }); const Ptr = @TypeOf(v);
const type_info_v = @typeInfo(Ptr);
if (type_info_v != .pointer) {
stdx.debug.compileError("zml.meta.visit({}) is expecting a pointer/slice input, but received: {}", .{ Callback, Ptr });
} }
const ptr_info = type_info_v.pointer; const ptr_info = type_info_v.pointer;
if (@typeInfo(ptr_info.child) == .@"fn") return; const Child = ptr_info.child;
if (ptr_info.child == anyopaque) return;
// This is important, because with trivial types like void,
// Zig sometimes decide to call `visit` at comptime, but can't do
// the pointer wrangling logic at comptime.
// So we detect early this case and return.
if (@sizeOf(ptr_info.child) == 0) return;
const K, const mutating_cb = switch (@typeInfo(FnParam(cb, 1))) {
.pointer => |info| .{ info.child, !info.is_const },
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
};
// Abort if v doesnt' contain any K.
if (comptime !Contains(Ptr, K)) return;
// Handle simple cases.
switch (Ptr) {
*const K, *K => return cb(ctx, v),
*const ?K, *?K => return if (v.*) |*val| cb(ctx, val) else {},
[]const K, []K => {
for (v) |*v_elem| cb(ctx, v_elem);
return;
},
else => {},
}
// Handle std.BoundedArray that contains uninitalized data.
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
}
// Recursively visit fields of v.
switch (ptr_info.size) { switch (ptr_info.size) {
// If we have a single pointer, two cases: .one => switch (@typeInfo(Child)) {
// * It's a pointer to K, in which case we call the callback. .@"struct" => |s| inline for (s.fields) |field| {
// * It's a pointer to something else, in which case, we explore and recurse if needed. if (field.is_comptime or comptime !Contains(field.type, K)) continue;
.one => if (ptr_info.child == K) { const field_type_info = @typeInfo(field.type);
cb(ctx, v);
} else if (ptr_info.child == ?K) {
if (v.*) |*val| cb(ctx, val);
} else switch (@typeInfo(ptr_info.child)) {
.@"struct" => |s| inline for (s.fields) |field_info| {
if (field_info.is_comptime) continue;
const field_type_info = @typeInfo(field_info.type);
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field. // If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
switch (field_type_info) { switch (field_type_info) {
.pointer => visit(cb, ctx, @field(v, field_info.name)), .pointer => visit(cb, ctx, @field(v, field.name)),
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field_info.name)), .array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field.name)),
else => {}, else => {},
} }
}, },
@ -281,23 +287,19 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
.@"union" => switch (v.*) { .@"union" => switch (v.*) {
inline else => |*v_field| visit(cb, ctx, v_field), inline else => |*v_field| visit(cb, ctx, v_field),
}, },
else => {}, else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
}, },
// If we have a slice, two cases also:
// * It's a slice of K, in which case we call the callback for each element of the slice.
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
.slice => { .slice => {
for (v) |*v_elem| { for (v) |*v_elem| {
if (ptr_info.child == K) { switch (@typeInfo(Child)) {
cb(ctx, v_elem); .@"struct" => |s| inline for (s.fields) |field| {
} else switch (@typeInfo(ptr_info.child)) { if (field.is_comptime or comptime !Contains(field.type, K)) continue;
.@"struct" => |s| inline for (s.fields) |field_info| { const field_type_info = @typeInfo(field.type);
const field_type_info = @typeInfo(field_info.type);
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field. // If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
if (field_type_info == .pointer) { if (field_type_info == .pointer) {
visit(cb, ctx, @field(v_elem, field_info.name)); visit(cb, ctx, @field(v_elem, field.name));
} else { } else {
visit(cb, ctx, &@field(v_elem, field_info.name)); visit(cb, ctx, &@field(v_elem, field.name));
} }
}, },
.array => |_| for (v) |*elem| visit(cb, ctx, elem), .array => |_| for (v) |*elem| visit(cb, ctx, elem),
@ -305,11 +307,11 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
.@"union" => switch (v_elem.*) { .@"union" => switch (v_elem.*) {
inline else => |*v_field| visit(cb, ctx, v_field), inline else => |*v_field| visit(cb, ctx, v_field),
}, },
else => {}, else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
} }
} }
}, },
else => {}, .many, .c => stdx.debug.compileError("zml.meta.visit({}) doesn't support [*] style pointers, got: {}", .{ Callback, Ptr }),
} }
} }
@ -320,7 +322,7 @@ test visit {
const NestedAttrOptional = struct { nested: ?Attr }; const NestedAttrOptional = struct { nested: ?Attr };
const SimpleStruct = struct { prop: Attr }; const SimpleStruct = struct { prop: Attr };
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr }; const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional }; const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) };
const LocalContext = struct { result: usize }; const LocalContext = struct { result: usize };
@ -374,11 +376,16 @@ test visit {
} }
{ {
var context: LocalContext = .{ .result = 0 }; var context: LocalContext = .{ .result = 0 };
const prop5: std.BoundedArray(Attr, 8) = .{
.buffer = @splat(.{ .data = 4 }),
.len = 2,
};
const container: NestedTypesStruct = .{ const container: NestedTypesStruct = .{
.prop1 = .{ .data = 1 }, .prop1 = .{ .data = 1 },
.prop2 = .{ .other = "hello" }, .prop2 = .{ .other = "hello" },
.prop3 = .{ .nested = .{ .data = 2 } }, .prop3 = .{ .nested = .{ .data = 2 } },
.prop4 = .{ .nested = .{ .data = 3 } }, .prop4 = .{ .nested = .{ .data = 3 } },
.prop5 = prop5, // 4 will be counted twice.
}; };
visit((struct { visit((struct {
@ -387,7 +394,7 @@ test visit {
} }
}).cb, &context, &container); }).cb, &context, &container);
try std.testing.expectEqual(6, context.result); try std.testing.expectEqual(14, context.result);
} }
} }
@ -533,3 +540,36 @@ fn _CollectArg(func: anytype) type {
const params = @typeInfo(@TypeOf(func)).@"fn".params; const params = @typeInfo(@TypeOf(func)).@"fn".params;
return params[params.len - 1].type orelse @compileError("anytype not supported in collect"); return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
} }
pub fn Contains(Haystack: type, T: type) bool {
switch (Haystack) {
T, ?T => return true,
*T, ?*T => return true,
*const T, ?*const T => return true,
[]const T, ?[]const T => return true,
anyopaque => return false,
else => {},
}
return switch (@typeInfo(Haystack)) {
.@"struct" => |info| {
inline for (info.fields) |field| {
if (Contains(field.type, T))
return true;
}
return false;
},
.@"union" => |info| {
inline for (info.fields) |field| {
if (Contains(field.type, T))
return true;
}
return false;
},
.array => |info| Contains(info.child, T),
.pointer => |info| Contains(info.child, T),
.optional => |info| Contains(info.child, T),
.vector => |info| Contains(info.child, T),
else => false,
};
}

View File

@ -29,7 +29,7 @@ test {
pub const MlirFn = struct { pub const MlirFn = struct {
name: []const u8, name: []const u8,
num_args: u32, args_shapes: []Shape,
res_tensors: *const anyopaque, res_tensors: *const anyopaque,
res_types: []mlir.Type, res_types: []mlir.Type,
res_shapes: []Shape, res_shapes: []Shape,
@ -199,7 +199,7 @@ pub const CompilationContext = struct {
const loaded_executable: *pjrt.LoadedExecutable = blk: { const loaded_executable: *pjrt.LoadedExecutable = blk: {
if (pjrt_location) |pjrt_loc| { if (pjrt_location) |pjrt_loc| {
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| { if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
log.info("Loaded pre-compiled module from {s}", .{pjrt_loc}); log.info("Loaded pre-compiled module from {s} (generated from {s}/module.mlir)", .{ pjrt_loc, module_dir.? });
break :blk exe; break :blk exe;
} else |err| { } else |err| {
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc }); if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
@ -233,7 +233,7 @@ pub const CompilationContext = struct {
self._platform, self._platform,
loaded_executable, loaded_executable,
.{ .{
.n_in = f.num_args, .input_shapes = f.args_shapes,
.result_shapes = f.res_shapes, .result_shapes = f.res_shapes,
.n_devices = sharding.num_replicas * sharding.num_partitions, .n_devices = sharding.num_replicas * sharding.num_partitions,
}, },
@ -341,7 +341,7 @@ pub const CompilationContext = struct {
const locations = try arena.alloc(mlir.Location, tensor_count); const locations = try arena.alloc(mlir.Location, tensor_count);
@memset(locations, mlir.Location.unknown(mlir_ctx)); @memset(locations, mlir.Location.unknown(mlir_ctx));
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
@ -427,7 +427,7 @@ pub const CompilationContext = struct {
return .{ return .{
.mlir_fn = mlir_fn, .mlir_fn = mlir_fn,
.name = opts.name, .name = opts.name,
.num_args = @intCast(tensor_count), .args_shapes = input_shapes.items,
.res_tensors = fn_res, .res_tensors = fn_res,
.res_types = fn_res_types, .res_types = fn_res_types,
.res_shapes = fn_res_shapes, .res_shapes = fn_res_shapes,
@ -512,7 +512,7 @@ pub const CompilationContext = struct {
// Check that the `x` input argument gives its buffer to the result tensor. // Check that the `x` input argument gives its buffer to the result tensor.
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
try std.testing.expectEqual(3, f.num_args); try std.testing.expectEqual(3, f.args_shapes.len);
// We should have two buffers being donated. // We should have two buffers being donated.
const template = "tf.aliasing_output = {d} : i32"; const template = "tf.aliasing_output = {d} : i32";
var buf = template.*; var buf = template.*;
@ -540,9 +540,13 @@ pub const CompilationContext = struct {
} }
} }
pub fn numPartitions(self: CompilationContext) u8 {
return self._platform.sharding().num_partitions;
}
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
const ctx = self.mlirCtx(); const ctx = self.mlirCtx();
const num_partitions = self._platform.sharding().num_partitions; const num_partitions = self.numPartitions();
var sharding_str: std.BoundedArray(u8, 128) = .{}; var sharding_str: std.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.Attribute.string(ctx, sharding_str.constSlice()); return mlir.Attribute.string(ctx, sharding_str.constSlice());
@ -645,10 +649,11 @@ pub const CompilationContext = struct {
const loc = self.mlirCtx().location(@src()); const loc = self.mlirCtx().location(@src());
const values = try arena.alloc(mlir.Value, function.num_args); const num_args = function.args_shapes.len;
const values = try arena.alloc(mlir.Value, num_args);
self.extractValues(&args, values); self.extractValues(&args, values);
const donations = try arena.alloc(Tensor._Donation, function.num_args); const donations = try arena.alloc(Tensor._Donation, num_args);
meta.collectBuf(struct { meta.collectBuf(struct {
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation { pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
return ctx.getValueAndDonation(x)[1]; return ctx.getValueAndDonation(x)[1];

View File

@ -176,6 +176,8 @@ pub const RopeOpts = struct {
/// Read a Rope scaling config from HF config.json format. /// Read a Rope scaling config from HF config.json format.
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling { pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
const content = try std.json.Value.jsonParse(allocator, source, options); const content = try std.json.Value.jsonParse(allocator, source, options);
if (content == .null) return .default;
if (content != .object) return error.InvalidEnumTag; if (content != .object) return error.InvalidEnumTag;
const obj = content.object; const obj = content.object;

View File

@ -58,10 +58,10 @@ pub const Shape = struct {
const fv = @field(v, field.name); const fv = @field(v, field.name);
if (comptime stdx.meta.isInteger(field.type)) { if (comptime stdx.meta.isInteger(field.type)) {
dims_.appendAssumeCapacity(@intCast(fv)); dims_.appendAssumeCapacity(@intCast(fv));
} else if (comptime isAutoDim(fv)) { } else if (@TypeOf(fv) == EnumLiteral and comptime isAutoDim(fv)) {
dims_.appendAssumeCapacity(-1); dims_.appendAssumeCapacity(-1);
} else { } else {
stdx.debug.compileError("Field {s} should be an integer or an auto dimension", .{field.name}); stdx.debug.compileError("Field {s} should be an integer or an auto dimension, got {}", .{ field.name, field.type });
} }
if (comptime stdx.meta.isTuple(T)) { if (comptime stdx.meta.isTuple(T)) {
tags_.appendAssumeCapacity(TagUnknown); tags_.appendAssumeCapacity(TagUnknown);
@ -186,7 +186,7 @@ pub const Shape = struct {
EnumLiteral => @tagName(v).ptr, EnumLiteral => @tagName(v).ptr,
std.builtin.Type.StructField => v.name.ptr, std.builtin.Type.StructField => v.name.ptr,
Tag => v, Tag => v,
else => stdx.debug.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), else => stdx.debug.compileError("Shape tag should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
}; };
} }
@ -581,6 +581,41 @@ pub const Shape = struct {
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims()); try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
} }
pub fn removeMany(self: Shape, axes_: anytype) Shape {
var to_remove = self.axes(axes_);
if (to_remove.len == 0) return self;
std.mem.sort(u3, to_remove.slice(), {}, std.sort.asc(u3));
var sh: Shape = self;
const rk = self.rank();
var res_ax: u32 = 0;
for (0..rk) |ax| {
if (std.mem.indexOfScalar(u3, to_remove.constSlice(), @intCast(ax))) |_| {
continue;
}
sh._dims.buffer[res_ax] = self._dims.buffer[ax];
sh._tags.buffer[res_ax] = self._tags.buffer[ax];
res_ax += 1;
}
sh._dims.len = rk - to_remove.len;
sh._tags.len = rk - to_remove.len;
return sh;
}
test removeMany {
try std.testing.expectEqualSlices(
i64,
&.{12},
Shape.init(.{ 10, 11, 12 }, .f32).removeMany(.{ 0, 1 }).dims(),
);
try std.testing.expectEqualSlices(
i64,
&.{ 10, 11 },
Shape.init(.{ 10, 11, 12, 13 }, .f32).removeMany(.{ -1, -2 }).dims(),
);
}
pub fn transpose(self: Shape, permutations: anytype) Shape { pub fn transpose(self: Shape, permutations: anytype) Shape {
std.debug.assert(self.rank() == permutations.len); std.debug.assert(self.rank() == permutations.len);
const permutations_ = self.axes(permutations); const permutations_ = self.axes(permutations);
@ -729,7 +764,9 @@ pub const Shape = struct {
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T}); stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
var res = self; var res = self;
inline for (std.meta.fields(T)) |field| { inline for (std.meta.fields(T)) |field| {
res._tags.set(self.axis(field), toTag(@field(renames, field.name))); const new_field = @field(renames, field.name);
stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
res._tags.set(self.axis(field), toTag(new_field));
} }
return res; return res;
} }
@ -749,15 +786,20 @@ pub const Shape = struct {
} }
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) { pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
const base_stride = self.dtype().sizeOf();
const rk = self.rank(); const rk = self.rank();
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) }; var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
if (rk == 0) return strides; if (rk == 0) return strides;
strides.buffer[rk - 1] = base_stride;
for (1..rk) |i| { const V = @Vector(MAX_RANK, i64);
const j = @as(usize, rk) - 1 - i; const rank_mask = std.simd.iota(u8, MAX_RANK) < @as(@Vector(MAX_RANK, u8), @splat(rk));
strides.buffer[j] = self._dims.get(j + 1) * strides.buffer[j + 1]; // For each axis compute the product of all following dimensions
} // and the element size in bytes.
var d: V = @bitCast(self._dims.buffer);
d = @select(i64, rank_mask, d, @as(V, @splat(1)));
d = std.simd.shiftElementsLeft(d, 1, self.dtype().sizeOf());
d = std.simd.prefixScan(.Mul, -1, d);
strides.buffer = @bitCast(d);
return strides; return strides;
} }

View File

@ -176,6 +176,7 @@ pub const Tensor = struct {
var res = self; var res = self;
res._shape = self._shape.withSharding(axes_); res._shape = self._shape.withSharding(axes_);
if (ctx.numPartitions() <= 1) return self;
const op = dialect.stablehlo.custom_call( const op = dialect.stablehlo.custom_call(
mlir_ctx, mlir_ctx,
&.{self.value()}, &.{self.value()},
@ -1279,9 +1280,9 @@ pub const Tensor = struct {
/// see: https://paperswithcode.com/method/gelu /// see: https://paperswithcode.com/method/gelu
pub fn gelu(x: Tensor) Tensor { pub fn gelu(x: Tensor) Tensor {
const scaled_x_cube = x.mul(x).mul(x).scale(0.044715); const scaled_x_cube = x.mul(x).mul(x).scale(0.044715);
const one = Tensor.constant(x._shape, x.dtype().one()); const beta = std.math.sqrt(2.0 / std.math.pi);
const one_plus_tanh = Tensor.add(x, scaled_x_cube).scale(std.math.sqrt(2.0 / std.math.pi)).tanh().add(one); const tanh_ = x.add(scaled_x_cube).scale(beta).tanh();
return one_plus_tanh.mul(x).scale(0.5); return tanh_.addConstant(1).mul(x).scale(0.5);
} }
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor. /// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
@ -1526,8 +1527,34 @@ pub const Tensor = struct {
pub const Slice = struct { pub const Slice = struct {
start: i64 = 0, start: i64 = 0,
end: ?i64 = null, end: i64 = to_the_end,
step: i64 = 1, step: i32 = 1,
singleton: bool = false,
pub fn single(offset: i64) Slice {
return .{ .start = offset, .end = offset + 1, .singleton = true };
}
const to_the_end = std.math.maxInt(i64);
pub fn format(
self: Slice,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
if (self.singleton) {
try writer.print("[{}]", .{self.start});
} else if (self.end == to_the_end and self.step == 1) {
try writer.print("[{}..]", .{self.start});
} else if (self.step == 1) {
try writer.print("[{}..{}]", .{ self.start, self.end });
} else {
try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step });
}
}
}; };
/// Slices the input Tensor over the given axis using the given parameters. /// Slices the input Tensor over the given axis using the given parameters.
@ -1549,13 +1576,13 @@ pub const Tensor = struct {
const args: Slice = .{ const args: Slice = .{
.start = self.wrapIndex(a, s.start), .start = self.wrapIndex(a, s.start),
.end = if (s.end) |end| self.wrapIndex(a, end) else self.dim(a), .end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end),
.step = s.step, .step = s.step,
}; };
start_indices[a] = args.start; start_indices[a] = args.start;
limit_indices[a] = args.end.?; limit_indices[a] = args.end;
strides[a] = args.step; strides[a] = args.step;
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable); res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable);
} }
const mlir_ctx = self.getContext().mlirCtx(); const mlir_ctx = self.getContext().mlirCtx();
@ -1571,7 +1598,12 @@ pub const Tensor = struct {
loc, loc,
); );
return _result(res_shape, slice_op.result(0)); var res = _result(res_shape, slice_op.result(0));
var to_remove: Shape.AxesArray = .{};
for (slices, 0..) |s, a| {
if (s.singleton) to_remove.appendAssumeCapacity(@intCast(a));
}
return res.reshape(res_shape.removeMany(to_remove.constSlice()));
} }
test slice { test slice {
@ -1606,8 +1638,17 @@ pub const Tensor = struct {
} }
pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor { pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor {
// TODO: this use case could be handled directly by slice if we added a .single field return self.slice1d(axis_, .single(i));
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_); }
pub fn choose(self: Tensor, offsets: anytype) Tensor {
const off, const tags = Shape.parseDimensions(offsets);
var slices = [_]Slice{.{}} ** MAX_RANK;
for (off.constSlice(), tags.constSlice()) |o, t| {
const ax = self.axis(t);
slices[ax] = .single(o);
}
return self.slice(slices[0..self.rank()]);
} }
/// Concatenates the input Tensors along the given axis. /// Concatenates the input Tensors along the given axis.
@ -1866,7 +1907,12 @@ pub const Tensor = struct {
/// Returns a 0-rank Tensor with the given value. /// Returns a 0-rank Tensor with the given value.
pub fn scalar(val: anytype, dt: DataType) Tensor { pub fn scalar(val: anytype, dt: DataType) Tensor {
return Tensor.constant(.{}, Data.init(dt, val)); const data = Data.init(dt, val);
switch (dt.class()) {
.float => stdx.debug.assert(!std.math.isNan(val), "scalar(NaN) is probably due to compiling a model with an uninitialized field", .{}),
else => {},
}
return Tensor.constant(.{}, data);
} }
test scalar { test scalar {
@ -1913,7 +1959,7 @@ pub const Tensor = struct {
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape()); const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
const loc = ctx.location(@src()); const loc = ctx.location(@src());
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc); const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc);
return _result(val.shape(), constant_op.result(0)); return _result(val.shape(), constant_op.result(0));
} }
@ -3786,6 +3832,7 @@ pub const Tensor = struct {
/// Only for debug purpose, it inserts device to host synchronization /// Only for debug purpose, it inserts device to host synchronization
/// so it will slow down the program execution. /// so it will slow down the program execution.
pub fn print(input: Tensor) Tensor { pub fn print(input: Tensor) Tensor {
// TODO: find a way of doing print that doesn't involve a H2D copy.
return ops.addHostCallback( return ops.addHostCallback(
&printCallback, &printCallback,
null, null,
@ -3797,8 +3844,10 @@ pub const Tensor = struct {
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
const host_buffer = inputs[0]; const host_buffer = inputs[0];
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr); // This is true because of the operand aliases.
// Since the result is already pointing to the input we don't need to modify the buffer.
std.debug.assert(host_buffer._data == outputs[0]._data);
} }
}; };
@ -3918,6 +3967,10 @@ test "Tensor.maxPool2d" {
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. /// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized(comptime T: type) type { pub fn Bufferized(comptime T: type) type {
// TODO: we should strip out the non-buffer fields.
// Currently it's confusing cause the Bufferized struct contains field that are never read.
// Also it will simplify the layout of the Bufferized struct.
// accelerating the calls to execute.
return meta.MapType(Tensor, Buffer).map(T); return meta.MapType(Tensor, Buffer).map(T);
} }

View File

@ -1,10 +1,11 @@
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const builtin = @import("builtin");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml = @import("zml.zig");
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const shapesOf = @import("tensor.zig").shapesOf; const shapesOf = @import("tensor.zig").shapesOf;
const zml = @import("zml.zig");
const log = std.log.scoped(.@"zml/testing"); const log = std.log.scoped(.@"zml/testing");
@ -35,7 +36,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool
/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the /// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the
/// host for comparison ! /// host for comparison !
pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.page_allocator; const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator;
var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer) var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer)
.{ try left_.toHostAlloc(allocator), true } .{ try left_.toHostAlloc(allocator), true }
else else