Add buffer and hostbuffer utilities with precise f32→bf16 conversion, type inference for loadBuffers, store expected input shapes, enhance meta.visit and JSON TaggedUnion support, and improve logging.
This commit is contained in:
parent
1540c6e85e
commit
3849eb10b7
@ -350,7 +350,7 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const BufferFromHostBufferArgs = struct {
|
pub const BufferFromHostBufferArgs = struct {
|
||||||
data: []const u8,
|
data: [*]const u8,
|
||||||
buffer_type: BufferType,
|
buffer_type: BufferType,
|
||||||
dims: []const i64,
|
dims: []const i64,
|
||||||
byte_strides: ?[]const i64,
|
byte_strides: ?[]const i64,
|
||||||
@ -362,7 +362,7 @@ pub const Client = opaque {
|
|||||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||||
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
|
const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
.data = @ptrCast(@constCast(args.data.ptr)),
|
.data = @constCast(args.data),
|
||||||
.type = @intFromEnum(args.buffer_type),
|
.type = @intFromEnum(args.buffer_type),
|
||||||
.dims = @ptrCast(@constCast(args.dims.ptr)),
|
.dims = @ptrCast(@constCast(args.dims.ptr)),
|
||||||
.num_dims = args.dims.len,
|
.num_dims = args.dims.len,
|
||||||
|
|||||||
@ -1,25 +1,46 @@
|
|||||||
pub const std = @import("std");
|
pub const std = @import("std");
|
||||||
|
const ParseFromValueError = std.json.ParseFromValueError;
|
||||||
|
|
||||||
|
/// Handle json fields that can have different Zig types depending on the message.
|
||||||
|
/// Each union field should have a unique Zig type.
|
||||||
|
///
|
||||||
|
/// Example json:
|
||||||
|
///
|
||||||
|
/// ```json
|
||||||
|
/// [
|
||||||
|
/// { "question": "How old are you ?", "answer": 5 },
|
||||||
|
/// { "question": "Count to three.", "answer": [1, 2, 3] },
|
||||||
|
/// ]
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Corresponding Zig code:
|
||||||
|
///
|
||||||
|
/// ```zig
|
||||||
|
/// const Answer = union {
|
||||||
|
/// number: i32,
|
||||||
|
/// numbers: []const i32,
|
||||||
|
/// };
|
||||||
|
///
|
||||||
|
/// const Message = struct {
|
||||||
|
/// question: []const u8;
|
||||||
|
/// answer: stdx.json.Union(Answer);
|
||||||
|
/// }
|
||||||
|
/// ```
|
||||||
pub fn Union(comptime T: type) type {
|
pub fn Union(comptime T: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
const Self = @This();
|
const Self = @This();
|
||||||
|
|
||||||
value: T,
|
value: T,
|
||||||
|
|
||||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self {
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
|
||||||
return jsonParseFromValue(
|
return jsonParseFromValue(
|
||||||
allocator,
|
allocator,
|
||||||
try std.json.innerParse(
|
try std.json.innerParse(std.json.Value, allocator, source, options),
|
||||||
std.json.Value,
|
|
||||||
allocator,
|
|
||||||
source,
|
|
||||||
options,
|
|
||||||
),
|
|
||||||
options,
|
options,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self {
|
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
switch (field.type) {
|
switch (field.type) {
|
||||||
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
|
bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) },
|
||||||
@ -39,7 +60,67 @@ pub fn Union(comptime T: type) type {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return error.UnexpectedToken;
|
return error.UnknownField;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handle json fields that can have different Zig types depending on another field in the same message.
|
||||||
|
/// This is translated to a Zig tagged union.
|
||||||
|
///
|
||||||
|
/// Example json:
|
||||||
|
///
|
||||||
|
/// ```json
|
||||||
|
/// [
|
||||||
|
/// { "type": "faq", "question": "How old are you ?", "answer": 5 },
|
||||||
|
/// { "type": "address", "city": "NYC", "zipcode": "49130"},
|
||||||
|
/// ]
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Corresponding Zig struct:
|
||||||
|
///
|
||||||
|
/// ```zig
|
||||||
|
/// const Entry = union {
|
||||||
|
/// faq: struct { question: []const u8, answer: u32 },
|
||||||
|
/// address: struct { city: []const u8, zipcode: []const u8 },
|
||||||
|
/// };
|
||||||
|
///
|
||||||
|
/// const Message = []const stdx.json.TaggedUnion(Entry, "type");
|
||||||
|
/// ```
|
||||||
|
pub fn TaggedUnion(comptime T: type, comptime tag_name: [:0]const u8) type {
|
||||||
|
return struct {
|
||||||
|
const Self = @This();
|
||||||
|
|
||||||
|
value: T,
|
||||||
|
|
||||||
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self {
|
||||||
|
return jsonParseFromValue(
|
||||||
|
allocator,
|
||||||
|
try std.json.innerParse(std.json.Value, allocator, source, options),
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self {
|
||||||
|
errdefer std.log.warn("failed to parse: {} as {s}", .{ source, @typeName(T) });
|
||||||
|
if (source != .object) return error.UnexpectedToken;
|
||||||
|
const o = source.object;
|
||||||
|
const tag = o.get(tag_name) orelse return error.MissingField;
|
||||||
|
for (o.keys(), o.values()) |k, v| {
|
||||||
|
std.log.warn("object['{s}'] = {}", .{ k, v });
|
||||||
|
}
|
||||||
|
if (tag != .string) return error.LengthMismatch;
|
||||||
|
inline for (std.meta.fields(T)) |field| {
|
||||||
|
if (std.mem.eql(u8, field.name, tag.string)) {
|
||||||
|
const inner_source = o.get(field.name) orelse return error.MissingField;
|
||||||
|
const inner: field.type = std.json.innerParseFromValue(field.type, allocator, inner_source, options) catch |err| {
|
||||||
|
std.log.warn("failed to interpret {s} as a {s}: {}", .{ tag.string, @typeName(field.type), err });
|
||||||
|
return err;
|
||||||
|
};
|
||||||
|
return .{ .value = @unionInit(T, field.name, inner) };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return error.InvalidEnumTag;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
25
zml/aio.zig
25
zml/aio.zig
@ -1,11 +1,9 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const builtin = @import("builtin");
|
|
||||||
const c = @import("c");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
const asynk = @import("async");
|
||||||
const posix = @import("posix.zig");
|
const c = @import("c");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
pub const gguf = @import("aio/gguf.zig");
|
pub const gguf = @import("aio/gguf.zig");
|
||||||
pub const nemo = @import("aio/nemo.zig");
|
pub const nemo = @import("aio/nemo.zig");
|
||||||
@ -13,10 +11,11 @@ pub const safetensors = @import("aio/safetensors.zig");
|
|||||||
pub const tinyllama = @import("aio/tinyllama.zig");
|
pub const tinyllama = @import("aio/tinyllama.zig");
|
||||||
pub const torch = @import("aio/torch.zig");
|
pub const torch = @import("aio/torch.zig");
|
||||||
pub const yaml = @import("aio/yaml.zig");
|
pub const yaml = @import("aio/yaml.zig");
|
||||||
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const posix = @import("posix.zig");
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
pub const log = std.log.scoped(.@"zml/aio");
|
pub const log = std.log.scoped(.@"zml/aio");
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
std.testing.refAllDecls(gguf);
|
std.testing.refAllDecls(gguf);
|
||||||
@ -26,6 +25,8 @@ test {
|
|||||||
std.testing.refAllDecls(yaml);
|
std.testing.refAllDecls(yaml);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO error set for weight loading
|
||||||
|
|
||||||
/// Detects the format of the model file (base on filename) and open it.
|
/// Detects the format of the model file (base on filename) and open it.
|
||||||
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
|
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
|
||||||
return if (std.mem.endsWith(u8, model_path, ".safetensors"))
|
return if (std.mem.endsWith(u8, model_path, ".safetensors"))
|
||||||
@ -422,7 +423,7 @@ fn _populateStruct(
|
|||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.float => {
|
.float => {
|
||||||
obj.* = undefined;
|
obj.* = std.math.nan(@TypeOf(obj.*));
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
.void => true,
|
.void => true,
|
||||||
@ -450,7 +451,7 @@ test populateModel {
|
|||||||
|
|
||||||
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
|
// Create a fake HostBuffer, we use the given integer to identify the created buffer.
|
||||||
fn _newHostBuffer(n: u32) zml.HostBuffer {
|
fn _newHostBuffer(n: u32) zml.HostBuffer {
|
||||||
return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined };
|
return .{ ._shape = zml.Shape.init(.{n}, .f16), ._strides = undefined, ._data = undefined };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -500,7 +501,7 @@ test populateModel {
|
|||||||
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
|
/// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function.
|
||||||
pub fn loadBuffers(
|
pub fn loadBuffers(
|
||||||
comptime Model: type,
|
comptime Model: type,
|
||||||
init_args: anytype,
|
init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void,
|
||||||
buffer_store: BufferStore,
|
buffer_store: BufferStore,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
@ -513,8 +514,6 @@ pub fn loadBuffers(
|
|||||||
// If the Model has a "init" function, call it with the given parameters.
|
// If the Model has a "init" function, call it with the given parameters.
|
||||||
if (@hasDecl(Model, "init")) {
|
if (@hasDecl(Model, "init")) {
|
||||||
@call(.auto, Model.init, .{&model} ++ init_args);
|
@call(.auto, Model.init, .{&model} ++ init_args);
|
||||||
} else {
|
|
||||||
stdx.debug.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||||
|
|||||||
@ -44,23 +44,6 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Shard = struct {
|
|
||||||
api: *const pjrt.Api,
|
|
||||||
buffer: *pjrt.Buffer,
|
|
||||||
ready_event: ?*pjrt.Event = null,
|
|
||||||
ready: bool = false,
|
|
||||||
|
|
||||||
pub fn awaitt(self: *Shard) !void {
|
|
||||||
if (self.ready) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| {
|
|
||||||
try ev.awaitt(self.api);
|
|
||||||
}
|
|
||||||
self.ready = true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
_shape: Shape,
|
_shape: Shape,
|
||||||
_api: *const pjrt.Api,
|
_api: *const pjrt.Api,
|
||||||
_shards: Shards,
|
_shards: Shards,
|
||||||
@ -88,7 +71,7 @@ pub const Buffer = struct {
|
|||||||
} else 0;
|
} else 0;
|
||||||
|
|
||||||
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype());
|
||||||
const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice();
|
const byte_strides = host_buffer.strides();
|
||||||
|
|
||||||
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{};
|
||||||
const devices = platform.getDevices();
|
const devices = platform.getDevices();
|
||||||
@ -103,7 +86,7 @@ pub const Buffer = struct {
|
|||||||
platform.pjrt_client,
|
platform.pjrt_client,
|
||||||
platform.pjrt_api,
|
platform.pjrt_api,
|
||||||
pjrt.Client.BufferFromHostBufferArgs{
|
pjrt.Client.BufferFromHostBufferArgs{
|
||||||
.data = buf.data,
|
.data = buf._data,
|
||||||
.buffer_type = buffer_type,
|
.buffer_type = buffer_type,
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
.byte_strides = byte_strides,
|
.byte_strides = byte_strides,
|
||||||
@ -155,6 +138,14 @@ pub const Buffer = struct {
|
|||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
||||||
|
// TODO restore assert
|
||||||
|
// const memory = self.getMemory().kind(self._api);
|
||||||
|
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
|
||||||
|
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||||
|
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a Buffer with a single element.
|
/// Creates a Buffer with a single element.
|
||||||
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||||
const x = dtype_.constant(val);
|
const x = dtype_.constant(val);
|
||||||
@ -182,8 +173,8 @@ pub const Buffer = struct {
|
|||||||
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
|
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
|
||||||
const host_buffer: HostBuffer = .{
|
const host_buffer: HostBuffer = .{
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
._strides = [1]i64{0} ** Shape.MAX_RANK,
|
._strides = @splat(0),
|
||||||
.data = x.constSlice(),
|
._data = x.constSlice().ptr,
|
||||||
};
|
};
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer);
|
||||||
}
|
}
|
||||||
@ -207,7 +198,7 @@ pub const Buffer = struct {
|
|||||||
},
|
},
|
||||||
else => unreachable,
|
else => unreachable,
|
||||||
}
|
}
|
||||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes };
|
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -228,12 +219,12 @@ pub const Buffer = struct {
|
|||||||
/// could lead to crashes and operations on the buffer will be slower.
|
/// could lead to crashes and operations on the buffer will be slower.
|
||||||
/// Tested on Cuda 12.4.
|
/// Tested on Cuda 12.4.
|
||||||
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
|
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
|
||||||
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
|
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(buf._data));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Buffer from a pointer into device memory.
|
/// Creates a Buffer from a pointer into device memory.
|
||||||
/// This allows to interface with other libraries producing buffers.
|
/// This allows to interface with other libraries producing buffers.
|
||||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer {
|
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer {
|
||||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
@ -255,7 +246,7 @@ pub const Buffer = struct {
|
|||||||
.tile_dims_sizes = &.{},
|
.tile_dims_sizes = &.{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
.stream = @bitCast(@as(usize, @intFromPtr(stream))),
|
.stream = stream,
|
||||||
}) catch @panic("failed to createViewOfDeviceBuffer");
|
}) catch @panic("failed to createViewOfDeviceBuffer");
|
||||||
|
|
||||||
var shards: Shards = .{};
|
var shards: Shards = .{};
|
||||||
@ -296,7 +287,7 @@ pub const Buffer = struct {
|
|||||||
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer {
|
||||||
const output = try HostBuffer.empty(allocator, self.shape());
|
const output = try HostBuffer.empty(allocator, self.shape());
|
||||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{});
|
||||||
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data));
|
const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.bytes()));
|
||||||
if (maybe_event) |event| {
|
if (maybe_event) |event| {
|
||||||
try event.await_(self._api);
|
try event.await_(self._api);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const floats = @import("floats.zig");
|
const floats = @import("floats.zig");
|
||||||
|
|
||||||
const C64 = std.math.Complex(f32);
|
const C64 = std.math.Complex(f32);
|
||||||
@ -111,9 +112,7 @@ pub const DataType = enum(u8) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn toZigType(comptime dtype: DataType) type {
|
pub fn toZigType(comptime dtype: DataType) type {
|
||||||
return switch (dtype) {
|
return @FieldType(Data, @tagName(dtype));
|
||||||
inline else => |tag| std.meta.TagPayload(Data, tag),
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isSignedInt(dtype: DataType) bool {
|
pub fn isSignedInt(dtype: DataType) bool {
|
||||||
@ -125,19 +124,19 @@ pub const DataType = enum(u8) {
|
|||||||
|
|
||||||
pub fn sizeOf(self: DataType) u16 {
|
pub fn sizeOf(self: DataType) u16 {
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
inline else => |tag| @sizeOf(std.meta.TagPayload(Data, tag)),
|
inline else => |tag| @sizeOf(tag.toZigType()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn bitSizeOf(self: DataType) u16 {
|
pub fn bitSizeOf(self: DataType) u16 {
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
inline else => |tag| @bitSizeOf(std.meta.TagPayload(Data, tag)),
|
inline else => |tag| @bitSizeOf(tag.toZigType()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn alignOf(self: DataType) u29 {
|
pub fn alignOf(self: DataType) u29 {
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
inline else => |tag| @alignOf(std.meta.TagPayload(Data, tag)),
|
inline else => |tag| @alignOf(tag.toZigType()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
50
zml/exe.zig
50
zml/exe.zig
@ -1,13 +1,13 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const aio = @import("aio.zig");
|
const aio = @import("aio.zig");
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const Bufferized = @import("tensor.zig").Bufferized;
|
const Bufferized = @import("tensor.zig").Bufferized;
|
||||||
const CompilationContext = @import("module.zig").CompilationContext;
|
const CompilationContext = @import("module.zig").CompilationContext;
|
||||||
|
const meta = @import("meta.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||||
@ -147,6 +147,7 @@ pub const BaseExe = struct {
|
|||||||
/// Total number of buffers needed by this executable.
|
/// Total number of buffers needed by this executable.
|
||||||
input_buffer_count: u32,
|
input_buffer_count: u32,
|
||||||
|
|
||||||
|
input_shapes: []Shape,
|
||||||
result_shapes: []Shape,
|
result_shapes: []Shape,
|
||||||
|
|
||||||
/// Num devices used (>1 for sharded executable)
|
/// Num devices used (>1 for sharded executable)
|
||||||
@ -155,34 +156,44 @@ pub const BaseExe = struct {
|
|||||||
/// Allocator backing memory
|
/// Allocator backing memory
|
||||||
_arena: std.heap.ArenaAllocator,
|
_arena: std.heap.ArenaAllocator,
|
||||||
|
|
||||||
pub fn init(parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { n_in: u32, result_shapes: []const Shape, n_devices: u8 }) !BaseExe {
|
pub fn init(
|
||||||
|
parent_allocator: std.mem.Allocator,
|
||||||
|
platform: Platform,
|
||||||
|
exe: *pjrt.LoadedExecutable,
|
||||||
|
args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 },
|
||||||
|
) !BaseExe {
|
||||||
var arena = std.heap.ArenaAllocator.init(parent_allocator);
|
var arena = std.heap.ArenaAllocator.init(parent_allocator);
|
||||||
errdefer arena.deinit();
|
errdefer arena.deinit();
|
||||||
const allocator = arena.allocator();
|
const allocator = arena.allocator();
|
||||||
|
const n_in = args.input_shapes.len;
|
||||||
const n_out = args.result_shapes.len;
|
const n_out = args.result_shapes.len;
|
||||||
const n_devices = args.n_devices;
|
const n_devices = args.n_devices;
|
||||||
// Allocate once for all the *pjrt.Buffer we need to store ...
|
// Allocate once for all the *pjrt.Buffer we need to store ...
|
||||||
const all_buffers = try allocator.alloc(*pjrt.Buffer, (args.n_in + n_out) * n_devices);
|
const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices);
|
||||||
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ args.n_in * n_devices, n_out * n_devices });
|
const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices });
|
||||||
|
|
||||||
// ... and once for all the [*]*pjrt.Buffer.
|
// ... and once for all the [*]*pjrt.Buffer.
|
||||||
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
|
const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices);
|
||||||
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
|
const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices });
|
||||||
|
|
||||||
for (0..n_devices) |i| {
|
for (0..n_devices) |i| {
|
||||||
input_per_device[i] = all_input_buffers[i * args.n_in ..].ptr;
|
input_per_device[i] = all_input_buffers[i * n_in ..].ptr;
|
||||||
output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
|
output_per_device[i] = all_output_buffers[i * n_out ..].ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
||||||
|
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
||||||
|
@memcpy(all_shapes[n_in..], args.result_shapes);
|
||||||
return .{
|
return .{
|
||||||
.platform = platform,
|
.platform = platform,
|
||||||
.exe = exe,
|
.exe = exe,
|
||||||
.ready_buffer_count = 0,
|
.ready_buffer_count = 0,
|
||||||
.input_buffer_count = args.n_in,
|
.input_buffer_count = @intCast(n_in),
|
||||||
.num_devices = args.n_devices,
|
.num_devices = args.n_devices,
|
||||||
.input_per_device = input_per_device,
|
.input_per_device = input_per_device,
|
||||||
.output_per_device = output_per_device,
|
.output_per_device = output_per_device,
|
||||||
.result_shapes = try allocator.dupe(Shape, args.result_shapes),
|
.input_shapes = all_shapes[0..n_in],
|
||||||
|
.result_shapes = all_shapes[n_in..],
|
||||||
._arena = arena,
|
._arena = arena,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -209,7 +220,9 @@ pub const BaseExe = struct {
|
|||||||
// even if it has been marked as "can be donated" during compilation.
|
// even if it has been marked as "can be donated" during compilation.
|
||||||
// TODO: expose it ?
|
// TODO: expose it ?
|
||||||
.non_donatable_input_indices = &.{},
|
.non_donatable_input_indices = &.{},
|
||||||
}) catch unreachable;
|
}) catch |err| {
|
||||||
|
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
||||||
|
};
|
||||||
|
|
||||||
for (events[0..sharding.num_partitions]) |e| {
|
for (events[0..sharding.num_partitions]) |e| {
|
||||||
if (e) |ev| {
|
if (e) |ev| {
|
||||||
@ -232,7 +245,7 @@ pub const BaseExe = struct {
|
|||||||
// }
|
// }
|
||||||
|
|
||||||
pub fn prepare(self: *BaseExe, x: anytype) void {
|
pub fn prepare(self: *BaseExe, x: anytype) void {
|
||||||
const n = fillBuffers(&x, self.input_per_device, self.ready_buffer_count);
|
const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count);
|
||||||
self.ready_buffer_count += n;
|
self.ready_buffer_count += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -244,6 +257,14 @@ pub const BaseExe = struct {
|
|||||||
|
|
||||||
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
|
return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
||||||
|
return .init(parent_allocator, self.platform, self.exe, .{
|
||||||
|
.input_shapes = self.input_shapes,
|
||||||
|
.result_shapes = self.result_shapes,
|
||||||
|
.n_devices = self.num_devices,
|
||||||
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents a ZML function, compiled into a PJRT executable.
|
/// Represents a ZML function, compiled into a PJRT executable.
|
||||||
@ -280,7 +301,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
|
pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) {
|
||||||
const total_ready = fillBuffers(&args, self.inner.input_per_device, self.inner.ready_buffer_count);
|
const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count);
|
||||||
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
||||||
self.inner._unsafeCall();
|
self.inner._unsafeCall();
|
||||||
var result: Bufferized(ReturnT) = undefined;
|
var result: Bufferized(ReturnT) = undefined;
|
||||||
@ -302,20 +323,23 @@ fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
/// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor.
|
||||||
fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
|
fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
index: u32,
|
index: u32,
|
||||||
buffers: []const [*]*pjrt.Buffer,
|
buffers: []const [*]*pjrt.Buffer,
|
||||||
|
shapes: []const Shape,
|
||||||
};
|
};
|
||||||
var context: LocalContext = .{
|
var context: LocalContext = .{
|
||||||
.index = start,
|
.index = start,
|
||||||
.buffers = buffers,
|
.buffers = buffers,
|
||||||
|
.shapes = shapes,
|
||||||
};
|
};
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *const Buffer) void {
|
||||||
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
// stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index });
|
||||||
const model_sharding = ctx.buffers.len;
|
const model_sharding = ctx.buffers.len;
|
||||||
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len });
|
||||||
|
stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() });
|
||||||
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
for (buffer._shards.constSlice(), 0..) |shard, d| {
|
||||||
ctx.buffers[d][ctx.index] = shard;
|
ctx.buffers[d][ctx.index] = shard;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -305,11 +305,23 @@ pub const BFloat16 = packed struct(u16) {
|
|||||||
pub fn isInf(self: BFloat16) bool {
|
pub fn isInf(self: BFloat16) bool {
|
||||||
return allBitsOne(self.exponent) and self.mantissa == 0;
|
return allBitsOne(self.exponent) and self.mantissa == 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn toF32(self: BFloat16) f32 {
|
||||||
|
// Pad the BF16 with zeros 0
|
||||||
|
return @bitCast([2]u16{ 0, @bitCast(self) });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn fromF32(float32: f32) BFloat16 {
|
||||||
|
var int: u32 = @bitCast(float32);
|
||||||
|
// Round up if needed.
|
||||||
|
int += 0x8000;
|
||||||
|
const parts: [2]u16 = @bitCast(int);
|
||||||
|
return @bitCast(parts[1]);
|
||||||
|
}
|
||||||
|
|
||||||
const Helpers = FloatHelpers(@This());
|
const Helpers = FloatHelpers(@This());
|
||||||
pub const zero = Helpers.zero;
|
pub const zero = Helpers.zero;
|
||||||
pub const neg = Helpers.neg;
|
pub const neg = Helpers.neg;
|
||||||
pub const fromF32 = Helpers.fromF32;
|
|
||||||
pub const toF32 = Helpers.toF32;
|
|
||||||
pub const format = Helpers.format;
|
pub const format = Helpers.format;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -317,7 +329,7 @@ test BFloat16 {
|
|||||||
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
|
// From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
|
||||||
try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 });
|
try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 });
|
||||||
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
|
try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 });
|
||||||
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 });
|
try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 66 });
|
||||||
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
|
try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 });
|
||||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
|
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff });
|
||||||
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));
|
try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32)));
|
||||||
|
|||||||
@ -18,8 +18,8 @@ test {
|
|||||||
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
||||||
pub const HostBuffer = struct {
|
pub const HostBuffer = struct {
|
||||||
_shape: Shape,
|
_shape: Shape,
|
||||||
_strides: ?[Shape.MAX_RANK]i64 = null,
|
_strides: [Shape.MAX_RANK]i64,
|
||||||
data: []const u8,
|
_data: [*]const u8,
|
||||||
_memory: union(enum) {
|
_memory: union(enum) {
|
||||||
managed: std.mem.Alignment,
|
managed: std.mem.Alignment,
|
||||||
unmanaged,
|
unmanaged,
|
||||||
@ -28,10 +28,11 @@ pub const HostBuffer = struct {
|
|||||||
/// Allocates a HostBuffer with the given shape.
|
/// Allocates a HostBuffer with the given shape.
|
||||||
/// The memory is left undefined.
|
/// The memory is left undefined.
|
||||||
/// The caller owns the memory, and need to call `deinit()`.
|
/// The caller owns the memory, and need to call `deinit()`.
|
||||||
pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer {
|
pub fn empty(allocator: std.mem.Allocator, sh: Shape) error{OutOfMemory}!HostBuffer {
|
||||||
return .{
|
return .{
|
||||||
._shape = sh,
|
._shape = sh,
|
||||||
.data = try allocator.alignedAlloc(u8, 64, sh.byteSize()),
|
._strides = sh.computeStrides().buffer,
|
||||||
|
._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr,
|
||||||
._memory = .{ .managed = .@"64" },
|
._memory = .{ .managed = .@"64" },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -43,7 +44,8 @@ pub const HostBuffer = struct {
|
|||||||
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
|
stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len });
|
||||||
return .{
|
return .{
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
.data = data_,
|
._strides = shape_.computeStrides().buffer,
|
||||||
|
._data = data_.ptr,
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -53,7 +55,7 @@ pub const HostBuffer = struct {
|
|||||||
// This means we don't own the data.
|
// This means we don't own the data.
|
||||||
if (self._memory == .unmanaged) return;
|
if (self._memory == .unmanaged) return;
|
||||||
const log2_align = self._memory.managed;
|
const log2_align = self._memory.managed;
|
||||||
allocator.rawFree(@constCast(self.data), log2_align, @returnAddress());
|
allocator.rawFree(self.mutBytes(), log2_align, @returnAddress());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps an exisiting slice into a HostBuffer.
|
/// Wraps an exisiting slice into a HostBuffer.
|
||||||
@ -62,10 +64,12 @@ pub const HostBuffer = struct {
|
|||||||
/// that will still need to be deallocated.
|
/// that will still need to be deallocated.
|
||||||
pub fn fromSlice(sh: anytype, s: anytype) HostBuffer {
|
pub fn fromSlice(sh: anytype, s: anytype) HostBuffer {
|
||||||
const shape_ = Shape.init(sh, DataType.fromSliceElementType(s));
|
const shape_ = Shape.init(sh, DataType.fromSliceElementType(s));
|
||||||
std.debug.assert(shape_.count() == s.len);
|
const raw_bytes = std.mem.sliceAsBytes(s);
|
||||||
|
std.debug.assert(shape_.byteSize() == raw_bytes.len);
|
||||||
return .{
|
return .{
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
.data = @alignCast(std.mem.sliceAsBytes(s)),
|
._strides = shape_.computeStrides().buffer,
|
||||||
|
._data = raw_bytes.ptr,
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -81,7 +85,7 @@ pub const HostBuffer = struct {
|
|||||||
@memcpy(tmp[0..strides_.len], strides_);
|
@memcpy(tmp[0..strides_.len], strides_);
|
||||||
return .{
|
return .{
|
||||||
._shape = sh,
|
._shape = sh,
|
||||||
.data = @alignCast(std.mem.sliceAsBytes(s)),
|
._data = @alignCast(std.mem.sliceAsBytes(s).ptr),
|
||||||
._strides = tmp,
|
._strides = tmp,
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
@ -89,13 +93,15 @@ pub const HostBuffer = struct {
|
|||||||
|
|
||||||
/// Creates a tensor from a **pointer** to a "multi dimension" array.
|
/// Creates a tensor from a **pointer** to a "multi dimension" array.
|
||||||
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
|
/// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object.
|
||||||
|
/// Typically this is use with constant arrays.
|
||||||
pub fn fromArray(arr_ptr: anytype) HostBuffer {
|
pub fn fromArray(arr_ptr: anytype) HostBuffer {
|
||||||
const T = @TypeOf(arr_ptr.*);
|
const T = @TypeOf(arr_ptr.*);
|
||||||
const sh = parseArrayInfo(T);
|
const sh = parseArrayInfo(T);
|
||||||
|
std.debug.assert(sh.byteSize() == @sizeOf(T));
|
||||||
return .{
|
return .{
|
||||||
._shape = sh,
|
._shape = sh,
|
||||||
.data = @alignCast(std.mem.sliceAsBytes(arr_ptr)),
|
._strides = sh.computeStrides().buffer,
|
||||||
// Array are typically stack allocated and don't need to be freed.
|
._data = @ptrCast(arr_ptr),
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -121,16 +127,15 @@ pub const HostBuffer = struct {
|
|||||||
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||||
|
|
||||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||||
const b = dt.sizeOf();
|
|
||||||
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
const res = try empty(allocator, Shape.init(.{n_steps}, dt));
|
||||||
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
|
||||||
var data_ = @constCast(res.data);
|
|
||||||
switch (dt) {
|
switch (dt) {
|
||||||
inline else => {
|
inline else => |d| if (comptime d.class() != .integer) {
|
||||||
|
stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt});
|
||||||
|
} else {
|
||||||
|
const Zt = d.toZigType();
|
||||||
var j: i64 = args.start;
|
var j: i64 = args.start;
|
||||||
for (0..@intCast(n_steps)) |i| {
|
for (res.mutItems(Zt)) |*val| {
|
||||||
var v = Data.init(dt, j);
|
val.* = @intCast(j);
|
||||||
@memcpy(data_[i * b .. (i + 1) * b], v.constSlice());
|
|
||||||
j +%= args.step;
|
j +%= args.step;
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -160,16 +165,26 @@ pub const HostBuffer = struct {
|
|||||||
/// WARNING: It's only valid if the buffer is contiguous.
|
/// WARNING: It's only valid if the buffer is contiguous.
|
||||||
/// Strided buffers can't use this method.
|
/// Strided buffers can't use this method.
|
||||||
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
pub fn items(self: HostBuffer, comptime T: type) []const T {
|
||||||
if (DataType.fromZigType(T) != self.dtype()) {
|
// TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32.
|
||||||
std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) });
|
||||||
}
|
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
||||||
if (!self.isContiguous()) {
|
const ptr: [*]const T = @alignCast(@ptrCast(self._data));
|
||||||
std.debug.panic("{} isn't contiguous", .{self});
|
|
||||||
}
|
|
||||||
const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr));
|
|
||||||
return ptr[0..self._shape.count()];
|
return ptr[0..self._shape.count()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mutItems(self: HostBuffer, comptime T: type) []T {
|
||||||
|
return @constCast(self.items(T));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bytes(self: HostBuffer) []const u8 {
|
||||||
|
stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self});
|
||||||
|
return self._data[0..self._shape.byteSize()];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mutBytes(self: HostBuffer) []u8 {
|
||||||
|
return @constCast(self.bytes());
|
||||||
|
}
|
||||||
|
|
||||||
pub fn shape(self: HostBuffer) Shape {
|
pub fn shape(self: HostBuffer) Shape {
|
||||||
return self._shape;
|
return self._shape;
|
||||||
}
|
}
|
||||||
@ -178,9 +193,9 @@ pub const HostBuffer = struct {
|
|||||||
return self._shape.dtype();
|
return self._shape.dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn strides(self: *const HostBuffer) ?[]const i64 {
|
pub fn strides(self: *const HostBuffer) []const i64 {
|
||||||
// Pass strides per pointer otherwise we return a pointer to this stack frame.
|
// Pass strides per pointer otherwise we return a pointer to this stack frame.
|
||||||
return if (self._strides) |*strd| strd[0..self.rank()] else null;
|
return self._strides[0..self._shape.rank()];
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: rename .data into ._data and make it a [*]u8
|
// TODO: rename .data into ._data and make it a [*]u8
|
||||||
@ -205,7 +220,7 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn isContiguous(self: HostBuffer) bool {
|
pub fn isContiguous(self: HostBuffer) bool {
|
||||||
const _strides = self._strides orelse return true;
|
const _strides = self._strides;
|
||||||
const cont_strides = self._shape.computeStrides();
|
const cont_strides = self._shape.computeStrides();
|
||||||
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
|
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
|
||||||
if (d != 1 and stride != cont_stride) return false;
|
if (d != 1 and stride != cont_stride) return false;
|
||||||
@ -217,6 +232,7 @@ pub const HostBuffer = struct {
|
|||||||
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self});
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.reshape(shape_);
|
res._shape = self._shape.reshape(shape_);
|
||||||
|
res._strides = res._shape.computeStrides().buffer;
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -236,15 +252,12 @@ pub const HostBuffer = struct {
|
|||||||
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
||||||
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
||||||
|
|
||||||
// If strides weren't set it means original buffer is contiguous.
|
const offset: usize = @intCast(start * self._strides[ax]);
|
||||||
// But it won't be anymore after slicing. The strides don't change though.
|
const new_shape = self.shape().set(ax, end - start);
|
||||||
const _strides = self._strides orelse self._shape.computeStrides().buffer;
|
|
||||||
const offset: usize = @intCast(start * _strides[ax]);
|
|
||||||
return .{
|
return .{
|
||||||
._shape = self.shape().set(ax, end - start),
|
._shape = new_shape,
|
||||||
.data = self.data[offset..],
|
._data = self._data[offset..],
|
||||||
// When axis is 0, we stay contiguous.
|
._strides = self._strides,
|
||||||
._strides = if (ax == 0) self._strides else _strides,
|
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -254,18 +267,52 @@ pub const HostBuffer = struct {
|
|||||||
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
|
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn choose(self: HostBuffer, offsets: anytype) HostBuffer {
|
||||||
|
const off, const tags = Shape.parseDimensions(offsets);
|
||||||
|
var sh = self._shape;
|
||||||
|
var offset: i64 = 0;
|
||||||
|
for (off.constSlice(), tags.constSlice()) |o, t| {
|
||||||
|
const ax = sh.axis(t);
|
||||||
|
offset += o * self._strides[ax];
|
||||||
|
sh._dims.buffer[ax] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
var new_strides: [Shape.MAX_RANK]i64 = @splat(self.dtype().sizeOf());
|
||||||
|
|
||||||
|
// TODO rewrite with simd. This is a pshuf, but it's not supported by @shuffle.
|
||||||
|
var res_ax: u32 = 0;
|
||||||
|
for (0..self._shape.rank()) |ax| {
|
||||||
|
if (sh._dims.buffer[ax] == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
sh._dims.buffer[res_ax] = self._shape._dims.buffer[ax];
|
||||||
|
sh._tags.buffer[res_ax] = self._shape._tags.buffer[ax];
|
||||||
|
new_strides[res_ax] = self._strides[ax];
|
||||||
|
res_ax += 1;
|
||||||
|
}
|
||||||
|
sh._dims.len -= off.len;
|
||||||
|
sh._tags.len -= off.len;
|
||||||
|
|
||||||
|
return HostBuffer{
|
||||||
|
._shape = sh,
|
||||||
|
._strides = new_strides,
|
||||||
|
._data = self._data[@intCast(offset)..],
|
||||||
|
._memory = .unmanaged,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
||||||
const ax = self._shape.axis(axis_);
|
const ax = self._shape.axis(axis_);
|
||||||
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
||||||
|
|
||||||
var _strides: ?[Shape.MAX_RANK]i64 = self._strides;
|
var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() };
|
||||||
if (self._strides) |strydes| {
|
_ = strd.orderedRemove(ax);
|
||||||
std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]);
|
|
||||||
}
|
|
||||||
return .{
|
return .{
|
||||||
._shape = self.shape().drop(ax),
|
._shape = self.shape().drop(ax),
|
||||||
.data = self.data,
|
._data = self._data,
|
||||||
._strides = _strides,
|
._strides = strd.buffer,
|
||||||
._memory = self._memory,
|
._memory = self._memory,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -276,10 +323,13 @@ pub const HostBuffer = struct {
|
|||||||
options: std.fmt.FormatOptions,
|
options: std.fmt.FormatOptions,
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
_ = options;
|
||||||
|
if (std.mem.eql(u8, fmt, "v")) {
|
||||||
|
try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||||
|
} else {
|
||||||
try writer.print("HostBuffer(.{_})", .{self._shape});
|
try writer.print("HostBuffer(.{_})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
/// Formatter for a HostBuffer that also print the values not just the shape.
|
||||||
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
||||||
|
|||||||
128
zml/meta.zig
128
zml/meta.zig
@ -237,42 +237,48 @@ test mapAlloc {
|
|||||||
/// Recursively visit the given struct and calls the callback for each K found.
|
/// Recursively visit the given struct and calls the callback for each K found.
|
||||||
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
||||||
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||||
const T = @TypeOf(v);
|
|
||||||
const type_info_v = @typeInfo(T);
|
|
||||||
const K = switch (@typeInfo(FnParam(cb, 1))) {
|
|
||||||
.pointer => |info| info.child,
|
|
||||||
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
|
||||||
};
|
|
||||||
|
|
||||||
if (type_info_v != .pointer) {
|
|
||||||
const Callback = @TypeOf(cb);
|
const Callback = @TypeOf(cb);
|
||||||
stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T });
|
const Ptr = @TypeOf(v);
|
||||||
|
const type_info_v = @typeInfo(Ptr);
|
||||||
|
if (type_info_v != .pointer) {
|
||||||
|
stdx.debug.compileError("zml.meta.visit({}) is expecting a pointer/slice input, but received: {}", .{ Callback, Ptr });
|
||||||
}
|
}
|
||||||
const ptr_info = type_info_v.pointer;
|
const ptr_info = type_info_v.pointer;
|
||||||
if (@typeInfo(ptr_info.child) == .@"fn") return;
|
const Child = ptr_info.child;
|
||||||
if (ptr_info.child == anyopaque) return;
|
|
||||||
// This is important, because with trivial types like void,
|
|
||||||
// Zig sometimes decide to call `visit` at comptime, but can't do
|
|
||||||
// the pointer wrangling logic at comptime.
|
|
||||||
// So we detect early this case and return.
|
|
||||||
if (@sizeOf(ptr_info.child) == 0) return;
|
|
||||||
|
|
||||||
|
const K, const mutating_cb = switch (@typeInfo(FnParam(cb, 1))) {
|
||||||
|
.pointer => |info| .{ info.child, !info.is_const },
|
||||||
|
else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}),
|
||||||
|
};
|
||||||
|
// Abort if v doesnt' contain any K.
|
||||||
|
if (comptime !Contains(Ptr, K)) return;
|
||||||
|
|
||||||
|
// Handle simple cases.
|
||||||
|
switch (Ptr) {
|
||||||
|
*const K, *K => return cb(ctx, v),
|
||||||
|
*const ?K, *?K => return if (v.*) |*val| cb(ctx, val) else {},
|
||||||
|
[]const K, []K => {
|
||||||
|
for (v) |*v_elem| cb(ctx, v_elem);
|
||||||
|
return;
|
||||||
|
},
|
||||||
|
else => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle std.BoundedArray that contains uninitalized data.
|
||||||
|
if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) {
|
||||||
|
return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively visit fields of v.
|
||||||
switch (ptr_info.size) {
|
switch (ptr_info.size) {
|
||||||
// If we have a single pointer, two cases:
|
.one => switch (@typeInfo(Child)) {
|
||||||
// * It's a pointer to K, in which case we call the callback.
|
.@"struct" => |s| inline for (s.fields) |field| {
|
||||||
// * It's a pointer to something else, in which case, we explore and recurse if needed.
|
if (field.is_comptime or comptime !Contains(field.type, K)) continue;
|
||||||
.one => if (ptr_info.child == K) {
|
const field_type_info = @typeInfo(field.type);
|
||||||
cb(ctx, v);
|
|
||||||
} else if (ptr_info.child == ?K) {
|
|
||||||
if (v.*) |*val| cb(ctx, val);
|
|
||||||
} else switch (@typeInfo(ptr_info.child)) {
|
|
||||||
.@"struct" => |s| inline for (s.fields) |field_info| {
|
|
||||||
if (field_info.is_comptime) continue;
|
|
||||||
const field_type_info = @typeInfo(field_info.type);
|
|
||||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||||
switch (field_type_info) {
|
switch (field_type_info) {
|
||||||
.pointer => visit(cb, ctx, @field(v, field_info.name)),
|
.pointer => visit(cb, ctx, @field(v, field.name)),
|
||||||
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field_info.name)),
|
.array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field.name)),
|
||||||
else => {},
|
else => {},
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -281,23 +287,19 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
.@"union" => switch (v.*) {
|
.@"union" => switch (v.*) {
|
||||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||||
},
|
},
|
||||||
else => {},
|
else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
|
||||||
},
|
},
|
||||||
// If we have a slice, two cases also:
|
|
||||||
// * It's a slice of K, in which case we call the callback for each element of the slice.
|
|
||||||
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
|
|
||||||
.slice => {
|
.slice => {
|
||||||
for (v) |*v_elem| {
|
for (v) |*v_elem| {
|
||||||
if (ptr_info.child == K) {
|
switch (@typeInfo(Child)) {
|
||||||
cb(ctx, v_elem);
|
.@"struct" => |s| inline for (s.fields) |field| {
|
||||||
} else switch (@typeInfo(ptr_info.child)) {
|
if (field.is_comptime or comptime !Contains(field.type, K)) continue;
|
||||||
.@"struct" => |s| inline for (s.fields) |field_info| {
|
const field_type_info = @typeInfo(field.type);
|
||||||
const field_type_info = @typeInfo(field_info.type);
|
|
||||||
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
|
||||||
if (field_type_info == .pointer) {
|
if (field_type_info == .pointer) {
|
||||||
visit(cb, ctx, @field(v_elem, field_info.name));
|
visit(cb, ctx, @field(v_elem, field.name));
|
||||||
} else {
|
} else {
|
||||||
visit(cb, ctx, &@field(v_elem, field_info.name));
|
visit(cb, ctx, &@field(v_elem, field.name));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
.array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
.array => |_| for (v) |*elem| visit(cb, ctx, elem),
|
||||||
@ -305,11 +307,11 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
|||||||
.@"union" => switch (v_elem.*) {
|
.@"union" => switch (v_elem.*) {
|
||||||
inline else => |*v_field| visit(cb, ctx, v_field),
|
inline else => |*v_field| visit(cb, ctx, v_field),
|
||||||
},
|
},
|
||||||
else => {},
|
else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
else => {},
|
.many, .c => stdx.debug.compileError("zml.meta.visit({}) doesn't support [*] style pointers, got: {}", .{ Callback, Ptr }),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,7 +322,7 @@ test visit {
|
|||||||
const NestedAttrOptional = struct { nested: ?Attr };
|
const NestedAttrOptional = struct { nested: ?Attr };
|
||||||
const SimpleStruct = struct { prop: Attr };
|
const SimpleStruct = struct { prop: Attr };
|
||||||
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
|
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
|
||||||
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional };
|
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) };
|
||||||
|
|
||||||
const LocalContext = struct { result: usize };
|
const LocalContext = struct { result: usize };
|
||||||
|
|
||||||
@ -374,11 +376,16 @@ test visit {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
var context: LocalContext = .{ .result = 0 };
|
var context: LocalContext = .{ .result = 0 };
|
||||||
|
const prop5: std.BoundedArray(Attr, 8) = .{
|
||||||
|
.buffer = @splat(.{ .data = 4 }),
|
||||||
|
.len = 2,
|
||||||
|
};
|
||||||
const container: NestedTypesStruct = .{
|
const container: NestedTypesStruct = .{
|
||||||
.prop1 = .{ .data = 1 },
|
.prop1 = .{ .data = 1 },
|
||||||
.prop2 = .{ .other = "hello" },
|
.prop2 = .{ .other = "hello" },
|
||||||
.prop3 = .{ .nested = .{ .data = 2 } },
|
.prop3 = .{ .nested = .{ .data = 2 } },
|
||||||
.prop4 = .{ .nested = .{ .data = 3 } },
|
.prop4 = .{ .nested = .{ .data = 3 } },
|
||||||
|
.prop5 = prop5, // 4 will be counted twice.
|
||||||
};
|
};
|
||||||
|
|
||||||
visit((struct {
|
visit((struct {
|
||||||
@ -387,7 +394,7 @@ test visit {
|
|||||||
}
|
}
|
||||||
}).cb, &context, &container);
|
}).cb, &context, &container);
|
||||||
|
|
||||||
try std.testing.expectEqual(6, context.result);
|
try std.testing.expectEqual(14, context.result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -533,3 +540,36 @@ fn _CollectArg(func: anytype) type {
|
|||||||
const params = @typeInfo(@TypeOf(func)).@"fn".params;
|
const params = @typeInfo(@TypeOf(func)).@"fn".params;
|
||||||
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
return params[params.len - 1].type orelse @compileError("anytype not supported in collect");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn Contains(Haystack: type, T: type) bool {
|
||||||
|
switch (Haystack) {
|
||||||
|
T, ?T => return true,
|
||||||
|
*T, ?*T => return true,
|
||||||
|
*const T, ?*const T => return true,
|
||||||
|
[]const T, ?[]const T => return true,
|
||||||
|
anyopaque => return false,
|
||||||
|
else => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
return switch (@typeInfo(Haystack)) {
|
||||||
|
.@"struct" => |info| {
|
||||||
|
inline for (info.fields) |field| {
|
||||||
|
if (Contains(field.type, T))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
.@"union" => |info| {
|
||||||
|
inline for (info.fields) |field| {
|
||||||
|
if (Contains(field.type, T))
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
},
|
||||||
|
.array => |info| Contains(info.child, T),
|
||||||
|
.pointer => |info| Contains(info.child, T),
|
||||||
|
.optional => |info| Contains(info.child, T),
|
||||||
|
.vector => |info| Contains(info.child, T),
|
||||||
|
else => false,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|||||||
@ -29,7 +29,7 @@ test {
|
|||||||
|
|
||||||
pub const MlirFn = struct {
|
pub const MlirFn = struct {
|
||||||
name: []const u8,
|
name: []const u8,
|
||||||
num_args: u32,
|
args_shapes: []Shape,
|
||||||
res_tensors: *const anyopaque,
|
res_tensors: *const anyopaque,
|
||||||
res_types: []mlir.Type,
|
res_types: []mlir.Type,
|
||||||
res_shapes: []Shape,
|
res_shapes: []Shape,
|
||||||
@ -199,7 +199,7 @@ pub const CompilationContext = struct {
|
|||||||
const loaded_executable: *pjrt.LoadedExecutable = blk: {
|
const loaded_executable: *pjrt.LoadedExecutable = blk: {
|
||||||
if (pjrt_location) |pjrt_loc| {
|
if (pjrt_location) |pjrt_loc| {
|
||||||
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
|
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
|
||||||
log.info("Loaded pre-compiled module from {s}", .{pjrt_loc});
|
log.info("Loaded pre-compiled module from {s} (generated from {s}/module.mlir)", .{ pjrt_loc, module_dir.? });
|
||||||
break :blk exe;
|
break :blk exe;
|
||||||
} else |err| {
|
} else |err| {
|
||||||
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
|
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
|
||||||
@ -233,7 +233,7 @@ pub const CompilationContext = struct {
|
|||||||
self._platform,
|
self._platform,
|
||||||
loaded_executable,
|
loaded_executable,
|
||||||
.{
|
.{
|
||||||
.n_in = f.num_args,
|
.input_shapes = f.args_shapes,
|
||||||
.result_shapes = f.res_shapes,
|
.result_shapes = f.res_shapes,
|
||||||
.n_devices = sharding.num_replicas * sharding.num_partitions,
|
.n_devices = sharding.num_replicas * sharding.num_partitions,
|
||||||
},
|
},
|
||||||
@ -341,7 +341,7 @@ pub const CompilationContext = struct {
|
|||||||
const locations = try arena.alloc(mlir.Location, tensor_count);
|
const locations = try arena.alloc(mlir.Location, tensor_count);
|
||||||
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
||||||
|
|
||||||
var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count);
|
var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count);
|
||||||
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
||||||
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
||||||
|
|
||||||
@ -427,7 +427,7 @@ pub const CompilationContext = struct {
|
|||||||
return .{
|
return .{
|
||||||
.mlir_fn = mlir_fn,
|
.mlir_fn = mlir_fn,
|
||||||
.name = opts.name,
|
.name = opts.name,
|
||||||
.num_args = @intCast(tensor_count),
|
.args_shapes = input_shapes.items,
|
||||||
.res_tensors = fn_res,
|
.res_tensors = fn_res,
|
||||||
.res_types = fn_res_types,
|
.res_types = fn_res_types,
|
||||||
.res_shapes = fn_res_shapes,
|
.res_shapes = fn_res_shapes,
|
||||||
@ -512,7 +512,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||||
try std.testing.expectEqual(3, f.num_args);
|
try std.testing.expectEqual(3, f.args_shapes.len);
|
||||||
// We should have two buffers being donated.
|
// We should have two buffers being donated.
|
||||||
const template = "tf.aliasing_output = {d} : i32";
|
const template = "tf.aliasing_output = {d} : i32";
|
||||||
var buf = template.*;
|
var buf = template.*;
|
||||||
@ -540,9 +540,13 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn numPartitions(self: CompilationContext) u8 {
|
||||||
|
return self._platform.sharding().num_partitions;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
||||||
const ctx = self.mlirCtx();
|
const ctx = self.mlirCtx();
|
||||||
const num_partitions = self._platform.sharding().num_partitions;
|
const num_partitions = self.numPartitions();
|
||||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||||
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
||||||
@ -645,10 +649,11 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
const loc = self.mlirCtx().location(@src());
|
const loc = self.mlirCtx().location(@src());
|
||||||
|
|
||||||
const values = try arena.alloc(mlir.Value, function.num_args);
|
const num_args = function.args_shapes.len;
|
||||||
|
const values = try arena.alloc(mlir.Value, num_args);
|
||||||
self.extractValues(&args, values);
|
self.extractValues(&args, values);
|
||||||
|
|
||||||
const donations = try arena.alloc(Tensor._Donation, function.num_args);
|
const donations = try arena.alloc(Tensor._Donation, num_args);
|
||||||
meta.collectBuf(struct {
|
meta.collectBuf(struct {
|
||||||
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
|
pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation {
|
||||||
return ctx.getValueAndDonation(x)[1];
|
return ctx.getValueAndDonation(x)[1];
|
||||||
|
|||||||
@ -176,6 +176,8 @@ pub const RopeOpts = struct {
|
|||||||
/// Read a Rope scaling config from HF config.json format.
|
/// Read a Rope scaling config from HF config.json format.
|
||||||
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
|
pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling {
|
||||||
const content = try std.json.Value.jsonParse(allocator, source, options);
|
const content = try std.json.Value.jsonParse(allocator, source, options);
|
||||||
|
if (content == .null) return .default;
|
||||||
|
|
||||||
if (content != .object) return error.InvalidEnumTag;
|
if (content != .object) return error.InvalidEnumTag;
|
||||||
|
|
||||||
const obj = content.object;
|
const obj = content.object;
|
||||||
|
|||||||
@ -58,10 +58,10 @@ pub const Shape = struct {
|
|||||||
const fv = @field(v, field.name);
|
const fv = @field(v, field.name);
|
||||||
if (comptime stdx.meta.isInteger(field.type)) {
|
if (comptime stdx.meta.isInteger(field.type)) {
|
||||||
dims_.appendAssumeCapacity(@intCast(fv));
|
dims_.appendAssumeCapacity(@intCast(fv));
|
||||||
} else if (comptime isAutoDim(fv)) {
|
} else if (@TypeOf(fv) == EnumLiteral and comptime isAutoDim(fv)) {
|
||||||
dims_.appendAssumeCapacity(-1);
|
dims_.appendAssumeCapacity(-1);
|
||||||
} else {
|
} else {
|
||||||
stdx.debug.compileError("Field {s} should be an integer or an auto dimension", .{field.name});
|
stdx.debug.compileError("Field {s} should be an integer or an auto dimension, got {}", .{ field.name, field.type });
|
||||||
}
|
}
|
||||||
if (comptime stdx.meta.isTuple(T)) {
|
if (comptime stdx.meta.isTuple(T)) {
|
||||||
tags_.appendAssumeCapacity(TagUnknown);
|
tags_.appendAssumeCapacity(TagUnknown);
|
||||||
@ -186,7 +186,7 @@ pub const Shape = struct {
|
|||||||
EnumLiteral => @tagName(v).ptr,
|
EnumLiteral => @tagName(v).ptr,
|
||||||
std.builtin.Type.StructField => v.name.ptr,
|
std.builtin.Type.StructField => v.name.ptr,
|
||||||
Tag => v,
|
Tag => v,
|
||||||
else => stdx.debug.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
else => stdx.debug.compileError("Shape tag should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -581,6 +581,41 @@ pub const Shape = struct {
|
|||||||
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
|
try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn removeMany(self: Shape, axes_: anytype) Shape {
|
||||||
|
var to_remove = self.axes(axes_);
|
||||||
|
if (to_remove.len == 0) return self;
|
||||||
|
std.mem.sort(u3, to_remove.slice(), {}, std.sort.asc(u3));
|
||||||
|
|
||||||
|
var sh: Shape = self;
|
||||||
|
const rk = self.rank();
|
||||||
|
var res_ax: u32 = 0;
|
||||||
|
for (0..rk) |ax| {
|
||||||
|
if (std.mem.indexOfScalar(u3, to_remove.constSlice(), @intCast(ax))) |_| {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
sh._dims.buffer[res_ax] = self._dims.buffer[ax];
|
||||||
|
sh._tags.buffer[res_ax] = self._tags.buffer[ax];
|
||||||
|
res_ax += 1;
|
||||||
|
}
|
||||||
|
sh._dims.len = rk - to_remove.len;
|
||||||
|
sh._tags.len = rk - to_remove.len;
|
||||||
|
return sh;
|
||||||
|
}
|
||||||
|
|
||||||
|
test removeMany {
|
||||||
|
try std.testing.expectEqualSlices(
|
||||||
|
i64,
|
||||||
|
&.{12},
|
||||||
|
Shape.init(.{ 10, 11, 12 }, .f32).removeMany(.{ 0, 1 }).dims(),
|
||||||
|
);
|
||||||
|
try std.testing.expectEqualSlices(
|
||||||
|
i64,
|
||||||
|
&.{ 10, 11 },
|
||||||
|
Shape.init(.{ 10, 11, 12, 13 }, .f32).removeMany(.{ -1, -2 }).dims(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn transpose(self: Shape, permutations: anytype) Shape {
|
pub fn transpose(self: Shape, permutations: anytype) Shape {
|
||||||
std.debug.assert(self.rank() == permutations.len);
|
std.debug.assert(self.rank() == permutations.len);
|
||||||
const permutations_ = self.axes(permutations);
|
const permutations_ = self.axes(permutations);
|
||||||
@ -729,7 +764,9 @@ pub const Shape = struct {
|
|||||||
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
|
stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T});
|
||||||
var res = self;
|
var res = self;
|
||||||
inline for (std.meta.fields(T)) |field| {
|
inline for (std.meta.fields(T)) |field| {
|
||||||
res._tags.set(self.axis(field), toTag(@field(renames, field.name)));
|
const new_field = @field(renames, field.name);
|
||||||
|
stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field });
|
||||||
|
res._tags.set(self.axis(field), toTag(new_field));
|
||||||
}
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
@ -749,15 +786,20 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
|
pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) {
|
||||||
const base_stride = self.dtype().sizeOf();
|
|
||||||
const rk = self.rank();
|
const rk = self.rank();
|
||||||
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) };
|
var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk };
|
||||||
if (rk == 0) return strides;
|
if (rk == 0) return strides;
|
||||||
strides.buffer[rk - 1] = base_stride;
|
|
||||||
for (1..rk) |i| {
|
const V = @Vector(MAX_RANK, i64);
|
||||||
const j = @as(usize, rk) - 1 - i;
|
const rank_mask = std.simd.iota(u8, MAX_RANK) < @as(@Vector(MAX_RANK, u8), @splat(rk));
|
||||||
strides.buffer[j] = self._dims.get(j + 1) * strides.buffer[j + 1];
|
// For each axis compute the product of all following dimensions
|
||||||
}
|
// and the element size in bytes.
|
||||||
|
var d: V = @bitCast(self._dims.buffer);
|
||||||
|
d = @select(i64, rank_mask, d, @as(V, @splat(1)));
|
||||||
|
d = std.simd.shiftElementsLeft(d, 1, self.dtype().sizeOf());
|
||||||
|
d = std.simd.prefixScan(.Mul, -1, d);
|
||||||
|
|
||||||
|
strides.buffer = @bitCast(d);
|
||||||
return strides;
|
return strides;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -176,6 +176,7 @@ pub const Tensor = struct {
|
|||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.withSharding(axes_);
|
res._shape = self._shape.withSharding(axes_);
|
||||||
|
|
||||||
|
if (ctx.numPartitions() <= 1) return self;
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
&.{self.value()},
|
&.{self.value()},
|
||||||
@ -1279,9 +1280,9 @@ pub const Tensor = struct {
|
|||||||
/// see: https://paperswithcode.com/method/gelu
|
/// see: https://paperswithcode.com/method/gelu
|
||||||
pub fn gelu(x: Tensor) Tensor {
|
pub fn gelu(x: Tensor) Tensor {
|
||||||
const scaled_x_cube = x.mul(x).mul(x).scale(0.044715);
|
const scaled_x_cube = x.mul(x).mul(x).scale(0.044715);
|
||||||
const one = Tensor.constant(x._shape, x.dtype().one());
|
const beta = std.math.sqrt(2.0 / std.math.pi);
|
||||||
const one_plus_tanh = Tensor.add(x, scaled_x_cube).scale(std.math.sqrt(2.0 / std.math.pi)).tanh().add(one);
|
const tanh_ = x.add(scaled_x_cube).scale(beta).tanh();
|
||||||
return one_plus_tanh.mul(x).scale(0.5);
|
return tanh_.addConstant(1).mul(x).scale(0.5);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
|
/// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor.
|
||||||
@ -1526,8 +1527,34 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
pub const Slice = struct {
|
pub const Slice = struct {
|
||||||
start: i64 = 0,
|
start: i64 = 0,
|
||||||
end: ?i64 = null,
|
end: i64 = to_the_end,
|
||||||
step: i64 = 1,
|
step: i32 = 1,
|
||||||
|
singleton: bool = false,
|
||||||
|
|
||||||
|
pub fn single(offset: i64) Slice {
|
||||||
|
return .{ .start = offset, .end = offset + 1, .singleton = true };
|
||||||
|
}
|
||||||
|
|
||||||
|
const to_the_end = std.math.maxInt(i64);
|
||||||
|
|
||||||
|
pub fn format(
|
||||||
|
self: Slice,
|
||||||
|
comptime fmt: []const u8,
|
||||||
|
options: std.fmt.FormatOptions,
|
||||||
|
writer: anytype,
|
||||||
|
) !void {
|
||||||
|
_ = fmt;
|
||||||
|
_ = options;
|
||||||
|
if (self.singleton) {
|
||||||
|
try writer.print("[{}]", .{self.start});
|
||||||
|
} else if (self.end == to_the_end and self.step == 1) {
|
||||||
|
try writer.print("[{}..]", .{self.start});
|
||||||
|
} else if (self.step == 1) {
|
||||||
|
try writer.print("[{}..{}]", .{ self.start, self.end });
|
||||||
|
} else {
|
||||||
|
try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step });
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Slices the input Tensor over the given axis using the given parameters.
|
/// Slices the input Tensor over the given axis using the given parameters.
|
||||||
@ -1549,13 +1576,13 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const args: Slice = .{
|
const args: Slice = .{
|
||||||
.start = self.wrapIndex(a, s.start),
|
.start = self.wrapIndex(a, s.start),
|
||||||
.end = if (s.end) |end| self.wrapIndex(a, end) else self.dim(a),
|
.end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end),
|
||||||
.step = s.step,
|
.step = s.step,
|
||||||
};
|
};
|
||||||
start_indices[a] = args.start;
|
start_indices[a] = args.start;
|
||||||
limit_indices[a] = args.end.?;
|
limit_indices[a] = args.end;
|
||||||
strides[a] = args.step;
|
strides[a] = args.step;
|
||||||
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable);
|
res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable);
|
||||||
}
|
}
|
||||||
|
|
||||||
const mlir_ctx = self.getContext().mlirCtx();
|
const mlir_ctx = self.getContext().mlirCtx();
|
||||||
@ -1571,7 +1598,12 @@ pub const Tensor = struct {
|
|||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
return _result(res_shape, slice_op.result(0));
|
var res = _result(res_shape, slice_op.result(0));
|
||||||
|
var to_remove: Shape.AxesArray = .{};
|
||||||
|
for (slices, 0..) |s, a| {
|
||||||
|
if (s.singleton) to_remove.appendAssumeCapacity(@intCast(a));
|
||||||
|
}
|
||||||
|
return res.reshape(res_shape.removeMany(to_remove.constSlice()));
|
||||||
}
|
}
|
||||||
|
|
||||||
test slice {
|
test slice {
|
||||||
@ -1606,8 +1638,17 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor {
|
pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor {
|
||||||
// TODO: this use case could be handled directly by slice if we added a .single field
|
return self.slice1d(axis_, .single(i));
|
||||||
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_);
|
}
|
||||||
|
|
||||||
|
pub fn choose(self: Tensor, offsets: anytype) Tensor {
|
||||||
|
const off, const tags = Shape.parseDimensions(offsets);
|
||||||
|
var slices = [_]Slice{.{}} ** MAX_RANK;
|
||||||
|
for (off.constSlice(), tags.constSlice()) |o, t| {
|
||||||
|
const ax = self.axis(t);
|
||||||
|
slices[ax] = .single(o);
|
||||||
|
}
|
||||||
|
return self.slice(slices[0..self.rank()]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Concatenates the input Tensors along the given axis.
|
/// Concatenates the input Tensors along the given axis.
|
||||||
@ -1866,7 +1907,12 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a 0-rank Tensor with the given value.
|
/// Returns a 0-rank Tensor with the given value.
|
||||||
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
||||||
return Tensor.constant(.{}, Data.init(dt, val));
|
const data = Data.init(dt, val);
|
||||||
|
switch (dt.class()) {
|
||||||
|
.float => stdx.debug.assert(!std.math.isNan(val), "scalar(NaN) is probably due to compiling a model with an uninitialized field", .{}),
|
||||||
|
else => {},
|
||||||
|
}
|
||||||
|
return Tensor.constant(.{}, data);
|
||||||
}
|
}
|
||||||
|
|
||||||
test scalar {
|
test scalar {
|
||||||
@ -1913,7 +1959,7 @@ pub const Tensor = struct {
|
|||||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
||||||
const loc = ctx.location(@src());
|
const loc = ctx.location(@src());
|
||||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
||||||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc);
|
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc);
|
||||||
return _result(val.shape(), constant_op.result(0));
|
return _result(val.shape(), constant_op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3786,6 +3832,7 @@ pub const Tensor = struct {
|
|||||||
/// Only for debug purpose, it inserts device to host synchronization
|
/// Only for debug purpose, it inserts device to host synchronization
|
||||||
/// so it will slow down the program execution.
|
/// so it will slow down the program execution.
|
||||||
pub fn print(input: Tensor) Tensor {
|
pub fn print(input: Tensor) Tensor {
|
||||||
|
// TODO: find a way of doing print that doesn't involve a H2D copy.
|
||||||
return ops.addHostCallback(
|
return ops.addHostCallback(
|
||||||
&printCallback,
|
&printCallback,
|
||||||
null,
|
null,
|
||||||
@ -3797,8 +3844,10 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
||||||
const host_buffer = inputs[0];
|
const host_buffer = inputs[0];
|
||||||
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||||
std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr);
|
// This is true because of the operand aliases.
|
||||||
|
// Since the result is already pointing to the input we don't need to modify the buffer.
|
||||||
|
std.debug.assert(host_buffer._data == outputs[0]._data);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3918,6 +3967,10 @@ test "Tensor.maxPool2d" {
|
|||||||
|
|
||||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||||
pub fn Bufferized(comptime T: type) type {
|
pub fn Bufferized(comptime T: type) type {
|
||||||
|
// TODO: we should strip out the non-buffer fields.
|
||||||
|
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||||
|
// Also it will simplify the layout of the Bufferized struct.
|
||||||
|
// accelerating the calls to execute.
|
||||||
return meta.MapType(Tensor, Buffer).map(T);
|
return meta.MapType(Tensor, Buffer).map(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const shapesOf = @import("tensor.zig").shapesOf;
|
const shapesOf = @import("tensor.zig").shapesOf;
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
const log = std.log.scoped(.@"zml/testing");
|
const log = std.log.scoped(.@"zml/testing");
|
||||||
|
|
||||||
@ -35,7 +36,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool
|
|||||||
/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the
|
/// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the
|
||||||
/// host for comparison !
|
/// host for comparison !
|
||||||
pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||||
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.page_allocator;
|
const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator;
|
||||||
var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer)
|
var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer)
|
||||||
.{ try left_.toHostAlloc(allocator), true }
|
.{ try left_.toHostAlloc(allocator), true }
|
||||||
else
|
else
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user