Add experimental zml.callback API (renamed from custom_call) and fix tensor.print(); update PJRT bindings, host buffer utilities, and related core ZML modules.
This commit is contained in:
parent
1fa056a790
commit
cc969bd532
148
pjrt/ffi.zig
148
pjrt/ffi.zig
@ -2,12 +2,21 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
pub const TypeId = c.XLA_FFI_TypeId;
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const pjrtStruct = @import("pjrt.zig").pjrtStruct;
|
const pjrt = @import("pjrt.zig");
|
||||||
|
const Stream = @import("pjrt.zig").Stream;
|
||||||
|
|
||||||
const log = std.log.scoped(.pjrt);
|
const log = std.log.scoped(.pjrt);
|
||||||
|
|
||||||
|
comptime {
|
||||||
|
if (@typeInfo(TypeId).@"struct".fields.len != 1) @compileError("TypeId has changed");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The signature of a generic custom call.
|
||||||
|
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
|
||||||
|
|
||||||
pub const ApiVersion = extern struct {
|
pub const ApiVersion = extern struct {
|
||||||
pub const major = c.XLA_FFI_API_MAJOR;
|
pub const major = c.XLA_FFI_API_MAJOR;
|
||||||
pub const minor = c.XLA_FFI_API_MINOR;
|
pub const minor = c.XLA_FFI_API_MINOR;
|
||||||
@ -29,13 +38,13 @@ pub const ExtensionBase = extern struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
|
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
|
||||||
pub const HandlerTraits = packed struct(u32) {
|
pub const HandlerTraits = packed struct(c_uint) {
|
||||||
/// Calls to FFI handler are safe to trace into the command buffer.
|
/// Calls to FFI handler are safe to trace into the command buffer.
|
||||||
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
|
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
|
||||||
/// that can be captured and then replayed.
|
/// that can be captured and then replayed.
|
||||||
command_buffer_compatible: u1,
|
command_buffer_compatible: bool,
|
||||||
|
|
||||||
__unassigned__: u31,
|
__unassigned__: u31 = 0,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Metadata = extern struct {
|
pub const Metadata = extern struct {
|
||||||
@ -49,25 +58,6 @@ pub const MetadataExtension = extern struct {
|
|||||||
metadata: ?*Metadata,
|
metadata: ?*Metadata,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const ApiError = error{
|
|
||||||
Cancelled,
|
|
||||||
Unknown,
|
|
||||||
InvalidArgument,
|
|
||||||
DeadlineExceeded,
|
|
||||||
NotFound,
|
|
||||||
AlreadyExists,
|
|
||||||
PermissionDenied,
|
|
||||||
ResourceExhausted,
|
|
||||||
FailedPrecondition,
|
|
||||||
Aborted,
|
|
||||||
OutOfRange,
|
|
||||||
Unimplemented,
|
|
||||||
Internal,
|
|
||||||
Unavailable,
|
|
||||||
DataLoss,
|
|
||||||
Unauthenticated,
|
|
||||||
};
|
|
||||||
|
|
||||||
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn to(self: anytype) switch (@TypeOf(self)) {
|
pub fn to(self: anytype) switch (@TypeOf(self)) {
|
||||||
@ -91,8 +81,8 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
|||||||
pub const Api = opaque {
|
pub const Api = opaque {
|
||||||
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
||||||
|
|
||||||
pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
|
pub fn stream(self: *const Api, context: *const ExecutionContext) *pjrt.Stream {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
||||||
.ctx = @constCast(context.inner()),
|
.ctx = @constCast(context.inner()),
|
||||||
});
|
});
|
||||||
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
||||||
@ -107,8 +97,8 @@ pub const Api = opaque {
|
|||||||
return @ptrCast(ret.stream.?);
|
return @ptrCast(ret.stream.?);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
|
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) pjrt.ApiError!*anyopaque {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
||||||
.ctx = @constCast(context.inner()),
|
.ctx = @constCast(context.inner()),
|
||||||
.size = size,
|
.size = size,
|
||||||
.alignment = alignment,
|
.alignment = alignment,
|
||||||
@ -127,8 +117,8 @@ pub const Api = opaque {
|
|||||||
return ret.data.?;
|
return ret.data.?;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
|
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) pjrt.ApiError!void {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
||||||
.ctx = @constCast(context.inner()),
|
.ctx = @constCast(context.inner()),
|
||||||
.size = size,
|
.size = size,
|
||||||
.data = data,
|
.data = data,
|
||||||
@ -163,33 +153,31 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
|
|||||||
pub const ExecutionContext = opaque {
|
pub const ExecutionContext = opaque {
|
||||||
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
||||||
|
|
||||||
pub fn Context(comptime T: type) type {
|
pub fn getContext(self: *const ExecutionContext, type_id: TypeId, api: *const Api) pjrt.ApiError!*anyopaque {
|
||||||
return struct {
|
var ret: c.XLA_FFI_ExecutionContext_Get_Args = .{
|
||||||
pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
|
.struct_size = pjrt.pjrtStructSize(c.XLA_FFI_ExecutionContext_Get_Args),
|
||||||
const type_id: TypeId = .{ .type_id = T.type_id };
|
.extension_start = api.inner().extension_start,
|
||||||
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
|
.ctx = @ptrCast(@constCast(self)),
|
||||||
.ctx = @constCast(self.inner()),
|
.type_id = @constCast(&type_id),
|
||||||
.type_id = @constCast(&type_id.toCStruct()),
|
.data = undefined, // set by XLA_FFI_ExecutionContext_Get.
|
||||||
});
|
|
||||||
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
|
||||||
|
|
||||||
if (result) |ffi_error| {
|
|
||||||
const err = Error.fromInner(ffi_error);
|
|
||||||
defer err.destroy(api);
|
|
||||||
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
|
|
||||||
|
|
||||||
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
|
||||||
return error.Unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ret.data == null) return error.NotFound;
|
|
||||||
return @ptrCast(@alignCast(ret.data.?));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
const maybe_err = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
||||||
|
|
||||||
|
if (maybe_err) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ret.data == null) return error.NotFound;
|
||||||
|
return ret.data.?;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
|
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) pjrt.ApiError!i32 {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
|
||||||
.ctx = @constCast(self.inner()),
|
.ctx = @constCast(self.inner()),
|
||||||
});
|
});
|
||||||
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
|
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
|
||||||
@ -206,8 +194,10 @@ pub const ExecutionContext = opaque {
|
|||||||
return ret.device_ordinal;
|
return ret.device_ordinal;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
|
const Task = fn (*anyopaque) void;
|
||||||
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
|
|
||||||
|
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) pjrt.ApiError!void {
|
||||||
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
|
||||||
.ctx = @constCast(self.inner()),
|
.ctx = @constCast(self.inner()),
|
||||||
.task = @ptrCast(@alignCast(task)),
|
.task = @ptrCast(@alignCast(task)),
|
||||||
.data = @ptrCast(@alignCast(data)),
|
.data = @ptrCast(@alignCast(data)),
|
||||||
@ -225,23 +215,9 @@ pub const ExecutionContext = opaque {
|
|||||||
return error.Unknown;
|
return error.Unknown;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn getTypeId(type_name: []const u8) TypeId {
|
|
||||||
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
|
|
||||||
|
|
||||||
return .{
|
|
||||||
.type_id = id,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const TypeId = c.XLA_FFI_TypeId;
|
pub const ByteSpan = extern struct {
|
||||||
|
|
||||||
const Task = fn (*anyopaque) void;
|
|
||||||
|
|
||||||
const Stream = @import("pjrt.zig").Stream;
|
|
||||||
|
|
||||||
const ByteSpan = extern struct {
|
|
||||||
ptr: [*]const u8,
|
ptr: [*]const u8,
|
||||||
len: usize,
|
len: usize,
|
||||||
|
|
||||||
@ -252,7 +228,7 @@ const ByteSpan = extern struct {
|
|||||||
|
|
||||||
pub const DataType = enum(c.XLA_FFI_DataType) {
|
pub const DataType = enum(c.XLA_FFI_DataType) {
|
||||||
invalid = c.XLA_FFI_DataType_INVALID,
|
invalid = c.XLA_FFI_DataType_INVALID,
|
||||||
pred = c.XLA_FFI_DataType_PRED,
|
bool = c.XLA_FFI_DataType_PRED,
|
||||||
i8 = c.XLA_FFI_DataType_S8,
|
i8 = c.XLA_FFI_DataType_S8,
|
||||||
i16 = c.XLA_FFI_DataType_S16,
|
i16 = c.XLA_FFI_DataType_S16,
|
||||||
i32 = c.XLA_FFI_DataType_S32,
|
i32 = c.XLA_FFI_DataType_S32,
|
||||||
@ -399,6 +375,8 @@ pub const Attrs = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// All informations needed by the user callback,
|
||||||
|
/// including the list of input/ouput buffers to work on.
|
||||||
pub const CallFrame = extern struct {
|
pub const CallFrame = extern struct {
|
||||||
struct_size: usize,
|
struct_size: usize,
|
||||||
extension_start: ?*ExtensionBase,
|
extension_start: ?*ExtensionBase,
|
||||||
@ -422,9 +400,11 @@ pub const CallFrame = extern struct {
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
|
pub fn stream(call_frame: CallFrame) ?*const pjrt.Stream {
|
||||||
|
return call_frame.api.stream(call_frame.ctx);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
||||||
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
||||||
@ -444,7 +424,7 @@ pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
|||||||
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
|
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
|
||||||
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
|
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
|
||||||
|
|
||||||
pub fn toApiError(code: ErrorCode) ApiError {
|
pub fn toApiError(code: ErrorCode) pjrt.ApiError {
|
||||||
return switch (code) {
|
return switch (code) {
|
||||||
.cancelled => error.Cancelled,
|
.cancelled => error.Cancelled,
|
||||||
.unknown => error.Unknown,
|
.unknown => error.Unknown,
|
||||||
@ -470,8 +450,10 @@ pub const Error = opaque {
|
|||||||
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
||||||
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
||||||
|
|
||||||
|
pub const ok: ?*Error = null;
|
||||||
|
|
||||||
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
|
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
||||||
.message = message.ptr,
|
.message = message.ptr,
|
||||||
.errc = @intFromEnum(error_code),
|
.errc = @intFromEnum(error_code),
|
||||||
});
|
});
|
||||||
@ -479,12 +461,12 @@ pub const Error = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn destroy(err: *Error, api: *const Api) void {
|
pub fn destroy(err: *Error, api: *const Api) void {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
|
||||||
api.inner().XLA_FFI_Error_Destroy.?(&ret);
|
api.inner().XLA_FFI_Error_Destroy.?(&ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
|
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
|
||||||
.@"error" = err.inner(),
|
.@"error" = err.inner(),
|
||||||
});
|
});
|
||||||
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
|
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
|
||||||
@ -496,8 +478,8 @@ pub const Future = opaque {
|
|||||||
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
|
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
|
||||||
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
|
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
|
||||||
|
|
||||||
pub fn create(api: *const Api) ApiError!*Future {
|
pub fn create(api: *const Api) pjrt.ApiError!*Future {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{});
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_Create_Args{});
|
||||||
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
|
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
|
||||||
|
|
||||||
if (result) |ffi_error| {
|
if (result) |ffi_error| {
|
||||||
@ -512,8 +494,8 @@ pub const Future = opaque {
|
|||||||
return fromInner(ret.future.?);
|
return fromInner(ret.future.?);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn setAvailable(self: *Future, api: *const Api) ApiError!void {
|
pub fn setAvailable(self: *Future, api: *const Api) pjrt.ApiError!void {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
|
||||||
.future = self.inner(),
|
.future = self.inner(),
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -529,8 +511,8 @@ pub const Future = opaque {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void {
|
pub fn setError(self: *Future, api: *const Api, err: *Error) pjrt.ApiError!void {
|
||||||
var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{
|
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetError_Args{
|
||||||
.future = self.inner(),
|
.future = self.inner(),
|
||||||
.@"error" = err.inner(),
|
.@"error" = err.inner(),
|
||||||
});
|
});
|
||||||
|
|||||||
@ -20,7 +20,7 @@ test {
|
|||||||
// as the way PJRT does it is not very robust.
|
// as the way PJRT does it is not very robust.
|
||||||
//
|
//
|
||||||
// 1. https://github.com/openxla/xla/issues/10032
|
// 1. https://github.com/openxla/xla/issues/10032
|
||||||
fn pjrtStructSize(comptime T: type) usize {
|
pub fn pjrtStructSize(comptime T: type) usize {
|
||||||
// unsafe on purpose, we want this to fail if that ever changes
|
// unsafe on purpose, we want this to fail if that ever changes
|
||||||
const typedef_name = comptime blk: {
|
const typedef_name = comptime blk: {
|
||||||
const needle = ".struct_";
|
const needle = ".struct_";
|
||||||
@ -164,7 +164,7 @@ pub const Api = struct {
|
|||||||
return @ptrCast(ret.context.?);
|
return @ptrCast(ret.context.?);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn ffi(api: *const Api) ?FFI {
|
pub fn ffi(api: *const Api) ?Ffi {
|
||||||
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
||||||
return .{ .inner = ext };
|
return .{ .inner = ext };
|
||||||
}
|
}
|
||||||
@ -1278,7 +1278,7 @@ pub const NamedValue = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const FFI = extern struct {
|
pub const Ffi = extern struct {
|
||||||
inner: *const c.PJRT_FFI,
|
inner: *const c.PJRT_FFI,
|
||||||
|
|
||||||
pub const UserData = extern struct {
|
pub const UserData = extern struct {
|
||||||
@ -1293,19 +1293,15 @@ pub const FFI = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const RegisterFfiOptions = struct {
|
|
||||||
traits: RegisterHandlerTraits = @enumFromInt(0),
|
|
||||||
};
|
|
||||||
|
|
||||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||||
pub fn register(
|
pub fn register(
|
||||||
self: *const FFI,
|
self: *const Ffi,
|
||||||
api: *const Api,
|
api: *const Api,
|
||||||
target_name: []const u8,
|
target_name: []const u8,
|
||||||
platform_name: []const u8,
|
platform_name: []const u8,
|
||||||
func: *const ffi.Handler,
|
func: *const ffi.Handler,
|
||||||
options: RegisterFfiOptions,
|
traits: ffi.HandlerTraits,
|
||||||
) ApiError!void {
|
) ApiError!void {
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
||||||
.target_name = target_name.ptr,
|
.target_name = target_name.ptr,
|
||||||
@ -1313,7 +1309,7 @@ pub const FFI = extern struct {
|
|||||||
.handler = @ptrCast(@constCast(func)),
|
.handler = @ptrCast(@constCast(func)),
|
||||||
.platform_name = platform_name.ptr,
|
.platform_name = platform_name.ptr,
|
||||||
.platform_name_size = platform_name.len,
|
.platform_name_size = platform_name.len,
|
||||||
.traits = @intFromEnum(options.traits),
|
.traits = @bitCast(traits),
|
||||||
});
|
});
|
||||||
const result = self.inner.register_handler.?(&ret);
|
const result = self.inner.register_handler.?(&ret);
|
||||||
if (result) |pjrt_c_error| {
|
if (result) |pjrt_c_error| {
|
||||||
@ -1323,8 +1319,7 @@ pub const FFI = extern struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
|
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId {
|
||||||
const type_name = @typeName(T);
|
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
|
||||||
.type_name = type_name.ptr,
|
.type_name = type_name.ptr,
|
||||||
.type_name_size = type_name.len,
|
.type_name_size = type_name.len,
|
||||||
@ -1336,10 +1331,10 @@ pub const FFI = extern struct {
|
|||||||
return pjrt_error.getCode(api).toApiError();
|
return pjrt_error.getCode(api).toApiError();
|
||||||
}
|
}
|
||||||
|
|
||||||
T.type_id = ret.type_id;
|
return .{ .type_id = ret.type_id };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
|
pub fn addUserData(self: *const Ffi, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
|
||||||
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
|
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
|
||||||
.context = @ptrCast(context),
|
.context = @ptrCast(context),
|
||||||
.user_data = user_data.toCStruct(),
|
.user_data = user_data.toCStruct(),
|
||||||
@ -1352,12 +1347,3 @@ pub const FFI = extern struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
|
|
||||||
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
|
|
||||||
_,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const CustomCallRegistry = extern struct {
|
|
||||||
inner: *const c.PJRT_FFI_Register_Handler,
|
|
||||||
};
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ zig_shared_library(
|
|||||||
name = "zmlxcuda",
|
name = "zmlxcuda",
|
||||||
# Use Clang's compiler-rt, but disable stack checking
|
# Use Clang's compiler-rt, but disable stack checking
|
||||||
# to avoid requiring on the _zig_probe_stack symbol.
|
# to avoid requiring on the _zig_probe_stack symbol.
|
||||||
copts = ["-fno-stack-check"],
|
copts = ["-fno-stack-check", "-fllvm"],
|
||||||
main = "zmlxcuda.zig",
|
main = "zmlxcuda.zig",
|
||||||
shared_lib_name = "libzmlxcuda.so.0",
|
shared_lib_name = "libzmlxcuda.so.0",
|
||||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||||
|
|||||||
160
stdx/fmt.zig
160
stdx/fmt.zig
@ -1,145 +1,117 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
pub const Fmt = union(enum) {
|
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
|
||||||
int: IntFmt,
|
return .{ .slice = any_slice };
|
||||||
float: FloatFmt,
|
}
|
||||||
generic: void,
|
|
||||||
|
|
||||||
pub fn parse(T: type, comptime fmt_: []const u8) Fmt {
|
fn FmtSlice(T: type) type {
|
||||||
return switch (@typeInfo(T)) {
|
return struct {
|
||||||
.float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) },
|
slice: []const T,
|
||||||
.int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) },
|
|
||||||
else => .{ .generic = {} },
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const FullFormatOptions = struct {
|
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
fmt: Fmt,
|
return switch (@typeInfo(T)) {
|
||||||
options: std.fmt.FormatOptions,
|
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
|
||||||
};
|
.comptime_int, .int => try formatIntSlice(f.slice, n, writer),
|
||||||
|
.bool => try formatBoolSlice(f.slice, n, writer),
|
||||||
pub const IntFmt = struct {
|
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
|
||||||
base: u8,
|
try formatComplexSlice(f.slice, n, writer);
|
||||||
case: std.fmt.Case = .lower,
|
} else if (@hasDecl(T, "toF32")) {
|
||||||
|
try formatFloatSlice(f.slice, n, writer);
|
||||||
pub fn parseComptime(comptime fmt_: []const u8) IntFmt {
|
} else {
|
||||||
return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_);
|
try formatSliceAny(f.slice, n, writer);
|
||||||
}
|
},
|
||||||
|
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
|
||||||
pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt {
|
};
|
||||||
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d"))
|
}
|
||||||
.{ .base = 10, .case = .lower }
|
|
||||||
else if (std.mem.eql(u8, fmt_, "x"))
|
|
||||||
.{ .base = 16, .case = .lower }
|
|
||||||
else if (std.mem.eql(u8, fmt_, "X"))
|
|
||||||
.{ .base = 16, .case = .upper }
|
|
||||||
else if (std.mem.eql(u8, fmt_, "o"))
|
|
||||||
.{ .base = 8, .case = .upper }
|
|
||||||
else
|
|
||||||
// TODO: unicode/ascii
|
|
||||||
error.InvalidArgument;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const FloatFmt = enum(u8) {
|
|
||||||
scientific = @intFromEnum(std.fmt.Number.Mode.scientific),
|
|
||||||
decimal = @intFromEnum(std.fmt.Number.Mode.decimal),
|
|
||||||
hex,
|
|
||||||
|
|
||||||
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
|
|
||||||
return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt {
|
|
||||||
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e"))
|
|
||||||
.scientific
|
|
||||||
else if (std.mem.eql(u8, fmt_, "d"))
|
|
||||||
.decimal
|
|
||||||
else if (std.mem.eql(u8, fmt_, "x"))
|
|
||||||
.hex
|
|
||||||
else
|
|
||||||
error.InvalidArgument;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
|
||||||
return switch (@typeInfo(@TypeOf(value))) {
|
|
||||||
.comptime_float, .float => try formatFloatValue(value, full, writer),
|
|
||||||
.comptime_int, .int => try formatIntValue(value, full, writer),
|
|
||||||
else => try formatAnyValue(value, full, writer),
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
pub fn formatFloat(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
const x = switch (@typeInfo(@TypeOf(value))) {
|
const x = switch (@typeInfo(@TypeOf(value))) {
|
||||||
.@"struct" => value.toF32(),
|
.@"struct" => value.toF32(),
|
||||||
.float => value,
|
.float => value,
|
||||||
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
|
else => @compileError("formatFloat expects a float, got: " ++ @typeName(@TypeOf(value))),
|
||||||
};
|
|
||||||
try switch (full.fmt.float) {
|
|
||||||
.scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }),
|
|
||||||
.decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }),
|
|
||||||
.hex => writer.printFloatHexOptions(x, .{ .mode = .hex }),
|
|
||||||
};
|
};
|
||||||
|
return writer.printFloat(x, spec);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
pub fn formatInt(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
switch (@typeInfo(@TypeOf(value))) {
|
switch (@typeInfo(@TypeOf(value))) {
|
||||||
.int => {},
|
.int => {},
|
||||||
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
|
else => @compileError("formatInt expects an int, got: " ++ @typeName(@TypeOf(value))),
|
||||||
}
|
}
|
||||||
return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options);
|
return writer.printInt(value, spec.mode.base().?, spec.case, .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
pub fn formatComplex(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
|
try writer.writeAll(".{.re=");
|
||||||
|
try writer.printFloat(value.re, spec);
|
||||||
|
try writer.writeAll(", .im=");
|
||||||
|
try writer.printFloat(value.im, spec);
|
||||||
|
try writer.writeAll("}");
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatBool(value: bool, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
|
try writer.alignBufferOptions(if (value) "1" else "0", .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
var buf: [48]u8 = undefined;
|
var buf: [48]u8 = undefined;
|
||||||
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
|
const T = @TypeOf(value);
|
||||||
|
const fmt = if (@hasDecl(T, "formatNumber")) "{d}" else "{f}";
|
||||||
|
|
||||||
|
const s = std.fmt.bufPrint(&buf, fmt, .{value}) catch blk: {
|
||||||
buf[45..].* = "...".*;
|
buf[45..].* = "...".*;
|
||||||
break :blk buf[0..];
|
break :blk buf[0..];
|
||||||
};
|
};
|
||||||
return try writer.alignBufferOptions(s, full.options);
|
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
|
// use the format "width" for the number of columns instead of individual width.
|
||||||
// Write first rows
|
const num_cols: usize = spec.width orelse 12;
|
||||||
const num_cols: usize = full.options.width orelse 12;
|
var my_options = spec;
|
||||||
|
my_options.width = null;
|
||||||
const n: usize = values.len;
|
const n: usize = values.len;
|
||||||
|
|
||||||
_ = try writer.write("{");
|
_ = try writer.write("{");
|
||||||
if (n <= num_cols) {
|
if (n <= num_cols) {
|
||||||
for (values, 0..) |v, i| {
|
for (values, 0..) |v, i| {
|
||||||
// Force inlining so that the switch and the buffer can be done once.
|
// Force inlining so that the switch and the buffer can be done once.
|
||||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||||
if (i < n - 1) _ = try writer.write(",");
|
if (i < n - 1) _ = try writer.write(",");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const half = @divFloor(num_cols, 2);
|
const half = @divFloor(num_cols, 2);
|
||||||
for (values[0..half]) |v| {
|
for (values[0..half]) |v| {
|
||||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||||
_ = try writer.write(",");
|
_ = try writer.write(",");
|
||||||
}
|
}
|
||||||
_ = try writer.write(" ..., ");
|
_ = try writer.write(" ..., ");
|
||||||
for (values[n - half ..], 0..) |v, i| {
|
for (values[n - half ..], 0..) |v, i| {
|
||||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||||
if (i < half - 1) _ = try writer.write(",");
|
if (i < half - 1) _ = try writer.write(",");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = try writer.write("}");
|
_ = try writer.write("}");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
return try formatSliceCustom(formatAny, values, spec, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatFloatValue, values, full, writer);
|
return try formatSliceCustom(formatFloat, values, spec, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatIntValue, values, full, writer);
|
return try formatSliceCustom(formatInt, values, spec, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
return try formatSliceCustom(formatComplex, values, spec, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||||
|
return try formatSliceCustom(formatBool, values, spec, writer);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -31,6 +31,7 @@ zig_library(
|
|||||||
"aio/torch/py.zig",
|
"aio/torch/py.zig",
|
||||||
"buffer.zig",
|
"buffer.zig",
|
||||||
"context.zig",
|
"context.zig",
|
||||||
|
"callback.zig",
|
||||||
"dtype.zig",
|
"dtype.zig",
|
||||||
"exe.zig",
|
"exe.zig",
|
||||||
"floats.zig",
|
"floats.zig",
|
||||||
@ -53,7 +54,7 @@ zig_library(
|
|||||||
"torch.zig",
|
"torch.zig",
|
||||||
"zml.zig",
|
"zml.zig",
|
||||||
],
|
],
|
||||||
copts = ["-lc"],
|
copts = ["-lc", "-freference-trace=20"],
|
||||||
main = "zml.zig",
|
main = "zml.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -49,7 +49,7 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
pub const FromOptions = struct {
|
pub const FromOptions = struct {
|
||||||
wait: bool = true,
|
wait: bool = true,
|
||||||
memory: ?pjrt.Memory.Kind = null,
|
memory: ?Memory = null,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||||
@ -89,15 +89,20 @@ pub const Buffer = struct {
|
|||||||
.byte_strides = byte_strides,
|
.byte_strides = byte_strides,
|
||||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||||
};
|
};
|
||||||
if (opts.memory) |memory_kind| {
|
if (platform.target == .cpu or opts.memory == null) {
|
||||||
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
|
||||||
const memory = for (memories) |m| {
|
|
||||||
const kind = m.kind(platform.pjrt_api);
|
|
||||||
if (kind == memory_kind) break m;
|
|
||||||
} else return error.NotFound;
|
|
||||||
args.memory = memory;
|
|
||||||
} else {
|
|
||||||
args.device = devices[i];
|
args.device = devices[i];
|
||||||
|
} else {
|
||||||
|
const memory = opts.memory.?;
|
||||||
|
const device_memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||||
|
// TODO measure the cost of this and consider caching on Zig side inside the platform.
|
||||||
|
const selected_memory = for (device_memories) |m| {
|
||||||
|
const kind = m.kind(platform.pjrt_api);
|
||||||
|
if (kind == memory.toPjrtMemory()) break m;
|
||||||
|
} else {
|
||||||
|
log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) });
|
||||||
|
return error.NotFound;
|
||||||
|
};
|
||||||
|
args.memory = selected_memory;
|
||||||
}
|
}
|
||||||
|
|
||||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||||
@ -179,10 +184,10 @@ pub const Buffer = struct {
|
|||||||
return try from(platform, host_buffer, opts);
|
return try from(platform, host_buffer, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
pub fn asHostBuffer(self: Buffer) HostBuffer {
|
||||||
// TODO restore assert
|
// TODO: skip this check on cpu
|
||||||
// const memory = self.getMemory().kind(self._api);
|
// const memory = self.getMemory().kind(self._api);
|
||||||
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
|
// stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
||||||
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||||
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||||
}
|
}
|
||||||
@ -299,6 +304,12 @@ pub const Buffer = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 {
|
||||||
|
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
|
||||||
|
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
|
||||||
|
return @ptrCast(opaque_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||||
pub fn getValue(self: Buffer, T: type) !T {
|
pub fn getValue(self: Buffer, T: type) !T {
|
||||||
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||||
@ -390,13 +401,31 @@ pub const Buffer = struct {
|
|||||||
return @reduce(.Or, self._shape._sharding_info);
|
return @reduce(.Or, self._shape._sharding_info);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
|
pub const CopyToMemoryOpts = struct {
|
||||||
|
wait: bool = true,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer {
|
||||||
|
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory());
|
||||||
|
if (pjrt_memory == null) {
|
||||||
|
stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self });
|
||||||
|
}
|
||||||
|
|
||||||
var new_shards: Buffer.Shards = .{};
|
var new_shards: Buffer.Shards = .{};
|
||||||
for (self._shards.slice()) |shard| {
|
for (self._shards.slice()) |shard| {
|
||||||
const new_shard = try shard.copyToMemory(self._api, memory);
|
const new_shard = try shard.copyToMemory(self._api, pjrt_memory.?);
|
||||||
new_shards.appendAssumeCapacity(new_shard);
|
new_shards.appendAssumeCapacity(new_shard);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (opts.wait) {
|
||||||
|
for (new_shards.constSlice()) |shard| {
|
||||||
|
const event = shard.getReadyEvent(self._api);
|
||||||
|
if (event) |e| {
|
||||||
|
try e.awaitBlocking(self._api);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
313
zml/callback.zig
Normal file
313
zml/callback.zig
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
|
const mlir = @import("mlir");
|
||||||
|
const pjrt = @import("pjrt");
|
||||||
|
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
|
const CompilationContext = @import("module.zig").CompilationContext;
|
||||||
|
const DataType = @import("dtype.zig").DataType;
|
||||||
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const mlirx = @import("mlirx.zig");
|
||||||
|
const pjrtx = @import("pjrtx.zig");
|
||||||
|
const Platform = @import("platform.zig").Platform;
|
||||||
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/callback");
|
||||||
|
|
||||||
|
/// Inserts a user-defined callback into the computation graph.
|
||||||
|
/// The callback is defined with a struct, that store runtime information needed by the callback.
|
||||||
|
///
|
||||||
|
/// ❗Experimental API❗
|
||||||
|
///
|
||||||
|
/// ```zig
|
||||||
|
/// pub const MyCallback = struct {
|
||||||
|
/// // a unique type_id will be set by the PJRT plugin during registration.
|
||||||
|
/// pub var type_id: pjrt.ffi.TypeId = undefined;
|
||||||
|
///
|
||||||
|
/// pub const callback_config: zml.callback.Config = .{
|
||||||
|
/// // assumption this custom call makes about the input / output buffers
|
||||||
|
/// };
|
||||||
|
///
|
||||||
|
/// // Required, this will tell the callback in which env it runs.
|
||||||
|
/// platform: zml.Platform,
|
||||||
|
/// // data needed by the callback
|
||||||
|
/// my_data: []const u8,
|
||||||
|
///
|
||||||
|
/// // storage modified by the runtime to tell the callback where it should write its results.
|
||||||
|
/// // Normally the callback doesn't need to allocate as the input and output buffers are given.
|
||||||
|
/// results: [1]Buffer = undefined,
|
||||||
|
///
|
||||||
|
/// pub fn init(my_data: []const u8) !MyCallback {
|
||||||
|
/// return .{ .my_data = my_data };
|
||||||
|
/// }
|
||||||
|
///
|
||||||
|
/// pub fn call(callback: *MyCallback, input: Buffer) !void {
|
||||||
|
/// // Do something with `input` and `callback.my_data`, write the results inside `callback.results[0]`
|
||||||
|
/// }
|
||||||
|
/// };
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// See eg the implementation of the `zml.callback.Print` callback, for a practical example.
|
||||||
|
///
|
||||||
|
/// Note calling this during the compilation of a module, isn't enough:
|
||||||
|
///
|
||||||
|
/// * backend need to be made aware of the callback, see `zml.Platform.registerCallback`
|
||||||
|
/// * executable need to know the specific data needed by `MyCallback`, see `zml.Exe.bind`
|
||||||
|
pub fn call(
|
||||||
|
comptime Callback: type,
|
||||||
|
inputs: TensorArgs(Callback),
|
||||||
|
output_shapes: []const Shape,
|
||||||
|
) []Tensor {
|
||||||
|
checkIsValidCallback(Callback);
|
||||||
|
|
||||||
|
const ctx = CompilationContext.current();
|
||||||
|
const allocator = ctx.allocator();
|
||||||
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
|
const platform = ctx._platform;
|
||||||
|
const pjrt_api = platform.pjrt_api;
|
||||||
|
|
||||||
|
if (pjrt_api.ffi() == null) {
|
||||||
|
stdx.debug.panic("Custom calls are not supported for target {s}", .{@tagName(platform.target)});
|
||||||
|
}
|
||||||
|
|
||||||
|
const output_tensors = allocator.alloc(Tensor, output_shapes.len) catch @panic("OOM");
|
||||||
|
// Note: we don't always free output_tensor, because it's returned to the caller.
|
||||||
|
// It's also why we allocate it first so that it doesn't fragment the arena.
|
||||||
|
errdefer allocator.free(output_tensors);
|
||||||
|
|
||||||
|
const output_types = allocator.alloc(mlir.Type, output_shapes.len) catch @panic("OOM");
|
||||||
|
defer allocator.free(output_types);
|
||||||
|
for (output_types, output_shapes) |*output_type, output_shape| {
|
||||||
|
output_type.* = mlirx.tensorType(mlir_ctx, output_shape);
|
||||||
|
}
|
||||||
|
const input_values = allocator.alloc(mlir.Value, inputs.len) catch @panic("OOM");
|
||||||
|
defer allocator.free(input_values);
|
||||||
|
for (input_values, inputs) |*input_value, input_tensor| {
|
||||||
|
input_value.* = input_tensor.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
const target_name = "zml$" ++ @typeName(Callback);
|
||||||
|
const op = stablehlo.custom_call(
|
||||||
|
mlir_ctx,
|
||||||
|
input_values,
|
||||||
|
.{
|
||||||
|
.call_target_name = target_name,
|
||||||
|
.api_version = .typed_ffi,
|
||||||
|
.backend_config = .dict(mlir_ctx, &.{}),
|
||||||
|
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
|
||||||
|
.has_side_effect = true,
|
||||||
|
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
|
||||||
|
},
|
||||||
|
output_types,
|
||||||
|
mlir_ctx.location(@src()),
|
||||||
|
);
|
||||||
|
|
||||||
|
for (output_tensors, output_shapes, 0..) |*output_tensor, output_shape, i| {
|
||||||
|
output_tensor.* = Tensor._result(output_shape, op.result(i));
|
||||||
|
}
|
||||||
|
return output_tensors;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Describe properties of a callback
|
||||||
|
///
|
||||||
|
/// * output_operand_aliases: the callback reuse input buffer to write the output
|
||||||
|
/// * copy_inputs_to_host_pinned: the callback need to work on host visible buffers
|
||||||
|
/// * traits: PJRT specified properties of the callback
|
||||||
|
pub const Config = struct {
|
||||||
|
output_operand_aliases: []const i64 = &.{},
|
||||||
|
copy_inputs_to_host_pinned: bool = false,
|
||||||
|
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
|
||||||
|
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
|
||||||
|
// TODO: handle sharded inputs
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Compile-time check that a callback has all informations we require.
|
||||||
|
pub fn checkIsValidCallback(Callback: type) void {
|
||||||
|
stdx.debug.assertComptime(@hasDecl(Callback, "call"), "Expected callback {} to have a call method", .{Callback});
|
||||||
|
const ArgsT = stdx.meta.FnArgs(Callback.call);
|
||||||
|
inline for (@typeInfo(ArgsT).@"struct".fields[1..]) |field| {
|
||||||
|
stdx.debug.assertComptime(field.type == Buffer, "Expected callback {}.call arguments to be of type zml.Buffer, got {}", .{ Callback, field.type });
|
||||||
|
}
|
||||||
|
|
||||||
|
stdx.debug.assertComptime(@hasDecl(Callback, "type_id") and @TypeOf(Callback.type_id) == pjrt.ffi.TypeId, "Expected callback {} to have a field `pub var type_id: pjrt.ffi.TypeId`", .{Callback});
|
||||||
|
stdx.debug.assertComptime(@hasDecl(Callback, "callback_config") and @TypeOf(Callback.callback_config) == Config, "Expected callback {} to have a field `pub const callback_config: zml.CustomCallOptions`", .{Callback});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
|
||||||
|
checkIsValidCallback(Callback);
|
||||||
|
|
||||||
|
const ffi = platform.pjrt_api.ffi() orelse return error.Unavailable;
|
||||||
|
const target_name = "zml$" ++ @typeName(Callback);
|
||||||
|
|
||||||
|
const proxy_cb = proxy(Callback);
|
||||||
|
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback));
|
||||||
|
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
|
||||||
|
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
|
||||||
|
}
|
||||||
|
|
||||||
|
fn proxy(Callback: type) pjrt.ffi.Handler {
|
||||||
|
return struct {
|
||||||
|
pub fn cb(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
|
||||||
|
return CallbackImpl(Callback, call_frame);
|
||||||
|
}
|
||||||
|
}.cb;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt.ffi.Error {
|
||||||
|
if (call_frame.registeringHook()) return null;
|
||||||
|
|
||||||
|
const opts = Callback.callback_config;
|
||||||
|
|
||||||
|
const execution_context = call_frame.ctx;
|
||||||
|
log.debug("Custom call {s} called !", .{@typeName(Callback)});
|
||||||
|
const user_ctx_opaque = execution_context.getContext(Callback.type_id, call_frame.api) catch {
|
||||||
|
log.err("{} user data was never given for current executable", .{Callback});
|
||||||
|
return .create(call_frame.api, .failed_precondition, "failed to fetch user context" ++ @typeName(Callback));
|
||||||
|
};
|
||||||
|
const user_ctx: *Callback = @ptrCast(@alignCast(user_ctx_opaque));
|
||||||
|
// We actually have one more constraint here, we force the Callback to have a platform field,
|
||||||
|
// and to correctly set it.
|
||||||
|
// Is this good ? We could also simplify this by registering ourselves the `Platform` type id.
|
||||||
|
const platform: Platform = user_ctx.platform;
|
||||||
|
|
||||||
|
// Hook to get a cuda stream in the callback.
|
||||||
|
if (@hasField(Callback, "stream") and platform.target != .cpu) {
|
||||||
|
const stream = call_frame.api.stream(execution_context);
|
||||||
|
user_ctx.stream = stream;
|
||||||
|
}
|
||||||
|
|
||||||
|
var callback_args: std.meta.ArgsTuple(@TypeOf(Callback.call)) = undefined;
|
||||||
|
callback_args[0] = user_ctx;
|
||||||
|
|
||||||
|
inline for (1..callback_args.len, call_frame.args.buffers()) |i, ffi_buffer| {
|
||||||
|
const shape = shapeFromFfi(ffi_buffer);
|
||||||
|
var zml_buffer: Buffer = if (platform.target == .cpu)
|
||||||
|
.asViewOfHostBuffer(platform, .fromBytes(shape, ffi_buffer.data[0..shape.byteSize()]))
|
||||||
|
else
|
||||||
|
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
|
||||||
|
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||||
|
log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() });
|
||||||
|
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
|
||||||
|
log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err });
|
||||||
|
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
|
||||||
|
};
|
||||||
|
log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* });
|
||||||
|
}
|
||||||
|
callback_args[i] = zml_buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
defer {
|
||||||
|
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||||
|
inline for (1..callback_args.len) |i| callback_args[i].deinit();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (0..call_frame.results.len) |i| {
|
||||||
|
const ffi_buffer = call_frame.results.buffers()[i];
|
||||||
|
const ffi_buffer_shape = shapeFromFfi(ffi_buffer);
|
||||||
|
|
||||||
|
if (platform.target == .cpu) {
|
||||||
|
user_ctx.results[i] = Buffer.asViewOfHostBuffer(platform, HostBuffer.fromBytes(ffi_buffer_shape, ffi_buffer.data[0..ffi_buffer_shape.byteSize()]));
|
||||||
|
} else {
|
||||||
|
user_ctx.results[i] = Buffer.asViewOfDeviceBuffer(platform, shapeFromFfi(ffi_buffer), null, ffi_buffer.data);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@call(.auto, Callback.call, callback_args) catch |err| {
|
||||||
|
log.err("Callback {} failed with {}", .{ Callback, err });
|
||||||
|
return .create(call_frame.api, .internal, "internal callback error");
|
||||||
|
};
|
||||||
|
|
||||||
|
return .ok;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Internal custom calls.
|
||||||
|
/// These are not meant to be used by users, but rather by the library itself.
|
||||||
|
pub const internal_callbacks = [_]type{
|
||||||
|
Print,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn registerInternalCallbacks(platform: Platform) !void {
|
||||||
|
inline for (internal_callbacks) |Callback| {
|
||||||
|
try register(Callback, platform);
|
||||||
|
// log.debug("Registered internal custom call {s} with type_id {d}", .{ @typeName(Callback), Callback.type_id.type_id });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Allocate user data data needed by the ZML provided custom calls.
|
||||||
|
pub fn bindInternalCallbacks(
|
||||||
|
arena: std.mem.Allocator,
|
||||||
|
platform: Platform,
|
||||||
|
ffi: pjrt.Ffi,
|
||||||
|
execute_context: *pjrt.ExecuteContext,
|
||||||
|
) (std.mem.Allocator.Error || pjrt.ApiError)!void {
|
||||||
|
// Atm we don't have a mechanism to detect which ZML callbacks the executable needs,
|
||||||
|
// so we always allocate.
|
||||||
|
{
|
||||||
|
// Print
|
||||||
|
const print_ptr = try arena.create(Print);
|
||||||
|
print_ptr.* = try .init(platform);
|
||||||
|
try addUserData(Print, platform.pjrt_api, ffi, execute_context, print_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn addUserData(
|
||||||
|
Callback: type,
|
||||||
|
api: *const pjrt.Api,
|
||||||
|
ffi: pjrt.Ffi,
|
||||||
|
execute_context: *pjrt.ExecuteContext,
|
||||||
|
user_data: *Callback,
|
||||||
|
) pjrt.ApiError!void {
|
||||||
|
try ffi.addUserData(
|
||||||
|
api,
|
||||||
|
execute_context,
|
||||||
|
.{ .type_id = Callback.type_id.type_id, .user_data = @ptrCast(user_data) },
|
||||||
|
);
|
||||||
|
log.debug("Bound {s}@{x} with type id {d} on {any}", .{ @typeName(Callback), @intFromPtr(user_data), Callback.type_id.type_id, execute_context });
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The print callback
|
||||||
|
pub const Print = struct {
|
||||||
|
// a unique type_id will be set by the PJRT plugin during registration.
|
||||||
|
pub var type_id: pjrt.ffi.TypeId = undefined;
|
||||||
|
|
||||||
|
pub const callback_config: Config = .{
|
||||||
|
// Print callback pretends to modify the given input buffer, but just returns it unmodified.
|
||||||
|
.output_operand_aliases = &.{0},
|
||||||
|
// It also needs PJRT to copy the data on the host first so it can print it.
|
||||||
|
.copy_inputs_to_host_pinned = true,
|
||||||
|
// Print is fairly predictable and can be captured in an execution graph.
|
||||||
|
.traits = .{ .command_buffer_compatible = false },
|
||||||
|
};
|
||||||
|
|
||||||
|
platform: Platform,
|
||||||
|
results: [1]Buffer = undefined,
|
||||||
|
|
||||||
|
pub fn init(platform: Platform) !Print {
|
||||||
|
return .{ .platform = platform };
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn call(_: *Print, input: Buffer) !void {
|
||||||
|
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {d:20.3}", .{ input, input.asHostBuffer() });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
fn shapeFromFfi(ffi_buffer: *const pjrt.ffi.Buffer) Shape {
|
||||||
|
const dt: DataType = switch (ffi_buffer.dtype) {
|
||||||
|
.invalid => stdx.debug.panic("Invalid FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
|
||||||
|
.token, .f8e4m3, .f8e3m4 => stdx.debug.panic("Unsupported FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
|
||||||
|
inline else => |t| @field(DataType, @tagName(t)),
|
||||||
|
};
|
||||||
|
return Shape.init(ffi_buffer.dims(), dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn TensorArgs(Callback: type) type {
|
||||||
|
const ArgsT = stdx.meta.FnArgs(Callback.call);
|
||||||
|
|
||||||
|
const args = @typeInfo(ArgsT).@"struct".fields;
|
||||||
|
return [args.len - 1]Tensor;
|
||||||
|
}
|
||||||
172
zml/context.zig
172
zml/context.zig
@ -7,16 +7,19 @@ const runfiles = @import("runfiles");
|
|||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const DataType = @import("dtype.zig").DataType;
|
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
|
||||||
const Shape = @import("shape.zig").Shape;
|
|
||||||
const Target = @import("platform.zig").Target;
|
|
||||||
const zml_platform = @import("platform.zig");
|
|
||||||
|
|
||||||
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
const zml = struct {
|
||||||
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
pub const callback = @import("callback.zig");
|
||||||
|
pub const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
pub const Platform = @import("platform.zig").Platform;
|
||||||
|
pub const platform = @import("platform.zig");
|
||||||
|
pub const Shape = @import("shape.zig").Shape;
|
||||||
|
pub const Target = @import("platform.zig").Target;
|
||||||
|
};
|
||||||
|
|
||||||
|
const PjrtApiMap = std.EnumArray(zml.Target, ?*const pjrt.Api);
|
||||||
|
const PlatformsMap = std.EnumArray(zml.Target, ?zml.Platform);
|
||||||
const log = std.log.scoped(.@"zml/context");
|
const log = std.log.scoped(.@"zml/context");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -94,7 +97,7 @@ pub const Context = struct {
|
|||||||
return .{ .platforms = PlatformsMap.initFill(null) };
|
return .{ .platforms = PlatformsMap.initFill(null) };
|
||||||
}
|
}
|
||||||
|
|
||||||
fn platformToLibrary(comptime target: Target) []const u8 {
|
fn platformToLibrary(comptime target: zml.Target) []const u8 {
|
||||||
const ext = switch (builtin.os.tag) {
|
const ext = switch (builtin.os.tag) {
|
||||||
.windows => ".dll",
|
.windows => ".dll",
|
||||||
.macos, .ios, .watchos => ".dylib",
|
.macos, .ios, .watchos => ".dylib",
|
||||||
@ -105,7 +108,7 @@ pub const Context = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
pub fn pjrtApi(target: zml.Target) *const pjrt.Api {
|
||||||
return Context.apis.get(target).?;
|
return Context.apis.get(target).?;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -119,12 +122,12 @@ pub const Context = struct {
|
|||||||
self.* = undefined;
|
self.* = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
|
const prefered_targets = [_]zml.Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
|
||||||
|
|
||||||
/// Automatically selects the best Platform loaded in the current Context.
|
/// Automatically selects the best Platform loaded in the current Context.
|
||||||
///
|
///
|
||||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||||
pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform {
|
pub fn autoPlatform(self: *Context, opts: zml.Platform.CreateOptions) zml.Platform {
|
||||||
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
|
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
|
||||||
|
|
||||||
return self.platformByPreferences(opts, &prefered_targets);
|
return self.platformByPreferences(opts, &prefered_targets);
|
||||||
@ -133,7 +136,7 @@ pub const Context = struct {
|
|||||||
/// Given a list of preferred targets to select the best Platform
|
/// Given a list of preferred targets to select the best Platform
|
||||||
///
|
///
|
||||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||||
pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform {
|
pub fn platformByPreferences(self: *Context, opts: zml.Platform.CreateOptions, prefered: []const zml.Target) zml.Platform {
|
||||||
// Try prefered targets.
|
// Try prefered targets.
|
||||||
for (prefered) |target| {
|
for (prefered) |target| {
|
||||||
if (apis.get(target) == null) continue;
|
if (apis.get(target) == null) continue;
|
||||||
@ -150,7 +153,7 @@ pub const Context = struct {
|
|||||||
// CPU should only be use as fallback.
|
// CPU should only be use as fallback.
|
||||||
if (target == .cpu) continue;
|
if (target == .cpu) continue;
|
||||||
if (entry.value.* == null) continue;
|
if (entry.value.* == null) continue;
|
||||||
if (std.mem.indexOfScalar(Target, prefered, target) != null) continue;
|
if (std.mem.indexOfScalar(zml.Target, prefered, target) != null) continue;
|
||||||
return self.platform(target, opts) catch |err| {
|
return self.platform(target, opts) catch |err| {
|
||||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||||
continue;
|
continue;
|
||||||
@ -164,25 +167,25 @@ pub const Context = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform {
|
pub fn platform(self: *Context, target: zml.Target, opts: zml.Platform.CreateOptions) !zml.Platform {
|
||||||
if (self.platforms.get(target)) |p| {
|
if (self.platforms.get(target)) |p| {
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
const api = Context.apis.get(target);
|
const api = Context.apis.get(target);
|
||||||
if (api == null) return error.PlatformNotCompiled;
|
if (api == null) return error.PlatformNotCompiled;
|
||||||
const p = try Platform.init(target, api.?, opts);
|
const p = try zml.Platform.init(target, api.?, opts);
|
||||||
if (p.getDevices().len == 0) {
|
if (p.getDevices().len == 0) {
|
||||||
log.err("No device found for platform {} !", .{target});
|
log.err("No device found for platform {} !", .{target});
|
||||||
return error.NoDevicesFound;
|
return error.NoDevicesFound;
|
||||||
}
|
}
|
||||||
|
|
||||||
try CustomCall.registerZmlCustomCalls(p);
|
|
||||||
|
|
||||||
self.platforms.set(target, p);
|
self.platforms.set(target, p);
|
||||||
|
try zml.callback.registerInternalCallbacks(p);
|
||||||
|
|
||||||
return p;
|
return p;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {
|
pub fn printAvailablePlatforms(self: Context, selected: zml.Platform) void {
|
||||||
// List available targets
|
// List available targets
|
||||||
log.info("Available Platforms:", .{});
|
log.info("Available Platforms:", .{});
|
||||||
const selected_prefix = "✅";
|
const selected_prefix = "✅";
|
||||||
@ -190,7 +193,7 @@ pub const Context = struct {
|
|||||||
const selected_postfix = "(AUTO-SELECTED)";
|
const selected_postfix = "(AUTO-SELECTED)";
|
||||||
const not_selected_postfix = "";
|
const not_selected_postfix = "";
|
||||||
|
|
||||||
for (zml_platform.available_targets) |target| {
|
for (zml.platform.available_targets) |target| {
|
||||||
log.info(" {s} {s} {s}", .{
|
log.info(" {s} {s} {s}", .{
|
||||||
if (target == selected.target) selected_prefix else not_selected_prefix,
|
if (target == selected.target) selected_prefix else not_selected_prefix,
|
||||||
@tagName(target),
|
@tagName(target),
|
||||||
@ -211,133 +214,4 @@ pub const Context = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void;
|
|
||||||
};
|
|
||||||
|
|
||||||
const CustomCall = struct {
|
|
||||||
pub fn registerZmlCustomCalls(platform: Platform) !void {
|
|
||||||
const ffi = platform.pjrt_api.ffi() orelse {
|
|
||||||
log.warn("Registering custom calls failed: No FFI Extension found in {s} PJRT Plugin.", .{@tagName(platform.target)});
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
|
|
||||||
}
|
|
||||||
|
|
||||||
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
|
|
||||||
if (call_frame.registeringHook()) return null;
|
|
||||||
|
|
||||||
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
|
|
||||||
std.debug.assert(callback_attr.dtype == .u64);
|
|
||||||
const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize));
|
|
||||||
|
|
||||||
const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable;
|
|
||||||
std.debug.assert(user_ctx_ptr.dtype == .u64);
|
|
||||||
const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize));
|
|
||||||
|
|
||||||
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
|
|
||||||
for (input_buffers, 0..) |*b, i| {
|
|
||||||
b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
|
|
||||||
for (output_buffers, 0..) |*b, i| {
|
|
||||||
b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
|
|
||||||
}
|
|
||||||
|
|
||||||
callback(user_ctx, input_buffers, output_buffers);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
|
|
||||||
// log.warn("received buffer {}", .{buffer_desc});
|
|
||||||
const dt: DataType = switch (buffer_desc.dtype) {
|
|
||||||
.invalid => @panic("invalid ffi"),
|
|
||||||
.pred => .bool,
|
|
||||||
.i8 => .i8,
|
|
||||||
.i16 => .i16,
|
|
||||||
.i32 => .i32,
|
|
||||||
.i64 => .i64,
|
|
||||||
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
|
|
||||||
inline else => |t| @field(DataType, @tagName(t)),
|
|
||||||
};
|
|
||||||
return Shape.init(buffer_desc.dims(), dt);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a HostBuffer from a ffi description of a buffer.
|
|
||||||
/// Normally the ffi describe device buffer but we assume they are located in pinned memory,
|
|
||||||
/// and therefore the data pointer is readable both from host and from device.
|
|
||||||
fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
|
||||||
const buffer_shape = getShape(buffer_desc);
|
|
||||||
return HostBuffer.fromBytes(
|
|
||||||
buffer_shape,
|
|
||||||
buffer_desc.data[0..buffer_shape.byteSize()],
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const cuda = struct {
|
|
||||||
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
|
|
||||||
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
|
|
||||||
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
|
|
||||||
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
|
|
||||||
|
|
||||||
pub const MemcpyKind = enum(c_int) {
|
|
||||||
host_to_host = 0,
|
|
||||||
host_to_device = 1,
|
|
||||||
device_to_host = 2,
|
|
||||||
device_to_device = 3,
|
|
||||||
inferred = 4,
|
|
||||||
};
|
|
||||||
|
|
||||||
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int;
|
|
||||||
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int;
|
|
||||||
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int;
|
|
||||||
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
|
|
||||||
|
|
||||||
pub fn init() void {
|
|
||||||
var cudart = std.DynLib.open("libcudart.so.12") catch {
|
|
||||||
log.err("cudart not found, callback will segfault", .{});
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
defer cudart.close();
|
|
||||||
|
|
||||||
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
|
|
||||||
@panic("cudaMemcpyAsync not found");
|
|
||||||
};
|
|
||||||
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
|
|
||||||
@panic("cudaMemcpy not found");
|
|
||||||
};
|
|
||||||
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
|
|
||||||
@panic("cudaStreamSynchronize not found");
|
|
||||||
};
|
|
||||||
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
|
|
||||||
@panic("cudaLaunchHostFunc not found");
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
|
|
||||||
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
|
|
||||||
check(err);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
|
|
||||||
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
|
|
||||||
check(err);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
|
|
||||||
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
|
|
||||||
check(err);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
|
|
||||||
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
|
|
||||||
check(err);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn check(err: c_int) void {
|
|
||||||
if (err == 0) return;
|
|
||||||
stdx.debug.panic("CUDA error: {d}", .{err});
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|||||||
49
zml/exe.zig
49
zml/exe.zig
@ -5,6 +5,7 @@ const stdx = @import("stdx");
|
|||||||
const aio = @import("aio.zig");
|
const aio = @import("aio.zig");
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const Bufferized = @import("tensor.zig").Bufferized;
|
const Bufferized = @import("tensor.zig").Bufferized;
|
||||||
|
const callback = @import("callback.zig");
|
||||||
const CompilationContext = @import("module.zig").CompilationContext;
|
const CompilationContext = @import("module.zig").CompilationContext;
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const pjrt = @import("pjrtx.zig");
|
const pjrt = @import("pjrtx.zig");
|
||||||
@ -154,7 +155,7 @@ pub const BaseExe = struct {
|
|||||||
exe: *pjrt.LoadedExecutable,
|
exe: *pjrt.LoadedExecutable,
|
||||||
|
|
||||||
/// The execution context for this executable.
|
/// The execution context for this executable.
|
||||||
context: ?*pjrt.ExecuteContext = null,
|
execute_context: ?*pjrt.ExecuteContext,
|
||||||
|
|
||||||
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
||||||
input_per_device: []const [*]*pjrt.Buffer,
|
input_per_device: []const [*]*pjrt.Buffer,
|
||||||
@ -205,9 +206,18 @@ pub const BaseExe = struct {
|
|||||||
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
||||||
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
||||||
@memcpy(all_shapes[n_in..], args.result_shapes);
|
@memcpy(all_shapes[n_in..], args.result_shapes);
|
||||||
|
|
||||||
|
var execute_context: ?*pjrt.ExecuteContext = null;
|
||||||
|
if (platform.pjrt_api.ffi()) |ffi| {
|
||||||
|
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
|
||||||
|
execute_context = try platform.pjrt_api.createExecuteContext();
|
||||||
|
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
|
||||||
|
}
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
.platform = platform,
|
.platform = platform,
|
||||||
.exe = exe,
|
.exe = exe,
|
||||||
|
.execute_context = execute_context,
|
||||||
.ready_buffer_count = 0,
|
.ready_buffer_count = 0,
|
||||||
.input_buffer_count = @intCast(n_in),
|
.input_buffer_count = @intCast(n_in),
|
||||||
.num_devices = args.n_devices,
|
.num_devices = args.n_devices,
|
||||||
@ -220,7 +230,7 @@ pub const BaseExe = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: BaseExe) void {
|
pub fn deinit(self: BaseExe) void {
|
||||||
if (self.context) |ctx| {
|
if (self.execute_context) |ctx| {
|
||||||
ctx.deinit(self.platform.pjrt_api);
|
ctx.deinit(self.platform.pjrt_api);
|
||||||
}
|
}
|
||||||
self._arena.deinit();
|
self._arena.deinit();
|
||||||
@ -244,16 +254,16 @@ pub const BaseExe = struct {
|
|||||||
// even if it has been marked as "can be donated" during compilation.
|
// even if it has been marked as "can be donated" during compilation.
|
||||||
// TODO: expose it ?
|
// TODO: expose it ?
|
||||||
.non_donatable_input_indices = &.{},
|
.non_donatable_input_indices = &.{},
|
||||||
.context = self.context,
|
.context = self.execute_context,
|
||||||
}) catch |err| {
|
}) catch |err| {
|
||||||
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
||||||
};
|
};
|
||||||
|
|
||||||
for (events[0..sharding.num_partitions]) |e| {
|
// for (events[0..sharding.num_partitions]) |e| {
|
||||||
if (e) |ev| {
|
// if (e) |ev| {
|
||||||
ev.await_(self.platform.pjrt_api) catch unreachable;
|
// ev.await_(self.platform.pjrt_api) catch unreachable;
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
||||||
@ -285,6 +295,17 @@ pub const BaseExe = struct {
|
|||||||
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void {
|
||||||
|
stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{});
|
||||||
|
const pjrt_api = exe.platform.pjrt_api;
|
||||||
|
|
||||||
|
if (pjrt_api.ffi()) |ffi| {
|
||||||
|
try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op);
|
||||||
|
} else {
|
||||||
|
stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
||||||
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
||||||
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||||
@ -314,11 +335,11 @@ pub const BaseExe = struct {
|
|||||||
|
|
||||||
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
||||||
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
|
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
|
||||||
.n_in = self.input_buffer_count,
|
.input_shapes = self.input_shapes,
|
||||||
.result_shapes = self.result_shapes,
|
.result_shapes = self.result_shapes,
|
||||||
.n_devices = self.num_devices,
|
.n_devices = self.num_devices,
|
||||||
});
|
});
|
||||||
exe.context = self.context;
|
exe.execute_context = self.execute_context;
|
||||||
return exe;
|
return exe;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -348,6 +369,14 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|||||||
return new;
|
return new;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// For a given customCall inside this executable,
|
||||||
|
/// provide a pointer to runtime data.
|
||||||
|
/// The caller keeps memory ownership and need to ensure that the value
|
||||||
|
/// stays alive as long as the executable.
|
||||||
|
pub fn bind(self: Self, comptime T: type, value: *T) !void {
|
||||||
|
try self.inner.bind(T, value);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn serialize(self: Self, writer: anytype) !void {
|
pub fn serialize(self: Self, writer: anytype) !void {
|
||||||
return try self.inner.serialize(writer);
|
return try self.inner.serialize(writer);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -325,37 +325,18 @@ pub const HostBuffer = struct {
|
|||||||
self: HostBuffer,
|
self: HostBuffer,
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
// TODO debug option
|
|
||||||
// try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
|
||||||
try writer.print("HostBuffer(.{f})", .{self._shape});
|
try writer.print("HostBuffer(.{f})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
return self.prettyPrintIndented(writer, 4, 0, n);
|
||||||
pub fn pretty(self: HostBuffer) PrettyPrinter {
|
|
||||||
return .{ .x = self };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const PrettyPrinter = struct {
|
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: std.fmt.Number) !void {
|
||||||
x: HostBuffer,
|
|
||||||
|
|
||||||
// TODO(0.15.0) revisit pretty printer
|
|
||||||
pub fn format(self: PrettyPrinter, writer: anytype) !void {
|
|
||||||
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
|
|
||||||
.integer => .parse(i32, "d"),
|
|
||||||
.float => .parse(f32, "d"),
|
|
||||||
else => .parse(void, ""),
|
|
||||||
};
|
|
||||||
const options: std.fmt.FormatOptions = .{};
|
|
||||||
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void {
|
|
||||||
return self.prettyPrintIndented(writer, 4, 0, options);
|
return self.prettyPrintIndented(writer, 4, 0, options);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
|
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void {
|
||||||
if (self.rank() == 0) {
|
if (self.rank() == 0) {
|
||||||
// Special case input tensor is a scalar
|
// Special case input tensor is a scalar
|
||||||
return switch (self.dtype()) {
|
return switch (self.dtype()) {
|
||||||
@ -363,9 +344,10 @@ pub const HostBuffer = struct {
|
|||||||
const val: dt.toZigType() = self.items(dt.toZigType())[0];
|
const val: dt.toZigType() = self.items(dt.toZigType())[0];
|
||||||
return switch (comptime dt.class()) {
|
return switch (comptime dt.class()) {
|
||||||
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
|
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
|
||||||
.float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer),
|
.float => stdx.fmt.formatFloat(floats.floatCast(f32, val), options, writer),
|
||||||
.integer => stdx.fmt.formatIntValue(val, options, writer),
|
.integer => stdx.fmt.formatInt(val, options, writer),
|
||||||
.bool, .complex => stdx.fmt.formatAnyValue(val, options, writer),
|
.bool => stdx.fmt.formatBool(val, options, writer),
|
||||||
|
.complex => stdx.fmt.formatComplex(val, options, writer),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -380,7 +362,8 @@ pub const HostBuffer = struct {
|
|||||||
switch (comptime dt.class()) {
|
switch (comptime dt.class()) {
|
||||||
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
||||||
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
||||||
.bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer),
|
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
|
||||||
|
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1178,9 +1178,9 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
|
|||||||
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
|
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
|
||||||
.pointer => |info| switch (info.size) {
|
.pointer => |info| switch (info.size) {
|
||||||
.one => switch (strat) {
|
.one => switch (strat) {
|
||||||
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
.deep => hash(hasher, key.*, .Shallow),
|
.Deep => hash(hasher, key.*, .Shallow),
|
||||||
.deeprecursive => switch (@typeInfo(info.child)) {
|
.DeepRecursive => switch (@typeInfo(info.child)) {
|
||||||
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
|
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
else => hash(hasher, key.*, .DeepRecursive),
|
else => hash(hasher, key.*, .DeepRecursive),
|
||||||
},
|
},
|
||||||
@ -1196,7 +1196,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
|
|||||||
.many,
|
.many,
|
||||||
.c,
|
.c,
|
||||||
=> switch (strat) {
|
=> switch (strat) {
|
||||||
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||||
else => @compileError(
|
else => @compileError(
|
||||||
\\ unknown-length pointers and C pointers cannot be hashed deeply.
|
\\ unknown-length pointers and C pointers cannot be hashed deeply.
|
||||||
\\ Consider providing your own hash function.
|
\\ Consider providing your own hash function.
|
||||||
|
|||||||
27
zml/ops.zig
27
zml/ops.zig
@ -764,33 +764,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const HostCallbackOpt = struct {
|
|
||||||
has_side_effect: bool = false,
|
|
||||||
output_operand_aliases: []const i64 = &.{},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn addHostCallback(
|
|
||||||
callback: *const Context.HostCallback,
|
|
||||||
blkctx: ?*anyopaque,
|
|
||||||
inputs: []const Tensor,
|
|
||||||
output_shapes: []const Shape,
|
|
||||||
opts: HostCallbackOpt,
|
|
||||||
) []Tensor {
|
|
||||||
return customCall(
|
|
||||||
"zmlHostBufferCallback",
|
|
||||||
inputs,
|
|
||||||
output_shapes,
|
|
||||||
.{
|
|
||||||
.callback = @intFromPtr(callback),
|
|
||||||
.user_context = @intFromPtr(blkctx),
|
|
||||||
},
|
|
||||||
.{
|
|
||||||
.has_side_effect = opts.has_side_effect,
|
|
||||||
.output_operand_aliases = opts.output_operand_aliases,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const TritonOps = struct {
|
pub const TritonOps = struct {
|
||||||
debug: bool = false,
|
debug: bool = false,
|
||||||
name: [:0]const u8,
|
name: [:0]const u8,
|
||||||
|
|||||||
@ -207,6 +207,13 @@ pub const Event = opaque {
|
|||||||
return self.inner().getEventError(api);
|
return self.inner().getEventError(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn awaitBlocking(self: *Event, api: *const Api) ApiError!void {
|
||||||
|
if (self.isReady(api)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
try self.inner().await_(api);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn await_(self: *Event, api: *const Api) ApiError!void {
|
pub fn await_(self: *Event, api: *const Api) ApiError!void {
|
||||||
defer self.deinit(api);
|
defer self.deinit(api);
|
||||||
|
|
||||||
@ -264,14 +271,14 @@ pub const LoadedExecutable = opaque {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
||||||
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
|
try self.inner().execute(api, pjrt.LoadedExecutable.ExecuteArgs{
|
||||||
.num_args = args.num_args,
|
.num_args = args.num_args,
|
||||||
.arguments = @ptrCast(args.arguments),
|
.arguments = @ptrCast(args.arguments),
|
||||||
.results = @ptrCast(args.results),
|
.results = @ptrCast(args.results),
|
||||||
.events = @ptrCast(args.events),
|
.events = @ptrCast(args.events),
|
||||||
.non_donatable_input_indices = args.non_donatable_input_indices,
|
.non_donatable_input_indices = args.non_donatable_input_indices,
|
||||||
.context = args.context,
|
.context = args.context,
|
||||||
} });
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
||||||
|
|||||||
@ -22,6 +22,11 @@ pub const Platform = struct {
|
|||||||
target: Target,
|
target: Target,
|
||||||
pjrt_api: *const pjrt.Api,
|
pjrt_api: *const pjrt.Api,
|
||||||
pjrt_client: *pjrt.Client,
|
pjrt_client: *pjrt.Client,
|
||||||
|
|
||||||
|
// This make the pjrt struct quite fat, but is only used during compilation.
|
||||||
|
// TODO: Reconsider having it here, and maybe pass explicitly to compile,
|
||||||
|
// or create an intermediary struct:
|
||||||
|
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
||||||
compilation_options: CompilationOptions = .{},
|
compilation_options: CompilationOptions = .{},
|
||||||
|
|
||||||
pub const MAX_NUM_DEVICES: u8 = 32;
|
pub const MAX_NUM_DEVICES: u8 = 32;
|
||||||
@ -71,17 +76,6 @@ pub const Platform = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn registerFFIType(self: Platform, comptime T: type) !void {
|
|
||||||
if (self.pjrt_api.ffi()) |ffi| {
|
|
||||||
if (!@hasDecl(T, "type_id")) {
|
|
||||||
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
|
|
||||||
}
|
|
||||||
try ffi.registerTypeId(self.pjrt_api, T);
|
|
||||||
} else {
|
|
||||||
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *Platform) void {
|
pub fn deinit(self: *Platform) void {
|
||||||
self.pjrt_client.deinit(self.pjrt_api);
|
self.pjrt_client.deinit(self.pjrt_api);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -5,6 +5,7 @@ const mlir = @import("mlir");
|
|||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
|
const callback = @import("callback.zig");
|
||||||
const CompilationContext = @import("module.zig").CompilationContext;
|
const CompilationContext = @import("module.zig").CompilationContext;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
@ -3824,22 +3825,7 @@ pub const Tensor = struct {
|
|||||||
/// Only for debug purpose, it inserts device to host synchronization
|
/// Only for debug purpose, it inserts device to host synchronization
|
||||||
/// so it will slow down the program execution.
|
/// so it will slow down the program execution.
|
||||||
pub fn print(input: Tensor) Tensor {
|
pub fn print(input: Tensor) Tensor {
|
||||||
// TODO: find a way of doing print that doesn't involve a H2D copy.
|
return callback.call(callback.Print, .{input}, &.{input.shape()})[0];
|
||||||
return ops.addHostCallback(
|
|
||||||
&printCallback,
|
|
||||||
null,
|
|
||||||
&.{input},
|
|
||||||
&.{input.shape()},
|
|
||||||
.{ .output_operand_aliases = &.{0} },
|
|
||||||
)[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
|
||||||
const host_buffer = inputs[0];
|
|
||||||
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() });
|
|
||||||
// This is true because of the operand aliases.
|
|
||||||
// Since the result is already pointing to the input we don't need to modify the buffer.
|
|
||||||
std.debug.assert(host_buffer._data == outputs[0]._data);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -51,14 +51,13 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
|||||||
if (should_free_left) left.deinit(allocator);
|
if (should_free_left) left.deinit(allocator);
|
||||||
if (should_free_right) right.deinit(allocator);
|
if (should_free_right) right.deinit(allocator);
|
||||||
}
|
}
|
||||||
errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() });
|
errdefer log.err("\n--> Left: {0f}{0d:24.3}\n--> Right: {1f}{1d:24.3}", .{ left, right });
|
||||||
|
|
||||||
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
||||||
log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
|
log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
|
||||||
return error.TestUnexpectedResult;
|
return error.TestUnexpectedResult;
|
||||||
}
|
}
|
||||||
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
|
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
|
||||||
log.err("left.dtype ({}) != right.dtype ({})", .{ left.dtype(), right.dtype() });
|
log.err("left.dtype ({f}) != right.dtype ({f})", .{ left.shape(), right.shape() });
|
||||||
return error.TestUnexpectedResult;
|
return error.TestUnexpectedResult;
|
||||||
}
|
}
|
||||||
switch (left.dtype()) {
|
switch (left.dtype()) {
|
||||||
@ -89,7 +88,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
|||||||
const right_data = right.items(R);
|
const right_data = right.items(R);
|
||||||
for (left_data, right_data, 0..) |l, r, i| {
|
for (left_data, right_data, 0..) |l, r, i| {
|
||||||
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
|
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
|
||||||
log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
|
log.err("left.data != right_data.\n < {d:40.3} \n > {d:40.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ stdx.fmt.slice(center(left_data, i)), stdx.fmt.slice(center(right_data, i)), i, left_data[i], right_data[i] });
|
||||||
return error.TestUnexpectedResult;
|
return error.TestUnexpectedResult;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -6,11 +6,13 @@
|
|||||||
// Namespaces
|
// Namespaces
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
|
pub const platform_specific = @import("c");
|
||||||
pub const tokenizer = @import("zml/tokenizer");
|
pub const tokenizer = @import("zml/tokenizer");
|
||||||
|
|
||||||
pub const aio = @import("aio.zig");
|
pub const aio = @import("aio.zig");
|
||||||
pub const Buffer = @import("buffer.zig").Buffer;
|
pub const Buffer = @import("buffer.zig").Buffer;
|
||||||
pub const Bufferized = @import("tensor.zig").Bufferized;
|
pub const Bufferized = @import("tensor.zig").Bufferized;
|
||||||
|
pub const callback = @import("callback.zig");
|
||||||
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
|
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
|
||||||
pub const context = @import("context.zig");
|
pub const context = @import("context.zig");
|
||||||
pub const Context = @import("context.zig").Context;
|
pub const Context = @import("context.zig").Context;
|
||||||
@ -43,7 +45,6 @@ pub const Tensor = @import("tensor.zig").Tensor;
|
|||||||
pub const testing = @import("testing.zig");
|
pub const testing = @import("testing.zig");
|
||||||
pub const torch = @import("torch.zig");
|
pub const torch = @import("torch.zig");
|
||||||
|
|
||||||
// pub const tokenizer = @import("tokenizer.zig");
|
|
||||||
pub const tools = struct {
|
pub const tools = struct {
|
||||||
pub const Tracer = @import("tools/tracer.zig").Tracer;
|
pub const Tracer = @import("tools/tracer.zig").Tracer;
|
||||||
};
|
};
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user