diff --git a/zml/platform.zig b/zml/platform.zig index 410d60b..c5872a5 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -89,7 +89,9 @@ pub const Platform = struct { const memory_target: pjrt.Memory.Kind = switch (memory) { .host_unpinned => switch (platform.target) { // Cuda doesn't have host_unpinned. - .cuda => .host_pinned, + // ROCm doesn't seem to have it either. + // TODO(gwenzek): investigate why it was not forced before. + .cuda, .rocm => .host_pinned, else => .host_unpinned, }, inline else => |t| t,