From 37725cdaa65fb0bf57ccbf878bd5e530d2b0aec2 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 4 Dec 2023 10:38:10 +0000 Subject: [PATCH] =?UTF-8?q?Update=20PJRT,=20runtime,=20and=20ZML=20modules?= =?UTF-8?q?=20to=20use=20per=E2=80=91target=20output=20folders=20and=20exp?= =?UTF-8?q?ose=20`profiler.dumpDataAsJson`=20for=20JSON=20profiling=20outp?= =?UTF-8?q?ut.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pjrt/BUILD.bazel | 4 +- pjrt/convert/trace_container.zig | 4 +- pjrt/pjrt.zig | 2 +- pjrt/profiler.zig | 58 ++++++++++--- runtimes/neuron/neuron.zig | 2 + zml/module.zig | 142 ++++++++++++++----------------- zml/platform.zig | 5 +- zml/testing.zig | 18 +--- 8 files changed, 124 insertions(+), 111 deletions(-) diff --git a/pjrt/BUILD.bazel b/pjrt/BUILD.bazel index 7dc5927..cf00cc2 100644 --- a/pjrt/BUILD.bazel +++ b/pjrt/BUILD.bazel @@ -10,11 +10,13 @@ cc_library( zig_library( name = "pjrt", - srcs = ["profiler.zig"], + srcs = ["profiler.zig"] + glob(["convert/*.zig"]), main = "pjrt.zig", visibility = ["//visibility:public"], deps = [ ":profiler_options_proto", + ":trace_events_proto", + ":xplane_proto", "//stdx", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", diff --git a/pjrt/convert/trace_container.zig b/pjrt/convert/trace_container.zig index 022d0e6..a967087 100644 --- a/pjrt/convert/trace_container.zig +++ b/pjrt/convert/trace_container.zig @@ -76,7 +76,7 @@ pub const TraceContainer = struct { }; } - fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: std.io.AnyWriter) !void { + fn xstatValueToString(stat: *const xplane_proto.XStat, plane: *const xplane_visitor.XPlaneVisitor, writer: anytype) !void { if (stat.value) |val| { switch (val) { inline .int64_value, .uint64_value, .double_value => |v| try writer.print("{d}", .{v}), @@ -231,7 +231,7 @@ pub const TraceContainer = struct { self.events.shrinkRetainingCapacity(max_count); } - pub fn toJson(self: *TraceContainer, writer: std.io.AnyWriter) !void { + pub fn toJson(self: *TraceContainer, writer: anytype) !void { try writer.writeAll( \\{"displayTimeUnit":"ns","metadata":{"highres-ticks":true},"traceEvents":[ ); diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 69bd22a..9cfa5fb 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -359,7 +359,7 @@ pub const Client = opaque { } } log.warn("No profiler found for platform: {}", .{self}); - return Profiler.init(null, options); + return Profiler.init(null, null); } pub fn deserializeAndLoad(self: *const Client, api: *const Api, bytes: []const u8) ApiError!*LoadedExecutable { diff --git a/pjrt/profiler.zig b/pjrt/profiler.zig index 6265e64..93883b2 100644 --- a/pjrt/profiler.zig +++ b/pjrt/profiler.zig @@ -2,7 +2,8 @@ const std = @import("std"); const c = @import("c"); const tsl_proto = @import("//tsl:profiler_options_proto"); -const log = std.log.scoped(.@"zml/profiler"); +const log = std.log.scoped(.@"pjrt/profiler"); +const TraceContainer = @import("convert/trace_container.zig").TraceContainer; /// Pjrt Profiler extension pub const Profiler = struct { @@ -15,20 +16,35 @@ pub const Profiler = struct { pub const Error = c.PLUGIN_Profiler_Error; pub const Options = tsl_proto.ProfileOptions; - pub fn init(api: ?c.PLUGIN_Profiler_Api, options: Options) Profiler { + pub const default_options: Options = .{ + .version = 1, + .device_type = .UNSPECIFIED, // profile all devices + .include_dataset_ops = false, // tensorflow specific + .host_tracer_level = 2, + .device_tracer_level = 1, + .python_tracer_level = 0, + .enable_hlo_proto = true, + .start_timestamp_ns = 0, + .duration_ms = 0, + .repository_path = .Empty, + }; + + pub fn init(api: ?c.PLUGIN_Profiler_Api, options: ?Options) Profiler { if (api == null) { return .{ .api = null, .inner = undefined }; } + var options_with_timestamp = options orelse default_options; + options_with_timestamp.start_timestamp_ns = @truncate(@max(0, std.time.nanoTimestamp())); var buffer: [std.fs.max_path_bytes + @sizeOf(Options) * 4]u8 = undefined; var fba = std.heap.FixedBufferAllocator.init(&buffer); - const byte_options = options.encode(fba.allocator()) catch unreachable; - var res: Profiler = .{ .api = api, .inner = undefined }; + const byte_options = options_with_timestamp.encode(fba.allocator()) catch unreachable; var args: c.PLUGIN_Profiler_Create_Args = .{ .options = byte_options.ptr, .options_size = byte_options.len, .profiler = undefined, // out }; + var res: Profiler = .{ .api = api, .inner = undefined }; res.check(api.?.create.?(&args)) catch unreachable; res.inner = args.profiler.?; @@ -70,15 +86,15 @@ pub const Profiler = struct { }; try self.check(self.api.?.collect_data.?(&args)); std.debug.assert(args.buffer_size_in_bytes > 0); - const buffer: ProfilingData = if (args.buffer == null) blk: { - std.log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes}); + return if (args.buffer == null) blk: { + log.debug("Plugin profiler wants us to allocate {d} bytes for profile data", .{args.buffer_size_in_bytes}); // The plugin want us to allocate memory for it: const buffer = try allocator.alloc(u8, args.buffer_size_in_bytes); args.buffer = buffer.ptr; try self.check(self.api.?.collect_data.?(&args)); break :blk .{ .owned = buffer }; } else blk: { - std.log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes}); + log.debug("Plugin profiler has {d} bytes of profile data", .{args.buffer_size_in_bytes}); // Drop sentinel. The profiler plugin returns a null terminated string. // But this is creating issues if we save the sentinel on disk, // because it will trip up protobuf readers. @@ -86,9 +102,6 @@ pub const Profiler = struct { data = if (data.len > 0 and data[data.len - 1] == 0) data[0 .. data.len - 1] else data; break :blk .{ .external = data }; }; - - // printDataAsXSpace(allocator, buffer.items()); - return buffer; } pub fn dumpDataTo( @@ -108,6 +121,31 @@ pub const Profiler = struct { return try file.writeAll(profile_data.items()); } + pub fn dumpAsJsonTo( + self: *Profiler, + allocator: std.mem.Allocator, + dir: std.fs.Dir, + file_name: []const u8, + ) !void { + const profile_data = try self.collectData(allocator); + defer profile_data.free(allocator); + + if (profile_data.items().len == 0) { + log.warn("No profile data was collected: {}", .{self}); + return; + } + + var converter = try TraceContainer.init(allocator, profile_data.items(), null); + defer converter.deinit(); + + var output_file = try dir.createFile(file_name, .{}); + defer output_file.close(); + var buffered_writer = std.io.bufferedWriter(output_file.writer()); + log.info("Writing profiling data to {s}", .{file_name}); + try converter.toJson(buffered_writer.writer()); + try buffered_writer.flush(); + } + fn check(self: *Profiler, c_error: ?*Error) !void { if (c_error) |err| { self.last_error = err; diff --git a/runtimes/neuron/neuron.zig b/runtimes/neuron/neuron.zig index b1139ac..4b7e141 100644 --- a/runtimes/neuron/neuron.zig +++ b/runtimes/neuron/neuron.zig @@ -103,6 +103,8 @@ fn comptimeStrJoin(comptime separator: [:0]const u8, comptime slices: []const [: } pub fn setNeuronCCFlags() void { + // See neuronxcc reference: + // https://awsdocs-neuron.readthedocs-hosted.com/en/latest/compiler/neuronx-cc/api-reference-guide/neuron-compiler-cli-reference-guide.html#neuron-compiler-cli-reference-guide _ = c.setenv("NEURON_CC_FLAGS", comptimeStrJoin(" ", &.{ // 30% faster, no visible speed difference on llama "--optlevel=1", diff --git a/zml/module.zig b/zml/module.zig index 9eaff9b..81ec32d 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -203,35 +203,39 @@ pub const CompilationContext = struct { module.op().setAttributeByName("mhlo.num_partitions", mlir.IntegerAttribute(.i32).init(mlir_ctx, sharding.num_partitions).asAttr()); const module_hash = computeModuleHash(self._platform, module); + var module_dir: ?[]const u8 = null; + var pjrt_location: ?[:0]const u8 = null; + if (self._platform.compilation_options.xla_dump_to) |xla_dump_to| { + const sep = std.fs.path.sep_str; + const module_dir_name = try std.fmt.allocPrint(arena, "{s}{s}{s}{s}{s}_{x}", .{ xla_dump_to, sep, @tagName(self._platform.target), sep, self._name, module_hash }); + try std.fs.cwd().makePath(module_dir_name); + module_dir = try std.fs.cwd().realpathAlloc(arena, module_dir_name); + const cache_dir = try std.fs.cwd().openDir(module_dir.?, .{}); + // Write the mlir to a file. All errors are discarded, since this is for debugging only. - if (std.fs.openDirAbsolute(xla_dump_to, .{})) |dir| { - const name = self._name; - const file_name = std.fmt.allocPrint(arena, "{s}_{x}.mlir", .{ name, module_hash }) catch name; - if (dir.createFile(file_name, .{ .truncate = true })) |file| { - module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); - log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name }); - } else |_| { - log.warn("Failed to open {s}", .{file_name}); - } + const mlir_name = "module.mlir"; + if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| { + module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); + log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name }); } else |_| { - log.warn("Folder not found {s}", .{xla_dump_to}); + log.warn("Failed to open {s}", .{mlir_name}); } + + pjrt_location = try std.fs.path.joinZ(arena, &.{ module_dir.?, "module.pjrt" }); } - const tracer = Tracer.init("ai.zml.compilation"); - const compile_frame = tracer.frameStart("pjrt cached compilation"); - defer tracer.frameEnd(compile_frame, "pjrt cached compilation"); - const loaded_executable: *pjrt.LoadedExecutable = blk: { - const cache_location = try absoluteCacheFileZ(arena, self._platform.compilation_options.cache_location, module_hash); - if (cache_location) |cache_file| { - if (loadPjrtExecutable(arena, self._platform, cache_file)) |exe| { + if (pjrt_location) |pjrt_loc| { + if (loadPjrtExecutable(arena, self._platform, pjrt_loc)) |exe| { + log.info("Loaded pre-compiled module from {s}", .{pjrt_loc}); break :blk exe; - } else |_| {} + } else |err| { + if (err != error.FileNotFound) log.warn("Failed to load pre-compiled module: {} at {s}", .{ err, pjrt_loc }); + } } - const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module) catch |err| { + const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir.?) catch |err| { log.err( "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", .{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err }, @@ -239,9 +243,9 @@ pub const CompilationContext = struct { return err; }; - if (cache_location) |cache_file| { - storePjrtExecutable(self._platform, loaded_executable, cache_file) catch |err| { - log.debug("Failed to store module: {}", .{err}); + if (pjrt_location) |pjrt_loc| { + storePjrtExecutable(self._platform, loaded_executable, pjrt_loc) catch |err| { + log.warn("Failed to store compiled module: {} at {s}", .{ err, pjrt_loc }); }; } break :blk loaded_executable; @@ -813,22 +817,13 @@ fn computeModuleHash(platform: Platform, module: mlir.Module) u64 { return hasher.final(); } -fn absoluteCacheFileZ(arena: std.mem.Allocator, cache_path: ?[]const u8, module_hash: u64) !?[:0]const u8 { - if (cache_path == null) return null; - const resolved_path = try std.fs.cwd().realpathAlloc(arena, cache_path.?); - std.fs.makeDirAbsolute(resolved_path) catch |err| switch (err) { - error.PathAlreadyExists => {}, - else => return err, - }; - - var buf: [24]u8 = undefined; - const module_name = std.fmt.bufPrint(&buf, "{x}.pjrt", .{module_hash}) catch unreachable; - return try std.fs.path.joinZ(arena, &.{ resolved_path, module_name }); -} - const max_pjrt_executable_size = 400 * 1024 * 1024; fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_file: [:0]const u8) !*pjrt.LoadedExecutable { + const tracer = Tracer.init("ai.zml.load_exe"); + const compile_frame = tracer.frameStart("pjrt load executable"); + defer tracer.frameEnd(compile_frame, "pjrt load executable"); + const loaded_executable_file = try std.fs.openFileAbsoluteZ(absolute_file, .{}); defer loaded_executable_file.close(); @@ -837,12 +832,7 @@ fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, absolute_fil defer arena.free(bytes); const size = try loaded_executable_file.readAll(bytes); - - log.info("Loading module from {s}", .{absolute_file}); - return platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]) catch |err| { - log.warn("Failed to load module: {}", .{err}); - return err; - }; + return try platform.pjrt_client.deserializeAndLoad(platform.pjrt_api, bytes[0..size]); } fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecutable, absolute_file: [:0]const u8) !void { @@ -856,10 +846,13 @@ fn storePjrtExecutable(platform: Platform, loaded_executable: *pjrt.LoadedExecut defer serialize_result.deinit(); try loaded_executable_file.writeAll(serialize_result.bytes); - log.info("Stored module to {s}", .{absolute_file}); } -fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module) !*pjrt.LoadedExecutable { +fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module: mlir.Module, xla_dump_to_: ?[]const u8) !*pjrt.LoadedExecutable { + const tracer = Tracer.init("ai.zml.compilation"); + const compile_frame = tracer.frameStart("pjrt compilation"); + defer tracer.frameEnd(compile_frame, "pjrt compilation"); + const sharding = platform.sharding(); // NOTE(Corendos): Hack needed because Protobuf struct are not public. @@ -887,38 +880,27 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m }; // Let the arena deinit, zig-protobuf deinit is very slow. - if (platform.compilation_options.xla_dump_to) |xla_dump_to| { - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_dump_to" }, - .value = .{ .value = .{ .string_field = .{ .Const = xla_dump_to } } }, - }); + try options.env_option_overrides.ensureUnusedCapacity(arena, 16); + if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| { + setFlag(&options, "xla_dump_to", xla_dump_to); if (platform.compilation_options.xla_dump_fusion_visualization) { - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_dump_hlo_as_html" }, - .value = .{ .value = .{ .bool_field = true } }, - }); - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_dump_hlo_as_dot" }, - .value = .{ .value = .{ .bool_field = true } }, - }); - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_dump_fusion_visualization" }, - .value = .{ .value = .{ .bool_field = true } }, - }); + setFlag(&options, "xla_dump_hlo_as_html", true); + setFlag(&options, "xla_dump_hlo_as_dot", true); + setFlag(&options, "xla_dump_fusion_visualization", true); } } switch (platform.target) { .cuda => cuda_dir: { // NVIDIA recommends to disable Triton GEMM on JAX: // https://github.com/NVIDIA/JAX-Toolbox?tab=readme-ov-file#environment-variables - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_gpu_enable_triton_gemm" }, - .value = .{ .value = .{ .bool_field = false } }, - }); - // try options.env_option_overrides.append(arena, .{ - // .key = .{ .Const = "xla_gpu_enable_latency_hiding_scheduler" }, - // .value = .{ .value = .{ .bool_field = true } }, - // }); + setFlag(&options, "xla_gpu_enable_triton_gemm", false); + // setFlag(&options, "xla_gpu_enable_cudnn_fmha", true); + // setFlag(&options, "xla_gpu_fused_attention_use_cudnn_rng", true); + // setFlag(&options, "xla_gpu_enable_cudnn_layer_norm", true); + // setFlag(&options, "xla_gpu_enable_custom_fusions", true); + // setFlag(&options, "xla_gpu_enable_dynamic_slice_fusion", true); + // setFlag(&options, "xla_gpu_use_runtime_fusion", true); + // setFlag(&options, "xla_gpu_enable_latency_hiding_scheduler", true); var r_ = try runfiles.Runfiles.create(.{ .allocator = arena }) orelse { log.warn("Bazel runfile not found !", .{}); break :cuda_dir; @@ -928,22 +910,12 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m const r = r_.withSourceRepo(source_repo); const cuda_data_dir = (try r.rlocationAlloc(arena, "libpjrt_cuda/sandbox")).?; log.info("xla_gpu_cuda_data_dir: {s}", .{cuda_data_dir}); - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_gpu_cuda_data_dir" }, - .value = .{ - .value = .{ - .string_field = .{ .Const = cuda_data_dir }, - }, - }, - }); + setFlag(&options, "xla_gpu_cuda_data_dir", cuda_data_dir); }, .rocm => { // Disable Triton GEMM on ROCM. For some reason it's much, much slower when // enabled on CDNA and it's used on RDNA. Disable it altogether. - try options.env_option_overrides.append(arena, .{ - .key = .{ .Const = "xla_gpu_enable_triton_gemm" }, - .value = .{ .value = .{ .bool_field = false } }, - }); + setFlag(&options, "xla_gpu_enable_triton_gemm", false); }, else => {}, } @@ -956,6 +928,16 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m return loaded_executable; } +fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void { + const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) { + .Bool => .{ .value = .{ .bool_field = value } }, + .Int => .{ .value = .{ .int_field = value } }, + .Float => .{ .value = .{ .double_field = value } }, + else => .{ .value = .{ .string_field = .{ .Const = value } } }, + }; + options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option }); +} + /// Visit the given struct and recursively counts the number of tensors found. pub fn countTensors(v: anytype) usize { const LocalContext = struct { diff --git a/zml/platform.zig b/zml/platform.zig index ae515d6..6c13597 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -17,7 +17,6 @@ pub const available_targets = std.enums.values(Target); pub const CompilationOptions = struct { xla_dump_to: ?[]const u8 = null, xla_dump_fusion_visualization: bool = false, - cache_location: ?[]const u8 = null, sharding_enabled: bool = false, sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, }; @@ -82,8 +81,8 @@ pub const Platform = struct { /// Returns the Profiler for this API. /// Not all platform have a profiling api, for those the profiler object will do nothing. /// Platforms with known profiler extensions: cuda, xpu - pub fn getProfiler(self: Platform, options: pjrt.Profiler.Options) pjrt.Profiler { - return self.pjrt_client.getProfiler(self.pjrt_api, options); + pub fn getProfiler(self: Platform, options: ?pjrt.Profiler.Options) pjrt.Profiler { + return self.pjrt_client.getProfiler(self.pjrt_api, options orelse pjrt.Profiler.default_options); } }; diff --git a/zml/testing.zig b/zml/testing.zig index 92e2b70..58b10cd 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -13,14 +13,10 @@ var _platform: ?zml.Platform = null; pub fn env() zml.Platform { if (!builtin.is_test) @compileError("Cannot use zml.testing.env outside of a test block"); if (_platform == null) { - _test_compile_opts = if (initCacheDir()) - .{ - .cache_location = "/tmp/zml/tests/cache", - .xla_dump_to = "/tmp/zml/tests/", - .sharding_enabled = true, - } - else - .{}; + _test_compile_opts = .{ + .xla_dump_to = "/tmp/zml/tests/", + .sharding_enabled = true, + }; var ctx = zml.Context.init() catch unreachable; _platform = ctx.autoPlatform(.{}).withCompilationOptions(_test_compile_opts); @@ -31,12 +27,6 @@ pub fn env() zml.Platform { var _test_compile_opts: zml.CompilationOptions = .{}; -fn initCacheDir() bool { - const tmp = std.fs.openDirAbsolute("/tmp", .{}) catch return false; - tmp.makePath("zml/tests/cache") catch return false; - return true; -} - /// In neural network we generally care about the relative precision, /// but on a given dimension, if the output is close to 0, then the precision /// don't matter as much.