diff --git a/third_party/zls/cat.zig b/third_party/zls/cat.zig index ebca6f5..acfa8ea 100644 --- a/third_party/zls/cat.zig +++ b/third_party/zls/cat.zig @@ -8,7 +8,7 @@ pub fn main() !void { defer std.process.argsFree(gpa, args); const file_path = args[1]; - var file = try std.fs.cwd().openFile(file_path, .{ .mode = .read_only}); + var file = try std.fs.cwd().openFile(file_path, .{ .mode = .read_only }); defer file.close(); if (builtin.zig_version.major == 0 and builtin.zig_version.minor >= 15) { diff --git a/third_party/zls/zls_runner.zig b/third_party/zls/zls_runner.zig index 81555e0..0f011ec 100644 --- a/third_party/zls/zls_runner.zig +++ b/third_party/zls/zls_runner.zig @@ -1,5 +1,5 @@ /// Custom ZLS launcher. -/// +/// /// Sets up paths to ZLS dependencies with Bazel runfiles. /// /// This file is used as a template by `zls_write_runner_zig_src.bzl`. diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 1a04541..8172ddc 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -762,9 +762,13 @@ pub const Op = union(enum) { } }; +pub const ParseError = error{ + UnknownPickleOp, +} || std.fmt.ParseIntError || std.Io.Reader.DelimiterError || std.mem.Allocator.Error; + /// Read a stream of bytes, and interpret it as a stream of Pickle operators. /// The given allocator needs to be an arena cause we are not aligning allocations to avoid copies. -pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { +pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ParseError![]const Op { // It's not very efficient to interleave the results with the data copied from the stream, // because growth event in the results ArrayList will lead to fragmentation. // Trying to mitigate that by using a generous default size. @@ -776,7 +780,8 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { const code: OpCode = @enumFromInt(try reader.takeByte()); const op: Op = switch (code) { .int => int: { - const bytes = try reader.takeDelimiterExclusive('\n'); + const bytes_with_ln = try reader.takeDelimiterInclusive('\n'); + const bytes = bytes_with_ln[0 .. bytes_with_ln.len - 1]; // Legacy hack, see OpCode.int documentation // We do this parsing right away to simplify downstream code. break :int if (bytes.len == 2 and bytes[0] == '0' and bytes[1] == '0') @@ -830,17 +835,10 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { .dup => .dup, .mark => .mark, .pop_mark => .pop_mark, - // If we fail to parse delay the error to the evaluation. - .get => get: { - const digits = try reader.takeDelimiterExclusive('\n'); - break :get .{ .get = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) }; - }, + .get => .{ .get = try readTextU32(reader) }, .binget => .{ .get = try reader.takeByte() }, .long_binget => .{ .get = try reader.takeInt(u32, .little) }, - .put => put: { - const digits = try reader.takeDelimiterExclusive('\n'); - break :put .{ .put = std.fmt.parseInt(u32, digits, 10) catch std.math.maxInt(u32) }; - }, + .put => .{ .put = try readTextU32(reader) }, .binput => .{ .put = try reader.takeByte() }, .long_binput => .{ .put = try reader.takeInt(u32, .little) }, .memoize => .memoize, @@ -882,7 +880,7 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { .binpersid => .binpersid, _ => |unk_tag| { log.err("Unknow pickle operator {}, note we are only supporting pickle protocol up to version 5.", .{unk_tag}); - return error.NotSupported; + return error.UnknownPickleOp; }, }; try results.append(arena, op); @@ -891,154 +889,153 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { return results.items; } -//TODO(gwenzek): re-enable these tests when the bug has been fixed. -// test "parse protocol 4" { -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); +test "parse protocol 4" { + var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); + defer arena.deinit(); -// const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); -// var read_buffer: [1024]u8 = undefined; -// var reader = file.reader(&read_buffer); -// const ops = try parse(arena.allocator(), &reader.interface); + const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test_4.pickle", .{ .mode = .read_only }); + var read_buffer: [1024]u8 = undefined; + var reader = file.reader(&read_buffer); + const ops = try parse(arena.allocator(), &reader.interface); -// // this can be obtained by running: `python -m pickletools simple_test_4.pickle` -// const expected: []const Op = &.{ -// .{ .proto = 4 }, -// .{ .frame = 119 }, -// .empty_dict, -// .memoize, -// .mark, -// .{ .unicode = "hello" }, -// .memoize, -// .{ .unicode = "world" }, -// .memoize, -// .{ .unicode = "int" }, -// .memoize, -// .{ .int = 1 }, -// .{ .unicode = "float" }, -// .memoize, -// .{ .binfloat = 3.141592 }, -// .{ .unicode = "list" }, -// .memoize, -// .empty_list, -// .memoize, -// .mark, -// .{ .int = 255 }, -// .{ .int = 1234 }, -// .{ .int = -123 }, -// .{ .int = 1_000_000_000 }, -// .{ .binlong = &writeIntBuff(u48, 999_000_000_000) }, -// .{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) }, -// .appends, -// .{ .unicode = "bool" }, -// .memoize, -// .{ .bool = false }, -// .{ .unicode = "tuple" }, -// .memoize, -// .{ .unicode = "a" }, -// .memoize, -// .{ .int = 10 }, -// .tuple2, -// .memoize, -// .setitems, -// .stop, -// }; -// try std.testing.expectEqualDeep(expected, ops); -// } + // this can be obtained by running: `python -m pickletools simple_test_4.pickle` + const expected: []const Op = &.{ + .{ .proto = 4 }, + .{ .frame = 119 }, + .empty_dict, + .memoize, + .mark, + .{ .unicode = "hello" }, + .memoize, + .{ .unicode = "world" }, + .memoize, + .{ .unicode = "int" }, + .memoize, + .{ .int = 1 }, + .{ .unicode = "float" }, + .memoize, + .{ .binfloat = 3.141592 }, + .{ .unicode = "list" }, + .memoize, + .empty_list, + .memoize, + .mark, + .{ .int = 255 }, + .{ .int = 1234 }, + .{ .int = -123 }, + .{ .int = 1_000_000_000 }, + .{ .binlong = &writeIntBuff(u48, 999_000_000_000) }, + .{ .binlong = &writeIntBuff(u104, 999_000_000_000_000_000_000_000_000_000) }, + .appends, + .{ .unicode = "bool" }, + .memoize, + .{ .bool = false }, + .{ .unicode = "tuple" }, + .memoize, + .{ .unicode = "a" }, + .memoize, + .{ .int = 10 }, + .tuple2, + .memoize, + .setitems, + .stop, + }; + try std.testing.expectEqualDeep(expected, ops); +} -// test "parse protocol 0" { -// // We also test protocol 0, cause it's more text oriented. -// var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); -// defer arena.deinit(); +test "parse protocol 0" { + // We also test protocol 0, cause it's more text oriented. + var arena: std.heap.ArenaAllocator = .init(std.testing.allocator); + defer arena.deinit(); -// const pickle_0 = -// \\(dp0 -// \\Vhello -// \\p1 -// \\Vworld -// \\p2 -// \\sVint -// \\p3 -// \\I1 -// \\sVfloat -// \\p4 -// \\F3.141592 -// \\sVlist -// \\p5 -// \\(lp6 -// \\I255 -// \\aI1234 -// \\aI-123 -// \\aI1000000000 -// \\aL999000000000L -// \\aL999000000000000000000000000000L -// \\asVbool -// \\p7 -// \\I00 -// \\sVtuple -// \\p8 -// \\(Va -// \\p9 -// \\I10 -// \\tp10 -// \\s. -// ; + const pickle_0 = + \\(dp0 + \\Vhello + \\p1 + \\Vworld + \\p2 + \\sVint + \\p3 + \\I1 + \\sVfloat + \\p4 + \\F3.141592 + \\sVlist + \\p5 + \\(lp6 + \\I255 + \\aI1234 + \\aI-123 + \\aI1000000000 + \\aL999000000000L + \\aL999000000000000000000000000000L + \\asVbool + \\p7 + \\I00 + \\sVtuple + \\p8 + \\(Va + \\p9 + \\I10 + \\tp10 + \\s. + ; -// var reader: std.Io.Reader = .fixed(pickle_0); -// const ops = try parse(arena.allocator(), &reader); + var reader: std.Io.Reader = .fixed(pickle_0); + const ops = try parse(arena.allocator(), &reader); -// var expected = [_]Op{ -// .mark, -// .dict, -// .{ .put = 0 }, -// .{ .unicode = "hello" }, -// .{ .put = 1 }, -// .{ .unicode = "world" }, -// .{ .put = 2 }, -// .setitem, -// .{ .unicode = "int" }, -// .{ .put = 3 }, -// .{ .int = 1 }, -// .setitem, -// .{ .unicode = "float" }, -// .{ .put = 4 }, -// .{ .float = "3.141592" }, -// .setitem, -// .{ .unicode = "list" }, -// .{ .put = 5 }, -// .mark, -// .list, -// .{ .put = 6 }, -// .{ .int = 255 }, -// .append, -// .{ .int = 1234 }, -// .append, -// .{ .int = -123 }, -// .append, -// .{ .int = 1_000_000_000 }, -// .append, -// .{ .long = "999000000000L" }, -// .append, -// .{ .long = "999000000000000000000000000000L" }, -// .append, -// .setitem, -// .{ .unicode = "bool" }, -// .{ .put = 7 }, -// .{ .bool = false }, -// .setitem, -// .{ .unicode = "tuple" }, -// .{ .put = 8 }, -// .mark, -// .{ .unicode = "a" }, -// .{ .put = 9 }, -// .{ .int = 10 }, -// .tuple, -// .{ .put = 10 }, -// .setitem, -// .stop, -// }; -// try std.testing.expectEqualDeep(&expected, ops); -// } + var expected = [_]Op{ + .mark, + .dict, + .{ .put = 0 }, + .{ .unicode = "hello" }, + .{ .put = 1 }, + .{ .unicode = "world" }, + .{ .put = 2 }, + .setitem, + .{ .unicode = "int" }, + .{ .put = 3 }, + .{ .int = 1 }, + .setitem, + .{ .unicode = "float" }, + .{ .put = 4 }, + .{ .float = "3.141592" }, + .setitem, + .{ .unicode = "list" }, + .{ .put = 5 }, + .mark, + .list, + .{ .put = 6 }, + .{ .int = 255 }, + .append, + .{ .int = 1234 }, + .append, + .{ .int = -123 }, + .append, + .{ .int = 1_000_000_000 }, + .append, + .{ .long = "999000000000L" }, + .append, + .{ .long = "999000000000000000000000000000L" }, + .append, + .setitem, + .{ .unicode = "bool" }, + .{ .put = 7 }, + .{ .bool = false }, + .setitem, + .{ .unicode = "tuple" }, + .{ .put = 8 }, + .mark, + .{ .unicode = "a" }, + .{ .put = 9 }, + .{ .int = 10 }, + .tuple, + .{ .put = 10 }, + .setitem, + .stop, + }; + try std.testing.expectEqualDeep(&expected, ops); +} fn _readSlice(reader: anytype, allocator: std.mem.Allocator, comptime len_bytes: u8) ![]u8 { const T = std.meta.Int(.unsigned, 8 * len_bytes); @@ -1054,8 +1051,11 @@ fn writeIntBuff(comptime T: type, value: T) [@divExact(@typeInfo(T).int.bits, 8) return res; } -fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![]const u8 { - const n = try reader.streamDelimiter(&alloc_writer.writer, '\n'); +fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ParseError![]const u8 { + const n = reader.streamDelimiter(&alloc_writer.writer, '\n') catch |err| switch (err) { + error.WriteFailed => return error.OutOfMemory, + else => |e| return e, + }; std.debug.assert(try reader.takeByte() == '\n'); const w = &alloc_writer.writer; std.debug.assert(w.end == n); @@ -1064,3 +1064,9 @@ fn readLine(reader: *std.Io.Reader, alloc_writer: *std.Io.Writer.Allocating) ![] w.end = 0; return items; } + +fn readTextU32(reader: *std.Io.Reader) ParseError!u32 { + // Note we use takeDelimiterInclusive because the newline must always be there. + const digits = try reader.takeDelimiterInclusive('\n'); + return try std.fmt.parseInt(u32, digits[0 .. digits.len - 1], 10); +} diff --git a/zml/buffer.zig b/zml/buffer.zig index a4a9b19..e76039d 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -280,10 +280,9 @@ pub const Buffer = struct { }; } - pub fn devicePtr(self: Buffer) u64 { + pub fn devicePtr(self: Buffer) *anyopaque { stdx.debug.internalAssert(!self.hasShardedAxis(), "TODO: support sharded Buffer", .{}); - const opaque_ptr: *anyopaque = self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable; - return @intFromPtr(opaque_ptr); + return self._shards.get(0).getOpaqueDeviceMemoryDataPointer(self._api) catch unreachable; } /// Fetches the content of the given buffer into a stack variable of the given type. @@ -362,7 +361,7 @@ pub const Buffer = struct { } pub fn format(self: Buffer, writer: *std.Io.Writer) !void { - try writer.print("Buffer({f})@{x}", .{ self._shape, self.devicePtr() }); + try writer.print("Buffer({f})@{x}", .{ self._shape, @intFromPtr(self.devicePtr()) }); } pub fn getMemory(self: Buffer) *const pjrt.Memory { @@ -470,7 +469,7 @@ pub const Buffer = struct { const host_visible_memories: []const Memory = &.{ .host_pinned, .host_unpinned }; for (host_visible_memories) |memory| { const x = try uninitialized(platform, .init(.{6}, .u8), .{ .memory = memory }); - const x_ptr: [*]u8 = @ptrFromInt(x.devicePtr()); + const x_ptr: [*]u8 = @ptrCast(x.devicePtr()); @memcpy(x_ptr, &[_]u8{ 104, 101, 108, 108, 111, 33 }); const y = try x.getValue([6]u8); diff --git a/zml/callback.zig b/zml/callback.zig index ea7a4b2..709e6c8 100644 --- a/zml/callback.zig +++ b/zml/callback.zig @@ -197,12 +197,12 @@ fn CallbackImpl(comptime Callback: type, call_frame: *pjrt.ffi.CallFrame) ?*pjrt else .asViewOfDeviceBuffer(platform, shape, null, ffi_buffer.data); if (opts.copy_inputs_to_host_pinned and platform.target != .cpu) { - log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, zml_buffer.devicePtr() }); + // log.debug("Copying argument {d} {f} {x} to host_pinned memory !", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()) }); zml_buffer = zml_buffer.copyToMemory(platform, .host_pinned, .{ .wait = true }) catch |err| { - log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, zml_buffer.devicePtr(), err }); + log.err("Failed to copy input buffer {d} {f} {x} to host_pinned: {}", .{ i, zml_buffer, @intFromPtr(zml_buffer.devicePtr()), err }); return .create(call_frame.api, .resource_exhausted, "host pinned OOM"); }; - log.debug("--> {f} {x}", .{ zml_buffer, zml_buffer.devicePtr() }); + // log.debug("--> {f} {x}", .{ zml_buffer, @intFromPtr(zml_buffer.devicePtr()) }); } callback_args[i] = zml_buffer; } diff --git a/zml/tools/BUILD.bazel b/zml/tools/BUILD.bazel index ff8c1b0..21ad565 100644 --- a/zml/tools/BUILD.bazel +++ b/zml/tools/BUILD.bazel @@ -17,9 +17,8 @@ zig_library( main = "tools.zig", visibility = ["//visibility:public"], deps = select({ - "@platforms//os:macos": [ - ":macos_c", - ], + # TODO(cerisier): fix MacOsTracer + # "@platforms//os:macos": [ ":macos_c" ], "//conditions:default": [], }), ) diff --git a/zml/tools/tracer.zig b/zml/tools/tracer.zig index d369b1e..3946cb6 100644 --- a/zml/tools/tracer.zig +++ b/zml/tools/tracer.zig @@ -3,7 +3,8 @@ const builtin = @import("builtin"); const c = @import("c"); pub const Tracer = switch (builtin.os.tag) { - .macos => MacOsTracer, + // TODO(cerisier): fix MacOsTracer + // .macos => MacOsTracer, .linux => if (@hasDecl(c, "ZML_RUNTIME_CUDA")) CudaTracer else FakeTracer, else => FakeTracer, };