pjrt: add FFI bindings for custom calls

This commit is contained in:
Tarry Singh 2024-09-10 09:14:28 +00:00
parent 1f5ff96c10
commit aec7072837
25 changed files with 1319 additions and 376 deletions

View File

@ -24,7 +24,7 @@ bazel_dep(name = "rules_zig", version = "20250314.0-b9739c6")
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a") bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
bazel_dep(name = "toolchains_protoc", version = "0.3.7") bazel_dep(name = "toolchains_protoc", version = "0.3.7")
bazel_dep(name = "with_cfg.bzl", version = "0.9.1") bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
bazel_dep(name = "xla", version = "20250317.0-71c67e2") bazel_dep(name = "xla", version = "20250317.1-71c67e2")
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e") bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf") bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")

View File

@ -739,18 +739,13 @@ pub const CustomCallOpts = struct {
typed_ffi = 4, typed_ffi = 4,
}; };
pub const BackendConfig = union(enum) {
string: [:0]const u8,
dict: mlir.DictionaryAttribute,
};
call_target_name: [:0]const u8, call_target_name: [:0]const u8,
has_side_effect: bool, has_side_effect: bool,
backend_config: BackendConfig = .{ .string = &.{} }, backend_config: ?mlir.Attribute,
operand_layouts: ?[]const []const usize = null, operand_layouts: ?[]const []const usize = null,
result_layouts: ?[]const []const usize = null, result_layouts: ?[]const []const usize = null,
output_operand_aliases: []const i64 = &.{}, output_operand_aliases: []const i64 = &.{},
addional_attributes: []const mlir.AttrTuple = &.{}, additional_attributes: []const mlir.AttrTuple = &.{},
api_version: ApiVersion, api_version: ApiVersion,
}; };
@ -758,68 +753,56 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
const MAX_OPERANDS = 64; const MAX_OPERANDS = 64;
const MAX_RESULTS = 16; const MAX_RESULTS = 16;
const output_operand_aliases = blk: { const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, "");
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
for (opts.output_operand_aliases) |alias| { stdx.debug.assert(
ret.appendAssumeCapacity( backend_config.is_a(mlir.StringAttribute),
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), "API version < 4 requires a string as backend_config, got {}",
.{backend_config},
);
} else {
stdx.debug.assert(
backend_config.is_a(mlir.DictionaryAttribute),
"API version >= 4 requires a dictionary as backend_config, got {}",
.{backend_config},
); );
} }
break :blk ret;
};
const backend_config: mlir.Attribute = switch (opts.backend_config) {
.string => blk: {
stdx.debug.assert(
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
"Only API version of less than 4 is supported for backend_config as string",
.{},
);
break :blk .string(ctx, opts.backend_config.string);
},
.dict => blk: {
stdx.debug.assert(
opts.api_version == .typed_ffi,
"Only API version 4 is supported for backend_config as dictionary",
.{},
);
break :blk opts.backend_config.dict.as(mlir.Attribute);
},
};
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
.{ "call_target_name", .string(ctx, opts.call_target_name) }, .{ "call_target_name", .string(ctx, opts.call_target_name) },
.{ "has_side_effect", .boolean(ctx, opts.has_side_effect) }, .{ "has_side_effect", .boolean(ctx, opts.has_side_effect) },
.{ "backend_config", backend_config }, .{ "backend_config", backend_config },
.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) },
}); });
if (opts.operand_layouts) |layouts| { {
const operand_layouts = blk: { var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; for (opts.output_operand_aliases) |alias| {
for (layouts) |ol| { output_operand_aliases.appendAssumeCapacity(
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
);
}
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
}
if (opts.operand_layouts) |layouts| {
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
for (layouts) |ol| {
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
} }
break :blk ret;
};
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
} }
if (opts.result_layouts) |layouts| { if (opts.result_layouts) |layouts| {
const result_layouts = blk: { var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
for (layouts) |rl| { for (layouts) |rl| {
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
} }
break :blk ret;
};
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
} }
attrs.appendSliceAssumeCapacity(opts.addional_attributes); attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs, .operands = inputs,
@ -829,22 +812,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
}); });
} }
// todo: move out of stablehlo.zig when we start to implement the frontend
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
const frontend_attributes = mlir.DictionaryAttribute.init(
ctx,
&.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())},
).asAttr();
return custom_call(ctx, inputs, .{
.call_target_name = "annotate_device_placement",
.has_side_effect = true,
.backend_config = .{ .string = &.{} },
.addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
.api_version = .original,
}, res_types, location);
}
pub const DotDimensionNumbersAttribute = struct { pub const DotDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,

View File

@ -414,6 +414,18 @@ pub const Attribute = struct {
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute { pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
return NamedAttribute.init(Identifier.get(ctx, name), attr); return NamedAttribute.init(Identifier.get(ctx, name), attr);
} }
pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute {
var attr_buf: [32]NamedAttribute = undefined;
stdx.debug.assert(named_attrs.len <= attr_buf.len, ".dict helper only support up to {} attribute, got {}. You will need to call mlir.DictionaryAttribute directly", .{ attr_buf.len, named_attrs.len });
const attrs = attr_buf[0..named_attrs.len];
for (attrs, named_attrs) |*attr, tuple| {
attr.* = .named(ctx, tuple[0], tuple[1]);
}
return DictionaryAttribute.init(ctx, attrs).asAttr();
}
}; };
pub const NamedAttribute = extern struct { pub const NamedAttribute = extern struct {

View File

@ -1,3 +1,4 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_library")
load("@zml//bazel:zig.bzl", "zig_cc_binary") load("@zml//bazel:zig.bzl", "zig_cc_binary")
load("//bazel:zig_proto_library.bzl", "zig_proto_library") load("//bazel:zig_proto_library.bzl", "zig_proto_library")
@ -12,7 +13,12 @@ cc_library(
zig_library( zig_library(
name = "pjrt", name = "pjrt",
srcs = ["profiler.zig"] + glob(["convert/*.zig"]), srcs = [
"ffi.zig",
"profiler.zig",
"convert/trace_container.zig",
"convert/xplane_schema.zig"
],
main = "pjrt.zig", main = "pjrt.zig",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [
@ -20,9 +26,12 @@ zig_library(
":trace_events_proto", ":trace_events_proto",
":xplane_proto", ":xplane_proto",
"//stdx", "//stdx",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
] + select({ ] + select({
"@platforms//os:linux": [":dlfcn"], "@platforms//os:linux": [":dlfcn"],
"//conditions:default": [], "//conditions:default": [],
@ -49,7 +58,7 @@ zig_proto_library(
zig_cc_binary( zig_cc_binary(
name = "xspace_to_json", name = "xspace_to_json",
srcs = glob(["convert/*.zig"]), srcs = ["convert/trace_container.zig", "convert/xplane_schema.zig"],
main = "xspace_to_json.zig", main = "xspace_to_json.zig",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = [ deps = [

View File

@ -1,8 +1,9 @@
const std = @import("std"); const std = @import("std");
const trace_events_proto = @import("//tsl:trace_events_proto"); const trace_events_proto = @import("//tsl:trace_events_proto");
const xplane_proto = @import("//tsl:xplane_proto"); const xplane_proto = @import("//tsl:xplane_proto");
const xplane_schema = @import("xplane_schema.zig"); const xplane_schema = @import("xplane_schema.zig");
const xplane_visitor = @import("xplane_visitor.zig");
// Constants used as trace_viewer PID (device_id in trace_events.proto). // Constants used as trace_viewer PID (device_id in trace_events.proto).
// PID 0 is unused. // PID 0 is unused.
@ -87,7 +88,7 @@ pub const TraceContainer = struct {
} }
} }
fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const xplane_visitor.XPlaneVisitor) !void { fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const XPlaneHashed) !void {
// Convert devices and resources. // Convert devices and resources.
const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id }); const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id });
var device = device_entry.value_ptr.*; var device = device_entry.value_ptr.*;
@ -156,7 +157,7 @@ pub const TraceContainer = struct {
fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void { fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void {
if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| { if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| {
const xplane = try xplane_visitor.XPlaneVisitor.init(allocator, hp); const xplane = try XPlaneHashed.init(allocator, hp);
try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane); try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane);
} }
@ -173,7 +174,7 @@ pub const TraceContainer = struct {
} }
for (device_planes.items) |dp| { for (device_planes.items) |dp| {
var xplane = try xplane_visitor.XPlaneVisitor.init(allocator, dp); var xplane = try XPlaneHashed.init(allocator, dp);
defer xplane.deinit(allocator); defer xplane.deinit(allocator);
const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id)); const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id));
try self.xplaneToTraceEvents(allocator, device_id, &xplane); try self.xplaneToTraceEvents(allocator, device_id, &xplane);
@ -298,3 +299,68 @@ pub const TraceContainer = struct {
fn picoToMicro(p: anytype) f64 { fn picoToMicro(p: anytype) f64 {
return @as(f64, @floatFromInt(p)) / 1E6; return @as(f64, @floatFromInt(p)) / 1E6;
} }
pub const XPlaneHashed = struct {
plane: *const xplane_proto.XPlane,
event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{},
stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{},
pub fn init(
allocator: std.mem.Allocator,
plane: *const xplane_proto.XPlane,
) !XPlaneHashed {
var res: XPlaneHashed = .{ .plane = plane };
try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len));
// build event metadata map
for (plane.event_metadata.items) |*event_metadata| {
res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?);
}
// build stat metadata map
try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len));
for (plane.stat_metadata.items) |*stat_metadata| {
res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?);
}
return res;
}
pub fn deinit(self: *XPlaneHashed, allocator: std.mem.Allocator) void {
self.stat_metadata_by_id.deinit(allocator);
self.event_metadata_by_id.deinit(allocator);
}
pub fn name(self: XPlaneHashed) []const u8 {
return self.plane.name.getSlice();
}
pub fn getEventType(self: XPlaneHashed, event_metadata_id: i64) xplane_schema.HostEventType {
if (self.event_metadata_by_id.get(event_metadata_id)) |v| {
return xplane_schema.HostEventType.fromString(v.name.getSlice());
} else return .unknown;
}
pub fn getStatMetadataName(self: XPlaneHashed, stat_metadata_id: i64) []const u8 {
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
return v.name.getSlice();
} else return &[_]u8{};
}
pub fn getStatType(self: XPlaneHashed, stat_metadata_id: i64) xplane_schema.StatType {
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
return xplane_schema.StatType.fromString(v.name.getSlice());
} else return .unknown;
}
pub fn xstatToString(self: XPlaneHashed, stat: xplane_proto.XStat, writer: anytype) !void {
if (stat.value == null) return;
switch (stat.value.?) {
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
.str_value => |*v| try writer.writeAll(v.getSlice()),
.bytes_value => try writer.writeAll("<opaque bytes>"),
.ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))),
}
}
};

