Introduce Executable.getCompiledMemoryStats in PJRT.

This commit is contained in:
Tarry Singh 2025-01-02 16:36:13 +00:00
parent 4b1a3ff48a
commit fbf1ecb8b7
2 changed files with 38 additions and 0 deletions

View File

@ -606,6 +606,43 @@ pub const Executable = opaque {
.deleter = @ptrCast(ret.serialized_executable_deleter.?),
};
}
pub fn getCompiledMemoryStats(self: *const Executable, api: *const Api) ApiError!CompiledMemoryStats {
const ret = try api.call(.PJRT_Executable_GetCompiledMemoryStats, .{
.executable = self.inner(),
});
return .{
.generated_code_size_in_bytes = @intCast(ret.generated_code_size_in_bytes),
.argument_size_in_bytes = @intCast(ret.argument_size_in_bytes),
.output_size_in_bytes = @intCast(ret.output_size_in_bytes),
.alias_size_in_bytes = @intCast(ret.alias_size_in_bytes),
.temp_size_in_bytes = @intCast(ret.temp_size_in_bytes),
.host_generated_code_size_in_bytes = @intCast(ret.host_generated_code_size_in_bytes),
.host_argument_size_in_bytes = @intCast(ret.host_argument_size_in_bytes),
.host_output_size_in_bytes = @intCast(ret.host_output_size_in_bytes),
.host_alias_size_in_bytes = @intCast(ret.host_alias_size_in_bytes),
.host_temp_size_in_bytes = @intCast(ret.host_temp_size_in_bytes),
};
}
};
pub const CompiledMemoryStats = struct {
// Mirrors xla::CompiledMemoryStats.
// Device default memory (e.g., HBM for GPU/TPU) usage stats.
generated_code_size_in_bytes: u64,
argument_size_in_bytes: u64,
output_size_in_bytes: u64,
// much: How argument is reused for output.
alias_size_in_bytes: u64,
temp_size_in_bytes: u64,
// memory: Host usage stats.
host_generated_code_size_in_bytes: u64,
host_argument_size_in_bytes: u64,
host_output_size_in_bytes: u64,
host_alias_size_in_bytes: u64,
host_temp_size_in_bytes: u64,
};
pub const LoadedExecutable = opaque {

View File

@ -19,6 +19,7 @@ pub const Error = pjrt.Error;
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
pub const SerializeResult = pjrt.SerializeResult;
pub const Executable = pjrt.Executable;
pub const CompiledMemoryStats = pjrt.CompiledMemoryStats;
pub const ExecuteError = ApiError;
pub const Memory = pjrt.Memory;
pub const Stream = pjrt.Stream;