Fix pickle loader to use takeDelimiterInclusive for Zig 0.15.2 and update ZLS runner, buffer, callback, and tracer utilities accordingly.

This commit is contained in:
Tarry Singh 2025-12-15 11:08:19 +00:00
parent 1b8d0ac627
commit a3abf148b4
7 changed files with 175 additions and 170 deletions

View File

@ -8,7 +8,7 @@ pub fn main() !void {
defer std.process.argsFree(gpa, args); defer std.process.argsFree(gpa, args);
const file_path = args[1]; const file_path = args[1];
var file = try std.fs.cwd().openFile(file_path, .{ .mode = .read_only}); var file = try std.fs.cwd().openFile(file_path, .{ .mode = .read_only });
defer file.close(); defer file.close();
if (builtin.zig_version.major == 0 and builtin.zig_version.minor >= 15) { if (builtin.zig_version.major == 0 and builtin.zig_version.minor >= 15) {

View File

@ -1,5 +1,5 @@
/// Custom ZLS launcher. /// Custom ZLS launcher.
/// ///
/// Sets up paths to ZLS dependencies with Bazel runfiles. /// Sets up paths to ZLS dependencies with Bazel runfiles.
/// ///
/// This file is used as a template by `zls_write_runner_zig_src.bzl`. /// This file is used as a template by `zls_write_runner_zig_src.bzl`.

View File

@ -762,9 +762,13 @@ pub const Op = union(enum) {
} }
}; };
pub const ParseError = error{
UnknownPickleOp,
} || std.fmt.ParseIntError || std.Io.Reader.DelimiterError || std.mem.Allocator.Error;
/// Read a stream of bytes, and interpret it as a stream of Pickle operators. /// Read a stream of bytes, and interpret it as a stream of Pickle operators.
/// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies. /// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies.
pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ParseError![]const Op {
// It's not very efficient to interleave the results with the data copied from the stream, // It's not very efficient to interleave the results with the data copied from the stream,
// because growth event in the results ArrayList will lead to fragmentation. // because growth event in the results ArrayList will lead to fragmentation.
// Trying to mitigate that by using a generous default size. // Trying to mitigate that by using a generous default size.
@ -776,7 +780,8 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
const code: OpCode = @enumFromInt(try reader.takeByte()); const code: OpCode = @enumFromInt(try reader.takeByte());
const op: Op = switch (code) { const op: Op = switch (code) {
.int => int: { .int => int: {
const bytes = try reader.takeDelimiterExclusive('\n'); const bytes_with_ln = try reader.takeDelimiterInclusive('\n');
const bytes = bytes_with_ln[0 .. bytes_with_ln.len - 1];
// Legacy hack, see OpCode.int documentation // Legacy hack, see OpCode.int documentation
// We do this parsing right away to simplify downstream code. // We do this parsing right away to simplify downstream code.
break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0') break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0')
@ -830,17 +835,10 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
.dup => .dup, .dup => .dup,
.mark => .mark, .mark => .mark,
.pop_mark => .pop_mark, .pop_mark => .pop_mark,
// If we fail to parse delay the error to the evaluation. .get => .{ .get = try readTextU32(reader) },
.get => get: {
const digits = try reader.takeDelimiterExclusive('\n');
break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
},
.binget => .{ .get = try reader.takeByte() }, .binget => .{ .get = try reader.takeByte() },
.long_binget => .{ .get = try reader.takeInt(u32, .little) }, .long_binget => .{ .get = try reader.takeInt(u32, .little) },
.put => put: { .put => .{ .put = try readTextU32(reader) },
const digits = try reader.takeDelimiterExclusive('\n');
break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) };
},
.binput => .{ .put = try reader.takeByte() }, .binput => .{ .put = try reader.takeByte() },
.long_binput => .{ .put = try reader.takeInt(u32, .little) }, .long_binput => .{ .put = try reader.takeInt(u32, .little) },
.memoize => .memoize, .memoize => .memoize,
@ -882,7 +880,7 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
.binpersid => .binpersid, .binpersid => .binpersid,
_ => |unk_tag| { _ => |unk_tag| {
log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag}); log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag});
return error.NotSupported; return error.UnknownPickleOp;
}, },
}; };
try results.append(arena, op); try results.append(arena, op);
@ -891,154 +889,153 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
return results.items; return results.items;
} }
//TODO(gwenzek): re-enable these tests when the bug has been fixed. test "parse protocol 4" {
// test "parse protocol 4" { var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); defer arena.deinit();
// defer arena.deinit();
// const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only });
// var read_buffer: [1024]u8 = undefined; var read_buffer: [1024]u8 = undefined;
// var reader = file.reader(&read_buffer); var reader = file.reader(&read_buffer);
// const ops = try parse(arena.allocator(), &reader.interface); const ops = try parse(arena.allocator(), &reader.interface);
// // this can be obtained by running: `python -m pickletools simple_test_4.pickle` // this can be obtained by running: `python -m pickletools simple_test_4.pickle`
// const expected: []const Op = &.{ const expected: []const Op = &.{
// .{ .proto = 4 }, .{ .proto = 4 },
// .{ .frame = 119 }, .{ .frame = 119 },
// .empty_dict, .empty_dict,
// .memoize, .memoize,
// .mark, .mark,
// .{ .unicode = "hello" }, .{ .unicode = "hello" },
// .memoize, .memoize,
// .{ .unicode = "world" }, .{ .unicode = "world" },
// .memoize, .memoize,
// .{ .unicode = "int" }, .{ .unicode = "int" },
// .memoize, .memoize,
// .{ .int = 1 }, .{ .int = 1 },
// .{ .unicode = "float" }, .{ .unicode = "float" },
// .memoize, .memoize,
// .{ .binfloat = 3.141592 }, .{ .binfloat = 3.141592 },
// .{ .unicode = "list" }, .{ .unicode = "list" },
// .memoize, .memoize,
// .empty_list, .empty_list,
// .memoize, .memoize,
// .mark, .mark,
// .{ .int = 255 }, .{ .int = 255 },
// .{ .int = 1234 }, .{ .int = 1234 },
// .{ .int = -123 }, .{ .int = -123 },
// .{ .int = 1_000_000_000 }, .{ .int = 1_000_000_000 },
// .{ .binlong = &writeIntBuff(u48, 999_000_000_000) }, .{ .binlong = &writeIntBuff(u48, 999_000_000_000) },
// .{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) }, .{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) },
// .appends, .appends,
// .{ .unicode = "bool" }, .{ .unicode = "bool" },
// .memoize, .memoize,
// .{ .bool = false }, .{ .bool = false },
// .{ .unicode = "tuple" }, .{ .unicode = "tuple" },
// .memoize, .memoize,
// .{ .unicode = "a" }, .{ .unicode = "a" },
// .memoize, .memoize,
// .{ .int = 10 }, .{ .int = 10 },
// .tuple2, .tuple2,
// .memoize, .memoize,
// .setitems, .setitems,
// .stop, .stop,
// }; };
// try std.testing.expectEqualDeep(expected, ops); try std.testing.expectEqualDeep(expected, ops);
// } }
// test "parse protocol 0" { test "parse protocol 0" {
// // We also test protocol 0, cause it's more text oriented. // We also test protocol 0, cause it's more text oriented.
// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); var arena: std.heap.ArenaAllocator = .init(std.testing.allocator);
// defer arena.deinit(); defer arena.deinit();
// const pickle_0 = const pickle_0 =
// \\(dp0 \\(dp0
// \\Vhello \\Vhello
// \\p1 \\p1
// \\Vworld \\Vworld
// \\p2 \\p2
// \\sVint \\sVint
// \\p3 \\p3
// \\I1 \\I1
// \\sVfloat \\sVfloat
// \\p4 \\p4
// \\F3.141592 \\F3.141592
// \\sVlist \\sVlist
// \\p5 \\p5
// \\(lp6 \\(lp6
// \\I255 \\I255
// \\aI1234 \\aI1234
// \\aI-123 \\aI-123
// \\aI1000000000 \\aI1000000000
// \\aL999000000000L \\aL999000000000L
// \\aL999000000000000000000000000000L \\aL999000000000000000000000000000L
// \\asVbool \\asVbool
// \\p7 \\p7
// \\I00 \\I00
// \\sVtuple \\sVtuple
// \\p8 \\p8
// \\(Va \\(Va
// \\p9 \\p9
// \\I10 \\I10
// \\tp10 \\tp10
// \\s. \\s.
// ; ;
// var reader: std.Io.Reader = .fixed(pickle_0); var reader: std.Io.Reader = .fixed(pickle_0);
// const ops = try parse(arena.allocator(), &reader); const ops = try parse(arena.allocator(), &reader);
// var expected = [_]Op{ var expected = [_]Op{
// .mark, .mark,
// .dict, .dict,
// .{ .put = 0 }, .{ .put = 0 },
// .{ .unicode = "hello" }, .{ .unicode = "hello" },
// .{ .put = 1 }, .{ .put = 1 },
// .{ .unicode = "world" }, .{ .unicode = "world" },
// .{ .put = 2 }, .{ .put = 2 },
// .setitem, .setitem,
// .{ .unicode = "int" }, .{ .unicode = "int" },
// .{ .put = 3 }, .{ .put = 3 },
// .{ .int = 1 }, .{ .int = 1 },
// .setitem, .setitem,
// .{ .unicode = "float" }, .{ .unicode = "float" },
// .{ .put = 4 }, .{ .put = 4 },
// .{ .float = "3.141592" }, .{ .float = "3.141592" },
// .setitem, .setitem,
// .{ .unicode = "list" }, .{ .unicode = "list" },
// .{ .put = 5 }, .{ .put = 5 },
// .mark, .mark,
// .list, .list,
// .{ .put = 6 }, .{ .put = 6 },
// .{ .int = 255 }, .{ .int = 255 },
// .append, .append,
// .{ .int = 1234 }, .{ .int = 1234 },
// .append, .append,
// .{ .int = -123 }, .{ .int = -123 },
// .append, .append,
// .{ .int = 1_000_000_000 }, .{ .int = 1_000_000_000 },
// .append, .append,
// .{ .long = "999000000000L" }, .{ .long = "999000000000L" },
// .append, .append,
// .{ .long = "999000000000000000000000000000L" }, .{ .long = "999000000000000000000000000000L" },
// .append, .append,
// .setitem, .setitem,
// .{ .unicode = "bool" }, .{ .unicode = "bool" },
// .{ .put = 7 }, .{ .put = 7 },
// .{ .bool = false }, .{ .bool = false },
// .setitem, .setitem,
// .{ .unicode = "tuple" }, .{ .unicode = "tuple" },
// .{ .put = 8 }, .{ .put = 8 },
// .mark, .mark,
// .{ .unicode = "a" }, .{ .unicode = "a" },
// .{ .put = 9 }, .{ .put = 9 },
// .{ .int = 10 }, .{ .int = 10 },
// .tuple, .tuple,
// .{ .put = 10 }, .{ .put = 10 },
// .setitem, .setitem,
// .stop, .stop,
// }; };
// try std.testing.expectEqualDeep(&expected, ops); try std.testing.expectEqualDeep(&expected, ops);
// } }
fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 {
const T = std.meta.Int(.unsigned, 8 * len_bytes); const T = std.meta.Int(.unsigned, 8 * len_bytes);
@ -1054,8 +1051,11 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8)
return res; return res;
} }
fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 { fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ParseError![]const u8 {
const n = try reader.streamDelimiter(&alloc_writer.writer, '\n'); const n = reader.streamDelimiter(&alloc_writer.writer, '\n') catch |err| switch (err) {
error.WriteFailed => return error.OutOfMemory,
else => |e| return e,
};
std.debug.assert(try reader.takeByte() == '\n'); std.debug.assert(try reader.takeByte() == '\n');
const w = &alloc_writer.writer; const w = &alloc_writer.writer;
std.debug.assert(w.end == n); std.debug.assert(w.end == n);
@ -1064,3 +1064,9 @@ fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]
w.end = 0; w.end = 0;
return items; return items;
} }
fn readTextU32(reader: *std.Io.Reader) ParseError!u32 {
// Note we use takeDelimiterInclusive because the newline must always be there.
const digits = try reader.takeDelimiterInclusive('\n');
return try std.fmt.parseInt(u32, digits[0 .. digits.len - 1], 10);
}