View File

@ -1,68 +0,0 @@
const std = @import("std");
const xplane_proto = @import("//tsl:xplane_proto");
const xplane_schema = @import("xplane_schema.zig");
pub const XPlaneVisitor = struct {
plane: *const xplane_proto.XPlane,
event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{},
stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{},
pub fn init(
allocator: std.mem.Allocator,
plane: *const xplane_proto.XPlane,
) !XPlaneVisitor {
var res: XPlaneVisitor = .{ .plane = plane };
try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len));
// build event metadata map
for (plane.event_metadata.items) |*event_metadata| {
res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?);
}
// build stat metadata map
try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len));
for (plane.stat_metadata.items) |*stat_metadata| {
res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?);
}
return res;
}
pub fn deinit(self: *XPlaneVisitor, allocator: std.mem.Allocator) void {
self.stat_metadata_by_id.deinit(allocator);
self.event_metadata_by_id.deinit(allocator);
}
pub fn name(self: XPlaneVisitor) []const u8 {
return self.plane.name.getSlice();
}
pub fn getEventType(self: XPlaneVisitor, event_metadata_id: i64) xplane_schema.HostEventType {
if (self.event_metadata_by_id.get(event_metadata_id)) |v| {
return xplane_schema.HostEventType.fromString(v.name.getSlice());
} else return .unknown;
}
pub fn getStatMetadataName(self: XPlaneVisitor, stat_metadata_id: i64) []const u8 {
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
return v.name.getSlice();
} else return &[_]u8{};
}
pub fn getStatType(self: XPlaneVisitor, stat_metadata_id: i64) xplane_schema.StatType {
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
return xplane_schema.StatType.fromString(v.name.getSlice());
} else return .unknown;
}
pub fn xstatToString(self: XPlaneVisitor, stat: xplane_proto.XStat, writer: anytype) !void {
if (stat.value == null) return;
switch (stat.value.?) {
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
.str_value => |*v| try writer.writeAll(v.getSlice()),
.bytes_value => try writer.writeAll("<opaque bytes>"),
.ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))),
}
}
};

517
pjrt/ffi.zig Normal file
View File

