Add experimental zml.callback API (renamed from custom_call) and fix tensor.print(); update PJRT bindings, host buffer utilities, and related core ZML modules.

This commit is contained in:
Tarry Singh 2025-08-20 10:27:54 +00:00
parent 1fa056a790
commit cc969bd532
17 changed files with 596 additions and 467 deletions

View File

@ -2,12 +2,21 @@
const std = @import("std");
const c = @import("c");
pub const TypeId = c.XLA_FFI_TypeId;
const stdx = @import("stdx");
const pjrtStruct = @import("pjrt.zig").pjrtStruct;
const pjrt = @import("pjrt.zig");
const Stream = @import("pjrt.zig").Stream;
const log = std.log.scoped(.pjrt);
comptime {
if (@typeInfo(TypeId).@"struct".fields.len != 1) @compileError("TypeId has changed");
}
/// The signature of a generic custom call.
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
pub const ApiVersion = extern struct {
pub const major = c.XLA_FFI_API_MAJOR;
pub const minor = c.XLA_FFI_API_MINOR;
@ -29,13 +38,13 @@ pub const ExtensionBase = extern struct {
};
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
pub const HandlerTraits = packed struct(u32) {
pub const HandlerTraits = packed struct(c_uint) {
/// Calls to FFI handler are safe to trace into the command buffer.
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
/// that can be captured and then replayed.
command_buffer_compatible: u1,
command_buffer_compatible: bool,
__unassigned__: u31,
__unassigned__: u31 = 0,
};
pub const Metadata = extern struct {
@ -49,25 +58,6 @@ pub const MetadataExtension = extern struct {
metadata: ?*Metadata,
};
pub const ApiError = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
return struct {
pub fn to(self: anytype) switch (@TypeOf(self)) {
@ -91,8 +81,8 @@ fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
pub const Api = opaque {
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
pub fn stream(self: *const Api, context: *const ExecutionContext) *Stream {
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
pub fn stream(self: *const Api, context: *const ExecutionContext) *pjrt.Stream {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Stream_Get_Args{
.ctx = @constCast(context.inner()),
});
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
@ -107,8 +97,8 @@ pub const Api = opaque {
return @ptrCast(ret.stream.?);
}
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
pub fn allocateDeviceMemory(self: *const Api, context: *const ExecutionContext, size: usize, alignment: usize) pjrt.ApiError!*anyopaque {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
.ctx = @constCast(context.inner()),
.size = size,
.alignment = alignment,
@ -127,8 +117,8 @@ pub const Api = opaque {
return ret.data.?;
}
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
pub fn freeDeviceMemory(self: *const Api, context: *const ExecutionContext, data: *anyopaque, size: usize) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
.ctx = @constCast(context.inner()),
.size = size,
.data = data,
@ -163,17 +153,17 @@ pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
pub const ExecutionContext = opaque {
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
pub fn Context(comptime T: type) type {
return struct {
pub fn get(self: *const ExecutionContext, api: *const Api) ApiError!*T {
const type_id: TypeId = .{ .type_id = T.type_id };
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
.ctx = @constCast(self.inner()),
.type_id = @constCast(&type_id.toCStruct()),
});
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
pub fn getContext(self: *const ExecutionContext, type_id: TypeId, api: *const Api) pjrt.ApiError!*anyopaque {
var ret: c.XLA_FFI_ExecutionContext_Get_Args = .{
.struct_size = pjrt.pjrtStructSize(c.XLA_FFI_ExecutionContext_Get_Args),
.extension_start = api.inner().extension_start,
.ctx = @ptrCast(@constCast(self)),
.type_id = @constCast(&type_id),
.data = undefined, // set by XLA_FFI_ExecutionContext_Get.
};
const maybe_err = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
if (result) |ffi_error| {
if (maybe_err) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
@ -183,13 +173,11 @@ pub const ExecutionContext = opaque {
}
if (ret.data == null) return error.NotFound;
return @ptrCast(@alignCast(ret.data.?));
}
};
return ret.data.?;
}
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) ApiError!i32 {
var ret = pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
pub fn getDeviceOrdinal(self: *const ExecutionContext, api: *const Api) pjrt.ApiError!i32 {
var ret = pjrt.pjrtStruct(c.XLA_FFI_DeviceOrdinal_Get_Args{
.ctx = @constCast(self.inner()),
});
const result = api.inner().XLA_FFI_DeviceOrdinal_Get.?(&ret);
@ -206,8 +194,10 @@ pub const ExecutionContext = opaque {
return ret.device_ordinal;
}
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
const Task = fn (*anyopaque) void;
pub fn scheduleTask(self: *const ExecutionContext, api: *const Api, task: *const Task, data: *anyopaque) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_ThreadPool_Schedule_Args{
.ctx = @constCast(self.inner()),
.task = @ptrCast(@alignCast(task)),
.data = @ptrCast(@alignCast(data)),
@ -225,23 +215,9 @@ pub const ExecutionContext = opaque {
return error.Unknown;
}
}
fn getTypeId(type_name: []const u8) TypeId {
const id: i64 = @bitCast(std.hash.Fnv1a_64.hash(type_name));
return .{
.type_id = id,
};
}
};
const TypeId = c.XLA_FFI_TypeId;
const Task = fn (*anyopaque) void;
const Stream = @import("pjrt.zig").Stream;
const ByteSpan = extern struct {
pub const ByteSpan = extern struct {
ptr: [*]const u8,
len: usize,
@ -252,7 +228,7 @@ const ByteSpan = extern struct {
pub const DataType = enum(c.XLA_FFI_DataType) {
invalid = c.XLA_FFI_DataType_INVALID,
pred = c.XLA_FFI_DataType_PRED,
bool = c.XLA_FFI_DataType_PRED,
i8 = c.XLA_FFI_DataType_S8,
i16 = c.XLA_FFI_DataType_S16,
i32 = c.XLA_FFI_DataType_S32,
@ -399,6 +375,8 @@ pub const Attrs = extern struct {
}
};
/// All informations needed by the user callback,
/// including the list of input/ouput buffers to work on.
pub const CallFrame = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
@ -422,9 +400,11 @@ pub const CallFrame = extern struct {
}
return false;
}
};
pub const Handler = fn (*CallFrame) callconv(.c) ?*Error;
pub fn stream(call_frame: CallFrame) ?*const pjrt.Stream {
return call_frame.api.stream(call_frame.ctx);
}
};
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
@ -444,7 +424,7 @@ pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
pub fn toApiError(code: ErrorCode) ApiError {
pub fn toApiError(code: ErrorCode) pjrt.ApiError {
return switch (code) {
.cancelled => error.Cancelled,
.unknown => error.Unknown,
@ -470,8 +450,10 @@ pub const Error = opaque {
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
pub const ok: ?*Error = null;
pub fn create(api: *const Api, error_code: ErrorCode, message: []const u8) *Error {
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Create_Args{
.message = message.ptr,
.errc = @intFromEnum(error_code),
});
@ -479,12 +461,12 @@ pub const Error = opaque {
}
pub fn destroy(err: *Error, api: *const Api) void {
var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
api.inner().XLA_FFI_Error_Destroy.?(&ret);
}
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
var ret = pjrt.pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
.@"error" = err.inner(),
});
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
@ -496,8 +478,8 @@ pub const Future = opaque {
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
pub fn create(api: *const Api) ApiError!*Future {
var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{});
pub fn create(api: *const Api) pjrt.ApiError!*Future {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_Create_Args{});
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
if (result) |ffi_error| {
@ -512,8 +494,8 @@ pub const Future = opaque {
return fromInner(ret.future.?);
}
pub fn setAvailable(self: *Future, api: *const Api) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
pub fn setAvailable(self: *Future, api: *const Api) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
.future = self.inner(),
});
@ -529,8 +511,8 @@ pub const Future = opaque {
}
}
pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{
pub fn setError(self: *Future, api: *const Api, err: *Error) pjrt.ApiError!void {
var ret = pjrt.pjrtStruct(c.XLA_FFI_Future_SetError_Args{
.future = self.inner(),
.@"error" = err.inner(),
});

View File

@ -20,7 +20,7 @@ test {
// as the way PJRT does it is not very robust.
//
// 1. https://github.com/openxla/xla/issues/10032
fn pjrtStructSize(comptime T: type) usize {
pub fn pjrtStructSize(comptime T: type) usize {
// unsafe on purpose, we want this to fail if that ever changes
const typedef_name = comptime blk: {
const needle = ".struct_";
@ -164,7 +164,7 @@ pub const Api = struct {
return @ptrCast(ret.context.?);
}
pub fn ffi(api: *const Api) ?FFI {
pub fn ffi(api: *const Api) ?Ffi {
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
return .{ .inner = ext };
}
@ -1278,7 +1278,7 @@ pub const NamedValue = extern struct {
}
};
pub const FFI = extern struct {
pub const Ffi = extern struct {
inner: *const c.PJRT_FFI,
pub const UserData = extern struct {
@ -1293,19 +1293,15 @@ pub const FFI = extern struct {
}
};
pub const RegisterFfiOptions = struct {
traits: RegisterHandlerTraits = @enumFromInt(0),
};
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub fn register(
self: *const FFI,
self: *const Ffi,
api: *const Api,
target_name: []const u8,
platform_name: []const u8,
func: *const ffi.Handler,
options: RegisterFfiOptions,
traits: ffi.HandlerTraits,
) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
.target_name = target_name.ptr,
@ -1313,7 +1309,7 @@ pub const FFI = extern struct {
.handler = @ptrCast(@constCast(func)),
.platform_name = platform_name.ptr,
.platform_name_size = platform_name.len,
.traits = @intFromEnum(options.traits),
.traits = @bitCast(traits),
});
const result = self.inner.register_handler.?(&ret);
if (result) |pjrt_c_error| {
@ -1323,8 +1319,7 @@ pub const FFI = extern struct {
}
}
pub fn registerTypeId(self: *const FFI, api: *const Api, T: type) ApiError!void {
const type_name = @typeName(T);
pub fn registerTypeId(self: *const Ffi, api: *const Api, type_name: []const u8) ApiError!ffi.TypeId {
var ret = pjrtStruct(c.PJRT_FFI_TypeID_Register_Args{
.type_name = type_name.ptr,
.type_name_size = type_name.len,
@ -1336,10 +1331,10 @@ pub const FFI = extern struct {
return pjrt_error.getCode(api).toApiError();
}
T.type_id = ret.type_id;
return .{ .type_id = ret.type_id };
}
pub fn addUserData(self: *const FFI, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
pub fn addUserData(self: *const Ffi, api: *const Api, context: *ExecuteContext, user_data: UserData) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_UserData_Add_Args{
.context = @ptrCast(context),
.user_data = user_data.toCStruct(),
@ -1352,12 +1347,3 @@ pub const FFI = extern struct {
}
}
};
pub const RegisterHandlerTraits = enum(c.PJRT_FFI_Handler_TraitsBits) {
command_buffer_compatible = c.PJRT_FFI_HANDLER_TRAITS_COMMAND_BUFFER_COMPATIBLE,
_,
};
pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_FFI_Register_Handler,
};

View File

@ -5,7 +5,7 @@ zig_shared_library(
name = "zmlxcuda",
# Use Clang's compiler-rt, but disable stack checking
# to avoid requiring on the _zig_probe_stack symbol.
copts = ["-fno-stack-check"],
copts = ["-fno-stack-check", "-fllvm"],
main = "zmlxcuda.zig",
shared_lib_name = "libzmlxcuda.so.0",
visibility = ["@libpjrt_cuda//:__subpackages__"],

View File

@ -1,145 +1,117 @@
const std = @import("std");
pub const Fmt = union(enum) {
int: IntFmt,
float: FloatFmt,
generic: void,
pub fn slice(any_slice: anytype) FmtSlice(std.meta.Elem(@TypeOf(any_slice))) {
return .{ .slice = any_slice };
}
pub fn parse(T: type, comptime fmt_: []const u8) Fmt {
fn FmtSlice(T: type) type {
return struct {
slice: []const T,
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
return switch (@typeInfo(T)) {
.float, .comptime_float => .{ .float = FloatFmt.parseComptime(fmt_) },
.int, .comptime_int => .{ .int = IntFmt.parseComptime(fmt_) },
else => .{ .generic = {} },
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
.comptime_int, .int => try formatIntSlice(f.slice, n, writer),
.bool => try formatBoolSlice(f.slice, n, writer),
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
try formatComplexSlice(f.slice, n, writer);
} else if (@hasDecl(T, "toF32")) {
try formatFloatSlice(f.slice, n, writer);
} else {
try formatSliceAny(f.slice, n, writer);
},
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
};
}
};
pub const FullFormatOptions = struct {
fmt: Fmt,
options: std.fmt.FormatOptions,
};
pub const IntFmt = struct {
base: u8,
case: std.fmt.Case = .lower,
pub fn parseComptime(comptime fmt_: []const u8) IntFmt {
return parse(fmt_) catch @panic("invalid fmt for int: " ++ fmt_);
}
pub fn parse(fmt_: []const u8) error{InvalidArgument}!IntFmt {
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "d"))
.{ .base = 10, .case = .lower }
else if (std.mem.eql(u8, fmt_, "x"))
.{ .base = 16, .case = .lower }
else if (std.mem.eql(u8, fmt_, "X"))
.{ .base = 16, .case = .upper }
else if (std.mem.eql(u8, fmt_, "o"))
.{ .base = 8, .case = .upper }
else
// TODO: unicode/ascii
error.InvalidArgument;
}
};
pub const FloatFmt = enum(u8) {
scientific = @intFromEnum(std.fmt.Number.Mode.scientific),
decimal = @intFromEnum(std.fmt.Number.Mode.decimal),
hex,
pub fn parseComptime(comptime fmt_: []const u8) FloatFmt {
return parse(fmt_) catch @panic("invalid fmt for float: " ++ fmt_);
}
pub fn parse(fmt_: []const u8) error{InvalidArgument}!FloatFmt {
return if (fmt_.len == 0 or std.mem.eql(u8, fmt_, "e"))
.scientific
else if (std.mem.eql(u8, fmt_, "d"))
.decimal
else if (std.mem.eql(u8, fmt_, "x"))
.hex
else
error.InvalidArgument;
}
};
pub fn formatValue(value: anytype, full: FullFormatOptions, writer: anytype) !void {
return switch (@typeInfo(@TypeOf(value))) {
.comptime_float, .float => try formatFloatValue(value, full, writer),
.comptime_int, .int => try formatIntValue(value, full, writer),
else => try formatAnyValue(value, full, writer),
};
}
pub fn formatFloatValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
pub fn formatFloat(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
const x = switch (@typeInfo(@TypeOf(value))) {
.@"struct" => value.toF32(),
.float => value,
else => @compileError("formatFloatValue expects a float, got: " ++ @typeName(@TypeOf(value))),
};
try switch (full.fmt.float) {
.scientific => writer.printFloat(x, .{ .mode = .scientific, .precision = full.options.precision }),
.decimal => writer.printFloat(x, .{ .mode = .decimal, .precision = full.options.precision }),
.hex => writer.printFloatHexOptions(x, .{ .mode = .hex }),
else => @compileError("formatFloat expects a float, got: " ++ @typeName(@TypeOf(value))),
};
return writer.printFloat(x, spec);
}
pub fn formatIntValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
pub fn formatInt(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
switch (@typeInfo(@TypeOf(value))) {
.int => {},
else => @compileError("formatIntValue expects an int, got: " ++ @typeName(@TypeOf(value))),
else => @compileError("formatInt expects an int, got: " ++ @typeName(@TypeOf(value))),
}
return writer.printInt(value, full.fmt.int.base, full.fmt.int.case, full.options);
return writer.printInt(value, spec.mode.base().?, spec.case, .{ .alignment = spec.alignment, .fill = spec.fill });
}
pub fn formatAnyValue(value: anytype, full: FullFormatOptions, writer: *std.Io.Writer) !void {
pub fn formatComplex(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
try writer.writeAll(".{.re=");
try writer.printFloat(value.re, spec);
try writer.writeAll(", .im=");
try writer.printFloat(value.im, spec);
try writer.writeAll("}");
}
pub fn formatBool(value: bool, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
try writer.alignBufferOptions(if (value) "1" else "0", .{ .alignment = spec.alignment, .fill = spec.fill });
}
pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
var buf: [48]u8 = undefined;
const s = std.fmt.bufPrint(&buf, "{any}", .{value}) catch blk: {
const T = @TypeOf(value);
const fmt = if (@hasDecl(T, "formatNumber")) "{d}" else "{f}";
const s = std.fmt.bufPrint(&buf, fmt, .{value}) catch blk: {
buf[45..].* = "...".*;
break :blk buf[0..];
};
return try writer.alignBufferOptions(s, full.options);
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
}
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, full: FullFormatOptions, writer: anytype) !void {
// Write first rows
const num_cols: usize = full.options.width orelse 12;
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
// use the format "width" for the number of columns instead of individual width.
const num_cols: usize = spec.width orelse 12;
var my_options = spec;
my_options.width = null;
const n: usize = values.len;
_ = try writer.write("{");
if (n <= num_cols) {
for (values, 0..) |v, i| {
// Force inlining so that the switch and the buffer can be done once.
try @call(.always_inline, fmt_func, .{ v, full, writer });
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
if (i < n - 1) _ = try writer.write(",");
}
} else {
const half = @divFloor(num_cols, 2);
for (values[0..half]) |v| {
try @call(.always_inline, fmt_func, .{ v, full, writer });
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
_ = try writer.write(",");
}
_ = try writer.write(" ..., ");
for (values[n - half ..], 0..) |v, i| {
try @call(.always_inline, fmt_func, .{ v, full, writer });
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
if (i < half - 1) _ = try writer.write(",");
}
}
_ = try writer.write("}");
}
pub fn formatAny(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatAnyValue, values, full, writer);
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatAny, values, spec, writer);
}
pub fn formatFloatSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatFloatValue, values, full, writer);
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatFloat, values, spec, writer);
}
pub fn formatIntSlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatIntValue, values, full, writer);
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatInt, values, spec, writer);
}
pub fn formatAnySlice(values: anytype, full: FullFormatOptions, writer: anytype) !void {
return try formatSliceCustom(formatAnyValue, values, full, writer);
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatComplex, values, spec, writer);
}
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
return try formatSliceCustom(formatBool, values, spec, writer);
}

View File

@ -31,6 +31,7 @@ zig_library(
"aio/torch/py.zig",
"buffer.zig",
"context.zig",
"callback.zig",
"dtype.zig",
"exe.zig",
"floats.zig",
@ -53,7 +54,7 @@ zig_library(
"torch.zig",
"zml.zig",
],
copts = ["-lc"],
copts = ["-lc", "-freference-trace=20"],
main = "zml.zig",
visibility = ["//visibility:public"],
deps = [

View File

@ -49,7 +49,7 @@ pub const Buffer = struct {
pub const FromOptions = struct {
wait: bool = true,
memory: ?pjrt.Memory.Kind = null,
memory: ?Memory = null,
};
/// Copies the content of the given buffer from host memory to the accelerator memory.
@ -89,15 +89,20 @@ pub const Buffer = struct {
.byte_strides = byte_strides,
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
};
if (opts.memory) |memory_kind| {
const memories = try devices[i].addressableMemories(platform.pjrt_api);
const memory = for (memories) |m| {
const kind = m.kind(platform.pjrt_api);
if (kind == memory_kind) break m;
} else return error.NotFound;
args.memory = memory;
} else {
if (platform.target == .cpu or opts.memory == null) {
args.device = devices[i];
} else {
const memory = opts.memory.?;
const device_memories = try devices[i].addressableMemories(platform.pjrt_api);
// TODO measure the cost of this and consider caching on Zig side inside the platform.
const selected_memory = for (device_memories) |m| {
const kind = m.kind(platform.pjrt_api);
if (kind == memory.toPjrtMemory()) break m;
} else {
log.warn("Platform {s} doesn't have memory {s}", .{ @tagName(platform.target), @tagName(memory) });
return error.NotFound;
};
args.memory = selected_memory;
}
const pjrt_buffer, const event = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, args);
@ -179,10 +184,10 @@ pub const Buffer = struct {
return try from(platform, host_buffer, opts);
}
pub fn asPinnedHostBuffer(self: Buffer) HostBuffer {
// TODO restore assert
pub fn asHostBuffer(self: Buffer) HostBuffer {
// TODO: skip this check on cpu
// const memory = self.getMemory().kind(self._api);
// stdx.debug.assert(memory == .pinned_host, "asPinnedHostBuffer({}) expects a buffer allocated on host memory, got {}. see `toMemory`", .{ self, memory });
// stdx.debug.assert((memory == .pinned_host) or (memory == .unpinned_host), "asHostBuffer({f}) expects a buffer allocated on host memory, got {t}. see `copyToMemory`", .{ self, memory });
const ptr: [*]u8 = @ptrCast(self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable);
return HostBuffer.fromBytes(self._shape, ptr[0..self._shape.byteSize()]);
}
@ -299,6 +304,12 @@ pub const Buffer = struct {
};
}
pub fn opaqueDeviceMemoryDataPointer(self: Buffer) [*]u8 {
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
return @ptrCast(opaque_ptr);
}
/// Fetches the content of the given buffer into a stack variable of the given type.
pub fn getValue(self: Buffer, T: type) !T {
stdx.debug.assert(self._shape.byteSize() == @sizeOf(T), "Buffer {f} has {d} bytes of data, can't load it to a {s} with {d} bytes", .{ self, self._shape.byteSize(), @typeName(T), @sizeOf(T) });
@ -390,13 +401,31 @@ pub const Buffer = struct {
return @reduce(.Or, self._shape._sharding_info);
}
pub fn copyToMemory(self: Buffer, memory: *const pjrt.Memory) !Buffer {
pub const CopyToMemoryOpts = struct {
wait: bool = true,
};
pub fn copyToMemory(self: Buffer, platform: Platform, memory: Memory, opts: CopyToMemoryOpts) !Buffer {
const pjrt_memory = platform.pjrt_client.memoryByKind(self._api, memory.toPjrtMemory());
if (pjrt_memory == null) {
stdx.debug.panic("Memory destination `{s}` for {f}", .{ memory.pjrtName(), self });
}
var new_shards: Buffer.Shards = .{};
for (self._shards.slice()) |shard| {
const new_shard = try shard.copyToMemory(self._api, memory);
const new_shard = try shard.copyToMemory(self._api, pjrt_memory.?);
new_shards.appendAssumeCapacity(new_shard);
}
if (opts.wait) {
for (new_shards.constSlice()) |shard| {
const event = shard.getReadyEvent(self._api);
if (event) |e| {
try e.awaitBlocking(self._api);
}
}
}
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
}

313
zml/callback.zig Normal file
View File

@ -0,0 +1,313 @@
const std = @import("std");
const asynk = @import("async");
const mlir = @import("mlir");
const pjrt = @import("pjrt");
const stablehlo = @import("mlir/dialects").stablehlo;
const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer;
const CompilationContext = @import("module.zig").CompilationContext;
const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const mlirx = @import("mlirx.zig");
const pjrtx = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const log = std.log.scoped(.@"zml/callback");
/// Inserts a user-defined callback into the computation graph.
/// The callback is defined with a struct, that store runtime information needed by the callback.
///
/// Experimental API
///
/// ```zig
/// pub const MyCallback = struct {
/// // a unique type_id will be set by the PJRT plugin during registration.
/// pub var type_id: pjrt.ffi.TypeId = undefined;
///
/// pub const callback_config: zml.callback.Config = .{
/// // assumption this custom call makes about the input / output buffers
/// };
///
/// // Required, this will tell the callback in which env it runs.
/// platform: zml.Platform,
/// // data needed by the callback
/// my_data: []const u8,
///
/// // storage modified by the runtime to tell the callback where it should write its results.
/// // Normally the callback doesn't need to allocate as the input and output buffers are given.
/// results: [1]Buffer = undefined,
///
/// pub fn init(my_data: []const u8) !MyCallback {
/// return .{ .my_data = my_data };
/// }
///
/// pub fn call(callback: *MyCallback, input: Buffer) !void {
/// // Do something with `input` and `callback.my_data`, write the results inside `callback.results[0]`
/// }
/// };
/// ```
///
/// See eg the implementation of the `zml.callback.Print` callback, for a practical example.
///
/// Note calling this during the compilation of a module, isn't enough:
///
/// * backend need to be made aware of the callback, see `zml.Platform.registerCallback`
/// * executable need to know the specific data needed by `MyCallback`, see `zml.Exe.bind`
pub fn call(
comptime Callback: type,
inputs: TensorArgs(Callback),
output_shapes: []const Shape,
) []Tensor {
checkIsValidCallback(Callback);
const ctx = CompilationContext.current();
const allocator = ctx.allocator();
const mlir_ctx = ctx.mlirCtx();
const platform = ctx._platform;
const pjrt_api = platform.pjrt_api;
if (pjrt_api.ffi() == null) {
stdx.debug.panic("Custom calls are not supported for target {s}", .{@tagName(platform.target)});
}
const output_tensors = allocator.alloc(Tensor, output_shapes.len) catch @panic("OOM");
// Note: we don't always free output_tensor, because it's returned to the caller.
// It's also why we allocate it first so that it doesn't fragment the arena.
errdefer allocator.free(output_tensors);
const output_types = allocator.alloc(mlir.Type, output_shapes.len) catch @panic("OOM");
defer allocator.free(output_types);
for (output_types, output_shapes) |*output_type, output_shape| {
output_type.* = mlirx.tensorType(mlir_ctx, output_shape);
}
const input_values = allocator.alloc(mlir.Value, inputs.len) catch @panic("OOM");
defer allocator.free(input_values);
for (input_values, inputs) |*input_value, input_tensor| {
input_value.* = input_tensor.value();
}
const target_name = "zml$" ++ @typeName(Callback);
const op = stablehlo.custom_call(
mlir_ctx,
input_values,
.{
.call_target_name = target_name,
.api_version = .typed_ffi,
.backend_config = .dict(mlir_ctx, &.{}),
.additional_attributes = &.{.{ "mhlo.frontend_attributes", .dict(mlir_ctx, &.{}) }},
.has_side_effect = true,
.output_operand_aliases = Callback.callback_config.output_operand_aliases,
},
output_types,
mlir_ctx.location(@src()),
);
for (output_tensors, output_shapes, 0..) |*output_tensor, output_shape, i| {
output_tensor.* = Tensor._result(output_shape, op.result(i));
}
return output_tensors;
}
/// Describe properties of a callback
///
/// * output_operand_aliases: the callback reuse input buffer to write the output
/// * copy_inputs_to_host_pinned: the callback need to work on host visible buffers
/// * traits: PJRT specified properties of the callback
pub const Config = struct {
output_operand_aliases: []const i64 = &.{},
copy_inputs_to_host_pinned: bool = false,
// TODO: document precisely what `command_buffer_compatible` is doing and its limitations.
traits: pjrt.ffi.HandlerTraits = .{ .command_buffer_compatible = false },
// TODO: handle sharded inputs
};
/// Compile-time check that a callback has all informations we require.
pub fn checkIsValidCallback(Callback: type) void {
stdx.debug.assertComptime(@hasDecl(Callback, "call"), "Expected callback {} to have a call method", .{Callback});
const ArgsT = stdx.meta.FnArgs(Callback.call);
inline for (@typeInfo(ArgsT).@"struct".fields[1..]) |field| {
stdx.debug.assertComptime(field.type == Buffer, "Expected callback {}.call arguments to be of type zml.Buffer, got {}", .{ Callback, field.type });
}
stdx.debug.assertComptime(@hasDecl(Callback, "type_id") and @TypeOf(Callback.type_id) == pjrt.ffi.TypeId, "Expected callback {} to have a field `pub var type_id: pjrt.ffi.TypeId`", .{Callback});
stdx.debug.assertComptime(@hasDecl(Callback, "callback_config") and @TypeOf(Callback.callback_config) == Config, "Expected callback {} to have a field `pub const callback_config: zml.CustomCallOptions`", .{Callback});
}
pub fn register(Callback: type, platform: Platform) pjrt.ApiError!void {
checkIsValidCallback(Callback);
const ffi = platform.pjrt_api.ffi() orelse return error.Unavailable;
const target_name = "zml$" ++ @typeName(Callback);
const proxy_cb = proxy(Callback);
Callback.type_id = try ffi.registerTypeId(platform.pjrt_api, @typeName(Callback));
try ffi.register(platform.pjrt_api, target_name, @tagName(platform.target), &proxy_cb, Callback.callback_config.traits);
log.debug("Registered custom call {} with target name \"{s}\"", .{ Callback, target_name });
}
fn proxy(Callback: type) pjrt.ffi.Handler {
return struct {
pub fn cb(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
return CallbackImpl(Callback, call_frame);
}
}.cb;
}
fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt.ffi.Error {
if (call_frame.registeringHook()) return null;
const opts = Callback.callback_config;
const execution_context = call_frame.ctx;
log.debug("Custom call {s} called !", .{@typeName(Callback)});
const user_ctx_opaque = execution_context.getContext(Callback.type_id, call_frame.api) catch {
log.err("{} user data was never given for current executable", .{Callback});
return .create(call_frame.api, .failed_precondition, "failed to fetch user context" ++ @typeName(Callback));
};
const user_ctx: *Callback = @ptrCast(@alignCast(user_ctx_opaque));
// We actually have one more constraint here, we force the Callback to have a platform field,
// and to correctly set it.
// Is this good ? We could also simplify this by registering ourselves the `Platform` type id.
const platform: Platform = user_ctx.platform;
// Hook to get a cuda stream in the callback.
if (@hasField(Callback, "stream") and platform.target != .cpu) {
const stream = call_frame.api.stream(execution_context);
user_ctx.stream = stream;
}
var callback_args: std.meta.ArgsTuple(@TypeOf(Callback.call)) = undefined;
callback_args[0] = user_ctx;
inline for (1..callback_args.len, call_frame.args.buffers()) |i, ffi_buffer| {
const shape = shapeFromFfi(ffi_buffer);
var zml_buffer: Buffer = if (platform.target == .cpu)
.asViewOfHostBuffer(platform, .fromBytes(shape, ffi_buffer.data[0..shape.byteSize()]))
else
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
log.debug("Copying argument {d} {f} {*} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer() });
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
log.err("Failed to copy input buffer {d} {f} {*} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), err });
return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
};
log.debug("--> {f} {*} ({})", .{ zml_buffer, zml_buffer.opaqueDeviceMemoryDataPointer(), @as(*const f32, @ptrCast(@alignCast(zml_buffer.opaqueDeviceMemoryDataPointer()))).* });
}
callback_args[i] = zml_buffer;
}
defer {
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
inline for (1..callback_args.len) |i| callback_args[i].deinit();
}
}
for (0..call_frame.results.len) |i| {
const ffi_buffer = call_frame.results.buffers()[i];
const ffi_buffer_shape = shapeFromFfi(ffi_buffer);
if (platform.target == .cpu) {
user_ctx.results[i] = Buffer.asViewOfHostBuffer(platform, HostBuffer.fromBytes(ffi_buffer_shape, ffi_buffer.data[0..ffi_buffer_shape.byteSize()]));
} else {
user_ctx.results[i] = Buffer.asViewOfDeviceBuffer(platform, shapeFromFfi(ffi_buffer), null, ffi_buffer.data);
}
}
@call(.auto, Callback.call, callback_args) catch |err| {
log.err("Callback {} failed with {}", .{ Callback, err });
return .create(call_frame.api, .internal, "internal callback error");
};
return .ok;
}
/// Internal custom calls.
/// These are not meant to be used by users, but rather by the library itself.
pub const internal_callbacks = [_]type{
Print,
};
pub fn registerInternalCallbacks(platform: Platform) !void {
inline for (internal_callbacks) |Callback| {
try register(Callback, platform);
// log.debug("Registered internal custom call {s} with type_id {d}", .{ @typeName(Callback), Callback.type_id.type_id });
}
}
/// Allocate user data data needed by the ZML provided custom calls.
pub fn bindInternalCallbacks(
arena: std.mem.Allocator,
platform: Platform,
ffi: pjrt.Ffi,
execute_context: *pjrt.ExecuteContext,
) (std.mem.Allocator.Error || pjrt.ApiError)!void {
// Atm we don't have a mechanism to detect which ZML callbacks the executable needs,
// so we always allocate.
{
// Print
const print_ptr = try arena.create(Print);
print_ptr.* = try .init(platform);
try addUserData(Print, platform.pjrt_api, ffi, execute_context, print_ptr);
}
}
pub fn addUserData(
Callback: type,
api: *const pjrt.Api,
ffi: pjrt.Ffi,
execute_context: *pjrt.ExecuteContext,
user_data: *Callback,
) pjrt.ApiError!void {
try ffi.addUserData(
api,
execute_context,
.{ .type_id = Callback.type_id.type_id, .user_data = @ptrCast(user_data) },
);
log.debug("Bound {s}@{x} with type id {d} on {any}", .{ @typeName(Callback), @intFromPtr(user_data), Callback.type_id.type_id, execute_context });
}
/// The print callback
pub const Print = struct {
// a unique type_id will be set by the PJRT plugin during registration.
pub var type_id: pjrt.ffi.TypeId = undefined;
pub const callback_config: Config = .{
// Print callback pretends to modify the given input buffer, but just returns it unmodified.
.output_operand_aliases = &.{0},
// It also needs PJRT to copy the data on the host first so it can print it.
.copy_inputs_to_host_pinned = true,
// Print is fairly predictable and can be captured in an execution graph.
.traits = .{ .command_buffer_compatible = false },
};
platform: Platform,
results: [1]Buffer = undefined,
pub fn init(platform: Platform) !Print {
return .{ .platform = platform };
}
pub fn call(_: *Print, input: Buffer) !void {
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {d:20.3}", .{ input, input.asHostBuffer() });
}
};
fn shapeFromFfi(ffi_buffer: *const pjrt.ffi.Buffer) Shape {
const dt: DataType = switch (ffi_buffer.dtype) {
.invalid => stdx.debug.panic("Invalid FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
.token, .f8e4m3, .f8e3m4 => stdx.debug.panic("Unsupported FFI dtype {any} used by {any}", .{ ffi_buffer.dtype, ffi_buffer }),
inline else => |t| @field(DataType, @tagName(t)),
};
return Shape.init(ffi_buffer.dims(), dt);
}
fn TensorArgs(Callback: type) type {
const ArgsT = stdx.meta.FnArgs(Callback.call);
const args = @typeInfo(ArgsT).@"struct".fields;
return [args.len - 1]Tensor;
}

View File

@ -7,16 +7,19 @@ const runfiles = @import("runfiles");
const runtimes = @import("runtimes");
const stdx = @import("stdx");
const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
const Target = @import("platform.zig").Target;
const zml_platform = @import("platform.zig");
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
const PlatformsMap = std.EnumArray(Target, ?Platform);
const zml = struct {
pub const callback = @import("callback.zig");
pub const HostBuffer = @import("hostbuffer.zig").HostBuffer;
pub const Platform = @import("platform.zig").Platform;
pub const platform = @import("platform.zig");
pub const Shape = @import("shape.zig").Shape;
pub const Target = @import("platform.zig").Target;
};
const PjrtApiMap = std.EnumArray(zml.Target, ?*const pjrt.Api);
const PlatformsMap = std.EnumArray(zml.Target, ?zml.Platform);
const log = std.log.scoped(.@"zml/context");
test {
@ -94,7 +97,7 @@ pub const Context = struct {
return .{ .platforms = PlatformsMap.initFill(null) };
}
fn platformToLibrary(comptime target: Target) []const u8 {
fn platformToLibrary(comptime target: zml.Target) []const u8 {
const ext = switch (builtin.os.tag) {
.windows => ".dll",
.macos, .ios, .watchos => ".dylib",
@ -105,7 +108,7 @@ pub const Context = struct {
};
}
pub fn pjrtApi(target: Target) *const pjrt.Api {
pub fn pjrtApi(target: zml.Target) *const pjrt.Api {
return Context.apis.get(target).?;
}
@ -119,12 +122,12 @@ pub const Context = struct {
self.* = undefined;
}
const prefered_targets = [_]Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
const prefered_targets = [_]zml.Target{ .tpu, .neuron, .cuda, .rocm, .cpu };
/// Automatically selects the best Platform loaded in the current Context.
///
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
pub fn autoPlatform(self: *Context, opts: Platform.CreateOptions) Platform {
pub fn autoPlatform(self: *Context, opts: zml.Platform.CreateOptions) zml.Platform {
stdx.debug.assert(prefered_targets.len == apis.values.len, "New target need to be inserted inside `zml.Context.preferred_targets`", .{});
return self.platformByPreferences(opts, &prefered_targets);
@ -133,7 +136,7 @@ pub const Context = struct {
/// Given a list of preferred targets to select the best Platform
///
/// For example, if supported, this will select a platform corresponding to an accelerator (GPU, TPU, ...).
pub fn platformByPreferences(self: *Context, opts: Platform.CreateOptions, prefered: []const Target) Platform {
pub fn platformByPreferences(self: *Context, opts: zml.Platform.CreateOptions, prefered: []const zml.Target) zml.Platform {
// Try prefered targets.
for (prefered) |target| {
if (apis.get(target) == null) continue;
@ -150,7 +153,7 @@ pub const Context = struct {
// CPU should only be use as fallback.
if (target == .cpu) continue;
if (entry.value.* == null) continue;
if (std.mem.indexOfScalar(Target, prefered, target) != null) continue;
if (std.mem.indexOfScalar(zml.Target, prefered, target) != null) continue;
return self.platform(target, opts) catch |err| {
log.err("Failed to load platform .{s}: {}", .{ @tagName(target), err });
continue;
@ -164,25 +167,25 @@ pub const Context = struct {
};
}
pub fn platform(self: *Context, target: Target, opts: Platform.CreateOptions) !Platform {
pub fn platform(self: *Context, target: zml.Target, opts: zml.Platform.CreateOptions) !zml.Platform {
if (self.platforms.get(target)) |p| {
return p;
}
const api = Context.apis.get(target);
if (api == null) return error.PlatformNotCompiled;
const p = try Platform.init(target, api.?, opts);
const p = try zml.Platform.init(target, api.?, opts);
if (p.getDevices().len == 0) {
log.err("No device found for platform {} !", .{target});
return error.NoDevicesFound;
}
try CustomCall.registerZmlCustomCalls(p);
self.platforms.set(target, p);
try zml.callback.registerInternalCallbacks(p);
return p;
}
pub fn printAvailablePlatforms(self: Context, selected: Platform) void {
pub fn printAvailablePlatforms(self: Context, selected: zml.Platform) void {
// List available targets
log.info("Available Platforms:", .{});
const selected_prefix = "";
@ -190,7 +193,7 @@ pub const Context = struct {
const selected_postfix = "(AUTO-SELECTED)";
const not_selected_postfix = "";
for (zml_platform.available_targets) |target| {
for (zml.platform.available_targets) |target| {
log.info(" {s} {s} {s}", .{
if (target == selected.target) selected_prefix else not_selected_prefix,
@tagName(target),
@ -211,133 +214,4 @@ pub const Context = struct {
}
}
}
pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void;
};
const CustomCall = struct {
pub fn registerZmlCustomCalls(platform: Platform) !void {
const ffi = platform.pjrt_api.ffi() orelse {
log.warn("Registering custom calls failed: No FFI Extension found in {s} PJRT Plugin.", .{@tagName(platform.target)});
return;
};
try ffi.register(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback, .{});
}
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.c) ?*pjrt.ffi.Error {
if (call_frame.registeringHook()) return null;
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
std.debug.assert(callback_attr.dtype == .u64);
const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize));
const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable;
std.debug.assert(user_ctx_ptr.dtype == .u64);
const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize));
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
for (input_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.args.buffers()[i]);
}
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
for (output_buffers, 0..) |*b, i| {
b.* = hostBufferFromPinnedBuffer(call_frame.results.buffers()[i]);
}
callback(user_ctx, input_buffers, output_buffers);
return null;
}
};
fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
// log.warn("received buffer {}", .{buffer_desc});
const dt: DataType = switch (buffer_desc.dtype) {
.invalid => @panic("invalid ffi"),
.pred => .bool,
.i8 => .i8,
.i16 => .i16,
.i32 => .i32,
.i64 => .i64,
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
inline else => |t| @field(DataType, @tagName(t)),
};
return Shape.init(buffer_desc.dims(), dt);
}
/// Create a HostBuffer from a ffi description of a buffer.
/// Normally the ffi describe device buffer but we assume they are located in pinned memory,
/// and therefore the data pointer is readable both from host and from device.
fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
const buffer_shape = getShape(buffer_desc);
return HostBuffer.fromBytes(
buffer_shape,
buffer_desc.data[0..buffer_shape.byteSize()],
);
}
pub const cuda = struct {
pub var streamSynchronize: StreamSynchronize = @ptrFromInt(0xdeadc00da00);
pub var cuLaunchHostFunc: CuLaunchHostFunc = @ptrFromInt(0xdeadc00da00);
var _memcpyAsync: MemcpyAsync = @ptrFromInt(0xdeadc00da00);
var _memcpyBlocking: MemcpyBlocking = @ptrFromInt(0xdeadc00da00);
pub const MemcpyKind = enum(c_int) {
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
inferred = 4,
};
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: ?*anyopaque) callconv(.c) c_int;
const MemcpyBlocking = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind) callconv(.c) c_int;
const StreamSynchronize = *const fn (stream: *anyopaque) callconv(.c) c_int;
const CuLaunchHostFunc = *const fn (stream: *anyopaque, host_func: *const fn (user_data: *const anyopaque) callconv(.c) void, user_data: *const anyopaque) callconv(.c) c_int;
pub fn init() void {
var cudart = std.DynLib.open("libcudart.so.12") catch {
log.err("cudart not found, callback will segfault", .{});
return;
};
defer cudart.close();
_memcpyAsync = cudart.lookup(MemcpyAsync, "cudaMemcpyAsync") orelse {
@panic("cudaMemcpyAsync not found");
};
_memcpyBlocking = cudart.lookup(MemcpyBlocking, "cudaMemcpy") orelse {
@panic("cudaMemcpy not found");
};
streamSynchronize = cudart.lookup(StreamSynchronize, "cudaStreamSynchronize") orelse {
@panic("cudaStreamSynchronize not found");
};
cuLaunchHostFunc = cudart.lookup(CuLaunchHostFunc, "cudaLaunchHostFunc") orelse {
@panic("cudaLaunchHostFunc not found");
};
}
pub fn memcpyToHostBlocking(dst: []u8, src: *const anyopaque) void {
const err = _memcpyBlocking(dst.ptr, src, dst.len, .device_to_host);
check(err);
}
pub fn memcpyToDeviceBlocking(dst: *anyopaque, src: []const u8) void {
const err = _memcpyBlocking(dst, src.ptr, src.len, .host_to_device);
check(err);
}
pub fn memcpyToDeviceAsync(dst: *anyopaque, src: []const u8, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst, src.ptr, src.len, .host_to_device, stream);
check(err);
}
pub fn memcpyToHostAsync(dst: []u8, src: *const anyopaque, stream: ?*anyopaque) void {
const err = _memcpyAsync(dst.ptr, src, dst.len, .device_to_host, stream);
check(err);
}
pub fn check(err: c_int) void {
if (err == 0) return;
stdx.debug.panic("CUDA error: {d}", .{err});
}
};

View File

@ -5,6 +5,7 @@ const stdx = @import("stdx");
const aio = @import("aio.zig");
const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("tensor.zig").Bufferized;
const callback = @import("callback.zig");
const CompilationContext = @import("module.zig").CompilationContext;
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
@ -154,7 +155,7 @@ pub const BaseExe = struct {
exe: *pjrt.LoadedExecutable,
/// The execution context for this executable.
context: ?*pjrt.ExecuteContext = null,
execute_context: ?*pjrt.ExecuteContext,
/// Pre-allocated slice of buffers to use as inputs when the module is called.
input_per_device: []const [*]*pjrt.Buffer,
@ -205,9 +206,18 @@ pub const BaseExe = struct {
const all_shapes = try allocator.alloc(Shape, n_in + n_out);
@memcpy(all_shapes[0..n_in], args.input_shapes);
@memcpy(all_shapes[n_in..], args.result_shapes);
var execute_context: ?*pjrt.ExecuteContext = null;
if (platform.pjrt_api.ffi()) |ffi| {
log.info("Created context execution {*} for {*}", .{ execute_context, exe });
execute_context = try platform.pjrt_api.createExecuteContext();
try callback.bindInternalCallbacks(allocator, platform, ffi, execute_context.?);
}
return .{
.platform = platform,
.exe = exe,
.execute_context = execute_context,
.ready_buffer_count = 0,
.input_buffer_count = @intCast(n_in),
.num_devices = args.n_devices,
@ -220,7 +230,7 @@ pub const BaseExe = struct {
}
pub fn deinit(self: BaseExe) void {
if (self.context) |ctx| {
if (self.execute_context) |ctx| {
ctx.deinit(self.platform.pjrt_api);
}
self._arena.deinit();
@ -244,16 +254,16 @@ pub const BaseExe = struct {
// even if it has been marked as "can be donated" during compilation.
// TODO: expose it ?
.non_donatable_input_indices = &.{},
.context = self.context,
.context = self.execute_context,
}) catch |err| {
std.debug.panic("PJRT_LoadedExecutable_Execute failed with: {}", .{err});
};
for (events[0..sharding.num_partitions]) |e| {
if (e) |ev| {
ev.await_(self.platform.pjrt_api) catch unreachable;
}
}
// for (events[0..sharding.num_partitions]) |e| {
// if (e) |ev| {
// ev.await_(self.platform.pjrt_api) catch unreachable;
// }
// }
}
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
@ -285,6 +295,17 @@ pub const BaseExe = struct {
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
}
pub fn bind(exe: BaseExe, Callback: type, op: *Callback) !void {
stdx.debug.assert(exe.execute_context != null, "Exe doesn't have an execution context", .{});
const pjrt_api = exe.platform.pjrt_api;
if (pjrt_api.ffi()) |ffi| {
try callback.addUserData(Callback, pjrt_api, ffi, exe.execute_context.?, op);
} else {
stdx.debug.panic("Callbacks are not supported for target {s}", .{@tagName(exe.platform.target)});
}
}
pub fn serialize(self: BaseExe, writer: anytype) !void {
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
var serialize_result = try executable.serialize(self.platform.pjrt_api);
@ -314,11 +335,11 @@ pub const BaseExe = struct {
pub fn clone(self: BaseExe, parent_allocator: std.mem.Allocator) !BaseExe {
var exe: BaseExe = try .init(parent_allocator, self.platform, self.exe, .{
.n_in = self.input_buffer_count,
.input_shapes = self.input_shapes,
.result_shapes = self.result_shapes,
.n_devices = self.num_devices,
});
exe.context = self.context;
exe.execute_context = self.execute_context;
return exe;
}
};
@ -348,6 +369,14 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
return new;
}
/// For a given customCall inside this executable,
/// provide a pointer to runtime data.
/// The caller keeps memory ownership and need to ensure that the value
/// stays alive as long as the executable.
pub fn bind(self: Self, comptime T: type, value: *T) !void {
try self.inner.bind(T, value);
}
pub fn serialize(self: Self, writer: anytype) !void {
return try self.inner.serialize(writer);
}

View File

@ -325,37 +325,18 @@ pub const HostBuffer = struct {
self: HostBuffer,
writer: anytype,
) !void {
// TODO debug option
// try writer.print("HostBuffer(.{f})@0x{x}", .{ self._shape, @intFromPtr(self._data) });
try writer.print("HostBuffer(.{f})", .{self._shape});
}
/// Formatter for a HostBuffer that also print the values not just the shape.
/// Usage: `std.log.info("my buffer: {}", .{buffer.pretty()});`
pub fn pretty(self: HostBuffer) PrettyPrinter {
return .{ .x = self };
pub fn formatNumber(self: HostBuffer, writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
return self.prettyPrintIndented(writer, 4, 0, n);
}
pub const PrettyPrinter = struct {
x: HostBuffer,
// TODO(0.15.0) revisit pretty printer
pub fn format(self: PrettyPrinter, writer: anytype) !void {
const fmt_: stdx.fmt.Fmt = switch (self.x.dtype().class()) {
.integer => .parse(i32, "d"),
.float => .parse(f32, "d"),
else => .parse(void, ""),
};
const options: std.fmt.FormatOptions = .{};
try prettyPrint(self.x, writer, .{ .fmt = fmt_, .options = options });
}
};
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: stdx.fmt.FullFormatOptions) !void {
pub fn prettyPrint(self: HostBuffer, writer: *std.Io.Writer, options: std.fmt.Number) !void {
return self.prettyPrintIndented(writer, 4, 0, options);
}
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: stdx.fmt.FullFormatOptions) !void {
fn prettyPrintIndented(self: HostBuffer, writer: *std.Io.Writer, num_rows: u8, indent_level: u8, options: std.fmt.Number) !void {
if (self.rank() == 0) {
// Special case input tensor is a scalar
return switch (self.dtype()) {
@ -363,9 +344,10 @@ pub const HostBuffer = struct {
const val: dt.toZigType() = self.items(dt.toZigType())[0];
return switch (comptime dt.class()) {
// Since we have custom floats, we need to explicitly convert to float32 ourselves.
.float => stdx.fmt.formatFloatValue(floats.floatCast(f32, val), options, writer),
.integer => stdx.fmt.formatIntValue(val, options, writer),
.bool, .complex => stdx.fmt.formatAnyValue(val, options, writer),
.float => stdx.fmt.formatFloat(floats.floatCast(f32, val), options, writer),
.integer => stdx.fmt.formatInt(val, options, writer),
.bool => stdx.fmt.formatBool(val, options, writer),
.complex => stdx.fmt.formatComplex(val, options, writer),
};
},
};
@ -380,7 +362,8 @@ pub const HostBuffer = struct {
switch (comptime dt.class()) {
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
.bool, .complex => try stdx.fmt.formatAnySlice(values, options, writer),
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
}
},
}

View File

@ -1178,9 +1178,9 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
.@"anyframe", .@"fn" => hash(hasher, @intFromPtr(key), strat),
.pointer => |info| switch (info.size) {
.one => switch (strat) {
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
.deep => hash(hasher, key.*, .Shallow),
.deeprecursive => switch (@typeInfo(info.child)) {
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
.Deep => hash(hasher, key.*, .Shallow),
.DeepRecursive => switch (@typeInfo(info.child)) {
.@"opaque", .@"fn" => hash(hasher, @intFromPtr(key), .Shallow),
else => hash(hasher, key.*, .DeepRecursive),
},
@ -1196,7 +1196,7 @@ pub fn hash(hasher: *std.hash.Wyhash, key: anytype, comptime strat: std.hash.Str
.many,
.c,
=> switch (strat) {
.shallow => hash(hasher, @intFromPtr(key), .Shallow),
.Shallow => hash(hasher, @intFromPtr(key), .Shallow),
else => @compileError(
\\ unknown-length pointers and C pointers cannot be hashed deeply.
\\ Consider providing your own hash function.

View File

@ -764,33 +764,6 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
return res;
}
pub const HostCallbackOpt = struct {
has_side_effect: bool = false,
output_operand_aliases: []const i64 = &.{},
};
pub fn addHostCallback(
callback: *const Context.HostCallback,
blkctx: ?*anyopaque,
inputs: []const Tensor,
output_shapes: []const Shape,
opts: HostCallbackOpt,
) []Tensor {
return customCall(
"zmlHostBufferCallback",
inputs,
output_shapes,
.{
.callback = @intFromPtr(callback),
.user_context = @intFromPtr(blkctx),
},
.{
.has_side_effect = opts.has_side_effect,
.output_operand_aliases = opts.output_operand_aliases,
},
);
}
pub const TritonOps = struct {
debug: bool = false,
name: [:0]const u8,

View File

@ -207,6 +207,13 @@ pub const Event = opaque {
return self.inner().getEventError(api);
}
pub fn awaitBlocking(self: *Event, api: *const Api) ApiError!void {
if (self.isReady(api)) {
return;
}
try self.inner().await_(api);
}
pub fn await_(self: *Event, api: *const Api) ApiError!void {
defer self.deinit(api);
@ -264,14 +271,14 @@ pub const LoadedExecutable = opaque {
};
pub fn execute(self: *const LoadedExecutable, api: *const Api, args: ExecuteArgs) ExecuteError!void {
try asynk.callBlocking(pjrt.LoadedExecutable.execute, .{ self.inner(), api, pjrt.LoadedExecutable.ExecuteArgs{
try self.inner().execute(api, pjrt.LoadedExecutable.ExecuteArgs{
.num_args = args.num_args,
.arguments = @ptrCast(args.arguments),
.results = @ptrCast(args.results),
.events = @ptrCast(args.events),
.non_donatable_input_indices = args.non_donatable_input_indices,
.context = args.context,
} });
});
}
pub fn getExecutable(self: *LoadedExecutable, api: *const Api) ApiError!*Executable {

View File

@ -22,6 +22,11 @@ pub const Platform = struct {
target: Target,
pjrt_api: *const pjrt.Api,
pjrt_client: *pjrt.Client,
// This make the pjrt struct quite fat, but is only used during compilation.
// TODO: Reconsider having it here, and maybe pass explicitly to compile,
// or create an intermediary struct:
// `const comp = platform.compiler(compile_opts); const exe = comp.compile(...);`
compilation_options: CompilationOptions = .{},
pub const MAX_NUM_DEVICES: u8 = 32;
@ -71,17 +76,6 @@ pub const Platform = struct {
return res;
}
pub fn registerFFIType(self: Platform, comptime T: type) !void {
if (self.pjrt_api.ffi()) |ffi| {
if (!@hasDecl(T, "type_id")) {
stdx.debug.panic("registerFFIType requires type {s} to have a `type_id` i64 field ", .{@typeName(T)});
}
try ffi.registerTypeId(self.pjrt_api, T);
} else {
stdx.debug.panic("registerFFIType is not available for target {s}", .{@tagName(self.target)});
}
}
pub fn deinit(self: *Platform) void {
self.pjrt_client.deinit(self.pjrt_api);
}

View File

@ -5,6 +5,7 @@ const mlir = @import("mlir");
const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer;
const callback = @import("callback.zig");
const CompilationContext = @import("module.zig").CompilationContext;
const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType;
@ -3824,22 +3825,7 @@ pub const Tensor = struct {
/// Only for debug purpose, it inserts device to host synchronization
/// so it will slow down the program execution.
pub fn print(input: Tensor) Tensor {
// TODO: find a way of doing print that doesn't involve a H2D copy.
return ops.addHostCallback(
&printCallback,
null,
&.{input},
&.{input.shape()},
.{ .output_operand_aliases = &.{0} },
)[0];
}
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
const host_buffer = inputs[0];
std.log.defaultLog(.info, .zml, "Device buffer: {f}: {f}", .{ host_buffer.shape(), host_buffer.pretty() });
// This is true because of the operand aliases.
// Since the result is already pointing to the input we don't need to modify the buffer.
std.debug.assert(host_buffer._data == outputs[0]._data);
return callback.call(callback.Print, .{input}, &.{input.shape()})[0];
}
};

View File

@ -51,14 +51,13 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
if (should_free_left) left.deinit(allocator);
if (should_free_right) right.deinit(allocator);
}
errdefer log.err("\n--> Left: {f}\n--> Right: {f}", .{ left.pretty(), right.pretty() });
errdefer log.err("\n--> Left: {0f}{0d:24.3}\n--> Right: {1f}{1d:24.3}", .{ left, right });
if (!std.mem.eql(i64, left.shape().dims(), right.shape().dims())) {
log.err("left.shape() {f} != right.shape() {f}", .{ left.shape(), right.shape() });
return error.TestUnexpectedResult;
}
if (left.dtype() != right.dtype() and !(left.dtype() == .f16 and right.dtype() == .bf16)) {
log.err("left.dtype ({}) != right.dtype ({})", .{ left.dtype(), right.dtype() });
log.err("left.dtype ({f}) != right.dtype ({f})", .{ left.shape(), right.shape() });
return error.TestUnexpectedResult;
}
switch (left.dtype()) {
@ -89,7 +88,7 @@ pub fn expectClose(left_: anytype, right_: anytype, tolerance: f32) !void {
const right_data = right.items(R);
for (left_data, right_data, 0..) |l, r, i| {
if (!approxEq(f32, zml.floats.floatCast(f32, l), zml.floats.floatCast(f32, r), tolerance)) {
log.err("left.data != right_data.\n < {any:.3} \n > {any:.3}\n error at idx {any}: {any:.3} != {any:.3}", .{ center(left_data, i), center(right_data, i), i, left_data[i], right_data[i] });
log.err("left.data != right_data.\n < {d:40.3} \n > {d:40.3}\n error at idx {d}: {d:.3} != {d:.3}", .{ stdx.fmt.slice(center(left_data, i)), stdx.fmt.slice(center(right_data, i)), i, left_data[i], right_data[i] });
return error.TestUnexpectedResult;
}
}

View File

@ -6,11 +6,13 @@
// Namespaces
const std = @import("std");
pub const platform_specific = @import("c");
pub const tokenizer = @import("zml/tokenizer");
pub const aio = @import("aio.zig");
pub const Buffer = @import("buffer.zig").Buffer;
pub const Bufferized = @import("tensor.zig").Bufferized;
pub const callback = @import("callback.zig");
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
pub const context = @import("context.zig");
pub const Context = @import("context.zig").Context;
@ -43,7 +45,6 @@ pub const Tensor = @import("tensor.zig").Tensor;
pub const testing = @import("testing.zig");
pub const torch = @import("torch.zig");
// pub const tokenizer = @import("tokenizer.zig");
pub const tools = struct {
pub const Tracer = @import("tools/tracer.zig").Tracer;
};