From ff1433d99858d96c889b9c33116ce6e1ea089956 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 25 Feb 2025 10:37:45 +0000 Subject: [PATCH] pjrt: bind PJRT_Client_CreateUninitializedBuffer. --- pjrt/pjrt.zig | 22 +++++++++++++++++ zml/buffer.zig | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++ zml/pjrtx.zig | 6 +++++ 3 files changed, 93 insertions(+) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index e23c321..18260f9 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -483,6 +483,28 @@ pub const Client = opaque { }); return @ptrCast(ret.transfer_manager.?); } + + pub const CreateUninitializedBufferArgs = struct { + dims: []const i64, + element_type: BufferType, + layout: MemoryLayout, + device: ?*const Device = null, + memory: ?*const Memory = null, + }; + + pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer { + var layout = args.layout.toCStruct(); + const ret = try api.call(.PJRT_Client_CreateUninitializedBuffer, .{ + .client = self.inner(), + .shape_dims = args.dims.ptr, + .shape_num_dims = @intCast(args.dims.len), + .shape_element_type = @intFromEnum(args.element_type), + .shape_layout = @ptrCast(&layout), + .device = @ptrCast(@constCast(args.device)), + .memory = @ptrCast(@constCast(args.memory)), + }); + return @ptrCast(ret.buffer.?); + } }; pub const MemoryStats = struct { diff --git a/zml/buffer.zig b/zml/buffer.zig index 3854cc3..36f6125 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -403,6 +403,71 @@ pub const Buffer = struct { return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api }; } + + pub const UnitializedOptions = struct { + memory: ?pjrt.Memory.Kind = null, + }; + + pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer { + var res: Buffer = .{ + ._api = platform.pjrt_api, + ._shape = shape_, + ._shards = .{}, + }; + errdefer for (res._shards.slice()) |shard| { + shard.deinit(platform.pjrt_api); + }; + + const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { + var minor_to_major: [Shape.MAX_RANK]i64 = undefined; + for (0..Shape.MAX_RANK) |i| { + minor_to_major[i] = @intCast(Shape.MAX_RANK - i - 1); + } + break :blk minor_to_major; + }; + + // TODO: support more advanced sharding specs + stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()}); + const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info); + const n_partitions = platform.sharding().num_partitions; + const shard_shape = if (sharding_ax) |ax| s: { + // This kind of sharding error should be detected earlier on. + stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions }); + const shard_shape = shape_.set(ax, @divExact(shape_.dim(ax), n_partitions)); + break :s shard_shape; + } else shape_; + + const buffer_type = bufferTypeFromDtype(shape_.dtype()); + const devices = platform.getDevices(); + for (0..n_partitions) |i| { + var args = pjrt.Client.CreateUninitializedBufferArgs{ + .dims = shard_shape.dims(), + .element_type = buffer_type, + .layout = .{ + .tiled = .{ + .minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..], + .tile_dims = &.{}, + .tile_dims_sizes = &.{}, + }, + }, + }; + if (opts.memory) |memory_kind| { + const memories = try devices[i].addressableMemories(platform.pjrt_api); + const memory = for (memories) |m| { + const kind = m.kind(platform.pjrt_api); + if (kind == memory_kind) break m; + } else return error.NotFound; + args.memory = memory; + } else { + args.device = devices[i]; + } + const pjrt_buffer = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args); + + res._shards.appendAssumeCapacity(pjrt_buffer); + } + + return res; + } }; pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType { diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 22f4d7f..1795efc 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -128,6 +128,12 @@ pub const Client = opaque { } return null; } + + pub const CreateUninitializedBufferArgs = pjrt.Client.CreateUninitializedBufferArgs; + + pub fn createUnitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer { + return @ptrCast(try self.inner().createUninitializedBuffer(api, args)); + } }; pub const Buffer = opaque {