Update ZML build configuration to replace zig-protobuf with upb library.

This commit is contained in:
Tarry Singh 2025-06-09 16:34:20 +00:00
parent cba3be4859
commit 1b58c34b8c
8 changed files with 565 additions and 4527 deletions

View File

@ -10,7 +10,7 @@ bazel_dep(name = "libxev", version = "20250718.0-9f785d2")
bazel_dep(name = "patchelf", version = "0.18.0")
bazel_dep(name = "pcre2", version = "10.43")
bazel_dep(name = "platforms", version = "0.0.11")
bazel_dep(name = "protobuf", version = "29.2")
bazel_dep(name = "protobuf", version = "32.0", repo_name = "com_google_protobuf")
bazel_dep(name = "rules_cc", version = "0.1.1")
bazel_dep(name = "rules_distroless", version = "0.5.1")
bazel_dep(name = "rules_proto", version = "7.1.0")
@ -20,27 +20,12 @@ bazel_dep(name = "rules_uv", version = "0.65.0")
bazel_dep(name = "rules_zig", version = "20250714.0-b14a4f1")
bazel_dep(name = "sentencepiece", version = "20240618.0-d7ace0a")
bazel_dep(name = "toolchains_llvm_bootstrapped", version = "0.2.4")
bazel_dep(name = "toolchains_protoc", version = "0.4.1")
bazel_dep(name = "with_cfg.bzl", version = "0.9.1")
bazel_dep(name = "xla", version = "20250718.0-6319f0d")
bazel_dep(name = "zig-protobuf", version = "20250716.0-97f1e31")
bazel_dep(name = "zig-yaml", version = "20240903.0-83d5fdf")
bazel_dep(name = "buildifier_prebuilt", version = "8.0.3", dev_dependency = True)
# Optional: choose a version of protoc rather than the latest.
protoc = use_extension("@toolchains_protoc//protoc:extensions.bzl", "protoc")
protoc.toolchain(
# Creates a repository to satisfy well-known-types dependencies such as
# deps=["@com_google_protobuf//:any_proto"]
google_protobuf = "com_google_protobuf",
# Pin to any version of protoc
version = "v29.2",
)
use_repo(protoc, "com_google_protobuf", "toolchains_protoc_hub")
register_toolchains("@toolchains_protoc_hub//:all")
zig = use_extension("@rules_zig//zig:extensions.bzl", "zig")
zig.index(file = "//bazel:zig_index.json")
zig.toolchain(zig_version = "0.14.1")
@ -88,7 +73,6 @@ pip.parse(
)
use_repo(pip, "huggingface_hub")
cpu = use_extension("//runtimes/cpu:cpu.bzl", "cpu_pjrt_plugin")
use_repo(cpu, "libpjrt_cpu_darwin_amd64", "libpjrt_cpu_darwin_arm64", "libpjrt_cpu_linux_amd64")
@ -160,12 +144,14 @@ apt.install(
manifest = "//runtimes/cuda:packages.yaml",
)
use_repo(apt, "apt_cuda")
apt.install(
name = "apt_rocm",
lock = "//runtimes/rocm:packages.lock.json",
manifest = "//runtimes/rocm:packages.yaml",
)
use_repo(apt, "apt_rocm")
apt.install(
name = "apt_neuron",
lock = "//runtimes/neuron:packages.lock.json",

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,6 @@
"""Starlark implementation of zig_proto_library"""
load("@protobuf//bazel/common:proto_info.bzl", "ProtoInfo")
load("@com_google_protobuf//bazel/common:proto_info.bzl", "ProtoInfo")
load("@rules_proto//proto:defs.bzl", "proto_common")
load(
"@rules_zig//zig/private/providers:zig_module_info.bzl",

View File

@ -1,6 +1,6 @@
load("@bazel_skylib//rules:copy_file.bzl", "copy_file")
load("@protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@protobuf//bazel:proto_library.bzl", "proto_library")
load("@com_google_protobuf//bazel:cc_proto_library.bzl", "cc_proto_library")
load("@com_google_protobuf//bazel:proto_library.bzl", "proto_library")
package(
default_visibility = ["//visibility:public"],
@ -90,7 +90,7 @@ cc_library(
":darts_clone",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@protobuf//:protobuf_lite",
"@com_google_protobuf//:protobuf_lite",
"@zml//third_party/com_google_sentencepiece:absl",
],
)

12
upb/BUILD.bazel Normal file
View File

@ -0,0 +1,12 @@
load("@rules_zig//zig:defs.bzl", "zig_library")
cc_library(
name = "empty",
)
zig_library(
name = "upb",
main = "upb.zig",
deps = [":empty"],
visibility = ["//visibility:public"],
)

149
upb/upb.zig Normal file
View File

@ -0,0 +1,149 @@
const std = @import("std");
const c = @import("c");
export fn _upb_MiniTable_StrongReference_dont_copy_me__upb_internal_use_only(mt: *c.upb_MiniTable) callconv(.c) void {
const unused: *volatile c.upb_MiniTable = mt;
_ = &unused;
}
pub const SerializeOptions = packed struct(c_int) {
deterministic: bool = false,
skip_unknown: bool = false,
check_required: bool = false,
_ignored: u29 = 0,
};
pub const SerializeError = error{
MaxDepthExceeded,
MissingRequired,
Unknown,
} || std.mem.Allocator.Error;
pub const ParseOptions = packed struct {
alias_string: bool = false,
check_required: bool = false,
experimental_allow_unlinked: bool = false,
always_validate_utf8: bool = false,
disable_fast_table: bool = false,
_ignored: u27 = 0,
};
pub const ParseError = error{
Malformed,
BadUtf8,
MaxDepthExceeded,
MissingRequired,
UnlinkedSubMessage,
Unknown,
} || std.mem.Allocator.Error;
pub fn stringView(data: ?[]const u8) c.upb_StringView {
return if (data) |d| c.upb_StringView_FromDataAndSize(d.ptr, d.len) else .{};
}
pub fn slice(sv: c.upb_StringView) ?[]const u8 {
return if (sv.data) |d| d[0..sv.size] else null;
}
fn ProtoName(comptime UpbType: type) []const u8 {
const needle = ".struct_";
const type_name = @typeName(UpbType);
const idx = std.mem.indexOf(u8, type_name, needle) orelse @compileError("Type name is invalid");
return type_name[idx + needle.len ..];
}
fn Minitable(comptime UpbType: type) *const c.upb_MiniTable {
const field_name = comptime blk: {
const name = ProtoName(UpbType);
var it = std.mem.tokenizeScalar(u8, name, '_');
while (it.next()) |_| {
const new_name = name[0..it.index] ++ "_" ++ name[it.index..] ++ "_msg_init";
if (@hasDecl(c, new_name)) {
break :blk new_name;
}
} else {
@compileError("Unable to find minitable for type:" ++ @typeName(UpbType));
}
};
return &@field(c, field_name);
}
pub fn serializeEx(ptr: anytype, arena: *c.upb_Arena, opts: SerializeOptions) SerializeError![]const u8 {
var buf: [*c]u8 = undefined;
var size: usize = undefined;
return switch (c.upb_Encode(@ptrCast(ptr), Minitable(@TypeOf(ptr.*)), @bitCast(opts), arena, &buf, &size)) {
c.kUpb_EncodeStatus_Ok => buf[0..size],
c.kUpb_EncodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
c.kUpb_EncodeStatus_MaxDepthExceeded => SerializeError.MaxDepthExceeded,
c.kUpb_EncodeStatus_MissingRequired => SerializeError.MissingRequired,
else => return SerializeError.Unknown,
};
}
pub fn serialize(ptr: anytype, arena: *c.upb_Arena) SerializeError![]const u8 {
return try serializeEx(ptr, arena, .{});
}
pub fn parseEx(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8, opts: ParseOptions) ParseError!*UpbType {
const obj = try new(UpbType, arena);
return switch (c.upb_Decode(@ptrCast(@constCast(data)), data.len, @alignCast(@ptrCast(obj)), Minitable(UpbType), null, @bitCast(opts), arena)) {
c.kUpb_DecodeStatus_Ok => obj,
c.kUpb_DecodeStatus_Malformed => ParseError.Malformed,
c.kUpb_DecodeStatus_OutOfMemory => std.mem.Allocator.Error.OutOfMemory,
c.kUpb_DecodeStatus_BadUtf8 => ParseError.BadUtf8,
c.kUpb_DecodeStatus_MaxDepthExceeded => ParseError.MaxDepthExceeded,
c.kUpb_DecodeStatus_MissingRequired => ParseError.MissingRequired,
c.kUpb_DecodeStatus_UnlinkedSubMessage => ParseError.UnlinkedSubMessage,
else => ParseError.Unknown,
};
}
pub fn parse(comptime UpbType: type, arena: *c.upb_Arena, data: []const u8) ParseError!*UpbType {
return parseEx(arena, data, .{});
}
pub fn new(comptime UpbType: type, arena: *c.upb_Arena) std.mem.Allocator.Error!*UpbType {
const new_fn = @field(c, ProtoName(UpbType) ++ "_new");
return @ptrCast(new_fn(arena) orelse return std.mem.Allocator.Error.OutOfMemory);
}
pub const Allocator = struct {
upb_alloc: c.upb_alloc,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator) Allocator {
return .{
.upb_alloc = .{
.func = &alloc_impl,
},
.allocator = allocator,
};
}
pub fn inner(self: *Allocator) *c.upb_alloc {
return &self.upb_alloc;
}
fn alloc_impl(alloc: [*c]c.upb_alloc, ptr: ?*anyopaque, oldsize: usize, size: usize, actual_size: [*c]usize) callconv(.c) ?*anyopaque {
const upb_alloc: *c.upb_alloc = alloc orelse return null;
const self: *Allocator = @fieldParentPtr("upb_alloc", upb_alloc);
defer {
if (actual_size) |as| {
as.* = size;
}
}
if (ptr) |ptr_| {
const ptr_as_slice: []u8 = @as([*c]u8, @ptrCast(ptr_))[0..oldsize];
if (size == 0) {
self.allocator.free(ptr_as_slice);
return null;
} else if (size != oldsize) {
return (self.allocator.realloc(ptr_as_slice, size) catch return null).ptr;
}
@panic("Unsupported case");
}
return (self.allocator.alloc(u8, size) catch return null).ptr;
}
};

View File

@ -1,9 +1,19 @@
load("@com_google_protobuf//bazel:upb_proto_library.bzl", "upb_c_proto_library")
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_zig//zig:defs.bzl", "zig_library")
load("//bazel:zig.bzl", "zig_cc_test")
load("//bazel:zig_proto_library.bzl", "zig_proto_library")
load("//bazel:zig_srcs.bzl", "zig_srcs")
upb_c_proto_library(
name = "xla_data_upb",
deps = ["@xla//xla:xla_data_proto"],
)
upb_c_proto_library(
name = "xla_compile_options_upb",
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
)
cc_library(
name = "posix",
hdrs = ["posix.h"],
@ -25,26 +35,21 @@ zig_library(
visibility = ["//visibility:public"],
deps = [
":posix",
":xla_proto",
":xla_compile_options_upb",
":xla_data_upb",
"//async",
"//mlir",
"//mlir/dialects",
"//pjrt",
"//runtimes",
"//stdx",
"//upb",
"//zml/tokenizer",
"//zml/tools",
"@rules_zig//zig/runfiles",
],
)
zig_proto_library(
name = "xla_proto",
import_name = "//xla:xla_proto",
deps = ["@xla//xla/pjrt/proto:compile_options_proto"],
)
# All ZML Tests
zig_cc_test(

View File

@ -1,10 +1,11 @@
const std = @import("std");
const asynk = @import("async");
const c = @import("c");
const dialect = @import("mlir/dialects");
const mlir = @import("mlir");
const stdx = @import("stdx");
const xla_pb = @import("//xla:xla_proto");
const upb = @import("upb");
const BaseExe = @import("exe.zig").BaseExe;
const Buffer = @import("buffer.zig").Buffer;
@ -834,6 +835,28 @@ fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecut
try loaded_executable_file.writeAll(serialize_result.bytes);
}
fn setXlaOverrideFlag(map: *c.upb_Map, flag: []const u8, value: anytype, upb_arena: *c.upb_Arena) !void {
const result = c.upb_Map_Set(
map,
.{ .str_val = upb.stringView(flag) },
.{ .msg_val = blk: {
const field = try upb.new(c.xla_OptionOverrideProto, upb_arena);
switch (@typeInfo(@TypeOf(value))) {
.bool => c.xla_OptionOverrideProto_set_bool_field(field, value),
.comptime_int, .int => c.xla_OptionOverrideProto_set_int_field(field, @intCast(value)),
.comptime_float, .float => c.xla_OptionOverrideProto_set_double_field(field, @floatCast(value)),
else => c.xla_OptionOverrideProto_set_string_field(field, upb.stringView(value)),
}
break :blk @ptrCast(field);
} },
upb_arena,
);
if (result == false) {
return std.mem.Allocator.Error.OutOfMemory;
}
}
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable {
const tracer = Tracer.init("ai.zml.compilation");
const compile_frame = tracer.frameStart("pjrt compilation");
@ -841,86 +864,83 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
const sharding = platform.sharding();
var options: xla_pb.CompileOptionsProto = .{
.executable_build_options = .{
.device_ordinal = -1,
.num_replicas = sharding.num_replicas,
.num_partitions = sharding.num_partitions,
.use_spmd_partitioning = sharding.num_partitions > 1 or sharding.num_replicas > 1,
.device_assignment = .{
.replica_count = sharding.num_replicas,
.computation_count = sharding.num_partitions,
// Filled below.
.computation_devices = .{},
var upb_alloc: upb.Allocator = .init(arena);
const upb_arena = c.upb_Arena_Init(null, 0, upb_alloc.inner());
defer c.upb_Arena_Free(upb_arena);
const options = blk: {
const options = try upb.new(c.xla_CompileOptionsProto, upb_arena);
c.xla_CompileOptionsProto_set_executable_build_options(options, executable_build_options_blk: {
const exec_build_options = try upb.new(c.xla_ExecutableBuildOptionsProto, upb_arena);
c.xla_ExecutableBuildOptionsProto_set_device_ordinal(exec_build_options, -1);
c.xla_ExecutableBuildOptionsProto_set_num_replicas(exec_build_options, sharding.num_replicas);
c.xla_ExecutableBuildOptionsProto_set_num_partitions(exec_build_options, sharding.num_partitions);
c.xla_ExecutableBuildOptionsProto_set_use_spmd_partitioning(exec_build_options, sharding.num_partitions > 1 or sharding.num_replicas > 1);
c.xla_ExecutableBuildOptionsProto_set_device_assignment(exec_build_options, device_assignment_blk: {
const device_assignment = try upb.new(c.xla_DeviceAssignmentProto, upb_arena);
c.xla_DeviceAssignmentProto_set_replica_count(device_assignment, sharding.num_replicas);
c.xla_DeviceAssignmentProto_set_computation_count(device_assignment, sharding.num_partitions);
const computation_devices = c.xla_DeviceAssignmentProto_resize_computation_devices(device_assignment, sharding.num_partitions, upb_arena);
for (computation_devices[0..sharding.num_partitions], 0..) |*computation_device, i| {
computation_device.* = try upb.new(c.xla_DeviceAssignmentProto_ComputationDevice, upb_arena);
_ = c.xla_DeviceAssignmentProto_ComputationDevice_add_replica_device_ids(computation_device.*, @intCast(i), upb_arena);
}
break :device_assignment_blk device_assignment;
});
break :executable_build_options_blk exec_build_options;
});
const overrides_map = c._xla_CompileOptionsProto_env_option_overrides_mutable_upb_map(options, upb_arena);
switch (platform.target) {
.cuda => {
// NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_latency_hiding_scheduler", true, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_llvm_module_compilation_parallelism", true, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_libnvptxcompiler", true, upb_arena);
},
},
.rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
// enabled on CDNA and it's used on RDNA. Disable it altogether.
try setXlaOverrideFlag(overrides_map, "xla_gpu_enable_triton_gemm", false, upb_arena);
// Use lld from libllvm instead of invoking the ld.lld binary.
// This saves us from having to sandbox it.
try setXlaOverrideFlag(overrides_map, "xla_gpu_use_inprocess_lld", true, upb_arena);
},
else => {},
}
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
try setXlaOverrideFlag(overrides_map, "xla_dump_to", xla_dump_to, upb_arena);
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_as_proto", true, upb_arena);
if (platform.compilation_options.xla_dump_fusion_visualization) {
try setXlaOverrideFlag(overrides_map, "xla_dump_fusion_visualization", true, upb_arena);
}
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
try setXlaOverrideFlag(overrides_map, "xla_dump_hlo_pass_re", re, upb_arena);
}
}
break :blk options;
};
const computation_devices = &options.executable_build_options.?.device_assignment.?.computation_devices;
try computation_devices.ensureTotalCapacity(arena, sharding.num_partitions);
const replica_device_ids = try arena.alloc(i64, sharding.num_partitions);
for (0..sharding.num_partitions) |i| {
replica_device_ids[i] = @intCast(i);
computation_devices.appendAssumeCapacity(.{ .replica_device_ids = .fromOwnedSlice(replica_device_ids[i .. i + 1]) });
}
// Let the arena deinit, zig-protobuf deinit is very slow.
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
setFlag(&options, "xla_dump_to", xla_dump_to);
setFlag(&options, "xla_dump_hlo_as_proto", true);
if (platform.compilation_options.xla_dump_fusion_visualization) {
setFlag(&options, "xla_dump_fusion_visualization", true);
}
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
setFlag(&options, "xla_dump_hlo_pass_re", re);
}
}
switch (platform.target) {
.cuda => {
// NVIDIA recommends these settings
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
setFlag(&options, "xla_gpu_enable_llvm_module_compilation_parallelism", true);
setFlag(&options, "xla_gpu_enable_libnvptxcompiler", true);
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
// setFlags(&options, "xla_gpu_enable_address_computation_fusion", true);
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_enable_while_loop_double_buffering", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
},
.rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
// enabled on CDNA and it's used on RDNA. Disable it altogether.
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
// Use lld from libllvm instead of invoking the ld.lld binary.
// This saves us from having to sandbox it.
setFlag(&options, "xla_gpu_use_inprocess_lld", true);
},
else => {},
}
const options_bytes = try options.encode(arena);
const loaded_executable = try platform.pjrt_client.compile(platform.pjrt_api, arena, module, options_bytes);
const loaded_executable = try platform.pjrt_client.compile(
platform.pjrt_api,
arena,
module,
try upb.serialize(options, upb_arena),
);
errdefer loaded_executable.deinit();
return loaded_executable;
}
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
.bool => .{ .value = .{ .bool_field = value } },
.comptime_int, .int => .{ .value = .{ .int_field = value } },
.comptime_float, .float => .{ .value = .{ .double_field = value } },
else => .{ .value = .{ .string_field = .{ .Const = value } } },
};
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
}
/// Visit the given struct and recursively counts the number of tensors found.
pub fn countTensors(v: anytype) usize {
const LocalContext = struct {