Add experimental zml.callback API (renamed from custom_call) and fix tensor.print(); update PJRT bindings, host buffer utilities, and related core ZML modules.
This commit is contained in:
parent
1fa056a790
commit
cc969bd532
130
pjrt/ffi.zig
130
pjrt/ffi.zig
@ -2,12 +2,21 @@
|
||||
const std = @import("std");
|
||||
|
||||
const c = @import("c");
|
||||
pub const TypeId = c.XLA_FFI_TypeId;
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const pjrtStruct = @import("pjrt.zig").pjrtStruct;
|
||||
const pjrt = @import("pjrt.zig");
|
||||
const Stream = @import("pjrt.zig").Stream;
|
||||
|
||||
const log = std.log.scoped(.pjrt);
|
||||
|
||||
comptime {
|
||||
if (@typeInfo(TypeId).@"struct".fields.len != 1) @compileError("TypeId has changed");
|
||||
}
|
||||
|
||||
/// The signature of a generic custom call.
|
||||
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
|
||||
|
||||
pub const ApiVersion = extern struct {
|
||||
pub const major = c.XLA_FFI_API_MAJOR;
|
||||
pub const minor = c.XLA_FFI_API_MINOR;
|
||||
@ -29,13 +38,13 @@ pub const ExtensionBase = extern struct {
|
||||
};
|
||||
|
||||
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
|
||||
pub const HandlerTraits = packed struct(u32) {
|
||||
pub const HandlerTraits = packed struct(c_uint) {
|
||||
/// Calls to FFI handler are safe to trace into the command buffer.
|
||||
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
|
||||
/// that can be captured and then replayed.
|
||||
command_buffer_compatible: u1,
|
||||
command_buffer_compatible: bool,
|
||||
|
||||
__unassigned__: u31,
|
||||
__unassigned__: u31 = 0,
|
||||
};
|
||||
|
||||
pub const Metadata = extern struct {
|
||||
@ -49,25 +58,6 @@ pub const MetadataExtension = extern struct {
|
||||
metadata: ?*Metadata,
|
||||
};
|
||||
|
||||
pub const ApiError = error{
|
||||
Cancelled,
|
||||
Unknown,
|
||||
InvalidArgument,
|
||||
DeadlineExceeded,
|
||||
NotFound,
|
||||
AlreadyExists,
|
||||
PermissionDenied,
|
||||
ResourceExhausted,
|
||||
FailedPrecondition,
|
||||
Aborted,
|
||||
OutOfRange,
|
||||
Unimplemented,
|
||||
Internal,
|
||||
Unavailable,
|
||||
DataLoss,
|
||||
Unauthenticated,
|
||||
};
|
||||
|
||||
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
||||
return struct {
|
||||
pub fn to(self: anytype) switch (@TypeOf(self)) {
|
||||
@ -91,8 +81,8 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
||||
pub const Api = opaque {
|
||||
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
||||
|
||||
pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
||||
pub fn stream(self: *const Api, context: *const ExecutionContext) *pjrt.Stream {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
||||
.ctx = @constCast(context.inner()),
|
||||
});
|
||||
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
||||
@ -107,8 +97,8 @@ pub const Api = opaque {
|
||||
return @ptrCast(ret.stream.?);
|
||||
}
|
||||
|
||||
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
|
||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
||||
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) pjrt.ApiError!*anyopaque {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
||||
.ctx = @constCast(context.inner()),
|
||||
.size = size,
|
||||
.alignment = alignment,
|
||||
@ -127,8 +117,8 @@ pub const Api = opaque {
|
||||
return ret.data.?;
|
||||
}
|
||||
|
||||
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
|
||||
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
||||
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) pjrt.ApiError!void {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
||||
.ctx = @constCast(context.inner()),
|
||||
.size = size,
|
||||
.data = data,
|
||||
@ -163,17 +153,17 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
|
||||
pub const ExecutionContext = opaque {
|
||||
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
||||
|
||||
pub fn Context(comptime T: type) type {
|
||||
return struct {
|
||||
pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
|
||||
const type_id: TypeId = .{ .type_id = T.type_id };
|
||||
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
|
||||
.ctx = @constCast(self.inner()),
|
||||
.type_id = @constCast(&type_id.toCStruct()),
|
||||
});
|
||||
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
||||
pub fn getContext(self: *const ExecutionContext, type_id: TypeId, api: *const Api) pjrt.ApiError!*anyopaque {
|
||||
var ret: c.XLA_FFI_ExecutionContext_Get_Args = .{
|
||||
.struct_size = pjrt.pjrtStructSize(c.XLA_FFI_ExecutionContext_Get_Args),
|
||||
.extension_start = api.inner().extension_start,
|
||||
.ctx = @ptrCast(@constCast(self)),
|
||||
.type_id = @constCast(&type_id),
|
||||
.data = undefined, // set by XLA_FFI_ExecutionContext_Get.
|
||||
};
|
||||
const maybe_err = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
||||
|
||||
if (result) |ffi_error| {
|
||||
if (maybe_err) |ffi_error| {
|
||||
const err = Error.fromInner(ffi_error);
|
||||
defer err.destroy(api);
|
||||
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
|
||||
@ -183,13 +173,11 @@ pub const ExecutionContext = opaque {
|
||||
}
|
||||
|
||||
if (ret.data == null) return error.NotFound;
|
||||
return @ptrCast(@alignCast(ret.data.?));
|
||||
}
|
||||
};
|
||||
return ret.data.?;
|
||||
}
|
||||
|
||||
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
|
||||
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
|
||||
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) pjrt.ApiError!i32 {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
|
||||
.ctx = @constCast(self.inner()),
|
||||
});
|
||||
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
|
||||
@ -206,8 +194,10 @@ pub const ExecutionContext = opaque {
|
||||
return ret.device_ordinal;
|
||||
}
|
||||
|
||||
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
|
||||
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
|
||||
const Task = fn (*anyopaque) void;
|
||||
|
||||
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) pjrt.ApiError!void {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
|
||||
.ctx = @constCast(self.inner()),
|
||||
.task = @ptrCast(@alignCast(task)),
|
||||
.data = @ptrCast(@alignCast(data)),
|
||||
@ -225,23 +215,9 @@ pub const ExecutionContext = opaque {
|
||||
return error.Unknown;
|
||||
}
|
||||
}
|
||||
|
||||
fn getTypeId(type_name: []const u8) TypeId {
|
||||
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
|
||||
|
||||
return .{
|
||||
.type_id = id,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const TypeId = c.XLA_FFI_TypeId;
|
||||
|
||||
const Task = fn (*anyopaque) void;
|
||||
|
||||
const Stream = @import("pjrt.zig").Stream;
|
||||
|
||||
const ByteSpan = extern struct {
|
||||
pub const ByteSpan = extern struct {
|
||||
ptr: [*]const u8,
|
||||
len: usize,
|
||||
|
||||
@ -252,7 +228,7 @@ const ByteSpan = extern struct {
|
||||
|
||||
pub const DataType = enum(c.XLA_FFI_DataType) {
|
||||
invalid = c.XLA_FFI_DataType_INVALID,
|
||||
pred = c.XLA_FFI_DataType_PRED,
|
||||
bool = c.XLA_FFI_DataType_PRED,
|
||||
i8 = c.XLA_FFI_DataType_S8,
|
||||
i16 = c.XLA_FFI_DataType_S16,
|
||||
i32 = c.XLA_FFI_DataType_S32,
|
||||
@ -399,6 +375,8 @@ pub const Attrs = extern struct {
|
||||
}
|
||||
};
|
||||
|
||||
/// All informations needed by the user callback,
|
||||
/// including the list of input/ouput buffers to work on.
|
||||
pub const CallFrame = extern struct {
|
||||
struct_size: usize,
|
||||
extension_start: ?*ExtensionBase,
|
||||
@ -422,9 +400,11 @@ pub const CallFrame = extern struct {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
|
||||
pub fn stream(call_frame: CallFrame) ?*const pjrt.Stream {
|
||||
return call_frame.api.stream(call_frame.ctx);
|
||||
}
|
||||
};
|
||||
|
||||
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
||||
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
||||
@ -444,7 +424,7 @@ pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
||||
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
|
||||
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
|
||||
|
||||
pub fn toApiError(code: ErrorCode) ApiError {
|
||||
pub fn toApiError(code: ErrorCode) pjrt.ApiError {
|
||||
return switch (code) {
|
||||
.cancelled => error.Cancelled,
|
||||
.unknown => error.Unknown,
|
||||
@ -470,8 +450,10 @@ pub const Error = opaque {
|
||||
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
||||
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
||||
|
||||
pub const ok: ?*Error = null;
|
||||
|
||||
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
||||
.message = message.ptr,
|
||||
.errc = @intFromEnum(error_code),
|
||||
});
|
||||
@ -479,12 +461,12 @@ pub const Error = opaque {
|
||||
}
|
||||
|
||||
pub fn destroy(err: *Error, api: *const Api) void {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
|
||||
api.inner().XLA_FFI_Error_Destroy.?(&ret);
|
||||
}
|
||||
|
||||
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
|
||||
.@"error" = err.inner(),
|
||||
});
|
||||
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
|
||||
@ -496,8 +478,8 @@ pub const Future = opaque {
|
||||
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
|
||||
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
|
||||
|
||||
pub fn create(api: *const Api) ApiError!*Future {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{});
|
||||
pub fn create(api: *const Api) pjrt.ApiError!*Future {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_Create_Args{});
|
||||
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
|
||||
|
||||
if (result) |ffi_error| {
|
||||
@ -512,8 +494,8 @@ pub const Future = opaque {
|
||||
return fromInner(ret.future.?);
|
||||
}
|
||||
|
||||
pub fn setAvailable(self: *Future, api: *const Api) ApiError!void {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
|
||||
pub fn setAvailable(self: *Future, api: *const Api) pjrt.ApiError!void {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
|
||||
.future = self.inner(),
|
||||
});
|
||||
|
||||
@ -529,8 +511,8 @@ pub const Future = opaque {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void {
|
||||
var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{
|
||||
pub fn setError(self: *Future, api: *const Api, err: *Error) pjrt.ApiError!void {
|
||||
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetError_Args{
|
||||
.future = self.inner(),
|
||||
.@"error" = err.inner(),
|
||||
});
|
||||
|
||||
@ -20,7 +20,7 @@ test {
|
||||
// as the way PJRT does it is not very robust.
|
||||
//
|
||||
// 1. https://github.com/openxla/xla/issues/10032
|
||||
fn pjrtStructSize(comptime T: type) usize {
|
||||
pub fn pjrtStructSize(comptime T: type) usize {
|
||||
// unsafe on purpose, we want this to fail if that ever changes
|
||||
const typedef_name = comptime blk: {
|
||||
const needle = ".struct_";
|
||||
@ -164,7 +164,7 @@ pub const Api = struct {
|
||||
return @ptrCast(ret.context.?);
|
||||
}
|
||||
|
||||
pub fn ffi(api: *const Api) ?FFI {
|
||||
pub fn ffi(api: *const Api) ?Ffi {
|
||||
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
||||
return .{ .inner = ext };
|
||||
}
|
||||
@ -1278,7 +1278,7 @@ pub const NamedValue = extern struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const FFI = extern struct {
|
||||
pub const Ffi = extern struct {
|
||||
inner: *const c.PJRT_FFI,
|
||||
|
||||
pub const UserData = extern struct {
|
||||
@ -1293,19 +1293,15 @@ pub const FFI = extern struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const RegisterFfiOptions = struct {
|
||||
traits: RegisterHandlerTraits = @enumFromInt(0),
|
||||
};
|
||||
|
||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||
pub fn register(
|
||||
self: *const FFI,
|
||||
self: *const Ffi,
|
||||
api: *const Api,
|
||||
target_name: []const u8,
|
||||
platform_name: []const u8,
|
||||
func: *const ffi.Handler,
|
||||
options: RegisterFfiOptions,
|
||||
traits: ffi.HandlerTraits,
|
||||
) ApiError!void {
|
||||
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
||||
.target_name = target_name.ptr,
|
||||
@ -1313,7 +1309,7 @@ pub const FFI = extern struct {
|
||||
.handler = @ptrCast(@constCast(func)),
|
||||
.platform_name = platform_name.ptr,
|
||||
.platform_name_size = platform_name.len,
|
||||
.traits = @intFromEnum(options.traits),
|
||||
.traits = @bitCast(traits),
|
||||
});
|
||||
const result = self.inner.register_handler.?(&ret);
|
||||
if (result) |pjrt_c_error| {
|
||||
@ -1323,8 +1319,7 @@ pub const FFI = extern struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
|
||||
const type_name = @typeName(T);
|
||||
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId {
|
||||
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
|
||||
.type_name = type_name.ptr,
|
||||
.type_name_size = type_name.len,
|
||||
@ -1336,10 +1331,10 @@ pub const FFI = extern struct {
|
||||
return pjrt_error.getCode(api).toApiError();
|
||||
}
|
||||
|
||||
T.type_id = ret.type_id;
|
||||
return .{ .type_id = ret.type_id };
|
||||
}
|
||||
|
||||
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
|
||||
pub fn addUserData(self: *const Ffi, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
|
||||
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
|
||||
.context = @ptrCast(context),
|
||||
.user_data = user_data.toCStruct(),
|
||||
@ -1352,12 +1347,3 @@ pub const FFI = extern struct {
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
|
||||
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
|
||||
_,
|
||||
};
|
||||
|
||||
pub const CustomCallRegistry = extern struct {
|
||||
inner: *const c.PJRT_FFI_Register_Handler,
|
||||
};
|
||||
|
||||
@ -5,7 +5,7 @@ zig_shared_library(
|
||||
name = "zmlxcuda",
|
||||
# Use Clang's compiler-rt, but disable stack checking
|
||||
# to avoid requiring on the _zig_probe_stack symbol.
|
||||
copts = ["-fno-stack-check"],
|
||||
copts = ["-fno-stack-check", "-fllvm"],
|
||||
main = "zmlxcuda.zig",
|
||||
shared_lib_name = "libzmlxcuda.so.0",
|
||||
visibility = ["@libpjrt_cuda//:__subpackages__"],
|
||||
|
||||
156
stdx/fmt.zig
156
stdx/fmt.zig
@ -1,145 +1,117 @@
|
||||
const std = @import("std");
|
||||
|
||||
pub const Fmt = union(enum) {
|
||||
int: IntFmt,
|
||||
float: FloatFmt,
|
||||
generic: void,
|
||||
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
|
||||
return .{ .slice = any_slice };
|
||||
}
|
||||
|
||||
pub fn parse(T: type, comptime fmt_: []const u8) Fmt {
|
||||
fn FmtSlice(T: type) type {
|
||||
return struct {
|
||||
slice: []const T,
|
||||
|
||||
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||
return switch (@typeInfo(T)) {
|
||||
.float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) },
|
||||
.int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) },
|
||||
else => .{ .generic = {} },
|
||||
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
|
||||
.comptime_int, .int => try formatIntSlice(f.slice, n, writer),
|
||||
.bool => try formatBoolSlice(f.slice, n, writer),
|
||||
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
|
||||
try formatComplexSlice(f.slice, n, writer);
|
||||
} else if (@hasDecl(T, "toF32")) {
|
||||
try formatFloatSlice(f.slice, n, writer);
|
||||
} else {
|
||||
try formatSliceAny(f.slice, n, writer);
|
||||
},
|
||||
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const FullFormatOptions = struct {
|
||||
fmt: Fmt,
|
||||
options: std.fmt.FormatOptions,
|
||||
};
|
||||
|
||||
pub const IntFmt = struct {
|
||||
base: u8,
|
||||
case: std.fmt.Case = .lower,
|
||||
|
||||
pub fn parseComptime(comptime fmt_: []const u8) IntFmt {
|
||||
return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_);
|
||||
}
|
||||
|
||||
pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt {
|
||||
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d"))
|
||||
.{ .base = 10, .case = .lower }
|
||||
else if (std.mem.eql(u8, fmt_, "x"))
|
||||
.{ .base = 16, .case = .lower }
|
||||
else if (std.mem.eql(u8, fmt_, "X"))
|
||||
.{ .base = 16, .case = .upper }
|
||||
else if (std.mem.eql(u8, fmt_, "o"))
|
||||
.{ .base = 8, .case = .upper }
|
||||
else
|
||||
// TODO: unicode/ascii
|
||||
error.InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
pub const FloatFmt = enum(u8) {
|
||||
scientific = @intFromEnum(std.fmt.Number.Mode.scientific),
|
||||
decimal = @intFromEnum(std.fmt.Number.Mode.decimal),
|
||||
hex,
|
||||
|
||||
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
|
||||
return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_);
|
||||
}
|
||||
|
||||
pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt {
|
||||
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e"))
|
||||
.scientific
|
||||
else if (std.mem.eql(u8, fmt_, "d"))
|
||||
.decimal
|
||||
else if (std.mem.eql(u8, fmt_, "x"))
|
||||
.hex
|
||||
else
|
||||
error.InvalidArgument;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
return switch (@typeInfo(@TypeOf(value))) {
|
||||
.comptime_float, .float => try formatFloatValue(value, full, writer),
|
||||
.comptime_int, .int => try formatIntValue(value, full, writer),
|
||||
else => try formatAnyValue(value, full, writer),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||
pub fn formatFloat(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
const x = switch (@typeInfo(@TypeOf(value))) {
|
||||
.@"struct" => value.toF32(),
|
||||
.float => value,
|
||||
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
|
||||
};
|
||||
try switch (full.fmt.float) {
|
||||
.scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }),
|
||||
.decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }),
|
||||
.hex => writer.printFloatHexOptions(x, .{ .mode = .hex }),
|
||||
else => @compileError("formatFloat expects a float, got: " ++ @typeName(@TypeOf(value))),
|
||||
};
|
||||
return writer.printFloat(x, spec);
|
||||
}
|
||||
|
||||
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||
pub fn formatInt(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
switch (@typeInfo(@TypeOf(value))) {
|
||||
.int => {},
|
||||
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
|
||||
else => @compileError("formatInt expects an int, got: " ++ @typeName(@TypeOf(value))),
|
||||
}
|
||||
return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options);
|
||||
return writer.printInt(value, spec.mode.base().?, spec.case, .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||
}
|
||||
|
||||
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
|
||||
pub fn formatComplex(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
try writer.writeAll(".{.re=");
|
||||
try writer.printFloat(value.re, spec);
|
||||
try writer.writeAll(", .im=");
|
||||
try writer.printFloat(value.im, spec);
|
||||
try writer.writeAll("}");
|
||||
}
|
||||
|
||||
pub fn formatBool(value: bool, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
try writer.alignBufferOptions(if (value) "1" else "0", .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||
}
|
||||
|
||||
pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
var buf: [48]u8 = undefined;
|
||||
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
|
||||
const T = @TypeOf(value);
|
||||
const fmt = if (@hasDecl(T, "formatNumber")) "{d}" else "{f}";
|
||||
|
||||
const s = std.fmt.bufPrint(&buf, fmt, .{value}) catch blk: {
|
||||
buf[45..].* = "...".*;
|
||||
break :blk buf[0..];
|
||||
};
|
||||
return try writer.alignBufferOptions(s, full.options);
|
||||
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||
}
|
||||
|
||||
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
|
||||
// Write first rows
|
||||
const num_cols: usize = full.options.width orelse 12;
|
||||
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
// use the format "width" for the number of columns instead of individual width.
|
||||
const num_cols: usize = spec.width orelse 12;
|
||||
var my_options = spec;
|
||||
my_options.width = null;
|
||||
const n: usize = values.len;
|
||||
|
||||
_ = try writer.write("{");
|
||||
if (n <= num_cols) {
|
||||
for (values, 0..) |v, i| {
|
||||
// Force inlining so that the switch and the buffer can be done once.
|
||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||
if (i < n - 1) _ = try writer.write(",");
|
||||
}
|
||||
} else {
|
||||
const half = @divFloor(num_cols, 2);
|
||||
for (values[0..half]) |v| {
|
||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||
_ = try writer.write(",");
|
||||
}
|
||||
_ = try writer.write(" ..., ");
|
||||
for (values[n - half ..], 0..) |v, i| {
|
||||
try @call(.always_inline, fmt_func, .{ v, full, writer });
|
||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
||||
if (i < half - 1) _ = try writer.write(",");
|
||||
}
|
||||
}
|
||||
_ = try writer.write("}");
|
||||
}
|
||||
|
||||
pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
||||
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
return try formatSliceCustom(formatAny, values, spec, writer);
|
||||
}
|
||||
|
||||
pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
return try formatSliceCustom(formatFloatValue, values, full, writer);
|
||||
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
return try formatSliceCustom(formatFloat, values, spec, writer);
|
||||
}
|
||||
|
||||
pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
return try formatSliceCustom(formatIntValue, values, full, writer);
|
||||
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
return try formatSliceCustom(formatInt, values, spec, writer);
|
||||
}
|
||||
|
||||
pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
|
||||
return try formatSliceCustom(formatAnyValue, values, full, writer);
|
||||
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
return try formatSliceCustom(formatComplex, values, spec, writer);
|
||||
}
|
||||
|
||||
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
||||
return try formatSliceCustom(formatBool, values, spec, writer);
|
||||
}
|
||||
|
||||
@ -31,6 +31,7 @@ zig_library(
|
||||
"aio/torch/py.zig",
|
||||
"buffer.zig",
|
||||
"context.zig",
|
||||
"callback.zig",
|
||||
"dtype.zig",
|
||||
"exe.zig",
|
||||
"floats.zig",
|
||||
@ -53,7 +54,7 @@ zig_library(
|
||||
"torch.zig",
|
||||
"zml.zig",
|
||||
],
|
||||
copts = ["-lc"],
|
||||
copts = ["-lc", "-freference-trace=20"],
|
||||
main = "zml.zig",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
|
||||
@ -49,7 +49,7 @@ pub const Buffer = struct {
|
||||
|
||||
pub const FromOptions = struct {
|
||||
wait: bool = true,
|
||||
memory: ?pjrt.Memory.Kind = null,
|
||||
memory: ?Memory = null,
|
||||
};
|
||||
|
||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||
@ -89,15 +89,20 @@ pub const Buffer = struct {
|
||||
.byte_strides = byte_strides,
|
||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||
};
|
||||
if (opts.memory) |memory_kind| {
|
||||
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||
const memory = for (memories) |m| {
|
||||
const kind = m.kind(platform.pjrt_api);
|
||||
if (kind == memory_kind) break m;
|
||||
} else return error.NotFound;
|
||||
args.memory = memory;
|
||||
} else {
|
||||
if (platform.target == .cpu or opts.memory == null) {
|
||||
args.device = devices[i];
|
||||
} else {
|
||||
const memory = opts.memory.?;
|
||||
const device_memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||
// TODO measure the cost of this and consider caching on Zig side inside the platform.
|
||||
const selected_memory = for (device_memories) |m| {
|
||||
const kind = m.kind(platform.pjrt_api);
|
||||
if (kind == memory.toPjrtMemory()) break m;
|
||||
} else {
|
||||
log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) });
|
||||
return error.NotFound;
|
||||
};
|
||||
args.memory = selected_memory;
|
||||
}
|
||||
|
||||
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
|
||||
@ -179,10 +184,10 @@ pub const Buffer = struct {
|
||||
return try from(platform, host_buffer, opts);
|
||||
}
|
||||
|
||||
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
|
||||
// TODO restore assert
|
||||
pub fn asHostBuffer(self: Buffer) HostBuffer {
|
||||
// TODO: skip this check on cpu
|
||||
// const memory = self.getMemory().kind(self._api);
|
||||
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
|
||||
// stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
|
||||
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
|
||||
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
|
||||
}
|
||||
@ -299,6 +304,12 @@ pub const Buffer = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 {
|
||||
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
|
||||
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
|
||||
return @ptrCast(opaque_ptr);
|
||||
}
|
||||
|
||||
/// Fetches the content of the given buffer into a stack variable of the given type.
|
||||
pub fn getValue(self: Buffer, T: type) !T {
|
||||
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
|
||||
@ -390,13 +401,31 @@ pub const Buffer = struct {
|
||||
return @reduce(.Or, self._shape._sharding_info);
|
||||
}
|
||||
|
||||
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
|
||||
pub const CopyToMemoryOpts = struct {
|
||||
wait: bool = true,
|
||||
};
|
||||
|
||||
pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer {
|
||||
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory());
|
||||
if (pjrt_memory == null) {
|
||||
stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self });
|
||||
}
|
||||
|
||||
var new_shards: Buffer.Shards = .{};
|
||||
for (self._shards.slice()) |shard| {
|
||||
const new_shard = try shard.copyToMemory(self._api, memory);
|
||||
const new_shard = try shard.copyToMemory(self._api, pjrt_memory.?);
|
||||
new_shards.appendAssumeCapacity(new_shard);
|
||||
}
|
||||
|
||||
if (opts.wait) {
|
||||
for (new_shards.constSlice()) |shard| {
|
||||
const event = shard.getReadyEvent(self._api);
|
||||
if (event) |e| {
|
||||
try e.awaitBlocking(self._api);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||
}
|
||||
|
||||
|
||||
313
zml/callback.zig
Normal file
313
zml/callback.zig
Normal file
@ -0,0 +1,313 @@
|
||||
const std = @import("std");
|
||||
|
||||
const asynk = @import("async");
|
||||
const mlir = @import("mlir");
|
||||
const pjrt = @import("pjrt");
|
||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const mlirx = @import("mlirx.zig");
|
||||
const pjrtx = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
|
||||
const log = std.log.scoped(.@"zml/callback");
|
||||
|
||||
/// Inserts a user-defined callback into the computation graph.
|
||||
/// The callback is defined with a struct, that store runtime information needed by the callback.
|
||||
///
|
||||
/// ❗Experimental API❗
|
||||
///
|
||||
/// ```zig
|
||||
/// pub const MyCallback = struct {
|
||||
/// // a unique type_id will be set by the PJRT plugin during registration.
|
||||
/// pub var type_id: pjrt.ffi.TypeId = undefined;
|
||||
///
|
||||
/// pub const callback_config: zml.callback.Config = .{
|
||||
/// // assumption this custom call makes about the input / output buffers
|
||||
/// };
|
||||
///
|
||||
/// // Required, this will tell the callback in which env it runs.
|
||||
/// platform: zml.Platform,
|
||||
/// // data needed by the callback
|
||||
/// my_data: []const u8,
|
||||
///
|
||||
/// // storage modified by the runtime to tell the callback where it should write its results.
|
||||
/// // Normally the callback doesn't need to allocate as the input and output buffers are given.
|
||||
/// results: [1]Buffer = undefined,
|
||||
///
|
||||
/// pub fn init(my_data: []const u8) !MyCallback {
|
||||
/// return .{ .my_data = my_data };
|
||||
/// }
|
||||
///
|
||||
/// pub fn call(callback: *MyCallback, input: Buffer) !void {
|
||||
/// // Do something with `input` and `callback.my_data`, write the results inside `callback.results[0]`
|
||||
/// }
|
||||
/// };
|
||||
/// ```
|
||||
///
|
||||
/// See eg the implementation of the `zml.callback.Print` callback, for a practical example.
|
||||
///
|
||||
/// Note calling this during the compilation of a module, isn't enough:
|
||||
///
|
||||
/// * backend need to be made aware of the callback, see `zml.Platform.registerCallback`
|
||||
/// * executable need to know the specific data needed by `MyCallback`, see `zml.Exe.bind`
|
||||
pub fn call(
|
||||
comptime Callback: type,
|
||||
inputs: TensorArgs(Callback),
|
||||
output_shapes: []const Shape,
|
||||
) []Tensor {
|
||||
checkIsValidCallback(Callback);
|
||||
|
||||
const ctx = CompilationContext.current();
|
||||
const allocator = ctx.allocator();
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
const platform = ctx._platform;
|
||||
const pjrt_api = platform.pjrt_api;
|
||||
|
||||
if (pjrt_api.ffi() == null) {
|
||||
stdx.debug.panic("Custom calls are not supported for target {s}", .{@tagName(platform.target)});
|
||||
}
|
||||
|
||||
const output_tensors = allocator.alloc(Tensor, output_shapes.len) catch @panic("OOM");
|
||||
// Note: we don't always free output_tensor, because it's returned to the caller.
|
||||
// It's also why we allocate it first so that it doesn't fragment the arena.
|
||||
errdefer allocator.free(output_tensors);
|
||||
|
||||
const output_types = allocator.alloc(mlir.Type, output_shapes.len) catch @panic("OOM");
|
||||
defer allocator.free(output_types);
|
||||
for (output_types, output_shapes) |*output_type, output_shape| {
|
||||
output_type.* = mlirx.tensorType(mlir_ctx, output_shape);
|
||||
}
|
||||
const input_values = allocator.alloc(mlir.Value, inputs.len) catch @panic("OOM");
|
||||
defer allocator.free(input_values);
|
||||
for (input_values, inputs) |*input_value, input_tensor| {
|
||||
input_value.* = input_tensor.value();
|
||||
}
|
||||
|
||||
const target_name = "zml$" ++ @typeName(Callback);
|
||||
const op = stablehlo.custom_call(
|
||||
mlir_ctx,
|
||||
input_values,
|
||||
.{
|
||||
.call_target_name = target_name,
|
||||
.api_version = .typed_ffi,
|
||||
.backend_config = .dict(mlir_ctx, &.{}),
|
||||
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
|
||||
.has_side_effect = true,
|
||||
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
|
||||
},
|
||||
output_types,
|
||||
mlir_ctx.location(@src()),
|
||||
);
|
||||
|
||||
for (output_tensors, output_shapes, 0..) |*output_tensor, output_shape, i| {
|
||||
output_tensor.* = Tensor._result(output_shape, op.result(i));
|
||||
}
|
||||
return output_tensors;
|
||||
}
|
||||
|
||||
/// Describe properties of a callback
|
||||
///
|
||||
/// * output_operand_aliases: the callback reuse input buffer to write the output
|
||||
/// * copy_inputs_to_host_pinned: the callback need to work on host visible buffers
|
||||
/// * traits: PJRT specified properties of the callback
|
||||
pub const Config = struct {
|
||||
output_operand_aliases: []const i64 = &.{},
|
||||
copy_inputs_to_host_pinned: bool = false,
|
||||
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
|
||||
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
|
||||
// TODO: handle sharded inputs
|
||||
};
|
||||
|
||||
/// Compile-time check that a callback has all informations we require.
|
||||
pub fn checkIsValidCallback(Callback: type) void {
|
||||
stdx.debug.assertComptime(@hasDecl(Callback, "call"), "Expected callback {} to have a call method", .{Callback});
|
||||
const ArgsT = stdx.meta.FnArgs(Callback.call);
|
||||
inline for (@typeInfo(ArgsT).@"struct".fields[1..]) |field| {
|
||||
stdx.debug.assertComptime(field.type == Buffer, "Expected callback {}.call arguments to be of type zml.Buffer, got {}", .{ Callback, field.type });
|
||||
}
|
||||
|
||||
stdx.debug.assertComptime(@hasDecl(Callback, "type_id") and @TypeOf(Callback.type_id) == pjrt.ffi.TypeId, "Expected callback {} to have a field `pub var type_id: pjrt.ffi.TypeId`", .{Callback});
|
||||
stdx.debug.assertComptime(@hasDecl(Callback, "callback_config") and @TypeOf(Callback.callback_config) == Config, "Expected callback {} to have a field `pub const callback_config: zml.CustomCallOptions`", .{Callback});
|
||||
}
|
||||
|
||||
pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
|
||||
checkIsValidCallback(Callback);
|
||||
|
||||
const ffi = platform.pjrt_api.ffi() orelse return error.Unavailable;
|
||||
const target_name = "zml$" ++ @typeName(Callback);
|
||||
|
||||
const proxy_cb = proxy(Callback);
|
||||
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback));
|
||||
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
|
||||
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
|
||||
}
|
||||
|
||||
fn proxy(Callback: type) pjrt.ffi.Handler {
|
||||
return struct {
|
||||
pub fn cb(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
|
||||
return CallbackImpl(Callback, call_frame);
|
||||
}
|
||||
}.cb;
|
||||
}
|
||||
|
||||
fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt.ffi.Error {
|
||||
if (call_frame.registeringHook()) return null;
|
||||
|
||||
const opts = Callback.callback_config;
|
||||
|
||||
const execution_context = call_frame.ctx;
|
||||
log.debug("Custom call {s} called !", .{@typeName(Callback)});
|
||||
const user_ctx_opaque = execution_context.getContext(Callback.type_id, call_frame.api) catch {
|
||||
log.err("{} user data was never given for current executable", .{Callback});
|
||||
return .create(call_frame.api, .failed_precondition, "failed to fetch user context" ++ @typeName(Callback));
|
||||
};
|
||||
const user_ctx: *Callback = @ptrCast(@alignCast(user_ctx_opaque));
|
||||
// We actually have one more constraint here, we force the Callback to have a platform field,
|
||||
// and to correctly set it.
|
||||
// Is this good ? We could also simplify this by registering ourselves the `Platform` type id.
|
||||
const platform: Platform = user_ctx.platform;
|
||||
|
||||
// Hook to get a cuda stream in the callback.
|
||||
if (@hasField(Callback, "stream") and platform.target != .cpu) {
|
||||
const stream = call_frame.api.stream(execution_context);
|
||||
user_ctx.stream = stream;
|
||||
}
|
||||
|
||||
var callback_args: std.meta.ArgsTuple(@TypeOf(Callback.call)) = undefined;
|
||||
callback_args[0] = user_ctx;
|
||||
|
||||
inline for (1..callback_args.len, call_frame.args.buffers()) |i, ffi_buffer| {
|
||||
const shape = shapeFromFfi(ffi_buffer);
|
||||
var zml_buffer: Buffer = if (platform.target == .cpu)
|
||||
.asViewOfHostBuffer(platform, .fromBytes(shape, ffi_buffer.data[0..shape.byteSize()]))
|
||||
else
|
||||
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
|
||||
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||
log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() });
|
||||
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
|
||||
log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err });
|
||||
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
|
||||
};
|
||||
log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* });
|
||||
}
|
||||
callback_args[i] = zml_buffer;
|
||||
}
|
||||
|
||||
defer {
|
||||
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
|
||||
inline for (1..callback_args.len) |i| callback_args[i].deinit();
|
||||
}
|
||||
}
|
||||
|
||||
for (0..call_frame.results.len) |i| {
|
||||
const ffi_buffer = call_frame.results.buffers()[i];
|
||||
const ffi_buffer_shape = shapeFromFfi(ffi_buffer);
|
||||
|
||||
if (platform.target == .cpu) {
|
||||
user_ctx.results[i] = Buffer.asViewOfHostBuffer(platform, HostBuffer.fromBytes(ffi_buffer_shape, ffi_buffer.data[0..ffi_buffer_shape.byteSize()]));
|
||||
} else {
|
||||
user_ctx.results[i] = Buffer.asViewOfDeviceBuffer(platform, shapeFromFfi(ffi_buffer), null, ffi_buffer.data);
|
||||
}
|
||||
}
|
||||
|
||||
@call(.auto, Callback.call, callback_args) catch |err| {
|
||||
log.err("Callback {} failed with {}", .{ Callback, err });
|
||||
return .create(call_frame.api, .internal, "internal callback error");
|
||||
};
|
||||
|
||||
return .ok;
|
||||
}
|
||||
|
||||
/// Internal custom calls.
|
||||
/// These are not meant to be used by users, but rather by the library itself.
|
||||
pub const internal_callbacks = [_]type{
|
||||
Print,
|
||||
};
|
||||
|
||||
pub fn registerInternalCallbacks(platform: Platform) !void {
|
||||
inline for (internal_callbacks) |Callback| {
|
||||
try register(Callback, platform);
|
||||
// log.debug("Registered internal custom call {s} with type_id {d}", .{ @typeName(Callback), Callback.type_id.type_id });
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate user data data needed by the ZML provided custom calls.
|
||||
pub fn bindInternalCallbacks(
|
||||
arena: std.mem.Allocator,
|
||||
platform: Platform,
|
||||
ffi: pjrt.Ffi,
|
||||
execute_context: *pjrt.ExecuteContext,
|
||||
) (std.mem.Allocator.Error || pjrt.ApiError)!void {
|
||||
// Atm we don't have a mechanism to detect which ZML callbacks the executable needs,
|
||||
// so we always allocate.
|
||||
{
|
||||
// Print
|
||||
const print_ptr = try arena.create(Print);
|
||||
print_ptr.* = try .init(platform);
|
||||
try addUserData(Print, platform.pjrt_api, ffi, execute_context, print_ptr);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn addUserData(
|
||||
Callback: type,
|
||||
api: *const pjrt.Api,
|
||||
ffi: pjrt.Ffi,
|
||||
execute_context: *pjrt.ExecuteContext,
|
||||
user_data: *Callback,
|
||||
) pjrt.ApiError!void {
|
||||
try ffi.addUserData(
|
||||
api,
|
||||
execute_context,
|
||||
.{ .type_id = Callback.type_id.type_id, .user_data = @ptrCast(user_data) },
|
||||
);
|
||||
log.debug("Bound {s}@{x} with type id {d} on {any}", .{ @typeName(Callback), @intFromPtr(user_data), Callback.type_id.type_id, execute_context });
|
||||
}
|
||||
|
||||
/// The print callback
|
||||
pub const Print = struct {
|
||||
// a unique type_id will be set by the PJRT plugin during registration.
|
||||
pub var type_id: pjrt.ffi.TypeId = undefined;
|
||||
|
||||
pub const callback_config: Config = .{
|
||||
// Print callback pretends to modify the given input buffer, but just returns it unmodified.
|
||||
.output_operand_aliases = &.{0},
|
||||
// It also needs PJRT to copy the data on the host first so it can print it.
|
||||
.copy_inputs_to_host_pinned = true,
|
||||
// Print is fairly predictable and can be captured in an execution graph.
|
||||
.traits = .{ .command_buffer_compatible = false },
|
||||
};
|
||||
|
||||
platform: Platform,
|
||||
results: [1]Buffer = undefined,
|
||||
|
||||
pub fn init(platform: Platform) !Print {
|
||||
return .{ .platform = platform };
|
||||
}
|
||||
|
||||
pub fn call(_: *Print, input: Buffer) !void {
|
||||
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {d:20.3}", .{ input, input.asHostBuffer() });
|
||||
}
|
||||
};
|
||||
|
||||
fn shapeFromFfi(ffi_buffer: *const pjrt.ffi.Buffer) Shape {
|
||||
const dt: DataType = switch (ffi_buffer.dtype) {
|
||||
.invalid => stdx.debug.panic("Invalid FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
|
||||
.token, .f8e4m3, .f8e3m4 => stdx.debug.panic("Unsupported FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
|
||||
inline else => |t| @field(DataType, @tagName(t)),
|
||||
};
|
||||
return Shape.init(ffi_buffer.dims(), dt);
|
||||
}
|
||||
|
||||
fn TensorArgs(Callback: type) type {
|
||||
const ArgsT = stdx.meta.FnArgs(Callback.call);
|
||||
|
||||
const args = @typeInfo(ArgsT).@"struct".fields;
|
||||
return [args.len - 1]Tensor;
|
||||
}
|
||||
172
zml/context.zig
172
zml/context.zig
@ -7,16 +7,19 @@ const runfiles = @import("runfiles");
|
||||
const runtimes = @import("runtimes");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Target = @import("platform.zig").Target;
|
||||
const zml_platform = @import("platform.zig");
|
||||
|
||||
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
||||
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
||||
const zml = struct {
|
||||
pub const callback = @import("callback.zig");
|
||||
pub const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
pub const Platform = @import("platform.zig").Platform;
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const Shape = @import("shape.zig").Shape;
|
||||
pub const Target = @import("platform.zig").Target;
|
||||
};
|
||||
|
||||
const PjrtApiMap = std.EnumArray(zml.Target, ?*const pjrt.Api);
|
||||
const PlatformsMap = std.EnumArray(zml.Target, ?zml.Platform);
|
||||
const log = std.log.scoped(.@"zml/context");
|
||||
|
||||
test {
|
||||
@ -94,7 +97,7 @@ pub const Context = struct {
|
||||
return .{ .platforms = PlatformsMap.initFill(null) };
|
||||
}
|
||||
|
||||
fn platformToLibrary(comptime target: Target) []const u8 {
|
||||
fn platformToLibrary(comptime target: zml.Target) []const u8 {
|
||||
const ext = switch (builtin.os.tag) {
|
||||
.windows => ".dll",
|
||||
.macos, .ios, .watchos => ".dylib",
|
||||
@ -105,7 +108,7 @@ pub const Context = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn pjrtApi(target: Target) *const pjrt.Api {
|
||||
pub fn pjrtApi(target: zml.Target) *const pjrt.Api {
|
||||
return Context.apis.get(target).?;
|
||||
}
|
||||
|
||||
@ -119,12 +122,12 @@ pub const Context = struct {
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
|
||||
const prefered_targets = [_]zml.Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
|
||||
|
||||
/// Automatically selects the best Platform loaded in the current Context.
|
||||
///
|
||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||
pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform {
|
||||
pub fn autoPlatform(self: *Context, opts: zml.Platform.CreateOptions) zml.Platform {
|
||||
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
|
||||
|
||||
return self.platformByPreferences(opts, &prefered_targets);
|
||||
@ -133,7 +136,7 @@ pub const Context = struct {
|
||||
/// Given a list of preferred targets to select the best Platform
|
||||
///
|
||||
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
|
||||
pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform {
|
||||
pub fn platformByPreferences(self: *Context, opts: zml.Platform.CreateOptions, prefered: []const zml.Target) zml.Platform {
|
||||
// Try prefered targets.
|
||||
for (prefered) |target| {
|
||||
if (apis.get(target) == null) continue;
|
||||
@ -150,7 +153,7 @@ pub const Context = struct {
|
||||
// CPU should only be use as fallback.
|
||||
if (target == .cpu) continue;
|
||||
if (entry.value.* == null) continue;
|
||||
if (std.mem.indexOfScalar(Target, prefered, target) != null) continue;
|
||||
if (std.mem.indexOfScalar(zml.Target, prefered, target) != null) continue;
|
||||
return self.platform(target, opts) catch |err| {
|
||||
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
|
||||
continue;
|
||||
@ -164,25 +167,25 @@ pub const Context = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform {
|
||||
pub fn platform(self: *Context, target: zml.Target, opts: zml.Platform.CreateOptions) !zml.Platform {
|
||||
if (self.platforms.get(target)) |p| {
|
||||
return p;
|
||||
}
|
||||
const api = Context.apis.get(target);
|
||||
if (api == null) return error.PlatformNotCompiled;
|
||||
const p = try Platform.init(target, api.?, opts);
|
||||
const p = try zml.Platform.init(target, api.?, opts);
|
||||
if (p.getDevices().len == 0) {
|
||||
log.err("No device found for platform {} !", .{target});
|
||||
return error.NoDevicesFound;
|
||||
}
|
||||
|
||||
try CustomCall.registerZmlCustomCalls(p);
|
||||
|
||||
self.platforms.set(target, p);
|
||||
try zml.callback.registerInternalCallbacks(p);
|
||||
|
||||
return p;
|
||||
}
|
||||
|
||||
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {
|
||||
pub fn printAvailablePlatforms(self: Context, selected: zml.Platform) void {
|
||||
// List available targets
|
||||
log.info("Available Platforms:", .{});
|
||||
const selected_prefix = "✅";
|
||||
@ -190,7 +193,7 @@ pub const Context = struct {
|
||||
const selected_postfix = "(AUTO-SELECTED)";
|
||||
const not_selected_postfix = "";
|
||||
|
||||
for (zml_platform.available_targets) |target| {
|
||||
for (zml.platform.available_targets) |target| {
|
||||
log.info(" {s} {s} {s}", .{
|
||||
if (target == selected.target) selected_prefix else not_selected_prefix,
|
||||
@tagName(target),
|
||||
@ -211,133 +214,4 @@ pub const Context = struct {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void;
|
||||
};
|
||||
|
||||
const CustomCall = struct {
|
||||
pub fn registerZmlCustomCalls(platform: Platform) !void {
|
||||
const ffi = platform.pjrt_api.ffi() orelse {
|
||||
log.warn("Registering custom calls failed: No FFI Extension found in {s} PJRT Plugin.", .{@tagName(platform.target)});
|
||||
return;
|
||||
};
|
||||
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
|
||||
}
|
||||
|
||||
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
|
||||
if (call_frame.registeringHook()) return null;
|
||||
|
||||
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
|
||||
std.debug.assert(callback_attr.dtype == .u64);
|
||||
const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize));
|
||||
|
||||
const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable;
|
||||
std.debug.assert(user_ctx_ptr.dtype == .u64);
|
||||
const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize));
|
||||
|
||||
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
|
||||
for (input_buffers, 0..) |*b, i| {
|
||||
b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
|
||||
}
|
||||
|
||||
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
|
||||
for (output_buffers, 0..) |*b, i| {
|
||||
b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
|
||||
}
|
||||
|
||||
callback(user_ctx, input_buffers, output_buffers);
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
|
||||
// log.warn("received buffer {}", .{buffer_desc});
|
||||
const dt: DataType = switch (buffer_desc.dtype) {
|
||||
.invalid => @panic("invalid ffi"),
|
||||
.pred => .bool,
|
||||
.i8 => .i8,
|
||||
.i16 => .i16,
|
||||
.i32 => .i32,
|
||||
.i64 => .i64,
|
||||
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
|
||||
inline else => |t| @field(DataType, @tagName(t)),
|
||||
};
|
||||
return Shape.init(buffer_desc.dims(), dt);
|
||||
}
|
||||
|
||||
/// Create a HostBuffer from a ffi description of a buffer.
|
||||
/// Normally the ffi describe device buffer but we assume they are located in pinned memory,
|
||||
/// and therefore the data pointer is readable both from host and from device.
|
||||
fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
||||
const buffer_shape = getShape(buffer_desc);
|
||||
return HostBuffer.fromBytes(
|
||||
buffer_shape,
|
||||
buffer_desc.data[0..buffer_shape.byteSize()],
|
||||
);
|
||||
}
|
||||
|
||||
pub const cuda = struct {
|
||||
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
|
||||
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
|
||||
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
|
||||
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
|
||||
|
||||
pub const MemcpyKind = enum(c_int) {
|
||||
host_to_host = 0,
|
||||
host_to_device = 1,
|
||||
device_to_host = 2,
|
||||
device_to_device = 3,
|
||||
inferred = 4,
|
||||
};
|
||||
|
||||
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int;
|
||||
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int;
|
||||
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int;
|
||||
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
|
||||
|
||||
pub fn init() void {
|
||||
var cudart = std.DynLib.open("libcudart.so.12") catch {
|
||||
log.err("cudart not found, callback will segfault", .{});
|
||||
return;
|
||||
};
|
||||
defer cudart.close();
|
||||
|
||||
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
|
||||
@panic("cudaMemcpyAsync not found");
|
||||
};
|
||||
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
|
||||
@panic("cudaMemcpy not found");
|
||||
};
|
||||
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
|
||||
@panic("cudaStreamSynchronize not found");
|
||||
};
|
||||
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
|
||||
@panic("cudaLaunchHostFunc not found");
|
||||
};
|
||||
}
|
||||
|
||||
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
|
||||
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
|
||||
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
|
||||
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
|
||||
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
|
||||
check(err);
|
||||
}
|
||||
|
||||
pub fn check(err: c_int) void {
|
||||
if (err == 0) return;
|
||||
stdx.debug.panic("CUDA error: {d}", .{err});
|
||||
}
|
||||
};
|
||||
|
||||
49
zml/exe.zig
49
zml/exe.zig
@ -5,6 +5,7 @@ const stdx = @import("stdx");
|
||||
const aio = @import("aio.zig");
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const callback = @import("callback.zig");
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const meta = @import("meta.zig");
|
||||
const pjrt = @import("pjrtx.zig");
|
||||
@ -154,7 +155,7 @@ pub const BaseExe = struct {
|
||||
exe: *pjrt.LoadedExecutable,
|
||||
|
||||
/// The execution context for this executable.
|
||||
context: ?*pjrt.ExecuteContext = null,
|
||||
execute_context: ?*pjrt.ExecuteContext,
|
||||
|
||||
/// Pre-allocated slice of buffers to use as inputs when the module is called.
|
||||
input_per_device: []const [*]*pjrt.Buffer,
|
||||
@ -205,9 +206,18 @@ pub const BaseExe = struct {
|
||||
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
|
||||
@memcpy(all_shapes[0..n_in], args.input_shapes);
|
||||
@memcpy(all_shapes[n_in..], args.result_shapes);
|
||||
|
||||
var execute_context: ?*pjrt.ExecuteContext = null;
|
||||
if (platform.pjrt_api.ffi()) |ffi| {
|
||||
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
|
||||
execute_context = try platform.pjrt_api.createExecuteContext();
|
||||
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
|
||||
}
|
||||
|
||||
return .{
|
||||
.platform = platform,
|
||||
.exe = exe,
|
||||
.execute_context = execute_context,
|
||||
.ready_buffer_count = 0,
|
||||
.input_buffer_count = @intCast(n_in),
|
||||
.num_devices = args.n_devices,
|
||||
@ -220,7 +230,7 @@ pub const BaseExe = struct {
|
||||
}
|
||||
|
||||
pub fn deinit(self: BaseExe) void {
|
||||
if (self.context) |ctx| {
|
||||
if (self.execute_context) |ctx| {
|
||||
ctx.deinit(self.platform.pjrt_api);
|
||||
}
|
||||
self._arena.deinit();
|
||||
@ -244,16 +254,16 @@ pub const BaseExe = struct {
|
||||
// even if it has been marked as "can be donated" during compilation.
|
||||
// TODO: expose it ?
|
||||
.non_donatable_input_indices = &.{},
|
||||
.context = self.context,
|
||||
.context = self.execute_context,
|
||||
}) catch |err| {
|
||||
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
|
||||
};
|
||||
|
||||
for (events[0..sharding.num_partitions]) |e| {
|
||||
if (e) |ev| {
|
||||
ev.await_(self.platform.pjrt_api) catch unreachable;
|
||||
}
|
||||
}
|
||||
// for (events[0..sharding.num_partitions]) |e| {
|
||||
// if (e) |ev| {
|
||||
// ev.await_(self.platform.pjrt_api) catch unreachable;
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
||||
@ -285,6 +295,17 @@ pub const BaseExe = struct {
|
||||
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
||||
}
|
||||
|
||||
pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void {
|
||||
stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{});
|
||||
const pjrt_api = exe.platform.pjrt_api;
|
||||
|
||||
if (pjrt_api.ffi()) |ffi| {
|
||||
try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op);
|
||||
} else {
|
||||
stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
||||
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
||||
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||
@ -314,11 +335,11 @@ pub const BaseExe = struct {
|
||||
|
||||
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
|
||||
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
|
||||
.n_in = self.input_buffer_count,
|
||||
.input_shapes = self.input_shapes,
|
||||
.result_shapes = self.result_shapes,
|
||||
.n_devices = self.num_devices,
|
||||
});
|
||||
exe.context = self.context;
|
||||
exe.execute_context = self.execute_context;
|
||||
return exe;
|
||||
}
|
||||
};
|
||||
@ -348,6 +369,14 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
||||
return new;
|
||||
}
|
||||
|
||||
/// For a given customCall inside this executable,
|
||||
/// provide a pointer to runtime data.
|
||||
/// The caller keeps memory ownership and need to ensure that the value
|
||||
/// stays alive as long as the executable.
|
||||
pub fn bind(self: Self, comptime T: type, value: *T) !void {
|
||||
try self.inner.bind(T, value);
|
||||
}
|
||||
|
||||
pub fn serialize(self: Self, writer: anytype) !void {
|
||||
return try self.inner.serialize(writer);
|
||||
}
|
||||
|
||||
@ -325,37 +325,18 @@ pub const HostBuffer = struct {
|
||||
self: HostBuffer,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
// TODO debug option
|
||||
// try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
|
||||
try writer.print("HostBuffer(.{f})", .{self._shape});
|
||||
}
|
||||
|
||||
/// Formatter for a HostBuffer that also print the values not just the shape.
|
||||
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
|
||||
pub fn pretty(self: HostBuffer) PrettyPrinter {
|
||||
return .{ .x = self };
|
||||
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||
return self.prettyPrintIndented(writer, 4, 0, n);
|
||||
}
|
||||
|
||||
pub const PrettyPrinter = struct {
|
||||
x: HostBuffer,
|
||||
|
||||
// TODO(0.15.0) revisit pretty printer
|
||||
pub fn format(self: PrettyPrinter, writer: anytype) !void {
|
||||
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
|
||||
.integer => .parse(i32, "d"),
|
||||
.float => .parse(f32, "d"),
|
||||
else => .parse(void, ""),
|
||||
};
|
||||
const options: std.fmt.FormatOptions = .{};
|
||||
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
|
||||
}
|
||||
};
|
||||
|
||||
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void {
|
||||
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: std.fmt.Number) !void {
|
||||
return self.prettyPrintIndented(writer, 4, 0, options);
|
||||
}
|
||||
|
||||
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
|
||||
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void {
|
||||
if (self.rank() == 0) {
|
||||
// Special case input tensor is a scalar
|
||||
return switch (self.dtype()) {
|
||||
@ -363,9 +344,10 @@ pub const HostBuffer = struct {
|
||||
const val: dt.toZigType() = self.items(dt.toZigType())[0];
|
||||
return switch (comptime dt.class()) {
|
||||
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
|
||||
.float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer),
|
||||
.integer => stdx.fmt.formatIntValue(val, options, writer),
|
||||
.bool, .complex => stdx.fmt.formatAnyValue(val, options, writer),
|
||||
.float => stdx.fmt.formatFloat(floats.floatCast(f32, val), options, writer),
|
||||
.integer => stdx.fmt.formatInt(val, options, writer),
|
||||
.bool => stdx.fmt.formatBool(val, options, writer),
|
||||
.complex => stdx.fmt.formatComplex(val, options, writer),
|
||||
};
|
||||
},
|
||||
};
|
||||
@ -380,7 +362,8 @@ pub const HostBuffer = struct {
|
||||
switch (comptime dt.class()) {
|
||||
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
||||
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
||||
.bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer),
|
||||
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
|
||||
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
@ -1178,9 +1178,9 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
|
||||
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
|
||||
.pointer => |info| switch (info.size) {
|
||||
.one => switch (strat) {
|
||||
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||
.deep => hash(hasher, key.*, .Shallow),
|
||||
.deeprecursive => switch (@typeInfo(info.child)) {
|
||||
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||
.Deep => hash(hasher, key.*, .Shallow),
|
||||
.DeepRecursive => switch (@typeInfo(info.child)) {
|
||||
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
|
||||
else => hash(hasher, key.*, .DeepRecursive),
|
||||
},
|
||||
@ -1196,7 +1196,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
|
||||
.many,
|
||||
.c,
|
||||
=> switch (strat) {
|
||||
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
|
||||
else => @compileError(
|
||||
\\ unknown-length pointers and C pointers cannot be hashed deeply.
|
||||
\\ Consider providing your own hash function.
|
||||
|
||||
27
zml/ops.zig
27
zml/ops.zig
@ -764,33 +764,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
||||
return res;
|
||||
}
|
||||
|
||||
pub const HostCallbackOpt = struct {
|
||||
has_side_effect: bool = false,
|
||||
output_operand_aliases: []const i64 = &.{},
|
||||
};
|
||||
|
||||
pub fn addHostCallback(
|
||||
callback: *const Context.HostCallback,
|
||||
blkctx: ?*anyopaque,
|
||||
inputs: []const Tensor,
|
||||
output_shapes: []const Shape,
|
||||
opts: HostCallbackOpt,
|
||||
) []Tensor {
|
||||
return customCall(
|
||||
"zmlHostBufferCallback",
|
||||
inputs,
|
||||
output_shapes,
|
||||
.{
|
||||
.callback = @intFromPtr(callback),
|
||||
.user_context = @intFromPtr(blkctx),
|
||||
},
|
||||
.{
|
||||
.has_side_effect = opts.has_side_effect,
|
||||
.output_operand_aliases = opts.output_operand_aliases,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub const TritonOps = struct {
|
||||
debug: bool = false,
|
||||
name: [:0]const u8,
|
||||
|
||||
@ -207,6 +207,13 @@ pub const Event = opaque {
|
||||
return self.inner().getEventError(api);
|
||||
}
|
||||
|
||||
pub fn awaitBlocking(self: *Event, api: *const Api) ApiError!void {
|
||||
if (self.isReady(api)) {
|
||||
return;
|
||||
}
|
||||
try self.inner().await_(api);
|
||||
}
|
||||
|
||||
pub fn await_(self: *Event, api: *const Api) ApiError!void {
|
||||
defer self.deinit(api);
|
||||
|
||||
@ -264,14 +271,14 @@ pub const LoadedExecutable = opaque {
|
||||
};
|
||||
|
||||
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
|
||||
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
|
||||
try self.inner().execute(api, pjrt.LoadedExecutable.ExecuteArgs{
|
||||
.num_args = args.num_args,
|
||||
.arguments = @ptrCast(args.arguments),
|
||||
.results = @ptrCast(args.results),
|
||||
.events = @ptrCast(args.events),
|
||||
.non_donatable_input_indices = args.non_donatable_input_indices,
|
||||
.context = args.context,
|
||||
} });
|
||||
});
|
||||
}
|
||||
|
||||
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {
|
||||
|
||||
@ -22,6 +22,11 @@ pub const Platform = struct {
|
||||
target: Target,
|
||||
pjrt_api: *const pjrt.Api,
|
||||
pjrt_client: *pjrt.Client,
|
||||
|
||||
// This make the pjrt struct quite fat, but is only used during compilation.
|
||||
// TODO: Reconsider having it here, and maybe pass explicitly to compile,
|
||||
// or create an intermediary struct:
|
||||
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
|
||||
compilation_options: CompilationOptions = .{},
|
||||
|
||||
pub const MAX_NUM_DEVICES: u8 = 32;
|
||||
@ -71,17 +76,6 @@ pub const Platform = struct {
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn registerFFIType(self: Platform, comptime T: type) !void {
|
||||
if (self.pjrt_api.ffi()) |ffi| {
|
||||
if (!@hasDecl(T, "type_id")) {
|
||||
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
|
||||
}
|
||||
try ffi.registerTypeId(self.pjrt_api, T);
|
||||
} else {
|
||||
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn deinit(self: *Platform) void {
|
||||
self.pjrt_client.deinit(self.pjrt_api);
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const callback = @import("callback.zig");
|
||||
const CompilationContext = @import("module.zig").CompilationContext;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
@ -3824,22 +3825,7 @@ pub const Tensor = struct {
|
||||
/// Only for debug purpose, it inserts device to host synchronization
|
||||
/// so it will slow down the program execution.
|
||||
pub fn print(input: Tensor) Tensor {
|
||||
// TODO: find a way of doing print that doesn't involve a H2D copy.
|
||||
return ops.addHostCallback(
|
||||
&printCallback,
|
||||
null,
|
||||
&.{input},
|
||||
&.{input.shape()},
|
||||
.{ .output_operand_aliases = &.{0} },
|
||||
)[0];
|
||||
}
|
||||
|
||||
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
||||
const host_buffer = inputs[0];
|
||||
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||
// This is true because of the operand aliases.
|
||||
// Since the result is already pointing to the input we don't need to modify the buffer.
|
||||
std.debug.assert(host_buffer._data == outputs[0]._data);
|
||||
return callback.call(callback.Print, .{input}, &.{input.shape()})[0];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@ -51,14 +51,13 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
if (should_free_left) left.deinit(allocator);
|
||||
if (should_free_right) right.deinit(allocator);
|
||||
}
|
||||
errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() });
|
||||
|
||||
errdefer log.err("\n--> Left: {0f}{0d:24.3}\n--> Right: {1f}{1d:24.3}", .{ left, right });
|
||||
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
|
||||
log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
|
||||
return error.TestUnexpectedResult;
|
||||
}
|
||||
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
|
||||
log.err("left.dtype ({}) != right.dtype ({})", .{ left.dtype(), right.dtype() });
|
||||
log.err("left.dtype ({f}) != right.dtype ({f})", .{ left.shape(), right.shape() });
|
||||
return error.TestUnexpectedResult;
|
||||
}
|
||||
switch (left.dtype()) {
|
||||
@ -89,7 +88,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
|
||||
const right_data = right.items(R);
|
||||
for (left_data, right_data, 0..) |l, r, i| {
|
||||
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
|
||||
log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
|
||||
log.err("left.data != right_data.\n < {d:40.3} \n > {d:40.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ stdx.fmt.slice(center(left_data, i)), stdx.fmt.slice(center(right_data, i)), i, left_data[i], right_data[i] });
|
||||
return error.TestUnexpectedResult;
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,11 +6,13 @@
|
||||
// Namespaces
|
||||
const std = @import("std");
|
||||
|
||||
pub const platform_specific = @import("c");
|
||||
pub const tokenizer = @import("zml/tokenizer");
|
||||
|
||||
pub const aio = @import("aio.zig");
|
||||
pub const Buffer = @import("buffer.zig").Buffer;
|
||||
pub const Bufferized = @import("tensor.zig").Bufferized;
|
||||
pub const callback = @import("callback.zig");
|
||||
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
|
||||
pub const context = @import("context.zig");
|
||||
pub const Context = @import("context.zig").Context;
|
||||
@ -43,7 +45,6 @@ pub const Tensor = @import("tensor.zig").Tensor;
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
// pub const tokenizer = @import("tokenizer.zig");
|
||||
pub const tools = struct {
|
||||
pub const Tracer = @import("tools/tracer.zig").Tracer;
|
||||
};
|
||||
|
||||
Loading…
Reference in New Issue
Block a user