From fbf1ecb8b748b1644281fbe5e76f8613ef9dc9e5 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 2 Jan 2025 16:36:13 +0000 Subject: [PATCH] Introduce Executable.getCompiledMemoryStats in PJRT. --- pjrt/pjrt.zig | 37 +++++++++++++++++++++++++++++++++++++ zml/pjrtx.zig | 1 + 2 files changed, 38 insertions(+) diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index c834878..2d28461 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -606,6 +606,43 @@ pub const Executable = opaque { .deleter = @ptrCast(ret.serialized_executable_deleter.?), }; } + + pub fn getCompiledMemoryStats(self: *const Executable, api: *const Api) ApiError!CompiledMemoryStats { + const ret = try api.call(.PJRT_Executable_GetCompiledMemoryStats, .{ + .executable = self.inner(), + }); + + return .{ + .generated_code_size_in_bytes = @intCast(ret.generated_code_size_in_bytes), + .argument_size_in_bytes = @intCast(ret.argument_size_in_bytes), + .output_size_in_bytes = @intCast(ret.output_size_in_bytes), + .alias_size_in_bytes = @intCast(ret.alias_size_in_bytes), + .temp_size_in_bytes = @intCast(ret.temp_size_in_bytes), + .host_generated_code_size_in_bytes = @intCast(ret.host_generated_code_size_in_bytes), + .host_argument_size_in_bytes = @intCast(ret.host_argument_size_in_bytes), + .host_output_size_in_bytes = @intCast(ret.host_output_size_in_bytes), + .host_alias_size_in_bytes = @intCast(ret.host_alias_size_in_bytes), + .host_temp_size_in_bytes = @intCast(ret.host_temp_size_in_bytes), + }; + } +}; + +pub const CompiledMemoryStats = struct { + // Mirrors xla::CompiledMemoryStats. + // Device default memory (e.g., HBM for GPU/TPU) usage stats. + generated_code_size_in_bytes: u64, + argument_size_in_bytes: u64, + output_size_in_bytes: u64, + // much: How argument is reused for output. + alias_size_in_bytes: u64, + temp_size_in_bytes: u64, + + // memory: Host usage stats. + host_generated_code_size_in_bytes: u64, + host_argument_size_in_bytes: u64, + host_output_size_in_bytes: u64, + host_alias_size_in_bytes: u64, + host_temp_size_in_bytes: u64, }; pub const LoadedExecutable = opaque { diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 13ef6b6..32a3d79 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -19,6 +19,7 @@ pub const Error = pjrt.Error; pub const GetCostAnalysisError = pjrt.GetCostAnalysisError; pub const SerializeResult = pjrt.SerializeResult; pub const Executable = pjrt.Executable; +pub const CompiledMemoryStats = pjrt.CompiledMemoryStats; pub const ExecuteError = ApiError; pub const Memory = pjrt.Memory; pub const Stream = pjrt.Stream;