Add platform tag to buffers for target identification and safety checks; include workaround for PJRT uninitialized memory handling.
This commit is contained in:
parent
9aeb4e9cd0
commit
29bd1242ba
@ -360,9 +360,11 @@ pub const Client = opaque {
|
||||
buffer_type: BufferType,
|
||||
dims: []const i64,
|
||||
byte_strides: ?[]const i64,
|
||||
device: ?*const Device = null,
|
||||
host_buffer_semantics: HostBufferSemantics,
|
||||
memory: ?*const Memory = null,
|
||||
dst: union(enum) {
|
||||
device: *const Device,
|
||||
memory: *const Memory,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||
@ -375,11 +377,11 @@ pub const Client = opaque {
|
||||
.byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null,
|
||||
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
|
||||
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
|
||||
.device = @ptrCast(@constCast(args.device)),
|
||||
.memory = @ptrCast(@constCast(args.memory)),
|
||||
.device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null,
|
||||
.memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null,
|
||||
.device_layout = null, // TODO
|
||||
.done_with_host_buffer = null,
|
||||
.buffer = null,
|
||||
.done_with_host_buffer = null, // out
|
||||
.buffer = null, // out
|
||||
});
|
||||
|
||||
return .{
|
||||
@ -430,7 +432,7 @@ pub const Client = opaque {
|
||||
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
||||
const ret = api.call(.PJRT_Client_AddressableMemories, .{
|
||||
.client = self.inner(),
|
||||
}) catch unreachable;
|
||||
}) catch return &.{};
|
||||
if (ret.addressable_memories) |memories| {
|
||||
return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
|
||||
}
|
||||
@ -474,8 +476,10 @@ pub const Client = opaque {
|
||||
dims: []const i64,
|
||||
element_type: BufferType,
|
||||
layout: MemoryLayout,
|
||||
device: ?*const Device = null,
|
||||
memory: ?*const Memory = null,
|
||||
dst: union(enum) {
|
||||
device: *const Device,
|
||||
memory: *const Memory,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
||||
@ -486,8 +490,8 @@ pub const Client = opaque {
|
||||
.shape_num_dims = @intCast(args.dims.len),
|
||||
.shape_element_type = @intFromEnum(args.element_type),
|
||||
.shape_layout = @ptrCast(&layout),
|
||||
.device = @ptrCast(@constCast(args.device)),
|
||||
.memory = @ptrCast(@constCast(args.memory)),
|
||||
.device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null,
|
||||
.memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null,
|
||||
});
|
||||
return @ptrCast(ret.buffer.?);
|
||||
}
|
||||
@ -530,6 +534,8 @@ pub const MemoryStats = struct {
|
||||
pool_bytes_is_set: bool, // out
|
||||
peak_pool_bytes: u64, // out
|
||||
peak_pool_bytes_is_set: bool, // out
|
||||
|
||||
pub const zeroes = std.mem.zeroes(MemoryStats);
|
||||
};
|
||||
|
||||
pub const Device = opaque {
|
||||
@ -556,10 +562,11 @@ pub const Device = opaque {
|
||||
return @intCast(ret.local_hardware_id);
|
||||
}
|
||||
|
||||
pub fn addressableMemories(self: *const Device, api: *const Api) ApiError![]const *Memory {
|
||||
const ret = try api.call(.PJRT_Device_AddressableMemories, .{
|
||||
.device = self.inner(),
|
||||
});
|
||||
pub fn addressableMemories(self: *const Device, api: *const Api) []const *Memory {
|
||||
const ret = api.call(
|
||||
.PJRT_Device_AddressableMemories,
|
||||
.{ .device = self.inner() },
|
||||
) catch return &.{};
|
||||
return @ptrCast(ret.memories[0..ret.num_memories]);
|
||||
}
|
||||
|
||||
@ -728,7 +735,6 @@ pub const LoadedExecutable = opaque {
|
||||
_ = api.call(.PJRT_LoadedExecutable_Destroy, .{
|
||||
.executable = self.inner(),
|
||||
}) catch {};
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
||||
@ -759,6 +765,7 @@ pub const LoadedExecutable = opaque {
|
||||
non_donatable_input_indices: []const i64 = &.{},
|
||||
context: ?*ExecuteContext,
|
||||
};
|
||||
|
||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||
.send_callbacks = null,
|
||||
@ -1048,8 +1055,16 @@ pub const Event = opaque {
|
||||
pub const Memory = opaque {
|
||||
pub const Kind = enum {
|
||||
device,
|
||||
pinned_host,
|
||||
unpinned_host,
|
||||
host_pinned,
|
||||
host_unpinned,
|
||||
|
||||
pub fn pjrtName(k: Kind) []const u8 {
|
||||
return switch (k) {
|
||||
.device => "device",
|
||||
.host_pinned => "pinned_host",
|
||||
.host_unpinned => "unpinned_host",
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const inner = InnerMixin(c.PJRT_Memory).inner;
|
||||
@ -1061,8 +1076,12 @@ pub const Memory = opaque {
|
||||
|
||||
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
||||
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
||||
const kind_ = ret.kind orelse unreachable;
|
||||
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
|
||||
return switch (ret.kind_size) {
|
||||
"device".len => .device,
|
||||
"pinned_host".len => .host_pinned,
|
||||
"unpinned_host".len => .host_unpinned,
|
||||
else => @panic("Memory kind not supported"),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
||||
|
||||
@ -41,8 +41,9 @@
|
||||
//! caller to manage the lifetime. The caller should be skipping program name.
|
||||
|
||||
const std = @import("std");
|
||||
const builtin = @import("builtin");
|
||||
const assert = std.debug.assert;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const debug = @import("debug.zig");
|
||||
|
||||
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
||||
@ -285,7 +286,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
||||
|
||||
fn assert_valid_value_type(comptime T: type) void {
|
||||
comptime {
|
||||
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int) return;
|
||||
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int or @typeInfo(T) == .float) return;
|
||||
|
||||
if (@typeInfo(T) == .@"enum") {
|
||||
const info = @typeInfo(T).@"enum";
|
||||
@ -347,6 +348,7 @@ fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||
if (V == []const u8 or V == [:0]const u8) return value;
|
||||
if (V == ByteSize) return parse_value_size(flag, value);
|
||||
if (@typeInfo(V) == .int) return parse_value_int(V, flag, value);
|
||||
if (@typeInfo(V) == .float) return parse_value_float(V, flag, value);
|
||||
if (@typeInfo(V) == .@"enum") return parse_value_enum(V, flag, value);
|
||||
comptime unreachable;
|
||||
}
|
||||
@ -515,6 +517,20 @@ fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||
};
|
||||
}
|
||||
|
||||
/// Parse string value into a float, providing a nice error message for the user.
|
||||
fn parse_value_float(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||
|
||||
return std.fmt.parseFloat(T, value) catch |err| {
|
||||
switch (err) {
|
||||
error.InvalidCharacter => fatal(
|
||||
"{s}: expected a decimal value, but found '{s}' (invalid character)",
|
||||
.{ flag, value },
|
||||
),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
||||
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||
comptime assert(@typeInfo(E).@"enum".is_exhaustive);
|
||||
|
||||
@ -116,12 +116,17 @@ pub const BufferStore = struct {
|
||||
if (id < self._unique_id or self._unique_id + _store_id_range <= id) {
|
||||
@panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object.");
|
||||
}
|
||||
if (platform.target != .cpu) mem_debug: {
|
||||
const stats = platform.getDevices()[0].memoryStats(platform.pjrt_api) catch break :mem_debug;
|
||||
log.debug("Loading {s} -> {f} {d:>10} bytes ({d:>10} allocated / {d:>10} reserved)", .{ self.buffers.keys()[id - self._unique_id], x._shape, x.shape().byteSize(), stats.bytes_in_use, stats.bytes_reserved });
|
||||
}
|
||||
break :hb self.buffers.values()[id - self._unique_id];
|
||||
},
|
||||
else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"),
|
||||
};
|
||||
|
||||
// Use the sharding information stored in the tensor.
|
||||
std.debug.assert(host_buffer.shape().eql(x.shape()));
|
||||
host_buffer._shape = x.shape();
|
||||
return try host_buffer.toDevice(platform);
|
||||
}
|
||||
@ -703,7 +708,7 @@ pub fn loadModelBuffersWithPrefix(
|
||||
var res: zml.Bufferized(Model) = undefined;
|
||||
try zml.meta.mapAlloc(struct {
|
||||
pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer {
|
||||
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined };
|
||||
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined, ._target = undefined };
|
||||
}
|
||||
}.initBuffer, allocator, {}, model, &res);
|
||||
|
||||
|
||||
209
zml/buffer.zig
209
zml/buffer.zig
@ -8,6 +8,7 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Target = @import("platform.zig").Target;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
@ -22,40 +23,22 @@ const log = std.log.scoped(.zml);
|
||||
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||
pub const Buffer = struct {
|
||||
pub const Memory = enum {
|
||||
host,
|
||||
host_pinned,
|
||||
device,
|
||||
|
||||
pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind {
|
||||
return switch (self) {
|
||||
.host => .unpinned_host,
|
||||
.host_pinned => .pinned_host,
|
||||
.device => .device,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn pjrtName(self: Memory) []const u8 {
|
||||
return @tagName(self.toPjrtMemory());
|
||||
}
|
||||
};
|
||||
|
||||
_shape: Shape,
|
||||
_api: *const pjrt.Api,
|
||||
_target: Target,
|
||||
_shards: Shards,
|
||||
|
||||
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||
pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||
|
||||
pub const FromOptions = struct {
|
||||
wait: bool = true,
|
||||
memory: ?Memory = null,
|
||||
};
|
||||
pub const Memory = pjrt.Memory.Kind;
|
||||
pub const FromOptions = struct { wait: bool = true, memory: Memory = .device };
|
||||
|
||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
|
||||
var res: Buffer = .{
|
||||
._api = platform.pjrt_api,
|
||||
._target = platform.target,
|
||||
._shape = host_buffer.shape(),
|
||||
._shards = .{},
|
||||
};
|
||||
@ -82,35 +65,22 @@ pub const Buffer = struct {
|
||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||
} else host_buffer;
|
||||
|
||||
var args = pjrt.Client.BufferFromHostBufferArgs{
|
||||
const args = pjrt.Client.BufferFromHostBufferArgs{
|
||||
.data = buf._data,
|
||||
.buffer_type = buffer_type,
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = byte_strides,
|
||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||
// CPU has no distinctions between memories.
|
||||
.dst = if (platform.target == .cpu or opts.memory == .device)
|
||||
.{ .device = devices[i] }
|
||||
else
|
||||
.{ .memory = platform.memoryForDevice(opts.memory, devices[i]) },
|
||||
};
|
||||
if (platform.target == .cpu or opts.memory == null) {
|
||||
args.device = devices[i];
|
||||
} else {
|
||||
const memory = opts.memory.?;
|
||||
const device_memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||
// TODO measure the cost of this and consider caching on Zig side inside the platform.
|
||||
const selected_memory = for (device_memories) |m| {
|
||||
const kind = m.kind(platform.pjrt_api);
|
||||
if (kind == memory.toPjrtMemory()) break m;
|
||||
} else {
|
||||
log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) });
|
||||
return error.NotFound;
|
||||
};
|
||||
args.memory = selected_memory;
|
||||
}
|
||||
|
||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||
|
||||
if (event) |ev| {
|
||||
ev.deinit(platform.pjrt_api);
|
||||
}
|
||||
|
||||
if (event) |ev| ev.deinit(platform.pjrt_api);
|
||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||
}
|
||||
|
||||
@ -131,6 +101,15 @@ pub const Buffer = struct {
|
||||
return self;
|
||||
}
|
||||
|
||||
pub fn awaitBlocking(self: Buffer) !Buffer {
|
||||
for (self._shards.constSlice()) |buffer| {
|
||||
if (buffer.getReadyEvent(self._api)) |ev| {
|
||||
try ev.awaitBlocking(self._api);
|
||||
}
|
||||
}
|
||||
return self;
|
||||
}
|
||||
|
||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||
@ -139,6 +118,7 @@ pub const Buffer = struct {
|
||||
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
||||
return .{
|
||||
._api = platform.pjrt_api,
|
||||
._target = platform.target,
|
||||
._shape = shape_,
|
||||
._shards = shards,
|
||||
};
|
||||
@ -185,9 +165,10 @@ pub const Buffer = struct {
|
||||
}
|
||||
|
||||
pub fn asHostBuffer(self: Buffer) HostBuffer {
|
||||
// TODO: skip this check on cpu
|
||||
// const memory = self.getMemory().kind(self._api);
|
||||
// stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
||||
if (self._target != .cpu) {
|
||||
const memory = self.getMemory().kind(self._api);
|
||||
stdx.debug.assert((memory == .host_pinned) or (memory == .host_unpinned), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
||||
}
|
||||
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||
}
|
||||
@ -200,7 +181,7 @@ pub const Buffer = struct {
|
||||
}
|
||||
|
||||
/// Creates a Buffer with a single element repeated manytime.
|
||||
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
|
||||
pub fn constant(platform: Platform, shape_: Shape, val: anytype, opts: FromOptions) !Buffer {
|
||||
var start = try std.time.Timer.start();
|
||||
defer {
|
||||
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
||||
@ -210,6 +191,8 @@ pub const Buffer = struct {
|
||||
}
|
||||
}
|
||||
|
||||
// Constant is always blocking because it uses pointer to stack memory.
|
||||
const cst_opts: FromOptions = .{ .memory = opts.memory, .wait = true };
|
||||
// Convert val to the requested dtype.
|
||||
const x = shape_.dtype().constant(val);
|
||||
const byte_size = shape_.dtype().sizeOf();
|
||||
@ -222,7 +205,7 @@ pub const Buffer = struct {
|
||||
._strides = @splat(0),
|
||||
._data = x.constSlice().ptr,
|
||||
};
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
return try from(platform, host_buffer, cst_opts);
|
||||
}
|
||||
|
||||
// To speed up copies, duplicate the scalar value into a vector,
|
||||
@ -245,14 +228,14 @@ pub const Buffer = struct {
|
||||
else => unreachable,
|
||||
}
|
||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||
return try from(platform, host_buffer, .{ .wait = true });
|
||||
return try from(platform, host_buffer, cst_opts);
|
||||
}
|
||||
|
||||
test constant {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42);
|
||||
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42, .{ .wait = true });
|
||||
const y = try x.getValue([4 * 3 * 2]u16);
|
||||
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
|
||||
}
|
||||
@ -271,14 +254,6 @@ pub const Buffer = struct {
|
||||
/// Creates a Buffer from a pointer into device memory.
|
||||
/// This allows to interface with other libraries producing buffers.
|
||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
|
||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
res[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||
}
|
||||
break :blk res;
|
||||
};
|
||||
|
||||
const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||
.data = device_data,
|
||||
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||
@ -287,7 +262,7 @@ pub const Buffer = struct {
|
||||
.device = platform.getDevices()[0],
|
||||
.layout = .{
|
||||
.tiled = .{
|
||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
||||
.minor_to_major = minorToMajor(shape_.rank()),
|
||||
.tile_dims = &.{},
|
||||
.tile_dims_sizes = &.{},
|
||||
},
|
||||
@ -299,15 +274,16 @@ pub const Buffer = struct {
|
||||
shards.appendAssumeCapacity(pjrt_buffer);
|
||||
return .{
|
||||
._api = platform.pjrt_api,
|
||||
._target = platform.target,
|
||||
._shape = shape_,
|
||||
._shards = shards,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 {
|
||||
pub fn devicePtr(self: Buffer) u64 {
|
||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
|
||||
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
|
||||
return @ptrCast(opaque_ptr);
|
||||
return @intFromPtr(opaque_ptr);
|
||||
}
|
||||
|
||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||
@ -350,6 +326,7 @@ pub const Buffer = struct {
|
||||
/// Depending on the platform, the memory is typically not released to the OS
|
||||
/// but just marked as available in the memory pool.
|
||||
pub fn deinit(self: *const Buffer) void {
|
||||
// log.warn("Unloading {f} {d} bytes", .{ self._shape, self._shape.byteSize() });
|
||||
for (self._shards.constSlice()) |buffer| {
|
||||
buffer.deinit(self._api);
|
||||
}
|
||||
@ -385,7 +362,7 @@ pub const Buffer = struct {
|
||||
}
|
||||
|
||||
pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
|
||||
try writer.print("Buffer({f})", .{self._shape});
|
||||
try writer.print("Buffer({f})@{x}", .{ self._shape, self.devicePtr() });
|
||||
}
|
||||
|
||||
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
||||
@ -402,10 +379,10 @@ pub const Buffer = struct {
|
||||
wait: bool = true,
|
||||
};
|
||||
|
||||
pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer {
|
||||
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory());
|
||||
pub fn copyToMemory(self: Buffer, platform: Platform, memory: pjrt.Memory.Kind, opts: CopyToMemoryOpts) !Buffer {
|
||||
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory);
|
||||
if (pjrt_memory == null) {
|
||||
stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self });
|
||||
stdx.debug.panic("Memory destination `{t}` for {f}", .{ memory, self });
|
||||
}
|
||||
|
||||
var new_shards: Buffer.Shards = .{};
|
||||
@ -423,35 +400,34 @@ pub const Buffer = struct {
|
||||
}
|
||||
}
|
||||
|
||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||
return Buffer{ ._api = self._api, ._target = platform.target, ._shape = self._shape, ._shards = new_shards };
|
||||
}
|
||||
|
||||
pub const UnitializedOptions = struct {
|
||||
memory: ?pjrt.Memory.Kind = null,
|
||||
};
|
||||
pub const UnitializedOptions = struct { memory: Memory = .device };
|
||||
|
||||
pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer {
|
||||
if (opts.memory != .device) {
|
||||
// XLA uninitialized doesn't respect memory see https://github.com/openxla/xla/pull/31292
|
||||
// TODO: use uninitialized when it works again.
|
||||
const host_buffer: HostBuffer = try .empty(std.heap.smp_allocator, shape_);
|
||||
defer host_buffer.deinit(std.heap.smp_allocator);
|
||||
return try .from(platform, host_buffer, .{ .wait = true, .memory = opts.memory });
|
||||
}
|
||||
|
||||
var res: Buffer = .{
|
||||
._api = platform.pjrt_api,
|
||||
._shape = shape_,
|
||||
._shards = .{},
|
||||
._target = platform.target,
|
||||
};
|
||||
errdefer for (res._shards.slice()) |shard| {
|
||||
shard.deinit(platform.pjrt_api);
|
||||
};
|
||||
|
||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||
var minor_to_major: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
minor_to_major[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||
}
|
||||
break :blk minor_to_major;
|
||||
};
|
||||
|
||||
// TODO: support more advanced sharding specs
|
||||
stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
||||
const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info);
|
||||
const n_partitions = platform.sharding().num_partitions;
|
||||
|
||||
const shard_shape = if (sharding_ax) |ax| s: {
|
||||
// This kind of sharding error should be detected earlier on.
|
||||
stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions });
|
||||
@ -459,37 +435,56 @@ pub const Buffer = struct {
|
||||
break :s shard_shape;
|
||||
} else shape_;
|
||||
|
||||
const buffer_type = bufferTypeFromDtype(shape_.dtype());
|
||||
var args = pjrt.Client.CreateUninitializedBufferArgs{
|
||||
.dims = shard_shape.dims(),
|
||||
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||
.layout = .{
|
||||
.tiled = .{
|
||||
.minor_to_major = minorToMajor(shape_.rank()),
|
||||
.tile_dims = &.{},
|
||||
.tile_dims_sizes = &.{},
|
||||
},
|
||||
},
|
||||
// set per device, see below.
|
||||
.dst = undefined,
|
||||
};
|
||||
|
||||
const devices = platform.getDevices();
|
||||
for (0..n_partitions) |i| {
|
||||
var args = pjrt.Client.CreateUninitializedBufferArgs{
|
||||
.dims = shard_shape.dims(),
|
||||
.element_type = buffer_type,
|
||||
.layout = .{
|
||||
.tiled = .{
|
||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
||||
.tile_dims = &.{},
|
||||
.tile_dims_sizes = &.{},
|
||||
},
|
||||
},
|
||||
};
|
||||
if (opts.memory) |memory_kind| {
|
||||
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||
const memory = for (memories) |m| {
|
||||
const kind = m.kind(platform.pjrt_api);
|
||||
if (kind == memory_kind) break m;
|
||||
} else return error.NotFound;
|
||||
args.memory = memory;
|
||||
} else {
|
||||
args.device = devices[i];
|
||||
}
|
||||
const pjrt_buffer = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args);
|
||||
args.dst = if (platform.target == .cpu or opts.memory == .device)
|
||||
.{ .device = devices[i] }
|
||||
else
|
||||
.{ .memory = platform.memoryForDevice(opts.memory, devices[i]) };
|
||||
|
||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||
const shard = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args);
|
||||
res._shards.appendAssumeCapacity(shard);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
test uninitialized {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned };
|
||||
for (host_visible_memories) |memory| {
|
||||
const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory });
|
||||
const x_ptr: [*]u8 = @ptrFromInt(x.devicePtr());
|
||||
@memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 });
|
||||
|
||||
const y = try x.getValue([6]u8);
|
||||
try std.testing.expectEqualStrings("hello!", &y);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn isDeleted(self: Buffer) bool {
|
||||
const deleted: bool = self._shards.get(0).isDeleted(self._api);
|
||||
for (self._shards.slice()[1..]) |shard| {
|
||||
stdx.debug.assert(shard.isDeleted(self._api) == deleted, "Buffer has some shards deleted but not others.", .{});
|
||||
}
|
||||
return deleted;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||
@ -517,3 +512,15 @@ test bufferTypeFromDtype {
|
||||
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||
}
|
||||
}
|
||||
|
||||
const _MINOR_TO_MAJOR = blk: {
|
||||
var ret: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||
}
|
||||
break :blk ret;
|
||||
};
|
||||
|
||||
fn minorToMajor(rank: u8) []const i64 {
|
||||
return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..];
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ pub fn call(
|
||||
.api_version = .typed_ffi,
|
||||
.backend_config = .dict(mlir_ctx, &.{}),
|
||||
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
|
||||
.has_side_effect = true,
|
||||
.has_side_effect = Callback.callback_config.has_side_effect,
|
||||
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
|
||||
},
|
||||
output_types,
|
||||
@ -123,6 +123,7 @@ pub const Config = struct {
|
||||
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
|
||||
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
|
||||
// TODO: handle sharded inputs
|
||||
has_side_effect: bool = true,
|
||||
};
|
||||
|
||||
/// Compile-time check that a callback has all informations we require.
|
||||
@ -190,12 +191,12 @@ fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt
|
||||
else
|
||||
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
|
||||
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||
log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() });
|
||||
log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.devicePtr() });
|
||||
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
|
||||
log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err });
|
||||
log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.devicePtr(), err });
|
||||
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
|
||||
};
|
||||
log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* });
|
||||
log.debug("--> {f} {x}", .{ zml_buffer, zml_buffer.devicePtr() });
|
||||
}
|
||||
callback_args[i] = zml_buffer;
|
||||
}
|
||||
@ -282,6 +283,7 @@ pub const Print = struct {
|
||||
.copy_inputs_to_host_pinned = true,
|
||||
// Print is fairly predictable and can be captured in an execution graph.
|
||||
.traits = .{ .command_buffer_compatible = false },
|
||||
.has_side_effect = false,
|
||||
};
|
||||
|
||||
platform: Platform,
|
||||
|
||||
@ -235,6 +235,7 @@ pub const BaseExe = struct {
|
||||
if (self.execute_context) |ctx| {
|
||||
ctx.deinit(self.platform.pjrt_api);
|
||||
}
|
||||
self.exe.deinit(self.platform.pjrt_api);
|
||||
self._arena.deinit();
|
||||
}
|
||||
|
||||
@ -395,6 +396,10 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
||||
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn clone(self: Self, allocator: std.mem.Allocator) !Self {
|
||||
return .{ .inner = try self.inner.clone(allocator) };
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@ -322,7 +322,7 @@ pub const HostBuffer = struct {
|
||||
}
|
||||
|
||||
pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
|
||||
try writer.print("HostBuffer(.{f})", .{self._shape});
|
||||
try writer.print("HostBuffer(.{f})@{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||
}
|
||||
|
||||
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||
|
||||
@ -221,6 +221,12 @@ pub const CompilationContext = struct {
|
||||
break :blk loaded_executable;
|
||||
};
|
||||
|
||||
{
|
||||
const exe = try loaded_executable.getExecutable(self._platform.pjrt_api);
|
||||
const stats = try exe.getCompiledMemoryStats(self._platform.pjrt_api);
|
||||
log.debug("Compiled {s}: {any}", .{ self._name, stats });
|
||||
}
|
||||
|
||||
log.debug("******** ZML generated MLIR ********", .{});
|
||||
log.debug("{f}", .{module.op().mlirFormatter(.{})});
|
||||
|
||||
@ -881,6 +887,9 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
|
||||
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
|
||||
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
|
||||
if (platform.compilation_options.device_memory_size > 0) {
|
||||
c.xla_ExecutableBuildOptionsProto_set_device_memory_size(exec_build_options, @intCast(platform.compilation_options.device_memory_size));
|
||||
}
|
||||
|
||||
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
|
||||
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
|
||||
@ -895,7 +904,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
}
|
||||
break :device_assignment_blk device_assignment;
|
||||
});
|
||||
|
||||
break :executable_build_options_blk exec_build_options;
|
||||
});
|
||||
|
||||
|
||||
@ -243,7 +243,9 @@ pub const Event = opaque {
|
||||
|
||||
if (ctx.err) |e| {
|
||||
defer e.deinit(api);
|
||||
return e.getCode(api).toApiError();
|
||||
const err_code = e.getCode(api).toApiError();
|
||||
log.err("{t} {s}", .{ err_code, e.getMessage(api) });
|
||||
return err_code;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -10,12 +10,17 @@ const log = std.log.scoped(.zml);
|
||||
|
||||
pub const available_targets = std.enums.values(Target);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
pub const CompilationOptions = struct {
|
||||
xla_dump_to: ?[]const u8 = null,
|
||||
xla_dump_fusion_visualization: bool = false,
|
||||
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
||||
sharding_enabled: bool = false,
|
||||
sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
|
||||
device_memory_size: u64 = 0,
|
||||
};
|
||||
|
||||
pub const Platform = struct {
|
||||
@ -29,7 +34,7 @@ pub const Platform = struct {
|
||||
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
||||
compilation_options: CompilationOptions = .{},
|
||||
|
||||
pub const MAX_NUM_DEVICES: u8 = 32;
|
||||
pub const MAX_NUM_DEVICES: u8 = if (runtimes.isEnabled(.tpu)) 32 else 8;
|
||||
pub const CreateOptions = _CreateOptions;
|
||||
|
||||
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
|
||||
@ -79,6 +84,44 @@ pub const Platform = struct {
|
||||
pub fn deinit(self: *Platform) void {
|
||||
self.pjrt_client.deinit(self.pjrt_api);
|
||||
}
|
||||
|
||||
pub fn memoryForDevice(platform: Platform, memory: pjrt.Memory.Kind, device: *const pjrt.Device) *const pjrt.Memory {
|
||||
const memory_target: pjrt.Memory.Kind = switch (memory) {
|
||||
.host_unpinned => switch (platform.target) {
|
||||
// Cuda doesn't have host_unpinned.
|
||||
.cuda => .host_pinned,
|
||||
else => .host_unpinned,
|
||||
},
|
||||
inline else => |t| t,
|
||||
};
|
||||
// TODO measure the cost of this and consider caching.
|
||||
const device_memories = device.addressableMemories(platform.pjrt_api);
|
||||
for (device_memories) |m| {
|
||||
if (memory_target == m.kind(platform.pjrt_api)) {
|
||||
return m;
|
||||
}
|
||||
}
|
||||
log.err("Platform {t} doesn't have memory {t}", .{ platform.target, memory });
|
||||
@panic("Memory kind not found");
|
||||
}
|
||||
|
||||
test memoryForDevice {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
const memory_fields = @typeInfo(pjrt.Memory.Kind).@"enum".fields;
|
||||
inline for (memory_fields) |field| {
|
||||
for (platform.getDevices()) |dev| {
|
||||
_ = platform.memoryForDevice(@field(pjrt.Memory.Kind, field.name), dev);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn memoryStats(platform: Platform, device_id: usize) pjrt.MemoryStats {
|
||||
if (platform.target == .cpu) return .zeroes;
|
||||
|
||||
const device = platform.getDevices()[device_id];
|
||||
return device.memoryStats(platform.pjrt_api) catch .zeroes;
|
||||
}
|
||||
};
|
||||
|
||||
const _CreateOptions = struct {
|
||||
@ -127,17 +170,21 @@ const _CreateOptions = struct {
|
||||
fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
|
||||
switch (self.allocator) {
|
||||
.platform => {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
||||
values.appendAssumeCapacity(.fromString("allocator", "platform"));
|
||||
},
|
||||
.bfc, .async => |opt| {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
|
||||
values.appendAssumeCapacity(.fromString("allocator", switch (self.allocator) {
|
||||
.bfc => "bfc",
|
||||
.async => "cuda_async",
|
||||
.platform => unreachable,
|
||||
}));
|
||||
values.appendAssumeCapacity(.from("preallocate", opt.preallocate));
|
||||
if (opt.memory_fraction > 0) {
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction));
|
||||
values.appendAssumeCapacity(.from("memory_fraction", opt.memory_fraction));
|
||||
}
|
||||
if (opt.collective_memory_size_mb > 0) {
|
||||
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
|
||||
values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective));
|
||||
values.appendAssumeCapacity(.from("collective_memory_size", collective));
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -202,10 +202,8 @@ pub const Tensor = struct {
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
if (ctx.target() == .cpu) return self;
|
||||
|
||||
const memory_kind = @tagName(kind.toPjrtMemory());
|
||||
|
||||
const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{
|
||||
.{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) },
|
||||
.{ "_xla_buffer_placement", .string(mlir_ctx, kind.pjrtName()) },
|
||||
});
|
||||
|
||||
const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{
|
||||
@ -311,6 +309,11 @@ pub const Tensor = struct {
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
|
||||
/// Returns the given tensor as one contiguous buffer of bytes.
|
||||
pub fn bytes(self: Tensor) Tensor {
|
||||
return self.bitCast(.u8).flattenAll().withTags(.{.bytes});
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
|
||||
pub fn countLeadingZeros(self: Tensor) Tensor {
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
@ -2683,7 +2686,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
{
|
||||
// Test with actual values and batching along axis .a
|
||||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
|
||||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0, .{});
|
||||
defer operand.deinit();
|
||||
const start_indices = (try zml.Buffer.fromArray(
|
||||
platform,
|
||||
@ -2704,6 +2707,7 @@ pub const Tensor = struct {
|
||||
platform,
|
||||
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
|
||||
1,
|
||||
.{},
|
||||
);
|
||||
defer values.deinit();
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user