pjrt: add FFI bindings for custom calls
This commit is contained in:
parent
1f5ff96c10
commit
aec7072837
@ -24,7 +24,7 @@ bazel_dep(name = "rules_zig", version = "20250314.0-b9739c6")
|
|||||||
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
||||||
bazel_dep(name = "toolchains_protoc", version = "0.3.7")
|
bazel_dep(name = "toolchains_protoc", version = "0.3.7")
|
||||||
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
||||||
bazel_dep(name = "xla", version = "20250317.0-71c67e2")
|
bazel_dep(name = "xla", version = "20250317.1-71c67e2")
|
||||||
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
|
bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e")
|
||||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||||
|
|
||||||
|
|||||||
@ -739,18 +739,13 @@ pub const CustomCallOpts = struct {
|
|||||||
typed_ffi = 4,
|
typed_ffi = 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const BackendConfig = union(enum) {
|
|
||||||
string: [:0]const u8,
|
|
||||||
dict: mlir.DictionaryAttribute,
|
|
||||||
};
|
|
||||||
|
|
||||||
call_target_name: [:0]const u8,
|
call_target_name: [:0]const u8,
|
||||||
has_side_effect: bool,
|
has_side_effect: bool,
|
||||||
backend_config: BackendConfig = .{ .string = &.{} },
|
backend_config: ?mlir.Attribute,
|
||||||
operand_layouts: ?[]const []const usize = null,
|
operand_layouts: ?[]const []const usize = null,
|
||||||
result_layouts: ?[]const []const usize = null,
|
result_layouts: ?[]const []const usize = null,
|
||||||
output_operand_aliases: []const i64 = &.{},
|
output_operand_aliases: []const i64 = &.{},
|
||||||
addional_attributes: []const mlir.AttrTuple = &.{},
|
additional_attributes: []const mlir.AttrTuple = &.{},
|
||||||
api_version: ApiVersion,
|
api_version: ApiVersion,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -758,68 +753,56 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
const MAX_OPERANDS = 64;
|
const MAX_OPERANDS = 64;
|
||||||
const MAX_RESULTS = 16;
|
const MAX_RESULTS = 16;
|
||||||
|
|
||||||
const output_operand_aliases = blk: {
|
const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, "");
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
|
||||||
for (opts.output_operand_aliases) |alias| {
|
stdx.debug.assert(
|
||||||
ret.appendAssumeCapacity(
|
backend_config.is_a(mlir.StringAttribute),
|
||||||
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
|
"API version < 4 requires a string as backend_config, got {}",
|
||||||
);
|
.{backend_config},
|
||||||
}
|
);
|
||||||
break :blk ret;
|
} else {
|
||||||
};
|
stdx.debug.assert(
|
||||||
|
backend_config.is_a(mlir.DictionaryAttribute),
|
||||||
const backend_config: mlir.Attribute = switch (opts.backend_config) {
|
"API version >= 4 requires a dictionary as backend_config, got {}",
|
||||||
.string => blk: {
|
.{backend_config},
|
||||||
stdx.debug.assert(
|
);
|
||||||
@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi),
|
}
|
||||||
"Only API version of less than 4 is supported for backend_config as string",
|
|
||||||
.{},
|
|
||||||
);
|
|
||||||
break :blk .string(ctx, opts.backend_config.string);
|
|
||||||
},
|
|
||||||
.dict => blk: {
|
|
||||||
stdx.debug.assert(
|
|
||||||
opts.api_version == .typed_ffi,
|
|
||||||
"Only API version 4 is supported for backend_config as dictionary",
|
|
||||||
.{},
|
|
||||||
);
|
|
||||||
break :blk opts.backend_config.dict.as(mlir.Attribute);
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{};
|
||||||
|
|
||||||
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{
|
||||||
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
|
.{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) },
|
||||||
.{ "call_target_name", .string(ctx, opts.call_target_name) },
|
.{ "call_target_name", .string(ctx, opts.call_target_name) },
|
||||||
.{ "has_side_effect", .boolean(ctx, opts.has_side_effect) },
|
.{ "has_side_effect", .boolean(ctx, opts.has_side_effect) },
|
||||||
.{ "backend_config", backend_config },
|
.{ "backend_config", backend_config },
|
||||||
.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) },
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
|
for (opts.output_operand_aliases) |alias| {
|
||||||
|
output_operand_aliases.appendAssumeCapacity(
|
||||||
|
OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) });
|
||||||
|
}
|
||||||
|
|
||||||
if (opts.operand_layouts) |layouts| {
|
if (opts.operand_layouts) |layouts| {
|
||||||
const operand_layouts = blk: {
|
var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{};
|
for (layouts) |ol| {
|
||||||
for (layouts) |ol| {
|
operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
||||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol));
|
}
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) });
|
||||||
}
|
}
|
||||||
|
|
||||||
if (opts.result_layouts) |layouts| {
|
if (opts.result_layouts) |layouts| {
|
||||||
const result_layouts = blk: {
|
var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
||||||
var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{};
|
for (layouts) |rl| {
|
||||||
for (layouts) |rl| {
|
result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
||||||
ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl));
|
}
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) });
|
||||||
}
|
}
|
||||||
|
|
||||||
attrs.appendSliceAssumeCapacity(opts.addional_attributes);
|
attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes");
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||||
.operands = inputs,
|
.operands = inputs,
|
||||||
@ -829,22 +812,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: move out of stablehlo.zig when we start to implement the frontend
|
|
||||||
pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
|
||||||
const frontend_attributes = mlir.DictionaryAttribute.init(
|
|
||||||
ctx,
|
|
||||||
&.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())},
|
|
||||||
).asAttr();
|
|
||||||
|
|
||||||
return custom_call(ctx, inputs, .{
|
|
||||||
.call_target_name = "annotate_device_placement",
|
|
||||||
.has_side_effect = true,
|
|
||||||
.backend_config = .{ .string = &.{} },
|
|
||||||
.addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
|
|
||||||
.api_version = .original,
|
|
||||||
}, res_types, location);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const DotDimensionNumbersAttribute = struct {
|
pub const DotDimensionNumbersAttribute = struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
|
|
||||||
|
|||||||
@ -414,6 +414,18 @@ pub const Attribute = struct {
|
|||||||
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
|
pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute {
|
||||||
return NamedAttribute.init(Identifier.get(ctx, name), attr);
|
return NamedAttribute.init(Identifier.get(ctx, name), attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute {
|
||||||
|
var attr_buf: [32]NamedAttribute = undefined;
|
||||||
|
stdx.debug.assert(named_attrs.len <= attr_buf.len, ".dict helper only support up to {} attribute, got {}. You will need to call mlir.DictionaryAttribute directly", .{ attr_buf.len, named_attrs.len });
|
||||||
|
|
||||||
|
const attrs = attr_buf[0..named_attrs.len];
|
||||||
|
for (attrs, named_attrs) |*attr, tuple| {
|
||||||
|
attr.* = .named(ctx, tuple[0], tuple[1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DictionaryAttribute.init(ctx, attrs).asAttr();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const NamedAttribute = extern struct {
|
pub const NamedAttribute = extern struct {
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||||
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
load("@zml//bazel:zig.bzl", "zig_cc_binary")
|
||||||
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
|
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
|
||||||
@ -12,7 +13,12 @@ cc_library(
|
|||||||
|
|
||||||
zig_library(
|
zig_library(
|
||||||
name = "pjrt",
|
name = "pjrt",
|
||||||
srcs = ["profiler.zig"] + glob(["convert/*.zig"]),
|
srcs = [
|
||||||
|
"ffi.zig",
|
||||||
|
"profiler.zig",
|
||||||
|
"convert/trace_container.zig",
|
||||||
|
"convert/xplane_schema.zig"
|
||||||
|
],
|
||||||
main = "pjrt.zig",
|
main = "pjrt.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -20,9 +26,12 @@ zig_library(
|
|||||||
":trace_events_proto",
|
":trace_events_proto",
|
||||||
":xplane_proto",
|
":xplane_proto",
|
||||||
"//stdx",
|
"//stdx",
|
||||||
|
"@xla//xla/ffi/api:c_api",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs",
|
||||||
|
"@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs",
|
||||||
] + select({
|
] + select({
|
||||||
"@platforms//os:linux": [":dlfcn"],
|
"@platforms//os:linux": [":dlfcn"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
@ -49,7 +58,7 @@ zig_proto_library(
|
|||||||
|
|
||||||
zig_cc_binary(
|
zig_cc_binary(
|
||||||
name = "xspace_to_json",
|
name = "xspace_to_json",
|
||||||
srcs = glob(["convert/*.zig"]),
|
srcs = ["convert/trace_container.zig", "convert/xplane_schema.zig"],
|
||||||
main = "xspace_to_json.zig",
|
main = "xspace_to_json.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const trace_events_proto = @import("//tsl:trace_events_proto");
|
const trace_events_proto = @import("//tsl:trace_events_proto");
|
||||||
const xplane_proto = @import("//tsl:xplane_proto");
|
const xplane_proto = @import("//tsl:xplane_proto");
|
||||||
|
|
||||||
const xplane_schema = @import("xplane_schema.zig");
|
const xplane_schema = @import("xplane_schema.zig");
|
||||||
const xplane_visitor = @import("xplane_visitor.zig");
|
|
||||||
|
|
||||||
// Constants used as trace_viewer PID (device_id in trace_events.proto).
|
// Constants used as trace_viewer PID (device_id in trace_events.proto).
|
||||||
// PID 0 is unused.
|
// PID 0 is unused.
|
||||||
@ -87,7 +88,7 @@ pub const TraceContainer = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const xplane_visitor.XPlaneVisitor) !void {
|
fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const XPlaneHashed) !void {
|
||||||
// Convert devices and resources.
|
// Convert devices and resources.
|
||||||
const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id });
|
const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id });
|
||||||
var device = device_entry.value_ptr.*;
|
var device = device_entry.value_ptr.*;
|
||||||
@ -156,7 +157,7 @@ pub const TraceContainer = struct {
|
|||||||
|
|
||||||
fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void {
|
fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void {
|
||||||
if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| {
|
if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| {
|
||||||
const xplane = try xplane_visitor.XPlaneVisitor.init(allocator, hp);
|
const xplane = try XPlaneHashed.init(allocator, hp);
|
||||||
try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane);
|
try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -173,7 +174,7 @@ pub const TraceContainer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (device_planes.items) |dp| {
|
for (device_planes.items) |dp| {
|
||||||
var xplane = try xplane_visitor.XPlaneVisitor.init(allocator, dp);
|
var xplane = try XPlaneHashed.init(allocator, dp);
|
||||||
defer xplane.deinit(allocator);
|
defer xplane.deinit(allocator);
|
||||||
const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id));
|
const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id));
|
||||||
try self.xplaneToTraceEvents(allocator, device_id, &xplane);
|
try self.xplaneToTraceEvents(allocator, device_id, &xplane);
|
||||||
@ -298,3 +299,68 @@ pub const TraceContainer = struct {
|
|||||||
fn picoToMicro(p: anytype) f64 {
|
fn picoToMicro(p: anytype) f64 {
|
||||||
return @as(f64, @floatFromInt(p)) / 1E6;
|
return @as(f64, @floatFromInt(p)) / 1E6;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const XPlaneHashed = struct {
|
||||||
|
plane: *const xplane_proto.XPlane,
|
||||||
|
event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{},
|
||||||
|
stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{},
|
||||||
|
|
||||||
|
pub fn init(
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
plane: *const xplane_proto.XPlane,
|
||||||
|
) !XPlaneHashed {
|
||||||
|
var res: XPlaneHashed = .{ .plane = plane };
|
||||||
|
|
||||||
|
try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len));
|
||||||
|
// build event metadata map
|
||||||
|
for (plane.event_metadata.items) |*event_metadata| {
|
||||||
|
res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// build stat metadata map
|
||||||
|
try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len));
|
||||||
|
for (plane.stat_metadata.items) |*stat_metadata| {
|
||||||
|
res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deinit(self: *XPlaneHashed, allocator: std.mem.Allocator) void {
|
||||||
|
self.stat_metadata_by_id.deinit(allocator);
|
||||||
|
self.event_metadata_by_id.deinit(allocator);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn name(self: XPlaneHashed) []const u8 {
|
||||||
|
return self.plane.name.getSlice();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getEventType(self: XPlaneHashed, event_metadata_id: i64) xplane_schema.HostEventType {
|
||||||
|
if (self.event_metadata_by_id.get(event_metadata_id)) |v| {
|
||||||
|
return xplane_schema.HostEventType.fromString(v.name.getSlice());
|
||||||
|
} else return .unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getStatMetadataName(self: XPlaneHashed, stat_metadata_id: i64) []const u8 {
|
||||||
|
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
|
||||||
|
return v.name.getSlice();
|
||||||
|
} else return &[_]u8{};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getStatType(self: XPlaneHashed, stat_metadata_id: i64) xplane_schema.StatType {
|
||||||
|
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
|
||||||
|
return xplane_schema.StatType.fromString(v.name.getSlice());
|
||||||
|
} else return .unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn xstatToString(self: XPlaneHashed, stat: xplane_proto.XStat, writer: anytype) !void {
|
||||||
|
if (stat.value == null) return;
|
||||||
|
|
||||||
|
switch (stat.value.?) {
|
||||||
|
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
|
||||||
|
.str_value => |*v| try writer.writeAll(v.getSlice()),
|
||||||
|
.bytes_value => try writer.writeAll("<opaque bytes>"),
|
||||||
|
.ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|||||||
@ -1,68 +0,0 @@
|
|||||||
const std = @import("std");
|
|
||||||
const xplane_proto = @import("//tsl:xplane_proto");
|
|
||||||
const xplane_schema = @import("xplane_schema.zig");
|
|
||||||
|
|
||||||
pub const XPlaneVisitor = struct {
|
|
||||||
plane: *const xplane_proto.XPlane,
|
|
||||||
event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{},
|
|
||||||
stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{},
|
|
||||||
|
|
||||||
pub fn init(
|
|
||||||
allocator: std.mem.Allocator,
|
|
||||||
plane: *const xplane_proto.XPlane,
|
|
||||||
) !XPlaneVisitor {
|
|
||||||
var res: XPlaneVisitor = .{ .plane = plane };
|
|
||||||
|
|
||||||
try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len));
|
|
||||||
// build event metadata map
|
|
||||||
for (plane.event_metadata.items) |*event_metadata| {
|
|
||||||
res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?);
|
|
||||||
}
|
|
||||||
|
|
||||||
// build stat metadata map
|
|
||||||
try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len));
|
|
||||||
for (plane.stat_metadata.items) |*stat_metadata| {
|
|
||||||
res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?);
|
|
||||||
}
|
|
||||||
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *XPlaneVisitor, allocator: std.mem.Allocator) void {
|
|
||||||
self.stat_metadata_by_id.deinit(allocator);
|
|
||||||
self.event_metadata_by_id.deinit(allocator);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn name(self: XPlaneVisitor) []const u8 {
|
|
||||||
return self.plane.name.getSlice();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getEventType(self: XPlaneVisitor, event_metadata_id: i64) xplane_schema.HostEventType {
|
|
||||||
if (self.event_metadata_by_id.get(event_metadata_id)) |v| {
|
|
||||||
return xplane_schema.HostEventType.fromString(v.name.getSlice());
|
|
||||||
} else return .unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getStatMetadataName(self: XPlaneVisitor, stat_metadata_id: i64) []const u8 {
|
|
||||||
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
|
|
||||||
return v.name.getSlice();
|
|
||||||
} else return &[_]u8{};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn getStatType(self: XPlaneVisitor, stat_metadata_id: i64) xplane_schema.StatType {
|
|
||||||
if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| {
|
|
||||||
return xplane_schema.StatType.fromString(v.name.getSlice());
|
|
||||||
} else return .unknown;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn xstatToString(self: XPlaneVisitor, stat: xplane_proto.XStat, writer: anytype) !void {
|
|
||||||
if (stat.value == null) return;
|
|
||||||
|
|
||||||
switch (stat.value.?) {
|
|
||||||
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
|
|
||||||
.str_value => |*v| try writer.writeAll(v.getSlice()),
|
|
||||||
.bytes_value => try writer.writeAll("<opaque bytes>"),
|
|
||||||
.ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
517
pjrt/ffi.zig
Normal file
517
pjrt/ffi.zig
Normal file
@ -0,0 +1,517 @@
|
|||||||
|
/// Bindings for PJRT custom call declaration / execution.
|
||||||
|
const std = @import("std");
|
||||||
|
|
||||||
|
const c = @import("c");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const pjrtStruct = @import("pjrt.zig").pjrtStruct;
|
||||||
|
|
||||||
|
const log = std.log.scoped(.pjrt);
|
||||||
|
|
||||||
|
pub const ApiVersion = extern struct {
|
||||||
|
pub const major = c.XLA_FFI_API_MAJOR;
|
||||||
|
pub const minor = c.XLA_FFI_API_MINOR;
|
||||||
|
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*ExtensionBase,
|
||||||
|
major_version: i32,
|
||||||
|
minor_version: i32,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ExtensionType = enum(c.XLA_FFI_Extension_Type) {
|
||||||
|
metadata = c.XLA_FFI_Extension_Metadata,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ExtensionBase = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
type: ExtensionType,
|
||||||
|
next: ?*ExtensionBase,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449
|
||||||
|
pub const HandlerTraits = packed struct(u32) {
|
||||||
|
/// Calls to FFI handler are safe to trace into the command buffer.
|
||||||
|
/// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values)
|
||||||
|
/// that can be captured and then replayed.
|
||||||
|
command_buffer_compatible: u1,
|
||||||
|
|
||||||
|
__unassigned__: u31,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Metadata = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
api_version: ApiVersion,
|
||||||
|
traits: HandlerTraits,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const MetadataExtension = extern struct {
|
||||||
|
extension_base: ExtensionBase,
|
||||||
|
metadata: ?*Metadata,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ApiError = error{
|
||||||
|
Cancelled,
|
||||||
|
Unknown,
|
||||||
|
InvalidArgument,
|
||||||
|
DeadlineExceeded,
|
||||||
|
NotFound,
|
||||||
|
AlreadyExists,
|
||||||
|
PermissionDenied,
|
||||||
|
ResourceExhausted,
|
||||||
|
FailedPrecondition,
|
||||||
|
Aborted,
|
||||||
|
OutOfRange,
|
||||||
|
Unimplemented,
|
||||||
|
Internal,
|
||||||
|
Unavailable,
|
||||||
|
DataLoss,
|
||||||
|
Unauthenticated,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn TransmuteMixin(comptime T: type, comptime InnerT: type) type {
|
||||||
|
return struct {
|
||||||
|
pub fn to(self: anytype) switch (@TypeOf(self)) {
|
||||||
|
*T => *InnerT,
|
||||||
|
*const T => *const InnerT,
|
||||||
|
else => unreachable,
|
||||||
|
} {
|
||||||
|
return @ptrCast(@alignCast(self));
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from(self: anytype) switch (@TypeOf(self)) {
|
||||||
|
*InnerT => *T,
|
||||||
|
*const InnerT => *const T,
|
||||||
|
else => unreachable,
|
||||||
|
} {
|
||||||
|
return @ptrCast(@alignCast(self));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const Api = opaque {
|
||||||
|
pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to;
|
||||||
|
|
||||||
|
pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{
|
||||||
|
.ctx = if (context) |ctx| ctx.inner() else null,
|
||||||
|
});
|
||||||
|
const result = self.inner().XLA_FFI_Stream_Get.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(self);
|
||||||
|
log.err("[Api.getStream] {s}", .{err.getMessage(self)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.stream.?;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{
|
||||||
|
.ctx = if (context) |ctx| ctx.inner() else null,
|
||||||
|
.size = size,
|
||||||
|
.alignment = alignment,
|
||||||
|
});
|
||||||
|
const result = self.inner().XLA_FFI_DeviceMemory_Allocate.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(self);
|
||||||
|
log.err("[Api.allocateDeviceMemory] {s}", .{err.getMessage(self)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.data.?;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{
|
||||||
|
.ctx = if (context) |ctx| ctx.inner() else null,
|
||||||
|
.size = size,
|
||||||
|
.data = data,
|
||||||
|
});
|
||||||
|
const result = self.inner().XLA_FFI_DeviceMemory_Free.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(self);
|
||||||
|
log.err("[Api.freeDeviceMemory] {s}", .{err.getMessage(self)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(Corentin): Implement remaining methods if needed:
|
||||||
|
// * XLA_FFI_ThreadPool_Schedule
|
||||||
|
// * XLA_FFI_Handler_Register
|
||||||
|
// * XLA_FFI_TypeId_Register
|
||||||
|
// * XLA_FFI_State_Set
|
||||||
|
// * XLA_FFI_State_Get
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) {
|
||||||
|
instantiate = c.XLA_FFI_ExecutionStage_INSTANTIATE,
|
||||||
|
prepare = c.XLA_FFI_ExecutionStage_PREPARE,
|
||||||
|
initialize = c.XLA_FFI_ExecutionStage_INITIALIZE,
|
||||||
|
execute = c.XLA_FFI_ExecutionStage_EXECUTE,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const ExecutionContext = opaque {
|
||||||
|
pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to;
|
||||||
|
|
||||||
|
// pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void {
|
||||||
|
// // register type id ==> typeid
|
||||||
|
// const typename_ = "zml." ++ @typeName(@TypeOf(value));
|
||||||
|
|
||||||
|
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
|
||||||
|
// .ctx = self.inner(),
|
||||||
|
// .handler = @ptrCast(@alignCast(handler)),
|
||||||
|
// });
|
||||||
|
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
|
||||||
|
|
||||||
|
// var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{
|
||||||
|
// .ctx = self.inner(),
|
||||||
|
// .handler = @ptrCast(@alignCast(handler)),
|
||||||
|
// });
|
||||||
|
// const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret);
|
||||||
|
|
||||||
|
// if (result) |ffi_error| {
|
||||||
|
// const err = Error.fromInner(ffi_error);
|
||||||
|
// defer err.destroy(api);
|
||||||
|
// log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// // TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
// return error.Unknown;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{
|
||||||
|
.ctx = self.inner(),
|
||||||
|
.type_id = @ptrCast(@alignCast(type_id)),
|
||||||
|
});
|
||||||
|
const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret.data.?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO getDeviceOrdinal()
|
||||||
|
};
|
||||||
|
|
||||||
|
const ByteSpan = extern struct {
|
||||||
|
ptr: [*]const u8,
|
||||||
|
len: usize,
|
||||||
|
|
||||||
|
pub fn slice(self: ByteSpan) []const u8 {
|
||||||
|
return self.ptr[0..self.len];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const TypeId = extern struct {
|
||||||
|
type_id: i64,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const DataType = enum(c.XLA_FFI_DataType) {
|
||||||
|
invalid = c.XLA_FFI_DataType_INVALID,
|
||||||
|
pred = c.XLA_FFI_DataType_PRED,
|
||||||
|
s8 = c.XLA_FFI_DataType_S8,
|
||||||
|
s16 = c.XLA_FFI_DataType_S16,
|
||||||
|
s32 = c.XLA_FFI_DataType_S32,
|
||||||
|
s64 = c.XLA_FFI_DataType_S64,
|
||||||
|
u8 = c.XLA_FFI_DataType_U8,
|
||||||
|
u16 = c.XLA_FFI_DataType_U16,
|
||||||
|
u32 = c.XLA_FFI_DataType_U32,
|
||||||
|
u64 = c.XLA_FFI_DataType_U64,
|
||||||
|
f16 = c.XLA_FFI_DataType_F16,
|
||||||
|
f32 = c.XLA_FFI_DataType_F32,
|
||||||
|
f64 = c.XLA_FFI_DataType_F64,
|
||||||
|
bf16 = c.XLA_FFI_DataType_BF16,
|
||||||
|
c64 = c.XLA_FFI_DataType_C64,
|
||||||
|
c128 = c.XLA_FFI_DataType_C128,
|
||||||
|
token = c.XLA_FFI_DataType_TOKEN,
|
||||||
|
f8e5m2 = c.XLA_FFI_DataType_F8E5M2,
|
||||||
|
f8e3m4 = c.XLA_FFI_DataType_F8E3M4,
|
||||||
|
f8e4m3 = c.XLA_FFI_DataType_F8E4M3,
|
||||||
|
f8e4m3fn = c.XLA_FFI_DataType_F8E4M3FN,
|
||||||
|
f8e4m3b11fnuz = c.XLA_FFI_DataType_F8E4M3B11FNUZ,
|
||||||
|
f8e5m2fnuz = c.XLA_FFI_DataType_F8E5M2FNUZ,
|
||||||
|
f8e4m3fnuz = c.XLA_FFI_DataType_F8E4M3FNUZ,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Buffer = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*c.XLA_FFI_Extension_Base,
|
||||||
|
dtype: DataType,
|
||||||
|
data: [*]u8,
|
||||||
|
rank: u64,
|
||||||
|
_dims: [*]const i64,
|
||||||
|
|
||||||
|
pub fn dims(self: Buffer) []const i64 {
|
||||||
|
return self._dims[0..self.rank];
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn format(
|
||||||
|
buffer: Buffer,
|
||||||
|
comptime fmt: []const u8,
|
||||||
|
options: std.fmt.FormatOptions,
|
||||||
|
writer: anytype,
|
||||||
|
) !void {
|
||||||
|
_ = fmt;
|
||||||
|
_ = options;
|
||||||
|
|
||||||
|
try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) });
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Args = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*const c.XLA_FFI_Extension_Base,
|
||||||
|
len: u64,
|
||||||
|
types: [*]const Type,
|
||||||
|
ptr: [*]*const Buffer,
|
||||||
|
|
||||||
|
pub const Type = enum(c.XLA_FFI_ArgType) {
|
||||||
|
buffer = c.XLA_FFI_ArgType_BUFFER,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn get(self: Args, i: usize) *const Buffer {
|
||||||
|
std.debug.assert(self.types[0..self.len][i] == .buffer);
|
||||||
|
return self.ptr[0..self.len][i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Rets = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*const c.XLA_FFI_Extension_Base,
|
||||||
|
len: u64,
|
||||||
|
types: [*]const Type,
|
||||||
|
ptr: [*]*const Buffer,
|
||||||
|
|
||||||
|
pub const Type = enum(c.XLA_FFI_RetType) {
|
||||||
|
buffer = c.XLA_FFI_RetType_BUFFER,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn get(self: Rets, i: usize) *const Buffer {
|
||||||
|
std.debug.assert(self.types[0..self.len][i] == .buffer);
|
||||||
|
return self.ptr[0..self.len][i];
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const AttrType = enum(c.XLA_FFI_AttrType) {
|
||||||
|
array = c.XLA_FFI_AttrType_ARRAY,
|
||||||
|
dictionary = c.XLA_FFI_AttrType_DICTIONARY,
|
||||||
|
scalar = c.XLA_FFI_AttrType_SCALAR,
|
||||||
|
string = c.XLA_FFI_AttrType_STRING,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Attrs = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*ExtensionBase,
|
||||||
|
len: u64,
|
||||||
|
types: [*]const AttrType,
|
||||||
|
names: [*]const *const ByteSpan,
|
||||||
|
ptr: [*]const *const Attr,
|
||||||
|
|
||||||
|
const Attr = extern union {
|
||||||
|
scalar: Scalar,
|
||||||
|
array: Array,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Scalar = extern struct {
|
||||||
|
dtype: DataType,
|
||||||
|
value: *const anyopaque,
|
||||||
|
|
||||||
|
pub fn get(self: Scalar, T: type) T {
|
||||||
|
const ptr: *const T = @alignCast(@ptrCast(self.value));
|
||||||
|
return ptr.*;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Array = extern struct {
|
||||||
|
dtype: DataType,
|
||||||
|
len: usize,
|
||||||
|
data: [*]const u8,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) {
|
||||||
|
const attr = self.ptr[0..self.len][index];
|
||||||
|
const actual_type = self.types[index];
|
||||||
|
if (actual_type != attr_type) return null;
|
||||||
|
return @ptrCast(attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getByName(self: Attrs, comptime attr_type: AttrType, name: []const u8) ?*const @FieldType(Attr, @tagName(attr_type)) {
|
||||||
|
const names = self.names[0..self.len];
|
||||||
|
for (0.., names) |i, attr_name| {
|
||||||
|
if (std.mem.eql(u8, attr_name.slice(), name)) {
|
||||||
|
return self.getByIndex(attr_type, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const CallFrame = extern struct {
|
||||||
|
struct_size: usize,
|
||||||
|
extension_start: ?*ExtensionBase,
|
||||||
|
api: ?*const Api,
|
||||||
|
ctx: ?*const ExecutionContext,
|
||||||
|
stage: ExecutionStage,
|
||||||
|
args: Args,
|
||||||
|
results: Rets,
|
||||||
|
attrs: Attrs,
|
||||||
|
future: ?*Future,
|
||||||
|
|
||||||
|
/// The registery mechanism will first call the custom call in registration mode,
|
||||||
|
/// and expects us to indicate which version of XLA we have been compiled against.
|
||||||
|
/// Returns true if we registered ourselves and if the caller custom call should return early.
|
||||||
|
pub fn registeringHook(call_frame: *CallFrame) bool {
|
||||||
|
if (call_frame.extension_start != null and call_frame.extension_start.?.type == .metadata) {
|
||||||
|
const metadata_extension: *MetadataExtension = @fieldParentPtr("extension_base", call_frame.extension_start.?);
|
||||||
|
metadata_extension.metadata.?.api_version.major_version = ApiVersion.major;
|
||||||
|
metadata_extension.metadata.?.api_version.minor_version = ApiVersion.minor;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Handler = fn (*CallFrame) callconv(.C) ?*Error;
|
||||||
|
|
||||||
|
pub const ErrorCode = enum(c.XLA_FFI_Error_Code) {
|
||||||
|
cancelled = c.XLA_FFI_Error_Code_CANCELLED,
|
||||||
|
unknown = c.XLA_FFI_Error_Code_UNKNOWN,
|
||||||
|
invalid_argument = c.XLA_FFI_Error_Code_INVALID_ARGUMENT,
|
||||||
|
deadline_exceeded = c.XLA_FFI_Error_Code_DEADLINE_EXCEEDED,
|
||||||
|
not_found = c.XLA_FFI_Error_Code_NOT_FOUND,
|
||||||
|
already_exists = c.XLA_FFI_Error_Code_ALREADY_EXISTS,
|
||||||
|
permission_denied = c.XLA_FFI_Error_Code_PERMISSION_DENIED,
|
||||||
|
resource_exhausted = c.XLA_FFI_Error_Code_RESOURCE_EXHAUSTED,
|
||||||
|
failed_precondition = c.XLA_FFI_Error_Code_FAILED_PRECONDITION,
|
||||||
|
aborted = c.XLA_FFI_Error_Code_ABORTED,
|
||||||
|
out_of_range = c.XLA_FFI_Error_Code_OUT_OF_RANGE,
|
||||||
|
unimplemented = c.XLA_FFI_Error_Code_UNIMPLEMENTED,
|
||||||
|
internal = c.XLA_FFI_Error_Code_INTERNAL,
|
||||||
|
unavailable = c.XLA_FFI_Error_Code_UNAVAILABLE,
|
||||||
|
data_loss = c.XLA_FFI_Error_Code_DATA_LOSS,
|
||||||
|
unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED,
|
||||||
|
|
||||||
|
pub fn toApiError(code: ErrorCode) ApiError {
|
||||||
|
return switch (code) {
|
||||||
|
.cancelled => error.Cancelled,
|
||||||
|
.unknown => error.Unknown,
|
||||||
|
.invalid_argument => error.InvalidArgument,
|
||||||
|
.deadline_exceeded => error.DeadlineExceeded,
|
||||||
|
.not_found => error.FfiNotFound,
|
||||||
|
.already_exists => error.AlreadyExists,
|
||||||
|
.permission_denied => error.PermissionDenied,
|
||||||
|
.resource_exhausted => error.ResourceExhausted,
|
||||||
|
.failed_precondition => error.FailedPrecondition,
|
||||||
|
.aborted => error.Aborted,
|
||||||
|
.out_of_range => error.OutOfRange,
|
||||||
|
.unimplemented => error.Unimplemented,
|
||||||
|
.internal => error.Internal,
|
||||||
|
.unavailable => error.Unavailable,
|
||||||
|
.data_loss => error.DataLoss,
|
||||||
|
.unauthenticated => error.Unauthenticated,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Error = opaque {
|
||||||
|
pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to;
|
||||||
|
pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from;
|
||||||
|
|
||||||
|
pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{
|
||||||
|
.message = message.ptr,
|
||||||
|
.errc = @intFromEnum(error_code),
|
||||||
|
});
|
||||||
|
return fromInner(api.inner().XLA_FFI_Error_Create.?(&ret).?);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn destroy(err: *Error, api: *const Api) void {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() });
|
||||||
|
api.inner().XLA_FFI_Error_Destroy.?(&ret);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{
|
||||||
|
.@"error" = err.inner(),
|
||||||
|
});
|
||||||
|
api.inner().XLA_FFI_Error_GetMessage.?(&ret);
|
||||||
|
return std.mem.span(ret.message);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pub const Future = opaque {
|
||||||
|
pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to;
|
||||||
|
pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from;
|
||||||
|
|
||||||
|
pub fn create(api: *const Api) ApiError!*Future {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{});
|
||||||
|
const result = api.inner().XLA_FFI_Future_Create.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
log.err("[Future.create] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
|
||||||
|
return fromInner(ret.future.?);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn setAvailable(self: *Future, api: *const Api) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{
|
||||||
|
.future = self.inner(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = api.inner().XLA_FFI_Future_SetAvailable.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err = Error.fromInner(ffi_error);
|
||||||
|
defer err.destroy(api);
|
||||||
|
log.err("[Future.setAvailable] {s}", .{err.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{
|
||||||
|
.future = self.inner(),
|
||||||
|
.@"error" = err.inner(),
|
||||||
|
});
|
||||||
|
|
||||||
|
const result = api.inner().XLA_FFI_Future_SetError.?(&ret);
|
||||||
|
|
||||||
|
if (result) |ffi_error| {
|
||||||
|
const err2 = Error.fromInner(ffi_error);
|
||||||
|
defer err2.destroy(api);
|
||||||
|
log.err("[Future.setError] {s}", .{err2.getMessage(api)});
|
||||||
|
|
||||||
|
// TODO(Corentin): Retrieve error code from Error when implemented in XLA.
|
||||||
|
return error.Unknown;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
@ -1,13 +1,14 @@
|
|||||||
const builtin = @import("builtin");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
pub const ffi = @import("ffi.zig");
|
||||||
|
pub const Profiler = @import("profiler.zig").Profiler;
|
||||||
|
|
||||||
const log = std.log.scoped(.pjrt);
|
const log = std.log.scoped(.pjrt);
|
||||||
|
|
||||||
pub const Profiler = @import("profiler.zig").Profiler;
|
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
@ -160,10 +161,9 @@ pub const Api = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry {
|
||||||
if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| {
|
if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| {
|
||||||
return .{ .inner = ext.custom_call.? };
|
return .{ .inner = ext.register_handler.? };
|
||||||
}
|
}
|
||||||
// log.warn("No Custom Call registry found for platform: {}", .{self});
|
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -405,7 +405,7 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const CreateViewOfDeviceBufferArgs = struct {
|
pub const CreateViewOfDeviceBufferArgs = struct {
|
||||||
data: []const u8,
|
data: *anyopaque,
|
||||||
dims: []const i64,
|
dims: []const i64,
|
||||||
element_type: BufferType,
|
element_type: BufferType,
|
||||||
layout: MemoryLayout,
|
layout: MemoryLayout,
|
||||||
@ -421,7 +421,7 @@ pub const Client = opaque {
|
|||||||
const layout = args.layout.toCStruct();
|
const layout = args.layout.toCStruct();
|
||||||
const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{
|
const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{
|
||||||
.client = self.inner(),
|
.client = self.inner(),
|
||||||
.device_buffer_ptr = @ptrCast(@constCast(args.data.ptr)),
|
.device_buffer_ptr = @ptrCast(@constCast(args.data)),
|
||||||
.dims = args.dims.ptr,
|
.dims = args.dims.ptr,
|
||||||
.num_dims = args.dims.len,
|
.num_dims = args.dims.len,
|
||||||
.element_type = @intFromEnum(args.element_type),
|
.element_type = @intFromEnum(args.element_type),
|
||||||
@ -919,18 +919,14 @@ pub const Memory = opaque {
|
|||||||
const inner = InnerMixin(c.PJRT_Memory).inner;
|
const inner = InnerMixin(c.PJRT_Memory).inner;
|
||||||
|
|
||||||
pub fn id(self: *const Memory, api: *const Api) usize {
|
pub fn id(self: *const Memory, api: *const Api) usize {
|
||||||
const ret = api.call(.PJRT_Memory_Id, .{
|
const ret = api.call(.PJRT_Memory_Id, .{ .memory = self.inner() }) catch unreachable;
|
||||||
.memory = self.inner(),
|
|
||||||
}) catch unreachable;
|
|
||||||
return @intCast(ret.id);
|
return @intCast(ret.id);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
pub fn kind(self: *const Memory, api: *const Api) Kind {
|
||||||
const ret = api.call(.PJRT_Memory_Kind, .{
|
const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable;
|
||||||
.memory = self.inner(),
|
const kind_ = ret.kind orelse unreachable[0..ret.kind_size];
|
||||||
}) catch unreachable;
|
return std.meta.stringToEnum(Kind, kind_) orelse unreachable;
|
||||||
const kind_ = ret.kind orelse unreachable;
|
|
||||||
return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
pub fn kindId(self: *const Memory, api: *const Api) u32 {
|
||||||
@ -1161,23 +1157,25 @@ pub const NamedValue = extern struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Custom call signature arguments are:
|
|
||||||
/// * a pointer to a platform specific stream handle
|
|
||||||
/// * a pointer to an unspecified list of platform specific buffer handle
|
|
||||||
/// * a context struct passed as a slice of bytes
|
|
||||||
pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void;
|
|
||||||
|
|
||||||
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
// todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize
|
||||||
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
// introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5
|
||||||
pub const CustomCallRegistry = extern struct {
|
pub const CustomCallRegistry = extern struct {
|
||||||
inner: *const c.PJRT_Gpu_Register_Custom_Call,
|
inner: *const c.PJRT_FFI_Register_Handler,
|
||||||
|
|
||||||
pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void {
|
pub fn registerFfi(
|
||||||
var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{
|
self: *const CustomCallRegistry,
|
||||||
.function_name = name.ptr,
|
api: *const Api,
|
||||||
.function_name_size = name.len,
|
target_name: []const u8,
|
||||||
.api_version = @intCast(api_version),
|
platform_name: []const u8,
|
||||||
.handler_execute = @ptrCast(@constCast(func)),
|
func: *const ffi.Handler,
|
||||||
|
) ApiError!void {
|
||||||
|
var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{
|
||||||
|
.api_version = 1,
|
||||||
|
.target_name = target_name.ptr,
|
||||||
|
.target_name_size = target_name.len,
|
||||||
|
.handler = @ptrCast(@constCast(func)),
|
||||||
|
.platform_name = platform_name.ptr,
|
||||||
|
.platform_name_size = platform_name.len,
|
||||||
});
|
});
|
||||||
const result = self.inner(&ret);
|
const result = self.inner(&ret);
|
||||||
if (result) |pjrt_c_error| {
|
if (result) |pjrt_c_error| {
|
||||||
|
|||||||
@ -6,3 +6,9 @@ pub const math = @import("math.zig");
|
|||||||
pub const meta = @import("meta.zig");
|
pub const meta = @import("meta.zig");
|
||||||
pub const queue = @import("queue.zig");
|
pub const queue = @import("queue.zig");
|
||||||
pub const time = @import("time.zig");
|
pub const time = @import("time.zig");
|
||||||
|
|
||||||
|
pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T {
|
||||||
|
debug.assert(len <= max_len, "stackSlice can only create a slice of up to {} elements, got: {}", .{ max_len, len });
|
||||||
|
var storage: [max_len]T = undefined;
|
||||||
|
return storage[0..len];
|
||||||
|
}
|
||||||
|
|||||||
37
third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel
vendored
Normal file
37
third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel
vendored
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250317.1-71c67e2",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl", "python_version_repo")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
"llvm-raw",
|
||||||
|
"stablehlo",
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
|
||||||
|
llvm_configure(name = "llvm-project")
|
||||||
37
third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel
vendored
Normal file
37
third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel
vendored
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
module(
|
||||||
|
name = "xla",
|
||||||
|
version = "20250317.1-71c67e2",
|
||||||
|
compatibility_level = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.8")
|
||||||
|
bazel_dep(name = "bazel_skylib", version = "1.5.0")
|
||||||
|
bazel_dep(name = "rules_cc", version = "0.0.17")
|
||||||
|
bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple")
|
||||||
|
bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl")
|
||||||
|
bazel_dep(name = "rules_python", version = "0.29.0")
|
||||||
|
bazel_dep(name = "rules_proto", version = "6.0.0-rc1")
|
||||||
|
bazel_dep(name = "rules_java", version = "7.3.2")
|
||||||
|
bazel_dep(name = "rules_pkg", version = "0.9.1")
|
||||||
|
bazel_dep(name = "zlib", version = "1.2.13")
|
||||||
|
bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2")
|
||||||
|
bazel_dep(name = "rules_license", version = "0.0.8")
|
||||||
|
|
||||||
|
tsl = use_extension("//:tsl.bzl", "tsl")
|
||||||
|
use_repo(tsl, "tsl", "python_version_repo")
|
||||||
|
|
||||||
|
xla_workspace = use_extension("//:workspace.bzl", "xla_workspace")
|
||||||
|
use_repo(
|
||||||
|
xla_workspace,
|
||||||
|
"com_github_grpc_grpc",
|
||||||
|
"com_google_protobuf",
|
||||||
|
"local_config_cuda",
|
||||||
|
"local_config_remote_execution",
|
||||||
|
"local_config_rocm",
|
||||||
|
"local_config_tensorrt",
|
||||||
|
"llvm-raw",
|
||||||
|
"stablehlo",
|
||||||
|
)
|
||||||
|
|
||||||
|
llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure")
|
||||||
|
llvm_configure(name = "llvm-project")
|
||||||
19
third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl
vendored
Normal file
19
third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
load("//third_party:repo.bzl", "tf_vendored")
|
||||||
|
load("//third_party/py:python_init_repositories.bzl", "python_init_repositories")
|
||||||
|
|
||||||
|
def _tsl_impl(mctx):
|
||||||
|
python_init_repositories(
|
||||||
|
requirements = {
|
||||||
|
"3.11": "//:requirements_lock_3_11.txt",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tf_vendored(name = "tsl", relpath = "third_party/tsl")
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = ["tsl"],
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
tsl = module_extension(
|
||||||
|
implementation = _tsl_impl,
|
||||||
|
)
|
||||||
60
third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl
vendored
Normal file
60
third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl
vendored
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls")
|
||||||
|
load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
|
||||||
|
load("//third_party/gpus:rocm_configure.bzl", "rocm_configure")
|
||||||
|
load("//third_party/llvm:workspace.bzl", llvm = "repo")
|
||||||
|
load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo")
|
||||||
|
load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo")
|
||||||
|
load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure")
|
||||||
|
load("//third_party/triton:workspace.bzl", triton = "repo")
|
||||||
|
load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure")
|
||||||
|
|
||||||
|
def _xla_workspace_impl(mctx):
|
||||||
|
cuda_configure(name = "local_config_cuda")
|
||||||
|
remote_execution_configure(name = "local_config_remote_execution")
|
||||||
|
rocm_configure(name = "local_config_rocm")
|
||||||
|
tensorrt_configure(name = "local_config_tensorrt")
|
||||||
|
pybind11_bazel()
|
||||||
|
triton()
|
||||||
|
llvm("llvm-raw")
|
||||||
|
stablehlo()
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_github_grpc_grpc",
|
||||||
|
sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f",
|
||||||
|
strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd",
|
||||||
|
system_build_file = "//third_party/systemlibs:grpc.BUILD",
|
||||||
|
patch_file = [
|
||||||
|
"//third_party/grpc:generate_cc_env_fix.patch",
|
||||||
|
"//third_party/grpc:register_go_toolchain.patch",
|
||||||
|
],
|
||||||
|
system_link_files = {
|
||||||
|
"//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel",
|
||||||
|
"//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl",
|
||||||
|
"//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"),
|
||||||
|
)
|
||||||
|
tf_http_archive(
|
||||||
|
name = "com_google_protobuf",
|
||||||
|
patch_file = ["//third_party/protobuf:protobuf.patch"],
|
||||||
|
sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1",
|
||||||
|
strip_prefix = "protobuf-3.21.9",
|
||||||
|
system_build_file = "//third_party/systemlibs:protobuf.BUILD",
|
||||||
|
system_link_files = {
|
||||||
|
"//third_party/systemlibs:protobuf.bzl": "protobuf.bzl",
|
||||||
|
"//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl",
|
||||||
|
},
|
||||||
|
urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"),
|
||||||
|
)
|
||||||
|
return mctx.extension_metadata(
|
||||||
|
reproducible = True,
|
||||||
|
root_module_direct_deps = "all",
|
||||||
|
root_module_direct_dev_deps = [],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla_workspace = module_extension(
|
||||||
|
implementation = _xla_workspace_impl,
|
||||||
|
)
|
||||||
41
third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
41
third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch
vendored
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001
|
||||||
|
From: Jean-Baptiste Dalido <jb@zml.ai>
|
||||||
|
Date: Mon, 6 Jan 2025 13:33:13 +0100
|
||||||
|
Subject: [PATCH] bazel: migration to bazel 8.0.1
|
||||||
|
|
||||||
|
---
|
||||||
|
.bazelversion | 2 +-
|
||||||
|
third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++--
|
||||||
|
2 files changed, 3 insertions(+), 3 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/.bazelversion b/.bazelversion
|
||||||
|
index f22d756da3..fa5fce04b3 100644
|
||||||
|
--- a/.bazelversion
|
||||||
|
+++ b/.bazelversion
|
||||||
|
@@ -1 +1 @@
|
||||||
|
-7.4.1
|
||||||
|
+8.1.1
|
||||||
|
\ No newline at end of file
|
||||||
|
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
|
||||||
|
index d62531152d..71d80a5a99 100644
|
||||||
|
--- a/third_party/gpus/cuda_configure.bzl
|
||||||
|
+++ b/third_party/gpus/cuda_configure.bzl
|
||||||
|
@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead.
|
||||||
|
load(
|
||||||
|
"@bazel_tools//tools/cpp:lib_cc_configure.bzl",
|
||||||
|
"escape_string",
|
||||||
|
- "get_env_var",
|
||||||
|
)
|
||||||
|
load(
|
||||||
|
"@bazel_tools//tools/cpp:windows_cc_configure.bzl",
|
||||||
|
- "find_msvc_tool",
|
||||||
|
"find_vc_path",
|
||||||
|
"setup_vc_env_vars",
|
||||||
|
)
|
||||||
|
+load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool")
|
||||||
|
+load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var")
|
||||||
|
load("//third_party/clang_toolchain:download_clang.bzl", "download_clang")
|
||||||
|
load(
|
||||||
|
"//third_party/remote_config:common.bzl",
|
||||||
|
--
|
||||||
|
2.39.3 (Apple Git-146)
|
||||||
@ -0,0 +1,131 @@
|
|||||||
|
From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001
|
||||||
|
From: Hugo Mano <hugo@zml.ai>
|
||||||
|
Date: Wed, 5 Feb 2025 19:25:03 +0100
|
||||||
|
Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt
|
||||||
|
|
||||||
|
PR: https://github.com/openxla/xla/pull/13420
|
||||||
|
---
|
||||||
|
xla/pjrt/c/BUILD | 5 ++++
|
||||||
|
xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++
|
||||||
|
xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++--
|
||||||
|
3 files changed, 54 insertions(+), 2 deletions(-)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD
|
||||||
|
index ad2ed95bce..0e7c35c30f 100644
|
||||||
|
--- a/xla/pjrt/c/BUILD
|
||||||
|
+++ b/xla/pjrt/c/BUILD
|
||||||
|
@@ -69,7 +69,12 @@ cc_library(
|
||||||
|
":pjrt_c_api_wrapper_impl",
|
||||||
|
"//xla/ffi:execution_context",
|
||||||
|
"//xla/ffi:type_id_registry",
|
||||||
|
+ "//xla/ffi:ffi_api",
|
||||||
|
+ "//xla/ffi/api:c_api",
|
||||||
|
+ "//xla/ffi/api:ffi",
|
||||||
|
+ "//xla/service:custom_call_target_registry",
|
||||||
|
"@com_google_absl//absl/status",
|
||||||
|
+ "@com_google_absl//absl/strings:str_format",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
index a33bd4aa9c..3309194538 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h
|
||||||
|
@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data);
|
||||||
|
// Adds a user data to the execute context.
|
||||||
|
typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args);
|
||||||
|
|
||||||
|
+struct PJRT_FFI_Register_Handler_Args {
|
||||||
|
+ size_t struct_size;
|
||||||
|
+ const char* target_name;
|
||||||
|
+ size_t target_name_size;
|
||||||
|
+ int api_version; // 0 for an untyped call, 1 -- for typed
|
||||||
|
+ void* handler;
|
||||||
|
+ const char* platform_name;
|
||||||
|
+ size_t platform_name_size;
|
||||||
|
+};
|
||||||
|
+PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler);
|
||||||
|
+
|
||||||
|
+// Registers an FFI call handler for a specific platform.
|
||||||
|
+typedef PJRT_Error* PJRT_FFI_Register_Handler(
|
||||||
|
+ PJRT_FFI_Register_Handler_Args* args);
|
||||||
|
+
|
||||||
|
typedef struct PJRT_FFI_Extension {
|
||||||
|
size_t struct_size;
|
||||||
|
PJRT_Extension_Type type;
|
||||||
|
PJRT_Extension_Base* next;
|
||||||
|
PJRT_FFI_TypeID_Register* type_id_register;
|
||||||
|
PJRT_FFI_UserData_Add* user_data_add;
|
||||||
|
+ PJRT_FFI_Register_Handler* register_handler;
|
||||||
|
} PJRT_FFI;
|
||||||
|
PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add);
|
||||||
|
|
||||||
|
diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
index 0375b39d0b..3527a0756e 100644
|
||||||
|
--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
+++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc
|
||||||
|
@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h"
|
||||||
|
+#include <string>
|
||||||
|
|
||||||
|
#include "absl/status/status.h"
|
||||||
|
+#include "absl/strings/str_format.h"
|
||||||
|
+#include "xla/ffi/api/c_api.h"
|
||||||
|
+#include "xla/ffi/api/ffi.h"
|
||||||
|
#include "xla/ffi/execution_context.h"
|
||||||
|
-#include "xla/ffi/type_id_registry.h"
|
||||||
|
+ #include "xla/ffi/type_id_registry.h"
|
||||||
|
+#include "xla/ffi/ffi_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
|
||||||
|
#include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h"
|
||||||
|
+#include "xla/service/custom_call_target_registry.h"
|
||||||
|
|
||||||
|
namespace pjrt {
|
||||||
|
|
||||||
|
@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
+static PJRT_Error* PJRT_FFI_Register_Handler(
|
||||||
|
+ PJRT_FFI_Register_Handler_Args* args) {
|
||||||
|
+ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual(
|
||||||
|
+ "PJRT_FFI_Register_Handler_Args",
|
||||||
|
+ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size));
|
||||||
|
+ std::string target_name(args->target_name, args->target_name_size);
|
||||||
|
+ std::string platform_name(args->platform_name, args->platform_name_size);
|
||||||
|
+ switch (args->api_version) {
|
||||||
|
+ case 0:
|
||||||
|
+ xla::CustomCallTargetRegistry::Global()->Register(
|
||||||
|
+ target_name, args->handler, platform_name);
|
||||||
|
+ return nullptr;
|
||||||
|
+ case 1:
|
||||||
|
+ xla::ffi::Ffi::RegisterStaticHandler(
|
||||||
|
+ xla::ffi::GetXlaFfiApi(), target_name, platform_name,
|
||||||
|
+ reinterpret_cast<XLA_FFI_Handler*>(args->handler));
|
||||||
|
+ return nullptr;
|
||||||
|
+ default:
|
||||||
|
+ return new PJRT_Error{absl::UnimplementedError(
|
||||||
|
+ absl::StrFormat("API version %d not supported for PJRT GPU plugin. "
|
||||||
|
+ "Supported versions are 0 and 1.",
|
||||||
|
+ args->api_version))};
|
||||||
|
+ }
|
||||||
|
+}
|
||||||
|
+
|
||||||
|
PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||||
|
return {
|
||||||
|
/*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE,
|
||||||
|
@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) {
|
||||||
|
/*next=*/next,
|
||||||
|
/*type_id_register=*/PJRT_FFI_TypeID_Register,
|
||||||
|
/*user_data_add=*/PJRT_FFI_UserData_Add,
|
||||||
|
+ /*register_handler=*/PJRT_FFI_Register_Handler,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
--
|
||||||
|
2.43.0
|
||||||
15
third_party/modules/xla/20250317.1-71c67e2/source.json
vendored
Normal file
15
third_party/modules/xla/20250317.1-71c67e2/source.json
vendored
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
{
|
||||||
|
"strip_prefix": "xla-71c67e2a4f40267115a0d4ea7c36748bbe7e750e",
|
||||||
|
"url": "https://github.com/openxla/xla/archive/71c67e2a4f40267115a0d4ea7c36748bbe7e750e.tar.gz",
|
||||||
|
"integrity": "sha256-j6D1MC7+WsbZ+Ve3hPmDlCzZX1yV6RIP2BWDgQcbcYc=",
|
||||||
|
"overlay": {
|
||||||
|
"tsl.bzl": "",
|
||||||
|
"workspace.bzl": "",
|
||||||
|
"MODULE.bazel": ""
|
||||||
|
},
|
||||||
|
"patch_strip": 1,
|
||||||
|
"patches": {
|
||||||
|
"0001-bazel-migration-to-bazel-8.1.1.patch": "",
|
||||||
|
"0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": ""
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,16 +1,15 @@
|
|||||||
const asynk = @import("async");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
|
||||||
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const pjrt = @import("pjrtx.zig");
|
|
||||||
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
|
const asynk = @import("async");
|
||||||
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const meta = @import("meta.zig");
|
||||||
|
const pjrt = @import("pjrtx.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
|
||||||
@ -27,10 +26,22 @@ const log = std.log.scoped(.zml);
|
|||||||
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
/// * loading weights from disk directly to the `device zml.aio.loadBuffers`
|
||||||
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
/// * can be created by calling `HostBuffer.toDevice(platform)`.
|
||||||
pub const Buffer = struct {
|
pub const Buffer = struct {
|
||||||
pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).@"enum".tag_type) {
|
pub const Memory = enum {
|
||||||
host = @intFromEnum(pjrt.Memory.Kind.unpinned_host),
|
host,
|
||||||
host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host),
|
host_pinned,
|
||||||
device = @intFromEnum(pjrt.Memory.Kind.device),
|
device,
|
||||||
|
|
||||||
|
pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind {
|
||||||
|
return switch (self) {
|
||||||
|
.host => .unpinned_host,
|
||||||
|
.host_pinned => .pinned_host,
|
||||||
|
.device => .device,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pjrtName(self: Memory) []const u8 {
|
||||||
|
return @tagName(self.toPjrtMemory());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Shard = struct {
|
pub const Shard = struct {
|
||||||
@ -216,13 +227,13 @@ pub const Buffer = struct {
|
|||||||
/// and it might not work on all platforms,
|
/// and it might not work on all platforms,
|
||||||
/// could lead to crashes and operations on the buffer will be slower.
|
/// could lead to crashes and operations on the buffer will be slower.
|
||||||
/// Tested on Cuda 12.4.
|
/// Tested on Cuda 12.4.
|
||||||
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) !Buffer {
|
pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer {
|
||||||
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
|
return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Buffer from a pointer into device memory.
|
/// Creates a Buffer from a pointer into device memory.
|
||||||
/// This allows to interface with other libraries producing buffers.
|
/// This allows to interface with other libraries producing buffers.
|
||||||
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) !Buffer {
|
pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer {
|
||||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
@ -231,9 +242,8 @@ pub const Buffer = struct {
|
|||||||
break :blk res;
|
break :blk res;
|
||||||
};
|
};
|
||||||
|
|
||||||
const device_bytes: [*]u8 = @ptrCast(device_data);
|
const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||||
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
.data = device_data,
|
||||||
.data = device_bytes[0..shape_.byteSize()],
|
|
||||||
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
.element_type = bufferTypeFromDtype(shape_.dtype()),
|
||||||
.dims = shape_.dims(),
|
.dims = shape_.dims(),
|
||||||
// TODO: exposes sharding in the API.
|
// TODO: exposes sharding in the API.
|
||||||
@ -246,7 +256,7 @@ pub const Buffer = struct {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
.stream = @bitCast(@as(usize, @intFromPtr(stream))),
|
.stream = @bitCast(@as(usize, @intFromPtr(stream))),
|
||||||
});
|
}) catch @panic("failed to createViewOfDeviceBuffer");
|
||||||
|
|
||||||
var shards: Shards = .{};
|
var shards: Shards = .{};
|
||||||
shards.appendAssumeCapacity(pjrt_buffer);
|
shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
@ -342,6 +352,11 @@ pub const Buffer = struct {
|
|||||||
try writer.print("Buffer({_})", .{self._shape});
|
try writer.print("Buffer({_})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn getMemory(self: Buffer) *const pjrt.Memory {
|
||||||
|
const shard = self._shards.get(0);
|
||||||
|
return shard.memory(self._api);
|
||||||
|
}
|
||||||
|
|
||||||
fn hasShardedAxis(self: Buffer) bool {
|
fn hasShardedAxis(self: Buffer) bool {
|
||||||
if (self._shards.len == 1) return false;
|
if (self._shards.len == 1) return false;
|
||||||
return @reduce(.Or, self._shape._sharding_info);
|
return @reduce(.Or, self._shape._sharding_info);
|
||||||
|
|||||||
139
zml/context.zig
139
zml/context.zig
@ -1,21 +1,23 @@
|
|||||||
|
const std = @import("std");
|
||||||
const builtin = @import("builtin");
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
const runfiles = @import("runfiles");
|
const runfiles = @import("runfiles");
|
||||||
const runtimes = @import("runtimes");
|
const runtimes = @import("runtimes");
|
||||||
const std = @import("std");
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml_platform = @import("platform.zig");
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
const pjrt = @import("pjrtx.zig");
|
const DataType = @import("dtype.zig").DataType;
|
||||||
|
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
const pjrt = @import("pjrtx.zig");
|
||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
const Shape = @import("shape.zig").Shape;
|
||||||
const Target = @import("platform.zig").Target;
|
const Target = @import("platform.zig").Target;
|
||||||
|
const zml_platform = @import("platform.zig");
|
||||||
|
|
||||||
const available_targets = @import("platform.zig").available_targets;
|
const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api);
|
||||||
|
const PlatformsMap = std.EnumArray(Target, ?Platform);
|
||||||
const log = std.log.scoped(.@"zml/context");
|
const log = std.log.scoped(.@"zml/context");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -174,10 +176,8 @@ pub const Context = struct {
|
|||||||
log.err("No device found for platform {} !", .{target});
|
log.err("No device found for platform {} !", .{target});
|
||||||
return error.NoDevicesFound;
|
return error.NoDevicesFound;
|
||||||
}
|
}
|
||||||
// TODO: should this be moved to platform.zig ?
|
|
||||||
if (target == .cuda) {
|
try CustomCall.registerZmlCustomCalls(p);
|
||||||
try cuda.registerZmlCustomCalls(p);
|
|
||||||
}
|
|
||||||
|
|
||||||
self.platforms.set(target, p);
|
self.platforms.set(target, p);
|
||||||
return p;
|
return p;
|
||||||
@ -213,77 +213,68 @@ pub const Context = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const HostCallbackCtx = struct {
|
pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void;
|
||||||
host: HostBuffer,
|
|
||||||
mutex: std.Thread.Mutex = std.Thread.Mutex{},
|
|
||||||
};
|
|
||||||
pub const HostCallback = fn (HostBuffer) void;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const cuda = struct {
|
const CustomCall = struct {
|
||||||
var runtime: Runtime = undefined;
|
pub fn registerZmlCustomCalls(platform: Platform) !void {
|
||||||
|
const registry = platform.pjrt_api.customCallRegistry();
|
||||||
|
|
||||||
pub fn registerZmlCustomCalls(cuda_platform: Platform) !void {
|
if (registry) |reg| {
|
||||||
std.debug.assert(cuda_platform.target == .cuda);
|
try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback);
|
||||||
|
} else {
|
||||||
cuda.runtime = try Runtime.init();
|
stdx.debug.panic("Registering custom calls failed", .{});
|
||||||
const registry = cuda_platform.pjrt_api.customCallRegistry().?;
|
|
||||||
try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub const Stream = opaque {};
|
|
||||||
pub const MemcpyKind = enum(c_int) {
|
|
||||||
host_to_host = 0,
|
|
||||||
host_to_device = 1,
|
|
||||||
device_to_host = 2,
|
|
||||||
device_to_device = 3,
|
|
||||||
default = 4,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const Runtime = struct {
|
|
||||||
memcpyAsync: MemcpyAsync,
|
|
||||||
streamSynchronize: StreamSynchronize,
|
|
||||||
|
|
||||||
const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int;
|
|
||||||
const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int;
|
|
||||||
|
|
||||||
pub fn init() !Runtime {
|
|
||||||
var cudart = try std.DynLib.open("libcudart.so.12");
|
|
||||||
defer cudart.close();
|
|
||||||
|
|
||||||
return .{
|
|
||||||
.memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound,
|
|
||||||
.streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound,
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
};
|
|
||||||
|
|
||||||
fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } {
|
|
||||||
std.debug.assert(args_len == @sizeOf(*anyopaque) * 2);
|
|
||||||
|
|
||||||
const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*);
|
|
||||||
const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr);
|
|
||||||
|
|
||||||
const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*);
|
|
||||||
const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr);
|
|
||||||
return .{ fn_ptr, ctx_ptr };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void {
|
fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error {
|
||||||
const stream: *Stream = @ptrCast(opaque_stream);
|
if (call_frame.registeringHook()) return null;
|
||||||
const src: *anyopaque = buffers[0];
|
|
||||||
const callback, const ctx = getContext(args, args_len);
|
|
||||||
|
|
||||||
// Add synchronization because this is called from the device driver.
|
const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable;
|
||||||
ctx.mutex.lock();
|
std.debug.assert(callback_attr.dtype == .u64);
|
||||||
defer ctx.mutex.unlock();
|
const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize));
|
||||||
|
|
||||||
const host_dst: []u8 = @constCast(ctx.host.data);
|
const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable;
|
||||||
const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream);
|
std.debug.assert(user_ctx_ptr.dtype == .u64);
|
||||||
_ = memcpy_result;
|
const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize));
|
||||||
const synchronize_result = cuda.runtime.streamSynchronize(stream);
|
|
||||||
_ = synchronize_result;
|
|
||||||
|
|
||||||
callback(ctx.host);
|
const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len);
|
||||||
|
for (input_buffers, 0..) |*b, i| {
|
||||||
|
b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len);
|
||||||
|
for (output_buffers, 0..) |*b, i| {
|
||||||
|
b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
callback(user_ctx, input_buffers, output_buffers);
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape {
|
||||||
|
// log.warn("received buffer {}", .{buffer_desc});
|
||||||
|
const dt: DataType = switch (buffer_desc.dtype) {
|
||||||
|
.invalid => @panic("invalid ffi"),
|
||||||
|
.pred => .bool,
|
||||||
|
.s8 => .i8,
|
||||||
|
.s16 => .i16,
|
||||||
|
.s32 => .i32,
|
||||||
|
.s64 => .i64,
|
||||||
|
.token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"),
|
||||||
|
inline else => |t| @field(DataType, @tagName(t)),
|
||||||
|
};
|
||||||
|
return Shape.init(buffer_desc.dims(), dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a HostBuffer from a ffi description of a buffer.
|
||||||
|
/// Normally the ffi describe device buffer but we assume they are located in pinned memory,
|
||||||
|
/// and therefore the data pointer is readable both from host and from device.
|
||||||
|
fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer {
|
||||||
|
const buffer_shape = getShape(buffer_desc);
|
||||||
|
return HostBuffer.fromBytes(
|
||||||
|
buffer_shape,
|
||||||
|
buffer_desc.data[0..buffer_shape.byteSize()],
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const Buffer = @import("buffer.zig").Buffer;
|
const Buffer = @import("buffer.zig").Buffer;
|
||||||
@ -98,6 +99,13 @@ pub const HostBuffer = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a HostBuffer tagged with the tags in 'tagz'.
|
||||||
|
pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer {
|
||||||
|
var res = self;
|
||||||
|
res._shape = self._shape.withTags(tagz);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
pub const ArangeArgs = struct {
|
pub const ArangeArgs = struct {
|
||||||
start: i64 = 0,
|
start: i64 = 0,
|
||||||
end: i64,
|
end: i64,
|
||||||
@ -240,6 +248,11 @@ pub const HostBuffer = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn choose1d(self: HostBuffer, axis_: anytype, start: i64) HostBuffer {
|
||||||
|
const ax = self.axis(axis_);
|
||||||
|
return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer {
|
||||||
const ax = self._shape.axis(axis_);
|
const ax = self._shape.axis(axis_);
|
||||||
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self });
|
||||||
|
|||||||
@ -1,11 +1,10 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const stdx = @import("stdx");
|
const testing = std.testing;
|
||||||
|
|
||||||
|
const stdx = @import("stdx");
|
||||||
const FnParam = stdx.meta.FnParam;
|
const FnParam = stdx.meta.FnParam;
|
||||||
const asSlice = stdx.meta.asSlice;
|
const asSlice = stdx.meta.asSlice;
|
||||||
|
|
||||||
const testing = std.testing;
|
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -367,6 +367,7 @@ pub const CompilationContext = struct {
|
|||||||
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
|
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
|
||||||
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
|
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
|
||||||
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
|
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
|
||||||
|
const fn_res_output_memory_kind = try res_allocator.alloc(Buffer.Memory, out_tensor_count);
|
||||||
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
||||||
{
|
{
|
||||||
defer self.closeBlock(fn_body);
|
defer self.closeBlock(fn_body);
|
||||||
@ -382,7 +383,7 @@ pub const CompilationContext = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
||||||
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations, fn_res_output_memory_kind);
|
||||||
|
|
||||||
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||||
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
|
fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]);
|
||||||
@ -396,6 +397,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
if (opts.kind == .main) {
|
if (opts.kind == .main) {
|
||||||
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
||||||
|
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
|
||||||
if (self._platform.sharding().num_partitions > 1) {
|
if (self._platform.sharding().num_partitions > 1) {
|
||||||
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
|
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
|
||||||
}
|
}
|
||||||
@ -433,6 +435,20 @@ pub const CompilationContext = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn addOutputMemoryKindAttributes(self: CompilationContext, attributes: []AttributeList, output_memory_kind: []const Buffer.Memory) void {
|
||||||
|
const mlir_ctx = self.mlirCtx();
|
||||||
|
for (attributes, output_memory_kind) |*attr, memory_kind| {
|
||||||
|
// .device is the default output, don't explicitly emit the attribute
|
||||||
|
if (memory_kind == .device) continue;
|
||||||
|
|
||||||
|
attr.appendAssumeCapacity(.named(
|
||||||
|
mlir_ctx,
|
||||||
|
"mhlo.memory_kind",
|
||||||
|
.string(mlir_ctx, memory_kind.pjrtName()),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Given a list of donations mapping output buffers to input buffers,
|
/// Given a list of donations mapping output buffers to input buffers,
|
||||||
/// generate donation attribute for each `n_args` input argument.
|
/// generate donation attribute for each `n_args` input argument.
|
||||||
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
|
fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void {
|
||||||
@ -712,7 +728,15 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found.
|
/// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found.
|
||||||
pub fn extractValuesAndTypes(self: *const CompilationContext, v: anytype, values: []mlir.Value, types: []mlir.Type, shapes: []Shape, donations: []Tensor._Donation) void {
|
pub fn extractValuesAndTypes(
|
||||||
|
self: *const CompilationContext,
|
||||||
|
v: anytype,
|
||||||
|
values: []mlir.Value,
|
||||||
|
types: []mlir.Type,
|
||||||
|
shapes: []Shape,
|
||||||
|
donations: []Tensor._Donation,
|
||||||
|
output_memory_kind: []Buffer.Memory,
|
||||||
|
) void {
|
||||||
std.debug.assert(values.len == types.len);
|
std.debug.assert(values.len == types.len);
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
self: *const CompilationContext,
|
self: *const CompilationContext,
|
||||||
@ -721,8 +745,16 @@ pub const CompilationContext = struct {
|
|||||||
types: []mlir.Type,
|
types: []mlir.Type,
|
||||||
shapes: []Shape,
|
shapes: []Shape,
|
||||||
donations: []Tensor._Donation,
|
donations: []Tensor._Donation,
|
||||||
|
output_memory_kind: []Buffer.Memory,
|
||||||
|
};
|
||||||
|
var context = LocalContext{
|
||||||
|
.self = self,
|
||||||
|
.values = values,
|
||||||
|
.types = types,
|
||||||
|
.shapes = shapes,
|
||||||
|
.donations = donations,
|
||||||
|
.output_memory_kind = output_memory_kind,
|
||||||
};
|
};
|
||||||
var context = LocalContext{ .self = self, .values = values, .types = types, .shapes = shapes, .donations = donations };
|
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
|
fn cb(ctx: *LocalContext, tensor: *const Tensor) void {
|
||||||
const value, const donation = ctx.self.getValueAndDonation(tensor.*);
|
const value, const donation = ctx.self.getValueAndDonation(tensor.*);
|
||||||
@ -730,6 +762,7 @@ pub const CompilationContext = struct {
|
|||||||
ctx.types[ctx.index] = value.getType();
|
ctx.types[ctx.index] = value.getType();
|
||||||
ctx.shapes[ctx.index] = tensor._shape;
|
ctx.shapes[ctx.index] = tensor._shape;
|
||||||
ctx.donations[ctx.index] = donation;
|
ctx.donations[ctx.index] = donation;
|
||||||
|
ctx.output_memory_kind[ctx.index] = tensor._output_memory_kind;
|
||||||
ctx.index += 1;
|
ctx.index += 1;
|
||||||
}
|
}
|
||||||
}).cb, &context, v);
|
}).cb, &context, v);
|
||||||
|
|||||||
@ -1,16 +1,16 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const Context = @import("../context.zig").Context;
|
|
||||||
const module = @import("../module.zig");
|
|
||||||
const mlir = @import("../mlir.zig");
|
|
||||||
const dialect = @import("mlir/dialects");
|
const dialect = @import("mlir/dialects");
|
||||||
|
|
||||||
const Tensor = @import("../tensor.zig").Tensor;
|
const Context = @import("../context.zig").Context;
|
||||||
const Shape = @import("../shape.zig").Shape;
|
|
||||||
const SdpaOpts = @import("../nn.zig").SdpaOpts;
|
|
||||||
const DataType = @import("../dtype.zig").DataType;
|
const DataType = @import("../dtype.zig").DataType;
|
||||||
const Data = @import("../dtype.zig").Data;
|
const Data = @import("../dtype.zig").Data;
|
||||||
|
const mlir = @import("../mlir.zig");
|
||||||
|
const module = @import("../module.zig");
|
||||||
const CompilationContext = module.CompilationContext;
|
const CompilationContext = module.CompilationContext;
|
||||||
|
const SdpaOpts = @import("../nn.zig").SdpaOpts;
|
||||||
|
const Shape = @import("../shape.zig").Shape;
|
||||||
|
const Tensor = @import("../tensor.zig").Tensor;
|
||||||
|
|
||||||
pub fn canUseCudnnSdpa(q_shape: Shape) bool {
|
pub fn canUseCudnnSdpa(q_shape: Shape) bool {
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
@ -125,7 +125,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
&.{ q.value(), k.value(), v.value(), bias.value() },
|
&.{ q.value(), k.value(), v.value(), bias.value() },
|
||||||
.{
|
.{
|
||||||
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
||||||
.backend_config = .{ .string = backend_config },
|
.backend_config = .string(mlir_ctx, backend_config),
|
||||||
.has_side_effect = false,
|
.has_side_effect = false,
|
||||||
.api_version = .original,
|
.api_version = .original,
|
||||||
},
|
},
|
||||||
|
|||||||
153
zml/ops.zig
153
zml/ops.zig
@ -1,28 +1,31 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
const assert = std.debug.assert;
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
|
const _collectAxes = @import("tensor.zig")._collectAxes;
|
||||||
const buffer = @import("buffer.zig");
|
const buffer = @import("buffer.zig");
|
||||||
const helpers = @import("helpers.zig");
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
const mlir = @import("mlir.zig");
|
|
||||||
const module = @import("module.zig");
|
|
||||||
|
|
||||||
const Buffer = buffer.Buffer;
|
const Buffer = buffer.Buffer;
|
||||||
const CompilationContext = module.CompilationContext;
|
const Bufferized = @import("tensor.zig").Bufferized;
|
||||||
const Context = @import("context.zig").Context;
|
const Context = @import("context.zig").Context;
|
||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
const EnumLiteral = @TypeOf(.enum_literal);
|
const helpers = @import("helpers.zig");
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const meta = @import("meta.zig");
|
||||||
|
const mlir = @import("mlir.zig");
|
||||||
|
const module = @import("module.zig");
|
||||||
|
const CompilationContext = module.CompilationContext;
|
||||||
|
const Platform = @import("platform.zig").Platform;
|
||||||
const Shape = @import("shape.zig").Shape;
|
const Shape = @import("shape.zig").Shape;
|
||||||
|
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||||
const Tensor = @import("tensor.zig").Tensor;
|
const Tensor = @import("tensor.zig").Tensor;
|
||||||
const _collectAxes = @import("tensor.zig")._collectAxes;
|
|
||||||
|
|
||||||
|
const EnumLiteral = @TypeOf(.enum_literal);
|
||||||
const dialect = struct {
|
const dialect = struct {
|
||||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||||
};
|
};
|
||||||
|
|
||||||
const assert = std.debug.assert;
|
|
||||||
const log = std.log.scoped(.@"zml/tensor");
|
const log = std.log.scoped(.@"zml/tensor");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -766,50 +769,56 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// At runtime the given tensor will be materialized and copied to host,
|
pub const HostCallbackOpt = struct {
|
||||||
/// and the callback will be called on it.
|
has_side_effect: bool = false,
|
||||||
|
output_operand_aliases: []const i64 = &.{},
|
||||||
|
};
|
||||||
|
|
||||||
pub fn addHostCallback(
|
pub fn addHostCallback(
|
||||||
callback: *const fn (HostBuffer) void,
|
callback: *const Context.HostCallback,
|
||||||
input: Tensor,
|
blkctx: ?*anyopaque,
|
||||||
) Tensor {
|
inputs: []const Tensor,
|
||||||
// TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy.
|
output_shapes: []const Shape,
|
||||||
if (input.getContext().target() != .cuda) return input;
|
opts: HostCallbackOpt,
|
||||||
|
) []Tensor {
|
||||||
const len = input.byteSize();
|
|
||||||
// Reserve memory to be able to log the runtime Buffer later during the computation.
|
|
||||||
// This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled.
|
|
||||||
const HostCallbackCtx = Context.HostCallbackCtx;
|
|
||||||
const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch {
|
|
||||||
log.err("Failed to pre-allocate buffer to print {}.", .{input});
|
|
||||||
return input;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Save the HostBuffer inside the same memory slice, so that it's still present at runtime.
|
|
||||||
// Use an fba to have the stable buffer at an aligned offset.
|
|
||||||
var fba = std.heap.FixedBufferAllocator.init(full_data[len..]);
|
|
||||||
const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable;
|
|
||||||
stable_ctx_ptr.* = .{
|
|
||||||
.host = HostBuffer.fromBytes(input.shape(), full_data[0..len]),
|
|
||||||
};
|
|
||||||
|
|
||||||
const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr };
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
|
|
||||||
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
|
const backend_config = mlir.Attribute.dict(mlir_ctx, &.{
|
||||||
|
.{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) },
|
||||||
|
.{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) },
|
||||||
|
});
|
||||||
|
|
||||||
|
const values = stdx.stackSlice(8, mlir.Value, inputs.len);
|
||||||
|
for (inputs, values) |i, *v| {
|
||||||
|
v.* = ctx.getValue(i.toMemory(.host_pinned));
|
||||||
|
}
|
||||||
|
const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len);
|
||||||
|
for (res_types, output_shapes) |*r, o| {
|
||||||
|
r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type);
|
||||||
|
}
|
||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src());
|
const loc = ctx.mlirCtx().location(@src());
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
&.{input.value()},
|
values,
|
||||||
.{
|
.{
|
||||||
.has_side_effect = false,
|
|
||||||
.call_target_name = "zmlHostBufferCallback",
|
.call_target_name = "zmlHostBufferCallback",
|
||||||
.backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) },
|
.api_version = .typed_ffi,
|
||||||
.output_operand_aliases = &.{0},
|
.backend_config = backend_config,
|
||||||
.api_version = .original,
|
.has_side_effect = opts.has_side_effect,
|
||||||
|
.output_operand_aliases = opts.output_operand_aliases,
|
||||||
},
|
},
|
||||||
&.{input.value().getType()},
|
res_types,
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
return Tensor._result(input.shape(), op.result(0));
|
|
||||||
|
const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM");
|
||||||
|
for (res, output_shapes, 0..) |*r, o, i| {
|
||||||
|
r.* = Tensor._result(o, op.result(i)).toMemory(.device);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const TritonOps = struct {
|
pub const TritonOps = struct {
|
||||||
@ -834,46 +843,32 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T
|
|||||||
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output);
|
||||||
}
|
}
|
||||||
|
|
||||||
const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{
|
const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{
|
||||||
.named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)),
|
.{ "name", .string(ctx.mlirCtx(), opts.name) },
|
||||||
.named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)),
|
.{ "ir", .string(ctx.mlirCtx(), opts.ir) },
|
||||||
.named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])),
|
.{ "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0]) },
|
||||||
.named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])),
|
.{ "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1]) },
|
||||||
.named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])),
|
.{ "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2]) },
|
||||||
.named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)),
|
.{ "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages) },
|
||||||
.named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)),
|
.{ "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps) },
|
||||||
});
|
});
|
||||||
|
|
||||||
const MINOR_TO_MAJOR = blk: {
|
var operands_layouts: [inputs.len][]const usize = undefined;
|
||||||
var ret: [Shape.MAX_RANK]usize = undefined;
|
inline for (inputs, 0..) |input, i| {
|
||||||
for (0..Shape.MAX_RANK) |i| {
|
operands_layouts[i] = minorToMajor(input.rank());
|
||||||
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
}
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
const operands_layouts = blk: {
|
var results_layouts: [outputs.len][]const usize = undefined;
|
||||||
var ret: [inputs.len][]const usize = undefined;
|
inline for (outputs, 0..) |output, i| {
|
||||||
inline for (inputs, 0..) |input, i| {
|
results_layouts[i] = minorToMajor(output.rank());
|
||||||
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..];
|
}
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
const results_layouts = blk: {
|
|
||||||
var ret: [outputs.len][]const usize = undefined;
|
|
||||||
inline for (outputs, 0..) |output, i| {
|
|
||||||
ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..];
|
|
||||||
}
|
|
||||||
break :blk ret;
|
|
||||||
};
|
|
||||||
|
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
&values,
|
&values,
|
||||||
.{
|
.{
|
||||||
.call_target_name = "__gpu$xla.gpu.triton",
|
.call_target_name = "__gpu$xla.gpu.triton",
|
||||||
.backend_config = .{ .dict = attrs },
|
.backend_config = backend_config,
|
||||||
.has_side_effect = false,
|
.has_side_effect = false,
|
||||||
.api_version = .typed_ffi,
|
.api_version = .typed_ffi,
|
||||||
.operand_layouts = &operands_layouts,
|
.operand_layouts = &operands_layouts,
|
||||||
@ -1256,3 +1251,15 @@ inline fn toI64(values: anytype) []i64 {
|
|||||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||||
return res[0..values.len];
|
return res[0..values.len];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const _MINOR_TO_MAJOR = blk: {
|
||||||
|
var ret: [Shape.MAX_RANK]usize = undefined;
|
||||||
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
|
ret[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||||
|
}
|
||||||
|
break :blk ret;
|
||||||
|
};
|
||||||
|
|
||||||
|
fn minorToMajor(rank: u8) []const usize {
|
||||||
|
return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..];
|
||||||
|
}
|
||||||
|
|||||||
@ -1,19 +1,10 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const builtin = @import("builtin");
|
|
||||||
const dialects = @import("mlir/dialects");
|
const dialects = @import("mlir/dialects");
|
||||||
const mlir = @import("mlir");
|
const mlir = @import("mlir");
|
||||||
const pjrt = @import("pjrt");
|
const pjrt = @import("pjrt");
|
||||||
const std = @import("std");
|
pub const ffi = pjrt.ffi;
|
||||||
const stdx = @import("stdx");
|
|
||||||
const c = @import("c");
|
|
||||||
|
|
||||||
const dtype = @import("dtype.zig");
|
|
||||||
const meta = @import("meta.zig");
|
|
||||||
|
|
||||||
const Target = @import("platform.zig").Target;
|
|
||||||
|
|
||||||
const log = std.log.scoped(.zml);
|
|
||||||
|
|
||||||
pub const Profiler = pjrt.Profiler;
|
pub const Profiler = pjrt.Profiler;
|
||||||
pub const ApiError = pjrt.ApiError;
|
pub const ApiError = pjrt.ApiError;
|
||||||
pub const ErrorCode = pjrt.ErrorCode;
|
pub const ErrorCode = pjrt.ErrorCode;
|
||||||
@ -23,7 +14,6 @@ pub const DeviceDescription = pjrt.DeviceDescription;
|
|||||||
pub const Api = pjrt.Api;
|
pub const Api = pjrt.Api;
|
||||||
pub const NamedValue = pjrt.NamedValue;
|
pub const NamedValue = pjrt.NamedValue;
|
||||||
pub const ClientInitError = pjrt.ClientInitError;
|
pub const ClientInitError = pjrt.ClientInitError;
|
||||||
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
|
|
||||||
pub const Error = pjrt.Error;
|
pub const Error = pjrt.Error;
|
||||||
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
||||||
pub const SerializeResult = pjrt.SerializeResult;
|
pub const SerializeResult = pjrt.SerializeResult;
|
||||||
@ -31,6 +21,10 @@ pub const Executable = pjrt.Executable;
|
|||||||
pub const ExecuteError = ApiError;
|
pub const ExecuteError = ApiError;
|
||||||
pub const Memory = pjrt.Memory;
|
pub const Memory = pjrt.Memory;
|
||||||
|
|
||||||
|
const log = std.log.scoped(.zml);
|
||||||
|
|
||||||
|
pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError;
|
||||||
|
|
||||||
fn InnerMixin(comptime innerT: type) type {
|
fn InnerMixin(comptime innerT: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT {
|
inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT {
|
||||||
@ -159,6 +153,10 @@ pub const Buffer = opaque {
|
|||||||
return self.inner().isOnCpu(api);
|
return self.inner().isOnCpu(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn memory(self: *const Buffer, api: *const Api) *const Memory {
|
||||||
|
return self.inner().memory(api);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
|
pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event {
|
||||||
return @ptrCast(try self.inner().toHostBuffer(api, dst));
|
return @ptrCast(try self.inner().toHostBuffer(api, dst));
|
||||||
}
|
}
|
||||||
@ -183,8 +181,8 @@ pub const Buffer = opaque {
|
|||||||
return @ptrCast(self.inner().copyToDevice(api, device));
|
return @ptrCast(self.inner().copyToDevice(api, device));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory: *const Memory) ApiError!*Buffer {
|
pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer {
|
||||||
return @ptrCast(self.inner().copyToMemory(api, memory));
|
return @ptrCast(self.inner().copyToMemory(api, memory_));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event {
|
||||||
|
|||||||
@ -9,6 +9,7 @@ const Buffer = @import("buffer.zig").Buffer;
|
|||||||
const Data = @import("dtype.zig").Data;
|
const Data = @import("dtype.zig").Data;
|
||||||
const DataType = @import("dtype.zig").DataType;
|
const DataType = @import("dtype.zig").DataType;
|
||||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||||
|
const Memory = @import("buffer.zig").Buffer.Memory;
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
const mlir = @import("mlir.zig");
|
const mlir = @import("mlir.zig");
|
||||||
const Location = mlir.Location;
|
const Location = mlir.Location;
|
||||||
@ -41,10 +42,10 @@ pub const Tensor = struct {
|
|||||||
_shape: Shape,
|
_shape: Shape,
|
||||||
_id: _Id,
|
_id: _Id,
|
||||||
_donation: _Donation = .no_buffer,
|
_donation: _Donation = .no_buffer,
|
||||||
|
_output_memory_kind: Memory = .device,
|
||||||
|
|
||||||
pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 };
|
pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 };
|
||||||
pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 };
|
pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 };
|
||||||
|
|
||||||
pub const MAX_RANK = Shape.MAX_RANK;
|
pub const MAX_RANK = Shape.MAX_RANK;
|
||||||
|
|
||||||
/// Returns the current compilation context.
|
/// Returns the current compilation context.
|
||||||
@ -171,20 +172,22 @@ pub const Tensor = struct {
|
|||||||
return switch (self._id) {
|
return switch (self._id) {
|
||||||
.arg_id, .mlir => {
|
.arg_id, .mlir => {
|
||||||
const ctx = self.getContext();
|
const ctx = self.getContext();
|
||||||
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.withSharding(axes_);
|
res._shape = self._shape.withSharding(axes_);
|
||||||
|
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
mlir_ctx,
|
||||||
&.{self.value()},
|
&.{self.value()},
|
||||||
.{
|
.{
|
||||||
.call_target_name = "Sharding",
|
.call_target_name = "Sharding",
|
||||||
.has_side_effect = false,
|
.has_side_effect = false,
|
||||||
.addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }},
|
.backend_config = null,
|
||||||
|
.additional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }},
|
||||||
.api_version = .original,
|
.api_version = .original,
|
||||||
},
|
},
|
||||||
&.{self.value().getType()},
|
&.{self.value().getType()},
|
||||||
ctx.mlirCtx().location(@src()),
|
mlir_ctx.location(@src()),
|
||||||
);
|
);
|
||||||
|
|
||||||
return _result(res._shape, op.result(0));
|
return _result(res._shape, op.result(0));
|
||||||
@ -197,6 +200,39 @@ pub const Tensor = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn toMemory(self: Tensor, kind: Memory) Tensor {
|
||||||
|
return switch (self._id) {
|
||||||
|
.arg_id, .mlir => {
|
||||||
|
const ctx = self.getContext();
|
||||||
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
|
if (ctx.target() == .cpu) return self;
|
||||||
|
var res = self;
|
||||||
|
res._output_memory_kind = kind;
|
||||||
|
|
||||||
|
const memory_kind = @tagName(kind.toPjrtMemory());
|
||||||
|
|
||||||
|
const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{
|
||||||
|
.{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) },
|
||||||
|
});
|
||||||
|
|
||||||
|
const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{
|
||||||
|
.call_target_name = "annotate_device_placement",
|
||||||
|
.has_side_effect = true,
|
||||||
|
.backend_config = null,
|
||||||
|
.additional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }},
|
||||||
|
.api_version = .original,
|
||||||
|
}, &.{self.value().getType()}, mlir_ctx.location(@src()));
|
||||||
|
|
||||||
|
return _result(res._shape, op.result(0));
|
||||||
|
},
|
||||||
|
.buffer_id => {
|
||||||
|
var res = self;
|
||||||
|
res._output_memory_kind = kind;
|
||||||
|
return res;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a Tensor with new tag names.
|
/// Returns a Tensor with new tag names.
|
||||||
pub fn rename(self: Tensor, renames: anytype) Tensor {
|
pub fn rename(self: Tensor, renames: anytype) Tensor {
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -3747,18 +3783,22 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Insert code that will print the content of the given buffer at runtime.
|
/// Insert code that will print the content of the given buffer at runtime.
|
||||||
/// Only for debug purpose, it has the following limitations:
|
/// Only for debug purpose, it inserts device to host synchronization
|
||||||
/// * only supported on Cuda atm
|
/// so it will slow down the program execution.
|
||||||
/// * only prints the first 1024 values
|
|
||||||
/// * pre allocates a buffer on the host to copy the content of the device buffer,
|
|
||||||
/// this buffer won't be freed. You will have one buffer per "print" call in the IR.
|
|
||||||
/// * does device to host synchronization so it will slow down the program execution.
|
|
||||||
pub fn print(input: Tensor) Tensor {
|
pub fn print(input: Tensor) Tensor {
|
||||||
return ops.addHostCallback(&printCallback, input);
|
return ops.addHostCallback(
|
||||||
|
&printCallback,
|
||||||
|
null,
|
||||||
|
&.{input},
|
||||||
|
&.{input.shape()},
|
||||||
|
.{ .output_operand_aliases = &.{0} },
|
||||||
|
)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
fn printCallback(host_buffer: HostBuffer) void {
|
fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void {
|
||||||
|
const host_buffer = inputs[0];
|
||||||
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() });
|
||||||
|
std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user