View File

@ -280,10 +280,9 @@ pub const Buffer = struct {
}; };
} }
pub fn devicePtr(self: Buffer) u64 { pub fn devicePtr(self: Buffer) *anyopaque {
stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{}); stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{});
const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable; return self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable;
return @intFromPtr(opaque_ptr);
} }
/// Fetches the content of the given buffer into a stack variable of the given type. /// Fetches the content of the given buffer into a stack variable of the given type.
@ -362,7 +361,7 @@ pub const Buffer = struct {
} }
pub fn format(self: Buffer, writer: *std.Io.Writer) !void { pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
try writer.print("Buffer({f})@{x}", .{ self._shape, self.devicePtr() }); try writer.print("Buffer({f})@{x}", .{ self._shape, @intFromPtr(self.devicePtr()) });
} }
pub fn getMemory(self: Buffer) *const pjrt.Memory { pub fn getMemory(self: Buffer) *const pjrt.Memory {
@ -470,7 +469,7 @@ pub const Buffer = struct {
const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned }; const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned };
for (host_visible_memories) |memory| { for (host_visible_memories) |memory| {
const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory }); const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory });
const x_ptr: [*]u8 = @ptrFromInt(x.devicePtr()); const x_ptr: [*]u8 = @ptrCast(x.devicePtr());
@memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 }); @memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 });
const y = try x.getValue([6]u8); const y = try x.getValue([6]u8);

