diff --git a/docs/huggingface-access-token.md b/docs/huggingface-access-token.md deleted file mode 100644 index 633a607..0000000 --- a/docs/huggingface-access-token.md +++ /dev/null @@ -1,37 +0,0 @@ -# Running Gated Huggingface Models with Token Authentication - -Some models have restrictions and may require some sort of approval or agreement -process, which, by consequence, **requires token-authentication with Huggingface**. - -The easiest way might be to use the `huggingface-cli login` command. - -Alternatively, here is how you can generate a **"read-only public repositories"** -access token to log into your account on Huggingface, directly from `bazel`, in order to download models. - -* log in at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). -* click on "Create new token" -* give the token a name, eg `zml_public_repos`, -* under _Repositories_, grant the following permission: "Read access to contents of all public gated repos you can access". -* at the bottom click on "Create token". -* copy the token by clicking `Copy`. **You won't be able to see it again.** -* the token looks something like `hf_abCdEfGhijKlM`. -* store the token on your machine (replace the placeholder with your actual token): - -You can use the `HUGGINGFACE_TOKEN` environment variable to store the token or use -its standard location: -``` -mkdir -p $HOME/.cache/huggingface/; echo > "$HOME/.cache/huggingface/token" -``` - -Now you're ready to download a gated model like `Meta-Llama-3-8b`! - -**Example:** - -``` -# requires token in $HOME/.cache/huggingface/token, as created by the -# `huggingface-cli login` command, or the `HUGGINGFACE_TOKEN` environment variable. -cd examples -bazel run --config=release //llama:Meta-Llama-3-8b -bazel run --config=release //llama:Meta-Llama-3-8b -- --promt="Once upon a time," -``` - diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 682b569..470a415 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -1,28 +1,17 @@ load("@rules_zig//zig:defs.bzl", "zig_library") -load("@zml//bazel:zig.bzl", "zig_cc_binary") load("@zml//bazel:zig_srcs.bzl", "zig_srcs") -load("@zml//bazel:zig_proto_library.bzl", "zig_proto_library") zig_library( name = "pjrt", - srcs = [ - "convert/trace_container.zig", - "convert/xplane_schema.zig", - "ffi.zig", - "profiler.zig", - ], + srcs = ["ffi.zig"], main = "pjrt.zig", visibility = ["//visibility:public"], deps = [ - ":profiler_options_proto", - ":trace_events_proto", - ":xplane_proto", "//stdx", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", - "@xla//xla/pjrt/c:pjrt_c_api_profiler_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_triton_extension_hdrs", ], ) @@ -31,36 +20,3 @@ zig_srcs( name = "sources", zig_lib = ":pjrt", ) - -zig_proto_library( - name = "profiler_options_proto", - import_name = "//tsl:profiler_options_proto", - deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:profiler_options_proto"], -) - -zig_proto_library( - name = "xplane_proto", - import_name = "//tsl:xplane_proto", - deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:xplane_proto"], -) - -zig_proto_library( - name = "trace_events_proto", - import_name = "//tsl:trace_events_proto", - deps = ["@xla//third_party/tsl/tsl/profiler/protobuf:trace_events_proto"], -) - -zig_cc_binary( - name = "xspace_to_json", - srcs = [ - "convert/trace_container.zig", - "convert/xplane_schema.zig", - ], - main = "xspace_to_json.zig", - visibility = ["//visibility:public"], - deps = [ - ":trace_events_proto", - ":xplane_proto", - "//stdx", - ], -) diff --git a/pjrt/convert/trace_container.zig b/pjrt/convert/trace_container.zig deleted file mode 100644 index 6346547..0000000 --- a/pjrt/convert/trace_container.zig +++ /dev/null @@ -1,366 +0,0 @@ -const std = @import("std"); - -const trace_events_proto = @import("//tsl:trace_events_proto"); -const xplane_proto = @import("//tsl:xplane_proto"); - -const xplane_schema = @import("xplane_schema.zig"); - -// Constants used as trace_viewer PID (device_id in trace_events.proto). -// PID 0 is unused. -// Support up to 500 accelerator devices. -const first_device_id = 1; -const last_device_id = 500; -// Support Upto 200 custom planes as fake devices (i.e., planes with a -// "/custom:" prefix). See `::custom_plane_prefix` for more -// information -const first_custom_plane_device_id = last_device_id + 1; -const max_custom_plane_devices_per_host = 200; -const last_custom_plane_device_id = first_custom_plane_device_id + max_custom_plane_devices_per_host - 1; - -// Host threads are shown as a single fake device. -pub const host_threads_device_id = last_custom_plane_device_id + 1; - -pub const xla_async_op_line_name = "Async XLA Ops"; - -pub const host_threads_plane_name = "/host:CPU"; -pub const gpu_plane_prefix = "/device:GPU:"; -pub const tpu_plane_prefix = "/device:TPU:"; -pub const custom_plane_prefix = "/device:CUSTOM:"; - -pub const TraceContainer = struct { - arena: std.heap.ArenaAllocator, - events: std.ArrayListUnmanaged(TraceEvent) = .{}, - devices: std.AutoArrayHashMapUnmanaged(u32, Device) = .{}, - - pub const Device = struct { - name: []const u8, - device_id: u32, - resources: std.AutoArrayHashMapUnmanaged(i64, Resource) = .{}, - }; - - pub const Resource = struct { - name: []const u8, - sort_index: i64, - }; - - pub const TraceEvent = struct { - device_id: u32 = 0, - resource_id: i64 = 0, - name: []const u8 = &[_]u8{}, - timestamp_ps: u128 = 0, - duration_ps: u64 = 0, - args: std.StringArrayHashMapUnmanaged([]const u8) = .{}, - }; - - pub fn init(allocator: std.mem.Allocator) TraceContainer { - return .{ - .arena = std.heap.ArenaAllocator.init(allocator), - }; - } - - pub fn deinit(self: *TraceContainer) void { - self.arena.deinit(); - } - - pub fn parseXSpaceBytes(self: *TraceContainer, pb_buffer: []const u8, max_events: ?usize) !void { - const arena = self.arena.allocator(); - - const xspace = try xplane_proto.XSpace.decode(pb_buffer, arena); - try self.fromXSpace(arena, xspace, max_events); - } - - fn findPlaneWithName(space: xplane_proto.XSpace, name: []const u8) ?*xplane_proto.XPlane { - for (space.planes.items) |*v| { - if (std.mem.eql(u8, v.name.getSlice(), name)) return v; - } - return null; - } - - fn findPlanesWithPrefix( - out: *std.ArrayList(*const xplane_proto.XPlane), - space: xplane_proto.XSpace, - prefix: []const u8, - ) !void { - for (space.planes.items) |*p| { - if (std.mem.startsWith(u8, p.name.getSlice(), prefix)) { - try out.append(p); - } - } - } - - fn xplaneToTraceEvents(self: *TraceContainer, allocator: std.mem.Allocator, device_id: u32, xplane: *const XPlaneHashed) !void { - // Convert devices and resources. - const device_entry = try self.devices.getOrPutValue(allocator, device_id, .{ .name = xplane.name(), .device_id = device_id }); - var device = device_entry.value_ptr.*; - defer device_entry.value_ptr.* = device; - - try device.resources.ensureUnusedCapacity(allocator, xplane.plane.lines.items.len); - const sort_by_ordinal = (device_id == host_threads_device_id); - - // Convert events. - for (xplane.plane.lines.items, 0..) |*xline, ordinal| { - const resource_id = if (xline.display_id != 0) xline.display_id else xline.id; - const resource_name = if (xline.display_name.isEmpty()) xline.name.getSlice() else xline.display_name.getSlice(); - device.resources.putAssumeCapacity(resource_id, .{ - .name = resource_name, - .sort_index = if (sort_by_ordinal) @intCast(ordinal) else resource_id, - }); - - if (std.mem.eql(u8, resource_name, xla_async_op_line_name)) continue; - - for (xline.events.items) |*xevent| { - const event_type = xplane.getEventType(xevent.metadata_id); - if (event_type.isInternalEvent()) continue; - var event = try self.createEvent(allocator); - event.device_id = device_id; - event.resource_id = resource_id; - - if (xplane.event_metadata_by_id.get(xevent.metadata_id)) |metadata| { - try event.args.ensureUnusedCapacity(allocator, 1 + metadata.stats.items.len); - - if (metadata.display_name != .Empty) { - event.name = metadata.display_name.getSlice(); - event.args.putAssumeCapacity("long_name", metadata.name.getSlice()); - } else { - event.name = metadata.name.getSlice(); - } - - event.timestamp_ps = (@as(u128, @intCast(xline.timestamp_ns)) * 1000) + @as(u128, @intCast(xevent.data.?.offset_ps)); - event.duration_ps = @intCast(xevent.duration_ps); - - for (metadata.stats.items) |xstat| { - if (xstat.value == null) continue; - var stat_buffer = std.ArrayList(u8).init(allocator); - try xplane.xstatToString(xstat, stat_buffer.writer().any()); - const stat_str = try stat_buffer.toOwnedSlice(); - const stat_type = xplane.getStatType(xstat.metadata_id); - if (stat_type.isInternalStat()) continue; - if (stat_type == .step_name) event.name = stat_str; - event.args.putAssumeCapacity(xplane.getStatMetadataName(xstat.metadata_id), stat_str); - } - } - - try event.args.ensureUnusedCapacity(allocator, xevent.stats.items.len); - for (xevent.stats.items) |xstat| { - if (xstat.value == null) continue; - var stat_buffer = std.ArrayList(u8).init(allocator); - try xplane.xstatToString(xstat, stat_buffer.writer().any()); - const stat_str = try stat_buffer.toOwnedSlice(); - const stat_type = xplane.getStatType(xstat.metadata_id); - if (stat_type.isInternalStat()) continue; - if (stat_type == .step_name) event.name = stat_str; - event.args.putAssumeCapacity(xplane.getStatMetadataName(xstat.metadata_id), stat_str); - } - } - } - } - - fn fromXSpace(self: *TraceContainer, allocator: std.mem.Allocator, xspace: xplane_proto.XSpace, max_events: ?usize) !void { - if (findPlaneWithName(xspace, host_threads_plane_name)) |hp| { - const xplane = try XPlaneHashed.init(allocator, hp); - try self.xplaneToTraceEvents(allocator, host_threads_device_id, &xplane); - } - - var device_planes = std.ArrayList(*const xplane_proto.XPlane).init(allocator); - defer device_planes.deinit(); - - try findPlanesWithPrefix(&device_planes, xspace, gpu_plane_prefix); - // We don't expect GPU and TPU planes and custom devices to be present in the same XSpace. - if (device_planes.items.len == 0) { - try findPlanesWithPrefix(&device_planes, xspace, tpu_plane_prefix); - } - if (device_planes.items.len == 0) { - try findPlanesWithPrefix(&device_planes, xspace, custom_plane_prefix); - } - - for (device_planes.items) |dp| { - var xplane = try XPlaneHashed.init(allocator, dp); - defer xplane.deinit(allocator); - const device_id: u32 = first_device_id + @as(u32, @intCast(xplane.plane.id)); - try self.xplaneToTraceEvents(allocator, device_id, &xplane); - } - - // Trace viewer (non-streaming) has scalability issues, we need to drop - // events to avoid loading failure for trace viewer. - if (max_events) |limit| self.capEvents(limit); - } - - pub fn createEvent(self: *TraceContainer, allocator: std.mem.Allocator) !*TraceEvent { - try self.events.append(allocator, .{}); - return &self.events.items[self.events.items.len - 1]; - } - - pub fn capEvents(self: *TraceContainer, max_count: u64) void { - const total_count = self.events.items.len; - if (total_count <= max_count) { - // Nothing to do. Events are not known sorted after return. - return; - } - // sort the events according to start time. - // TODO: partial sort would improve performance. - std.mem.sort(TraceEvent, self.events.items, {}, struct { - pub fn call(_: void, lhs: TraceEvent, rhs: TraceEvent) bool { - return lhs.timestamp_ps < rhs.timestamp_ps; - } - }.call); - self.events.shrinkRetainingCapacity(max_count); - } - - pub fn toJson(self: *TraceContainer, writer: anytype) !void { - try writer.writeAll( - \\{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},"traceEvents":[ - ); - - self.devices.sort(struct { - keys: []const u32, - pub fn lessThan(ctx: @This(), lhs: usize, rhs: usize) bool { - return ctx.keys[lhs] < ctx.keys[rhs]; - } - }{ .keys = self.devices.keys() }); - - for (self.devices.keys(), self.devices.values()) |device_id, *device| { - if (device.name.len != 0) { - try writer.print( - \\{{"ph":"M","pid":{d},"name":"process_name","args":{{"name":"{s}"}}}}, - , .{ device_id, device.name }); - } - try writer.print( - \\{{"ph":"M","pid":{d},"name":"process_sort_index","args":{{"sort_index":{d}}}}}, - , .{ - device_id, - device_id, - }); - - device.resources.sort(struct { - keys: []const i64, - pub fn lessThan(ctx: @This(), lhs: usize, rhs: usize) bool { - return ctx.keys[lhs] < ctx.keys[rhs]; - } - }{ .keys = device.resources.keys() }); - - for (device.resources.keys(), device.resources.values()) |resource_id, resource| { - if (resource.name.len != 0) { - try writer.print( - \\{{"ph":"M","pid":{d},"tid":{d},"name":"thread_name","args":{{"name":"{s}"}}}}, - , .{ - device_id, - resource_id, - resource.name, - }); - } - try writer.print( - \\{{"ph":"M","pid":{d},"tid":{d},"name":"thread_sort_index","args":{{"sort_index":{d}}}}}, - , .{ device_id, resource_id, resource.sort_index }); - } - } - - for (self.events.items) |*event| { - const duration_ps = @max(event.duration_ps, 1); - try writer.print( - \\{{"ph":"X","pid":{d},"tid":{d},"ts":{d:.17},"dur":{d:.17},"name":"{s}" - , .{ - event.device_id, - event.resource_id, - picoToMicro(event.timestamp_ps), - picoToMicro(duration_ps), - event.name, - }); - if (event.args.count() != 0) { - try writer.writeAll( - \\,"args":{ - ); - event.args.sort(struct { - keys: []const []const u8, - - pub fn lessThan(ctx: @This(), lhs: usize, rhs: usize) bool { - return std.mem.order(u8, ctx.keys[lhs], ctx.keys[rhs]).compare(std.math.CompareOperator.lt); - } - }{ .keys = event.args.keys() }); - - for (event.args.keys(), event.args.values(), 0..) |key, value, i| { - if (i < event.args.count() - 1) { - try writer.print( - \\"{s}":"{s}", - , .{ key, value }); - } else { - // Last item has closing bracket rather than trailing comma. - try writer.print( - \\"{s}":"{s}"}} - , .{ key, value }); - } - } - } - try writer.writeAll("},"); - } - try writer.writeAll("{}]}"); - } -}; - -fn picoToMicro(p: anytype) f64 { - return @as(f64, @floatFromInt(p)) / 1E6; -} - -pub const XPlaneHashed = struct { - plane: *const xplane_proto.XPlane, - event_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XEventMetadata) = .{}, - stat_metadata_by_id: std.AutoHashMapUnmanaged(i64, *const xplane_proto.XStatMetadata) = .{}, - - pub fn init( - allocator: std.mem.Allocator, - plane: *const xplane_proto.XPlane, - ) !XPlaneHashed { - var res: XPlaneHashed = .{ .plane = plane }; - - try res.event_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.event_metadata.items.len)); - // build event metadata map - for (plane.event_metadata.items) |*event_metadata| { - res.event_metadata_by_id.putAssumeCapacity(event_metadata.key, &event_metadata.value.?); - } - - // build stat metadata map - try res.stat_metadata_by_id.ensureUnusedCapacity(allocator, @intCast(plane.stat_metadata.items.len)); - for (plane.stat_metadata.items) |*stat_metadata| { - res.stat_metadata_by_id.putAssumeCapacity(stat_metadata.key, &stat_metadata.value.?); - } - - return res; - } - - pub fn deinit(self: *XPlaneHashed, allocator: std.mem.Allocator) void { - self.stat_metadata_by_id.deinit(allocator); - self.event_metadata_by_id.deinit(allocator); - } - - pub fn name(self: XPlaneHashed) []const u8 { - return self.plane.name.getSlice(); - } - - pub fn getEventType(self: XPlaneHashed, event_metadata_id: i64) xplane_schema.HostEventType { - if (self.event_metadata_by_id.get(event_metadata_id)) |v| { - return xplane_schema.HostEventType.fromString(v.name.getSlice()); - } else return .unknown; - } - - pub fn getStatMetadataName(self: XPlaneHashed, stat_metadata_id: i64) []const u8 { - if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { - return v.name.getSlice(); - } else return &[_]u8{}; - } - - pub fn getStatType(self: XPlaneHashed, stat_metadata_id: i64) xplane_schema.StatType { - if (self.stat_metadata_by_id.get(stat_metadata_id)) |v| { - return xplane_schema.StatType.fromString(v.name.getSlice()); - } else return .unknown; - } - - pub fn xstatToString(self: XPlaneHashed, stat: xplane_proto.XStat, writer: anytype) !void { - if (stat.value == null) return; - - switch (stat.value.?) { - inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}), - .str_value => |*v| try writer.writeAll(v.getSlice()), - .bytes_value => try writer.writeAll(""), - .ref_value => |v| try writer.writeAll(self.getStatMetadataName(@intCast(v))), - } - } -}; diff --git a/pjrt/convert/xplane_schema.zig b/pjrt/convert/xplane_schema.zig deleted file mode 100644 index ef70462..0000000 --- a/pjrt/convert/xplane_schema.zig +++ /dev/null @@ -1,297 +0,0 @@ -const std = @import("std"); - -// `HostEventType` uses the unconventional casing/formatting -// so that the string representation of the enum used in the -// protobuf encoding directly maps to the zig enum tag name. -pub const HostEventType = enum(u16) { - unknown = 0, - TraceContext, - SessionRun, - FunctionRun, - RunGraph, - RunGraphDone, - TfOpRun, - EagerExecute, - @"ExecutorState::Process", - ExecutorDoneCallback, - MemoryAllocation, - MemoryDeallocation, - // Performance counter related. - RemotePerfCounter, - // tf.data captured function events. - @"InstantiatedCapturedFunction::Run", - @"InstantiatedCapturedFunction::RunWithBorrowedArgs", - @"InstantiatedCapturedFunction::RunInstantiated", - @"InstantiatedCapturedFunction::RunAsync", - // Loop ops. - ParallelForOp, - ForeverOp, - @"WhileOp-EvalCond", - @"WhileOp-StartBody", - ForOp, - // tf.data related. - @"IteratorGetNextOp::DoCompute", - @"IteratorGetNextAsOptionalOp::DoCompute", - Iterator, - @"Iterator::Prefetch::Generator", - PrefetchProduce, - PrefetchConsume, - ParallelInterleaveProduce, - ParallelInterleaveConsume, - ParallelInterleaveInitializeInput, - ParallelMapProduce, - ParallelMapConsume, - MapAndBatchProduce, - MapAndBatchConsume, - ParseExampleProduce, - ParseExampleConsume, - ParallelBatchProduce, - ParallelBatchConsume, - // Batching related. - BatchingSessionRun, - ProcessBatch, - BrainSessionRun, - ConcatInputTensors, - MergeInputTensors, - ScheduleWithoutSplit, - ScheduleWithSplit, - ScheduleWithEagerSplit, - @"ASBSQueue::Schedule", - // TFRT related. - TfrtModelRun, - // Serving related. - ServingModelRun, - // GPU related. - KernelLaunch, - KernelExecute, - // TPU related - EnqueueRequestLocked, - RunProgramRequest, - HostCallbackRequest, - TransferH2DRequest, - TransferPreprocessedH2DRequest, - TransferD2HRequest, - OnDeviceSendRequest, - OnDeviceRecvRequest, - OnDeviceSendRecvLocalRequest, - CustomWait, - OnDeviceSendRequestMulti, - OnDeviceRecvRequestMulti, - PjrtAsyncWait, - DoEnqueueProgram, - DoEnqueueContinuationProgram, - WriteHbm, - ReadHbm, - TpuExecuteOp, - CompleteCallbacks, - @"tpu::System::TransferToDevice=>IssueEvent", - @"tpu::System::TransferToDevice=>IssueEvent=>Done", - @"tpu::System::TransferFromDevice=>IssueEvent", - @"tpu::System::TransferFromDevice=>IssueEvent=>Done", - @"tpu::System::Execute", - @"TPUPartitionedCallOp-InitializeVarOnTPU", - @"TPUPartitionedCallOp-ExecuteRemote", - @"TPUPartitionedCallOp-ExecuteLocal", - Linearize, - Delinearize, - @"TransferBufferFromDevice-FastPath", - - pub fn fromString(event_name: []const u8) HostEventType { - return std.meta.stringToEnum(HostEventType, event_name) orelse .unknown; - } - - pub fn isInternalEvent(event_type: HostEventType) bool { - // TODO(b/162102421): Introduce a prefix for internal event names. - return switch (event_type) { - .MemoryAllocation, - .MemoryDeallocation, - .PrefetchProduce, - .PrefetchConsume, - .ParallelInterleaveProduce, - .ParallelInterleaveConsume, - .ParallelInterleaveInitializeInput, - .ParallelMapProduce, - .ParallelMapConsume, - .MapAndBatchProduce, - .MapAndBatchConsume, - .ParseExampleProduce, - .ParseExampleConsume, - => true, - else => false, - }; - } -}; - -// `StatType` uses the unconventional casing/formatting -// so that the string representation of the enum used in the -// protobuf encoding directly maps to the zig enum tag name. -pub const StatType = enum(u16) { - unknown = 0, - // TraceMe arguments. - id, - device_ordinal, - chip_ordinal, - node_ordinal, - model_id, - queue_id, - queue_addr, - request_id, - run_id, - replica_id, - graph_type, - step_num, - iter_num, - index_on_host, - allocator_name, - bytes_reserved, - bytes_allocated, - bytes_available, - fragmentation, - peak_bytes_in_use, - requested_bytes, - allocation_bytes, - addr, - region_type, - data_type, - shape, - layout, - kpi_name, - kpi_value, - element_id, - parent_id, - core_type, - // XPlane semantics related. - _pt, - _ct, - _p, - _c, - _r, - _a, - // Device trace arguments. - device_id, - device_type_string, - context_id, - correlation_id, - // TODO(b/176137043): These "details" should differentiate between activity - // and API event sources. - memcpy_details, - memalloc_details, - MemFree_details, - Memset_details, - MemoryResidency_details, - nvtx_range, - kernel_details, - stream, - // Stats added when processing traces. - group_id, - flow, - step_name, - tf_op, - hlo_op, - deduplicated_name, - hlo_category, - hlo_module, - program_id, - equation, - is_eager, - is_func, - tf_function_call, - tracing_count, - flops, - model_flops, - bytes_accessed, - memory_access_breakdown, - source, - model_name, - model_version, - bytes_transferred, - queue, - dcn_collective_info, - // Performance counter related. - @"Raw Value", - @"Scaled Value", - @"Thread Id", - matrix_unit_utilization_percent, - // XLA metadata map related. - @"Hlo Proto", - // Device capability related. - clock_rate, - // For GPU, this is the number of SMs. - core_count, - memory_bandwidth, - memory_size, - compute_cap_major, - compute_cap_minor, - peak_teraflops_per_second, - peak_hbm_bw_gigabytes_per_second, - peak_sram_rd_bw_gigabytes_per_second, - peak_sram_wr_bw_gigabytes_per_second, - device_vendor, - // Batching related. - batch_size_after_padding, - padding_amount, - batching_input_task_size, - // GPU occupancy metrics - theoretical_occupancy_pct, - occupancy_min_grid_size, - occupancy_suggested_block_size, - // Aggregated Stats - self_duration_ps, - min_duration_ps, - total_profile_duration_ps, - max_iteration_num, - device_type, - uses_megacore, - symbol_id, - tf_op_name, - dma_stall_duration_ps, - key, - payload_size_bytes, - duration_us, - buffer_size, - transfers, - // Dcn message Stats - dcn_label, - dcn_source_slice_id, - dcn_source_per_slice_device_id, - dcn_destination_slice_id, - dcn_destination_per_slice_device_id, - dcn_chunk, - dcn_loop_index, - @"EdgeTPU Model information", - @"EdgeTPU Model Profile information", - @"EdgeTPU MLIR", - dropped_traces, - cuda_graph_id, - // Many events have `.cuda_graph_id`, such as graph sub events when tracing is in - // node level. Yet `.cuda_graph_exec_id` is used only for CudaGraphExecution events - // on the GPU device when tracing is in graph level. - cuda_graph_exec_id, - cuda_graph_orig_id, - step_idle_time_ps, - gpu_device_name, - source_stack, - device_offset_ps, - device_duration_ps, - - pub fn fromString(stat_name: []const u8) StatType { - return std.meta.stringToEnum(StatType, stat_name) orelse .unknown; - } - - pub fn isInternalStat(stat_type: StatType) bool { - return switch (stat_type) { - .kernel_details, - ._pt, - ._p, - ._ct, - ._c, - ._r, - .flops, - .bytes_accessed, - .program_id, - .symbol_id, - => true, - else => false, - }; - } -}; diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 58d9ac2..00749d5 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -5,7 +5,6 @@ const c = @import("c"); const stdx = @import("stdx"); pub const ffi = @import("ffi.zig"); -pub const Profiler = @import("profiler.zig").Profiler; const log = std.log.scoped(.pjrt); @@ -389,19 +388,6 @@ pub const Client = opaque { }; } - /// Returns the Profiler for this API. - /// Not all platform have a profiling api, for those the profiler object will do nothing. - /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: *const Client, api: *const Api, options: Profiler.Options) Profiler { - if (api.version().minor >= 45) { - if (api.lookupExtension(c.PJRT_Profiler_Extension, c.PJRT_Extension_Type_Profiler)) |ext| { - return Profiler.init(ext.profiler_api.*, options); - } - } - log.warn("No profiler found for platform: {}", .{self}); - return Profiler.init(null, null); - } - pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { const ret = try api.call(.PJRT_Executable_DeserializeAndLoad, .{ .client = self.inner(), diff --git a/pjrt/profiler.zig b/pjrt/profiler.zig deleted file mode 100644 index 77f1520..0000000 --- a/pjrt/profiler.zig +++ /dev/null @@ -1,194 +0,0 @@ -const std = @import("std"); -const c = @import("c"); -const tsl_proto = @import("//tsl:profiler_options_proto"); - -const log = std.log.scoped(.@"pjrt/profiler"); -const TraceContainer = @import("convert/trace_container.zig").TraceContainer; - -/// Pjrt Profiler extension -pub const Profiler = struct { - api: ?c.PLUGIN_Profiler_Api, - inner: *c.PLUGIN_Profiler, - last_error: ?*Error = null, - status: Status = .ready, - - pub const Status = enum { ready, started, stopped, done }; - pub const Error = c.PLUGIN_Profiler_Error; - pub const Options = tsl_proto.ProfileOptions; - - pub const default_options: Options = .{ - .version = 1, - .device_type = .UNSPECIFIED, // profile all devices - .include_dataset_ops = false, // tensorflow specific - .host_tracer_level = 2, - .device_tracer_level = 1, - .python_tracer_level = 0, - .enable_hlo_proto = true, - .start_timestamp_ns = 0, - .duration_ms = 0, - .repository_path = .Empty, - }; - - pub fn init(api: ?c.PLUGIN_Profiler_Api, options: ?Options) Profiler { - if (api == null) { - return .{ .api = null, .inner = undefined }; - } - var options_with_timestamp = options orelse default_options; - options_with_timestamp.start_timestamp_ns = @truncate(@max(0, std.time.nanoTimestamp())); - - var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined; - var fba = std.heap.FixedBufferAllocator.init(&buffer); - const byte_options = options_with_timestamp.encode(fba.allocator()) catch unreachable; - var args: c.PLUGIN_Profiler_Create_Args = .{ - .options = byte_options.ptr, - .options_size = byte_options.len, - .profiler = undefined, // out - }; - var res: Profiler = .{ .api = api, .inner = undefined }; - res.check(api.?.create.?(&args)) catch unreachable; - - res.inner = args.profiler.?; - return res; - } - - fn transition(self: *Profiler, fn_name: []const u8, expected: Status, next: Status) void { - if (self.status == expected) { - self.status = next; - return; - } - std.debug.panic("Profiler can't `{s}()`. Current status: {}, expected: {}", .{ fn_name, self.status, expected }); - } - - pub fn start(self: *Profiler) void { - self.transition("start", .ready, .started); - if (self.api == null) return; - var args: c.PLUGIN_Profiler_Start_Args = .{ .profiler = self.inner }; - self.check(self.api.?.start.?(&args)) catch unreachable; - } - - pub fn stop(self: *Profiler) void { - self.transition("stop", .started, .stopped); - if (self.api == null) return; - - var args: c.PLUGIN_Profiler_Stop_Args = .{ .profiler = self.inner }; - self.check(self.api.?.stop.?(&args)) catch unreachable; - } - - pub fn collectData(self: *Profiler, allocator: std.mem.Allocator) !ProfilingData { - self.transition("collect_data", .stopped, .done); - if (self.api == null) return .{ .external = &.{} }; - - var args: c.PLUGIN_Profiler_CollectData_Args = .{ - .struct_size = c.PLUGIN_Profiler_CollectData_Args_STRUCT_SIZE, - .profiler = self.inner, - .buffer = null, - .buffer_size_in_bytes = 0, - }; - try self.check(self.api.?.collect_data.?(&args)); - std.debug.assert(args.buffer_size_in_bytes > 0); - return if (args.buffer == null) blk: { - log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes}); - // The plugin want us to allocate memory for it: - const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes); - args.buffer = buffer.ptr; - try self.check(self.api.?.collect_data.?(&args)); - break :blk .{ .owned = buffer }; - } else blk: { - log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes}); - // Drop sentinel. The profiler plugin returns a null terminated string. - // But this is creating issues if we save the sentinel on disk, - // because it will trip up protobuf readers. - var data = args.buffer[0..args.buffer_size_in_bytes]; - data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data; - break :blk .{ .external = data }; - }; - } - - pub fn dumpDataTo( - self: *Profiler, - allocator: std.mem.Allocator, - dir: std.fs.Dir, - file_name: []const u8, - ) !void { - const profile_data = try self.collectData(allocator); - defer profile_data.free(allocator); - - if (profile_data.items().len == 0) return; - - const file = try dir.createFile(file_name, .{ .truncate = true }); - defer file.close(); - log.info("Writing profiling data to {s} ({} bytes)", .{ file_name, profile_data.items().len }); - return try file.writeAll(profile_data.items()); - } - - pub fn dumpAsJsonTo( - self: *Profiler, - allocator: std.mem.Allocator, - dir: std.fs.Dir, - file_name: []const u8, - ) !void { - log.info("Writing profiling data to {s}", .{file_name}); - var output_file = try dir.createFile(file_name, .{}); - defer output_file.close(); - var buffered_writer = std.io.bufferedWriter(output_file.writer()); - try self.dumpAsJsonToWriter(allocator, buffered_writer.writer()); - try buffered_writer.flush(); - } - - pub fn dumpAsJsonToWriter( - self: *Profiler, - allocator: std.mem.Allocator, - writer: anytype, - ) !void { - const profile_data = try self.collectData(allocator); - defer profile_data.free(allocator); - - if (profile_data.items().len == 0) { - log.warn("No profile data was collected: {}", .{self}); - return; - } - - var converter = TraceContainer.init(allocator); - defer converter.deinit(); - try converter.parseXSpaceBytes(profile_data.items(), 1_000_000); - - try converter.toJson(writer); - } - - fn check(self: *Profiler, c_error: ?*Error) !void { - if (c_error) |err| { - self.last_error = err; - return error.PjrtProfilerError; - } - } - - pub fn deinit(self: Profiler) void { - switch (self.status) { - .started => log.warn("Profiler was never stopped", .{}), - .stopped => log.warn("Profiler data was never collected", .{}), - else => {}, - } - if (self.api == null) return; - - var args: c.PLUGIN_Profiler_Destroy_Args = .{ .profiler = self.inner }; - _ = self.api.?.destroy.?(&args); - } -}; - -const ProfilingData = union(enum) { - owned: []const u8, - external: []const u8, - - pub fn items(self: ProfilingData) []const u8 { - return switch (self) { - inline else => |x| x, - }; - } - - pub fn free(self: ProfilingData, allocator: std.mem.Allocator) void { - switch (self) { - .owned => |data| allocator.free(data), - .external => {}, - } - } -}; diff --git a/pjrt/xspace_to_json.zig b/pjrt/xspace_to_json.zig deleted file mode 100644 index e5c0698..0000000 --- a/pjrt/xspace_to_json.zig +++ /dev/null @@ -1,47 +0,0 @@ -const std = @import("std"); -const stdx = @import("stdx"); -const flags = stdx.flags; - -const TraceContainer = @import("convert/trace_container.zig").TraceContainer; - -const CliArgs = struct { - pub const help = - \\ llama --path=path_to_profiling_data - ; - path: []const u8, - max_events: ?usize = null, -}; - -pub fn main() !void { - var gpa = std.heap.GeneralPurposeAllocator(.{ .thread_safe = true }){}; - defer _ = gpa.deinit(); - const allocator = gpa.allocator(); - - var args = std.process.args(); - const cli_args = flags.parse(&args, CliArgs); - - var fd = try std.fs.openFileAbsolute(cli_args.path, .{}); - defer fd.close(); - - const pb_buffer = try fd.readToEndAlloc(allocator, (try fd.stat()).size); - defer allocator.free(pb_buffer); - if (pb_buffer.len == 0) return error.EmptyBuffer; - - var converter = TraceContainer.init(allocator); - defer converter.deinit(); - try converter.parseXSpaceBytes(pb_buffer, cli_args.max_events); - - var path_buffer: [1028]u8 = undefined; - - const output_path = try std.fmt.bufPrint(&path_buffer, "{s}/{s}.json", .{ - std.fs.path.dirname(cli_args.path) orelse "", - std.fs.path.stem(cli_args.path), - }); - - var output_file = try std.fs.createFileAbsolute(output_path, .{}); - defer output_file.close(); - - try converter.toJson(output_file.writer().any()); - - std.debug.print("Wrote JSON to {s}\n", .{output_path}); -} diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 1795efc..b12bdb2 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -5,7 +5,6 @@ const dialects = @import("mlir/dialects"); const mlir = @import("mlir"); const pjrt = @import("pjrt"); pub const ffi = pjrt.ffi; -pub const Profiler = pjrt.Profiler; pub const ApiError = pjrt.ApiError; pub const ErrorCode = pjrt.ErrorCode; pub const ExecuteContext = pjrt.ExecuteContext; @@ -109,13 +108,6 @@ pub const Client = opaque { return try asynk.callBlocking(compileSync, .{ self, api, allocator, module, compile_options_pb }); } - /// Returns the Profiler for this API. - /// Not all platform have a profiling api, for those the profiler object will do nothing. - /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: *const Client, api: *const Api, options: pjrt.Profiler.Options) pjrt.Profiler { - return self.inner().getProfiler(api, options); - } - pub fn addressableMemories(self: *const Client, api: *const Api) []*const Memory { return self.inner().addressableMemories(api); } diff --git a/zml/platform.zig b/zml/platform.zig index a7bc87c..6f48137 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -85,13 +85,6 @@ pub const Platform = struct { pub fn deinit(self: *Platform) void { self.pjrt_client.deinit(self.pjrt_api); } - - /// Returns the Profiler for this API. - /// Not all platform have a profiling api, for those the profiler object will do nothing. - /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: Platform, options: ?pjrt.Profiler.Options) pjrt.Profiler { - return self.pjrt_client.getProfiler(self.pjrt_api, options orelse pjrt.Profiler.default_options); - } }; const _CreateOptions = struct {