Fix llama token handling and remove redundant prompt token reuse in core Zig modules (aio, module, nn, pjrtx, tensor)

This commit is contained in:
Tarry Singh 2024-05-02 17:10:11 +00:00
parent 394e63e273
commit a34190679b
5 changed files with 13 additions and 6 deletions

View File

@ -574,6 +574,12 @@ pub fn unloadBuffers(model: anytype) void {
}).cb, {}, model);
}
/// deinit all buffers in the given struct
pub fn awaitAll(buffers: anytype) !void {
// TODO: implement once we have async buffers.
_ = buffers;
}
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {

View File

@ -891,7 +891,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
setFlag(&options, "xla_dump_to", xla_dump_to);
setFlag(&options, "xla_dump_hlo_as_dot", true);
setFlag(&options, "xla_dump_hlo_as_proto", true);
if (platform.compilation_options.xla_dump_fusion_visualization) {
setFlag(&options, "xla_dump_fusion_visualization", true);
}

View File

@ -934,7 +934,7 @@ pub const PartialSoftmax = struct {
/// Returns intermediary results to allow aggregating later.
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax {
const a = self.axis(axis);
const max_val = self.max(a);
const max_val = self.max(a).maximum(Tensor.scalar(-1e16, self.dtype()));
const out = self.sub(max_val.broad(self.shape())).exp();
return .{
.values = out,

View File

@ -63,7 +63,7 @@ pub const Client = opaque {
}
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer {
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer {
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
if (event_) |event__| {
const event: *Event = @ptrCast(event__);
@ -211,7 +211,7 @@ pub const Event = opaque {
return self.inner().getEventError(api);
}
pub fn await_(self: *Event, api: *const Api) !void {
pub fn await_(self: *Event, api: *const Api) ApiError!void {
defer self.deinit(api);
if (self.isReady(api)) {

View File

@ -2420,11 +2420,12 @@ pub const Tensor = struct {
update_fn: *const fn (Tensor, Tensor) Tensor = increment,
pub fn increment(old_value: Tensor, new_value: Tensor) Tensor {
return old_value.add(new_value.convert(old_value.dtype()));
return old_value.add(new_value);
}
pub fn override(old_value: Tensor, new_value: Tensor) Tensor {
return new_value.convert(old_value.dtype());
_ = old_value;
return new_value;
}
};