View File

@ -197,12 +197,12 @@ fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt
else else
.asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data); .asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data);
if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) { if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) {
log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.devicePtr() }); // log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()) });
zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| { zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| {
log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.devicePtr(), err }); log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()), err });
return .create(call_frame.api, .resource_exhausted, "host pinned OOM"); return .create(call_frame.api, .resource_exhausted, "host pinned OOM");
}; };
log.debug("--> {f} {x}", .{ zml_buffer, zml_buffer.devicePtr() }); // log.debug("--> {f} {x}", .{ zml_buffer, @intFromPtr(zml_buffer.devicePtr()) });
} }
callback_args[i] = zml_buffer; callback_args[i] = zml_buffer;
} }

View File

@ -17,9 +17,8 @@ zig_library(
main = "tools.zig", main = "tools.zig",
visibility = ["//visibility:public"], visibility = ["//visibility:public"],
deps = select({ deps = select({
"@platforms//os:macos": [ # TODO(cerisier): fix MacOsTracer
":macos_c", # "@platforms//os:macos": [ ":macos_c" ],
],
"//conditions:default": [], "//conditions:default": [],
}), }),
) )

View File

@ -3,7 +3,8 @@ const builtin = @import("builtin");
const c = @import("c"); const c = @import("c");
pub const Tracer = switch (builtin.os.tag) { pub const Tracer = switch (builtin.os.tag) {
.macos => MacOsTracer, // TODO(cerisier): fix MacOsTracer
// .macos => MacOsTracer,
.linux => if (@hasDecl(c, "ZML_RUNTIME_CUDA")) CudaTracer else FakeTracer, .linux => if (@hasDecl(c, "ZML_RUNTIME_CUDA")) CudaTracer else FakeTracer,
else => FakeTracer, else => FakeTracer,
}; };