Update ZML build configuration to replace zig-protobuf with upb library.
This commit is contained in:
parent
cba3be4859
commit
1b58c34b8c
20
MODULE.bazel
20
MODULE.bazel
@ -10,7 +10,7 @@ bazel_dep(name = "libxev", version = "20250718.0-9f785d2")
|
||||
bazel_dep(name = "patchelf", version = "0.18.0")
|
||||
bazel_dep(name = "pcre2", version = "10.43")
|
||||
bazel_dep(name = "platforms", version = "0.0.11")
|
||||
bazel_dep(name = "protobuf", version = "29.2")
|
||||
bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf")
|
||||
bazel_dep(name = "rules_cc", version = "0.1.1")
|
||||
bazel_dep(name = "rules_distroless", version = "0.5.1")
|
||||
bazel_dep(name = "rules_proto", version = "7.1.0")
|
||||
@ -20,27 +20,12 @@ bazel_dep(name = "rules_uv", version = "0.65.0")
|
||||
bazel_dep(name = "rules_zig", version = "20250714.0-b14a4f1")
|
||||
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
|
||||
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
|
||||
bazel_dep(name = "toolchains_protoc", version = "0.4.1")
|
||||
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
|
||||
bazel_dep(name = "xla", version = "20250718.0-6319f0d")
|
||||
bazel_dep(name = "zig-protobuf", version = "20250716.0-97f1e31")
|
||||
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
|
||||
|
||||
bazel_dep(name = "buildifier_prebuilt", version = "8.0.3", dev_dependency = True)
|
||||
|
||||
# Optional: choose a version of protoc rather than the latest.
|
||||
protoc = use_extension("@toolchains_protoc//protoc:extensions.bzl", "protoc")
|
||||
protoc.toolchain(
|
||||
# Creates a repository to satisfy well-known-types dependencies such as
|
||||
# deps=["@com_google_protobuf//:any_proto"]
|
||||
google_protobuf = "com_google_protobuf",
|
||||
# Pin to any version of protoc
|
||||
version = "v29.2",
|
||||
)
|
||||
use_repo(protoc, "com_google_protobuf", "toolchains_protoc_hub")
|
||||
|
||||
register_toolchains("@toolchains_protoc_hub//:all")
|
||||
|
||||
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
|
||||
zig.index(file = "//bazel:zig_index.json")
|
||||
zig.toolchain(zig_version = "0.14.1")
|
||||
@ -88,7 +73,6 @@ pip.parse(
|
||||
)
|
||||
use_repo(pip, "huggingface_hub")
|
||||
|
||||
|
||||
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
|
||||
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
|
||||
|
||||
@ -160,12 +144,14 @@ apt.install(
|
||||
manifest = "//runtimes/cuda:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_cuda")
|
||||
|
||||
apt.install(
|
||||
name = "apt_rocm",
|
||||
lock = "//runtimes/rocm:packages.lock.json",
|
||||
manifest = "//runtimes/rocm:packages.yaml",
|
||||
)
|
||||
use_repo(apt, "apt_rocm")
|
||||
|
||||
apt.install(
|
||||
name = "apt_neuron",
|
||||
lock = "//runtimes/neuron:packages.lock.json",
|
||||
|
||||
4712
MODULE.bazel.lock
4712
MODULE.bazel.lock
File diff suppressed because one or more lines are too long
@ -1,6 +1,6 @@
|
||||
"""Starlark implementation of zig_proto_library"""
|
||||
|
||||
load("@protobuf//bazel/common:proto_info.bzl", "ProtoInfo")
|
||||
load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo")
|
||||
load("@rules_proto//proto:defs.bzl", "proto_common")
|
||||
load(
|
||||
"@rules_zig//zig/private/providers:zig_module_info.bzl",
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
|
||||
load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
|
||||
load("@protobuf//bazel:proto_library.bzl", "proto_library")
|
||||
load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
|
||||
load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
@ -90,7 +90,7 @@ cc_library(
|
||||
":darts_clone",
|
||||
":sentencepiece_cc_proto",
|
||||
":sentencepiece_model_cc_proto",
|
||||
"@protobuf//:protobuf_lite",
|
||||
"@com_google_protobuf//:protobuf_lite",
|
||||
"@zml//third_party/com_google_sentencepiece:absl",
|
||||
],
|
||||
)
|
||||
|
||||
12
upb/BUILD.bazel
Normal file
12
upb/BUILD.bazel
Normal file
@ -0,0 +1,12 @@
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
|
||||
cc_library(
|
||||
name = "empty",
|
||||
)
|
||||
|
||||
zig_library(
|
||||
name = "upb",
|
||||
main = "upb.zig",
|
||||
deps = [":empty"],
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
149
upb/upb.zig
Normal file
149
upb/upb.zig
Normal file
@ -0,0 +1,149 @@
|
||||
const std = @import("std");
|
||||
|
||||
const c = @import("c");
|
||||
|
||||
export fn _upb_MiniTable_StrongReference_dont_copy_me__upb_internal_use_only(mt: *c.upb_MiniTable) callconv(.c) void {
|
||||
const unused: *volatile c.upb_MiniTable = mt;
|
||||
_ = &unused;
|
||||
}
|
||||
|
||||
pub const SerializeOptions = packed struct(c_int) {
|
||||
deterministic: bool = false,
|
||||
skip_unknown: bool = false,
|
||||
check_required: bool = false,
|
||||
_ignored: u29 = 0,
|
||||
};
|
||||
|
||||
pub const SerializeError = error{
|
||||
MaxDepthExceeded,
|
||||
MissingRequired,
|
||||
Unknown,
|
||||
} || std.mem.Allocator.Error;
|
||||
|
||||
pub const ParseOptions = packed struct {
|
||||
alias_string: bool = false,
|
||||
check_required: bool = false,
|
||||
experimental_allow_unlinked: bool = false,
|
||||
always_validate_utf8: bool = false,
|
||||
disable_fast_table: bool = false,
|
||||
_ignored: u27 = 0,
|
||||
};
|
||||
|
||||
pub const ParseError = error{
|
||||
Malformed,
|
||||
BadUtf8,
|
||||
MaxDepthExceeded,
|
||||
MissingRequired,
|
||||
UnlinkedSubMessage,
|
||||
Unknown,
|
||||
} || std.mem.Allocator.Error;
|
||||
|
||||
pub fn stringView(data: ?[]const u8) c.upb_StringView {
|
||||
return if (data) |d| c.upb_StringView_FromDataAndSize(d.ptr, d.len) else .{};
|
||||
}
|
||||
|
||||
pub fn slice(sv: c.upb_StringView) ?[]const u8 {
|
||||
return if (sv.data) |d| d[0..sv.size] else null;
|
||||
}
|
||||
|
||||
fn ProtoName(comptime UpbType: type) []const u8 {
|
||||
const needle = ".struct_";
|
||||
const type_name = @typeName(UpbType);
|
||||
const idx = std.mem.indexOf(u8, type_name, needle) orelse @compileError("Type name is invalid");
|
||||
return type_name[idx + needle.len ..];
|
||||
}
|
||||
|
||||
fn Minitable(comptime UpbType: type) *const c.upb_MiniTable {
|
||||
const field_name = comptime blk: {
|
||||
const name = ProtoName(UpbType);
|
||||
var it = std.mem.tokenizeScalar(u8, name, '_');
|
||||
while (it.next()) |_| {
|
||||
const new_name = name[0..it.index] ++ "_" ++ name[it.index..] ++ "_msg_init";
|
||||
if (@hasDecl(c, new_name)) {
|
||||
break :blk new_name;
|
||||
}
|
||||
} else {
|
||||
@compileError("Unable to find minitable for type:" ++ @typeName(UpbType));
|
||||
}
|
||||
};
|
||||
return &@field(c, field_name);
|
||||
}
|
||||
|
||||
pub fn serializeEx(ptr: anytype, arena: *c.upb_Arena, opts: SerializeOptions) SerializeError![]const u8 {
|
||||
var buf: [*c]u8 = undefined;
|
||||
var size: usize = undefined;
|
||||
return switch (c.upb_Encode(@ptrCast(ptr), Minitable(@TypeOf(ptr.*)), @bitCast(opts), arena, &buf, &size)) {
|
||||
c.kUpb_EncodeStatus_Ok => buf[0..size],
|
||||
c.kUpb_EncodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
|
||||
c.kUpb_EncodeStatus_MaxDepthExceeded => SerializeError.MaxDepthExceeded,
|
||||
c.kUpb_EncodeStatus_MissingRequired => SerializeError.MissingRequired,
|
||||
else => return SerializeError.Unknown,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 {
|
||||
return try serializeEx(ptr, arena, .{});
|
||||
}
|
||||
|
||||
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
|
||||
const obj = try new(UpbType, arena);
|
||||
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
|
||||
c.kUpb_DecodeStatus_Ok => obj,
|
||||
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
|
||||
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
|
||||
c.kUpb_DecodeStatus_BadUtf8 => ParseError.BadUtf8,
|
||||
c.kUpb_DecodeStatus_MaxDepthExceeded => ParseError.MaxDepthExceeded,
|
||||
c.kUpb_DecodeStatus_MissingRequired => ParseError.MissingRequired,
|
||||
c.kUpb_DecodeStatus_UnlinkedSubMessage => ParseError.UnlinkedSubMessage,
|
||||
else => ParseError.Unknown,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn parse(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8) ParseError!*UpbType {
|
||||
return parseEx(arena, data, .{});
|
||||
}
|
||||
|
||||
pub fn new(comptime UpbType: type, arena: *c.upb_Arena) std.mem.Allocator.Error!*UpbType {
|
||||
const new_fn = @field(c, ProtoName(UpbType) ++ "_new");
|
||||
return @ptrCast(new_fn(arena) orelse return std.mem.Allocator.Error.OutOfMemory);
|
||||
}
|
||||
|
||||
pub const Allocator = struct {
|
||||
upb_alloc: c.upb_alloc,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) Allocator {
|
||||
return .{
|
||||
.upb_alloc = .{
|
||||
.func = &alloc_impl,
|
||||
},
|
||||
.allocator = allocator,
|
||||
};
|
||||
}
|
||||
|
||||
pub fn inner(self: *Allocator) *c.upb_alloc {
|
||||
return &self.upb_alloc;
|
||||
}
|
||||
|
||||
fn alloc_impl(alloc: [*c]c.upb_alloc, ptr: ?*anyopaque, oldsize: usize, size: usize, actual_size: [*c]usize) callconv(.c) ?*anyopaque {
|
||||
const upb_alloc: *c.upb_alloc = alloc orelse return null;
|
||||
const self: *Allocator = @fieldParentPtr("upb_alloc", upb_alloc);
|
||||
defer {
|
||||
if (actual_size) |as| {
|
||||
as.* = size;
|
||||
}
|
||||
}
|
||||
if (ptr) |ptr_| {
|
||||
const ptr_as_slice: []u8 = @as([*c]u8, @ptrCast(ptr_))[0..oldsize];
|
||||
if (size == 0) {
|
||||
self.allocator.free(ptr_as_slice);
|
||||
return null;
|
||||
} else if (size != oldsize) {
|
||||
return (self.allocator.realloc(ptr_as_slice, size) catch return null).ptr;
|
||||
}
|
||||
@panic("Unsupported case");
|
||||
}
|
||||
|
||||
return (self.allocator.alloc(u8, size) catch return null).ptr;
|
||||
}
|
||||
};
|
||||
@ -1,9 +1,19 @@
|
||||
load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library")
|
||||
load("@rules_cc//cc:defs.bzl", "cc_library")
|
||||
load("@rules_zig//zig:defs.bzl", "zig_library")
|
||||
load("//bazel:zig.bzl", "zig_cc_test")
|
||||
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
|
||||
load("//bazel:zig_srcs.bzl", "zig_srcs")
|
||||
|
||||
upb_c_proto_library(
|
||||
name = "xla_data_upb",
|
||||
deps = ["@xla//xla:xla_data_proto"],
|
||||
)
|
||||
|
||||
upb_c_proto_library(
|
||||
name = "xla_compile_options_upb",
|
||||
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "posix",
|
||||
hdrs = ["posix.h"],
|
||||
@ -25,26 +35,21 @@ zig_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":posix",
|
||||
":xla_proto",
|
||||
":xla_compile_options_upb",
|
||||
":xla_data_upb",
|
||||
"//async",
|
||||
"//mlir",
|
||||
"//mlir/dialects",
|
||||
"//pjrt",
|
||||
"//runtimes",
|
||||
"//stdx",
|
||||
"//upb",
|
||||
"//zml/tokenizer",
|
||||
"//zml/tools",
|
||||
"@rules_zig//zig/runfiles",
|
||||
],
|
||||
)
|
||||
|
||||
zig_proto_library(
|
||||
name = "xla_proto",
|
||||
import_name = "//xla:xla_proto",
|
||||
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
|
||||
)
|
||||
|
||||
|
||||
# All ZML Tests
|
||||
|
||||
zig_cc_test(
|
||||
|
||||
138
zml/module.zig
138
zml/module.zig
@ -1,10 +1,11 @@
|
||||
const std = @import("std");
|
||||
|
||||
const asynk = @import("async");
|
||||
const c = @import("c");
|
||||
const dialect = @import("mlir/dialects");
|
||||
const mlir = @import("mlir");
|
||||
const stdx = @import("stdx");
|
||||
const xla_pb = @import("//xla:xla_proto");
|
||||
const upb = @import("upb");
|
||||
|
||||
const BaseExe = @import("exe.zig").BaseExe;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
@ -834,6 +835,28 @@ fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecut
|
||||
try loaded_executable_file.writeAll(serialize_result.bytes);
|
||||
}
|
||||
|
||||
fn setXlaOverrideFlag(map: *c.upb_Map, flag: []const u8, value: anytype, upb_arena: *c.upb_Arena) !void {
|
||||
const result = c.upb_Map_Set(
|
||||
map,
|
||||
.{ .str_val = upb.stringView(flag) },
|
||||
.{ .msg_val = blk: {
|
||||
const field = try upb.new(c.xla_OptionOverrideProto, upb_arena);
|
||||
switch (@typeInfo(@TypeOf(value))) {
|
||||
.bool => c.xla_OptionOverrideProto_set_bool_field(field, value),
|
||||
.comptime_int, .int => c.xla_OptionOverrideProto_set_int_field(field, @intCast(value)),
|
||||
.comptime_float, .float => c.xla_OptionOverrideProto_set_double_field(field, @floatCast(value)),
|
||||
else => c.xla_OptionOverrideProto_set_string_field(field, upb.stringView(value)),
|
||||
}
|
||||
break :blk @ptrCast(field);
|
||||
} },
|
||||
upb_arena,
|
||||
);
|
||||
|
||||
if (result == false) {
|
||||
return std.mem.Allocator.Error.OutOfMemory;
|
||||
}
|
||||
}
|
||||
|
||||
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable {
|
||||
const tracer = Tracer.init("ai.zml.compilation");
|
||||
const compile_frame = tracer.frameStart("pjrt compilation");
|
||||
@ -841,86 +864,83 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
|
||||
const sharding = platform.sharding();
|
||||
|
||||
var options: xla_pb.CompileOptionsProto = .{
|
||||
.executable_build_options = .{
|
||||
.device_ordinal = -1,
|
||||
.num_replicas = sharding.num_replicas,
|
||||
.num_partitions = sharding.num_partitions,
|
||||
.use_spmd_partitioning = sharding.num_partitions > 1 or sharding.num_replicas > 1,
|
||||
.device_assignment = .{
|
||||
.replica_count = sharding.num_replicas,
|
||||
.computation_count = sharding.num_partitions,
|
||||
// Filled below.
|
||||
.computation_devices = .{},
|
||||
},
|
||||
},
|
||||
};
|
||||
const computation_devices = &options.executable_build_options.?.device_assignment.?.computation_devices;
|
||||
try computation_devices.ensureTotalCapacity(arena, sharding.num_partitions);
|
||||
const replica_device_ids = try arena.alloc(i64, sharding.num_partitions);
|
||||
for (0..sharding.num_partitions) |i| {
|
||||
replica_device_ids[i] = @intCast(i);
|
||||
computation_devices.appendAssumeCapacity(.{ .replica_device_ids = .fromOwnedSlice(replica_device_ids[i .. i + 1]) });
|
||||
}
|
||||
var upb_alloc: upb.Allocator = .init(arena);
|
||||
const upb_arena = c.upb_Arena_Init(null, 0, upb_alloc.inner());
|
||||
defer c.upb_Arena_Free(upb_arena);
|
||||
|
||||
// Let the arena deinit, zig-protobuf deinit is very slow.
|
||||
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
||||
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||
setFlag(&options, "xla_dump_to", xla_dump_to);
|
||||
setFlag(&options, "xla_dump_hlo_as_proto", true);
|
||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||
setFlag(&options, "xla_dump_fusion_visualization", true);
|
||||
}
|
||||
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
|
||||
setFlag(&options, "xla_dump_hlo_pass_re", re);
|
||||
}
|
||||
const options = blk: {
|
||||
const options = try upb.new(c.xla_CompileOptionsProto, upb_arena);
|
||||
c.xla_CompileOptionsProto_set_executable_build_options(options, executable_build_options_blk: {
|
||||
const exec_build_options = try upb.new(c.xla_ExecutableBuildOptionsProto, upb_arena);
|
||||
|
||||
c.xla_ExecutableBuildOptionsProto_set_device_ordinal(exec_build_options, -1);
|
||||
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
|
||||
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
|
||||
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
|
||||
|
||||
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
|
||||
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
|
||||
|
||||
c.xla_DeviceAssignmentProto_set_replica_count(device_assignment, sharding.num_replicas);
|
||||
c.xla_DeviceAssignmentProto_set_computation_count(device_assignment, sharding.num_partitions);
|
||||
|
||||
const computation_devices = c.xla_DeviceAssignmentProto_resize_computation_devices(device_assignment, sharding.num_partitions, upb_arena);
|
||||
for (computation_devices[0..sharding.num_partitions], 0..) |*computation_device, i| {
|
||||
computation_device.* = try upb.new(c.xla_DeviceAssignmentProto_ComputationDevice, upb_arena);
|
||||
_ = c.xla_DeviceAssignmentProto_ComputationDevice_add_replica_device_ids(computation_device.*, @intCast(i), upb_arena);
|
||||
}
|
||||
break :device_assignment_blk device_assignment;
|
||||
});
|
||||
|
||||
break :executable_build_options_blk exec_build_options;
|
||||
});
|
||||
|
||||
const overrides_map = c._xla_CompileOptionsProto_env_option_overrides_mutable_upb_map(options, upb_arena);
|
||||
switch (platform.target) {
|
||||
.cuda => {
|
||||
// NVIDIA recommends these settings
|
||||
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
||||
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
|
||||
setFlag(&options, "xla_gpu_enable_llvm_module_compilation_parallelism", true);
|
||||
setFlag(&options, "xla_gpu_enable_libnvptxcompiler", true);
|
||||
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
|
||||
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
|
||||
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
|
||||
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
|
||||
// setFlags(&options, "xla_gpu_enable_address_computation_fusion", true);
|
||||
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
|
||||
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_libnvptxcompiler", true, upb_arena);
|
||||
},
|
||||
.rocm => {
|
||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||
// enabled on CDNA and it's used on RDNA. Disable it altogether.
|
||||
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
|
||||
// Use lld from libllvm instead of invoking the ld.lld binary.
|
||||
// This saves us from having to sandbox it.
|
||||
setFlag(&options, "xla_gpu_use_inprocess_lld", true);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_gpu_use_inprocess_lld", true, upb_arena);
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
|
||||
const options_bytes = try options.encode(arena);
|
||||
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||
try setXlaOverrideFlag(overrides_map, "xla_dump_to", xla_dump_to, upb_arena);
|
||||
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_as_proto", true, upb_arena);
|
||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||
try setXlaOverrideFlag(overrides_map, "xla_dump_fusion_visualization", true, upb_arena);
|
||||
}
|
||||
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
|
||||
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_pass_re", re, upb_arena);
|
||||
}
|
||||
}
|
||||
|
||||
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
|
||||
break :blk options;
|
||||
};
|
||||
|
||||
const loaded_executable = try platform.pjrt_client.compile(
|
||||
platform.pjrt_api,
|
||||
arena,
|
||||
module,
|
||||
try upb.serialize(options, upb_arena),
|
||||
);
|
||||
errdefer loaded_executable.deinit();
|
||||
|
||||
return loaded_executable;
|
||||
}
|
||||
|
||||
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
||||
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
||||
.bool => .{ .value = .{ .bool_field = value } },
|
||||
.comptime_int, .int => .{ .value = .{ .int_field = value } },
|
||||
.comptime_float, .float => .{ .value = .{ .double_field = value } },
|
||||
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
||||
};
|
||||
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
||||
}
|
||||
|
||||
/// Visit the given struct and recursively counts the number of tensors found.
|
||||
pub fn countTensors(v: anytype) usize {
|
||||
const LocalContext = struct {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user