diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 25e4551..a3ef83d 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -360,9 +360,11 @@ pub const Client = opaque { buffer_type: BufferType, dims: []const i64, byte_strides: ?[]const i64, - device: ?*const Device = null, host_buffer_semantics: HostBufferSemantics, - memory: ?*const Memory = null, + dst: union(enum) { + device: *const Device, + memory: *const Memory, + }, }; pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } { @@ -375,11 +377,11 @@ pub const Client = opaque { .byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null, .num_byte_strides = if (args.byte_strides) |bs| bs.len else 0, .host_buffer_semantics = @intFromEnum(args.host_buffer_semantics), - .device = @ptrCast(@constCast(args.device)), - .memory = @ptrCast(@constCast(args.memory)), + .device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null, + .memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null, .device_layout = null, // TODO - .done_with_host_buffer = null, - .buffer = null, + .done_with_host_buffer = null, // out + .buffer = null, // out }); return .{ @@ -430,7 +432,7 @@ pub const Client = opaque { pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory { const ret = api.call(.PJRT_Client_AddressableMemories, .{ .client = self.inner(), - }) catch unreachable; + }) catch return &.{}; if (ret.addressable_memories) |memories| { return @ptrCast(@constCast(memories[0..ret.num_addressable_memories])); } @@ -474,8 +476,10 @@ pub const Client = opaque { dims: []const i64, element_type: BufferType, layout: MemoryLayout, - device: ?*const Device = null, - memory: ?*const Memory = null, + dst: union(enum) { + device: *const Device, + memory: *const Memory, + }, }; pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer { @@ -486,8 +490,8 @@ pub const Client = opaque { .shape_num_dims = @intCast(args.dims.len), .shape_element_type = @intFromEnum(args.element_type), .shape_layout = @ptrCast(&layout), - .device = @ptrCast(@constCast(args.device)), - .memory = @ptrCast(@constCast(args.memory)), + .device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null, + .memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null, }); return @ptrCast(ret.buffer.?); } @@ -530,6 +534,8 @@ pub const MemoryStats = struct { pool_bytes_is_set: bool, // out peak_pool_bytes: u64, // out peak_pool_bytes_is_set: bool, // out + + pub const zeroes = std.mem.zeroes(MemoryStats); }; pub const Device = opaque { @@ -556,10 +562,11 @@ pub const Device = opaque { return @intCast(ret.local_hardware_id); } - pub fn addressableMemories(self: *const Device, api: *const Api) ApiError![]const *Memory { - const ret = try api.call(.PJRT_Device_AddressableMemories, .{ - .device = self.inner(), - }); + pub fn addressableMemories(self: *const Device, api: *const Api) []const *Memory { + const ret = api.call( + .PJRT_Device_AddressableMemories, + .{ .device = self.inner() }, + ) catch return &.{}; return @ptrCast(ret.memories[0..ret.num_memories]); } @@ -728,7 +735,6 @@ pub const LoadedExecutable = opaque { _ = api.call(.PJRT_LoadedExecutable_Destroy, .{ .executable = self.inner(), }) catch {}; - self.* = undefined; } pub fn delete(self: *LoadedExecutable, api: *const Api) void { @@ -759,6 +765,7 @@ pub const LoadedExecutable = opaque { non_donatable_input_indices: []const i64 = &.{}, context: ?*ExecuteContext, }; + pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void { var options = pjrtStruct(c.PJRT_ExecuteOptions{ .send_callbacks = null, @@ -1048,8 +1055,16 @@ pub const Event = opaque { pub const Memory = opaque { pub const Kind = enum { device, - pinned_host, - unpinned_host, + host_pinned, + host_unpinned, + + pub fn pjrtName(k: Kind) []const u8 { + return switch (k) { + .device => "device", + .host_pinned => "pinned_host", + .host_unpinned => "unpinned_host", + }; + } }; const inner = InnerMixin(c.PJRT_Memory).inner; @@ -1061,8 +1076,12 @@ pub const Memory = opaque { pub fn kind(self: *const Memory, api: *const Api) Kind { const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable; - const kind_ = ret.kind orelse unreachable; - return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable; + return switch (ret.kind_size) { + "device".len => .device, + "pinned_host".len => .host_pinned, + "unpinned_host".len => .host_unpinned, + else => @panic("Memory kind not supported"), + }; } pub fn kindId(self: *const Memory, api: *const Api) u32 { diff --git a/stdx/flags.zig b/stdx/flags.zig index 467cea6..667814f 100644 --- a/stdx/flags.zig +++ b/stdx/flags.zig @@ -41,8 +41,9 @@ //! caller to manage the lifetime. The caller should be skipping program name. const std = @import("std"); -const builtin = @import("builtin"); const assert = std.debug.assert; +const builtin = @import("builtin"); + const debug = @import("debug.zig"); /// Format and print an error message to stderr, then exit with an exit code of 1. @@ -285,7 +286,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags { fn assert_valid_value_type(comptime T: type) void { comptime { - if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int) return; + if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int or @typeInfo(T) == .float) return; if (@typeInfo(T) == .@"enum") { const info = @typeInfo(T).@"enum"; @@ -347,6 +348,7 @@ fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T { if (V == []const u8 or V == [:0]const u8) return value; if (V == ByteSize) return parse_value_size(flag, value); if (@typeInfo(V) == .int) return parse_value_int(V, flag, value); + if (@typeInfo(V) == .float) return parse_value_float(V, flag, value); if (@typeInfo(V) == .@"enum") return parse_value_enum(V, flag, value); comptime unreachable; } @@ -515,6 +517,20 @@ fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T { }; } +/// Parse string value into a float, providing a nice error message for the user. +fn parse_value_float(comptime T: type, flag: []const u8, value: [:0]const u8) T { + assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); + + return std.fmt.parseFloat(T, value) catch |err| { + switch (err) { + error.InvalidCharacter => fatal( + "{s}: expected a decimal value, but found '{s}' (invalid character)", + .{ flag, value }, + ), + } + }; +} + fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E { assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<'); comptime assert(@typeInfo(E).@"enum".is_exhaustive); diff --git a/zml/aio.zig b/zml/aio.zig index 47cc75f..8114e0a 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -116,12 +116,17 @@ pub const BufferStore = struct { if (id < self._unique_id or self._unique_id + _store_id_range <= id) { @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object."); } + if (platform.target != .cpu) mem_debug: { + const stats = platform.getDevices()[0].memoryStats(platform.pjrt_api) catch break :mem_debug; + log.debug("Loading {s} -> {f} {d:>10} bytes ({d:>10} allocated / {d:>10} reserved)", .{ self.buffers.keys()[id - self._unique_id], x._shape, x.shape().byteSize(), stats.bytes_in_use, stats.bytes_reserved }); + } break :hb self.buffers.values()[id - self._unique_id]; }, else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"), }; // Use the sharding information stored in the tensor. + std.debug.assert(host_buffer.shape().eql(x.shape())); host_buffer._shape = x.shape(); return try host_buffer.toDevice(platform); } @@ -703,7 +708,7 @@ pub fn loadModelBuffersWithPrefix( var res: zml.Bufferized(Model) = undefined; try zml.meta.mapAlloc(struct { pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer { - return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined }; + return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined, ._target = undefined }; } }.initBuffer, allocator, {}, model, &res); diff --git a/zml/buffer.zig b/zml/buffer.zig index 36a55d8..c9373e5 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -8,6 +8,7 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer; const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; +const Target = @import("platform.zig").Target; test { std.testing.refAllDecls(@This()); @@ -22,40 +23,22 @@ const log = std.log.scoped(.zml); /// * loading weights from disk directly to the `device zml.aio.loadBuffers` /// * can be created by calling `HostBuffer.toDevice(platform)`. pub const Buffer = struct { - pub const Memory = enum { - host, - host_pinned, - device, - - pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind { - return switch (self) { - .host => .unpinned_host, - .host_pinned => .pinned_host, - .device => .device, - }; - } - - pub fn pjrtName(self: Memory) []const u8 { - return @tagName(self.toPjrtMemory()); - } - }; - _shape: Shape, _api: *const pjrt.Api, + _target: Target, _shards: Shards, pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES; pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS); - pub const FromOptions = struct { - wait: bool = true, - memory: ?Memory = null, - }; + pub const Memory = pjrt.Memory.Kind; + pub const FromOptions = struct { wait: bool = true, memory: Memory = .device }; /// Copies the content of the given buffer from host memory to the accelerator memory. pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer { var res: Buffer = .{ ._api = platform.pjrt_api, + ._target = platform.target, ._shape = host_buffer.shape(), ._shards = .{}, }; @@ -82,35 +65,22 @@ pub const Buffer = struct { break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size }); } else host_buffer; - var args = pjrt.Client.BufferFromHostBufferArgs{ + const args = pjrt.Client.BufferFromHostBufferArgs{ .data = buf._data, .buffer_type = buffer_type, .dims = buf.shape().dims(), .byte_strides = byte_strides, .host_buffer_semantics = .ImmutableUntilTransferCompletes, + // CPU has no distinctions between memories. + .dst = if (platform.target == .cpu or opts.memory == .device) + .{ .device = devices[i] } + else + .{ .memory = platform.memoryForDevice(opts.memory, devices[i]) }, }; - if (platform.target == .cpu or opts.memory == null) { - args.device = devices[i]; - } else { - const memory = opts.memory.?; - const device_memories = try devices[i].addressableMemories(platform.pjrt_api); - // TODO measure the cost of this and consider caching on Zig side inside the platform. - const selected_memory = for (device_memories) |m| { - const kind = m.kind(platform.pjrt_api); - if (kind == memory.toPjrtMemory()) break m; - } else { - log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) }); - return error.NotFound; - }; - args.memory = selected_memory; - } const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args); - if (event) |ev| { - ev.deinit(platform.pjrt_api); - } - + if (event) |ev| ev.deinit(platform.pjrt_api); res._shards.appendAssumeCapacity(pjrt_buffer); } @@ -131,6 +101,15 @@ pub const Buffer = struct { return self; } + pub fn awaitBlocking(self: Buffer) !Buffer { + for (self._shards.constSlice()) |buffer| { + if (buffer.getReadyEvent(self._api)) |ev| { + try ev.awaitBlocking(self._api); + } + } + return self; + } + /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer { stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); @@ -139,6 +118,7 @@ pub const Buffer = struct { shards.appendSliceAssumeCapacity(pjrt_buffers); return .{ ._api = platform.pjrt_api, + ._target = platform.target, ._shape = shape_, ._shards = shards, }; @@ -185,9 +165,10 @@ pub const Buffer = struct { } pub fn asHostBuffer(self: Buffer) HostBuffer { - // TODO: skip this check on cpu - // const memory = self.getMemory().kind(self._api); - // stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory }); + if (self._target != .cpu) { + const memory = self.getMemory().kind(self._api); + stdx.debug.assert((memory == .host_pinned) or (memory == .host_unpinned), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory }); + } const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable); return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]); } @@ -200,7 +181,7 @@ pub const Buffer = struct { } /// Creates a Buffer with a single element repeated manytime. - pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer { + pub fn constant(platform: Platform, shape_: Shape, val: anytype, opts: FromOptions) !Buffer { var start = try std.time.Timer.start(); defer { const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms); @@ -210,6 +191,8 @@ pub const Buffer = struct { } } + // Constant is always blocking because it uses pointer to stack memory. + const cst_opts: FromOptions = .{ .memory = opts.memory, .wait = true }; // Convert val to the requested dtype. const x = shape_.dtype().constant(val); const byte_size = shape_.dtype().sizeOf(); @@ -222,7 +205,7 @@ pub const Buffer = struct { ._strides = @splat(0), ._data = x.constSlice().ptr, }; - return try from(platform, host_buffer, .{ .wait = true }); + return try from(platform, host_buffer, cst_opts); } // To speed up copies, duplicate the scalar value into a vector, @@ -245,14 +228,14 @@ pub const Buffer = struct { else => unreachable, } const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes }; - return try from(platform, host_buffer, .{ .wait = true }); + return try from(platform, host_buffer, cst_opts); } test constant { const zml = @import("zml.zig"); const platform = zml.testing.env(); - const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42); + const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42, .{ .wait = true }); const y = try x.getValue([4 * 3 * 2]u16); try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y); } @@ -271,14 +254,6 @@ pub const Buffer = struct { /// Creates a Buffer from a pointer into device memory. /// This allows to interface with other libraries producing buffers. pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer { - const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { - var res: [Shape.MAX_RANK]i64 = undefined; - for (0..Shape.MAX_RANK) |i| { - res[i] = @intCast(Shape.MAX_RANK - i - 1); - } - break :blk res; - }; - const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ .data = device_data, .element_type = bufferTypeFromDtype(shape_.dtype()), @@ -287,7 +262,7 @@ pub const Buffer = struct { .device = platform.getDevices()[0], .layout = .{ .tiled = .{ - .minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..], + .minor_to_major = minorToMajor(shape_.rank()), .tile_dims = &.{}, .tile_dims_sizes = &.{}, }, @@ -299,15 +274,16 @@ pub const Buffer = struct { shards.appendAssumeCapacity(pjrt_buffer); return .{ ._api = platform.pjrt_api, + ._target = platform.target, ._shape = shape_, ._shards = shards, }; } - pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 { + pub fn devicePtr(self: Buffer) u64 { stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{}); const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable; - return @ptrCast(opaque_ptr); + return @intFromPtr(opaque_ptr); } /// Fetches the content of the given buffer into a stack variable of the given type. @@ -350,6 +326,7 @@ pub const Buffer = struct { /// Depending on the platform, the memory is typically not released to the OS /// but just marked as available in the memory pool. pub fn deinit(self: *const Buffer) void { + // log.warn("Unloading {f} {d} bytes", .{ self._shape, self._shape.byteSize() }); for (self._shards.constSlice()) |buffer| { buffer.deinit(self._api); } @@ -385,7 +362,7 @@ pub const Buffer = struct { } pub fn format(self: Buffer, writer: *std.Io.Writer) !void { - try writer.print("Buffer({f})", .{self._shape}); + try writer.print("Buffer({f})@{x}", .{ self._shape, self.devicePtr() }); } pub fn getMemory(self: Buffer) *const pjrt.Memory { @@ -402,10 +379,10 @@ pub const Buffer = struct { wait: bool = true, }; - pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer { - const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory()); + pub fn copyToMemory(self: Buffer, platform: Platform, memory: pjrt.Memory.Kind, opts: CopyToMemoryOpts) !Buffer { + const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory); if (pjrt_memory == null) { - stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self }); + stdx.debug.panic("Memory destination `{t}` for {f}", .{ memory, self }); } var new_shards: Buffer.Shards = .{}; @@ -423,35 +400,34 @@ pub const Buffer = struct { } } - return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api }; + return Buffer{ ._api = self._api, ._target = platform.target, ._shape = self._shape, ._shards = new_shards }; } - pub const UnitializedOptions = struct { - memory: ?pjrt.Memory.Kind = null, - }; + pub const UnitializedOptions = struct { memory: Memory = .device }; pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer { + if (opts.memory != .device) { + // XLA uninitialized doesn't respect memory see https://github.com/openxla/xla/pull/31292 + // TODO: use uninitialized when it works again. + const host_buffer: HostBuffer = try .empty(std.heap.smp_allocator, shape_); + defer host_buffer.deinit(std.heap.smp_allocator); + return try .from(platform, host_buffer, .{ .wait = true, .memory = opts.memory }); + } + var res: Buffer = .{ ._api = platform.pjrt_api, ._shape = shape_, ._shards = .{}, + ._target = platform.target, }; errdefer for (res._shards.slice()) |shard| { shard.deinit(platform.pjrt_api); }; - const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { - var minor_to_major: [Shape.MAX_RANK]i64 = undefined; - for (0..Shape.MAX_RANK) |i| { - minor_to_major[i] = @intCast(Shape.MAX_RANK - i - 1); - } - break :blk minor_to_major; - }; - - // TODO: support more advanced sharding specs stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info); const n_partitions = platform.sharding().num_partitions; + const shard_shape = if (sharding_ax) |ax| s: { // This kind of sharding error should be detected earlier on. stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions }); @@ -459,37 +435,56 @@ pub const Buffer = struct { break :s shard_shape; } else shape_; - const buffer_type = bufferTypeFromDtype(shape_.dtype()); + var args = pjrt.Client.CreateUninitializedBufferArgs{ + .dims = shard_shape.dims(), + .element_type = bufferTypeFromDtype(shape_.dtype()), + .layout = .{ + .tiled = .{ + .minor_to_major = minorToMajor(shape_.rank()), + .tile_dims = &.{}, + .tile_dims_sizes = &.{}, + }, + }, + // set per device, see below. + .dst = undefined, + }; + const devices = platform.getDevices(); for (0..n_partitions) |i| { - var args = pjrt.Client.CreateUninitializedBufferArgs{ - .dims = shard_shape.dims(), - .element_type = buffer_type, - .layout = .{ - .tiled = .{ - .minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..], - .tile_dims = &.{}, - .tile_dims_sizes = &.{}, - }, - }, - }; - if (opts.memory) |memory_kind| { - const memories = try devices[i].addressableMemories(platform.pjrt_api); - const memory = for (memories) |m| { - const kind = m.kind(platform.pjrt_api); - if (kind == memory_kind) break m; - } else return error.NotFound; - args.memory = memory; - } else { - args.device = devices[i]; - } - const pjrt_buffer = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args); + args.dst = if (platform.target == .cpu or opts.memory == .device) + .{ .device = devices[i] } + else + .{ .memory = platform.memoryForDevice(opts.memory, devices[i]) }; - res._shards.appendAssumeCapacity(pjrt_buffer); + const shard = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args); + res._shards.appendAssumeCapacity(shard); } return res; } + + test uninitialized { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned }; + for (host_visible_memories) |memory| { + const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory }); + const x_ptr: [*]u8 = @ptrFromInt(x.devicePtr()); + @memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 }); + + const y = try x.getValue([6]u8); + try std.testing.expectEqualStrings("hello!", &y); + } + } + + pub fn isDeleted(self: Buffer) bool { + const deleted: bool = self._shards.get(0).isDeleted(self._api); + for (self._shards.slice()[1..]) |shard| { + stdx.debug.assert(shard.isDeleted(self._api) == deleted, "Buffer has some shards deleted but not others.", .{}); + } + return deleted; + } }; pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { @@ -517,3 +512,15 @@ test bufferTypeFromDtype { try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt))); } } + +const _MINOR_TO_MAJOR = blk: { + var ret: [Shape.MAX_RANK]i64 = undefined; + for (0..Shape.MAX_RANK) |i| { + ret[i] = @intCast(Shape.MAX_RANK - i - 1); + } + break :blk ret; +}; + +fn minorToMajor(rank: u8) []const i64 { + return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..]; +} diff --git a/zml/callback.zig b/zml/callback.zig index 4d06353..5af96d1 100644 --- a/zml/callback.zig +++ b/zml/callback.zig @@ -99,7 +99,7 @@ pub fn call( .api_version = .typed_ffi, .backend_config = .dict(mlir_ctx, &.{}), .additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }}, - .has_side_effect = true, + .has_side_effect = Callback.callback_config.has_side_effect, .output_operand_aliases = Callback.callback_config.output_operand_aliases, }, output_types, @@ -123,6 +123,7 @@ pub const Config = struct { // TODO: document precisely what `command_buffer_compatible` is doing and its limitations. traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false }, // TODO: handle sharded inputs + has_side_effect: bool = true, }; /// Compile-time check that a callback has all informations we require. @@ -190,12 +191,12 @@ fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt else .asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data); if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) { - log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() }); + log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.devicePtr() }); zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| { - log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err }); + log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.devicePtr(), err }); return .create(call_frame.api, .resource_exhausted, "host pinned OOM"); }; - log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* }); + log.debug("--> {f} {x}", .{ zml_buffer, zml_buffer.devicePtr() }); } callback_args[i] = zml_buffer; } @@ -282,6 +283,7 @@ pub const Print = struct { .copy_inputs_to_host_pinned = true, // Print is fairly predictable and can be captured in an execution graph. .traits = .{ .command_buffer_compatible = false }, + .has_side_effect = false, }; platform: Platform, diff --git a/zml/exe.zig b/zml/exe.zig index 37c9826..15d33b0 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -235,6 +235,7 @@ pub const BaseExe = struct { if (self.execute_context) |ctx| { ctx.deinit(self.platform.pjrt_api); } + self.exe.deinit(self.platform.pjrt_api); self._arena.deinit(); } @@ -395,6 +396,10 @@ pub fn Exe(ArgsT: type, ReturnT: type) type { self.inner._unsafeAssignResults(Bufferized(ReturnT), &result); return result; } + + pub fn clone(self: Self, allocator: std.mem.Allocator) !Self { + return .{ .inner = try self.inner.clone(allocator) }; + } }; } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 88419b7..3185348 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -322,7 +322,7 @@ pub const HostBuffer = struct { } pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void { - try writer.print("HostBuffer(.{f})", .{self._shape}); + try writer.print("HostBuffer(.{f})@{x}", .{ self._shape, @intFromPtr(self._data) }); } pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { diff --git a/zml/module.zig b/zml/module.zig index 5c27607..7ba63b2 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -221,6 +221,12 @@ pub const CompilationContext = struct { break :blk loaded_executable; }; + { + const exe = try loaded_executable.getExecutable(self._platform.pjrt_api); + const stats = try exe.getCompiledMemoryStats(self._platform.pjrt_api); + log.debug("Compiled {s}: {any}", .{ self._name, stats }); + } + log.debug("******** ZML generated MLIR ********", .{}); log.debug("{f}", .{module.op().mlirFormatter(.{})}); @@ -881,6 +887,9 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas); c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions); c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1); + if (platform.compilation_options.device_memory_size > 0) { + c.xla_ExecutableBuildOptionsProto_set_device_memory_size(exec_build_options, @intCast(platform.compilation_options.device_memory_size)); + } c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: { const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena); @@ -895,7 +904,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m } break :device_assignment_blk device_assignment; }); - break :executable_build_options_blk exec_build_options; }); diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 5932619..46c4052 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -243,7 +243,9 @@ pub const Event = opaque { if (ctx.err) |e| { defer e.deinit(api); - return e.getCode(api).toApiError(); + const err_code = e.getCode(api).toApiError(); + log.err("{t} {s}", .{ err_code, e.getMessage(api) }); + return err_code; } } }; diff --git a/zml/platform.zig b/zml/platform.zig index a6f4e14..410d60b 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -10,12 +10,17 @@ const log = std.log.scoped(.zml); pub const available_targets = std.enums.values(Target); +test { + std.testing.refAllDecls(@This()); +} + pub const CompilationOptions = struct { xla_dump_to: ?[]const u8 = null, xla_dump_fusion_visualization: bool = false, xla_dump_hlo_pass_re: ?[]const u8 = null, sharding_enabled: bool = false, sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{}, + device_memory_size: u64 = 0, }; pub const Platform = struct { @@ -29,7 +34,7 @@ pub const Platform = struct { // `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);` compilation_options: CompilationOptions = .{}, - pub const MAX_NUM_DEVICES: u8 = 32; + pub const MAX_NUM_DEVICES: u8 = if (runtimes.isEnabled(.tpu)) 32 else 8; pub const CreateOptions = _CreateOptions; pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform { @@ -79,6 +84,44 @@ pub const Platform = struct { pub fn deinit(self: *Platform) void { self.pjrt_client.deinit(self.pjrt_api); } + + pub fn memoryForDevice(platform: Platform, memory: pjrt.Memory.Kind, device: *const pjrt.Device) *const pjrt.Memory { + const memory_target: pjrt.Memory.Kind = switch (memory) { + .host_unpinned => switch (platform.target) { + // Cuda doesn't have host_unpinned. + .cuda => .host_pinned, + else => .host_unpinned, + }, + inline else => |t| t, + }; + // TODO measure the cost of this and consider caching. + const device_memories = device.addressableMemories(platform.pjrt_api); + for (device_memories) |m| { + if (memory_target == m.kind(platform.pjrt_api)) { + return m; + } + } + log.err("Platform {t} doesn't have memory {t}", .{ platform.target, memory }); + @panic("Memory kind not found"); + } + + test memoryForDevice { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + const memory_fields = @typeInfo(pjrt.Memory.Kind).@"enum".fields; + inline for (memory_fields) |field| { + for (platform.getDevices()) |dev| { + _ = platform.memoryForDevice(@field(pjrt.Memory.Kind, field.name), dev); + } + } + } + + pub fn memoryStats(platform: Platform, device_id: usize) pjrt.MemoryStats { + if (platform.target == .cpu) return .zeroes; + + const device = platform.getDevices()[device_id]; + return device.memoryStats(platform.pjrt_api) catch .zeroes; + } }; const _CreateOptions = struct { @@ -127,17 +170,21 @@ const _CreateOptions = struct { fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void { switch (self.allocator) { .platform => { - values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); + values.appendAssumeCapacity(.fromString("allocator", "platform")); }, .bfc, .async => |opt| { - values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator)); - values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate)); + values.appendAssumeCapacity(.fromString("allocator", switch (self.allocator) { + .bfc => "bfc", + .async => "cuda_async", + .platform => unreachable, + })); + values.appendAssumeCapacity(.from("preallocate", opt.preallocate)); if (opt.memory_fraction > 0) { - values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction)); + values.appendAssumeCapacity(.from("memory_fraction", opt.memory_fraction)); } if (opt.collective_memory_size_mb > 0) { const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024; - values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective)); + values.appendAssumeCapacity(.from("collective_memory_size", collective)); } }, } diff --git a/zml/tensor.zig b/zml/tensor.zig index 50b8c56..ff12d91 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -202,10 +202,8 @@ pub const Tensor = struct { const mlir_ctx = ctx.mlirCtx(); if (ctx.target() == .cpu) return self; - const memory_kind = @tagName(kind.toPjrtMemory()); - const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{ - .{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) }, + .{ "_xla_buffer_placement", .string(mlir_ctx, kind.pjrtName()) }, }); const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{ @@ -311,6 +309,11 @@ pub const Tensor = struct { return _result(res_shape, op.result(0)); } + /// Returns the given tensor as one contiguous buffer of bytes. + pub fn bytes(self: Tensor) Tensor { + return self.bitCast(.u8).flattenAll().withTags(.{.bytes}); + } + /// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor. pub fn countLeadingZeros(self: Tensor) Tensor { const loc = self.getContext().mlirCtx().location(@src()); @@ -2683,7 +2686,7 @@ pub const Tensor = struct { } { // Test with actual values and batching along axis .a - const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0); + const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0, .{}); defer operand.deinit(); const start_indices = (try zml.Buffer.fromArray( platform, @@ -2704,6 +2707,7 @@ pub const Tensor = struct { platform, Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16), 1, + .{}, ); defer values.deinit();