Add platform tag to buffers for target identification and safety checks; include workaround for PJRT uninitialized memory handling.
This commit is contained in:
parent
9aeb4e9cd0
commit
29bd1242ba
@ -360,9 +360,11 @@ pub const Client = opaque {
|
|||||||
buffer_type: BufferType,
|
buffer_type: BufferType,
|
||||||
dims: []const i64,
|
dims: []const i64,
|
||||||
byte_strides: ?[]const i64,
|
byte_strides: ?[]const i64,
|
||||||
device: ?*const Device = null,
|
|
||||||
host_buffer_semantics: HostBufferSemantics,
|
host_buffer_semantics: HostBufferSemantics,
|
||||||
memory: ?*const Memory = null,
|
dst: union(enum) {
|
||||||
|
device: *const Device,
|
||||||
|
memory: *const Memory,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!struct { *Buffer, ?*Event } {
|
||||||
@ -375,11 +377,11 @@ pub const Client = opaque {
|
|||||||
.byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null,
|
.byte_strides = if (args.byte_strides) |bs| @ptrCast(@constCast(bs.ptr)) else null,
|
||||||
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
|
.num_byte_strides = if (args.byte_strides) |bs| bs.len else 0,
|
||||||
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
|
.host_buffer_semantics = @intFromEnum(args.host_buffer_semantics),
|
||||||
.device = @ptrCast(@constCast(args.device)),
|
.device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null,
|
||||||
.memory = @ptrCast(@constCast(args.memory)),
|
.memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null,
|
||||||
.device_layout = null, // TODO
|
.device_layout = null, // TODO
|
||||||
.done_with_host_buffer = null,
|
.done_with_host_buffer = null, // out
|
||||||
.buffer = null,
|
.buffer = null, // out
|
||||||
});
|
});
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
@ -430,7 +432,7 @@ pub const Client = opaque {
|
|||||||
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory {
|
||||||
const ret = api.call(.PJRT_Client_AddressableMemories, .{
|
const ret = api.call(.PJRT_Client_AddressableMemories, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
}) catch unreachable;
|
}) catch return &.{};
|
||||||
if (ret.addressable_memories) |memories| {
|
if (ret.addressable_memories) |memories| {
|
||||||
return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
|
return @ptrCast(@constCast(memories[0..ret.num_addressable_memories]));
|
||||||
}
|
}
|
||||||
@ -474,8 +476,10 @@ pub const Client = opaque {
|
|||||||
dims: []const i64,
|
dims: []const i64,
|
||||||
element_type: BufferType,
|
element_type: BufferType,
|
||||||
layout: MemoryLayout,
|
layout: MemoryLayout,
|
||||||
device: ?*const Device = null,
|
dst: union(enum) {
|
||||||
memory: ?*const Memory = null,
|
device: *const Device,
|
||||||
|
memory: *const Memory,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
||||||
@ -486,8 +490,8 @@ pub const Client = opaque {
|
|||||||
.shape_num_dims = @intCast(args.dims.len),
|
.shape_num_dims = @intCast(args.dims.len),
|
||||||
.shape_element_type = @intFromEnum(args.element_type),
|
.shape_element_type = @intFromEnum(args.element_type),
|
||||||
.shape_layout = @ptrCast(&layout),
|
.shape_layout = @ptrCast(&layout),
|
||||||
.device = @ptrCast(@constCast(args.device)),
|
.device = if (args.dst == .device) @ptrCast(@constCast(args.dst.device)) else null,
|
||||||
.memory = @ptrCast(@constCast(args.memory)),
|
.memory = if (args.dst == .memory) @ptrCast(@constCast(args.dst.memory)) else null,
|
||||||
});
|
});
|
||||||
return @ptrCast(ret.buffer.?);
|
return @ptrCast(ret.buffer.?);
|
||||||
}
|
}
|
||||||
@ -530,6 +534,8 @@ pub const MemoryStats = struct {
|
|||||||
pool_bytes_is_set: bool, // out
|
pool_bytes_is_set: bool, // out
|
||||||
peak_pool_bytes: u64, // out
|
peak_pool_bytes: u64, // out
|
||||||
peak_pool_bytes_is_set: bool, // out
|
peak_pool_bytes_is_set: bool, // out
|
||||||
|
|
||||||
|
pub const zeroes = std.mem.zeroes(MemoryStats);
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Device = opaque {
|
pub const Device = opaque {
|
||||||
@ -556,10 +562,11 @@ pub const Device = opaque {
|
|||||||
return @intCast(ret.local_hardware_id);
|
return @intCast(ret.local_hardware_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn addressableMemories(self: *const Device, api: *const Api) ApiError![]const *Memory {
|
pub fn addressableMemories(self: *const Device, api: *const Api) []const *Memory {
|
||||||
const ret = try api.call(.PJRT_Device_AddressableMemories, .{
|
const ret = api.call(
|
||||||
.device = self.inner(),
|
.PJRT_Device_AddressableMemories,
|
||||||
});
|
.{ .device = self.inner() },
|
||||||
|
) catch return &.{};
|
||||||
return @ptrCast(ret.memories[0..ret.num_memories]);
|
return @ptrCast(ret.memories[0..ret.num_memories]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -728,7 +735,6 @@ pub const LoadedExecutable = opaque {
|
|||||||
_ = api.call(.PJRT_LoadedExecutable_Destroy, .{
|
_ = api.call(.PJRT_LoadedExecutable_Destroy, .{
|
||||||
.executable = self.inner(),
|
.executable = self.inner(),
|
||||||
}) catch {};
|
}) catch {};
|
||||||
self.* = undefined;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
pub fn delete(self: *LoadedExecutable, api: *const Api) void {
|
||||||
@ -759,6 +765,7 @@ pub const LoadedExecutable = opaque {
|
|||||||
non_donatable_input_indices: []const i64 = &.{},
|
non_donatable_input_indices: []const i64 = &.{},
|
||||||
context: ?*ExecuteContext,
|
context: ?*ExecuteContext,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ApiError!void {
|
||||||
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
var options = pjrtStruct(c.PJRT_ExecuteOptions{
|
||||||
.send_callbacks = null,
|
.send_callbacks = null,
|
||||||
@ -1048,8 +1055,16 @@ pub const Event = opaque {
|
|||||||
pub const Memory = opaque {
|
pub const Memory = opaque {
|
||||||
pub const Kind = enum {
|
pub const Kind = enum {
|
||||||
device,
|
device,
|
||||||
pinned_host,
|
host_pinned,
|
||||||
unpinned_host,
|
host_unpinned,
|
||||||
|
|
||||||
|
pub fn pjrtName(k: Kind) []const u8 {
|
||||||
|
return switch (k) {
|
||||||
|
.device => "device",
|
||||||
|
.host_pinned => "pinned_host",
|
||||||
|
.host_unpinned => "unpinned_host",
|
||||||
|
};
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const inner = InnerMixin(c.PJRT_Memory).inner;
|
const inner = InnerMixin(c.PJRT_Memory).inner;
|
||||||
@ -1061,8 +1076,12 @@ pub const Memory = opaque {
|
|||||||
|
|
||||||
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
||||||
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
||||||
const kind_ = ret.kind orelse unreachable;
|
return switch (ret.kind_size) {
|
||||||
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
|
"device".len => .device,
|
||||||
|
"pinned_host".len => .host_pinned,
|
||||||
|
"unpinned_host".len => .host_unpinned,
|
||||||
|
else => @panic("Memory kind not supported"),
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
||||||
|
|||||||
@ -41,8 +41,9 @@
|
|||||||
//! caller to manage the lifetime. The caller should be skipping program name.
|
//! caller to manage the lifetime. The caller should be skipping program name.
|
||||||
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
|
||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const debug = @import("debug.zig");
|
const debug = @import("debug.zig");
|
||||||
|
|
||||||
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
/// Format and print an error message to stderr, then exit with an exit code of 1.
|
||||||
@ -285,7 +286,7 @@ fn parse_flags(args: *std.process.ArgIterator, comptime Flags: type) Flags {
|
|||||||
|
|
||||||
fn assert_valid_value_type(comptime T: type) void {
|
fn assert_valid_value_type(comptime T: type) void {
|
||||||
comptime {
|
comptime {
|
||||||
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int) return;
|
if (T == []const u8 or T == [:0]const u8 or T == ByteSize or @typeInfo(T) == .int or @typeInfo(T) == .float) return;
|
||||||
|
|
||||||
if (@typeInfo(T) == .@"enum") {
|
if (@typeInfo(T) == .@"enum") {
|
||||||
const info = @typeInfo(T).@"enum";
|
const info = @typeInfo(T).@"enum";
|
||||||
@ -347,6 +348,7 @@ fn parse_value(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
|||||||
if (V == []const u8 or V == [:0]const u8) return value;
|
if (V == []const u8 or V == [:0]const u8) return value;
|
||||||
if (V == ByteSize) return parse_value_size(flag, value);
|
if (V == ByteSize) return parse_value_size(flag, value);
|
||||||
if (@typeInfo(V) == .int) return parse_value_int(V, flag, value);
|
if (@typeInfo(V) == .int) return parse_value_int(V, flag, value);
|
||||||
|
if (@typeInfo(V) == .float) return parse_value_float(V, flag, value);
|
||||||
if (@typeInfo(V) == .@"enum") return parse_value_enum(V, flag, value);
|
if (@typeInfo(V) == .@"enum") return parse_value_enum(V, flag, value);
|
||||||
comptime unreachable;
|
comptime unreachable;
|
||||||
}
|
}
|
||||||
@ -515,6 +517,20 @@ fn parse_value_int(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Parse string value into a float, providing a nice error message for the user.
|
||||||
|
fn parse_value_float(comptime T: type, flag: []const u8, value: [:0]const u8) T {
|
||||||
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
|
|
||||||
|
return std.fmt.parseFloat(T, value) catch |err| {
|
||||||
|
switch (err) {
|
||||||
|
error.InvalidCharacter => fatal(
|
||||||
|
"{s}: expected a decimal value, but found '{s}' (invalid character)",
|
||||||
|
.{ flag, value },
|
||||||
|
),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
fn parse_value_enum(comptime E: type, flag: []const u8, value: [:0]const u8) E {
|
||||||
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
assert((flag[0] == '-' and flag[1] == '-') or flag[0] == '<');
|
||||||
comptime assert(@typeInfo(E).@"enum".is_exhaustive);
|
comptime assert(@typeInfo(E).@"enum".is_exhaustive);
|
||||||
|
|||||||
@ -116,12 +116,17 @@ pub const BufferStore = struct {
|
|||||||
if (id < self._unique_id or self._unique_id + _store_id_range <= id) {
|
if (id < self._unique_id or self._unique_id + _store_id_range <= id) {
|
||||||
@panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object.");
|
@panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`, using the same store object.");
|
||||||
}
|
}
|
||||||
|
if (platform.target != .cpu) mem_debug: {
|
||||||
|
const stats = platform.getDevices()[0].memoryStats(platform.pjrt_api) catch break :mem_debug;
|
||||||
|
log.debug("Loading {s} -> {f} {d:>10} bytes ({d:>10} allocated / {d:>10} reserved)", .{ self.buffers.keys()[id - self._unique_id], x._shape, x.shape().byteSize(), stats.bytes_in_use, stats.bytes_reserved });
|
||||||
|
}
|
||||||
break :hb self.buffers.values()[id - self._unique_id];
|
break :hb self.buffers.values()[id - self._unique_id];
|
||||||
},
|
},
|
||||||
else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"),
|
else => @panic("`store.loadBufferById()` only works on Tensor created by `store.getTensor()`"),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Use the sharding information stored in the tensor.
|
// Use the sharding information stored in the tensor.
|
||||||
|
std.debug.assert(host_buffer.shape().eql(x.shape()));
|
||||||
host_buffer._shape = x.shape();
|
host_buffer._shape = x.shape();
|
||||||
return try host_buffer.toDevice(platform);
|
return try host_buffer.toDevice(platform);
|
||||||
}
|
}
|
||||||
@ -703,7 +708,7 @@ pub fn loadModelBuffersWithPrefix(
|
|||||||
var res: zml.Bufferized(Model) = undefined;
|
var res: zml.Bufferized(Model) = undefined;
|
||||||
try zml.meta.mapAlloc(struct {
|
try zml.meta.mapAlloc(struct {
|
||||||
pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer {
|
pub fn initBuffer(_: void, tensor: zml.Tensor) zml.Buffer {
|
||||||
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined };
|
return .{ ._shape = tensor.shape(), ._api = undefined, ._shards = undefined, ._target = undefined };
|
||||||
}
|
}
|
||||||
}.initBuffer, allocator, {}, model, &res);
|
}.initBuffer, allocator, {}, model, &res);
|
||||||
|
|
||||||
|
|||||||
195
zml/buffer.zig
195
zml/buffer.zig
@ -8,6 +8,7 @@ const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
|||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const Target = @import("platform.zig").Target;
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -22,40 +23,22 @@ const log = std.log.scoped(.zml);
|
|||||||
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
||||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||||
pub const Buffer = struct {
|
pub const Buffer = struct {
|
||||||
pub const Memory = enum {
|
|
||||||
host,
|
|
||||||
host_pinned,
|
|
||||||
device,
|
|
||||||
|
|
||||||
pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind {
|
|
||||||
return switch (self) {
|
|
||||||
.host => .unpinned_host,
|
|
||||||
.host_pinned => .pinned_host,
|
|
||||||
.device => .device,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn pjrtName(self: Memory) []const u8 {
|
|
||||||
return @tagName(self.toPjrtMemory());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
_shape: Shape,
|
_shape: Shape,
|
||||||
_api: *const pjrt.Api,
|
_api: *const pjrt.Api,
|
||||||
|
_target: Target,
|
||||||
_shards: Shards,
|
_shards: Shards,
|
||||||
|
|
||||||
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
pub const MAX_NUM_SHARDS: u8 = Platform.MAX_NUM_DEVICES;
|
||||||
pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
pub const Shards = stdx.BoundedArray(*pjrt.Buffer, MAX_NUM_SHARDS);
|
||||||
|
|
||||||
pub const FromOptions = struct {
|
pub const Memory = pjrt.Memory.Kind;
|
||||||
wait: bool = true,
|
pub const FromOptions = struct { wait: bool = true, memory: Memory = .device };
|
||||||
memory: ?Memory = null,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||||
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
|
pub fn from(platform: Platform, host_buffer: HostBuffer, opts: FromOptions) !Buffer {
|
||||||
var res: Buffer = .{
|
var res: Buffer = .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
|
._target = platform.target,
|
||||||
._shape = host_buffer.shape(),
|
._shape = host_buffer.shape(),
|
||||||
._shards = .{},
|
._shards = .{},
|
||||||
};
|
};
|
||||||
@ -82,35 +65,22 @@ pub const Buffer = struct {
|
|||||||
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
break :buf host_buffer.slice1d(ax, .{ .start = start, .end = start + chunk_size });
|
||||||
} else host_buffer;
|
} else host_buffer;
|
||||||
|
|
||||||
var args = pjrt.Client.BufferFromHostBufferArgs{
|
const args = pjrt.Client.BufferFromHostBufferArgs{
|
||||||
.data = buf._data,
|
.data = buf._data,
|
||||||
.buffer_type = buffer_type,
|
.buffer_type = buffer_type,
|
||||||
.dims = buf.shape().dims(),
|
.dims = buf.shape().dims(),
|
||||||
.byte_strides = byte_strides,
|
.byte_strides = byte_strides,
|
||||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||||
|
// CPU has no distinctions between memories.
|
||||||
|
.dst = if (platform.target == .cpu or opts.memory == .device)
|
||||||
|
.{ .device = devices[i] }
|
||||||
|
else
|
||||||
|
.{ .memory = platform.memoryForDevice(opts.memory, devices[i]) },
|
||||||
};
|
};
|
||||||
if (platform.target == .cpu or opts.memory == null) {
|
|
||||||
args.device = devices[i];
|
|
||||||
} else {
|
|
||||||
const memory = opts.memory.?;
|
|
||||||
const device_memories = try devices[i].addressableMemories(platform.pjrt_api);
|
|
||||||
// TODO measure the cost of this and consider caching on Zig side inside the platform.
|
|
||||||
const selected_memory = for (device_memories) |m| {
|
|
||||||
const kind = m.kind(platform.pjrt_api);
|
|
||||||
if (kind == memory.toPjrtMemory()) break m;
|
|
||||||
} else {
|
|
||||||
log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) });
|
|
||||||
return error.NotFound;
|
|
||||||
};
|
|
||||||
args.memory = selected_memory;
|
|
||||||
}
|
|
||||||
|
|
||||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||||
|
|
||||||
if (event) |ev| {
|
if (event) |ev| ev.deinit(platform.pjrt_api);
|
||||||
ev.deinit(platform.pjrt_api);
|
|
||||||
}
|
|
||||||
|
|
||||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +101,15 @@ pub const Buffer = struct {
|
|||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn awaitBlocking(self: Buffer) !Buffer {
|
||||||
|
for (self._shards.constSlice()) |buffer| {
|
||||||
|
if (buffer.getReadyEvent(self._api)) |ev| {
|
||||||
|
try ev.awaitBlocking(self._api);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return self;
|
||||||
|
}
|
||||||
|
|
||||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||||
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||||
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
stdx.debug.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||||
@ -139,6 +118,7 @@ pub const Buffer = struct {
|
|||||||
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
||||||
return .{
|
return .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
|
._target = platform.target,
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
._shards = shards,
|
._shards = shards,
|
||||||
};
|
};
|
||||||
@ -185,9 +165,10 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn asHostBuffer(self: Buffer) HostBuffer {
|
pub fn asHostBuffer(self: Buffer) HostBuffer {
|
||||||
// TODO: skip this check on cpu
|
if (self._target != .cpu) {
|
||||||
// const memory = self.getMemory().kind(self._api);
|
const memory = self.getMemory().kind(self._api);
|
||||||
// stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
stdx.debug.assert((memory == .host_pinned) or (memory == .host_unpinned), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
||||||
|
}
|
||||||
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||||
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||||
}
|
}
|
||||||
@ -200,7 +181,7 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Buffer with a single element repeated manytime.
|
/// Creates a Buffer with a single element repeated manytime.
|
||||||
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
|
pub fn constant(platform: Platform, shape_: Shape, val: anytype, opts: FromOptions) !Buffer {
|
||||||
var start = try std.time.Timer.start();
|
var start = try std.time.Timer.start();
|
||||||
defer {
|
defer {
|
||||||
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
||||||
@ -210,6 +191,8 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Constant is always blocking because it uses pointer to stack memory.
|
||||||
|
const cst_opts: FromOptions = .{ .memory = opts.memory, .wait = true };
|
||||||
// Convert val to the requested dtype.
|
// Convert val to the requested dtype.
|
||||||
const x = shape_.dtype().constant(val);
|
const x = shape_.dtype().constant(val);
|
||||||
const byte_size = shape_.dtype().sizeOf();
|
const byte_size = shape_.dtype().sizeOf();
|
||||||
@ -222,7 +205,7 @@ pub const Buffer = struct {
|
|||||||
._strides = @splat(0),
|
._strides = @splat(0),
|
||||||
._data = x.constSlice().ptr,
|
._data = x.constSlice().ptr,
|
||||||
};
|
};
|
||||||
return try from(platform, host_buffer, .{ .wait = true });
|
return try from(platform, host_buffer, cst_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
// To speed up copies, duplicate the scalar value into a vector,
|
// To speed up copies, duplicate the scalar value into a vector,
|
||||||
@ -245,14 +228,14 @@ pub const Buffer = struct {
|
|||||||
else => unreachable,
|
else => unreachable,
|
||||||
}
|
}
|
||||||
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, ._data = &bytes };
|
||||||
return try from(platform, host_buffer, .{ .wait = true });
|
return try from(platform, host_buffer, cst_opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
test constant {
|
test constant {
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42);
|
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42, .{ .wait = true });
|
||||||
const y = try x.getValue([4 * 3 * 2]u16);
|
const y = try x.getValue([4 * 3 * 2]u16);
|
||||||
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
|
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
|
||||||
}
|
}
|
||||||
@ -271,14 +254,6 @@ pub const Buffer = struct {
|
|||||||
/// Creates a Buffer from a pointer into device memory.
|
/// Creates a Buffer from a pointer into device memory.
|
||||||
/// This allows to interface with other libraries producing buffers.
|
/// This allows to interface with other libraries producing buffers.
|
||||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
|
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const pjrt.Stream, device_data: *anyopaque) Buffer {
|
||||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
|
||||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
|
||||||
res[i] = @intCast(Shape.MAX_RANK - i - 1);
|
|
||||||
}
|
|
||||||
break :blk res;
|
|
||||||
};
|
|
||||||
|
|
||||||
const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||||
.data = device_data,
|
.data = device_data,
|
||||||
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||||
@ -287,7 +262,7 @@ pub const Buffer = struct {
|
|||||||
.device = platform.getDevices()[0],
|
.device = platform.getDevices()[0],
|
||||||
.layout = .{
|
.layout = .{
|
||||||
.tiled = .{
|
.tiled = .{
|
||||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
.minor_to_major = minorToMajor(shape_.rank()),
|
||||||
.tile_dims = &.{},
|
.tile_dims = &.{},
|
||||||
.tile_dims_sizes = &.{},
|
.tile_dims_sizes = &.{},
|
||||||
},
|
},
|
||||||
@ -299,15 +274,16 @@ pub const Buffer = struct {
|
|||||||
shards.appendAssumeCapacity(pjrt_buffer);
|
shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
return .{
|
return .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
|
._target = platform.target,
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
._shards = shards,
|
._shards = shards,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 {
|
pub fn devicePtr(self: Buffer) u64 {
|
||||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
|
||||||
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
|
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
|
||||||
return @ptrCast(opaque_ptr);
|
return @intFromPtr(opaque_ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||||
@ -350,6 +326,7 @@ pub const Buffer = struct {
|
|||||||
/// Depending on the platform, the memory is typically not released to the OS
|
/// Depending on the platform, the memory is typically not released to the OS
|
||||||
/// but just marked as available in the memory pool.
|
/// but just marked as available in the memory pool.
|
||||||
pub fn deinit(self: *const Buffer) void {
|
pub fn deinit(self: *const Buffer) void {
|
||||||
|
// log.warn("Unloading {f} {d} bytes", .{ self._shape, self._shape.byteSize() });
|
||||||
for (self._shards.constSlice()) |buffer| {
|
for (self._shards.constSlice()) |buffer| {
|
||||||
buffer.deinit(self._api);
|
buffer.deinit(self._api);
|
||||||
}
|
}
|
||||||
@ -385,7 +362,7 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
|
pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
|
||||||
try writer.print("Buffer({f})", .{self._shape});
|
try writer.print("Buffer({f})@{x}", .{ self._shape, self.devicePtr() });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
||||||
@ -402,10 +379,10 @@ pub const Buffer = struct {
|
|||||||
wait: bool = true,
|
wait: bool = true,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer {
|
pub fn copyToMemory(self: Buffer, platform: Platform, memory: pjrt.Memory.Kind, opts: CopyToMemoryOpts) !Buffer {
|
||||||
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory());
|
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory);
|
||||||
if (pjrt_memory == null) {
|
if (pjrt_memory == null) {
|
||||||
stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self });
|
stdx.debug.panic("Memory destination `{t}` for {f}", .{ memory, self });
|
||||||
}
|
}
|
||||||
|
|
||||||
var new_shards: Buffer.Shards = .{};
|
var new_shards: Buffer.Shards = .{};
|
||||||
@ -423,35 +400,34 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
return Buffer{ ._api = self._api, ._target = platform.target, ._shape = self._shape, ._shards = new_shards };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const UnitializedOptions = struct {
|
pub const UnitializedOptions = struct { memory: Memory = .device };
|
||||||
memory: ?pjrt.Memory.Kind = null,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer {
|
pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer {
|
||||||
|
if (opts.memory != .device) {
|
||||||
|
// XLA uninitialized doesn't respect memory see https://github.com/openxla/xla/pull/31292
|
||||||
|
// TODO: use uninitialized when it works again.
|
||||||
|
const host_buffer: HostBuffer = try .empty(std.heap.smp_allocator, shape_);
|
||||||
|
defer host_buffer.deinit(std.heap.smp_allocator);
|
||||||
|
return try .from(platform, host_buffer, .{ .wait = true, .memory = opts.memory });
|
||||||
|
}
|
||||||
|
|
||||||
var res: Buffer = .{
|
var res: Buffer = .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
._shape = shape_,
|
._shape = shape_,
|
||||||
._shards = .{},
|
._shards = .{},
|
||||||
|
._target = platform.target,
|
||||||
};
|
};
|
||||||
errdefer for (res._shards.slice()) |shard| {
|
errdefer for (res._shards.slice()) |shard| {
|
||||||
shard.deinit(platform.pjrt_api);
|
shard.deinit(platform.pjrt_api);
|
||||||
};
|
};
|
||||||
|
|
||||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
|
||||||
var minor_to_major: [Shape.MAX_RANK]i64 = undefined;
|
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
|
||||||
minor_to_major[i] = @intCast(Shape.MAX_RANK - i - 1);
|
|
||||||
}
|
|
||||||
break :blk minor_to_major;
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: support more advanced sharding specs
|
|
||||||
stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
||||||
const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info);
|
const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info);
|
||||||
const n_partitions = platform.sharding().num_partitions;
|
const n_partitions = platform.sharding().num_partitions;
|
||||||
|
|
||||||
const shard_shape = if (sharding_ax) |ax| s: {
|
const shard_shape = if (sharding_ax) |ax| s: {
|
||||||
// This kind of sharding error should be detected earlier on.
|
// This kind of sharding error should be detected earlier on.
|
||||||
stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions });
|
stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions });
|
||||||
@ -459,37 +435,56 @@ pub const Buffer = struct {
|
|||||||
break :s shard_shape;
|
break :s shard_shape;
|
||||||
} else shape_;
|
} else shape_;
|
||||||
|
|
||||||
const buffer_type = bufferTypeFromDtype(shape_.dtype());
|
|
||||||
const devices = platform.getDevices();
|
|
||||||
for (0..n_partitions) |i| {
|
|
||||||
var args = pjrt.Client.CreateUninitializedBufferArgs{
|
var args = pjrt.Client.CreateUninitializedBufferArgs{
|
||||||
.dims = shard_shape.dims(),
|
.dims = shard_shape.dims(),
|
||||||
.element_type = buffer_type,
|
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||||
.layout = .{
|
.layout = .{
|
||||||
.tiled = .{
|
.tiled = .{
|
||||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
.minor_to_major = minorToMajor(shape_.rank()),
|
||||||
.tile_dims = &.{},
|
.tile_dims = &.{},
|
||||||
.tile_dims_sizes = &.{},
|
.tile_dims_sizes = &.{},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// set per device, see below.
|
||||||
|
.dst = undefined,
|
||||||
};
|
};
|
||||||
if (opts.memory) |memory_kind| {
|
|
||||||
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
|
||||||
const memory = for (memories) |m| {
|
|
||||||
const kind = m.kind(platform.pjrt_api);
|
|
||||||
if (kind == memory_kind) break m;
|
|
||||||
} else return error.NotFound;
|
|
||||||
args.memory = memory;
|
|
||||||
} else {
|
|
||||||
args.device = devices[i];
|
|
||||||
}
|
|
||||||
const pjrt_buffer = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args);
|
|
||||||
|
|
||||||
res._shards.appendAssumeCapacity(pjrt_buffer);
|
const devices = platform.getDevices();
|
||||||
|
for (0..n_partitions) |i| {
|
||||||
|
args.dst = if (platform.target == .cpu or opts.memory == .device)
|
||||||
|
.{ .device = devices[i] }
|
||||||
|
else
|
||||||
|
.{ .memory = platform.memoryForDevice(opts.memory, devices[i]) };
|
||||||
|
|
||||||
|
const shard = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args);
|
||||||
|
res._shards.appendAssumeCapacity(shard);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test uninitialized {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned };
|
||||||
|
for (host_visible_memories) |memory| {
|
||||||
|
const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory });
|
||||||
|
const x_ptr: [*]u8 = @ptrFromInt(x.devicePtr());
|
||||||
|
@memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 });
|
||||||
|
|
||||||
|
const y = try x.getValue([6]u8);
|
||||||
|
try std.testing.expectEqualStrings("hello!", &y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn isDeleted(self: Buffer) bool {
|
||||||
|
const deleted: bool = self._shards.get(0).isDeleted(self._api);
|
||||||
|
for (self._shards.slice()[1..]) |shard| {
|
||||||
|
stdx.debug.assert(shard.isDeleted(self._api) == deleted, "Buffer has some shards deleted but not others.", .{});
|
||||||
|
}
|
||||||
|
return deleted;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||||
@ -517,3 +512,15 @@ test bufferTypeFromDtype {
|
|||||||
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
try std.testing.expectEqual(dt, bufferTypeFromDtype(dtypeFromBufferType(dt)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const _MINOR_TO_MAJOR = blk: {
|
||||||
|
var ret: [Shape.MAX_RANK]i64 = undefined;
|
||||||
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
|
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
fn minorToMajor(rank: u8) []const i64 {
|
||||||
|
return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..];
|
||||||
|
}
|
||||||
|
|||||||
@ -99,7 +99,7 @@ pub fn call(
|
|||||||
.api_version = .typed_ffi,
|
.api_version = .typed_ffi,
|
||||||
.backend_config = .dict(mlir_ctx, &.{}),
|
.backend_config = .dict(mlir_ctx, &.{}),
|
||||||
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
|
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
|
||||||
.has_side_effect = true,
|
.has_side_effect = Callback.callback_config.has_side_effect,
|
||||||
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
|
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
|
||||||
},
|
},
|
||||||
output_types,
|
output_types,
|
||||||
@ -123,6 +123,7 @@ pub const Config = struct {
|
|||||||
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
|
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
|
||||||
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
|
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
|
||||||
// TODO: handle sharded inputs
|
// TODO: handle sharded inputs
|
||||||
|
has_side_effect: bool = true,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Compile-time check that a callback has all informations we require.
|
/// Compile-time check that a callback has all informations we require.
|
||||||
@ -190,12 +191,12 @@ fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt
|
|||||||
else
|
else
|
||||||
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
|
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
|
||||||
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||||
log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() });
|
log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.devicePtr() });
|
||||||
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
|
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
|
||||||
log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err });
|
log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.devicePtr(), err });
|
||||||
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
|
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
|
||||||
};
|
};
|
||||||
log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* });
|
log.debug("--> {f} {x}", .{ zml_buffer, zml_buffer.devicePtr() });
|
||||||
}
|
}
|
||||||
callback_args[i] = zml_buffer;
|
callback_args[i] = zml_buffer;
|
||||||
}
|
}
|
||||||
@ -282,6 +283,7 @@ pub const Print = struct {
|
|||||||
.copy_inputs_to_host_pinned = true,
|
.copy_inputs_to_host_pinned = true,
|
||||||
// Print is fairly predictable and can be captured in an execution graph.
|
// Print is fairly predictable and can be captured in an execution graph.
|
||||||
.traits = .{ .command_buffer_compatible = false },
|
.traits = .{ .command_buffer_compatible = false },
|
||||||
|
.has_side_effect = false,
|
||||||
};
|
};
|
||||||
|
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
|
|||||||
@ -235,6 +235,7 @@ pub const BaseExe = struct {
|
|||||||
if (self.execute_context) |ctx| {
|
if (self.execute_context) |ctx| {
|
||||||
ctx.deinit(self.platform.pjrt_api);
|
ctx.deinit(self.platform.pjrt_api);
|
||||||
}
|
}
|
||||||
|
self.exe.deinit(self.platform.pjrt_api);
|
||||||
self._arena.deinit();
|
self._arena.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -395,6 +396,10 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|||||||
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn clone(self: Self, allocator: std.mem.Allocator) !Self {
|
||||||
|
return .{ .inner = try self.inner.clone(allocator) };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -322,7 +322,7 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
|
pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
|
||||||
try writer.print("HostBuffer(.{f})", .{self._shape});
|
try writer.print("HostBuffer(.{f})@{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
|
|||||||
@ -221,6 +221,12 @@ pub const CompilationContext = struct {
|
|||||||
break :blk loaded_executable;
|
break :blk loaded_executable;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
const exe = try loaded_executable.getExecutable(self._platform.pjrt_api);
|
||||||
|
const stats = try exe.getCompiledMemoryStats(self._platform.pjrt_api);
|
||||||
|
log.debug("Compiled {s}: {any}", .{ self._name, stats });
|
||||||
|
}
|
||||||
|
|
||||||
log.debug("******** ZML generated MLIR ********", .{});
|
log.debug("******** ZML generated MLIR ********", .{});
|
||||||
log.debug("{f}", .{module.op().mlirFormatter(.{})});
|
log.debug("{f}", .{module.op().mlirFormatter(.{})});
|
||||||
|
|
||||||
@ -881,6 +887,9 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
|
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
|
||||||
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
|
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
|
||||||
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
|
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
|
||||||
|
if (platform.compilation_options.device_memory_size > 0) {
|
||||||
|
c.xla_ExecutableBuildOptionsProto_set_device_memory_size(exec_build_options, @intCast(platform.compilation_options.device_memory_size));
|
||||||
|
}
|
||||||
|
|
||||||
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
|
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
|
||||||
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
|
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
|
||||||
@ -895,7 +904,6 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
}
|
}
|
||||||
break :device_assignment_blk device_assignment;
|
break :device_assignment_blk device_assignment;
|
||||||
});
|
});
|
||||||
|
|
||||||
break :executable_build_options_blk exec_build_options;
|
break :executable_build_options_blk exec_build_options;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@ -243,7 +243,9 @@ pub const Event = opaque {
|
|||||||
|
|
||||||
if (ctx.err) |e| {
|
if (ctx.err) |e| {
|
||||||
defer e.deinit(api);
|
defer e.deinit(api);
|
||||||
return e.getCode(api).toApiError();
|
const err_code = e.getCode(api).toApiError();
|
||||||
|
log.err("{t} {s}", .{ err_code, e.getMessage(api) });
|
||||||
|
return err_code;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -10,12 +10,17 @@ const log = std.log.scoped(.zml);
|
|||||||
|
|
||||||
pub const available_targets = std.enums.values(Target);
|
pub const available_targets = std.enums.values(Target);
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
pub const CompilationOptions = struct {
|
pub const CompilationOptions = struct {
|
||||||
xla_dump_to: ?[]const u8 = null,
|
xla_dump_to: ?[]const u8 = null,
|
||||||
xla_dump_fusion_visualization: bool = false,
|
xla_dump_fusion_visualization: bool = false,
|
||||||
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
||||||
sharding_enabled: bool = false,
|
sharding_enabled: bool = false,
|
||||||
sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
|
sharding_axes: stdx.BoundedArray([*:0]const u8, 8) = .{},
|
||||||
|
device_memory_size: u64 = 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Platform = struct {
|
pub const Platform = struct {
|
||||||
@ -29,7 +34,7 @@ pub const Platform = struct {
|
|||||||
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
||||||
compilation_options: CompilationOptions = .{},
|
compilation_options: CompilationOptions = .{},
|
||||||
|
|
||||||
pub const MAX_NUM_DEVICES: u8 = 32;
|
pub const MAX_NUM_DEVICES: u8 = if (runtimes.isEnabled(.tpu)) 32 else 8;
|
||||||
pub const CreateOptions = _CreateOptions;
|
pub const CreateOptions = _CreateOptions;
|
||||||
|
|
||||||
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
|
pub fn init(target: Target, api: *const pjrt.Api, options: CreateOptions) !Platform {
|
||||||
@ -79,6 +84,44 @@ pub const Platform = struct {
|
|||||||
pub fn deinit(self: *Platform) void {
|
pub fn deinit(self: *Platform) void {
|
||||||
self.pjrt_client.deinit(self.pjrt_api);
|
self.pjrt_client.deinit(self.pjrt_api);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn memoryForDevice(platform: Platform, memory: pjrt.Memory.Kind, device: *const pjrt.Device) *const pjrt.Memory {
|
||||||
|
const memory_target: pjrt.Memory.Kind = switch (memory) {
|
||||||
|
.host_unpinned => switch (platform.target) {
|
||||||
|
// Cuda doesn't have host_unpinned.
|
||||||
|
.cuda => .host_pinned,
|
||||||
|
else => .host_unpinned,
|
||||||
|
},
|
||||||
|
inline else => |t| t,
|
||||||
|
};
|
||||||
|
// TODO measure the cost of this and consider caching.
|
||||||
|
const device_memories = device.addressableMemories(platform.pjrt_api);
|
||||||
|
for (device_memories) |m| {
|
||||||
|
if (memory_target == m.kind(platform.pjrt_api)) {
|
||||||
|
return m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.err("Platform {t} doesn't have memory {t}", .{ platform.target, memory });
|
||||||
|
@panic("Memory kind not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
test memoryForDevice {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
const memory_fields = @typeInfo(pjrt.Memory.Kind).@"enum".fields;
|
||||||
|
inline for (memory_fields) |field| {
|
||||||
|
for (platform.getDevices()) |dev| {
|
||||||
|
_ = platform.memoryForDevice(@field(pjrt.Memory.Kind, field.name), dev);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn memoryStats(platform: Platform, device_id: usize) pjrt.MemoryStats {
|
||||||
|
if (platform.target == .cpu) return .zeroes;
|
||||||
|
|
||||||
|
const device = platform.getDevices()[device_id];
|
||||||
|
return device.memoryStats(platform.pjrt_api) catch .zeroes;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const _CreateOptions = struct {
|
const _CreateOptions = struct {
|
||||||
@ -127,17 +170,21 @@ const _CreateOptions = struct {
|
|||||||
fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
|
fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
|
||||||
switch (self.allocator) {
|
switch (self.allocator) {
|
||||||
.platform => {
|
.platform => {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
values.appendAssumeCapacity(.fromString("allocator", "platform"));
|
||||||
},
|
},
|
||||||
.bfc, .async => |opt| {
|
.bfc, .async => |opt| {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("allocator", self.allocator));
|
values.appendAssumeCapacity(.fromString("allocator", switch (self.allocator) {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("preallocate", opt.preallocate));
|
.bfc => "bfc",
|
||||||
|
.async => "cuda_async",
|
||||||
|
.platform => unreachable,
|
||||||
|
}));
|
||||||
|
values.appendAssumeCapacity(.from("preallocate", opt.preallocate));
|
||||||
if (opt.memory_fraction > 0) {
|
if (opt.memory_fraction > 0) {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("memory_fraction", opt.memory_fraction));
|
values.appendAssumeCapacity(.from("memory_fraction", opt.memory_fraction));
|
||||||
}
|
}
|
||||||
if (opt.collective_memory_size_mb > 0) {
|
if (opt.collective_memory_size_mb > 0) {
|
||||||
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
|
const collective = @as(i64, opt.collective_memory_size_mb) * 1024 * 1024;
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("collective_memory_size", collective));
|
values.appendAssumeCapacity(.from("collective_memory_size", collective));
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -202,10 +202,8 @@ pub const Tensor = struct {
|
|||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
if (ctx.target() == .cpu) return self;
|
if (ctx.target() == .cpu) return self;
|
||||||
|
|
||||||
const memory_kind = @tagName(kind.toPjrtMemory());
|
|
||||||
|
|
||||||
const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{
|
const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{
|
||||||
.{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) },
|
.{ "_xla_buffer_placement", .string(mlir_ctx, kind.pjrtName()) },
|
||||||
});
|
});
|
||||||
|
|
||||||
const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{
|
const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{
|
||||||
@ -311,6 +309,11 @@ pub const Tensor = struct {
|
|||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the given tensor as one contiguous buffer of bytes.
|
||||||
|
pub fn bytes(self: Tensor) Tensor {
|
||||||
|
return self.bitCast(.u8).flattenAll().withTags(.{.bytes});
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
|
/// Returns a Tensor containing the element-wise number of leading 0 bits in the input Tensor.
|
||||||
pub fn countLeadingZeros(self: Tensor) Tensor {
|
pub fn countLeadingZeros(self: Tensor) Tensor {
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().mlirCtx().location(@src());
|
||||||
@ -2683,7 +2686,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
{
|
{
|
||||||
// Test with actual values and batching along axis .a
|
// Test with actual values and batching along axis .a
|
||||||
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0);
|
const operand = try zml.Buffer.constant(platform, Shape.init(.{ .a = 2, .b = 3, .c = 4, .d = 2 }, .u16), 0, .{});
|
||||||
defer operand.deinit();
|
defer operand.deinit();
|
||||||
const start_indices = (try zml.Buffer.fromArray(
|
const start_indices = (try zml.Buffer.fromArray(
|
||||||
platform,
|
platform,
|
||||||
@ -2704,6 +2707,7 @@ pub const Tensor = struct {
|
|||||||
platform,
|
platform,
|
||||||
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
|
Shape.init(.{ .n = 2, .a = 2, .m = 3, .c = 2, .d = 2 }, .u16),
|
||||||
1,
|
1,
|
||||||
|
.{},
|
||||||
);
|
);
|
||||||
defer values.deinit();
|
defer values.deinit();
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user