Fix llama token handling and remove redundant prompt token reuse in core Zig modules (aio, module, nn, pjrtx, tensor)
This commit is contained in:
parent
394e63e273
commit
a34190679b
@ -574,6 +574,12 @@ pub fn unloadBuffers(model: anytype) void {
|
|||||||
}).cb, {}, model);
|
}).cb, {}, model);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// deinit all buffers in the given struct
|
||||||
|
pub fn awaitAll(buffers: anytype) !void {
|
||||||
|
// TODO: implement once we have async buffers.
|
||||||
|
_ = buffers;
|
||||||
|
}
|
||||||
|
|
||||||
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void {
|
||||||
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received ";
|
||||||
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
const type_info, const T = switch (@typeInfo(@TypeOf(obj))) {
|
||||||
|
|||||||
@ -891,7 +891,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
||||||
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||||
setFlag(&options, "xla_dump_to", xla_dump_to);
|
setFlag(&options, "xla_dump_to", xla_dump_to);
|
||||||
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
setFlag(&options, "xla_dump_hlo_as_proto", true);
|
||||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||||
setFlag(&options, "xla_dump_fusion_visualization", true);
|
setFlag(&options, "xla_dump_fusion_visualization", true);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -934,7 +934,7 @@ pub const PartialSoftmax = struct {
|
|||||||
/// Returns intermediary results to allow aggregating later.
|
/// Returns intermediary results to allow aggregating later.
|
||||||
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax {
|
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax {
|
||||||
const a = self.axis(axis);
|
const a = self.axis(axis);
|
||||||
const max_val = self.max(a);
|
const max_val = self.max(a).maximum(Tensor.scalar(-1e16, self.dtype()));
|
||||||
const out = self.sub(max_val.broad(self.shape())).exp();
|
const out = self.sub(max_val.broad(self.shape())).exp();
|
||||||
return .{
|
return .{
|
||||||
.values = out,
|
.values = out,
|
||||||
|
|||||||
@ -63,7 +63,7 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs;
|
||||||
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer {
|
pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer {
|
||||||
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
|
const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args });
|
||||||
if (event_) |event__| {
|
if (event_) |event__| {
|
||||||
const event: *Event = @ptrCast(event__);
|
const event: *Event = @ptrCast(event__);
|
||||||
@ -211,7 +211,7 @@ pub const Event = opaque {
|
|||||||
return self.inner().getEventError(api);
|
return self.inner().getEventError(api);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn await_(self: *Event, api: *const Api) !void {
|
pub fn await_(self: *Event, api: *const Api) ApiError!void {
|
||||||
defer self.deinit(api);
|
defer self.deinit(api);
|
||||||
|
|
||||||
if (self.isReady(api)) {
|
if (self.isReady(api)) {
|
||||||
|
|||||||
@ -2420,11 +2420,12 @@ pub const Tensor = struct {
|
|||||||
update_fn: *const fn (Tensor, Tensor) Tensor = increment,
|
update_fn: *const fn (Tensor, Tensor) Tensor = increment,
|
||||||
|
|
||||||
pub fn increment(old_value: Tensor, new_value: Tensor) Tensor {
|
pub fn increment(old_value: Tensor, new_value: Tensor) Tensor {
|
||||||
return old_value.add(new_value.convert(old_value.dtype()));
|
return old_value.add(new_value);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn override(old_value: Tensor, new_value: Tensor) Tensor {
|
pub fn override(old_value: Tensor, new_value: Tensor) Tensor {
|
||||||
return new_value.convert(old_value.dtype());
|
_ = old_value;
|
||||||
|
return new_value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user