pjrt: bind PJRT_Client_CreateUninitializedBuffer.
This commit is contained in:
parent
8456a0d073
commit
ff1433d998
@ -483,6 +483,28 @@ pub const Client = opaque {
|
|||||||
});
|
});
|
||||||
return @ptrCast(ret.transfer_manager.?);
|
return @ptrCast(ret.transfer_manager.?);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const CreateUninitializedBufferArgs = struct {
|
||||||
|
dims: []const i64,
|
||||||
|
element_type: BufferType,
|
||||||
|
layout: MemoryLayout,
|
||||||
|
device: ?*const Device = null,
|
||||||
|
memory: ?*const Memory = null,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn createUninitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
||||||
|
var layout = args.layout.toCStruct();
|
||||||
|
const ret = try api.call(.PJRT_Client_CreateUninitializedBuffer, .{
|
||||||
|
.client = self.inner(),
|
||||||
|
.shape_dims = args.dims.ptr,
|
||||||
|
.shape_num_dims = @intCast(args.dims.len),
|
||||||
|
.shape_element_type = @intFromEnum(args.element_type),
|
||||||
|
.shape_layout = @ptrCast(&layout),
|
||||||
|
.device = @ptrCast(@constCast(args.device)),
|
||||||
|
.memory = @ptrCast(@constCast(args.memory)),
|
||||||
|
});
|
||||||
|
return @ptrCast(ret.buffer.?);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const MemoryStats = struct {
|
pub const MemoryStats = struct {
|
||||||
|
|||||||
@ -403,6 +403,71 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
return Buffer{ ._shape = self._shape, ._shards = new_shards, ._api = self._api };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const UnitializedOptions = struct {
|
||||||
|
memory: ?pjrt.Memory.Kind = null,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub fn uninitialized(platform: Platform, shape_: Shape, opts: UnitializedOptions) !Buffer {
|
||||||
|
var res: Buffer = .{
|
||||||
|
._api = platform.pjrt_api,
|
||||||
|
._shape = shape_,
|
||||||
|
._shards = .{},
|
||||||
|
};
|
||||||
|
errdefer for (res._shards.slice()) |shard| {
|
||||||
|
shard.deinit(platform.pjrt_api);
|
||||||
|
};
|
||||||
|
|
||||||
|
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||||
|
var minor_to_major: [Shape.MAX_RANK]i64 = undefined;
|
||||||
|
for (0..Shape.MAX_RANK) |i| {
|
||||||
|
minor_to_major[i] = @intCast(Shape.MAX_RANK - i - 1);
|
||||||
|
}
|
||||||
|
break :blk minor_to_major;
|
||||||
|
};
|
||||||
|
|
||||||
|
// TODO: support more advanced sharding specs
|
||||||
|
stdx.debug.assert(platform.sharding().num_replicas == 1, "ZML doesn't support num_replicas > 1 for now, got: {}", .{platform.sharding()});
|
||||||
|
const sharding_ax: ?u3 = std.simd.firstTrue(shape_._sharding_info);
|
||||||
|
const n_partitions = platform.sharding().num_partitions;
|
||||||
|
const shard_shape = if (sharding_ax) |ax| s: {
|
||||||
|
// This kind of sharding error should be detected earlier on.
|
||||||
|
stdx.debug.assert(@rem(shape_.dim(ax), n_partitions) == 0, "Buffer.uninitialized() expects the sharding axis {} to have a dimension divisble by the number of devices ({}).", .{ ax, n_partitions });
|
||||||
|
const shard_shape = shape_.set(ax, @divExact(shape_.dim(ax), n_partitions));
|
||||||
|
break :s shard_shape;
|
||||||
|
} else shape_;
|
||||||
|
|
||||||
|
const buffer_type = bufferTypeFromDtype(shape_.dtype());
|
||||||
|
const devices = platform.getDevices();
|
||||||
|
for (0..n_partitions) |i| {
|
||||||
|
var args = pjrt.Client.CreateUninitializedBufferArgs{
|
||||||
|
.dims = shard_shape.dims(),
|
||||||
|
.element_type = buffer_type,
|
||||||
|
.layout = .{
|
||||||
|
.tiled = .{
|
||||||
|
.minor_to_major = minor_to_major[Shape.MAX_RANK - shape_.rank() ..],
|
||||||
|
.tile_dims = &.{},
|
||||||
|
.tile_dims_sizes = &.{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
if (opts.memory) |memory_kind| {
|
||||||
|
const memories = try devices[i].addressableMemories(platform.pjrt_api);
|
||||||
|
const memory = for (memories) |m| {
|
||||||
|
const kind = m.kind(platform.pjrt_api);
|
||||||
|
if (kind == memory_kind) break m;
|
||||||
|
} else return error.NotFound;
|
||||||
|
args.memory = memory;
|
||||||
|
} else {
|
||||||
|
args.device = devices[i];
|
||||||
|
}
|
||||||
|
const pjrt_buffer = try platform.pjrt_client.createUnitializedBuffer(platform.pjrt_api, args);
|
||||||
|
|
||||||
|
res._shards.appendAssumeCapacity(pjrt_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
pub fn bufferTypeFromDtype(dt: DataType) pjrt.BufferType {
|
||||||
|
|||||||
@ -128,6 +128,12 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const CreateUninitializedBufferArgs = pjrt.Client.CreateUninitializedBufferArgs;
|
||||||
|
|
||||||
|
pub fn createUnitializedBuffer(self: *const Client, api: *const Api, args: CreateUninitializedBufferArgs) ApiError!*Buffer {
|
||||||
|
return @ptrCast(try self.inner().createUninitializedBuffer(api, args));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Buffer = opaque {
|
pub const Buffer = opaque {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user