Radix/pjrt/profiler.zig

195 lines
7.0 KiB
Zig
Raw Normal View History

const std = @import("std");
const c = @import("c");
const tsl_proto = @import("//tsl:profiler_options_proto");
const log = std.log.scoped(.@"pjrt/profiler");
const TraceContainer = @import("convert/trace_container.zig").TraceContainer;
/// Pjrt Profiler extension
pub const Profiler = struct {
api: ?c.PLUGIN_Profiler_Api,
inner: *c.PLUGIN_Profiler,
last_error: ?*Error = null,
status: Status = .ready,
pub const Status = enum { ready, started, stopped, done };
pub const Error = c.PLUGIN_Profiler_Error;
pub const Options = tsl_proto.ProfileOptions;
pub const default_options: Options = .{
.version = 1,
.device_type = .UNSPECIFIED, // profile all devices
.include_dataset_ops = false, // tensorflow specific
.host_tracer_level = 2,
.device_tracer_level = 1,
.python_tracer_level = 0,
.enable_hlo_proto = true,
.start_timestamp_ns = 0,
.duration_ms = 0,
.repository_path = .Empty,
};
pub fn init(api: ?c.PLUGIN_Profiler_Api, options: ?Options) Profiler {
if (api == null) {
return .{ .api = null, .inner = undefined };
}
var options_with_timestamp = options orelse default_options;
options_with_timestamp.start_timestamp_ns = @truncate(@max(0, std.time.nanoTimestamp()));
var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined;
var fba = std.heap.FixedBufferAllocator.init(&buffer);
const byte_options = options_with_timestamp.encode(fba.allocator()) catch unreachable;
var args: c.PLUGIN_Profiler_Create_Args = .{
.options = byte_options.ptr,
.options_size = byte_options.len,
.profiler = undefined, // out
};
var res: Profiler = .{ .api = api, .inner = undefined };
res.check(api.?.create.?(&args)) catch unreachable;
res.inner = args.profiler.?;
return res;
}
fn transition(self: *Profiler, fn_name: []const u8, expected: Status, next: Status) void {
if (self.status == expected) {
self.status = next;
return;
}
std.debug.panic("Profiler can't `{s}()`. Current status: {}, expected: {}", .{ fn_name, self.status, expected });
}
pub fn start(self: *Profiler) void {
self.transition("start", .ready, .started);
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Start_Args = .{ .profiler = self.inner };
self.check(self.api.?.start.?(&args)) catch unreachable;
}
pub fn stop(self: *Profiler) void {
self.transition("stop", .started, .stopped);
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Stop_Args = .{ .profiler = self.inner };
self.check(self.api.?.stop.?(&args)) catch unreachable;
}
pub fn collectData(self: *Profiler, allocator: std.mem.Allocator) !ProfilingData {
self.transition("collect_data", .stopped, .done);
if (self.api == null) return .{ .external = &.{} };
var args: c.PLUGIN_Profiler_CollectData_Args = .{
.struct_size = c.PLUGIN_Profiler_CollectData_Args_STRUCT_SIZE,
.profiler = self.inner,
.buffer = null,
.buffer_size_in_bytes = 0,
};
try self.check(self.api.?.collect_data.?(&args));
std.debug.assert(args.buffer_size_in_bytes > 0);
return if (args.buffer == null) blk: {
log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes});
// The plugin want us to allocate memory for it:
const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes);
args.buffer = buffer.ptr;
try self.check(self.api.?.collect_data.?(&args));
break :blk .{ .owned = buffer };
} else blk: {
log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes});
// Drop sentinel. The profiler plugin returns a null terminated string.
// But this is creating issues if we save the sentinel on disk,
// because it will trip up protobuf readers.
var data = args.buffer[0..args.buffer_size_in_bytes];
data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data;
break :blk .{ .external = data };
};
}
pub fn dumpDataTo(
self: *Profiler,
allocator: std.mem.Allocator,
dir: std.fs.Dir,
file_name: []const u8,
) !void {
const profile_data = try self.collectData(allocator);
defer profile_data.free(allocator);
if (profile_data.items().len == 0) return;
const file = try dir.createFile(file_name, .{ .truncate = true });
defer file.close();
log.info("Writing profiling data to {s} ({} bytes)", .{ file_name, profile_data.items().len });
return try file.writeAll(profile_data.items());
}
pub fn dumpAsJsonTo(
self: *Profiler,
allocator: std.mem.Allocator,
dir: std.fs.Dir,
file_name: []const u8,
) !void {
log.info("Writing profiling data to {s}", .{file_name});
var output_file = try dir.createFile(file_name, .{});
defer output_file.close();
var buffered_writer = std.io.bufferedWriter(output_file.writer());
try self.dumpAsJsonToWriter(allocator, buffered_writer.writer());
try buffered_writer.flush();
}
pub fn dumpAsJsonToWriter(
self: *Profiler,
allocator: std.mem.Allocator,
writer: anytype,
) !void {
const profile_data = try self.collectData(allocator);
defer profile_data.free(allocator);
if (profile_data.items().len == 0) {
log.warn("No profile data was collected: {}", .{self});
return;
}
var converter = TraceContainer.init(allocator);
defer converter.deinit();
try converter.parseXSpaceBytes(profile_data.items(), 1_000_000);
try converter.toJson(writer);
}
fn check(self: *Profiler, c_error: ?*Error) !void {
if (c_error) |err| {
self.last_error = err;
return error.PjrtProfilerError;
}
}
pub fn deinit(self: Profiler) void {
switch (self.status) {
.started => log.warn("Profiler was never stopped", .{}),
.stopped => log.warn("Profiler data was never collected", .{}),
else => {},
}
if (self.api == null) return;
var args: c.PLUGIN_Profiler_Destroy_Args = .{ .profiler = self.inner };
_ = self.api.?.destroy.?(&args);
}
};
const ProfilingData = union(enum) {
owned: []const u8,
external: []const u8,
pub fn items(self: ProfilingData) []const u8 {
return switch (self) {
inline else => |x| x,
};
}
pub fn free(self: ProfilingData, allocator: std.mem.Allocator) void {
switch (self) {
.owned => |data| allocator.free(data),
.external => {},
}
}
};