@ -0,0 +1,517 @@
/// Bindings for PJRT custom call declaration / execution.
const std = @import("std");
const c = @import("c");
const stdx = @import("stdx");
const pjrtStruct = @import("pjrt.zig").pjrtStruct;
const log = std.log.scoped(.pjrt);
pub const ApiVersion = extern struct {
pub const major = c.XLA_FFI_API_MAJOR;
pub const minor = c.XLA_FFI_API_MINOR;
struct_size: usize,
extension_start: ?*ExtensionBase,
major_version: i32,
minor_version: i32,
};
pub const ExtensionType = enum(c.XLA_FFI_Extension_Type) {
metadata = c.XLA_FFI_Extension_Metadata,
};
pub const ExtensionBase = extern struct {
struct_size: usize,
type: ExtensionType,
next: ?*ExtensionBase,
};
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
pub const HandlerTraits = packed struct(u32) {
/// Calls to FFI handler are safe to trace into the command buffer.
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
/// that can be captured and then replayed.
command_buffer_compatible: u1,
__unassigned__: u31,
};
pub const Metadata = extern struct {
struct_size: usize,
api_version: ApiVersion,
traits: HandlerTraits,
};
pub const MetadataExtension = extern struct {
extension_base: ExtensionBase,
metadata: ?*Metadata,
};
pub const ApiError = error{
Cancelled,
Unknown,
InvalidArgument,
DeadlineExceeded,
NotFound,
AlreadyExists,
PermissionDenied,
ResourceExhausted,
FailedPrecondition,
Aborted,
OutOfRange,
Unimplemented,
Internal,
Unavailable,
DataLoss,
Unauthenticated,
};
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
return struct {
pub fn to(self: anytype) switch (@TypeOf(self)) {
*T => *InnerT,
*const T => *const InnerT,
else => unreachable,
} {
return @ptrCast(@alignCast(self));
}
pub fn from(self: anytype) switch (@TypeOf(self)) {
*InnerT => *T,
*const InnerT => *const T,
else => unreachable,
} {
return @ptrCast(@alignCast(self));
}
};
}
pub const Api = opaque {
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
});
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.stream.?;
}
pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
.size = size,
.alignment = alignment,
});
const result = self.inner().XLA_FFI_DeviceMemory_Allocate.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.allocateDeviceMemory] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.data.?;
}
pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
.ctx = if (context) |ctx| ctx.inner() else null,
.size = size,
.data = data,
});
const result = self.inner().XLA_FFI_DeviceMemory_Free.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(self);
log.err("[Api.freeDeviceMemory] {s}", .{err.getMessage(self)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
// TODO(Corentin): Implement remaining methods if needed:
// * XLA_FFI_ThreadPool_Schedule
// * XLA_FFI_Handler_Register
// * XLA_FFI_TypeId_Register
// * XLA_FFI_State_Set
// * XLA_FFI_State_Get
};
pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
instantiate = c.XLA_FFI_ExecutionStage_INSTANTIATE,
prepare = c.XLA_FFI_ExecutionStage_PREPARE,
initialize = c.XLA_FFI_ExecutionStage_INITIALIZE,
execute = c.XLA_FFI_ExecutionStage_EXECUTE,
};
pub const ExecutionContext = opaque {
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
// pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void {
// // register type id ==> typeid
// const typename_ = "zml." ++ @typeName(@TypeOf(value));
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
// .ctx = self.inner(),
// .handler = @ptrCast(@alignCast(handler)),
// });
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
// if (result) |ffi_error| {
// const err = Error.fromInner(ffi_error);
// defer err.destroy(api);
// log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)});
// // TODO(Corentin): Retrieve error code from Error when implemented in XLA.
// return error.Unknown;
// }
// }
pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque {
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
.ctx = self.inner(),
.type_id = @ptrCast(@alignCast(type_id)),
});
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return ret.data.?;
}
// TODO getDeviceOrdinal()
};
const ByteSpan = extern struct {
ptr: [*]const u8,
len: usize,
pub fn slice(self: ByteSpan) []const u8 {
return self.ptr[0..self.len];
}
};
pub const TypeId = extern struct {
type_id: i64,
};
pub const DataType = enum(c.XLA_FFI_DataType) {
invalid = c.XLA_FFI_DataType_INVALID,
pred = c.XLA_FFI_DataType_PRED,
s8 = c.XLA_FFI_DataType_S8,
s16 = c.XLA_FFI_DataType_S16,
s32 = c.XLA_FFI_DataType_S32,
s64 = c.XLA_FFI_DataType_S64,
u8 = c.XLA_FFI_DataType_U8,
u16 = c.XLA_FFI_DataType_U16,
u32 = c.XLA_FFI_DataType_U32,
u64 = c.XLA_FFI_DataType_U64,
f16 = c.XLA_FFI_DataType_F16,
f32 = c.XLA_FFI_DataType_F32,
f64 = c.XLA_FFI_DataType_F64,
bf16 = c.XLA_FFI_DataType_BF16,
c64 = c.XLA_FFI_DataType_C64,
c128 = c.XLA_FFI_DataType_C128,
token = c.XLA_FFI_DataType_TOKEN,
f8e5m2 = c.XLA_FFI_DataType_F8E5M2,
f8e3m4 = c.XLA_FFI_DataType_F8E3M4,
f8e4m3 = c.XLA_FFI_DataType_F8E4M3,
f8e4m3fn = c.XLA_FFI_DataType_F8E4M3FN,
f8e4m3b11fnuz = c.XLA_FFI_DataType_F8E4M3B11FNUZ,
f8e5m2fnuz = c.XLA_FFI_DataType_F8E5M2FNUZ,
f8e4m3fnuz = c.XLA_FFI_DataType_F8E4M3FNUZ,
};
pub const Buffer = extern struct {
struct_size: usize,
extension_start: ?*c.XLA_FFI_Extension_Base,
dtype: DataType,
data: [*]u8,
rank: u64,
_dims: [*]const i64,
pub fn dims(self: Buffer) []const i64 {
return self._dims[0..self.rank];
}
pub fn format(
buffer: Buffer,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) });
}
};
pub const Args = extern struct {
struct_size: usize,
extension_start: ?*const c.XLA_FFI_Extension_Base,
len: u64,
types: [*]const Type,
ptr: [*]*const Buffer,
pub const Type = enum(c.XLA_FFI_ArgType) {
buffer = c.XLA_FFI_ArgType_BUFFER,
};
pub fn get(self: Args, i: usize) *const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer);
return self.ptr[0..self.len][i];
}
};
pub const Rets = extern struct {
struct_size: usize,
extension_start: ?*const c.XLA_FFI_Extension_Base,
len: u64,
types: [*]const Type,
ptr: [*]*const Buffer,
pub const Type = enum(c.XLA_FFI_RetType) {
buffer = c.XLA_FFI_RetType_BUFFER,
};
pub fn get(self: Rets, i: usize) *const Buffer {
std.debug.assert(self.types[0..self.len][i] == .buffer);
return self.ptr[0..self.len][i];
}
};
pub const AttrType = enum(c.XLA_FFI_AttrType) {
array = c.XLA_FFI_AttrType_ARRAY,
dictionary = c.XLA_FFI_AttrType_DICTIONARY,
scalar = c.XLA_FFI_AttrType_SCALAR,
string = c.XLA_FFI_AttrType_STRING,
};
pub const Attrs = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
len: u64,
types: [*]const AttrType,
names: [*]const *const ByteSpan,
ptr: [*]const *const Attr,
const Attr = extern union {
scalar: Scalar,
array: Array,
};
pub const Scalar = extern struct {
dtype: DataType,
value: *const anyopaque,
pub fn get(self: Scalar, T: type) T {
const ptr: *const T = @alignCast(@ptrCast(self.value));
return ptr.*;
}
};
pub const Array = extern struct {
dtype: DataType,
len: usize,
data: [*]const u8,
};
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
const attr = self.ptr[0..self.len][index];
const actual_type = self.types[index];
if (actual_type != attr_type) return null;
return @ptrCast(attr);
}
pub fn getByName(self: Attrs, comptime attr_type: AttrType, name: []const u8) ?*const @FieldType(Attr, @tagName(attr_type)) {
const names = self.names[0..self.len];
for (0.., names) |i, attr_name| {
if (std.mem.eql(u8, attr_name.slice(), name)) {
return self.getByIndex(attr_type, i);
}
}
return null;
}
};
pub const CallFrame = extern struct {
struct_size: usize,
extension_start: ?*ExtensionBase,
api: ?*const Api,
ctx: ?*const ExecutionContext,
stage: ExecutionStage,
args: Args,
results: Rets,
attrs: Attrs,
future: ?*Future,
/// The registery mechanism will first call the custom call in registration mode,
/// and expects us to indicate which version of XLA we have been compiled against.
/// Returns true if we registered ourselves and if the caller custom call should return early.
pub fn registeringHook(call_frame: *CallFrame) bool {
if (call_frame.extension_start != null and call_frame.extension_start.?.type == .metadata) {
const metadata_extension: *MetadataExtension = @fieldParentPtr("extension_base", call_frame.extension_start.?);
metadata_extension.metadata.?.api_version.major_version = ApiVersion.major;
metadata_extension.metadata.?.api_version.minor_version = ApiVersion.minor;
return true;
}
return false;
}
};
pub const Handler = fn (*CallFrame) callconv(.C) ?*Error;
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
unknown = c.XLA_FFI_Error_Code_UNKNOWN,
invalid_argument = c.XLA_FFI_Error_Code_INVALID_ARGUMENT,
deadline_exceeded = c.XLA_FFI_Error_Code_DEADLINE_EXCEEDED,
not_found = c.XLA_FFI_Error_Code_NOT_FOUND,
already_exists = c.XLA_FFI_Error_Code_ALREADY_EXISTS,
permission_denied = c.XLA_FFI_Error_Code_PERMISSION_DENIED,
resource_exhausted = c.XLA_FFI_Error_Code_RESOURCE_EXHAUSTED,
failed_precondition = c.XLA_FFI_Error_Code_FAILED_PRECONDITION,
aborted = c.XLA_FFI_Error_Code_ABORTED,
out_of_range = c.XLA_FFI_Error_Code_OUT_OF_RANGE,
unimplemented = c.XLA_FFI_Error_Code_UNIMPLEMENTED,
internal = c.XLA_FFI_Error_Code_INTERNAL,
unavailable = c.XLA_FFI_Error_Code_UNAVAILABLE,
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
pub fn toApiError(code: ErrorCode) ApiError {
return switch (code) {
.cancelled => error.Cancelled,
.unknown => error.Unknown,
.invalid_argument => error.InvalidArgument,
.deadline_exceeded => error.DeadlineExceeded,
.not_found => error.FfiNotFound,
.already_exists => error.AlreadyExists,
.permission_denied => error.PermissionDenied,
.resource_exhausted => error.ResourceExhausted,
.failed_precondition => error.FailedPrecondition,
.aborted => error.Aborted,
.out_of_range => error.OutOfRange,
.unimplemented => error.Unimplemented,
.internal => error.Internal,
.unavailable => error.Unavailable,
.data_loss => error.DataLoss,
.unauthenticated => error.Unauthenticated,
};
}
};
pub const Error = opaque {
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error {
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
.message = message.ptr,
.errc = @intFromEnum(error_code),
});
return fromInner(api.inner().XLA_FFI_Error_Create.?(&ret).?);
}
pub fn destroy(err: *Error, api: *const Api) void {
var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
api.inner().XLA_FFI_Error_Destroy.?(&ret);
}
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
.@"error" = err.inner(),
});
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
return std.mem.span(ret.message);
}
};
pub const Future = opaque {
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
pub fn create(api: *const Api) ApiError!*Future {
var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{});
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[Future.create] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
return fromInner(ret.future.?);
}
pub fn setAvailable(self: *Future, api: *const Api) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
.future = self.inner(),
});
const result = api.inner().XLA_FFI_Future_SetAvailable.?(&ret);
if (result) |ffi_error| {
const err = Error.fromInner(ffi_error);
defer err.destroy(api);
log.err("[Future.setAvailable] {s}", .{err.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void {
var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{
.future = self.inner(),
.@"error" = err.inner(),
});
const result = api.inner().XLA_FFI_Future_SetError.?(&ret);
if (result) |ffi_error| {
const err2 = Error.fromInner(ffi_error);
defer err2.destroy(api);
log.err("[Future.setError] {s}", .{err2.getMessage(api)});
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
return error.Unknown;
}
}
};

View File

@ -1,13 +1,14 @@
const builtin = @import("builtin");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const builtin = @import("builtin");
const c = @import("c"); const c = @import("c");
const stdx = @import("stdx");
pub const ffi = @import("ffi.zig");
pub const Profiler = @import("profiler.zig").Profiler;
const log = std.log.scoped(.pjrt); const log = std.log.scoped(.pjrt);
pub const Profiler = @import("profiler.zig").Profiler;
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }
@ -160,10 +161,9 @@ pub const Api = struct {
} }
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
return .{ .inner = ext.custom_call.? }; return .{ .inner = ext.register_handler.? };
} }
// log.warn("No Custom Call registry found for platform: {}", .{self});
return null; return null;
} }
@ -405,7 +405,7 @@ pub const Client = opaque {
} }
pub const CreateViewOfDeviceBufferArgs = struct { pub const CreateViewOfDeviceBufferArgs = struct {
data: []const u8, data: *anyopaque,
dims: []const i64, dims: []const i64,
element_type: BufferType, element_type: BufferType,
layout: MemoryLayout, layout: MemoryLayout,
@ -421,7 +421,7 @@ pub const Client = opaque {
const layout = args.layout.toCStruct(); const layout = args.layout.toCStruct();
const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{ const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{
.client = self.inner(), .client = self.inner(),
.device_buffer_ptr = @ptrCast(@constCast(args.data.ptr)), .device_buffer_ptr = @ptrCast(@constCast(args.data)),
.dims = args.dims.ptr, .dims = args.dims.ptr,
.num_dims = args.dims.len, .num_dims = args.dims.len,
.element_type = @intFromEnum(args.element_type), .element_type = @intFromEnum(args.element_type),
@ -919,18 +919,14 @@ pub const Memory = opaque {
const inner = InnerMixin(c.PJRT_Memory).inner; const inner = InnerMixin(c.PJRT_Memory).inner;
pub fn id(self: *const Memory, api: *const Api) usize { pub fn id(self: *const Memory, api: *const Api) usize {
const ret = api.call(.PJRT_Memory_Id, .{ const ret = api.call(.PJRT_Memory_Id, .{ .memory = self.inner() }) catch unreachable;
.memory = self.inner(),
}) catch unreachable;
return @intCast(ret.id); return @intCast(ret.id);
} }
pub fn kind(self: *const Memory, api: *const Api) Kind { pub fn kind(self: *const Memory, api: *const Api) Kind {
const ret = api.call(.PJRT_Memory_Kind, .{ const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
.memory = self.inner(), const kind_ = ret.kind orelse unreachable[0..ret.kind_size];
}) catch unreachable; return std.meta.stringToEnum(Kind, kind_) orelse unreachable;
const kind_ = ret.kind orelse unreachable;
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
} }
pub fn kindId(self: *const Memory, api: *const Api) u32 { pub fn kindId(self: *const Memory, api: *const Api) u32 {
@ -1161,23 +1157,25 @@ pub const NamedValue = extern struct {
} }
}; };
/// Custom call signature arguments are:
/// * a pointer to a platform specific stream handle
/// * a pointer to an unspecified list of platform specific buffer handle
/// * a context struct passed as a slice of bytes
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
pub const CustomCallRegistry = extern struct { pub const CustomCallRegistry = extern struct {
inner: *const c.PJRT_Gpu_Register_Custom_Call, inner: *const c.PJRT_FFI_Register_Handler,
pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void { pub fn registerFfi(
var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{ self: *const CustomCallRegistry,
.function_name = name.ptr, api: *const Api,
.function_name_size = name.len, target_name: []const u8,
.api_version = @intCast(api_version), platform_name: []const u8,
.handler_execute = @ptrCast(@constCast(func)), func: *const ffi.Handler,
) ApiError!void {
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
.api_version = 1,
.target_name = target_name.ptr,
.target_name_size = target_name.len,
.handler = @ptrCast(@constCast(func)),
.platform_name = platform_name.ptr,
.platform_name_size = platform_name.len,
}); });
const result = self.inner(&ret); const result = self.inner(&ret);
if (result) |pjrt_c_error| { if (result) |pjrt_c_error| {

View File

@ -6,3 +6,9 @@ pub const math = @import("math.zig");
pub const meta = @import("meta.zig"); pub const meta = @import("meta.zig");
pub const queue = @import("queue.zig"); pub const queue = @import("queue.zig");
pub const time = @import("time.zig"); pub const time = @import("time.zig");
pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T {
debug.assert(len <= max_len, "stackSlice can only create a slice of up to {} elements, got: {}", .{ max_len, len });
var storage: [max_len]T = undefined;
return storage[0..len];
}

View File

@ -0,0 +1,37 @@
module(
name = "xla",
version = "20250317.1-71c67e2",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.17")
bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl", "python_version_repo")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
"llvm-raw",
"stablehlo",
)
llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project")

View File

@ -0,0 +1,37 @@
module(
name = "xla",
version = "20250317.1-71c67e2",
compatibility_level = 1,
)
bazel_dep(name = "platforms", version = "0.0.8")
bazel_dep(name = "bazel_skylib", version = "1.5.0")
bazel_dep(name = "rules_cc", version = "0.0.17")
bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple")
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
bazel_dep(name = "rules_python", version = "0.29.0")
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
bazel_dep(name = "rules_java", version = "7.3.2")
bazel_dep(name = "rules_pkg", version = "0.9.1")
bazel_dep(name = "zlib", version = "1.2.13")
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
bazel_dep(name = "rules_license", version = "0.0.8")
tsl = use_extension("//:tsl.bzl", "tsl")
use_repo(tsl, "tsl", "python_version_repo")
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
use_repo(
xla_workspace,
"com_github_grpc_grpc",
"com_google_protobuf",
"local_config_cuda",
"local_config_remote_execution",
"local_config_rocm",
"local_config_tensorrt",
"llvm-raw",
"stablehlo",
)
llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
llvm_configure(name = "llvm-project")

View File

@ -0,0 +1,19 @@
load("//third_party:repo.bzl", "tf_vendored")
load("//third_party/py:python_init_repositories.bzl", "python_init_repositories")
def _tsl_impl(mctx):
python_init_repositories(
requirements = {
"3.11": "//:requirements_lock_3_11.txt",
},
)
tf_vendored(name = "tsl", relpath = "third_party/tsl")
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = ["tsl"],
root_module_direct_dev_deps = [],
)
tsl = module_extension(
implementation = _tsl_impl,
)

View File

@ -0,0 +1,60 @@
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
load("//third_party/llvm:workspace.bzl", llvm = "repo")
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
load("//third_party/triton:workspace.bzl", triton = "repo")
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
def _xla_workspace_impl(mctx):
cuda_configure(name = "local_config_cuda")
remote_execution_configure(name = "local_config_remote_execution")
rocm_configure(name = "local_config_rocm")
tensorrt_configure(name = "local_config_tensorrt")
pybind11_bazel()
triton()
llvm("llvm-raw")
stablehlo()
tf_http_archive(
name = "com_github_grpc_grpc",
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
system_build_file = "//third_party/systemlibs:grpc.BUILD",
patch_file = [
"//third_party/grpc:generate_cc_env_fix.patch",
"//third_party/grpc:register_go_toolchain.patch",
],
system_link_files = {
"//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel",
"//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
"//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
"//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
"//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
"//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
"//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
},
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
)
tf_http_archive(
name = "com_google_protobuf",
patch_file = ["//third_party/protobuf:protobuf.patch"],
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
strip_prefix = "protobuf-3.21.9",
system_build_file = "//third_party/systemlibs:protobuf.BUILD",
system_link_files = {
"//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
"//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
},
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
)
return mctx.extension_metadata(
reproducible = True,
root_module_direct_deps = "all",
root_module_direct_dev_deps = [],
)
xla_workspace = module_extension(
implementation = _xla_workspace_impl,
)

View File

@ -0,0 +1,41 @@
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
From: Jean-Baptiste Dalido <jb@zml.ai>
Date: Mon, 6 Jan 2025 13:33:13 +0100
Subject: [PATCH] bazel: migration to bazel 8.0.1
---
.bazelversion | 2 +-
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/.bazelversion b/.bazelversion
index f22d756da3..fa5fce04b3 100644
--- a/.bazelversion
+++ b/.bazelversion
@@ -1 +1 @@
-7.4.1
+8.1.1
\ No newline at end of file
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index d62531152d..71d80a5a99 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
load(
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
"escape_string",
- "get_env_var",
)
load(
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
- "find_msvc_tool",
"find_vc_path",
"setup_vc_env_vars",
)
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
load(
"//third_party/remote_config:common.bzl",
--
2.39.3 (Apple Git-146)

View File

@ -0,0 +1,131 @@
From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001
From: Hugo Mano <hugo@zml.ai>
Date: Wed, 5 Feb 2025 19:25:03 +0100
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
PR: https://github.com/openxla/xla/pull/13420
---
xla/pjrt/c/BUILD | 5 ++++
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++--
3 files changed, 54 insertions(+), 2 deletions(-)
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
index ad2ed95bce..0e7c35c30f 100644
--- a/xla/pjrt/c/BUILD
+++ b/xla/pjrt/c/BUILD
@@ -69,7 +69,12 @@ cc_library(
":pjrt_c_api_wrapper_impl",
"//xla/ffi:execution_context",
"//xla/ffi:type_id_registry",
+ "//xla/ffi:ffi_api",
+ "//xla/ffi/api:c_api",
+ "//xla/ffi/api:ffi",
+ "//xla/service:custom_call_target_registry",
"@com_google_absl//absl/status",
+ "@com_google_absl//absl/strings:str_format",
],
)
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
index a33bd4aa9c..3309194538 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
// Adds a user data to the execute context.
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
+struct PJRT_FFI_Register_Handler_Args {
+ size_t struct_size;
+ const char* target_name;
+ size_t target_name_size;
+ int api_version; // 0 for an untyped call, 1 -- for typed
+ void* handler;
+ const char* platform_name;
+ size_t platform_name_size;
+};
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler);
+
+// Registers an FFI call handler for a specific platform.
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args);
+
typedef struct PJRT_FFI_Extension {
size_t struct_size;
PJRT_Extension_Type type;
PJRT_Extension_Base* next;
PJRT_FFI_TypeID_Register* type_id_register;
PJRT_FFI_UserData_Add* user_data_add;
+ PJRT_FFI_Register_Handler* register_handler;
} PJRT_FFI;
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
index 0375b39d0b..3527a0756e 100644
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
+#include <string>
#include "absl/status/status.h"
+#include "absl/strings/str_format.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/ffi.h"
#include "xla/ffi/execution_context.h"
-#include "xla/ffi/type_id_registry.h"
+ #include "xla/ffi/type_id_registry.h"
+#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
+#include "xla/service/custom_call_target_registry.h"
namespace pjrt {
@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
return nullptr;
}
+static PJRT_Error* PJRT_FFI_Register_Handler(
+ PJRT_FFI_Register_Handler_Args* args) {
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
+ "PJRT_FFI_Register_Handler_Args",
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
+ std::string target_name(args->target_name, args->target_name_size);
+ std::string platform_name(args->platform_name, args->platform_name_size);
+ switch (args->api_version) {
+ case 0:
+ xla::CustomCallTargetRegistry::Global()->Register(
+ target_name, args->handler, platform_name);
+ return nullptr;
+ case 1:
+ xla::ffi::Ffi::RegisterStaticHandler(
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
+ return nullptr;
+ default:
+ return new PJRT_Error{absl::UnimplementedError(
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
+ "Supported versions are 0 and 1.",
+ args->api_version))};
+ }
+}
+
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
return {
/*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE,
@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
/*next=*/next,
/*type_id_register=*/PJRT_FFI_TypeID_Register,
/*user_data_add=*/PJRT_FFI_UserData_Add,
+ /*register_handler=*/PJRT_FFI_Register_Handler,
};
}
--
2.43.0

View File

@ -0,0 +1,15 @@
{
"strip_prefix": "xla-71c67e2a4f40267115a0d4ea7c36748bbe7e750e",
"url": "https://github.com/openxla/xla/archive/71c67e2a4f40267115a0d4ea7c36748bbe7e750e.tar.gz",
"integrity": "sha256-j6D1MC7+WsbZ+Ve3hPmDlCzZX1yV6RIP2BWDgQcbcYc=",
"overlay": {
"tsl.bzl": "",
"workspace.bzl": "",
"MODULE.bazel": ""
},
"patch_strip": 1,
"patches": {
"0001-bazel-migration-to-bazel-8.1.1.patch": "",
"0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": ""
}
}

View File

@ -1,16 +1,15 @@
const asynk = @import("async");
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx");
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
const testing = std.testing; const testing = std.testing;
const asynk = @import("async");
const stdx = @import("stdx");
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const meta = @import("meta.zig");
const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
@ -27,10 +26,22 @@ const log = std.log.scoped(.zml);
/// * loading weights from disk directly to the `device zml.aio.loadBuffers` /// * loading weights from disk directly to the `device zml.aio.loadBuffers`
/// * can be created by calling `HostBuffer.toDevice(platform)`. /// * can be created by calling `HostBuffer.toDevice(platform)`.
pub const Buffer = struct { pub const Buffer = struct {
pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).@"enum".tag_type) { pub const Memory = enum {
host = @intFromEnum(pjrt.Memory.Kind.unpinned_host), host,
host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host), host_pinned,
device = @intFromEnum(pjrt.Memory.Kind.device), device,
pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind {
return switch (self) {
.host => .unpinned_host,
.host_pinned => .pinned_host,
.device => .device,
};
}
pub fn pjrtName(self: Memory) []const u8 {
return @tagName(self.toPjrtMemory());
}
}; };
pub const Shard = struct { pub const Shard = struct {
@ -216,13 +227,13 @@ pub const Buffer = struct {
/// and it might not work on all platforms, /// and it might not work on all platforms,
/// could lead to crashes and operations on the buffer will be slower. /// could lead to crashes and operations on the buffer will be slower.
/// Tested on Cuda 12.4. /// Tested on Cuda 12.4.
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) !Buffer { pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr))); return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
} }
/// Creates a Buffer from a pointer into device memory. /// Creates a Buffer from a pointer into device memory.
/// This allows to interface with other libraries producing buffers. /// This allows to interface with other libraries producing buffers.
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) !Buffer { pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer {
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
var res: [Shape.MAX_RANK]i64 = undefined; var res: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| { for (0..Shape.MAX_RANK) |i| {
@ -231,9 +242,8 @@ pub const Buffer = struct {
break :blk res; break :blk res;
}; };
const device_bytes: [*]u8 = @ptrCast(device_data); const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ .data = device_data,
.data = device_bytes[0..shape_.byteSize()],
.element_type = bufferTypeFromDtype(shape_.dtype()), .element_type = bufferTypeFromDtype(shape_.dtype()),
.dims = shape_.dims(), .dims = shape_.dims(),
// TODO: exposes sharding in the API. // TODO: exposes sharding in the API.
@ -246,7 +256,7 @@ pub const Buffer = struct {
}, },
}, },
.stream = @bitCast(@as(usize, @intFromPtr(stream))), .stream = @bitCast(@as(usize, @intFromPtr(stream))),
}); }) catch @panic("failed to createViewOfDeviceBuffer");
var shards: Shards = .{}; var shards: Shards = .{};
shards.appendAssumeCapacity(pjrt_buffer); shards.appendAssumeCapacity(pjrt_buffer);
@ -342,6 +352,11 @@ pub const Buffer = struct {
try writer.print("Buffer({_})", .{self._shape}); try writer.print("Buffer({_})", .{self._shape});
} }
pub fn getMemory(self: Buffer) *const pjrt.Memory {
const shard = self._shards.get(0);
return shard.memory(self._api);
}
fn hasShardedAxis(self: Buffer) bool { fn hasShardedAxis(self: Buffer) bool {
if (self._shards.len == 1) return false; if (self._shards.len == 1) return false;
return @reduce(.Or, self._shape._sharding_info); return @reduce(.Or, self._shape._sharding_info);

View File

@ -1,21 +1,23 @@
const std = @import("std");
const builtin = @import("builtin"); const builtin = @import("builtin");
const c = @import("c"); const c = @import("c");
const mlir = @import("mlir"); const mlir = @import("mlir");
const runfiles = @import("runfiles"); const runfiles = @import("runfiles");
const runtimes = @import("runtimes"); const runtimes = @import("runtimes");
const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const zml_platform = @import("platform.zig"); const Buffer = @import("buffer.zig").Buffer;
const pjrt = @import("pjrtx.zig"); const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); const pjrt = @import("pjrtx.zig");
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const PlatformsMap = std.EnumArray(Target, ?Platform); const Shape = @import("shape.zig").Shape;
const Target = @import("platform.zig").Target; const Target = @import("platform.zig").Target;
const zml_platform = @import("platform.zig");
const available_targets = @import("platform.zig").available_targets; const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
const PlatformsMap = std.EnumArray(Target, ?Platform);
const log = std.log.scoped(.@"zml/context"); const log = std.log.scoped(.@"zml/context");
test { test {
@ -174,10 +176,8 @@ pub const Context = struct {
log.err("No device found for platform {} !", .{target}); log.err("No device found for platform {} !", .{target});
return error.NoDevicesFound; return error.NoDevicesFound;
} }
// TODO: should this be moved to platform.zig ?
if (target == .cuda) { try CustomCall.registerZmlCustomCalls(p);
try cuda.registerZmlCustomCalls(p);
}
self.platforms.set(target, p); self.platforms.set(target, p);
return p; return p;
@ -213,77 +213,68 @@ pub const Context = struct {
} }
} }
pub const HostCallbackCtx = struct { pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void;
host: HostBuffer,
mutex: std.Thread.Mutex = std.Thread.Mutex{},
};
pub const HostCallback = fn (HostBuffer) void;
}; };
const cuda = struct { const CustomCall = struct {
var runtime: Runtime = undefined; pub fn registerZmlCustomCalls(platform: Platform) !void {
const registry = platform.pjrt_api.customCallRegistry();
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void { if (registry) |reg| {
std.debug.assert(cuda_platform.target == .cuda); try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback);
} else {
cuda.runtime = try Runtime.init(); stdx.debug.panic("Registering custom calls failed", .{});
const registry = cuda_platform.pjrt_api.customCallRegistry().?; }
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
} }
pub const Stream = opaque {}; fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error {
pub const MemcpyKind = enum(c_int) { if (call_frame.registeringHook()) return null;
host_to_host = 0,
host_to_device = 1,
device_to_host = 2,
device_to_device = 3,
default = 4,
};
pub const Runtime = struct { const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
memcpyAsync: MemcpyAsync, std.debug.assert(callback_attr.dtype == .u64);
streamSynchronize: StreamSynchronize, const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize));
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int; const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable;
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int; std.debug.assert(user_ctx_ptr.dtype == .u64);
const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize));
pub fn init() !Runtime { const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
var cudart = try std.DynLib.open("libcudart.so.12"); for (input_buffers, 0..) |*b, i| {
defer cudart.close(); b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i));
return .{
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
};
}
};
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
return .{ fn_ptr, ctx_ptr };
} }
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void { const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
const stream: *Stream = @ptrCast(opaque_stream); for (output_buffers, 0..) |*b, i| {
const src: *anyopaque = buffers[0]; b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i));
const callback, const ctx = getContext(args, args_len); }
// Add synchronization because this is called from the device driver. callback(user_ctx, input_buffers, output_buffers);
ctx.mutex.lock(); return null;
defer ctx.mutex.unlock();
const host_dst: []u8 = @constCast(ctx.host.data);
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
_ = memcpy_result;
const synchronize_result = cuda.runtime.streamSynchronize(stream);
_ = synchronize_result;
callback(ctx.host);
} }
}; };
fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
// log.warn("received buffer {}", .{buffer_desc});
const dt: DataType = switch (buffer_desc.dtype) {
.invalid => @panic("invalid ffi"),
.pred => .bool,
.s8 => .i8,
.s16 => .i16,
.s32 => .i32,
.s64 => .i64,
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
inline else => |t| @field(DataType, @tagName(t)),
};
return Shape.init(buffer_desc.dims(), dt);
}
/// Create a HostBuffer from a ffi description of a buffer.
/// Normally the ffi describe device buffer but we assume they are located in pinned memory,
/// and therefore the data pointer is readable both from host and from device.
fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
const buffer_shape = getShape(buffer_desc);
return HostBuffer.fromBytes(
buffer_shape,
buffer_desc.data[0..buffer_shape.byteSize()],
);
}

