diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index 9dbd5aa..2b0e65b 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -502,6 +502,13 @@ pub const Device = opaque { }) catch unreachable; return @intCast(ret.local_hardware_id); } + + pub fn addressableMemories(self: *const Device, api: *const Api) ApiError![]const *Memory { + const ret = try api.call(.PJRT_Device_AddressableMemories, .{ + .device = self.inner(), + }); + return @ptrCast(ret.memories[0..ret.num_memories]); + } }; pub const DeviceDescription = opaque {