Introduce Executable.getCompiledMemoryStats in PJRT.
This commit is contained in:
parent
4b1a3ff48a
commit
fbf1ecb8b7
@ -606,6 +606,43 @@ pub const Executable = opaque {
|
||||
.deleter = @ptrCast(ret.serialized_executable_deleter.?),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn getCompiledMemoryStats(self: *const Executable, api: *const Api) ApiError!CompiledMemoryStats {
|
||||
const ret = try api.call(.PJRT_Executable_GetCompiledMemoryStats, .{
|
||||
.executable = self.inner(),
|
||||
});
|
||||
|
||||
return .{
|
||||
.generated_code_size_in_bytes = @intCast(ret.generated_code_size_in_bytes),
|
||||
.argument_size_in_bytes = @intCast(ret.argument_size_in_bytes),
|
||||
.output_size_in_bytes = @intCast(ret.output_size_in_bytes),
|
||||
.alias_size_in_bytes = @intCast(ret.alias_size_in_bytes),
|
||||
.temp_size_in_bytes = @intCast(ret.temp_size_in_bytes),
|
||||
.host_generated_code_size_in_bytes = @intCast(ret.host_generated_code_size_in_bytes),
|
||||
.host_argument_size_in_bytes = @intCast(ret.host_argument_size_in_bytes),
|
||||
.host_output_size_in_bytes = @intCast(ret.host_output_size_in_bytes),
|
||||
.host_alias_size_in_bytes = @intCast(ret.host_alias_size_in_bytes),
|
||||
.host_temp_size_in_bytes = @intCast(ret.host_temp_size_in_bytes),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub const CompiledMemoryStats = struct {
|
||||
// Mirrors xla::CompiledMemoryStats.
|
||||
// Device default memory (e.g., HBM for GPU/TPU) usage stats.
|
||||
generated_code_size_in_bytes: u64,
|
||||
argument_size_in_bytes: u64,
|
||||
output_size_in_bytes: u64,
|
||||
// much: How argument is reused for output.
|
||||
alias_size_in_bytes: u64,
|
||||
temp_size_in_bytes: u64,
|
||||
|
||||
// memory: Host usage stats.
|
||||
host_generated_code_size_in_bytes: u64,
|
||||
host_argument_size_in_bytes: u64,
|
||||
host_output_size_in_bytes: u64,
|
||||
host_alias_size_in_bytes: u64,
|
||||
host_temp_size_in_bytes: u64,
|
||||
};
|
||||
|
||||
pub const LoadedExecutable = opaque {
|
||||
|
||||
@ -19,6 +19,7 @@ pub const Error = pjrt.Error;
|
||||
pub const GetCostAnalysisError = pjrt.GetCostAnalysisError;
|
||||
pub const SerializeResult = pjrt.SerializeResult;
|
||||
pub const Executable = pjrt.Executable;
|
||||
pub const CompiledMemoryStats = pjrt.CompiledMemoryStats;
|
||||
pub const ExecuteError = ApiError;
|
||||
pub const Memory = pjrt.Memory;
|
||||
pub const Stream = pjrt.Stream;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user