View File

@ -1,4 +1,5 @@
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const stdx = @import("stdx");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
@ -98,6 +99,13 @@ pub const HostBuffer = struct {
}; };
} }
/// Returns a HostBuffer tagged with the tags in 'tagz'.
pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer {
var res = self;
res._shape = self._shape.withTags(tagz);
return res;
}
pub const ArangeArgs = struct { pub const ArangeArgs = struct {
start: i64 = 0, start: i64 = 0,
end: i64, end: i64,
@ -240,6 +248,11 @@ pub const HostBuffer = struct {
}; };
} }
pub fn choose1d(self: HostBuffer, axis_: anytype, start: i64) HostBuffer {
const ax = self.axis(axis_);
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
}
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
const ax = self._shape.axis(axis_); const ax = self._shape.axis(axis_);
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });

View File

@ -1,11 +1,10 @@
const std = @import("std"); const std = @import("std");
const stdx = @import("stdx"); const testing = std.testing;
const stdx = @import("stdx");
const FnParam = stdx.meta.FnParam; const FnParam = stdx.meta.FnParam;
const asSlice = stdx.meta.asSlice; const asSlice = stdx.meta.asSlice;
const testing = std.testing;
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
} }

View File

@ -367,6 +367,7 @@ pub const CompilationContext = struct {
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count); const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count); const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count); const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
const fn_res_output_memory_kind = try res_allocator.alloc(Buffer.Memory, out_tensor_count);
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
{ {
defer self.closeBlock(fn_body); defer self.closeBlock(fn_body);
@ -382,7 +383,7 @@ pub const CompilationContext = struct {
}; };
var fn_res_values: [out_tensor_count]mlir.Value = undefined; var fn_res_values: [out_tensor_count]mlir.Value = undefined;
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations, fn_res_output_memory_kind);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]); fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
@ -396,6 +397,7 @@ pub const CompilationContext = struct {
if (opts.kind == .main) { if (opts.kind == .main) {
self.addDonationsAttributes(arg_attrs, fn_res_donations); self.addDonationsAttributes(arg_attrs, fn_res_donations);
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
if (self._platform.sharding().num_partitions > 1) { if (self._platform.sharding().num_partitions > 1) {
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
} }
@ -433,6 +435,20 @@ pub const CompilationContext = struct {
}; };
} }
fn addOutputMemoryKindAttributes(self: CompilationContext, attributes: []AttributeList, output_memory_kind: []const Buffer.Memory) void {
const mlir_ctx = self.mlirCtx();
for (attributes, output_memory_kind) |*attr, memory_kind| {
// .device is the default output, don't explicitly emit the attribute
if (memory_kind == .device) continue;
attr.appendAssumeCapacity(.named(
mlir_ctx,
"mhlo.memory_kind",
.string(mlir_ctx, memory_kind.pjrtName()),
));
}
}
/// Given a list of donations mapping output buffers to input buffers, /// Given a list of donations mapping output buffers to input buffers,
/// generate donation attribute for each `n_args` input argument. /// generate donation attribute for each `n_args` input argument.
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void { fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
@ -712,7 +728,15 @@ pub const CompilationContext = struct {
} }
/// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found. /// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found.
pub fn extractValuesAndTypes(self: *const CompilationContext, v: anytype, values: []mlir.Value, types: []mlir.Type, shapes: []Shape, donations: []Tensor._Donation) void { pub fn extractValuesAndTypes(
self: *const CompilationContext,
v: anytype,
values: []mlir.Value,
types: []mlir.Type,
shapes: []Shape,
donations: []Tensor._Donation,
output_memory_kind: []Buffer.Memory,
) void {
std.debug.assert(values.len == types.len); std.debug.assert(values.len == types.len);
const LocalContext = struct { const LocalContext = struct {
self: *const CompilationContext, self: *const CompilationContext,
@ -721,8 +745,16 @@ pub const CompilationContext = struct {
types: []mlir.Type, types: []mlir.Type,
shapes: []Shape, shapes: []Shape,
donations: []Tensor._Donation, donations: []Tensor._Donation,
output_memory_kind: []Buffer.Memory,
};
var context = LocalContext{
.self = self,
.values = values,
.types = types,
.shapes = shapes,
.donations = donations,
.output_memory_kind = output_memory_kind,
}; };
var context = LocalContext{ .self = self, .values = values, .types = types, .shapes = shapes, .donations = donations };
meta.visit((struct { meta.visit((struct {
fn cb(ctx: *LocalContext, tensor: *const Tensor) void { fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
const value, const donation = ctx.self.getValueAndDonation(tensor.*); const value, const donation = ctx.self.getValueAndDonation(tensor.*);
@ -730,6 +762,7 @@ pub const CompilationContext = struct {
ctx.types[ctx.index] = value.getType(); ctx.types[ctx.index] = value.getType();
ctx.shapes[ctx.index] = tensor._shape; ctx.shapes[ctx.index] = tensor._shape;
ctx.donations[ctx.index] = donation; ctx.donations[ctx.index] = donation;
ctx.output_memory_kind[ctx.index] = tensor._output_memory_kind;
ctx.index += 1; ctx.index += 1;
} }
}).cb, &context, v); }).cb, &context, v);

View File

@ -1,16 +1,16 @@
const std = @import("std"); const std = @import("std");
const Context = @import("../context.zig").Context;
const module = @import("../module.zig");
const mlir = @import("../mlir.zig");
const dialect = @import("mlir/dialects"); const dialect = @import("mlir/dialects");
const Tensor = @import("../tensor.zig").Tensor; const Context = @import("../context.zig").Context;
const Shape = @import("../shape.zig").Shape;
const SdpaOpts = @import("../nn.zig").SdpaOpts;
const DataType = @import("../dtype.zig").DataType; const DataType = @import("../dtype.zig").DataType;
const Data = @import("../dtype.zig").Data; const Data = @import("../dtype.zig").Data;
const mlir = @import("../mlir.zig");
const module = @import("../module.zig");
const CompilationContext = module.CompilationContext; const CompilationContext = module.CompilationContext;
const SdpaOpts = @import("../nn.zig").SdpaOpts;
const Shape = @import("../shape.zig").Shape;
const Tensor = @import("../tensor.zig").Tensor;
pub fn canUseCudnnSdpa(q_shape: Shape) bool { pub fn canUseCudnnSdpa(q_shape: Shape) bool {
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
@ -125,7 +125,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
&.{ q.value(), k.value(), v.value(), bias.value() }, &.{ q.value(), k.value(), v.value(), bias.value() },
.{ .{
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax", .call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
.backend_config = .{ .string = backend_config }, .backend_config = .string(mlir_ctx, backend_config),
.has_side_effect = false, .has_side_effect = false,
.api_version = .original, .api_version = .original,
}, },

View File

@ -1,28 +1,31 @@
const std = @import("std"); const std = @import("std");
const assert = std.debug.assert;
const stdx = @import("stdx"); const stdx = @import("stdx");
const _collectAxes = @import("tensor.zig")._collectAxes;
const buffer = @import("buffer.zig"); const buffer = @import("buffer.zig");
const helpers = @import("helpers.zig");
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const module = @import("module.zig");
const Buffer = buffer.Buffer; const Buffer = buffer.Buffer;
const CompilationContext = module.CompilationContext; const Bufferized = @import("tensor.zig").Bufferized;
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const EnumLiteral = @TypeOf(.enum_literal); const helpers = @import("helpers.zig");
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const meta = @import("meta.zig");
const mlir = @import("mlir.zig");
const module = @import("module.zig");
const CompilationContext = module.CompilationContext;
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const ShapeOf = @import("tensor.zig").ShapeOf;
const Tensor = @import("tensor.zig").Tensor; const Tensor = @import("tensor.zig").Tensor;
const _collectAxes = @import("tensor.zig")._collectAxes;
const EnumLiteral = @TypeOf(.enum_literal);
const dialect = struct { const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo; const stablehlo = @import("mlir/dialects").stablehlo;
}; };
const assert = std.debug.assert;
const log = std.log.scoped(.@"zml/tensor"); const log = std.log.scoped(.@"zml/tensor");
test { test {
@ -766,50 +769,56 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
return res; return res;
} }
/// At runtime the given tensor will be materialized and copied to host, pub const HostCallbackOpt = struct {
/// and the callback will be called on it. has_side_effect: bool = false,
output_operand_aliases: []const i64 = &.{},
};
pub fn addHostCallback( pub fn addHostCallback(
callback: *const fn (HostBuffer) void, callback: *const Context.HostCallback,
input: Tensor, blkctx: ?*anyopaque,
) Tensor { inputs: []const Tensor,
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy. output_shapes: []const Shape,
if (input.getContext().target() != .cuda) return input; opts: HostCallbackOpt,
) []Tensor {
const len = input.byteSize();
// Reserve memory to be able to log the runtime Buffer later during the computation.
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
const HostCallbackCtx = Context.HostCallbackCtx;
const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch {
log.err("Failed to pre-allocate buffer to print {}.", .{input});
return input;
};
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
// Use an fba to have the stable buffer at an aligned offset.
var fba = std.heap.FixedBufferAllocator.init(full_data[len..]);
const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable;
stable_ctx_ptr.* = .{
.host = HostBuffer.fromBytes(input.shape(), full_data[0..len]),
};
const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr };
const ctx = CompilationContext.current(); const ctx = CompilationContext.current();
const mlir_ctx = ctx.mlirCtx();
const backend_config = mlir.Attribute.dict(mlir_ctx, &.{
.{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) },
.{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) },
});
const values = stdx.stackSlice(8, mlir.Value, inputs.len);
for (inputs, values) |i, *v| {
v.* = ctx.getValue(i.toMemory(.host_pinned));
}
const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len);
for (res_types, output_shapes) |*r, o| {
r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type);
}
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
const op = dialect.stablehlo.custom_call( const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(), ctx.mlirCtx(),
&.{input.value()}, values,
.{ .{
.has_side_effect = false,
.call_target_name = "zmlHostBufferCallback", .call_target_name = "zmlHostBufferCallback",
.backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) }, .api_version = .typed_ffi,
.output_operand_aliases = &.{0}, .backend_config = backend_config,
.api_version = .original, .has_side_effect = opts.has_side_effect,
.output_operand_aliases = opts.output_operand_aliases,
}, },
&.{input.value().getType()}, res_types,
loc, loc,
); );
return Tensor._result(input.shape(), op.result(0));
const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM");
for (res, output_shapes, 0..) |*r, o, i| {
r.* = Tensor._result(o, op.result(i)).toMemory(.device);
}
return res;
} }
pub const TritonOps = struct { pub const TritonOps = struct {
@ -834,46 +843,32 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
} }
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{ const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{
.named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)), .{ "name", .string(ctx.mlirCtx(), opts.name) },
.named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)), .{ "ir", .string(ctx.mlirCtx(), opts.ir) },
.named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])), .{ "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0]) },
.named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])), .{ "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1]) },
.named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])), .{ "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2]) },
.named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)), .{ "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages) },
.named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)), .{ "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps) },
}); });
const MINOR_TO_MAJOR = blk: { var operands_layouts: [inputs.len][]const usize = undefined;
var ret: [Shape.MAX_RANK]usize = undefined;
for (0..Shape.MAX_RANK) |i| {
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
}
break :blk ret;
};
const operands_layouts = blk: {
var ret: [inputs.len][]const usize = undefined;
inline for (inputs, 0..) |input, i| { inline for (inputs, 0..) |input, i| {
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..]; operands_layouts[i] = minorToMajor(input.rank());
} }
break :blk ret;
};
const results_layouts = blk: { var results_layouts: [outputs.len][]const usize = undefined;
var ret: [outputs.len][]const usize = undefined;
inline for (outputs, 0..) |output, i| { inline for (outputs, 0..) |output, i| {
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..]; results_layouts[i] = minorToMajor(output.rank());
} }
break :blk ret;
};
const op = dialect.stablehlo.custom_call( const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(), ctx.mlirCtx(),
&values, &values,
.{ .{
.call_target_name = "__gpu$xla.gpu.triton", .call_target_name = "__gpu$xla.gpu.triton",
.backend_config = .{ .dict = attrs }, .backend_config = backend_config,
.has_side_effect = false, .has_side_effect = false,
.api_version = .typed_ffi, .api_version = .typed_ffi,
.operand_layouts = &operands_layouts, .operand_layouts = &operands_layouts,
@ -1256,3 +1251,15 @@ inline fn toI64(values: anytype) []i64 {
for (values, 0..) |val, i| res[i] = @intCast(val); for (values, 0..) |val, i| res[i] = @intCast(val);
return res[0..values.len]; return res[0..values.len];
} }
const _MINOR_TO_MAJOR = blk: {
var ret: [Shape.MAX_RANK]usize = undefined;
for (0..Shape.MAX_RANK) |i| {
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
}
break :blk ret;
};
fn minorToMajor(rank: u8) []const usize {
return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..];
}

View File

@ -1,19 +1,10 @@
const std = @import("std");
const asynk = @import("async"); const asynk = @import("async");
const builtin = @import("builtin");
const dialects = @import("mlir/dialects"); const dialects = @import("mlir/dialects");
const mlir = @import("mlir"); const mlir = @import("mlir");
const pjrt = @import("pjrt"); const pjrt = @import("pjrt");
const std = @import("std"); pub const ffi = pjrt.ffi;
const stdx = @import("stdx");
const c = @import("c");
const dtype = @import("dtype.zig");
const meta = @import("meta.zig");
const Target = @import("platform.zig").Target;
const log = std.log.scoped(.zml);
pub const Profiler = pjrt.Profiler; pub const Profiler = pjrt.Profiler;
pub const ApiError = pjrt.ApiError; pub const ApiError = pjrt.ApiError;
pub const ErrorCode = pjrt.ErrorCode; pub const ErrorCode = pjrt.ErrorCode;
@ -23,7 +14,6 @@ pub const DeviceDescription = pjrt.DeviceDescription;
pub const Api = pjrt.Api; pub const Api = pjrt.Api;
pub const NamedValue = pjrt.NamedValue; pub const NamedValue = pjrt.NamedValue;
pub const ClientInitError = pjrt.ClientInitError; pub const ClientInitError = pjrt.ClientInitError;
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
pub const Error = pjrt.Error; pub const Error = pjrt.Error;
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
pub const SerializeResult = pjrt.SerializeResult; pub const SerializeResult = pjrt.SerializeResult;
@ -31,6 +21,10 @@ pub const Executable = pjrt.Executable;
pub const ExecuteError = ApiError; pub const ExecuteError = ApiError;
pub const Memory = pjrt.Memory; pub const Memory = pjrt.Memory;
const log = std.log.scoped(.zml);
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
fn InnerMixin(comptime innerT: type) type { fn InnerMixin(comptime innerT: type) type {
return struct { return struct {
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT { inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT {
@ -159,6 +153,10 @@ pub const Buffer = opaque {
return self.inner().isOnCpu(api); return self.inner().isOnCpu(api);
} }
pub fn memory(self: *const Buffer, api: *const Api) *const Memory {
return self.inner().memory(api);
}
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event { pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
return @ptrCast(try self.inner().toHostBuffer(api, dst)); return @ptrCast(try self.inner().toHostBuffer(api, dst));
} }
@ -183,8 +181,8 @@ pub const Buffer = opaque {
return @ptrCast(self.inner().copyToDevice(api, device)); return @ptrCast(self.inner().copyToDevice(api, device));
} }
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory: *const Memory) ApiError!*Buffer { pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
return @ptrCast(self.inner().copyToMemory(api, memory)); return @ptrCast(self.inner().copyToMemory(api, memory_));
} }
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event { pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {

View File

@ -9,6 +9,7 @@ const Buffer = @import("buffer.zig").Buffer;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Memory = @import("buffer.zig").Buffer.Memory;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const mlir = @import("mlir.zig"); const mlir = @import("mlir.zig");
const Location = mlir.Location; const Location = mlir.Location;
@ -41,10 +42,10 @@ pub const Tensor = struct {
_shape: Shape, _shape: Shape,
_id: _Id, _id: _Id,
_donation: _Donation = .no_buffer, _donation: _Donation = .no_buffer,
_output_memory_kind: Memory = .device,
pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 }; pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 };
pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 }; pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 };
pub const MAX_RANK = Shape.MAX_RANK; pub const MAX_RANK = Shape.MAX_RANK;
/// Returns the current compilation context. /// Returns the current compilation context.
@ -171,20 +172,22 @@ pub const Tensor = struct {
return switch (self._id) { return switch (self._id) {
.arg_id, .mlir => { .arg_id, .mlir => {
const ctx = self.getContext(); const ctx = self.getContext();
const mlir_ctx = ctx.mlirCtx();
var res = self; var res = self;
res._shape = self._shape.withSharding(axes_); res._shape = self._shape.withSharding(axes_);
const op = dialect.stablehlo.custom_call( const op = dialect.stablehlo.custom_call(
ctx.mlirCtx(), mlir_ctx,
&.{self.value()}, &.{self.value()},
.{ .{
.call_target_name = "Sharding", .call_target_name = "Sharding",
.has_side_effect = false, .has_side_effect = false,
.addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }}, .backend_config = null,
.additional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }},
.api_version = .original, .api_version = .original,
}, },
&.{self.value().getType()}, &.{self.value().getType()},
ctx.mlirCtx().location(@src()), mlir_ctx.location(@src()),
); );
return _result(res._shape, op.result(0)); return _result(res._shape, op.result(0));
@ -197,6 +200,39 @@ pub const Tensor = struct {
}; };
} }
pub fn toMemory(self: Tensor, kind: Memory) Tensor {
return switch (self._id) {
.arg_id, .mlir => {
const ctx = self.getContext();
const mlir_ctx = ctx.mlirCtx();
if (ctx.target() == .cpu) return self;
var res = self;
res._output_memory_kind = kind;
const memory_kind = @tagName(kind.toPjrtMemory());
const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{
.{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) },
});
const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{
.call_target_name = "annotate_device_placement",
.has_side_effect = true,
.backend_config = null,
.additional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
.api_version = .original,
}, &.{self.value().getType()}, mlir_ctx.location(@src()));
return _result(res._shape, op.result(0));
},
.buffer_id => {
var res = self;
res._output_memory_kind = kind;
return res;
},
};
}
/// Returns a Tensor with new tag names. /// Returns a Tensor with new tag names.
pub fn rename(self: Tensor, renames: anytype) Tensor { pub fn rename(self: Tensor, renames: anytype) Tensor {
var res = self; var res = self;
@ -3747,18 +3783,22 @@ pub const Tensor = struct {
} }
/// Insert code that will print the content of the given buffer at runtime. /// Insert code that will print the content of the given buffer at runtime.
/// Only for debug purpose, it has the following limitations: /// Only for debug purpose, it inserts device to host synchronization
/// * only supported on Cuda atm /// so it will slow down the program execution.
/// * only prints the first 1024 values
/// * pre allocates a buffer on the host to copy the content of the device buffer,
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
/// * does device to host synchronization so it will slow down the program execution.
pub fn print(input: Tensor) Tensor { pub fn print(input: Tensor) Tensor {
return ops.addHostCallback(&printCallback, input); return ops.addHostCallback(
&printCallback,
null,
&.{input},
&.{input.shape()},
.{ .output_operand_aliases = &.{0} },
)[0];
} }
fn printCallback(host_buffer: HostBuffer) void { fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
const host_buffer = inputs[0];
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr);
} }
}; };