Update PJRT, runtime, and ZML modules to use per‑target output folders and expose profiler.dumpDataAsJson for JSON profiling output.

This commit is contained in:
Tarry Singh 2023-12-04 10:38:10 +00:00
parent 22a846de72
commit 37725cdaa6
8 changed files with 124 additions and 111 deletions

View File

@ -10,11 +10,13 @@ cc_library(
zig_library(
name = "pjrt",
srcs = ["profiler.zig"],
srcs = ["profiler.zig"] + glob(["convert/*.zig"]),
main = "pjrt.zig",
visibility = ["//visibility:public"],
deps = [
":profiler_options_proto",
":trace_events_proto",
":xplane_proto",
"//stdx",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",

View File

@ -76,7 +76,7 @@ pub const TraceContainer = struct {
};
}
fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: std.io.AnyWriter) !void {
fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: anytype) !void {
if (stat.value) |val| {
switch (val) {
inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}),
@ -231,7 +231,7 @@ pub const TraceContainer = struct {
self.events.shrinkRetainingCapacity(max_count);
}
pub fn toJson(self: *TraceContainer, writer: std.io.AnyWriter) !void {
pub fn toJson(self: *TraceContainer, writer: anytype) !void {
try writer.writeAll(
\\{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},"traceEvents":[
);

View File

@ -359,7 +359,7 @@ pub const Client = opaque {
}
}
log.warn("No profiler found for platform: {}", .{self});
return Profiler.init(null, options);
return Profiler.init(null, null);
}
pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable {

View File

@ -2,7 +2,8 @@ const std = @import("std");
const c = @import("c");
const tsl_proto = @import("//tsl:profiler_options_proto");
const log = std.log.scoped(.@"zml/profiler");
const log = std.log.scoped(.@"pjrt/profiler");
const TraceContainer = @import("convert/trace_container.zig").TraceContainer;
/// Pjrt Profiler extension
pub const Profiler = struct {
@ -15,20 +16,35 @@ pub const Profiler = struct {
pub const Error = c.PLUGIN_Profiler_Error;
pub const Options = tsl_proto.ProfileOptions;
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: Options) Profiler {
pub const default_options: Options = .{
.version = 1,
.device_type = .UNSPECIFIED, // profile all devices
.include_dataset_ops = false, // tensorflow specific
.host_tracer_level = 2,
.device_tracer_level = 1,
.python_tracer_level = 0,
.enable_hlo_proto = true,
.start_timestamp_ns = 0,
.duration_ms = 0,
.repository_path = .Empty,
};
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: ?Options) Profiler {
if (api == null) {
return .{ .api = null, .inner = undefined };
}
var options_with_timestamp = options orelse default_options;
options_with_timestamp.start_timestamp_ns = @truncate(@max(0, std.time.nanoTimestamp()));
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buffer);
const byte_options = options.encode(fba.allocator()) catch unreachable;
var res: Profiler = .{ .api = api, .inner = undefined };
const byte_options = options_with_timestamp.encode(fba.allocator()) catch unreachable;
var args: c.PLUGIN_Profiler_Create_Args = .{
.options = byte_options.ptr,
.options_size = byte_options.len,
.profiler = undefined, // out
};
var res: Profiler = .{ .api = api, .inner = undefined };
res.check(api.?.create.?(&args)) catch unreachable;
res.inner = args.profiler.?;
@ -70,15 +86,15 @@ pub const Profiler = struct {
};
try self.check(self.api.?.collect_data.?(&args));
std.debug.assert(args.buffer_size_in_bytes > 0);
const buffer: ProfilingData = if (args.buffer == null) blk: {
std.log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
return if (args.buffer == null) blk: {
log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
// The plugin want us to allocate memory for it:
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
args.buffer = buffer.ptr;
try self.check(self.api.?.collect_data.?(&args));
break :blk .{ .owned = buffer };
} else blk: {
std.log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
// Drop sentinel. The profiler plugin returns a null terminated string.
// But this is creating issues if we save the sentinel on disk,
// because it will trip up protobuf readers.
@ -86,9 +102,6 @@ pub const Profiler = struct {
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
break :blk .{ .external = data };
};
// printDataAsXSpace(allocator, buffer.items());
return buffer;
}
pub fn dumpDataTo(
@ -108,6 +121,31 @@ pub const Profiler = struct {
return try file.writeAll(profile_data.items());
}
pub fn dumpAsJsonTo(
self: *Profiler,
allocator: std.mem.Allocator,
dir: std.fs.Dir,
file_name: []const u8,
) !void {
const profile_data = try self.collectData(allocator);
defer profile_data.free(allocator);
if (profile_data.items().len == 0) {
log.warn("No profile data was collected: {}", .{self});
return;
}
var converter = try TraceContainer.init(allocator, profile_data.items(), null);
defer converter.deinit();
var output_file = try dir.createFile(file_name, .{});
defer output_file.close();
var buffered_writer = std.io.bufferedWriter(output_file.writer());
log.info("Writing profiling data to {s}", .{file_name});
try converter.toJson(buffered_writer.writer());
try buffered_writer.flush();
}
fn check(self: *Profiler, c_error: ?*Error) !void {
if (c_error) |err| {
self.last_error = err;

View File

@ -103,6 +103,8 @@ fn comptimeStrJoin(comptime separator: [:0]const u8, comptime slices: []const [:
}
pub fn setNeuronCCFlags() void {
// See neuronxcc reference:
// https://awsdocs-neuron.readthedocs-hosted.com/en/latest/compiler/neuronx-cc/api-reference-guide/neuron-compiler-cli-reference-guide.html#neuron-compiler-cli-reference-guide
_ = c.setenv("NEURON_CC_FLAGS", comptimeStrJoin(" ", &.{
// 30% faster, no visible speed difference on llama
"--optlevel=1",

View File

@ -203,35 +203,39 @@ pub const CompilationContext = struct {
module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr());
const module_hash = computeModuleHash(self._platform, module);
var module_dir: ?[]const u8 = null;
var pjrt_location: ?[:0]const u8 = null;
if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| {
const sep = std.fs.path.sep_str;
const module_dir_name = try std.fmt.allocPrint(arena, "{s}{s}{s}{s}{s}_{x}", .{ xla_dump_to, sep, @tagName(self._platform.target), sep, self._name, module_hash });
try std.fs.cwd().makePath(module_dir_name);
module_dir = try std.fs.cwd().realpathAlloc(arena, module_dir_name);
const cache_dir = try std.fs.cwd().openDir(module_dir.?, .{});
// Write the mlir to a file. All errors are discarded, since this is for debugging only.
if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| {
const name = self._name;
const file_name = std.fmt.allocPrint(arena, "{s}_{x}.mlir", .{ name, module_hash }) catch name;
if (dir.createFile(file_name, .{ .truncate = true })) |file| {
const mlir_name = "module.mlir";
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false });
log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name });
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
} else |_| {
log.warn("Failed to open {s}", .{file_name});
}
} else |_| {
log.warn("Folder not found {s}", .{xla_dump_to});
}
log.warn("Failed to open {s}", .{mlir_name});
}
const tracer = Tracer.init("ai.zml.compilation");
const compile_frame = tracer.frameStart("pjrt cached compilation");
defer tracer.frameEnd(compile_frame, "pjrt cached compilation");
pjrt_location = try std.fs.path.joinZ(arena, &.{ module_dir.?, "module.pjrt" });
}
const loaded_executable: *pjrt.LoadedExecutable = blk: {
const cache_location = try absoluteCacheFileZ(arena, self._platform.compilation_options.cache_location, module_hash);
if (cache_location) |cache_file| {
if (loadPjrtExecutable(arena, self._platform, cache_file)) |exe| {
if (pjrt_location) |pjrt_loc| {
if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| {
log.info("Loaded pre-compiled module from {s}", .{pjrt_loc});
break :blk exe;
} else |_| {}
} else |err| {
if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc });
}
}
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module) catch |err| {
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir.?) catch |err| {
log.err(
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
.{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err },
@ -239,9 +243,9 @@ pub const CompilationContext = struct {
return err;
};
if (cache_location) |cache_file| {
storePjrtExecutable(self._platform, loaded_executable, cache_file) catch |err| {
log.debug("Failed to store module: {}", .{err});
if (pjrt_location) |pjrt_loc| {
storePjrtExecutable(self._platform, loaded_executable, pjrt_loc) catch |err| {
log.warn("Failed to store compiled module: {} at {s}", .{ err, pjrt_loc });
};
}
break :blk loaded_executable;
@ -813,22 +817,13 @@ fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
return hasher.final();
}
fn absoluteCacheFileZ(arena: std.mem.Allocator, cache_path: ?[]const u8, module_hash: u64) !?[:0]const u8 {
if (cache_path == null) return null;
const resolved_path = try std.fs.cwd().realpathAlloc(arena, cache_path.?);
std.fs.makeDirAbsolute(resolved_path) catch |err| switch (err) {
error.PathAlreadyExists => {},
else => return err,
};
var buf: [24]u8 = undefined;
const module_name = std.fmt.bufPrint(&buf, "{x}.pjrt", .{module_hash}) catch unreachable;
return try std.fs.path.joinZ(arena, &.{ resolved_path, module_name });
}
const max_pjrt_executable_size = 400 * 1024 * 1024;
fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable {
const tracer = Tracer.init("ai.zml.load_exe");
const compile_frame = tracer.frameStart("pjrt load executable");
defer tracer.frameEnd(compile_frame, "pjrt load executable");
const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{});
defer loaded_executable_file.close();
@ -837,12 +832,7 @@ fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_fil
defer arena.free(bytes);
const size = try loaded_executable_file.readAll(bytes);
log.info("Loading module from {s}", .{absolute_file});
return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]) catch |err| {
log.warn("Failed to load module: {}", .{err});
return err;
};
return try platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]);
}
fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void {
@ -856,10 +846,13 @@ fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecut
defer serialize_result.deinit();
try loaded_executable_file.writeAll(serialize_result.bytes);
log.info("Stored module to {s}", .{absolute_file});
}
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module) !*pjrt.LoadedExecutable {
fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable {
const tracer = Tracer.init("ai.zml.compilation");
const compile_frame = tracer.frameStart("pjrt compilation");
defer tracer.frameEnd(compile_frame, "pjrt compilation");
const sharding = platform.sharding();
// NOTE(Corendos): Hack needed because Protobuf struct are not public.
@ -887,38 +880,27 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
};
// Let the arena deinit, zig-protobuf deinit is very slow.
if (platform.compilation_options.xla_dump_to) |xla_dump_to| {
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_dump_to" },
.value = .{ .value = .{ .string_field = .{ .Const = xla_dump_to } } },
});
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
setFlag(&options, "xla_dump_to", xla_dump_to);
if (platform.compilation_options.xla_dump_fusion_visualization) {
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_dump_hlo_as_html" },
.value = .{ .value = .{ .bool_field = true } },
});
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_dump_hlo_as_dot" },
.value = .{ .value = .{ .bool_field = true } },
});
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_dump_fusion_visualization" },
.value = .{ .value = .{ .bool_field = true } },
});
setFlag(&options, "xla_dump_hlo_as_html", true);
setFlag(&options, "xla_dump_hlo_as_dot", true);
setFlag(&options, "xla_dump_fusion_visualization", true);
}
}
switch (platform.target) {
.cuda => cuda_dir: {
// NVIDIA recommends to disable Triton GEMM on JAX:
// https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_gpu_enable_triton_gemm" },
.value = .{ .value = .{ .bool_field = false } },
});
// try options.env_option_overrides.append(arena, .{
// .key = .{ .Const = "xla_gpu_enable_latency_hiding_scheduler" },
// .value = .{ .value = .{ .bool_field = true } },
// });
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
// setFlag(&options, "xla_gpu_enable_cudnn_fmha", true);
// setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true);
// setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true);
// setFlag(&options, "xla_gpu_enable_custom_fusions", true);
// setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true);
// setFlag(&options, "xla_gpu_use_runtime_fusion", true);
// setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true);
var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse {
log.warn("Bazel runfile not found !", .{});
break :cuda_dir;
@ -928,22 +910,12 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
const r = r_.withSourceRepo(source_repo);
const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?;
log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir});
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_gpu_cuda_data_dir" },
.value = .{
.value = .{
.string_field = .{ .Const = cuda_data_dir },
},
},
});
setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir);
},
.rocm => {
// Disable Triton GEMM on ROCM. For some reason it's much, much slower when
// enabled on CDNA and it's used on RDNA. Disable it altogether.
try options.env_option_overrides.append(arena, .{
.key = .{ .Const = "xla_gpu_enable_triton_gemm" },
.value = .{ .value = .{ .bool_field = false } },
});
setFlag(&options, "xla_gpu_enable_triton_gemm", false);
},
else => {},
}
@ -956,6 +928,16 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
return loaded_executable;
}
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
.Bool => .{ .value = .{ .bool_field = value } },
.Int => .{ .value = .{ .int_field = value } },
.Float => .{ .value = .{ .double_field = value } },
else => .{ .value = .{ .string_field = .{ .Const = value } } },
};
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
}
/// Visit the given struct and recursively counts the number of tensors found.
pub fn countTensors(v: anytype) usize {
const LocalContext = struct {

View File

@ -17,7 +17,6 @@ pub const available_targets = std.enums.values(Target);
pub const CompilationOptions = struct {
xla_dump_to: ?[]const u8 = null,
xla_dump_fusion_visualization: bool = false,
cache_location: ?[]const u8 = null,
sharding_enabled: bool = false,
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
};
@ -82,8 +81,8 @@ pub const Platform = struct {
/// Returns the Profiler for this API.
/// Not all platform have a profiling api, for those the profiler object will do nothing.
/// Platforms with known profiler extensions: cuda, xpu
pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler {
return self.pjrt_client.getProfiler(self.pjrt_api, options);
pub fn getProfiler(self: Platform, options: ?pjrt.Profiler.Options) pjrt.Profiler {
return self.pjrt_client.getProfiler(self.pjrt_api, options orelse pjrt.Profiler.default_options);
}
};

View File

@ -13,14 +13,10 @@ var _platform: ?zml.Platform = null;
pub fn env() zml.Platform {
if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block");
if (_platform == null) {
_test_compile_opts = if (initCacheDir())
.{
.cache_location = "/tmp/zml/tests/cache",
_test_compile_opts = .{
.xla_dump_to = "/tmp/zml/tests/",
.sharding_enabled = true,
}
else
.{};
};
var ctx = zml.Context.init() catch unreachable;
_platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts);
@ -31,12 +27,6 @@ pub fn env() zml.Platform {
var _test_compile_opts: zml.CompilationOptions = .{};
fn initCacheDir() bool {
const tmp = std.fs.openDirAbsolute("/tmp", .{}) catch return false;
tmp.makePath("zml/tests/cache") catch return false;
return true;
}
/// In neural network we generally care about the relative precision,
/// but on a given dimension, if the output is close to 0, then the precision
/// don't matter as much.