diff --git a/zml/aio.zig b/zml/aio.zig index 175ffb3..27abbd8 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -574,6 +574,12 @@ pub fn unloadBuffers(model: anytype) void { }).cb, {}, model); } +/// deinit all buffers in the given struct +pub fn awaitAll(buffers: anytype) !void { + // TODO: implement once we have async buffers. + _ = buffers; +} + fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *PrefixBuilder, buffer_store: BufferStore, obj: anytype, platform: zml.Platform) !void { const err_msg = "visitStructAndLoadBuffer must be called with a pointer to type. Received "; const type_info, const T = switch (@typeInfo(@TypeOf(obj))) { diff --git a/zml/module.zig b/zml/module.zig index 0c1bc19..253125c 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -891,7 +891,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m try options.env_option_overrides.ensureUnusedCapacity(arena, 16); if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| { setFlag(&options, "xla_dump_to", xla_dump_to); - setFlag(&options, "xla_dump_hlo_as_dot", true); + setFlag(&options, "xla_dump_hlo_as_proto", true); if (platform.compilation_options.xla_dump_fusion_visualization) { setFlag(&options, "xla_dump_fusion_visualization", true); } diff --git a/zml/nn.zig b/zml/nn.zig index 6f756e9..7d10a3a 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -934,7 +934,7 @@ pub const PartialSoftmax = struct { /// Returns intermediary results to allow aggregating later. pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax { const a = self.axis(axis); - const max_val = self.max(a); + const max_val = self.max(a).maximum(Tensor.scalar(-1e16, self.dtype())); const out = self.sub(max_val.broad(self.shape())).exp(); return .{ .values = out, diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index ae17f51..a642979 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -63,7 +63,7 @@ pub const Client = opaque { } pub const BufferFromHostBufferArgs = pjrt.Client.BufferFromHostBufferArgs; - pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) !*Buffer { + pub fn bufferFromHostBuffer(self: *const Client, api: *const Api, args: BufferFromHostBufferArgs) ApiError!*Buffer { const buffer, const event_ = try asynk.callBlocking(pjrt.Client.bufferFromHostBuffer, .{ self.inner(), api, args }); if (event_) |event__| { const event: *Event = @ptrCast(event__); @@ -211,7 +211,7 @@ pub const Event = opaque { return self.inner().getEventError(api); } - pub fn await_(self: *Event, api: *const Api) !void { + pub fn await_(self: *Event, api: *const Api) ApiError!void { defer self.deinit(api); if (self.isReady(api)) { diff --git a/zml/tensor.zig b/zml/tensor.zig index 23ab003..5529b6b 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -2420,11 +2420,12 @@ pub const Tensor = struct { update_fn: *const fn (Tensor, Tensor) Tensor = increment, pub fn increment(old_value: Tensor, new_value: Tensor) Tensor { - return old_value.add(new_value.convert(old_value.dtype())); + return old_value.add(new_value); } pub fn override(old_value: Tensor, new_value: Tensor) Tensor { - return new_value.convert(old_value.dtype()); + _ = old_value; + return new_value; } };