From aec7072837f807796faceddc231b89589528f0fe Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 10 Sep 2024 09:14:28 +0000 Subject: [PATCH] pjrt: add FFI bindings for custom calls --- MODULE.bazel | 2 +- mlir/dialects/stablehlo.zig | 103 ++-- mlir/mlir.zig | 12 + pjrt/BUILD.bazel | 13 +- pjrt/convert/trace_container.zig | 74 ++- pjrt/convert/xplane_visitor.zig | 68 --- pjrt/ffi.zig | 517 ++++++++++++++++++ pjrt/pjrt.zig | 58 +- stdx/stdx.zig | 6 + .../xla/20250317.1-71c67e2/MODULE.bazel | 37 ++ .../20250317.1-71c67e2/overlay/MODULE.bazel | 37 ++ .../xla/20250317.1-71c67e2/overlay/tsl.bzl | 19 + .../20250317.1-71c67e2/overlay/workspace.bzl | 60 ++ .../0001-bazel-migration-to-bazel-8.1.1.patch | 41 ++ ...ler-registration-API-to-the-FFI-PjRt.patch | 131 +++++ .../xla/20250317.1-71c67e2/source.json | 15 + zml/buffer.zig | 47 +- zml/context.zig | 139 +++-- zml/hostbuffer.zig | 13 + zml/meta.zig | 5 +- zml/module.zig | 39 +- zml/nn/cuda.zig | 14 +- zml/ops.zig | 153 +++--- zml/pjrtx.zig | 28 +- zml/tensor.zig | 64 ++- 25 files changed, 1319 insertions(+), 376 deletions(-) delete mode 100644 pjrt/convert/xplane_visitor.zig create mode 100644 pjrt/ffi.zig create mode 100644 third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel create mode 100644 third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel create mode 100644 third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl create mode 100644 third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl create mode 100644 third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch create mode 100644 third_party/modules/xla/20250317.1-71c67e2/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch create mode 100644 third_party/modules/xla/20250317.1-71c67e2/source.json diff --git a/MODULE.bazel b/MODULE.bazel index fd405c9..1bcf271 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -24,7 +24,7 @@ bazel_dep(name = "rules_zig", version = "20250314.0-b9739c6") bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a") bazel_dep(name = "toolchains_protoc", version = "0.3.7") bazel_dep(name = "with_cfg.bzl", version = "0.9.1") -bazel_dep(name = "xla", version = "20250317.0-71c67e2") +bazel_dep(name = "xla", version = "20250317.1-71c67e2") bazel_dep(name = "zig-protobuf", version = "20250318.0-930153e") bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf") diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index c5870a2..6328321 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -739,18 +739,13 @@ pub const CustomCallOpts = struct { typed_ffi = 4, }; - pub const BackendConfig = union(enum) { - string: [:0]const u8, - dict: mlir.DictionaryAttribute, - }; - call_target_name: [:0]const u8, has_side_effect: bool, - backend_config: BackendConfig = .{ .string = &.{} }, + backend_config: ?mlir.Attribute, operand_layouts: ?[]const []const usize = null, result_layouts: ?[]const []const usize = null, output_operand_aliases: []const i64 = &.{}, - addional_attributes: []const mlir.AttrTuple = &.{}, + additional_attributes: []const mlir.AttrTuple = &.{}, api_version: ApiVersion, }; @@ -758,68 +753,56 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa const MAX_OPERANDS = 64; const MAX_RESULTS = 16; - const output_operand_aliases = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (opts.output_operand_aliases) |alias| { - ret.appendAssumeCapacity( - OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), - ); - } - break :blk ret; - }; - - const backend_config: mlir.Attribute = switch (opts.backend_config) { - .string => blk: { - stdx.debug.assert( - @intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi), - "Only API version of less than 4 is supported for backend_config as string", - .{}, - ); - break :blk .string(ctx, opts.backend_config.string); - }, - .dict => blk: { - stdx.debug.assert( - opts.api_version == .typed_ffi, - "Only API version 4 is supported for backend_config as dictionary", - .{}, - ); - break :blk opts.backend_config.dict.as(mlir.Attribute); - }, - }; + const backend_config = opts.backend_config orelse mlir.Attribute.string(ctx, ""); + if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { + stdx.debug.assert( + backend_config.is_a(mlir.StringAttribute), + "API version < 4 requires a string as backend_config, got {}", + .{backend_config}, + ); + } else { + stdx.debug.assert( + backend_config.is_a(mlir.DictionaryAttribute), + "API version >= 4 requires a dictionary as backend_config, got {}", + .{backend_config}, + ); + } var attrs: std.BoundedArray(mlir.AttrTuple, 32) = .{}; - attrs.appendSliceAssumeCapacity(&[_]mlir.AttrTuple{ .{ "api_version", .int(ctx, .i32, @intFromEnum(opts.api_version)) }, .{ "call_target_name", .string(ctx, opts.call_target_name) }, .{ "has_side_effect", .boolean(ctx, opts.has_side_effect) }, .{ "backend_config", backend_config }, - .{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }, }); + { + var output_operand_aliases: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (opts.output_operand_aliases) |alias| { + output_operand_aliases.appendAssumeCapacity( + OutputOperandAliasAttribute.init(ctx, &.{}, alias, &.{}).as(mlir.Attribute), + ); + } + attrs.appendAssumeCapacity(.{ "output_operand_aliases", .array(ctx, output_operand_aliases.constSlice()) }); + } + if (opts.operand_layouts) |layouts| { - const operand_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; - for (layouts) |ol| { - ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); - } - break :blk ret; - }; + var operand_layouts: std.BoundedArray(mlir.Attribute, MAX_OPERANDS) = .{}; + for (layouts) |ol| { + operand_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(ol.len)}, .index, ol)); + } attrs.appendAssumeCapacity(.{ "operand_layouts", .array(ctx, operand_layouts.constSlice()) }); } if (opts.result_layouts) |layouts| { - const result_layouts = blk: { - var ret: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; - for (layouts) |rl| { - ret.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); - } - break :blk ret; - }; + var result_layouts: std.BoundedArray(mlir.Attribute, MAX_RESULTS) = .{}; + for (layouts) |rl| { + result_layouts.appendAssumeCapacity(.denseElements(ctx, &.{@intCast(rl.len)}, .index, rl)); + } attrs.appendAssumeCapacity(.{ "result_layouts", .array(ctx, result_layouts.constSlice()) }); } - attrs.appendSliceAssumeCapacity(opts.addional_attributes); + attrs.appendSlice(opts.additional_attributes) catch @panic("Too many additional_attributes"); return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ .operands = inputs, @@ -829,22 +812,6 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa }); } -// todo: move out of stablehlo.zig when we start to implement the frontend -pub fn annotate_device_placement(ctx: mlir.Context, inputs: []const mlir.Value, memory_kind: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { - const frontend_attributes = mlir.DictionaryAttribute.init( - ctx, - &.{.named(ctx, "_xla_buffer_placement", memory_kind.asAttr())}, - ).asAttr(); - - return custom_call(ctx, inputs, .{ - .call_target_name = "annotate_device_placement", - .has_side_effect = true, - .backend_config = .{ .string = &.{} }, - .addional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }}, - .api_version = .original, - }, res_types, location); -} - pub const DotDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 4e2336a..a4f378c 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -414,6 +414,18 @@ pub const Attribute = struct { pub fn named(attr: Attribute, ctx: Context, name: [:0]const u8) NamedAttribute { return NamedAttribute.init(Identifier.get(ctx, name), attr); } + + pub fn dict(ctx: Context, named_attrs: []const AttrTuple) Attribute { + var attr_buf: [32]NamedAttribute = undefined; + stdx.debug.assert(named_attrs.len <= attr_buf.len, ".dict helper only support up to {} attribute, got {}. You will need to call mlir.DictionaryAttribute directly", .{ attr_buf.len, named_attrs.len }); + + const attrs = attr_buf[0..named_attrs.len]; + for (attrs, named_attrs) |*attr, tuple| { + attr.* = .named(ctx, tuple[0], tuple[1]); + } + + return DictionaryAttribute.init(ctx, attrs).asAttr(); + } }; pub const NamedAttribute = extern struct { diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 08d803e..1b424ed 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -1,3 +1,4 @@ +load("@rules_cc//cc:defs.bzl", "cc_library") load("@rules_zig//zig:defs.bzl", "zig_library") load("@zml//bazel:zig.bzl", "zig_cc_binary") load("//bazel:zig_proto_library.bzl", "zig_proto_library") @@ -12,7 +13,12 @@ cc_library( zig_library( name = "pjrt", - srcs = ["profiler.zig"] + glob(["convert/*.zig"]), + srcs = [ + "ffi.zig", + "profiler.zig", + "convert/trace_container.zig", + "convert/xplane_schema.zig" + ], main = "pjrt.zig", visibility = ["//visibility:public"], deps = [ @@ -20,9 +26,12 @@ zig_library( ":trace_events_proto", ":xplane_proto", "//stdx", + "@xla//xla/ffi/api:c_api", + "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", + "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", ] + select({ "@platforms//os:linux": [":dlfcn"], "//conditions:default": [], @@ -49,7 +58,7 @@ zig_proto_library( zig_cc_binary( name = "xspace_to_json", - srcs = glob(["convert/*.zig"]), + srcs = ["convert/trace_container.zig", "convert/xplane_schema.zig"], main = "xspace_to_json.zig", visibility = ["//visibility:public"], deps = [ diff --git a/pjrt/convert/trace_container.zig b/pjrt/convert/trace_container.zig index 7156a1b..6346547 100644 --- a/pjrt/convert/trace_container.zig +++ b/pjrt/convert/trace_container.zig @@ -1,8 +1,9 @@ const std = @import("std"); + const trace_events_proto = @import("//tsl:trace_events_proto"); const xplane_proto = @import("//tsl:xplane_proto"); + const xplane_schema = @import("xplane_schema.zig"); -const xplane_visitor = @import("xplane_visitor.zig"); // Constants used as trace_viewer PID (device_id in trace_events.proto). // PID 0 is unused. @@ -87,7 +88,7 @@ pub const TraceContainer = struct { } } - fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const xplane_visitor.XPlaneVisitor) !void { + fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const XPlaneHashed) !void { // Convert devices and resources. const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id }); var device = device_entry.value_ptr.*; @@ -156,7 +157,7 @@ pub const TraceContainer = struct { fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void { if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| { - const xplane = try xplane_visitor.XPlaneVisitor.init(allocator, hp); + const xplane = try XPlaneHashed.init(allocator, hp); try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane); } @@ -173,7 +174,7 @@ pub const TraceContainer = struct { } for (device_planes.items) |dp| { - var xplane = try xplane_visitor.XPlaneVisitor.init(allocator, dp); + var xplane = try XPlaneHashed.init(allocator, dp); defer xplane.deinit(allocator); const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id)); try self.xplaneToTraceEvents(allocator, device_id, &xplane); @@ -298,3 +299,68 @@ pub const TraceContainer = struct { fn picoToMicro(p: anytype) f64 { return @as(f64, @floatFromInt(p)) / 1E6; } + +pub const XPlaneHashed = struct { + plane: *const xplane_proto.XPlane, + event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{}, + stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{}, + + pub fn init( + allocator: std.mem.Allocator, + plane: *const xplane_proto.XPlane, + ) !XPlaneHashed { + var res: XPlaneHashed = .{ .plane = plane }; + + try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len)); + // build event metadata map + for (plane.event_metadata.items) |*event_metadata| { + res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?); + } + + // build stat metadata map + try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len)); + for (plane.stat_metadata.items) |*stat_metadata| { + res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?); + } + + return res; + } + + pub fn deinit(self: *XPlaneHashed, allocator: std.mem.Allocator) void { + self.stat_metadata_by_id.deinit(allocator); + self.event_metadata_by_id.deinit(allocator); + } + + pub fn name(self: XPlaneHashed) []const u8 { + return self.plane.name.getSlice(); + } + + pub fn getEventType(self: XPlaneHashed, event_metadata_id: i64) xplane_schema.HostEventType { + if (self.event_metadata_by_id.get(event_metadata_id)) |v| { + return xplane_schema.HostEventType.fromString(v.name.getSlice()); + } else return .unknown; + } + + pub fn getStatMetadataName(self: XPlaneHashed, stat_metadata_id: i64) []const u8 { + if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { + return v.name.getSlice(); + } else return &[_]u8{}; + } + + pub fn getStatType(self: XPlaneHashed, stat_metadata_id: i64) xplane_schema.StatType { + if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { + return xplane_schema.StatType.fromString(v.name.getSlice()); + } else return .unknown; + } + + pub fn xstatToString(self: XPlaneHashed, stat: xplane_proto.XStat, writer: anytype) !void { + if (stat.value == null) return; + + switch (stat.value.?) { + inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}), + .str_value => |*v| try writer.writeAll(v.getSlice()), + .bytes_value => try writer.writeAll(""), + .ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))), + } + } +}; diff --git a/pjrt/convert/xplane_visitor.zig b/pjrt/convert/xplane_visitor.zig deleted file mode 100644 index a2f8eba..0000000 --- a/pjrt/convert/xplane_visitor.zig +++ /dev/null @@ -1,68 +0,0 @@ -const std = @import("std"); -const xplane_proto = @import("//tsl:xplane_proto"); -const xplane_schema = @import("xplane_schema.zig"); - -pub const XPlaneVisitor = struct { - plane: *const xplane_proto.XPlane, - event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{}, - stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{}, - - pub fn init( - allocator: std.mem.Allocator, - plane: *const xplane_proto.XPlane, - ) !XPlaneVisitor { - var res: XPlaneVisitor = .{ .plane = plane }; - - try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len)); - // build event metadata map - for (plane.event_metadata.items) |*event_metadata| { - res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?); - } - - // build stat metadata map - try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len)); - for (plane.stat_metadata.items) |*stat_metadata| { - res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?); - } - - return res; - } - - pub fn deinit(self: *XPlaneVisitor, allocator: std.mem.Allocator) void { - self.stat_metadata_by_id.deinit(allocator); - self.event_metadata_by_id.deinit(allocator); - } - - pub fn name(self: XPlaneVisitor) []const u8 { - return self.plane.name.getSlice(); - } - - pub fn getEventType(self: XPlaneVisitor, event_metadata_id: i64) xplane_schema.HostEventType { - if (self.event_metadata_by_id.get(event_metadata_id)) |v| { - return xplane_schema.HostEventType.fromString(v.name.getSlice()); - } else return .unknown; - } - - pub fn getStatMetadataName(self: XPlaneVisitor, stat_metadata_id: i64) []const u8 { - if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { - return v.name.getSlice(); - } else return &[_]u8{}; - } - - pub fn getStatType(self: XPlaneVisitor, stat_metadata_id: i64) xplane_schema.StatType { - if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { - return xplane_schema.StatType.fromString(v.name.getSlice()); - } else return .unknown; - } - - pub fn xstatToString(self: XPlaneVisitor, stat: xplane_proto.XStat, writer: anytype) !void { - if (stat.value == null) return; - - switch (stat.value.?) { - inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}), - .str_value => |*v| try writer.writeAll(v.getSlice()), - .bytes_value => try writer.writeAll(""), - .ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))), - } - } -}; diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig new file mode 100644 index 0000000..39e7c24 --- /dev/null +++ b/pjrt/ffi.zig @@ -0,0 +1,517 @@ +/// Bindings for PJRT custom call declaration / execution. +const std = @import("std"); + +const c = @import("c"); +const stdx = @import("stdx"); + +const pjrtStruct = @import("pjrt.zig").pjrtStruct; + +const log = std.log.scoped(.pjrt); + +pub const ApiVersion = extern struct { + pub const major = c.XLA_FFI_API_MAJOR; + pub const minor = c.XLA_FFI_API_MINOR; + + struct_size: usize, + extension_start: ?*ExtensionBase, + major_version: i32, + minor_version: i32, +}; + +pub const ExtensionType = enum(c.XLA_FFI_Extension_Type) { + metadata = c.XLA_FFI_Extension_Metadata, +}; + +pub const ExtensionBase = extern struct { + struct_size: usize, + type: ExtensionType, + next: ?*ExtensionBase, +}; + +// Based of https://github.com/openxla/xla/blob/145f836bd5175dc5dd262f716a0c59af2b0297a0/xla/ffi/api/c_api.h#L449 +pub const HandlerTraits = packed struct(u32) { + /// Calls to FFI handler are safe to trace into the command buffer. + /// It means that calls to FFI handler always launch exactly the same device operations (can depend on attribute values) + /// that can be captured and then replayed. + command_buffer_compatible: u1, + + __unassigned__: u31, +}; + +pub const Metadata = extern struct { + struct_size: usize, + api_version: ApiVersion, + traits: HandlerTraits, +}; + +pub const MetadataExtension = extern struct { + extension_base: ExtensionBase, + metadata: ?*Metadata, +}; + +pub const ApiError = error{ + Cancelled, + Unknown, + InvalidArgument, + DeadlineExceeded, + NotFound, + AlreadyExists, + PermissionDenied, + ResourceExhausted, + FailedPrecondition, + Aborted, + OutOfRange, + Unimplemented, + Internal, + Unavailable, + DataLoss, + Unauthenticated, +}; + +fn TransmuteMixin(comptime T: type, comptime InnerT: type) type { + return struct { + pub fn to(self: anytype) switch (@TypeOf(self)) { + *T => *InnerT, + *const T => *const InnerT, + else => unreachable, + } { + return @ptrCast(@alignCast(self)); + } + + pub fn from(self: anytype) switch (@TypeOf(self)) { + *InnerT => *T, + *const InnerT => *const T, + else => unreachable, + } { + return @ptrCast(@alignCast(self)); + } + }; +} + +pub const Api = opaque { + pub const inner = TransmuteMixin(Api, c.XLA_FFI_Api).to; + + pub fn getStream(self: *const Api, context: ?*ExecutionContext) ApiError!*anyopaque { + var ret = pjrtStruct(c.XLA_FFI_Stream_Get_Args{ + .ctx = if (context) |ctx| ctx.inner() else null, + }); + const result = self.inner().XLA_FFI_Stream_Get.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(self); + log.err("[Api.getStream] {s}", .{err.getMessage(self)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + + return ret.stream.?; + } + + pub fn allocateDeviceMemory(self: *const Api, context: ?*ExecutionContext, size: usize, alignment: usize) ApiError!*anyopaque { + var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Allocate_Args{ + .ctx = if (context) |ctx| ctx.inner() else null, + .size = size, + .alignment = alignment, + }); + const result = self.inner().XLA_FFI_DeviceMemory_Allocate.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(self); + log.err("[Api.allocateDeviceMemory] {s}", .{err.getMessage(self)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + + return ret.data.?; + } + + pub fn freeDeviceMemory(self: *const Api, context: ?*ExecutionContext, data: *anyopaque, size: usize) ApiError!void { + var ret = pjrtStruct(c.XLA_FFI_DeviceMemory_Free_Args{ + .ctx = if (context) |ctx| ctx.inner() else null, + .size = size, + .data = data, + }); + const result = self.inner().XLA_FFI_DeviceMemory_Free.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(self); + log.err("[Api.freeDeviceMemory] {s}", .{err.getMessage(self)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + } + + // TODO(Corentin): Implement remaining methods if needed: + // * XLA_FFI_ThreadPool_Schedule + // * XLA_FFI_Handler_Register + // * XLA_FFI_TypeId_Register + // * XLA_FFI_State_Set + // * XLA_FFI_State_Get +}; + +pub const ExecutionStage = enum(c.XLA_FFI_ExecutionStage) { + instantiate = c.XLA_FFI_ExecutionStage_INSTANTIATE, + prepare = c.XLA_FFI_ExecutionStage_PREPARE, + initialize = c.XLA_FFI_ExecutionStage_INITIALIZE, + execute = c.XLA_FFI_ExecutionStage_EXECUTE, +}; + +pub const ExecutionContext = opaque { + pub const inner = TransmuteMixin(ExecutionContext, c.XLA_FFI_ExecutionContext).to; + + // pub fn attach(self: *ExecutionContext, api: *const Api, value: anytype) ApiError!void { + // // register type id ==> typeid + // const typename_ = "zml." ++ @typeName(@TypeOf(value)); + + // var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{ + // .ctx = self.inner(), + // .handler = @ptrCast(@alignCast(handler)), + // }); + // const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret); + + // var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Register_Args{ + // .ctx = self.inner(), + // .handler = @ptrCast(@alignCast(handler)), + // }); + // const result = api.inner().XLA_FFI_ExecutionContext_Register.?(&ret); + + // if (result) |ffi_error| { + // const err = Error.fromInner(ffi_error); + // defer err.destroy(api); + // log.err("[ExecutionContext.register] {s}", .{err.getMessage(api)}); + + // // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + // return error.Unknown; + // } + // } + + pub fn get(self: *ExecutionContext, api: *const Api, type_id: *TypeId) ApiError!*anyopaque { + var ret = pjrtStruct(c.XLA_FFI_ExecutionContext_Get_Args{ + .ctx = self.inner(), + .type_id = @ptrCast(@alignCast(type_id)), + }); + const result = api.inner().XLA_FFI_ExecutionContext_Get.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + log.err("[ExecutionContext.get] {s}", .{err.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + + return ret.data.?; + } + + // TODO getDeviceOrdinal() +}; + +const ByteSpan = extern struct { + ptr: [*]const u8, + len: usize, + + pub fn slice(self: ByteSpan) []const u8 { + return self.ptr[0..self.len]; + } +}; + +pub const TypeId = extern struct { + type_id: i64, +}; + +pub const DataType = enum(c.XLA_FFI_DataType) { + invalid = c.XLA_FFI_DataType_INVALID, + pred = c.XLA_FFI_DataType_PRED, + s8 = c.XLA_FFI_DataType_S8, + s16 = c.XLA_FFI_DataType_S16, + s32 = c.XLA_FFI_DataType_S32, + s64 = c.XLA_FFI_DataType_S64, + u8 = c.XLA_FFI_DataType_U8, + u16 = c.XLA_FFI_DataType_U16, + u32 = c.XLA_FFI_DataType_U32, + u64 = c.XLA_FFI_DataType_U64, + f16 = c.XLA_FFI_DataType_F16, + f32 = c.XLA_FFI_DataType_F32, + f64 = c.XLA_FFI_DataType_F64, + bf16 = c.XLA_FFI_DataType_BF16, + c64 = c.XLA_FFI_DataType_C64, + c128 = c.XLA_FFI_DataType_C128, + token = c.XLA_FFI_DataType_TOKEN, + f8e5m2 = c.XLA_FFI_DataType_F8E5M2, + f8e3m4 = c.XLA_FFI_DataType_F8E3M4, + f8e4m3 = c.XLA_FFI_DataType_F8E4M3, + f8e4m3fn = c.XLA_FFI_DataType_F8E4M3FN, + f8e4m3b11fnuz = c.XLA_FFI_DataType_F8E4M3B11FNUZ, + f8e5m2fnuz = c.XLA_FFI_DataType_F8E5M2FNUZ, + f8e4m3fnuz = c.XLA_FFI_DataType_F8E4M3FNUZ, +}; + +pub const Buffer = extern struct { + struct_size: usize, + extension_start: ?*c.XLA_FFI_Extension_Base, + dtype: DataType, + data: [*]u8, + rank: u64, + _dims: [*]const i64, + + pub fn dims(self: Buffer) []const i64 { + return self._dims[0..self.rank]; + } + + pub fn format( + buffer: Buffer, + comptime fmt: []const u8, + options: std.fmt.FormatOptions, + writer: anytype, + ) !void { + _ = fmt; + _ = options; + + try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) }); + } +}; + +pub const Args = extern struct { + struct_size: usize, + extension_start: ?*const c.XLA_FFI_Extension_Base, + len: u64, + types: [*]const Type, + ptr: [*]*const Buffer, + + pub const Type = enum(c.XLA_FFI_ArgType) { + buffer = c.XLA_FFI_ArgType_BUFFER, + }; + + pub fn get(self: Args, i: usize) *const Buffer { + std.debug.assert(self.types[0..self.len][i] == .buffer); + return self.ptr[0..self.len][i]; + } +}; + +pub const Rets = extern struct { + struct_size: usize, + extension_start: ?*const c.XLA_FFI_Extension_Base, + len: u64, + types: [*]const Type, + ptr: [*]*const Buffer, + + pub const Type = enum(c.XLA_FFI_RetType) { + buffer = c.XLA_FFI_RetType_BUFFER, + }; + + pub fn get(self: Rets, i: usize) *const Buffer { + std.debug.assert(self.types[0..self.len][i] == .buffer); + return self.ptr[0..self.len][i]; + } +}; + +pub const AttrType = enum(c.XLA_FFI_AttrType) { + array = c.XLA_FFI_AttrType_ARRAY, + dictionary = c.XLA_FFI_AttrType_DICTIONARY, + scalar = c.XLA_FFI_AttrType_SCALAR, + string = c.XLA_FFI_AttrType_STRING, +}; + +pub const Attrs = extern struct { + struct_size: usize, + extension_start: ?*ExtensionBase, + len: u64, + types: [*]const AttrType, + names: [*]const *const ByteSpan, + ptr: [*]const *const Attr, + + const Attr = extern union { + scalar: Scalar, + array: Array, + }; + + pub const Scalar = extern struct { + dtype: DataType, + value: *const anyopaque, + + pub fn get(self: Scalar, T: type) T { + const ptr: *const T = @alignCast(@ptrCast(self.value)); + return ptr.*; + } + }; + + pub const Array = extern struct { + dtype: DataType, + len: usize, + data: [*]const u8, + }; + + pub fn getByIndex(self: Attrs, comptime attr_type: AttrType, index: usize) ?*const @FieldType(Attr, @tagName(attr_type)) { + const attr = self.ptr[0..self.len][index]; + const actual_type = self.types[index]; + if (actual_type != attr_type) return null; + return @ptrCast(attr); + } + + pub fn getByName(self: Attrs, comptime attr_type: AttrType, name: []const u8) ?*const @FieldType(Attr, @tagName(attr_type)) { + const names = self.names[0..self.len]; + for (0.., names) |i, attr_name| { + if (std.mem.eql(u8, attr_name.slice(), name)) { + return self.getByIndex(attr_type, i); + } + } + + return null; + } +}; + +pub const CallFrame = extern struct { + struct_size: usize, + extension_start: ?*ExtensionBase, + api: ?*const Api, + ctx: ?*const ExecutionContext, + stage: ExecutionStage, + args: Args, + results: Rets, + attrs: Attrs, + future: ?*Future, + + /// The registery mechanism will first call the custom call in registration mode, + /// and expects us to indicate which version of XLA we have been compiled against. + /// Returns true if we registered ourselves and if the caller custom call should return early. + pub fn registeringHook(call_frame: *CallFrame) bool { + if (call_frame.extension_start != null and call_frame.extension_start.?.type == .metadata) { + const metadata_extension: *MetadataExtension = @fieldParentPtr("extension_base", call_frame.extension_start.?); + metadata_extension.metadata.?.api_version.major_version = ApiVersion.major; + metadata_extension.metadata.?.api_version.minor_version = ApiVersion.minor; + return true; + } + return false; + } +}; + +pub const Handler = fn (*CallFrame) callconv(.C) ?*Error; + +pub const ErrorCode = enum(c.XLA_FFI_Error_Code) { + cancelled = c.XLA_FFI_Error_Code_CANCELLED, + unknown = c.XLA_FFI_Error_Code_UNKNOWN, + invalid_argument = c.XLA_FFI_Error_Code_INVALID_ARGUMENT, + deadline_exceeded = c.XLA_FFI_Error_Code_DEADLINE_EXCEEDED, + not_found = c.XLA_FFI_Error_Code_NOT_FOUND, + already_exists = c.XLA_FFI_Error_Code_ALREADY_EXISTS, + permission_denied = c.XLA_FFI_Error_Code_PERMISSION_DENIED, + resource_exhausted = c.XLA_FFI_Error_Code_RESOURCE_EXHAUSTED, + failed_precondition = c.XLA_FFI_Error_Code_FAILED_PRECONDITION, + aborted = c.XLA_FFI_Error_Code_ABORTED, + out_of_range = c.XLA_FFI_Error_Code_OUT_OF_RANGE, + unimplemented = c.XLA_FFI_Error_Code_UNIMPLEMENTED, + internal = c.XLA_FFI_Error_Code_INTERNAL, + unavailable = c.XLA_FFI_Error_Code_UNAVAILABLE, + data_loss = c.XLA_FFI_Error_Code_DATA_LOSS, + unauthenticated = c.XLA_FFI_Error_Code_UNAUTHENTICATED, + + pub fn toApiError(code: ErrorCode) ApiError { + return switch (code) { + .cancelled => error.Cancelled, + .unknown => error.Unknown, + .invalid_argument => error.InvalidArgument, + .deadline_exceeded => error.DeadlineExceeded, + .not_found => error.FfiNotFound, + .already_exists => error.AlreadyExists, + .permission_denied => error.PermissionDenied, + .resource_exhausted => error.ResourceExhausted, + .failed_precondition => error.FailedPrecondition, + .aborted => error.Aborted, + .out_of_range => error.OutOfRange, + .unimplemented => error.Unimplemented, + .internal => error.Internal, + .unavailable => error.Unavailable, + .data_loss => error.DataLoss, + .unauthenticated => error.Unauthenticated, + }; + } +}; + +pub const Error = opaque { + pub const inner = TransmuteMixin(Error, c.XLA_FFI_Error).to; + pub const fromInner = TransmuteMixin(Error, c.XLA_FFI_Error).from; + + pub fn create(api: *const Api, error_code: ErrorCode, message: [:0]const u8) *Error { + var ret = pjrtStruct(c.XLA_FFI_Error_Create_Args{ + .message = message.ptr, + .errc = @intFromEnum(error_code), + }); + return fromInner(api.inner().XLA_FFI_Error_Create.?(&ret).?); + } + + pub fn destroy(err: *Error, api: *const Api) void { + var ret = pjrtStruct(c.XLA_FFI_Error_Destroy_Args{ .@"error" = err.inner() }); + api.inner().XLA_FFI_Error_Destroy.?(&ret); + } + + pub fn getMessage(err: *Error, api: *const Api) [:0]const u8 { + var ret = pjrtStruct(c.XLA_FFI_Error_GetMessage_Args{ + .@"error" = err.inner(), + }); + api.inner().XLA_FFI_Error_GetMessage.?(&ret); + return std.mem.span(ret.message); + } +}; + +pub const Future = opaque { + pub const inner = TransmuteMixin(Future, c.XLA_FFI_Future).to; + pub const fromInner = TransmuteMixin(Future, c.XLA_FFI_Future).from; + + pub fn create(api: *const Api) ApiError!*Future { + var ret = pjrtStruct(c.XLA_FFI_Future_Create_Args{}); + const result = api.inner().XLA_FFI_Future_Create.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + log.err("[Future.create] {s}", .{err.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + + return fromInner(ret.future.?); + } + + pub fn setAvailable(self: *Future, api: *const Api) ApiError!void { + var ret = pjrtStruct(c.XLA_FFI_Future_SetAvailable_Args{ + .future = self.inner(), + }); + + const result = api.inner().XLA_FFI_Future_SetAvailable.?(&ret); + + if (result) |ffi_error| { + const err = Error.fromInner(ffi_error); + defer err.destroy(api); + log.err("[Future.setAvailable] {s}", .{err.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + } + + pub fn setError(self: *Future, api: *const Api, err: *Error) ApiError!void { + var ret = pjrtStruct(c.XLA_FFI_Future_SetError_Args{ + .future = self.inner(), + .@"error" = err.inner(), + }); + + const result = api.inner().XLA_FFI_Future_SetError.?(&ret); + + if (result) |ffi_error| { + const err2 = Error.fromInner(ffi_error); + defer err2.destroy(api); + log.err("[Future.setError] {s}", .{err2.getMessage(api)}); + + // TODO(Corentin): Retrieve error code from Error when implemented in XLA. + return error.Unknown; + } + } +}; diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 500db40..9dbd5aa 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -1,13 +1,14 @@ -const builtin = @import("builtin"); const std = @import("std"); -const stdx = @import("stdx"); +const builtin = @import("builtin"); const c = @import("c"); +const stdx = @import("stdx"); + +pub const ffi = @import("ffi.zig"); +pub const Profiler = @import("profiler.zig").Profiler; const log = std.log.scoped(.pjrt); -pub const Profiler = @import("profiler.zig").Profiler; - test { std.testing.refAllDecls(@This()); } @@ -160,10 +161,9 @@ pub const Api = struct { } pub fn customCallRegistry(api: *const Api) ?CustomCallRegistry { - if (api.lookupExtension(c.PJRT_Gpu_Custom_Call, c.PJRT_Extension_Type_Gpu_Custom_Call)) |ext| { - return .{ .inner = ext.custom_call.? }; + if (api.lookupExtension(c.PJRT_FFI_Extension, c.PJRT_Extension_Type_FFI)) |ext| { + return .{ .inner = ext.register_handler.? }; } - // log.warn("No Custom Call registry found for platform: {}", .{self}); return null; } @@ -405,7 +405,7 @@ pub const Client = opaque { } pub const CreateViewOfDeviceBufferArgs = struct { - data: []const u8, + data: *anyopaque, dims: []const i64, element_type: BufferType, layout: MemoryLayout, @@ -421,7 +421,7 @@ pub const Client = opaque { const layout = args.layout.toCStruct(); const ret = try api.call(.PJRT_Client_CreateViewOfDeviceBuffer, .{ .client = self.inner(), - .device_buffer_ptr = @ptrCast(@constCast(args.data.ptr)), + .device_buffer_ptr = @ptrCast(@constCast(args.data)), .dims = args.dims.ptr, .num_dims = args.dims.len, .element_type = @intFromEnum(args.element_type), @@ -919,18 +919,14 @@ pub const Memory = opaque { const inner = InnerMixin(c.PJRT_Memory).inner; pub fn id(self: *const Memory, api: *const Api) usize { - const ret = api.call(.PJRT_Memory_Id, .{ - .memory = self.inner(), - }) catch unreachable; + const ret = api.call(.PJRT_Memory_Id, .{ .memory = self.inner() }) catch unreachable; return @intCast(ret.id); } pub fn kind(self: *const Memory, api: *const Api) Kind { - const ret = api.call(.PJRT_Memory_Kind, .{ - .memory = self.inner(), - }) catch unreachable; - const kind_ = ret.kind orelse unreachable; - return std.meta.stringToEnum(Kind, kind_[0..ret.kind_size]) orelse unreachable; + const ret = api.call(.PJRT_Memory_Kind, .{ .memory = self.inner() }) catch unreachable; + const kind_ = ret.kind orelse unreachable[0..ret.kind_size]; + return std.meta.stringToEnum(Kind, kind_) orelse unreachable; } pub fn kindId(self: *const Memory, api: *const Api) u32 { @@ -1161,23 +1157,25 @@ pub const NamedValue = extern struct { } }; -/// Custom call signature arguments are: -/// * a pointer to a platform specific stream handle -/// * a pointer to an unspecified list of platform specific buffer handle -/// * a context struct passed as a slice of bytes -pub const CustomCall = fn (*anyopaque, [*]*anyopaque, [*]const u8, usize) callconv(.C) void; - // todo : support all missing handlers available in GPU plugin extension: handler_instantiate, handler_prepare, handler_initialize // introduced by https://github.com/openxla/xla/commit/ef85a7bcc308313492ebc50295a8a08b4e51b8f5 pub const CustomCallRegistry = extern struct { - inner: *const c.PJRT_Gpu_Register_Custom_Call, + inner: *const c.PJRT_FFI_Register_Handler, - pub fn register(self: *const CustomCallRegistry, api: *const Api, api_version: usize, name: []const u8, func: *const CustomCall) ApiError!void { - var ret = pjrtStruct(c.PJRT_Gpu_Register_Custom_Call_Args{ - .function_name = name.ptr, - .function_name_size = name.len, - .api_version = @intCast(api_version), - .handler_execute = @ptrCast(@constCast(func)), + pub fn registerFfi( + self: *const CustomCallRegistry, + api: *const Api, + target_name: []const u8, + platform_name: []const u8, + func: *const ffi.Handler, + ) ApiError!void { + var ret = pjrtStruct(c.PJRT_FFI_Register_Handler_Args{ + .api_version = 1, + .target_name = target_name.ptr, + .target_name_size = target_name.len, + .handler = @ptrCast(@constCast(func)), + .platform_name = platform_name.ptr, + .platform_name_size = platform_name.len, }); const result = self.inner(&ret); if (result) |pjrt_c_error| { diff --git a/stdx/stdx.zig b/stdx/stdx.zig index 2cbc84a..96707c8 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -6,3 +6,9 @@ pub const math = @import("math.zig"); pub const meta = @import("meta.zig"); pub const queue = @import("queue.zig"); pub const time = @import("time.zig"); + +pub inline fn stackSlice(comptime max_len: usize, T: type, len: usize) []T { + debug.assert(len <= max_len, "stackSlice can only create a slice of up to {} elements, got: {}", .{ max_len, len }); + var storage: [max_len]T = undefined; + return storage[0..len]; +} diff --git a/third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel b/third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel new file mode 100644 index 0000000..516cef1 --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/MODULE.bazel @@ -0,0 +1,37 @@ +module( + name = "xla", + version = "20250317.1-71c67e2", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl", "python_version_repo") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "llvm-raw", + "stablehlo", +) + +llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") +llvm_configure(name = "llvm-project") diff --git a/third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel b/third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel new file mode 100644 index 0000000..516cef1 --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/overlay/MODULE.bazel @@ -0,0 +1,37 @@ +module( + name = "xla", + version = "20250317.1-71c67e2", + compatibility_level = 1, +) + +bazel_dep(name = "platforms", version = "0.0.8") +bazel_dep(name = "bazel_skylib", version = "1.5.0") +bazel_dep(name = "rules_cc", version = "0.0.17") +bazel_dep(name = "rules_apple", version = "3.17.0", repo_name = "build_bazel_rules_apple") +bazel_dep(name = "abseil-cpp", version = "20240116.0", repo_name = "com_google_absl") +bazel_dep(name = "rules_python", version = "0.29.0") +bazel_dep(name = "rules_proto", version = "6.0.0-rc1") +bazel_dep(name = "rules_java", version = "7.3.2") +bazel_dep(name = "rules_pkg", version = "0.9.1") +bazel_dep(name = "zlib", version = "1.2.13") +bazel_dep(name = "re2", version = "2024-07-02.bcr.1", repo_name = "com_googlesource_code_re2") +bazel_dep(name = "rules_license", version = "0.0.8") + +tsl = use_extension("//:tsl.bzl", "tsl") +use_repo(tsl, "tsl", "python_version_repo") + +xla_workspace = use_extension("//:workspace.bzl", "xla_workspace") +use_repo( + xla_workspace, + "com_github_grpc_grpc", + "com_google_protobuf", + "local_config_cuda", + "local_config_remote_execution", + "local_config_rocm", + "local_config_tensorrt", + "llvm-raw", + "stablehlo", +) + +llvm_configure = use_repo_rule("@llvm-raw//utils/bazel:configure.bzl", "llvm_configure") +llvm_configure(name = "llvm-project") diff --git a/third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl b/third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl new file mode 100644 index 0000000..dc25c9d --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/overlay/tsl.bzl @@ -0,0 +1,19 @@ +load("//third_party:repo.bzl", "tf_vendored") +load("//third_party/py:python_init_repositories.bzl", "python_init_repositories") + +def _tsl_impl(mctx): + python_init_repositories( + requirements = { + "3.11": "//:requirements_lock_3_11.txt", + }, + ) + tf_vendored(name = "tsl", relpath = "third_party/tsl") + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = ["tsl"], + root_module_direct_dev_deps = [], + ) + +tsl = module_extension( + implementation = _tsl_impl, +) diff --git a/third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl b/third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl new file mode 100644 index 0000000..192e06f --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/overlay/workspace.bzl @@ -0,0 +1,60 @@ +load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") +load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") +load("//third_party/gpus:rocm_configure.bzl", "rocm_configure") +load("//third_party/llvm:workspace.bzl", llvm = "repo") +load("//third_party/pybind11_bazel:workspace.bzl", pybind11_bazel = "repo") +load("//third_party/stablehlo:workspace.bzl", stablehlo = "repo") +load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") +load("//third_party/triton:workspace.bzl", triton = "repo") +load("//tools/toolchains/remote:configure.bzl", "remote_execution_configure") + +def _xla_workspace_impl(mctx): + cuda_configure(name = "local_config_cuda") + remote_execution_configure(name = "local_config_remote_execution") + rocm_configure(name = "local_config_rocm") + tensorrt_configure(name = "local_config_tensorrt") + pybind11_bazel() + triton() + llvm("llvm-raw") + stablehlo() + tf_http_archive( + name = "com_github_grpc_grpc", + sha256 = "b956598d8cbe168b5ee717b5dafa56563eb5201a947856a6688bbeac9cac4e1f", + strip_prefix = "grpc-b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd", + system_build_file = "//third_party/systemlibs:grpc.BUILD", + patch_file = [ + "//third_party/grpc:generate_cc_env_fix.patch", + "//third_party/grpc:register_go_toolchain.patch", + ], + system_link_files = { + "//third_party/systemlibs:BUILD.bazel": "bazel/BUILD.bazel", + "//third_party/systemlibs:grpc.BUILD": "src/compiler/BUILD", + "//third_party/systemlibs:grpc.bazel.grpc_deps.bzl": "bazel/grpc_deps.bzl", + "//third_party/systemlibs:grpc.bazel.grpc_extra_deps.bzl": "bazel/grpc_extra_deps.bzl", + "//third_party/systemlibs:grpc.bazel.cc_grpc_library.bzl": "bazel/cc_grpc_library.bzl", + "//third_party/systemlibs:grpc.bazel.generate_cc.bzl": "bazel/generate_cc.bzl", + "//third_party/systemlibs:grpc.bazel.protobuf.bzl": "bazel/protobuf.bzl", + }, + urls = tf_mirror_urls("https://github.com/grpc/grpc/archive/b54a5b338637f92bfcf4b0bc05e0f57a5fd8fadd.tar.gz"), + ) + tf_http_archive( + name = "com_google_protobuf", + patch_file = ["//third_party/protobuf:protobuf.patch"], + sha256 = "f66073dee0bc159157b0bd7f502d7d1ee0bc76b3c1eac9836927511bdc4b3fc1", + strip_prefix = "protobuf-3.21.9", + system_build_file = "//third_party/systemlibs:protobuf.BUILD", + system_link_files = { + "//third_party/systemlibs:protobuf.bzl": "protobuf.bzl", + "//third_party/systemlibs:protobuf_deps.bzl": "protobuf_deps.bzl", + }, + urls = tf_mirror_urls("https://github.com/protocolbuffers/protobuf/archive/v3.21.9.zip"), + ) + return mctx.extension_metadata( + reproducible = True, + root_module_direct_deps = "all", + root_module_direct_dev_deps = [], + ) + +xla_workspace = module_extension( + implementation = _xla_workspace_impl, +) diff --git a/third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch b/third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch new file mode 100644 index 0000000..8924cf4 --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/patches/0001-bazel-migration-to-bazel-8.1.1.patch @@ -0,0 +1,41 @@ +From 6cf475b500521c1b8be06f590fdbc1818f0dc44b Mon Sep 17 00:00:00 2001 +From: Jean-Baptiste Dalido +Date: Mon, 6 Jan 2025 13:33:13 +0100 +Subject: [PATCH] bazel: migration to bazel 8.0.1 + +--- + .bazelversion | 2 +- + third_party/tsl/third_party/gpus/cuda_configure.bzl | 4 ++-- + 2 files changed, 3 insertions(+), 3 deletions(-) + +diff --git a/.bazelversion b/.bazelversion +index f22d756da3..fa5fce04b3 100644 +--- a/.bazelversion ++++ b/.bazelversion +@@ -1 +1 @@ +-7.4.1 ++8.1.1 +\ No newline at end of file +diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl +index d62531152d..71d80a5a99 100644 +--- a/third_party/gpus/cuda_configure.bzl ++++ b/third_party/gpus/cuda_configure.bzl +@@ -33,14 +33,14 @@ NB: DEPRECATED! Use `hermetic/cuda_configure` rule instead. + load( + "@bazel_tools//tools/cpp:lib_cc_configure.bzl", + "escape_string", +- "get_env_var", + ) + load( + "@bazel_tools//tools/cpp:windows_cc_configure.bzl", +- "find_msvc_tool", + "find_vc_path", + "setup_vc_env_vars", + ) ++load("@rules_cc//cc/private/toolchain:windows_cc_configure.bzl", "find_msvc_tool") ++load("@rules_cc//cc/private/toolchain:lib_cc_configure.bzl", "get_env_var") + load("//third_party/clang_toolchain:download_clang.bzl", "download_clang") + load( + "//third_party/remote_config:common.bzl", +-- +2.39.3 (Apple Git-146) diff --git a/third_party/modules/xla/20250317.1-71c67e2/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch b/third_party/modules/xla/20250317.1-71c67e2/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch new file mode 100644 index 0000000..58e3874 --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/patches/0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch @@ -0,0 +1,131 @@ +From 367df40470c00b9a4f83e3c5bc5553e7b0878351 Mon Sep 17 00:00:00 2001 +From: Hugo Mano +Date: Wed, 5 Feb 2025 19:25:03 +0100 +Subject: [PATCH 1/8] Added FFI handler registration API to the FFI PjRt + +PR: https://github.com/openxla/xla/pull/13420 +--- + xla/pjrt/c/BUILD | 5 ++++ + xla/pjrt/c/pjrt_c_api_ffi_extension.h | 16 ++++++++++++ + xla/pjrt/c/pjrt_c_api_ffi_internal.cc | 35 +++++++++++++++++++++++++-- + 3 files changed, 54 insertions(+), 2 deletions(-) + +diff --git a/xla/pjrt/c/BUILD b/xla/pjrt/c/BUILD +index ad2ed95bce..0e7c35c30f 100644 +--- a/xla/pjrt/c/BUILD ++++ b/xla/pjrt/c/BUILD +@@ -69,7 +69,12 @@ cc_library( + ":pjrt_c_api_wrapper_impl", + "//xla/ffi:execution_context", + "//xla/ffi:type_id_registry", ++ "//xla/ffi:ffi_api", ++ "//xla/ffi/api:c_api", ++ "//xla/ffi/api:ffi", ++ "//xla/service:custom_call_target_registry", + "@com_google_absl//absl/status", ++ "@com_google_absl//absl/strings:str_format", + ], + ) + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_extension.h b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +index a33bd4aa9c..3309194538 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_extension.h ++++ b/xla/pjrt/c/pjrt_c_api_ffi_extension.h +@@ -66,12 +66,28 @@ PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_UserData_Add_Args, user_data); + // Adds a user data to the execute context. + typedef PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args); + ++struct PJRT_FFI_Register_Handler_Args { ++ size_t struct_size; ++ const char* target_name; ++ size_t target_name_size; ++ int api_version; // 0 for an untyped call, 1 -- for typed ++ void* handler; ++ const char* platform_name; ++ size_t platform_name_size; ++}; ++PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Register_Handler_Args, handler); ++ ++// Registers an FFI call handler for a specific platform. ++typedef PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args); ++ + typedef struct PJRT_FFI_Extension { + size_t struct_size; + PJRT_Extension_Type type; + PJRT_Extension_Base* next; + PJRT_FFI_TypeID_Register* type_id_register; + PJRT_FFI_UserData_Add* user_data_add; ++ PJRT_FFI_Register_Handler* register_handler; + } PJRT_FFI; + PJRT_DEFINE_STRUCT_TRAITS(PJRT_FFI_Extension, user_data_add); + +diff --git a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +index 0375b39d0b..3527a0756e 100644 +--- a/xla/pjrt/c/pjrt_c_api_ffi_internal.cc ++++ b/xla/pjrt/c/pjrt_c_api_ffi_internal.cc +@@ -13,15 +13,20 @@ See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +-#include "xla/pjrt/c/pjrt_c_api_ffi_internal.h" ++#include + + #include "absl/status/status.h" ++#include "absl/strings/str_format.h" ++#include "xla/ffi/api/c_api.h" ++#include "xla/ffi/api/ffi.h" + #include "xla/ffi/execution_context.h" +-#include "xla/ffi/type_id_registry.h" ++ #include "xla/ffi/type_id_registry.h" ++#include "xla/ffi/ffi_api.h" + #include "xla/pjrt/c/pjrt_c_api.h" + #include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" + #include "xla/pjrt/c/pjrt_c_api_helpers.h" + #include "xla/pjrt/c/pjrt_c_api_wrapper_impl.h" ++#include "xla/service/custom_call_target_registry.h" + + namespace pjrt { + +@@ -55,6 +60,31 @@ static PJRT_Error* PJRT_FFI_UserData_Add(PJRT_FFI_UserData_Add_Args* args) { + return nullptr; + } + ++static PJRT_Error* PJRT_FFI_Register_Handler( ++ PJRT_FFI_Register_Handler_Args* args) { ++ PJRT_RETURN_IF_ERROR(ActualStructSizeIsGreaterOrEqual( ++ "PJRT_FFI_Register_Handler_Args", ++ PJRT_FFI_Register_Handler_Args_STRUCT_SIZE, args->struct_size)); ++ std::string target_name(args->target_name, args->target_name_size); ++ std::string platform_name(args->platform_name, args->platform_name_size); ++ switch (args->api_version) { ++ case 0: ++ xla::CustomCallTargetRegistry::Global()->Register( ++ target_name, args->handler, platform_name); ++ return nullptr; ++ case 1: ++ xla::ffi::Ffi::RegisterStaticHandler( ++ xla::ffi::GetXlaFfiApi(), target_name, platform_name, ++ reinterpret_cast(args->handler)); ++ return nullptr; ++ default: ++ return new PJRT_Error{absl::UnimplementedError( ++ absl::StrFormat("API version %d not supported for PJRT GPU plugin. " ++ "Supported versions are 0 and 1.", ++ args->api_version))}; ++ } ++} ++ + PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + return { + /*struct_size=*/PJRT_FFI_Extension_STRUCT_SIZE, +@@ -62,6 +92,7 @@ PJRT_FFI_Extension CreateFfiExtension(PJRT_Extension_Base* next) { + /*next=*/next, + /*type_id_register=*/PJRT_FFI_TypeID_Register, + /*user_data_add=*/PJRT_FFI_UserData_Add, ++ /*register_handler=*/PJRT_FFI_Register_Handler, + }; + } + +-- +2.43.0 diff --git a/third_party/modules/xla/20250317.1-71c67e2/source.json b/third_party/modules/xla/20250317.1-71c67e2/source.json new file mode 100644 index 0000000..5577d24 --- /dev/null +++ b/third_party/modules/xla/20250317.1-71c67e2/source.json @@ -0,0 +1,15 @@ +{ + "strip_prefix": "xla-71c67e2a4f40267115a0d4ea7c36748bbe7e750e", + "url": "https://github.com/openxla/xla/archive/71c67e2a4f40267115a0d4ea7c36748bbe7e750e.tar.gz", + "integrity": "sha256-j6D1MC7+WsbZ+Ve3hPmDlCzZX1yV6RIP2BWDgQcbcYc=", + "overlay": { + "tsl.bzl": "", + "workspace.bzl": "", + "MODULE.bazel": "" + }, + "patch_strip": 1, + "patches": { + "0001-bazel-migration-to-bazel-8.1.1.patch": "", + "0002-Added-FFI-handler-registration-API-to-the-FFI-PjRt.patch": "" + } +} diff --git a/zml/buffer.zig b/zml/buffer.zig index 4249aab..e0829d3 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -1,16 +1,15 @@ -const asynk = @import("async"); const std = @import("std"); -const stdx = @import("stdx"); - -const meta = @import("meta.zig"); -const pjrt = @import("pjrtx.zig"); - const testing = std.testing; +const asynk = @import("async"); +const stdx = @import("stdx"); + const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const meta = @import("meta.zig"); +const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; @@ -27,10 +26,22 @@ const log = std.log.scoped(.zml); /// * loading weights from disk directly to the `device zml.aio.loadBuffers` /// * can be created by calling `HostBuffer.toDevice(platform)`. pub const Buffer = struct { - pub const Memory = enum(@typeInfo(pjrt.Memory.Kind).@"enum".tag_type) { - host = @intFromEnum(pjrt.Memory.Kind.unpinned_host), - host_pinned = @intFromEnum(pjrt.Memory.Kind.pinned_host), - device = @intFromEnum(pjrt.Memory.Kind.device), + pub const Memory = enum { + host, + host_pinned, + device, + + pub fn toPjrtMemory(self: Memory) pjrt.Memory.Kind { + return switch (self) { + .host => .unpinned_host, + .host_pinned => .pinned_host, + .device => .device, + }; + } + + pub fn pjrtName(self: Memory) []const u8 { + return @tagName(self.toPjrtMemory()); + } }; pub const Shard = struct { @@ -216,13 +227,13 @@ pub const Buffer = struct { /// and it might not work on all platforms, /// could lead to crashes and operations on the buffer will be slower. /// Tested on Cuda 12.4. - pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) !Buffer { + pub fn asViewOfHostBuffer(platform: Platform, buf: HostBuffer) Buffer { return asViewOfDeviceBuffer(platform, buf.shape(), null, @constCast(@ptrCast(buf.data.ptr))); } /// Creates a Buffer from a pointer into device memory. /// This allows to interface with other libraries producing buffers. - pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) !Buffer { + pub fn asViewOfDeviceBuffer(platform: Platform, shape_: Shape, stream: ?*const anyopaque, device_data: *anyopaque) Buffer { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { var res: [Shape.MAX_RANK]i64 = undefined; for (0..Shape.MAX_RANK) |i| { @@ -231,9 +242,8 @@ pub const Buffer = struct { break :blk res; }; - const device_bytes: [*]u8 = @ptrCast(device_data); - const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ - .data = device_bytes[0..shape_.byteSize()], + const pjrt_buffer = platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ + .data = device_data, .element_type = bufferTypeFromDtype(shape_.dtype()), .dims = shape_.dims(), // TODO: exposes sharding in the API. @@ -246,7 +256,7 @@ pub const Buffer = struct { }, }, .stream = @bitCast(@as(usize, @intFromPtr(stream))), - }); + }) catch @panic("failed to createViewOfDeviceBuffer"); var shards: Shards = .{}; shards.appendAssumeCapacity(pjrt_buffer); @@ -342,6 +352,11 @@ pub const Buffer = struct { try writer.print("Buffer({_})", .{self._shape}); } + pub fn getMemory(self: Buffer) *const pjrt.Memory { + const shard = self._shards.get(0); + return shard.memory(self._api); + } + fn hasShardedAxis(self: Buffer) bool { if (self._shards.len == 1) return false; return @reduce(.Or, self._shape._sharding_info); diff --git a/zml/context.zig b/zml/context.zig index 7e36b6b..03de62a 100644 --- a/zml/context.zig +++ b/zml/context.zig @@ -1,21 +1,23 @@ +const std = @import("std"); const builtin = @import("builtin"); + const c = @import("c"); const mlir = @import("mlir"); const runfiles = @import("runfiles"); const runtimes = @import("runtimes"); -const std = @import("std"); const stdx = @import("stdx"); -const zml_platform = @import("platform.zig"); -const pjrt = @import("pjrtx.zig"); - +const Buffer = @import("buffer.zig").Buffer; +const DataType = @import("dtype.zig").DataType; const HostBuffer = @import("hostbuffer.zig").HostBuffer; -const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); +const pjrt = @import("pjrtx.zig"); const Platform = @import("platform.zig").Platform; -const PlatformsMap = std.EnumArray(Target, ?Platform); +const Shape = @import("shape.zig").Shape; const Target = @import("platform.zig").Target; +const zml_platform = @import("platform.zig"); -const available_targets = @import("platform.zig").available_targets; +const PjrtApiMap = std.EnumArray(Target, ?*const pjrt.Api); +const PlatformsMap = std.EnumArray(Target, ?Platform); const log = std.log.scoped(.@"zml/context"); test { @@ -174,10 +176,8 @@ pub const Context = struct { log.err("No device found for platform {} !", .{target}); return error.NoDevicesFound; } - // TODO: should this be moved to platform.zig ? - if (target == .cuda) { - try cuda.registerZmlCustomCalls(p); - } + + try CustomCall.registerZmlCustomCalls(p); self.platforms.set(target, p); return p; @@ -213,77 +213,68 @@ pub const Context = struct { } } - pub const HostCallbackCtx = struct { - host: HostBuffer, - mutex: std.Thread.Mutex = std.Thread.Mutex{}, - }; - pub const HostCallback = fn (HostBuffer) void; + pub const HostCallback = fn (?*anyopaque, []const HostBuffer, []const HostBuffer) void; }; -const cuda = struct { - var runtime: Runtime = undefined; +const CustomCall = struct { + pub fn registerZmlCustomCalls(platform: Platform) !void { + const registry = platform.pjrt_api.customCallRegistry(); - pub fn registerZmlCustomCalls(cuda_platform: Platform) !void { - std.debug.assert(cuda_platform.target == .cuda); - - cuda.runtime = try Runtime.init(); - const registry = cuda_platform.pjrt_api.customCallRegistry().?; - try registry.register(cuda_platform.pjrt_api, 0, "zmlHostBufferCallback", &hostBufferCallback); - } - - pub const Stream = opaque {}; - pub const MemcpyKind = enum(c_int) { - host_to_host = 0, - host_to_device = 1, - device_to_host = 2, - device_to_device = 3, - default = 4, - }; - - pub const Runtime = struct { - memcpyAsync: MemcpyAsync, - streamSynchronize: StreamSynchronize, - - const MemcpyAsync = *const fn (dst: *anyopaque, src: *const anyopaque, count: usize, kind: MemcpyKind, stream: *Stream) callconv(.C) c_int; - const StreamSynchronize = *const fn (stream: *Stream) callconv(.C) c_int; - - pub fn init() !Runtime { - var cudart = try std.DynLib.open("libcudart.so.12"); - defer cudart.close(); - - return .{ - .memcpyAsync = cudart.lookup(Runtime.MemcpyAsync, "cudaMemcpyAsync") orelse return error.NotFound, - .streamSynchronize = cudart.lookup(Runtime.StreamSynchronize, "cudaStreamSynchronize") orelse return error.NotFound, - }; + if (registry) |reg| { + try reg.registerFfi(platform.pjrt_api, "zmlHostBufferCallback", @tagName(platform.target), &hostBufferCallback); + } else { + stdx.debug.panic("Registering custom calls failed", .{}); } - }; - - fn getContext(args: [*]const u8, args_len: usize) struct { *const Context.HostCallback, *Context.HostCallbackCtx } { - std.debug.assert(args_len == @sizeOf(*anyopaque) * 2); - - const raw_fn_ptr: usize = @bitCast(args[0..@sizeOf(*anyopaque)].*); - const fn_ptr: *const Context.HostCallback = @ptrFromInt(raw_fn_ptr); - - const raw_ctx_ptr: usize = @bitCast(args[@sizeOf(*anyopaque)..][0..@sizeOf(*anyopaque)].*); - const ctx_ptr: *Context.HostCallbackCtx = @ptrFromInt(raw_ctx_ptr); - return .{ fn_ptr, ctx_ptr }; } - fn hostBufferCallback(opaque_stream: *anyopaque, buffers: [*]*anyopaque, args: [*]const u8, args_len: usize) callconv(.C) void { - const stream: *Stream = @ptrCast(opaque_stream); - const src: *anyopaque = buffers[0]; - const callback, const ctx = getContext(args, args_len); + fn hostBufferCallback(call_frame: *pjrt.ffi.CallFrame) callconv(.C) ?*pjrt.ffi.Error { + if (call_frame.registeringHook()) return null; - // Add synchronization because this is called from the device driver. - ctx.mutex.lock(); - defer ctx.mutex.unlock(); + const callback_attr = call_frame.attrs.getByName(.scalar, "callback") orelse unreachable; + std.debug.assert(callback_attr.dtype == .u64); + const callback: *const Context.HostCallback = @ptrFromInt(callback_attr.get(usize)); - const host_dst: []u8 = @constCast(ctx.host.data); - const memcpy_result = cuda.runtime.memcpyAsync(host_dst.ptr, src, host_dst.len, .device_to_host, stream); - _ = memcpy_result; - const synchronize_result = cuda.runtime.streamSynchronize(stream); - _ = synchronize_result; + const user_ctx_ptr = call_frame.attrs.getByName(.scalar, "user_context") orelse unreachable; + std.debug.assert(user_ctx_ptr.dtype == .u64); + const user_ctx: ?*anyopaque = @ptrFromInt(user_ctx_ptr.get(usize)); - callback(ctx.host); + const input_buffers = stdx.stackSlice(8, HostBuffer, call_frame.args.len); + for (input_buffers, 0..) |*b, i| { + b.* = hostBufferFromPinnedBuffer(call_frame.args.get(i)); + } + + const output_buffers = stdx.stackSlice(8, HostBuffer, call_frame.results.len); + for (output_buffers, 0..) |*b, i| { + b.* = hostBufferFromPinnedBuffer(call_frame.results.get(i)); + } + + callback(user_ctx, input_buffers, output_buffers); + return null; } }; + +fn getShape(buffer_desc: *const pjrt.ffi.Buffer) Shape { + // log.warn("received buffer {}", .{buffer_desc}); + const dt: DataType = switch (buffer_desc.dtype) { + .invalid => @panic("invalid ffi"), + .pred => .bool, + .s8 => .i8, + .s16 => .i16, + .s32 => .i32, + .s64 => .i64, + .token, .f8e4m3, .f8e3m4 => @panic("Unsupported ffi type"), + inline else => |t| @field(DataType, @tagName(t)), + }; + return Shape.init(buffer_desc.dims(), dt); +} + +/// Create a HostBuffer from a ffi description of a buffer. +/// Normally the ffi describe device buffer but we assume they are located in pinned memory, +/// and therefore the data pointer is readable both from host and from device. +fn hostBufferFromPinnedBuffer(buffer_desc: *const pjrt.ffi.Buffer) HostBuffer { + const buffer_shape = getShape(buffer_desc); + return HostBuffer.fromBytes( + buffer_shape, + buffer_desc.data[0..buffer_shape.byteSize()], + ); +} diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 56bd9a8..1885a00 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -1,4 +1,5 @@ const std = @import("std"); + const stdx = @import("stdx"); const Buffer = @import("buffer.zig").Buffer; @@ -98,6 +99,13 @@ pub const HostBuffer = struct { }; } + /// Returns a HostBuffer tagged with the tags in 'tagz'. + pub fn withTags(self: HostBuffer, tagz: anytype) HostBuffer { + var res = self; + res._shape = self._shape.withTags(tagz); + return res; + } + pub const ArangeArgs = struct { start: i64 = 0, end: i64, @@ -240,6 +248,11 @@ pub const HostBuffer = struct { }; } + pub fn choose1d(self: HostBuffer, axis_: anytype, start: i64) HostBuffer { + const ax = self.axis(axis_); + return self.slice1d(ax, .{ .start = start, .end = start + 1 }).squeeze(ax); + } + pub fn squeeze(self: HostBuffer, axis_: anytype) HostBuffer { const ax = self._shape.axis(axis_); stdx.debug.assert(self.dim(ax) == 1, "squeeze expects a 1-d axis got {} in {}", .{ ax, self }); diff --git a/zml/meta.zig b/zml/meta.zig index f69a1e0..16680eb 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -1,11 +1,10 @@ const std = @import("std"); -const stdx = @import("stdx"); +const testing = std.testing; +const stdx = @import("stdx"); const FnParam = stdx.meta.FnParam; const asSlice = stdx.meta.asSlice; -const testing = std.testing; - test { std.testing.refAllDecls(@This()); } diff --git a/zml/module.zig b/zml/module.zig index beb7a01..011e6fe 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -367,6 +367,7 @@ pub const CompilationContext = struct { const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count); const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count); const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count); + const fn_res_output_memory_kind = try res_allocator.alloc(Buffer.Memory, out_tensor_count); var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable; { defer self.closeBlock(fn_body); @@ -382,7 +383,7 @@ pub const CompilationContext = struct { }; var fn_res_values: [out_tensor_count]mlir.Value = undefined; - self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); + self.extractValuesAndTypes(fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations, fn_res_output_memory_kind); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); fn_body[0].appendOperationRecursive(fn_ret, fn_body[1]); @@ -396,6 +397,7 @@ pub const CompilationContext = struct { if (opts.kind == .main) { self.addDonationsAttributes(arg_attrs, fn_res_donations); + self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind); if (self._platform.sharding().num_partitions > 1) { self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); } @@ -433,6 +435,20 @@ pub const CompilationContext = struct { }; } + fn addOutputMemoryKindAttributes(self: CompilationContext, attributes: []AttributeList, output_memory_kind: []const Buffer.Memory) void { + const mlir_ctx = self.mlirCtx(); + for (attributes, output_memory_kind) |*attr, memory_kind| { + // .device is the default output, don't explicitly emit the attribute + if (memory_kind == .device) continue; + + attr.appendAssumeCapacity(.named( + mlir_ctx, + "mhlo.memory_kind", + .string(mlir_ctx, memory_kind.pjrtName()), + )); + } + } + /// Given a list of donations mapping output buffers to input buffers, /// generate donation attribute for each `n_args` input argument. fn addDonationsAttributes(self: CompilationContext, attributes: []AttributeList, donations: []const Tensor._Donation) void { @@ -712,7 +728,15 @@ pub const CompilationContext = struct { } /// Visit the given struct and extract the mlir.Value and mlir.Type associated with each tensor found. - pub fn extractValuesAndTypes(self: *const CompilationContext, v: anytype, values: []mlir.Value, types: []mlir.Type, shapes: []Shape, donations: []Tensor._Donation) void { + pub fn extractValuesAndTypes( + self: *const CompilationContext, + v: anytype, + values: []mlir.Value, + types: []mlir.Type, + shapes: []Shape, + donations: []Tensor._Donation, + output_memory_kind: []Buffer.Memory, + ) void { std.debug.assert(values.len == types.len); const LocalContext = struct { self: *const CompilationContext, @@ -721,8 +745,16 @@ pub const CompilationContext = struct { types: []mlir.Type, shapes: []Shape, donations: []Tensor._Donation, + output_memory_kind: []Buffer.Memory, + }; + var context = LocalContext{ + .self = self, + .values = values, + .types = types, + .shapes = shapes, + .donations = donations, + .output_memory_kind = output_memory_kind, }; - var context = LocalContext{ .self = self, .values = values, .types = types, .shapes = shapes, .donations = donations }; meta.visit((struct { fn cb(ctx: *LocalContext, tensor: *const Tensor) void { const value, const donation = ctx.self.getValueAndDonation(tensor.*); @@ -730,6 +762,7 @@ pub const CompilationContext = struct { ctx.types[ctx.index] = value.getType(); ctx.shapes[ctx.index] = tensor._shape; ctx.donations[ctx.index] = donation; + ctx.output_memory_kind[ctx.index] = tensor._output_memory_kind; ctx.index += 1; } }).cb, &context, v); diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 4d5f4f0..8652f9b 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -1,16 +1,16 @@ const std = @import("std"); -const Context = @import("../context.zig").Context; -const module = @import("../module.zig"); -const mlir = @import("../mlir.zig"); const dialect = @import("mlir/dialects"); -const Tensor = @import("../tensor.zig").Tensor; -const Shape = @import("../shape.zig").Shape; -const SdpaOpts = @import("../nn.zig").SdpaOpts; +const Context = @import("../context.zig").Context; const DataType = @import("../dtype.zig").DataType; const Data = @import("../dtype.zig").Data; +const mlir = @import("../mlir.zig"); +const module = @import("../module.zig"); const CompilationContext = module.CompilationContext; +const SdpaOpts = @import("../nn.zig").SdpaOpts; +const Shape = @import("../shape.zig").Shape; +const Tensor = @import("../tensor.zig").Tensor; pub fn canUseCudnnSdpa(q_shape: Shape) bool { const ctx = CompilationContext.current(); @@ -125,7 +125,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { &.{ q.value(), k.value(), v.value(), bias.value() }, .{ .call_target_name = "__cudnn$fmhaScaleBiasSoftmax", - .backend_config = .{ .string = backend_config }, + .backend_config = .string(mlir_ctx, backend_config), .has_side_effect = false, .api_version = .original, }, diff --git a/zml/ops.zig b/zml/ops.zig index 6ab3bad..1765888 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -1,28 +1,31 @@ const std = @import("std"); +const assert = std.debug.assert; + const stdx = @import("stdx"); +const _collectAxes = @import("tensor.zig")._collectAxes; const buffer = @import("buffer.zig"); -const helpers = @import("helpers.zig"); -const meta = @import("meta.zig"); -const mlir = @import("mlir.zig"); -const module = @import("module.zig"); - const Buffer = buffer.Buffer; -const CompilationContext = module.CompilationContext; +const Bufferized = @import("tensor.zig").Bufferized; const Context = @import("context.zig").Context; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; -const EnumLiteral = @TypeOf(.enum_literal); +const helpers = @import("helpers.zig"); const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const meta = @import("meta.zig"); +const mlir = @import("mlir.zig"); +const module = @import("module.zig"); +const CompilationContext = module.CompilationContext; +const Platform = @import("platform.zig").Platform; const Shape = @import("shape.zig").Shape; +const ShapeOf = @import("tensor.zig").ShapeOf; const Tensor = @import("tensor.zig").Tensor; -const _collectAxes = @import("tensor.zig")._collectAxes; +const EnumLiteral = @TypeOf(.enum_literal); const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; -const assert = std.debug.assert; const log = std.log.scoped(.@"zml/tensor"); test { @@ -766,50 +769,56 @@ pub fn fromMlirOperationWithTags(op: mlir.Operation, base: anytype) @TypeOf(base return res; } -/// At runtime the given tensor will be materialized and copied to host, -/// and the callback will be called on it. +pub const HostCallbackOpt = struct { + has_side_effect: bool = false, + output_operand_aliases: []const i64 = &.{}, +}; + pub fn addHostCallback( - callback: *const fn (HostBuffer) void, - input: Tensor, -) Tensor { - // TODO: implement addCallback that exposes a pjrt.Buffer, so that the user can decide if they need to copy. - if (input.getContext().target() != .cuda) return input; - - const len = input.byteSize(); - // Reserve memory to be able to log the runtime Buffer later during the computation. - // This memory is leaked, we currently have no way to tie this lifetime to the lifetime of the module being compiled. - const HostCallbackCtx = Context.HostCallbackCtx; - const full_data = std.heap.page_allocator.alignedAlloc(u8, 32, len + 2 * @sizeOf(HostCallbackCtx)) catch { - log.err("Failed to pre-allocate buffer to print {}.", .{input}); - return input; - }; - - // Save the HostBuffer inside the same memory slice, so that it's still present at runtime. - // Use an fba to have the stable buffer at an aligned offset. - var fba = std.heap.FixedBufferAllocator.init(full_data[len..]); - const stable_ctx_ptr = fba.allocator().create(HostCallbackCtx) catch unreachable; - stable_ctx_ptr.* = .{ - .host = HostBuffer.fromBytes(input.shape(), full_data[0..len]), - }; - - const backend_config: [2:null]?*const anyopaque = .{ callback, stable_ctx_ptr }; + callback: *const Context.HostCallback, + blkctx: ?*anyopaque, + inputs: []const Tensor, + output_shapes: []const Shape, + opts: HostCallbackOpt, +) []Tensor { const ctx = CompilationContext.current(); + const mlir_ctx = ctx.mlirCtx(); + const backend_config = mlir.Attribute.dict(mlir_ctx, &.{ + .{ "callback", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(callback))) }, + .{ "user_context", .int(mlir_ctx, .u64, @bitCast(@intFromPtr(blkctx))) }, + }); + + const values = stdx.stackSlice(8, mlir.Value, inputs.len); + for (inputs, values) |i, *v| { + v.* = ctx.getValue(i.toMemory(.host_pinned)); + } + const res_types = stdx.stackSlice(8, mlir.Type, output_shapes.len); + for (res_types, output_shapes) |*r, o| { + r.* = mlir.ext.RankedTensorType.fromShape(mlir_ctx, o).as(mlir.Type); + } + const loc = ctx.mlirCtx().location(@src()); const op = dialect.stablehlo.custom_call( ctx.mlirCtx(), - &.{input.value()}, + values, .{ - .has_side_effect = false, .call_target_name = "zmlHostBufferCallback", - .backend_config = .{ .string = @ptrCast(std.mem.sliceAsBytes(&backend_config)) }, - .output_operand_aliases = &.{0}, - .api_version = .original, + .api_version = .typed_ffi, + .backend_config = backend_config, + .has_side_effect = opts.has_side_effect, + .output_operand_aliases = opts.output_operand_aliases, }, - &.{input.value().getType()}, + res_types, loc, ); - return Tensor._result(input.shape(), op.result(0)); + + const res = ctx.allocator().alloc(Tensor, output_shapes.len) catch @panic("OOM"); + for (res, output_shapes, 0..) |*r, o, i| { + r.* = Tensor._result(o, op.result(i)).toMemory(.device); + } + + return res; } pub const TritonOps = struct { @@ -834,46 +843,32 @@ pub fn triton(inputs: anytype, outputs: anytype, opts: TritonOps) [outputs.len]T res_types[i] = mlir.ext.mlirType(ctx.mlirCtx(), output); } - const attrs = mlir.DictionaryAttribute.init(ctx.mlirCtx(), &.{ - .named(ctx.mlirCtx(), "name", .string(ctx.mlirCtx(), opts.name)), - .named(ctx.mlirCtx(), "ir", .string(ctx.mlirCtx(), opts.ir)), - .named(ctx.mlirCtx(), "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0])), - .named(ctx.mlirCtx(), "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1])), - .named(ctx.mlirCtx(), "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2])), - .named(ctx.mlirCtx(), "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages)), - .named(ctx.mlirCtx(), "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps)), + const backend_config = mlir.Attribute.dict(ctx.mlirCtx(), &.{ + .{ "name", .string(ctx.mlirCtx(), opts.name) }, + .{ "ir", .string(ctx.mlirCtx(), opts.ir) }, + .{ "grid_x", .int(ctx.mlirCtx(), .i32, opts.grid[0]) }, + .{ "grid_y", .int(ctx.mlirCtx(), .i32, opts.grid[1]) }, + .{ "grid_z", .int(ctx.mlirCtx(), .i32, opts.grid[2]) }, + .{ "num_stages", .int(ctx.mlirCtx(), .i32, opts.num_stages) }, + .{ "num_warps", .int(ctx.mlirCtx(), .i32, opts.num_warps) }, }); - const MINOR_TO_MAJOR = blk: { - var ret: [Shape.MAX_RANK]usize = undefined; - for (0..Shape.MAX_RANK) |i| { - ret[i] = @intCast(Shape.MAX_RANK - i - 1); - } - break :blk ret; - }; + var operands_layouts: [inputs.len][]const usize = undefined; + inline for (inputs, 0..) |input, i| { + operands_layouts[i] = minorToMajor(input.rank()); + } - const operands_layouts = blk: { - var ret: [inputs.len][]const usize = undefined; - inline for (inputs, 0..) |input, i| { - ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - input.rank() ..]; - } - break :blk ret; - }; - - const results_layouts = blk: { - var ret: [outputs.len][]const usize = undefined; - inline for (outputs, 0..) |output, i| { - ret[i] = MINOR_TO_MAJOR[MINOR_TO_MAJOR.len - output.rank() ..]; - } - break :blk ret; - }; + var results_layouts: [outputs.len][]const usize = undefined; + inline for (outputs, 0..) |output, i| { + results_layouts[i] = minorToMajor(output.rank()); + } const op = dialect.stablehlo.custom_call( ctx.mlirCtx(), &values, .{ .call_target_name = "__gpu$xla.gpu.triton", - .backend_config = .{ .dict = attrs }, + .backend_config = backend_config, .has_side_effect = false, .api_version = .typed_ffi, .operand_layouts = &operands_layouts, @@ -1256,3 +1251,15 @@ inline fn toI64(values: anytype) []i64 { for (values, 0..) |val, i| res[i] = @intCast(val); return res[0..values.len]; } + +const _MINOR_TO_MAJOR = blk: { + var ret: [Shape.MAX_RANK]usize = undefined; + for (0..Shape.MAX_RANK) |i| { + ret[i] = @intCast(Shape.MAX_RANK - i - 1); + } + break :blk ret; +}; + +fn minorToMajor(rank: u8) []const usize { + return _MINOR_TO_MAJOR[_MINOR_TO_MAJOR.len - rank ..]; +} diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 382a5d9..bd094ea 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -1,19 +1,10 @@ +const std = @import("std"); + const asynk = @import("async"); -const builtin = @import("builtin"); const dialects = @import("mlir/dialects"); const mlir = @import("mlir"); const pjrt = @import("pjrt"); -const std = @import("std"); -const stdx = @import("stdx"); -const c = @import("c"); - -const dtype = @import("dtype.zig"); -const meta = @import("meta.zig"); - -const Target = @import("platform.zig").Target; - -const log = std.log.scoped(.zml); - +pub const ffi = pjrt.ffi; pub const Profiler = pjrt.Profiler; pub const ApiError = pjrt.ApiError; pub const ErrorCode = pjrt.ErrorCode; @@ -23,7 +14,6 @@ pub const DeviceDescription = pjrt.DeviceDescription; pub const Api = pjrt.Api; pub const NamedValue = pjrt.NamedValue; pub const ClientInitError = pjrt.ClientInitError; -pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError; pub const Error = pjrt.Error; pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; pub const SerializeResult = pjrt.SerializeResult; @@ -31,6 +21,10 @@ pub const Executable = pjrt.Executable; pub const ExecuteError = ApiError; pub const Memory = pjrt.Memory; +const log = std.log.scoped(.zml); + +pub const CompileError = std.mem.Allocator.Error || error{InvalidMlirBytecodeVersion} || ApiError; + fn InnerMixin(comptime innerT: type) type { return struct { inline fn inner(self: anytype) if (@typeInfo(@TypeOf(self)).pointer.is_const) *const innerT else *innerT { @@ -159,6 +153,10 @@ pub const Buffer = opaque { return self.inner().isOnCpu(api); } + pub fn memory(self: *const Buffer, api: *const Api) *const Memory { + return self.inner().memory(api); + } + pub fn toHostBuffer(self: *const Buffer, api: *const Api, dst: []u8) ApiError!?*Event { return @ptrCast(try self.inner().toHostBuffer(api, dst)); } @@ -183,8 +181,8 @@ pub const Buffer = opaque { return @ptrCast(self.inner().copyToDevice(api, device)); } - pub fn copyToMemory(self: *const Buffer, api: *const Api, memory: *const Memory) ApiError!*Buffer { - return @ptrCast(self.inner().copyToMemory(api, memory)); + pub fn copyToMemory(self: *const Buffer, api: *const Api, memory_: *const Memory) ApiError!*Buffer { + return @ptrCast(self.inner().copyToMemory(api, memory_)); } pub fn getReadyEvent(self: *const Buffer, api: *const Api) ?*Event { diff --git a/zml/tensor.zig b/zml/tensor.zig index 49381e9..b248ea2 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -9,6 +9,7 @@ const Buffer = @import("buffer.zig").Buffer; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const Memory = @import("buffer.zig").Buffer.Memory; const meta = @import("meta.zig"); const mlir = @import("mlir.zig"); const Location = mlir.Location; @@ -41,10 +42,10 @@ pub const Tensor = struct { _shape: Shape, _id: _Id, _donation: _Donation = .no_buffer, + _output_memory_kind: Memory = .device, pub const _Donation = union(enum) { no_buffer, input_buffer, arg: u16 }; pub const _Id = union(enum) { mlir: mlir.Value, buffer_id: u64, arg_id: u64 }; - pub const MAX_RANK = Shape.MAX_RANK; /// Returns the current compilation context. @@ -171,20 +172,22 @@ pub const Tensor = struct { return switch (self._id) { .arg_id, .mlir => { const ctx = self.getContext(); + const mlir_ctx = ctx.mlirCtx(); var res = self; res._shape = self._shape.withSharding(axes_); const op = dialect.stablehlo.custom_call( - ctx.mlirCtx(), + mlir_ctx, &.{self.value()}, .{ .call_target_name = "Sharding", .has_side_effect = false, - .addional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }}, + .backend_config = null, + .additional_attributes = &.{.{ "mhlo.sharding", ctx.getShardingAttr(res._shape) }}, .api_version = .original, }, &.{self.value().getType()}, - ctx.mlirCtx().location(@src()), + mlir_ctx.location(@src()), ); return _result(res._shape, op.result(0)); @@ -197,6 +200,39 @@ pub const Tensor = struct { }; } + pub fn toMemory(self: Tensor, kind: Memory) Tensor { + return switch (self._id) { + .arg_id, .mlir => { + const ctx = self.getContext(); + const mlir_ctx = ctx.mlirCtx(); + if (ctx.target() == .cpu) return self; + var res = self; + res._output_memory_kind = kind; + + const memory_kind = @tagName(kind.toPjrtMemory()); + + const frontend_attributes = mlir.Attribute.dict(mlir_ctx, &.{ + .{ "_xla_buffer_placement", .string(mlir_ctx, memory_kind) }, + }); + + const op = dialect.stablehlo.custom_call(mlir_ctx, &.{self.value()}, .{ + .call_target_name = "annotate_device_placement", + .has_side_effect = true, + .backend_config = null, + .additional_attributes = &.{.{ "mhlo.frontend_attributes", frontend_attributes }}, + .api_version = .original, + }, &.{self.value().getType()}, mlir_ctx.location(@src())); + + return _result(res._shape, op.result(0)); + }, + .buffer_id => { + var res = self; + res._output_memory_kind = kind; + return res; + }, + }; + } + /// Returns a Tensor with new tag names. pub fn rename(self: Tensor, renames: anytype) Tensor { var res = self; @@ -3747,18 +3783,22 @@ pub const Tensor = struct { } /// Insert code that will print the content of the given buffer at runtime. - /// Only for debug purpose, it has the following limitations: - /// * only supported on Cuda atm - /// * only prints the first 1024 values - /// * pre allocates a buffer on the host to copy the content of the device buffer, - /// this buffer won't be freed. You will have one buffer per "print" call in the IR. - /// * does device to host synchronization so it will slow down the program execution. + /// Only for debug purpose, it inserts device to host synchronization + /// so it will slow down the program execution. pub fn print(input: Tensor) Tensor { - return ops.addHostCallback(&printCallback, input); + return ops.addHostCallback( + &printCallback, + null, + &.{input}, + &.{input.shape()}, + .{ .output_operand_aliases = &.{0} }, + )[0]; } - fn printCallback(host_buffer: HostBuffer) void { + fn printCallback(_: ?*anyopaque, inputs: []const HostBuffer, outputs: []const HostBuffer) void { + const host_buffer = inputs[0]; std.debug.print("Device buffer: {}: {}", .{ host_buffer.shape(), host_buffer.pretty() }); + std.debug.assert(host_buffer.data.ptr == outputs[0].data.ptr); } };