From 3849eb10b75fb12369107e56c1d268ded37cc0db Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 28 Oct 2024 11:21:46 +0000 Subject: [PATCH] =?UTF-8?q?Add=20buffer=20and=20hostbuffer=20utilities=20w?= =?UTF-8?q?ith=20precise=20f32=E2=86=92bf16=20conversion,=20type=20inferen?= =?UTF-8?q?ce=20for=20loadBuffers,=20store=20expected=20input=20shapes,=20?= =?UTF-8?q?enhance=20meta.visit=20and=20JSON=20TaggedUnion=20support,=20an?= =?UTF-8?q?d=20improve=20logging.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pjrt/pjrt.zig | 4 +- stdx/json.zig | 99 +++++++++++++++++++++++++++++--- zml/aio.zig | 25 ++++---- zml/buffer.zig | 43 ++++++-------- zml/dtype.zig | 11 ++-- zml/exe.zig | 50 +++++++++++----- zml/floats.zig | 18 +++++- zml/hostbuffer.zig | 138 ++++++++++++++++++++++++++++++--------------- zml/meta.zig | 128 ++++++++++++++++++++++++++--------------- zml/module.zig | 23 +++++--- zml/nn.zig | 2 + zml/shape.zig | 64 +++++++++++++++++---- zml/tensor.zig | 83 ++++++++++++++++++++++----- zml/testing.zig | 7 ++- 14 files changed, 497 insertions(+), 198 deletions(-) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 2b0e65b..df1a7d4 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -350,7 +350,7 @@ pub const Client = opaque { } pub const BufferFromHostBufferArgs = struct { - data: []const u8, + data: [*]const u8, buffer_type: BufferType, dims: []const i64, byte_strides: ?[]const i64, @@ -362,7 +362,7 @@ pub const Client = opaque { pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { const ret = try api.call(.PJRT_Client_BufferFromHostBuffer, .{ .client = self.inner(), - .data = @ptrCast(@constCast(args.data.ptr)), + .data = @constCast(args.data), .type = @intFromEnum(args.buffer_type), .dims = @ptrCast(@constCast(args.dims.ptr)), .num_dims = args.dims.len, diff --git a/stdx/json.zig b/stdx/json.zig index 3b3f64d..c99cfb0 100644 --- a/stdx/json.zig +++ b/stdx/json.zig @@ -1,25 +1,46 @@ pub const std = @import("std"); +const ParseFromValueError = std.json.ParseFromValueError; +/// Handle json fields that can have different Zig types depending on the message. +/// Each union field should have a unique Zig type. +/// +/// Example json: +/// +/// ```json +/// [ +/// { "question": "How old are you ?", "answer": 5 }, +/// { "question": "Count to three.", "answer": [1, 2, 3] }, +/// ] +/// ``` +/// +/// Corresponding Zig code: +/// +/// ```zig +/// const Answer = union { +/// number: i32, +/// numbers: []const i32, +/// }; +/// +/// const Message = struct { +/// question: []const u8; +/// answer: stdx.json.Union(Answer); +/// } +/// ``` pub fn Union(comptime T: type) type { return struct { const Self = @This(); value: T, - pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Self { + pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self { return jsonParseFromValue( allocator, - try std.json.innerParse( - std.json.Value, - allocator, - source, - options, - ), + try std.json.innerParse(std.json.Value, allocator, source, options), options, ); } - pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) !Self { + pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self { inline for (std.meta.fields(T)) |field| { switch (field.type) { bool => if (source == .bool) return .{ .value = @unionInit(T, field.name, source.bool) }, @@ -39,7 +60,67 @@ pub fn Union(comptime T: type) type { }, } } - return error.UnexpectedToken; + return error.UnknownField; + } + }; +} + +/// Handle json fields that can have different Zig types depending on another field in the same message. +/// This is translated to a Zig tagged union. +/// +/// Example json: +/// +/// ```json +/// [ +/// { "type": "faq", "question": "How old are you ?", "answer": 5 }, +/// { "type": "address", "city": "NYC", "zipcode": "49130"}, +/// ] +/// ``` +/// +/// Corresponding Zig struct: +/// +/// ```zig +/// const Entry = union { +/// faq: struct { question: []const u8, answer: u32 }, +/// address: struct { city: []const u8, zipcode: []const u8 }, +/// }; +/// +/// const Message = []const stdx.json.TaggedUnion(Entry, "type"); +/// ``` +pub fn TaggedUnion(comptime T: type, comptime tag_name: [:0]const u8) type { + return struct { + const Self = @This(); + + value: T, + + pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) std.json.ParseError(@TypeOf(source.*))!Self { + return jsonParseFromValue( + allocator, + try std.json.innerParse(std.json.Value, allocator, source, options), + options, + ); + } + + pub fn jsonParseFromValue(allocator: std.mem.Allocator, source: std.json.Value, options: std.json.ParseOptions) ParseFromValueError!Self { + errdefer std.log.warn("failed to parse: {} as {s}", .{ source, @typeName(T) }); + if (source != .object) return error.UnexpectedToken; + const o = source.object; + const tag = o.get(tag_name) orelse return error.MissingField; + for (o.keys(), o.values()) |k, v| { + std.log.warn("object['{s}'] = {}", .{ k, v }); + } + if (tag != .string) return error.LengthMismatch; + inline for (std.meta.fields(T)) |field| { + if (std.mem.eql(u8, field.name, tag.string)) { + const inner_source = o.get(field.name) orelse return error.MissingField; + const inner: field.type = std.json.innerParseFromValue(field.type, allocator, inner_source, options) catch |err| { + std.log.warn("failed to interpret {s} as a {s}: {}", .{ tag.string, @typeName(field.type), err }); + return err; + }; + return .{ .value = @unionInit(T, field.name, inner) }; + } + } + return error.InvalidEnumTag; } }; } diff --git a/zml/aio.zig b/zml/aio.zig index 9829990..8ffa372 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -1,11 +1,9 @@ -const asynk = @import("async"); -const builtin = @import("builtin"); -const c = @import("c"); const std = @import("std"); -const stdx = @import("stdx"); +const builtin = @import("builtin"); -const zml = @import("zml.zig"); -const posix = @import("posix.zig"); +const asynk = @import("async"); +const c = @import("c"); +const stdx = @import("stdx"); pub const gguf = @import("aio/gguf.zig"); pub const nemo = @import("aio/nemo.zig"); @@ -13,10 +11,11 @@ pub const safetensors = @import("aio/safetensors.zig"); pub const tinyllama = @import("aio/tinyllama.zig"); pub const torch = @import("aio/torch.zig"); pub const yaml = @import("aio/yaml.zig"); +const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const posix = @import("posix.zig"); +const zml = @import("zml.zig"); pub const log = std.log.scoped(.@"zml/aio"); -const HostBuffer = @import("hostbuffer.zig").HostBuffer; - test { std.testing.refAllDecls(@This()); std.testing.refAllDecls(gguf); @@ -26,6 +25,8 @@ test { std.testing.refAllDecls(yaml); } +// TODO error set for weight loading + /// Detects the format of the model file (base on filename) and open it. pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore { return if (std.mem.endsWith(u8, model_path, ".safetensors")) @@ -422,7 +423,7 @@ fn _populateStruct( return true; }, .float => { - obj.* = undefined; + obj.* = std.math.nan(@TypeOf(obj.*)); return true; }, .void => true, @@ -450,7 +451,7 @@ test populateModel { // Create a fake HostBuffer, we use the given integer to identify the created buffer. fn _newHostBuffer(n: u32) zml.HostBuffer { - return .{ ._shape = zml.Shape.init(.{n}, .f16), .data = undefined }; + return .{ ._shape = zml.Shape.init(.{n}, .f16), ._strides = undefined, ._data = undefined }; } }; @@ -500,7 +501,7 @@ test populateModel { /// The `init_args` are used to initialize the non Buffer fields, using `Model.init` function. pub fn loadBuffers( comptime Model: type, - init_args: anytype, + init_args: if (@hasDecl(Model, "init")) stdx.meta.Tail(stdx.meta.FnArgs(Model.init)) else void, buffer_store: BufferStore, allocator: std.mem.Allocator, platform: zml.Platform, @@ -513,8 +514,6 @@ pub fn loadBuffers( // If the Model has a "init" function, call it with the given parameters. if (@hasDecl(Model, "init")) { @call(.auto, Model.init, .{&model} ++ init_args); - } else { - stdx.debug.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model}); } return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); diff --git a/zml/buffer.zig b/zml/buffer.zig index e0829d3..1fc57ef 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -44,23 +44,6 @@ pub const Buffer = struct { } }; - pub const Shard = struct { - api: *const pjrt.Api, - buffer: *pjrt.Buffer, - ready_event: ?*pjrt.Event = null, - ready: bool = false, - - pub fn awaitt(self: *Shard) !void { - if (self.ready) { - return; - } - if (self.ready_event orelse self.buffer.getReadyEvent(self.api)) |ev| { - try ev.awaitt(self.api); - } - self.ready = true; - } - }; - _shape: Shape, _api: *const pjrt.Api, _shards: Shards, @@ -88,7 +71,7 @@ pub const Buffer = struct { } else 0; const buffer_type = bufferTypeFromDtype(host_buffer.shape().dtype()); - const byte_strides = host_buffer.strides() orelse host_buffer.shape().computeStrides().constSlice(); + const byte_strides = host_buffer.strides(); var frames: std.BoundedArray(asynk.Frame(pjrt.Client.bufferFromHostBuffer), MAX_NUM_SHARDS) = .{}; const devices = platform.getDevices(); @@ -103,7 +86,7 @@ pub const Buffer = struct { platform.pjrt_client, platform.pjrt_api, pjrt.Client.BufferFromHostBufferArgs{ - .data = buf.data, + .data = buf._data, .buffer_type = buffer_type, .dims = buf.shape().dims(), .byte_strides = byte_strides, @@ -155,6 +138,14 @@ pub const Buffer = struct { return try from(platform, host_buffer); } + pub fn asPinnedHostBuffer(self: Buffer) HostBuffer { + // TODO restore assert + // const memory = self.getMemory().kind(self._api); + // stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory }); + const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable); + return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]); + } + /// Creates a Buffer with a single element. pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer { const x = dtype_.constant(val); @@ -182,8 +173,8 @@ pub const Buffer = struct { if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) { const host_buffer: HostBuffer = .{ ._shape = shape_, - ._strides = [1]i64{0} ** Shape.MAX_RANK, - .data = x.constSlice(), + ._strides = @splat(0), + ._data = x.constSlice().ptr, }; return try from(platform, host_buffer); } @@ -207,7 +198,7 @@ pub const Buffer = struct { }, else => unreachable, } - const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes }; + const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes }; return try from(platform, host_buffer); } @@ -228,12 +219,12 @@ pub const Buffer = struct { /// could lead to crashes and operations on the buffer will be slower. /// Tested on Cuda 12.4. pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer { - return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr))); + return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(buf._data)); } /// Creates a Buffer from a pointer into device memory. /// This allows to interface with other libraries producing buffers. - pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer { + pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?isize, device_data: *anyopaque) Buffer { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { var res: [Shape.MAX_RANK]i64 = undefined; for (0..Shape.MAX_RANK) |i| { @@ -255,7 +246,7 @@ pub const Buffer = struct { .tile_dims_sizes = &.{}, }, }, - .stream = @bitCast(@as(usize, @intFromPtr(stream))), + .stream = stream, }) catch @panic("failed to createViewOfDeviceBuffer"); var shards: Shards = .{}; @@ -296,7 +287,7 @@ pub const Buffer = struct { pub fn toHostAlloc(self: Buffer, allocator: std.mem.Allocator) !HostBuffer { const output = try HostBuffer.empty(allocator, self.shape()); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer -> Host transfer", .{}); - const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.data)); + const maybe_event = try self._shards.get(0).toHostBuffer(self._api, @constCast(output.bytes())); if (maybe_event) |event| { try event.await_(self._api); } diff --git a/zml/dtype.zig b/zml/dtype.zig index 1f1c046..b0deb8f 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -1,4 +1,5 @@ const std = @import("std"); + const floats = @import("floats.zig"); const C64 = std.math.Complex(f32); @@ -111,9 +112,7 @@ pub const DataType = enum(u8) { } pub fn toZigType(comptime dtype: DataType) type { - return switch (dtype) { - inline else => |tag| std.meta.TagPayload(Data, tag), - }; + return @FieldType(Data, @tagName(dtype)); } pub fn isSignedInt(dtype: DataType) bool { @@ -125,19 +124,19 @@ pub const DataType = enum(u8) { pub fn sizeOf(self: DataType) u16 { return switch (self) { - inline else => |tag| @sizeOf(std.meta.TagPayload(Data, tag)), + inline else => |tag| @sizeOf(tag.toZigType()), }; } pub fn bitSizeOf(self: DataType) u16 { return switch (self) { - inline else => |tag| @bitSizeOf(std.meta.TagPayload(Data, tag)), + inline else => |tag| @bitSizeOf(tag.toZigType()), }; } pub fn alignOf(self: DataType) u29 { return switch (self) { - inline else => |tag| @alignOf(std.meta.TagPayload(Data, tag)), + inline else => |tag| @alignOf(tag.toZigType()), }; } diff --git a/zml/exe.zig b/zml/exe.zig index 946410c..2292b70 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -1,13 +1,13 @@ const std = @import("std"); + const stdx = @import("stdx"); const aio = @import("aio.zig"); -const meta = @import("meta.zig"); -const pjrt = @import("pjrtx.zig"); - const Buffer = @import("buffer.zig").Buffer; const Bufferized = @import("tensor.zig").Bufferized; const CompilationContext = @import("module.zig").CompilationContext; +const meta = @import("meta.zig"); +const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; const ShapeOf = @import("tensor.zig").ShapeOf; @@ -147,6 +147,7 @@ pub const BaseExe = struct { /// Total number of buffers needed by this executable. input_buffer_count: u32, + input_shapes: []Shape, result_shapes: []Shape, /// Num devices used (>1 for sharded executable) @@ -155,34 +156,44 @@ pub const BaseExe = struct { /// Allocator backing memory _arena: std.heap.ArenaAllocator, - pub fn init(parent_allocator: std.mem.Allocator, platform: Platform, exe: *pjrt.LoadedExecutable, args: struct { n_in: u32, result_shapes: []const Shape, n_devices: u8 }) !BaseExe { + pub fn init( + parent_allocator: std.mem.Allocator, + platform: Platform, + exe: *pjrt.LoadedExecutable, + args: struct { input_shapes: []const Shape, result_shapes: []const Shape, n_devices: u8 }, + ) !BaseExe { var arena = std.heap.ArenaAllocator.init(parent_allocator); errdefer arena.deinit(); const allocator = arena.allocator(); + const n_in = args.input_shapes.len; const n_out = args.result_shapes.len; const n_devices = args.n_devices; // Allocate once for all the *pjrt.Buffer we need to store ... - const all_buffers = try allocator.alloc(*pjrt.Buffer, (args.n_in + n_out) * n_devices); - const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ args.n_in * n_devices, n_out * n_devices }); + const all_buffers = try allocator.alloc(*pjrt.Buffer, (n_in + n_out) * n_devices); + const all_input_buffers, const all_output_buffers = splitBuffer(*pjrt.Buffer, all_buffers, .{ n_in * n_devices, n_out * n_devices }); // ... and once for all the [*]*pjrt.Buffer. const all_per_device = try allocator.alloc([*]*pjrt.Buffer, 2 * n_devices); const input_per_device, const output_per_device = splitBuffer([*]*pjrt.Buffer, all_per_device, .{ n_devices, n_devices }); for (0..n_devices) |i| { - input_per_device[i] = all_input_buffers[i * args.n_in ..].ptr; + input_per_device[i] = all_input_buffers[i * n_in ..].ptr; output_per_device[i] = all_output_buffers[i * n_out ..].ptr; } + const all_shapes = try allocator.alloc(Shape, n_in + n_out); + @memcpy(all_shapes[0..n_in], args.input_shapes); + @memcpy(all_shapes[n_in..], args.result_shapes); return .{ .platform = platform, .exe = exe, .ready_buffer_count = 0, - .input_buffer_count = args.n_in, + .input_buffer_count = @intCast(n_in), .num_devices = args.n_devices, .input_per_device = input_per_device, .output_per_device = output_per_device, - .result_shapes = try allocator.dupe(Shape, args.result_shapes), + .input_shapes = all_shapes[0..n_in], + .result_shapes = all_shapes[n_in..], ._arena = arena, }; } @@ -209,7 +220,9 @@ pub const BaseExe = struct { // even if it has been marked as "can be donated" during compilation. // TODO: expose it ? .non_donatable_input_indices = &.{}, - }) catch unreachable; + }) catch |err| { + std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err}); + }; for (events[0..sharding.num_partitions]) |e| { if (e) |ev| { @@ -232,7 +245,7 @@ pub const BaseExe = struct { // } pub fn prepare(self: *BaseExe, x: anytype) void { - const n = fillBuffers(&x, self.input_per_device, self.ready_buffer_count); + const n = fillBuffers(&x, self.input_shapes, self.input_per_device, self.ready_buffer_count); self.ready_buffer_count += n; } @@ -244,6 +257,14 @@ pub const BaseExe = struct { return Buffer.fromPjrtBuffers(self.platform, self.result_shapes[i], shards.constSlice()); } + + pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe { + return .init(parent_allocator, self.platform, self.exe, .{ + .input_shapes = self.input_shapes, + .result_shapes = self.result_shapes, + .n_devices = self.num_devices, + }); + } }; /// Represents a ZML function, compiled into a PJRT executable. @@ -280,7 +301,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type { } pub fn call(self: Self, args: Bufferized(ArgsT)) Bufferized(ReturnT) { - const total_ready = fillBuffers(&args, self.inner.input_per_device, self.inner.ready_buffer_count); + const total_ready = fillBuffers(&args, self.inner.input_shapes, self.inner.input_per_device, self.inner.ready_buffer_count); std.debug.assert(total_ready == self.inner.input_buffer_count); self.inner._unsafeCall(); var result: Bufferized(ReturnT) = undefined; @@ -302,20 +323,23 @@ fn splitBuffer(T: type, buffer: []T, lengths: anytype) [lengths.len][]T { } /// Visit the given struct and fill the `buffers` slice with the buffer associated with encountered Tensor. -fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32) u32 { +fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buffer, start: u32) u32 { const LocalContext = struct { index: u32, buffers: []const [*]*pjrt.Buffer, + shapes: []const Shape, }; var context: LocalContext = .{ .index = start, .buffers = buffers, + .shapes = shapes, }; meta.visit((struct { fn cb(ctx: *LocalContext, buffer: *const Buffer) void { // stdx.debug.assert(!buffer._data.isDeleted(), "Can't use {} (argument buffer {}) because its pjrt buffer has been donated", .{ buffer, ctx.index }); const model_sharding = ctx.buffers.len; stdx.debug.assert(buffer._shards.len == model_sharding, "Can't feed a {}-sharded tensor into a {}-sharded model", .{ buffer._shards.len, ctx.buffers.len }); + stdx.debug.assert(ctx.shapes[ctx.index].eql(buffer.shape()), "Executable expected argument {} to have shape {}, got {}", .{ ctx.index, ctx.shapes[ctx.index], buffer.shape() }); for (buffer._shards.constSlice(), 0..) |shard, d| { ctx.buffers[d][ctx.index] = shard; } diff --git a/zml/floats.zig b/zml/floats.zig index 5fa7cde..74a8b99 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -305,11 +305,23 @@ pub const BFloat16 = packed struct(u16) { pub fn isInf(self: BFloat16) bool { return allBitsOne(self.exponent) and self.mantissa == 0; } + + pub fn toF32(self: BFloat16) f32 { + // Pad the BF16 with zeros 0 + return @bitCast([2]u16{ 0, @bitCast(self) }); + } + + pub fn fromF32(float32: f32) BFloat16 { + var int: u32 = @bitCast(float32); + // Round up if needed. + int += 0x8000; + const parts: [2]u16 = @bitCast(int); + return @bitCast(parts[1]); + } + const Helpers = FloatHelpers(@This()); pub const zero = Helpers.zero; pub const neg = Helpers.neg; - pub const fromF32 = Helpers.fromF32; - pub const toF32 = Helpers.toF32; pub const format = Helpers.format; }; @@ -317,7 +329,7 @@ test BFloat16 { // From https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples try std.testing.expectEqual(BFloat16.fromF32(0), BFloat16{ .sign = 0, .exponent = 0, .mantissa = 0 }); try std.testing.expectEqual(BFloat16.fromF32(-2), BFloat16{ .sign = 1, .exponent = 127 + 1, .mantissa = 0 }); - try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 65 }); + try std.testing.expectEqual(BFloat16.fromF32(3.02344107628), BFloat16{ .sign = 0, .exponent = 127 + 1, .mantissa = 66 }); try std.testing.expectEqual(BFloat16.fromF32(1.0 / 128.0), BFloat16{ .sign = 0, .exponent = 127 - 7, .mantissa = 0 }); try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf.neg()), [_]u8{ 0x80, 0xff }); try std.testing.expectEqual(BFloat16.inf, BFloat16.fromF32(std.math.inf(f32))); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 39c6d04..5e8697c 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -18,8 +18,8 @@ test { /// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere). pub const HostBuffer = struct { _shape: Shape, - _strides: ?[Shape.MAX_RANK]i64 = null, - data: []const u8, + _strides: [Shape.MAX_RANK]i64, + _data: [*]const u8, _memory: union(enum) { managed: std.mem.Alignment, unmanaged, @@ -28,10 +28,11 @@ pub const HostBuffer = struct { /// Allocates a HostBuffer with the given shape. /// The memory is left undefined. /// The caller owns the memory, and need to call `deinit()`. - pub fn empty(allocator: std.mem.Allocator, sh: Shape) !HostBuffer { + pub fn empty(allocator: std.mem.Allocator, sh: Shape) error{OutOfMemory}!HostBuffer { return .{ ._shape = sh, - .data = try allocator.alignedAlloc(u8, 64, sh.byteSize()), + ._strides = sh.computeStrides().buffer, + ._data = (try allocator.alignedAlloc(u8, 64, sh.byteSize())).ptr, ._memory = .{ .managed = .@"64" }, }; } @@ -43,7 +44,8 @@ pub const HostBuffer = struct { stdx.debug.assert(shape_.byteSize() == data_.len, "shape {} and data {} don't match", .{ shape_.byteSize(), data_.len }); return .{ ._shape = shape_, - .data = data_, + ._strides = shape_.computeStrides().buffer, + ._data = data_.ptr, ._memory = .unmanaged, }; } @@ -53,7 +55,7 @@ pub const HostBuffer = struct { // This means we don't own the data. if (self._memory == .unmanaged) return; const log2_align = self._memory.managed; - allocator.rawFree(@constCast(self.data), log2_align, @returnAddress()); + allocator.rawFree(self.mutBytes(), log2_align, @returnAddress()); } /// Wraps an exisiting slice into a HostBuffer. @@ -62,10 +64,12 @@ pub const HostBuffer = struct { /// that will still need to be deallocated. pub fn fromSlice(sh: anytype, s: anytype) HostBuffer { const shape_ = Shape.init(sh, DataType.fromSliceElementType(s)); - std.debug.assert(shape_.count() == s.len); + const raw_bytes = std.mem.sliceAsBytes(s); + std.debug.assert(shape_.byteSize() == raw_bytes.len); return .{ ._shape = shape_, - .data = @alignCast(std.mem.sliceAsBytes(s)), + ._strides = shape_.computeStrides().buffer, + ._data = raw_bytes.ptr, ._memory = .unmanaged, }; } @@ -81,7 +85,7 @@ pub const HostBuffer = struct { @memcpy(tmp[0..strides_.len], strides_); return .{ ._shape = sh, - .data = @alignCast(std.mem.sliceAsBytes(s)), + ._data = @alignCast(std.mem.sliceAsBytes(s).ptr), ._strides = tmp, ._memory = .unmanaged, }; @@ -89,13 +93,15 @@ pub const HostBuffer = struct { /// Creates a tensor from a **pointer** to a "multi dimension" array. /// Note this doesn't copy, the pointee array need to survive the `HostBuffer` object. + /// Typically this is use with constant arrays. pub fn fromArray(arr_ptr: anytype) HostBuffer { const T = @TypeOf(arr_ptr.*); const sh = parseArrayInfo(T); + std.debug.assert(sh.byteSize() == @sizeOf(T)); return .{ ._shape = sh, - .data = @alignCast(std.mem.sliceAsBytes(arr_ptr)), - // Array are typically stack allocated and don't need to be freed. + ._strides = sh.computeStrides().buffer, + ._data = @ptrCast(arr_ptr), ._memory = .unmanaged, }; } @@ -121,16 +127,15 @@ pub const HostBuffer = struct { stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; - const b = dt.sizeOf(); const res = try empty(allocator, Shape.init(.{n_steps}, dt)); - stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt}); - var data_ = @constCast(res.data); switch (dt) { - inline else => { + inline else => |d| if (comptime d.class() != .integer) { + stdx.debug.assert(dt.class() == .integer, "arange expects type to be integer, got {} instead.", .{dt}); + } else { + const Zt = d.toZigType(); var j: i64 = args.start; - for (0..@intCast(n_steps)) |i| { - var v = Data.init(dt, j); - @memcpy(data_[i * b .. (i + 1) * b], v.constSlice()); + for (res.mutItems(Zt)) |*val| { + val.* = @intCast(j); j +%= args.step; } }, @@ -160,16 +165,26 @@ pub const HostBuffer = struct { /// WARNING: It's only valid if the buffer is contiguous. /// Strided buffers can't use this method. pub fn items(self: HostBuffer, comptime T: type) []const T { - if (DataType.fromZigType(T) != self.dtype()) { - std.debug.panic("Can't reinterpret {} as {s}", .{ self, @typeName(T) }); - } - if (!self.isContiguous()) { - std.debug.panic("{} isn't contiguous", .{self}); - } - const ptr: [*]const T = @alignCast(@ptrCast(self.data.ptr)); + // TODO we should allow interpreting the output as @Vector(8, f32) when the tensor is f32. + stdx.debug.assert(DataType.fromZigType(T) == self.dtype(), "Can't reinterpret {} as {s}", .{ self, @typeName(T) }); + stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); + const ptr: [*]const T = @alignCast(@ptrCast(self._data)); return ptr[0..self._shape.count()]; } + pub fn mutItems(self: HostBuffer, comptime T: type) []T { + return @constCast(self.items(T)); + } + + pub fn bytes(self: HostBuffer) []const u8 { + stdx.debug.assert(self.isContiguous(), "{} isn't contiguous, can't interpret as []const u8", .{self}); + return self._data[0..self._shape.byteSize()]; + } + + pub fn mutBytes(self: HostBuffer) []u8 { + return @constCast(self.bytes()); + } + pub fn shape(self: HostBuffer) Shape { return self._shape; } @@ -178,9 +193,9 @@ pub const HostBuffer = struct { return self._shape.dtype(); } - pub fn strides(self: *const HostBuffer) ?[]const i64 { + pub fn strides(self: *const HostBuffer) []const i64 { // Pass strides per pointer otherwise we return a pointer to this stack frame. - return if (self._strides) |*strd| strd[0..self.rank()] else null; + return self._strides[0..self._shape.rank()]; } // TODO: rename .data into ._data and make it a [*]u8 @@ -205,7 +220,7 @@ pub const HostBuffer = struct { } pub fn isContiguous(self: HostBuffer) bool { - const _strides = self._strides orelse return true; + const _strides = self._strides; const cont_strides = self._shape.computeStrides(); for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| { if (d != 1 and stride != cont_stride) return false; @@ -217,6 +232,7 @@ pub const HostBuffer = struct { stdx.debug.assert(self.isContiguous(), "reshape expects a contiguous tensor, got: {}", .{self}); var res = self; res._shape = self._shape.reshape(shape_); + res._strides = res._shape.computeStrides().buffer; return res; } @@ -236,15 +252,12 @@ pub const HostBuffer = struct { stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); - // If strides weren't set it means original buffer is contiguous. - // But it won't be anymore after slicing. The strides don't change though. - const _strides = self._strides orelse self._shape.computeStrides().buffer; - const offset: usize = @intCast(start * _strides[ax]); + const offset: usize = @intCast(start * self._strides[ax]); + const new_shape = self.shape().set(ax, end - start); return .{ - ._shape = self.shape().set(ax, end - start), - .data = self.data[offset..], - // When axis is 0, we stay contiguous. - ._strides = if (ax == 0) self._strides else _strides, + ._shape = new_shape, + ._data = self._data[offset..], + ._strides = self._strides, ._memory = .unmanaged, }; } @@ -254,18 +267,52 @@ pub const HostBuffer = struct { return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax); } + pub fn choose(self: HostBuffer, offsets: anytype) HostBuffer { + const off, const tags = Shape.parseDimensions(offsets); + var sh = self._shape; + var offset: i64 = 0; + for (off.constSlice(), tags.constSlice()) |o, t| { + const ax = sh.axis(t); + offset += o * self._strides[ax]; + sh._dims.buffer[ax] = 0; + } + + var new_strides: [Shape.MAX_RANK]i64 = @splat(self.dtype().sizeOf()); + + // TODO rewrite with simd. This is a pshuf, but it's not supported by @shuffle. + var res_ax: u32 = 0; + for (0..self._shape.rank()) |ax| { + if (sh._dims.buffer[ax] == 0) { + continue; + } + + sh._dims.buffer[res_ax] = self._shape._dims.buffer[ax]; + sh._tags.buffer[res_ax] = self._shape._tags.buffer[ax]; + new_strides[res_ax] = self._strides[ax]; + res_ax += 1; + } + sh._dims.len -= off.len; + sh._tags.len -= off.len; + + return HostBuffer{ + ._shape = sh, + ._strides = new_strides, + ._data = self._data[@intCast(offset)..], + ._memory = .unmanaged, + }; + } + pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { const ax = self._shape.axis(axis_); stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); - var _strides: ?[Shape.MAX_RANK]i64 = self._strides; - if (self._strides) |strydes| { - std.mem.copyForwards(i64, _strides.?[0 .. Shape.MAX_RANK - 1], strydes[1..]); - } + var strd: std.BoundedArray(i64, Shape.MAX_RANK) = .{ .buffer = self._strides, .len = self.rank() }; + _ = strd.orderedRemove(ax); + return .{ ._shape = self.shape().drop(ax), - .data = self.data, - ._strides = _strides, + ._data = self._data, + ._strides = strd.buffer, ._memory = self._memory, }; } @@ -276,9 +323,12 @@ pub const HostBuffer = struct { options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = fmt; _ = options; - try writer.print("HostBuffer(.{_})", .{self._shape}); + if (std.mem.eql(u8, fmt, "v")) { + try writer.print("HostBuffer(.{_})@0x{x}", .{ self._shape, @intFromPtr(self._data) }); + } else { + try writer.print("HostBuffer(.{_})", .{self._shape}); + } } /// Formatter for a HostBuffer that also print the values not just the shape. diff --git a/zml/meta.zig b/zml/meta.zig index 16680eb..b3dcf68 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -237,42 +237,48 @@ test mapAlloc { /// Recursively visit the given struct and calls the callback for each K found. /// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it. pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { - const T = @TypeOf(v); - const type_info_v = @typeInfo(T); - const K = switch (@typeInfo(FnParam(cb, 1))) { - .pointer => |info| info.child, - else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}), - }; - + const Callback = @TypeOf(cb); + const Ptr = @TypeOf(v); + const type_info_v = @typeInfo(Ptr); if (type_info_v != .pointer) { - const Callback = @TypeOf(cb); - stdx.debug.compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: {} but received: {}", .{ Callback, T }); + stdx.debug.compileError("zml.meta.visit({}) is expecting a pointer/slice input, but received: {}", .{ Callback, Ptr }); } const ptr_info = type_info_v.pointer; - if (@typeInfo(ptr_info.child) == .@"fn") return; - if (ptr_info.child == anyopaque) return; - // This is important, because with trivial types like void, - // Zig sometimes decide to call `visit` at comptime, but can't do - // the pointer wrangling logic at comptime. - // So we detect early this case and return. - if (@sizeOf(ptr_info.child) == 0) return; + const Child = ptr_info.child; + const K, const mutating_cb = switch (@typeInfo(FnParam(cb, 1))) { + .pointer => |info| .{ info.child, !info.is_const }, + else => stdx.debug.compileError("zml.meta.visit is expecting a callback with a pointer as second argument but found {}", .{FnParam(cb, 1)}), + }; + // Abort if v doesnt' contain any K. + if (comptime !Contains(Ptr, K)) return; + + // Handle simple cases. + switch (Ptr) { + *const K, *K => return cb(ctx, v), + *const ?K, *?K => return if (v.*) |*val| cb(ctx, val) else {}, + []const K, []K => { + for (v) |*v_elem| cb(ctx, v_elem); + return; + }, + else => {}, + } + + // Handle std.BoundedArray that contains uninitalized data. + if (@typeInfo(Child) == .@"struct" and @hasDecl(Child, "constSlice") and @hasDecl(Child, "slice")) { + return visit(cb, ctx, if (mutating_cb) v.slice() else v.constSlice()); + } + + // Recursively visit fields of v. switch (ptr_info.size) { - // If we have a single pointer, two cases: - // * It's a pointer to K, in which case we call the callback. - // * It's a pointer to something else, in which case, we explore and recurse if needed. - .one => if (ptr_info.child == K) { - cb(ctx, v); - } else if (ptr_info.child == ?K) { - if (v.*) |*val| cb(ctx, val); - } else switch (@typeInfo(ptr_info.child)) { - .@"struct" => |s| inline for (s.fields) |field_info| { - if (field_info.is_comptime) continue; - const field_type_info = @typeInfo(field_info.type); + .one => switch (@typeInfo(Child)) { + .@"struct" => |s| inline for (s.fields) |field| { + if (field.is_comptime or comptime !Contains(field.type, K)) continue; + const field_type_info = @typeInfo(field.type); // If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field. switch (field_type_info) { - .pointer => visit(cb, ctx, @field(v, field_info.name)), - .array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field_info.name)), + .pointer => visit(cb, ctx, @field(v, field.name)), + .array, .optional, .@"union", .@"struct" => visit(cb, ctx, &@field(v, field.name)), else => {}, } }, @@ -281,23 +287,19 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { .@"union" => switch (v.*) { inline else => |*v_field| visit(cb, ctx, v_field), }, - else => {}, + else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }), }, - // If we have a slice, two cases also: - // * It's a slice of K, in which case we call the callback for each element of the slice. - // * It's a slice to something else, in which case, for each element we explore and recurse if needed. .slice => { for (v) |*v_elem| { - if (ptr_info.child == K) { - cb(ctx, v_elem); - } else switch (@typeInfo(ptr_info.child)) { - .@"struct" => |s| inline for (s.fields) |field_info| { - const field_type_info = @typeInfo(field_info.type); + switch (@typeInfo(Child)) { + .@"struct" => |s| inline for (s.fields) |field| { + if (field.is_comptime or comptime !Contains(field.type, K)) continue; + const field_type_info = @typeInfo(field.type); // If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field. if (field_type_info == .pointer) { - visit(cb, ctx, @field(v_elem, field_info.name)); + visit(cb, ctx, @field(v_elem, field.name)); } else { - visit(cb, ctx, &@field(v_elem, field_info.name)); + visit(cb, ctx, &@field(v_elem, field.name)); } }, .array => |_| for (v) |*elem| visit(cb, ctx, elem), @@ -305,11 +307,11 @@ pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void { .@"union" => switch (v_elem.*) { inline else => |*v_field| visit(cb, ctx, v_field), }, - else => {}, + else => stdx.debug.compileError("zml.meta.visit({}) doesn't support fields of type: {}", .{ Callback, Child }), } } }, - else => {}, + .many, .c => stdx.debug.compileError("zml.meta.visit({}) doesn't support [*] style pointers, got: {}", .{ Callback, Ptr }), } } @@ -320,7 +322,7 @@ test visit { const NestedAttrOptional = struct { nested: ?Attr }; const SimpleStruct = struct { prop: Attr }; const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr }; - const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional }; + const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional, prop5: std.BoundedArray(Attr, 8) }; const LocalContext = struct { result: usize }; @@ -374,11 +376,16 @@ test visit { } { var context: LocalContext = .{ .result = 0 }; + const prop5: std.BoundedArray(Attr, 8) = .{ + .buffer = @splat(.{ .data = 4 }), + .len = 2, + }; const container: NestedTypesStruct = .{ .prop1 = .{ .data = 1 }, .prop2 = .{ .other = "hello" }, .prop3 = .{ .nested = .{ .data = 2 } }, .prop4 = .{ .nested = .{ .data = 3 } }, + .prop5 = prop5, // 4 will be counted twice. }; visit((struct { @@ -387,7 +394,7 @@ test visit { } }).cb, &context, &container); - try std.testing.expectEqual(6, context.result); + try std.testing.expectEqual(14, context.result); } } @@ -533,3 +540,36 @@ fn _CollectArg(func: anytype) type { const params = @typeInfo(@TypeOf(func)).@"fn".params; return params[params.len - 1].type orelse @compileError("anytype not supported in collect"); } + +pub fn Contains(Haystack: type, T: type) bool { + switch (Haystack) { + T, ?T => return true, + *T, ?*T => return true, + *const T, ?*const T => return true, + []const T, ?[]const T => return true, + anyopaque => return false, + else => {}, + } + + return switch (@typeInfo(Haystack)) { + .@"struct" => |info| { + inline for (info.fields) |field| { + if (Contains(field.type, T)) + return true; + } + return false; + }, + .@"union" => |info| { + inline for (info.fields) |field| { + if (Contains(field.type, T)) + return true; + } + return false; + }, + .array => |info| Contains(info.child, T), + .pointer => |info| Contains(info.child, T), + .optional => |info| Contains(info.child, T), + .vector => |info| Contains(info.child, T), + else => false, + }; +} diff --git a/zml/module.zig b/zml/module.zig index 011e6fe..798b81c 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -29,7 +29,7 @@ test { pub const MlirFn = struct { name: []const u8, - num_args: u32, + args_shapes: []Shape, res_tensors: *const anyopaque, res_types: []mlir.Type, res_shapes: []Shape, @@ -199,7 +199,7 @@ pub const CompilationContext = struct { const loaded_executable: *pjrt.LoadedExecutable = blk: { if (pjrt_location) |pjrt_loc| { if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| { - log.info("Loaded pre-compiled module from {s}", .{pjrt_loc}); + log.info("Loaded pre-compiled module from {s} (generated from {s}/module.mlir)", .{ pjrt_loc, module_dir.? }); break :blk exe; } else |err| { if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc }); @@ -233,7 +233,7 @@ pub const CompilationContext = struct { self._platform, loaded_executable, .{ - .n_in = f.num_args, + .input_shapes = f.args_shapes, .result_shapes = f.res_shapes, .n_devices = sharding.num_replicas * sharding.num_partitions, }, @@ -341,7 +341,7 @@ pub const CompilationContext = struct { const locations = try arena.alloc(mlir.Location, tensor_count); @memset(locations, mlir.Location.unknown(mlir_ctx)); - var input_shapes = try std.ArrayList(Shape).initCapacity(arena, tensor_count); + var input_shapes = try std.ArrayList(Shape).initCapacity(res_allocator, tensor_count); meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); @@ -427,7 +427,7 @@ pub const CompilationContext = struct { return .{ .mlir_fn = mlir_fn, .name = opts.name, - .num_args = @intCast(tensor_count), + .args_shapes = input_shapes.items, .res_tensors = fn_res, .res_types = fn_res_types, .res_shapes = fn_res_shapes, @@ -512,7 +512,7 @@ pub const CompilationContext = struct { // Check that the `x` input argument gives its buffer to the result tensor. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. - try std.testing.expectEqual(3, f.num_args); + try std.testing.expectEqual(3, f.args_shapes.len); // We should have two buffers being donated. const template = "tf.aliasing_output = {d} : i32"; var buf = template.*; @@ -540,9 +540,13 @@ pub const CompilationContext = struct { } } + pub fn numPartitions(self: CompilationContext) u8 { + return self._platform.sharding().num_partitions; + } + pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { const ctx = self.mlirCtx(); - const num_partitions = self._platform.sharding().num_partitions; + const num_partitions = self.numPartitions(); var sharding_str: std.BoundedArray(u8, 128) = .{}; writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; return mlir.Attribute.string(ctx, sharding_str.constSlice()); @@ -645,10 +649,11 @@ pub const CompilationContext = struct { const loc = self.mlirCtx().location(@src()); - const values = try arena.alloc(mlir.Value, function.num_args); + const num_args = function.args_shapes.len; + const values = try arena.alloc(mlir.Value, num_args); self.extractValues(&args, values); - const donations = try arena.alloc(Tensor._Donation, function.num_args); + const donations = try arena.alloc(Tensor._Donation, num_args); meta.collectBuf(struct { pub fn cb(ctx: *const CompilationContext, x: Tensor) Tensor._Donation { return ctx.getValueAndDonation(x)[1]; diff --git a/zml/nn.zig b/zml/nn.zig index 104d793..c6a8560 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -176,6 +176,8 @@ pub const RopeOpts = struct { /// Read a Rope scaling config from HF config.json format. pub fn jsonParse(allocator: std.mem.Allocator, source: anytype, options: std.json.ParseOptions) !Scaling { const content = try std.json.Value.jsonParse(allocator, source, options); + if (content == .null) return .default; + if (content != .object) return error.InvalidEnumTag; const obj = content.object; diff --git a/zml/shape.zig b/zml/shape.zig index 8da8519..ce16ec9 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -58,10 +58,10 @@ pub const Shape = struct { const fv = @field(v, field.name); if (comptime stdx.meta.isInteger(field.type)) { dims_.appendAssumeCapacity(@intCast(fv)); - } else if (comptime isAutoDim(fv)) { + } else if (@TypeOf(fv) == EnumLiteral and comptime isAutoDim(fv)) { dims_.appendAssumeCapacity(-1); } else { - stdx.debug.compileError("Field {s} should be an integer or an auto dimension", .{field.name}); + stdx.debug.compileError("Field {s} should be an integer or an auto dimension, got {}", .{ field.name, field.type }); } if (comptime stdx.meta.isTuple(T)) { tags_.appendAssumeCapacity(TagUnknown); @@ -186,7 +186,7 @@ pub const Shape = struct { EnumLiteral => @tagName(v).ptr, std.builtin.Type.StructField => v.name.ptr, Tag => v, - else => stdx.debug.compileError("Value should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), + else => stdx.debug.compileError("Shape tag should be an EnumLiteral, a Shape.Tag or a StructField, got {}", .{T}), }; } @@ -581,6 +581,41 @@ pub const Shape = struct { try std.testing.expectEqualSlices(i64, &.{ 10, 11, 12 }, Shape.init(.{ 10, 11, 12, 13 }, .f32).remove(-1).dims()); } + pub fn removeMany(self: Shape, axes_: anytype) Shape { + var to_remove = self.axes(axes_); + if (to_remove.len == 0) return self; + std.mem.sort(u3, to_remove.slice(), {}, std.sort.asc(u3)); + + var sh: Shape = self; + const rk = self.rank(); + var res_ax: u32 = 0; + for (0..rk) |ax| { + if (std.mem.indexOfScalar(u3, to_remove.constSlice(), @intCast(ax))) |_| { + continue; + } + + sh._dims.buffer[res_ax] = self._dims.buffer[ax]; + sh._tags.buffer[res_ax] = self._tags.buffer[ax]; + res_ax += 1; + } + sh._dims.len = rk - to_remove.len; + sh._tags.len = rk - to_remove.len; + return sh; + } + + test removeMany { + try std.testing.expectEqualSlices( + i64, + &.{12}, + Shape.init(.{ 10, 11, 12 }, .f32).removeMany(.{ 0, 1 }).dims(), + ); + try std.testing.expectEqualSlices( + i64, + &.{ 10, 11 }, + Shape.init(.{ 10, 11, 12, 13 }, .f32).removeMany(.{ -1, -2 }).dims(), + ); + } + pub fn transpose(self: Shape, permutations: anytype) Shape { std.debug.assert(self.rank() == permutations.len); const permutations_ = self.axes(permutations); @@ -729,7 +764,9 @@ pub const Shape = struct { stdx.debug.assertComptime(stdx.meta.isStructOfAny(T, isAxisConvertible), "Must pass a struct of enum literals. Passed: {any}", .{T}); var res = self; inline for (std.meta.fields(T)) |field| { - res._tags.set(self.axis(field), toTag(@field(renames, field.name))); + const new_field = @field(renames, field.name); + stdx.debug.assert(self.hasTag(new_field) == null, "{}.rename({any}) failed because of duplicated axis {}", .{ self, renames, new_field }); + res._tags.set(self.axis(field), toTag(new_field)); } return res; } @@ -749,15 +786,20 @@ pub const Shape = struct { } pub fn computeStrides(self: Shape) std.BoundedArray(i64, MAX_RANK) { - const base_stride = self.dtype().sizeOf(); const rk = self.rank(); - var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = @intCast(self.rank()) }; + var strides: std.BoundedArray(i64, MAX_RANK) = .{ .len = rk }; if (rk == 0) return strides; - strides.buffer[rk - 1] = base_stride; - for (1..rk) |i| { - const j = @as(usize, rk) - 1 - i; - strides.buffer[j] = self._dims.get(j + 1) * strides.buffer[j + 1]; - } + + const V = @Vector(MAX_RANK, i64); + const rank_mask = std.simd.iota(u8, MAX_RANK) < @as(@Vector(MAX_RANK, u8), @splat(rk)); + // For each axis compute the product of all following dimensions + // and the element size in bytes. + var d: V = @bitCast(self._dims.buffer); + d = @select(i64, rank_mask, d, @as(V, @splat(1))); + d = std.simd.shiftElementsLeft(d, 1, self.dtype().sizeOf()); + d = std.simd.prefixScan(.Mul, -1, d); + + strides.buffer = @bitCast(d); return strides; } diff --git a/zml/tensor.zig b/zml/tensor.zig index 79c2f2c..ae06a77 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -176,6 +176,7 @@ pub const Tensor = struct { var res = self; res._shape = self._shape.withSharding(axes_); + if (ctx.numPartitions() <= 1) return self; const op = dialect.stablehlo.custom_call( mlir_ctx, &.{self.value()}, @@ -1279,9 +1280,9 @@ pub const Tensor = struct { /// see: https://paperswithcode.com/method/gelu pub fn gelu(x: Tensor) Tensor { const scaled_x_cube = x.mul(x).mul(x).scale(0.044715); - const one = Tensor.constant(x._shape, x.dtype().one()); - const one_plus_tanh = Tensor.add(x, scaled_x_cube).scale(std.math.sqrt(2.0 / std.math.pi)).tanh().add(one); - return one_plus_tanh.mul(x).scale(0.5); + const beta = std.math.sqrt(2.0 / std.math.pi); + const tanh_ = x.add(scaled_x_cube).scale(beta).tanh(); + return tanh_.addConstant(1).mul(x).scale(0.5); } /// Returns a Tensor containing an approximation of the Gaussian Error Linear Units (GeLU) activation function applied to each element of the input Tensor. @@ -1526,8 +1527,34 @@ pub const Tensor = struct { pub const Slice = struct { start: i64 = 0, - end: ?i64 = null, - step: i64 = 1, + end: i64 = to_the_end, + step: i32 = 1, + singleton: bool = false, + + pub fn single(offset: i64) Slice { + return .{ .start = offset, .end = offset + 1, .singleton = true }; + } + + const to_the_end = std.math.maxInt(i64); + + pub fn format( + self: Slice, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + if (self.singleton) { + try writer.print("[{}]", .{self.start}); + } else if (self.end == to_the_end and self.step == 1) { + try writer.print("[{}..]", .{self.start}); + } else if (self.step == 1) { + try writer.print("[{}..{}]", .{ self.start, self.end }); + } else { + try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step }); + } + } }; /// Slices the input Tensor over the given axis using the given parameters. @@ -1549,13 +1576,13 @@ pub const Tensor = struct { const args: Slice = .{ .start = self.wrapIndex(a, s.start), - .end = if (s.end) |end| self.wrapIndex(a, end) else self.dim(a), + .end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end), .step = s.step, }; start_indices[a] = args.start; - limit_indices[a] = args.end.?; + limit_indices[a] = args.end; strides[a] = args.step; - res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end.? - args.start, args.step) catch unreachable); + res_shape = res_shape.setDim(a, std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable); } const mlir_ctx = self.getContext().mlirCtx(); @@ -1571,7 +1598,12 @@ pub const Tensor = struct { loc, ); - return _result(res_shape, slice_op.result(0)); + var res = _result(res_shape, slice_op.result(0)); + var to_remove: Shape.AxesArray = .{}; + for (slices, 0..) |s, a| { + if (s.singleton) to_remove.appendAssumeCapacity(@intCast(a)); + } + return res.reshape(res_shape.removeMany(to_remove.constSlice())); } test slice { @@ -1606,8 +1638,17 @@ pub const Tensor = struct { } pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor { - // TODO: this use case could be handled directly by slice if we added a .single field - return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_); + return self.slice1d(axis_, .single(i)); + } + + pub fn choose(self: Tensor, offsets: anytype) Tensor { + const off, const tags = Shape.parseDimensions(offsets); + var slices = [_]Slice{.{}} ** MAX_RANK; + for (off.constSlice(), tags.constSlice()) |o, t| { + const ax = self.axis(t); + slices[ax] = .single(o); + } + return self.slice(slices[0..self.rank()]); } /// Concatenates the input Tensors along the given axis. @@ -1866,7 +1907,12 @@ pub const Tensor = struct { /// Returns a 0-rank Tensor with the given value. pub fn scalar(val: anytype, dt: DataType) Tensor { - return Tensor.constant(.{}, Data.init(dt, val)); + const data = Data.init(dt, val); + switch (dt.class()) { + .float => stdx.debug.assert(!std.math.isNan(val), "scalar(NaN) is probably due to compiling a model with an uninitialized field", .{}), + else => {}, + } + return Tensor.constant(.{}, data); } test scalar { @@ -1913,7 +1959,7 @@ pub const Tensor = struct { const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape()); const loc = ctx.location(@src()); const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); - const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc); + const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.bytes(), loc); return _result(val.shape(), constant_op.result(0)); } @@ -3786,6 +3832,7 @@ pub const Tensor = struct { /// Only for debug purpose, it inserts device to host synchronization /// so it will slow down the program execution. pub fn print(input: Tensor) Tensor { + // TODO: find a way of doing print that doesn't involve a H2D copy. return ops.addHostCallback( &printCallback, null, @@ -3797,8 +3844,10 @@ pub const Tensor = struct { fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { const host_buffer = inputs[0]; - std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); - std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr); + std.log.defaultLog(.info, .zml, "Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); + // This is true because of the operand aliases. + // Since the result is already pointing to the input we don't need to modify the buffer. + std.debug.assert(host_buffer._data == outputs[0]._data); } }; @@ -3918,6 +3967,10 @@ test "Tensor.maxPool2d" { /// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. pub fn Bufferized(comptime T: type) type { + // TODO: we should strip out the non-buffer fields. + // Currently it's confusing cause the Bufferized struct contains field that are never read. + // Also it will simplify the layout of the Bufferized struct. + // accelerating the calls to execute. return meta.MapType(Tensor, Buffer).map(T); } diff --git a/zml/testing.zig b/zml/testing.zig index b0d6827..8946f78 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -1,10 +1,11 @@ -const builtin = @import("builtin"); const std = @import("std"); +const builtin = @import("builtin"); + const stdx = @import("stdx"); -const zml = @import("zml.zig"); const meta = @import("meta.zig"); const shapesOf = @import("tensor.zig").shapesOf; +const zml = @import("zml.zig"); const log = std.log.scoped(.@"zml/testing"); @@ -35,7 +36,7 @@ pub fn approxEq(comptime Float: type, l: Float, r: Float, tolerance: Float) bool /// Testing utility. Accepts both Tensor and HostBuffer but Tensor will be copied to the /// host for comparison ! pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void { - const allocator = if (builtin.is_test) std.testing.allocator else std.heap.page_allocator; + const allocator = if (builtin.is_test) std.testing.allocator else std.heap.smp_allocator; var left: zml.HostBuffer, const should_free_left = if (@TypeOf(left_) == zml.Buffer) .{ try left_.toHostAlloc(allocator), true } else