Update PJRT, runtime, and ZML modules to use per‑target output folders and expose profiler.dumpDataAsJson for JSON profiling output.
This commit is contained in:
parent
22a846de72
commit
37725cdaa6
@ -10,11 +10,13 @@ cc_library(
|
|||||||
|
|
||||||
zig_library(
|
zig_library(
|
||||||
name = "pjrt",
|
name = "pjrt",
|
||||||
srcs = ["profiler.zig"],
|
srcs = ["profiler.zig"] + glob(["convert/*.zig"]),
|
||||||
main = "pjrt.zig",
|
main = "pjrt.zig",
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":profiler_options_proto",
|
":profiler_options_proto",
|
||||||
|
":trace_events_proto",
|
||||||
|
":xplane_proto",
|
||||||
"//stdx",
|
"//stdx",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
|
||||||
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
|
||||||
|
|||||||
@ -76,7 +76,7 @@ pub const TraceContainer = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: std.io.AnyWriter) !void {
|
fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: anytype) !void {
|
||||||
if (stat.value) |val| {
|
if (stat.value) |val| {
|
||||||
switch (val) {
|
switch (val) {
|
||||||
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
|
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
|
||||||
@ -231,7 +231,7 @@ pub const TraceContainer = struct {
|
|||||||
self.events.shrinkRetainingCapacity(max_count);
|
self.events.shrinkRetainingCapacity(max_count);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn toJson(self: *TraceContainer, writer: std.io.AnyWriter) !void {
|
pub fn toJson(self: *TraceContainer, writer: anytype) !void {
|
||||||
try writer.writeAll(
|
try writer.writeAll(
|
||||||
\\{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},"traceEvents":[
|
\\{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},"traceEvents":[
|
||||||
);
|
);
|
||||||
|
|||||||
@ -359,7 +359,7 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.warn("No profiler found for platform: {}", .{self});
|
log.warn("No profiler found for platform: {}", .{self});
|
||||||
return Profiler.init(null, options);
|
return Profiler.init(null, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {
|
||||||
|
|||||||
@ -2,7 +2,8 @@ const std = @import("std");
|
|||||||
const c = @import("c");
|
const c = @import("c");
|
||||||
const tsl_proto = @import("//tsl:profiler_options_proto");
|
const tsl_proto = @import("//tsl:profiler_options_proto");
|
||||||
|
|
||||||
const log = std.log.scoped(.@"zml/profiler");
|
const log = std.log.scoped(.@"pjrt/profiler");
|
||||||
|
const TraceContainer = @import("convert/trace_container.zig").TraceContainer;
|
||||||
|
|
||||||
/// Pjrt Profiler extension
|
/// Pjrt Profiler extension
|
||||||
pub const Profiler = struct {
|
pub const Profiler = struct {
|
||||||
@ -15,20 +16,35 @@ pub const Profiler = struct {
|
|||||||
pub const Error = c.PLUGIN_Profiler_Error;
|
pub const Error = c.PLUGIN_Profiler_Error;
|
||||||
pub const Options = tsl_proto.ProfileOptions;
|
pub const Options = tsl_proto.ProfileOptions;
|
||||||
|
|
||||||
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: Options) Profiler {
|
pub const default_options: Options = .{
|
||||||
|
.version = 1,
|
||||||
|
.device_type = .UNSPECIFIED, // profile all devices
|
||||||
|
.include_dataset_ops = false, // tensorflow specific
|
||||||
|
.host_tracer_level = 2,
|
||||||
|
.device_tracer_level = 1,
|
||||||
|
.python_tracer_level = 0,
|
||||||
|
.enable_hlo_proto = true,
|
||||||
|
.start_timestamp_ns = 0,
|
||||||
|
.duration_ms = 0,
|
||||||
|
.repository_path = .Empty,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: ?Options) Profiler {
|
||||||
if (api == null) {
|
if (api == null) {
|
||||||
return .{ .api = null, .inner = undefined };
|
return .{ .api = null, .inner = undefined };
|
||||||
}
|
}
|
||||||
|
var options_with_timestamp = options orelse default_options;
|
||||||
|
options_with_timestamp.start_timestamp_ns = @truncate(@max(0, std.time.nanoTimestamp()));
|
||||||
|
|
||||||
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
|
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
|
||||||
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
||||||
const byte_options = options.encode(fba.allocator()) catch unreachable;
|
const byte_options = options_with_timestamp.encode(fba.allocator()) catch unreachable;
|
||||||
var res: Profiler = .{ .api = api, .inner = undefined };
|
|
||||||
var args: c.PLUGIN_Profiler_Create_Args = .{
|
var args: c.PLUGIN_Profiler_Create_Args = .{
|
||||||
.options = byte_options.ptr,
|
.options = byte_options.ptr,
|
||||||
.options_size = byte_options.len,
|
.options_size = byte_options.len,
|
||||||
.profiler = undefined, // out
|
.profiler = undefined, // out
|
||||||
};
|
};
|
||||||
|
var res: Profiler = .{ .api = api, .inner = undefined };
|
||||||
res.check(api.?.create.?(&args)) catch unreachable;
|
res.check(api.?.create.?(&args)) catch unreachable;
|
||||||
|
|
||||||
res.inner = args.profiler.?;
|
res.inner = args.profiler.?;
|
||||||
@ -70,15 +86,15 @@ pub const Profiler = struct {
|
|||||||
};
|
};
|
||||||
try self.check(self.api.?.collect_data.?(&args));
|
try self.check(self.api.?.collect_data.?(&args));
|
||||||
std.debug.assert(args.buffer_size_in_bytes > 0);
|
std.debug.assert(args.buffer_size_in_bytes > 0);
|
||||||
const buffer: ProfilingData = if (args.buffer == null) blk: {
|
return if (args.buffer == null) blk: {
|
||||||
std.log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
|
log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
|
||||||
// The plugin want us to allocate memory for it:
|
// The plugin want us to allocate memory for it:
|
||||||
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
|
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
|
||||||
args.buffer = buffer.ptr;
|
args.buffer = buffer.ptr;
|
||||||
try self.check(self.api.?.collect_data.?(&args));
|
try self.check(self.api.?.collect_data.?(&args));
|
||||||
break :blk .{ .owned = buffer };
|
break :blk .{ .owned = buffer };
|
||||||
} else blk: {
|
} else blk: {
|
||||||
std.log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
|
log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
|
||||||
// Drop sentinel. The profiler plugin returns a null terminated string.
|
// Drop sentinel. The profiler plugin returns a null terminated string.
|
||||||
// But this is creating issues if we save the sentinel on disk,
|
// But this is creating issues if we save the sentinel on disk,
|
||||||
// because it will trip up protobuf readers.
|
// because it will trip up protobuf readers.
|
||||||
@ -86,9 +102,6 @@ pub const Profiler = struct {
|
|||||||
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
|
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
|
||||||
break :blk .{ .external = data };
|
break :blk .{ .external = data };
|
||||||
};
|
};
|
||||||
|
|
||||||
// printDataAsXSpace(allocator, buffer.items());
|
|
||||||
return buffer;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dumpDataTo(
|
pub fn dumpDataTo(
|
||||||
@ -108,6 +121,31 @@ pub const Profiler = struct {
|
|||||||
return try file.writeAll(profile_data.items());
|
return try file.writeAll(profile_data.items());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dumpAsJsonTo(
|
||||||
|
self: *Profiler,
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
dir: std.fs.Dir,
|
||||||
|
file_name: []const u8,
|
||||||
|
) !void {
|
||||||
|
const profile_data = try self.collectData(allocator);
|
||||||
|
defer profile_data.free(allocator);
|
||||||
|
|
||||||
|
if (profile_data.items().len == 0) {
|
||||||
|
log.warn("No profile data was collected: {}", .{self});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
var converter = try TraceContainer.init(allocator, profile_data.items(), null);
|
||||||
|
defer converter.deinit();
|
||||||
|
|
||||||
|
var output_file = try dir.createFile(file_name, .{});
|
||||||
|
defer output_file.close();
|
||||||
|
var buffered_writer = std.io.bufferedWriter(output_file.writer());
|
||||||
|
log.info("Writing profiling data to {s}", .{file_name});
|
||||||
|
try converter.toJson(buffered_writer.writer());
|
||||||
|
try buffered_writer.flush();
|
||||||
|
}
|
||||||
|
|
||||||
fn check(self: *Profiler, c_error: ?*Error) !void {
|
fn check(self: *Profiler, c_error: ?*Error) !void {
|
||||||
if (c_error) |err| {
|
if (c_error) |err| {
|
||||||
self.last_error = err;
|
self.last_error = err;
|
||||||
|
|||||||
@ -103,6 +103,8 @@ fn comptimeStrJoin(comptime separator: [:0]const u8, comptime slices: []const [:
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn setNeuronCCFlags() void {
|
pub fn setNeuronCCFlags() void {
|
||||||
|
// See neuronxcc reference:
|
||||||
|
// https://awsdocs-neuron.readthedocs-hosted.com/en/latest/compiler/neuronx-cc/api-reference-guide/neuron-compiler-cli-reference-guide.html#neuron-compiler-cli-reference-guide
|
||||||
_ = c.setenv("NEURON_CC_FLAGS", comptimeStrJoin(" ", &.{
|
_ = c.setenv("NEURON_CC_FLAGS", comptimeStrJoin(" ", &.{
|
||||||
// 30% faster, no visible speed difference on llama
|
// 30% faster, no visible speed difference on llama
|
||||||
"--optlevel=1",
|
"--optlevel=1",
|
||||||
|
|||||||
140
zml/module.zig
140
zml/module.zig
@ -203,35 +203,39 @@ pub const CompilationContext = struct {
|
|||||||
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
|
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
|
||||||
|
|
||||||
const module_hash = computeModuleHash(self._platform, module);
|
const module_hash = computeModuleHash(self._platform, module);
|
||||||
|
var module_dir: ?[]const u8 = null;
|
||||||
|
var pjrt_location: ?[:0]const u8 = null;
|
||||||
|
|
||||||
if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||||
|
const sep = std.fs.path.sep_str;
|
||||||
|
const module_dir_name = try std.fmt.allocPrint(arena, "{s}{s}{s}{s}{s}_{x}", .{ xla_dump_to, sep, @tagName(self._platform.target), sep, self._name, module_hash });
|
||||||
|
try std.fs.cwd().makePath(module_dir_name);
|
||||||
|
module_dir = try std.fs.cwd().realpathAlloc(arena, module_dir_name);
|
||||||
|
const cache_dir = try std.fs.cwd().openDir(module_dir.?, .{});
|
||||||
|
|
||||||
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
|
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
|
||||||
if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| {
|
const mlir_name = "module.mlir";
|
||||||
const name = self._name;
|
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
|
||||||
const file_name = std.fmt.allocPrint(arena, "{s}_{x}.mlir", .{ name, module_hash }) catch name;
|
|
||||||
if (dir.createFile(file_name, .{ .truncate = true })) |file| {
|
|
||||||
module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false });
|
module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false });
|
||||||
log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name });
|
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
|
||||||
} else |_| {
|
} else |_| {
|
||||||
log.warn("Failed to open {s}", .{file_name});
|
log.warn("Failed to open {s}", .{mlir_name});
|
||||||
}
|
|
||||||
} else |_| {
|
|
||||||
log.warn("Folder not found {s}", .{xla_dump_to});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const tracer = Tracer.init("ai.zml.compilation");
|
pjrt_location = try std.fs.path.joinZ(arena, &.{ module_dir.?, "module.pjrt" });
|
||||||
const compile_frame = tracer.frameStart("pjrt cached compilation");
|
}
|
||||||
defer tracer.frameEnd(compile_frame, "pjrt cached compilation");
|
|
||||||
|
|
||||||
const loaded_executable: *pjrt.LoadedExecutable = blk: {
|
const loaded_executable: *pjrt.LoadedExecutable = blk: {
|
||||||
const cache_location = try absoluteCacheFileZ(arena, self._platform.compilation_options.cache_location, module_hash);
|
if (pjrt_location) |pjrt_loc| {
|
||||||
if (cache_location) |cache_file| {
|
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
|
||||||
if (loadPjrtExecutable(arena, self._platform, cache_file)) |exe| {
|
log.info("Loaded pre-compiled module from {s}", .{pjrt_loc});
|
||||||
break :blk exe;
|
break :blk exe;
|
||||||
} else |_| {}
|
} else |err| {
|
||||||
|
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module) catch |err| {
|
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir.?) catch |err| {
|
||||||
log.err(
|
log.err(
|
||||||
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
||||||
.{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err },
|
.{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err },
|
||||||
@ -239,9 +243,9 @@ pub const CompilationContext = struct {
|
|||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (cache_location) |cache_file| {
|
if (pjrt_location) |pjrt_loc| {
|
||||||
storePjrtExecutable(self._platform, loaded_executable, cache_file) catch |err| {
|
storePjrtExecutable(self._platform, loaded_executable, pjrt_loc) catch |err| {
|
||||||
log.debug("Failed to store module: {}", .{err});
|
log.warn("Failed to store compiled module: {} at {s}", .{ err, pjrt_loc });
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
break :blk loaded_executable;
|
break :blk loaded_executable;
|
||||||
@ -813,22 +817,13 @@ fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
|
|||||||
return hasher.final();
|
return hasher.final();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn absoluteCacheFileZ(arena: std.mem.Allocator, cache_path: ?[]const u8, module_hash: u64) !?[:0]const u8 {
|
|
||||||
if (cache_path == null) return null;
|
|
||||||
const resolved_path = try std.fs.cwd().realpathAlloc(arena, cache_path.?);
|
|
||||||
std.fs.makeDirAbsolute(resolved_path) catch |err| switch (err) {
|
|
||||||
error.PathAlreadyExists => {},
|
|
||||||
else => return err,
|
|
||||||
};
|
|
||||||
|
|
||||||
var buf: [24]u8 = undefined;
|
|
||||||
const module_name = std.fmt.bufPrint(&buf, "{x}.pjrt", .{module_hash}) catch unreachable;
|
|
||||||
return try std.fs.path.joinZ(arena, &.{ resolved_path, module_name });
|
|
||||||
}
|
|
||||||
|
|
||||||
const max_pjrt_executable_size = 400 * 1024 * 1024;
|
const max_pjrt_executable_size = 400 * 1024 * 1024;
|
||||||
|
|
||||||
fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable {
|
fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable {
|
||||||
|
const tracer = Tracer.init("ai.zml.load_exe");
|
||||||
|
const compile_frame = tracer.frameStart("pjrt load executable");
|
||||||
|
defer tracer.frameEnd(compile_frame, "pjrt load executable");
|
||||||
|
|
||||||
const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{});
|
const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{});
|
||||||
defer loaded_executable_file.close();
|
defer loaded_executable_file.close();
|
||||||
|
|
||||||
@ -837,12 +832,7 @@ fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_fil
|
|||||||
defer arena.free(bytes);
|
defer arena.free(bytes);
|
||||||
|
|
||||||
const size = try loaded_executable_file.readAll(bytes);
|
const size = try loaded_executable_file.readAll(bytes);
|
||||||
|
return try platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]);
|
||||||
log.info("Loading module from {s}", .{absolute_file});
|
|
||||||
return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]) catch |err| {
|
|
||||||
log.warn("Failed to load module: {}", .{err});
|
|
||||||
return err;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void {
|
fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void {
|
||||||
@ -856,10 +846,13 @@ fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecut
|
|||||||
defer serialize_result.deinit();
|
defer serialize_result.deinit();
|
||||||
|
|
||||||
try loaded_executable_file.writeAll(serialize_result.bytes);
|
try loaded_executable_file.writeAll(serialize_result.bytes);
|
||||||
log.info("Stored module to {s}", .{absolute_file});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module) !*pjrt.LoadedExecutable {
|
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable {
|
||||||
|
const tracer = Tracer.init("ai.zml.compilation");
|
||||||
|
const compile_frame = tracer.frameStart("pjrt compilation");
|
||||||
|
defer tracer.frameEnd(compile_frame, "pjrt compilation");
|
||||||
|
|
||||||
const sharding = platform.sharding();
|
const sharding = platform.sharding();
|
||||||
|
|
||||||
// NOTE(Corendos): Hack needed because Protobuf struct are not public.
|
// NOTE(Corendos): Hack needed because Protobuf struct are not public.
|
||||||
@ -887,38 +880,27 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Let the arena deinit, zig-protobuf deinit is very slow.
|
// Let the arena deinit, zig-protobuf deinit is very slow.
|
||||||
if (platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
||||||
try options.env_option_overrides.append(arena, .{
|
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||||
.key = .{ .Const = "xla_dump_to" },
|
setFlag(&options, "xla_dump_to", xla_dump_to);
|
||||||
.value = .{ .value = .{ .string_field = .{ .Const = xla_dump_to } } },
|
|
||||||
});
|
|
||||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||||
try options.env_option_overrides.append(arena, .{
|
setFlag(&options, "xla_dump_hlo_as_html", true);
|
||||||
.key = .{ .Const = "xla_dump_hlo_as_html" },
|
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
||||||
.value = .{ .value = .{ .bool_field = true } },
|
setFlag(&options, "xla_dump_fusion_visualization", true);
|
||||||
});
|
|
||||||
try options.env_option_overrides.append(arena, .{
|
|
||||||
.key = .{ .Const = "xla_dump_hlo_as_dot" },
|
|
||||||
.value = .{ .value = .{ .bool_field = true } },
|
|
||||||
});
|
|
||||||
try options.env_option_overrides.append(arena, .{
|
|
||||||
.key = .{ .Const = "xla_dump_fusion_visualization" },
|
|
||||||
.value = .{ .value = .{ .bool_field = true } },
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch (platform.target) {
|
switch (platform.target) {
|
||||||
.cuda => cuda_dir: {
|
.cuda => cuda_dir: {
|
||||||
// NVIDIA recommends to disable Triton GEMM on JAX:
|
// NVIDIA recommends to disable Triton GEMM on JAX:
|
||||||
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
|
||||||
try options.env_option_overrides.append(arena, .{
|
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||||
.key = .{ .Const = "xla_gpu_enable_triton_gemm" },
|
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
|
||||||
.value = .{ .value = .{ .bool_field = false } },
|
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
|
||||||
});
|
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
|
||||||
// try options.env_option_overrides.append(arena, .{
|
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
|
||||||
// .key = .{ .Const = "xla_gpu_enable_latency_hiding_scheduler" },
|
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
|
||||||
// .value = .{ .value = .{ .bool_field = true } },
|
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
|
||||||
// });
|
// setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
|
||||||
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
|
||||||
log.warn("Bazel runfile not found !", .{});
|
log.warn("Bazel runfile not found !", .{});
|
||||||
break :cuda_dir;
|
break :cuda_dir;
|
||||||
@ -928,22 +910,12 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
const r = r_.withSourceRepo(source_repo);
|
const r = r_.withSourceRepo(source_repo);
|
||||||
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
|
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
|
||||||
log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
|
log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
|
||||||
try options.env_option_overrides.append(arena, .{
|
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
|
||||||
.key = .{ .Const = "xla_gpu_cuda_data_dir" },
|
|
||||||
.value = .{
|
|
||||||
.value = .{
|
|
||||||
.string_field = .{ .Const = cuda_data_dir },
|
|
||||||
},
|
|
||||||
},
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
.rocm => {
|
.rocm => {
|
||||||
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
|
||||||
// enabled on CDNA and it's used on RDNA. Disable it altogether.
|
// enabled on CDNA and it's used on RDNA. Disable it altogether.
|
||||||
try options.env_option_overrides.append(arena, .{
|
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
|
||||||
.key = .{ .Const = "xla_gpu_enable_triton_gemm" },
|
|
||||||
.value = .{ .value = .{ .bool_field = false } },
|
|
||||||
});
|
|
||||||
},
|
},
|
||||||
else => {},
|
else => {},
|
||||||
}
|
}
|
||||||
@ -956,6 +928,16 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
return loaded_executable;
|
return loaded_executable;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
||||||
|
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
||||||
|
.Bool => .{ .value = .{ .bool_field = value } },
|
||||||
|
.Int => .{ .value = .{ .int_field = value } },
|
||||||
|
.Float => .{ .value = .{ .double_field = value } },
|
||||||
|
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
||||||
|
};
|
||||||
|
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
||||||
|
}
|
||||||
|
|
||||||
/// Visit the given struct and recursively counts the number of tensors found.
|
/// Visit the given struct and recursively counts the number of tensors found.
|
||||||
pub fn countTensors(v: anytype) usize {
|
pub fn countTensors(v: anytype) usize {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
|
|||||||
@ -17,7 +17,6 @@ pub const available_targets = std.enums.values(Target);
|
|||||||
pub const CompilationOptions = struct {
|
pub const CompilationOptions = struct {
|
||||||
xla_dump_to: ?[]const u8 = null,
|
xla_dump_to: ?[]const u8 = null,
|
||||||
xla_dump_fusion_visualization: bool = false,
|
xla_dump_fusion_visualization: bool = false,
|
||||||
cache_location: ?[]const u8 = null,
|
|
||||||
sharding_enabled: bool = false,
|
sharding_enabled: bool = false,
|
||||||
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
||||||
};
|
};
|
||||||
@ -82,8 +81,8 @@ pub const Platform = struct {
|
|||||||
/// Returns the Profiler for this API.
|
/// Returns the Profiler for this API.
|
||||||
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
/// Not all platform have a profiling api, for those the profiler object will do nothing.
|
||||||
/// Platforms with known profiler extensions: cuda, xpu
|
/// Platforms with known profiler extensions: cuda, xpu
|
||||||
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
|
pub fn getProfiler(self: Platform, options: ?pjrt.Profiler.Options) pjrt.Profiler {
|
||||||
return self.pjrt_client.getProfiler(self.pjrt_api, options);
|
return self.pjrt_client.getProfiler(self.pjrt_api, options orelse pjrt.Profiler.default_options);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -13,14 +13,10 @@ var _platform: ?zml.Platform = null;
|
|||||||
pub fn env() zml.Platform {
|
pub fn env() zml.Platform {
|
||||||
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
|
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
|
||||||
if (_platform == null) {
|
if (_platform == null) {
|
||||||
_test_compile_opts = if (initCacheDir())
|
_test_compile_opts = .{
|
||||||
.{
|
|
||||||
.cache_location = "/tmp/zml/tests/cache",
|
|
||||||
.xla_dump_to = "/tmp/zml/tests/",
|
.xla_dump_to = "/tmp/zml/tests/",
|
||||||
.sharding_enabled = true,
|
.sharding_enabled = true,
|
||||||
}
|
};
|
||||||
else
|
|
||||||
.{};
|
|
||||||
|
|
||||||
var ctx = zml.Context.init() catch unreachable;
|
var ctx = zml.Context.init() catch unreachable;
|
||||||
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
|
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
|
||||||
@ -31,12 +27,6 @@ pub fn env() zml.Platform {
|
|||||||
|
|
||||||
var _test_compile_opts: zml.CompilationOptions = .{};
|
var _test_compile_opts: zml.CompilationOptions = .{};
|
||||||
|
|
||||||
fn initCacheDir() bool {
|
|
||||||
const tmp = std.fs.openDirAbsolute("/tmp", .{}) catch return false;
|
|
||||||
tmp.makePath("zml/tests/cache") catch return false;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// In neural network we generally care about the relative precision,
|
/// In neural network we generally care about the relative precision,
|
||||||
/// but on a given dimension, if the output is close to 0, then the precision
|
/// but on a given dimension, if the output is close to 0, then the precision
|
||||||
/// don't matter as much.
|
/// don't matter as much.
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user