Add event waiting when invoking a module and improve multi‑device sharding handling.
This commit is contained in:
parent
0189b71070
commit
8e43a45a3c
@ -1,262 +0,0 @@
|
|||||||
const std = @import("std");
|
|
||||||
|
|
||||||
pub const Value = union(enum) {
|
|
||||||
pub const Slice = struct {
|
|
||||||
pub const ItemType = enum {
|
|
||||||
uint8,
|
|
||||||
int8,
|
|
||||||
uint16,
|
|
||||||
int16,
|
|
||||||
uint32,
|
|
||||||
int32,
|
|
||||||
uint64,
|
|
||||||
int64,
|
|
||||||
float16,
|
|
||||||
float32,
|
|
||||||
float64,
|
|
||||||
boolval,
|
|
||||||
string,
|
|
||||||
// TODO (cryptodeal): gguf/torch/json (safetensors) in theory support nested arrays;
|
|
||||||
// we should support for the sake of completeness, but we have not yet encountered
|
|
||||||
// a model containing these types.
|
|
||||||
// TODO (cryptodeal): array,
|
|
||||||
};
|
|
||||||
|
|
||||||
item_type: ItemType,
|
|
||||||
data: []u8,
|
|
||||||
|
|
||||||
fn isNestedSlice(comptime T: type) bool {
|
|
||||||
const info = @typeInfo(T);
|
|
||||||
if (info != .Pointer or info.Pointer.size != .Slice) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
var child_info = @typeInfo(info.Pointer.child);
|
|
||||||
while (child_info == .Pointer and child_info.Pointer.size == .Slice) : (child_info = @typeInfo(child_info.Pointer.child)) {}
|
|
||||||
return switch (@TypeOf(child_info)) {};
|
|
||||||
}
|
|
||||||
|
|
||||||
fn fromZigType(comptime T: type) ItemType {
|
|
||||||
return switch (T) {
|
|
||||||
u8 => .uint8,
|
|
||||||
i8 => .int8,
|
|
||||||
u16 => .uint16,
|
|
||||||
i16 => .int16,
|
|
||||||
u32 => .uint32,
|
|
||||||
i32 => .int32,
|
|
||||||
u64 => .uint64,
|
|
||||||
i64 => .int64,
|
|
||||||
f16 => .float16,
|
|
||||||
f32 => .float32,
|
|
||||||
f64 => .float64,
|
|
||||||
bool => .boolval,
|
|
||||||
[]const u8 => .string,
|
|
||||||
else => @panic("Unsupported type for LoaderValue.Slice: " ++ @typeName(T)),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn toZigType(comptime kind: ItemType) type {
|
|
||||||
return switch (kind) {
|
|
||||||
.uint8 => u8,
|
|
||||||
.int8 => i8,
|
|
||||||
.uint16 => u16,
|
|
||||||
.int16 => i16,
|
|
||||||
.uint32 => u32,
|
|
||||||
.int32 => i32,
|
|
||||||
.uint64 => u64,
|
|
||||||
.int64 => i64,
|
|
||||||
.float16 => f16,
|
|
||||||
.float32 => f32,
|
|
||||||
.float64 => f64,
|
|
||||||
.boolval => bool,
|
|
||||||
.string => []const u8,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn cast(self: *Slice, comptime T: type) []T {
|
|
||||||
if (fromZigType(T) != self.item_type) {
|
|
||||||
@panic("Type mismatch in LoaderValue.Slice cast");
|
|
||||||
}
|
|
||||||
return @as([*]T, @ptrCast(@alignCast(self.data.ptr)))[0..self.data.len];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// TODO: this is overkill we don't need that many different types
|
|
||||||
// to represent metadata. bool, i64, f64, string are enough (like Json).
|
|
||||||
null,
|
|
||||||
uint8: u8,
|
|
||||||
int8: i8,
|
|
||||||
uint16: u16,
|
|
||||||
int16: i16,
|
|
||||||
uint32: u32,
|
|
||||||
int32: i32,
|
|
||||||
uint64: u64,
|
|
||||||
int64: i64,
|
|
||||||
float16: f16,
|
|
||||||
float32: f32,
|
|
||||||
float64: f64,
|
|
||||||
bigint: std.math.big.int.Managed,
|
|
||||||
boolval: bool,
|
|
||||||
array: Slice,
|
|
||||||
string: []const u8,
|
|
||||||
|
|
||||||
pub fn wrap(x: anytype) Value {
|
|
||||||
const tag = switch (@TypeOf(x)) {
|
|
||||||
u8 => .uint8,
|
|
||||||
i8 => .int8,
|
|
||||||
u16 => .uint16,
|
|
||||||
i16 => .int16,
|
|
||||||
u32 => .uint32,
|
|
||||||
i32 => .int32,
|
|
||||||
u64 => .uint64,
|
|
||||||
i64 => .int64,
|
|
||||||
f16 => .float16,
|
|
||||||
f32 => .float32,
|
|
||||||
f64 => .float64,
|
|
||||||
bool => .boolval,
|
|
||||||
[]const u8 => .string,
|
|
||||||
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(x))),
|
|
||||||
};
|
|
||||||
return @unionInit(Value, @tagName(tag), x);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"null"(self: Value) error{IncorrectType}!void {
|
|
||||||
switch (self) {
|
|
||||||
.null => {},
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `null`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"u8"(self: Value) error{IncorrectType}!u8 {
|
|
||||||
switch (self) {
|
|
||||||
.uint8 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `u8`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"i8"(self: Value) error{IncorrectType}!i8 {
|
|
||||||
switch (self) {
|
|
||||||
.int8 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `i8`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"u16"(self: Value) error{IncorrectType}!u16 {
|
|
||||||
switch (self) {
|
|
||||||
.uint16 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `u16`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"i16"(self: Value) error{IncorrectType}!i16 {
|
|
||||||
switch (self) {
|
|
||||||
.int16 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `i16`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"u32"(self: Value) error{IncorrectType}!u32 {
|
|
||||||
switch (self) {
|
|
||||||
.uint32 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `u32`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"i32"(self: Value) error{IncorrectType}!i32 {
|
|
||||||
switch (self) {
|
|
||||||
.int32 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `i32`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"u64"(self: Value) error{IncorrectType}!u64 {
|
|
||||||
switch (self) {
|
|
||||||
.uint64 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `u64`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"i64"(self: Value) error{IncorrectType}!i64 {
|
|
||||||
switch (self) {
|
|
||||||
.int64 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `i64`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"f16"(self: Value) error{IncorrectType}!f16 {
|
|
||||||
switch (self) {
|
|
||||||
.float16 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `f16`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"f32"(self: Value) error{IncorrectType}!f32 {
|
|
||||||
switch (self) {
|
|
||||||
.float32 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `f32`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"f64"(self: Value) error{IncorrectType}!f64 {
|
|
||||||
switch (self) {
|
|
||||||
.float64 => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `f64`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn @"bool"(self: Value) error{IncorrectType}!bool {
|
|
||||||
switch (self) {
|
|
||||||
.boolval => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected `bool`, actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn string(self: Value) error{IncorrectType}![]const u8 {
|
|
||||||
switch (self) {
|
|
||||||
.string => |v| return v,
|
|
||||||
inline else => |v| {
|
|
||||||
std.log.err("Expected string (`[]const u8`), actual value is {s}\n", .{@typeName(@TypeOf(v))});
|
|
||||||
return error.IncorrectType;
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
@ -859,18 +859,23 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
|
|
||||||
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
pub fn call(self: Self, args: Bufferized(Signature.ArgsT)) Bufferized(Signature.ReturnT) {
|
||||||
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
fillBuffers(&args, self.input_per_device, self.inner.model_buffer_count, self.inner.args_buffer_count);
|
||||||
var event: [1]*pjrt.Event = undefined;
|
var events: [Platform.MAX_NUM_DEVICES]*pjrt.Event = undefined;
|
||||||
|
const sharding = self.platform().sharding();
|
||||||
|
|
||||||
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
|
self.inner.exe.execute(self.inner.platform.pjrt_api, .{
|
||||||
.arguments = self.input_per_device,
|
.arguments = self.input_per_device,
|
||||||
.num_args = self.inner.args_buffer_count + self.inner.model_buffer_count,
|
.num_args = self.inner.args_buffer_count + self.inner.model_buffer_count,
|
||||||
.results = self.output_per_device,
|
.results = self.output_per_device,
|
||||||
.events = &event,
|
.events = events[0..sharding.num_partitions],
|
||||||
// TODO: this allows to tell a specific buffer shouldn't be donated,
|
// TODO: this allows to tell a specific buffer shouldn't be donated,
|
||||||
// even if it has been marked as "can be donated" during compilation.
|
// even if it has been marked as "can be donated" during compilation.
|
||||||
.non_donatable_input_indices = &.{},
|
.non_donatable_input_indices = &.{},
|
||||||
}) catch unreachable;
|
}) catch unreachable;
|
||||||
|
|
||||||
|
for (events[0..sharding.num_partitions]) |e| {
|
||||||
|
e.await_(self.inner.platform.pjrt_api) catch unreachable;
|
||||||
|
}
|
||||||
|
|
||||||
var result: Bufferized(Signature.ReturnT) = undefined;
|
var result: Bufferized(Signature.ReturnT) = undefined;
|
||||||
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count);
|
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count);
|
||||||
return result;
|
return result;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user