From bcde3962ce07570df0ce4333f489c2d049ae19e8 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 1 Aug 2023 11:35:04 +0000 Subject: [PATCH] =?UTF-8?q?Rework=20async=20runtime=20with=20coroutine=20s?= =?UTF-8?q?upport,=20rename=20async=20API=20(async=5F=E2=86=92asyncc,=20aw?= =?UTF-8?q?ait=5F=E2=86=92awaitt),=20improve=20type=20inference,=20bump=20?= =?UTF-8?q?libxev=20(default=20epoll)=20and=20update=20related=20stdx=20an?= =?UTF-8?q?d=20zml=20modules.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- BUILD.bazel | 10 + MODULE.bazel | 2 +- async/BUILD.bazel | 47 +-- async/queue_mpsc.zig | 116 +++++++ async/zigcoro.zig | 321 +++++++++++++----- bazel/zig.bzl | 22 +- stdx/BUILD.bazel | 1 + stdx/io.zig | 4 + stdx/stdx.zig | 3 +- .../libxev/20241119.0-6afcde9/MODULE.bazel | 7 + .../20241119.0-6afcde9/overlay/BUILD.bazel | 13 + .../20241119.0-6afcde9/overlay/MODULE.bazel | 7 + .../20241119.0-6afcde9/overlay/main2.zig | 22 ++ .../20241119.0-6afcde9/patches/128.patch | 119 +++++++ .../libxev/20241119.0-6afcde9/source.json | 14 + zml/BUILD.bazel | 7 +- zml/buffer.zig | 2 +- zml/pjrtx.zig | 15 +- 18 files changed, 587 insertions(+), 145 deletions(-) create mode 100644 async/queue_mpsc.zig create mode 100644 stdx/io.zig create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/MODULE.bazel create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/overlay/BUILD.bazel create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/overlay/MODULE.bazel create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/overlay/main2.zig create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/patches/128.patch create mode 100644 third_party/modules/libxev/20241119.0-6afcde9/source.json diff --git a/BUILD.bazel b/BUILD.bazel index e69de29..580e4db 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -0,0 +1,10 @@ +load("@rules_zig//zig:defs.bzl", "zls_completion") + +zls_completion( + name = "completion", + deps = [ + "//async", + "//stdx", + "//zml", + ], +) diff --git a/MODULE.bazel b/MODULE.bazel index e9a8b7c..4b027c8 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -54,7 +54,7 @@ use_repo(zls, "zls_aarch64-macos", "zls_x86_64-linux") register_toolchains("//third_party/zls:all") -bazel_dep(name = "libxev", version = "20240910.0-a2d9b31") +bazel_dep(name = "libxev", version = "20241119.0-6afcde9") bazel_dep(name = "llvm-raw", version = "20240919.0-94c024a") llvm = use_extension("@llvm-raw//utils/bazel:extension.bzl", "llvm") diff --git a/async/BUILD.bazel b/async/BUILD.bazel index 262466f..933b166 100644 --- a/async/BUILD.bazel +++ b/async/BUILD.bazel @@ -1,51 +1,16 @@ -load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_zig//zig:defs.bzl", "zig_library") -IMPL = [ - "threaded", - "zigcoro", -] - -string_flag( - name = "impl", - build_setting_default = "threaded", - values = IMPL, -) - -[ - config_setting( - name = "impl.{}".format(impl), - flag_values = {":impl": impl}, - ) - for impl in IMPL -] - zig_library( - name = "zigcoro", - srcs = ["meta.zig"], + name = "async", + srcs = [ + "queue_mpsc.zig", + ], import_name = "async", main = "zigcoro.zig", + visibility = ["//visibility:public"], deps = [ + "//stdx", "@libxev//:xev", "@zigcoro//:libcoro", ], ) - -zig_library( - name = "threaded", - srcs = ["meta.zig"], - import_name = "async", - main = "threaded.zig", - deps = [ - "@libxev//:xev", - ], -) - -alias( - name = "async", - actual = select({ - ":impl.threaded": ":threaded", - ":impl.zigcoro": ":zigcoro", - }), - visibility = ["//visibility:public"], -) diff --git a/async/queue_mpsc.zig b/async/queue_mpsc.zig new file mode 100644 index 0000000..68f4e8f --- /dev/null +++ b/async/queue_mpsc.zig @@ -0,0 +1,116 @@ +const std = @import("std"); +const assert = std.debug.assert; + +/// An intrusive MPSC (multi-provider, single consumer) queue implementation. +/// The type T must have a field "next" of type `?*T`. +/// +/// This is an implementatin of a Vyukov Queue[1]. +/// TODO(mitchellh): I haven't audited yet if I got all the atomic operations +/// correct. I was short term more focused on getting something that seemed +/// to work; I need to make sure it actually works. +/// +/// For those unaware, an intrusive variant of a data structure is one in which +/// the data type in the list has the pointer to the next element, rather +/// than a higher level "node" or "container" type. The primary benefit +/// of this (and the reason we implement this) is that it defers all memory +/// management to the caller: the data structure implementation doesn't need +/// to allocate "nodes" to contain each element. Instead, the caller provides +/// the element and how its allocated is up to them. +/// +/// [1]: https://www.1024cores.net/home/lock-free-algorithms/queues/intrusive-mpsc-node-based-queue +pub fn Intrusive(comptime T: type) type { + return struct { + const Self = @This(); + + /// Head is the front of the queue and tail is the back of the queue. + head: *T, + tail: *T, + stub: T, + + /// Initialize the queue. This requires a stable pointer to itself. + /// This must be called before the queue is used concurrently. + pub fn init(self: *Self) void { + self.head = &self.stub; + self.tail = &self.stub; + self.stub.next = null; + } + + /// Push an item onto the queue. This can be called by any number + /// of producers. + pub fn push(self: *Self, v: *T) void { + @atomicStore(?*T, &v.next, null, .unordered); + const prev = @atomicRmw(*T, &self.head, .Xchg, v, .acq_rel); + @atomicStore(?*T, &prev.next, v, .release); + } + + /// Pop the first in element from the queue. This must be called + /// by only a single consumer at any given time. + pub fn pop(self: *Self) ?*T { + var tail = @atomicLoad(*T, &self.tail, .unordered); + var next_ = @atomicLoad(?*T, &tail.next, .acquire); + if (tail == &self.stub) { + const next = next_ orelse return null; + @atomicStore(*T, &self.tail, next, .unordered); + tail = next; + next_ = @atomicLoad(?*T, &tail.next, .acquire); + } + + if (next_) |next| { + @atomicStore(*T, &self.tail, next, .release); + tail.next = null; + return tail; + } + + const head = @atomicLoad(*T, &self.head, .unordered); + if (tail != head) return null; + self.push(&self.stub); + + next_ = @atomicLoad(?*T, &tail.next, .acquire); + if (next_) |next| { + @atomicStore(*T, &self.tail, next, .unordered); + tail.next = null; + return tail; + } + + return null; + } + }; +} + +test Intrusive { + const testing = std.testing; + + // Types + const Elem = struct { + const Self = @This(); + next: ?*Self = null, + }; + const Queue = Intrusive(Elem); + var q: Queue = undefined; + q.init(); + + // Elems + var elems: [10]Elem = .{.{}} ** 10; + + // One + try testing.expect(q.pop() == null); + q.push(&elems[0]); + try testing.expect(q.pop().? == &elems[0]); + try testing.expect(q.pop() == null); + + // Two + try testing.expect(q.pop() == null); + q.push(&elems[0]); + q.push(&elems[1]); + try testing.expect(q.pop().? == &elems[0]); + try testing.expect(q.pop().? == &elems[1]); + try testing.expect(q.pop() == null); + + // // Interleaved + try testing.expect(q.pop() == null); + q.push(&elems[0]); + try testing.expect(q.pop().? == &elems[0]); + q.push(&elems[1]); + try testing.expect(q.pop().? == &elems[1]); + try testing.expect(q.pop() == null); +} diff --git a/async/zigcoro.zig b/async/zigcoro.zig index bf8b959..13fdbf3 100644 --- a/async/zigcoro.zig +++ b/async/zigcoro.zig @@ -1,123 +1,202 @@ const std = @import("std"); +const stdx = @import("stdx"); const xev = @import("xev"); const libcoro = @import("libcoro"); const aio = libcoro.asyncio; - -const FnSignature = @import("meta.zig").FnSignature; +const queue_mpsc = @import("queue_mpsc.zig"); pub fn Frame(comptime func: anytype) type { - const Signature = FnSignature(func, null); - return FrameEx(func, Signature.ArgsT); + const Signature = stdx.meta.FnSignature(func, null); + return FrameExx(func, Signature.ArgsT, Signature.ReturnT); } pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { - return FrameExx(func, argsT); + const Signature = stdx.meta.FnSignature(func, argsT); + return FrameExx(func, Signature.ArgsT, Signature.ReturnT); } -fn FrameExx(comptime func: anytype, comptime argsT: type) type { +fn FrameExx(comptime func: anytype, comptime argsT: type, comptime returnT: type) type { return struct { const Self = @This(); - const Signature = FnSignature(func, argsT); - const FrameT = libcoro.FrameT(func, .{ .ArgsT = Signature.ArgsT }); + const FrameT = libcoro.FrameT(func, .{ .ArgsT = argsT }); inner: FrameT, - pub fn await_(self: *Self) Signature.ReturnT { + pub const wait = await_; + pub const await_ = awaitt; + pub fn awaitt(self: *Self) returnT { defer { self.inner.deinit(); self.* = undefined; } return libcoro.xawait(self.inner); } - - fn from(other: anytype) !Self { - return .{ .inner = FrameT.wrap(other.frame()) }; - } }; } -pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) { - return asyncGeneric(func, args); +pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { + const Signature = stdx.meta.FnSignature(func, @TypeOf(args)); + return .{ + .inner = try aio.xasync(func, @as(Signature.ArgsT, args), null), + }; } -pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { - const frame = try aio.xasync(func, args, null); - return FrameEx(func, @TypeOf(args)).from(frame); -} - -pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) { - return callBlockingGeneric(func, args); -} - -pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { - const Signature = FnSignature(func, @TypeOf(args)); +pub fn callBlocking(comptime func: anytype, args: anytype) stdx.meta.FnSignature(func, @TypeOf(args)).ReturnT { + const Signature = stdx.meta.FnSignature(func, @TypeOf(args)); const TaskT = struct { const Self = @This(); _task: xev.ThreadPool.Task = .{ .callback = &Self.run }, - notif: Notification, - args: *const Signature.ArgsT, + event: threading.ResetEventSingle = .{}, + args: Signature.ArgsT, result: Signature.ReturnT = undefined, pub fn run(task_: *xev.ThreadPool.Task) void { const task: *Self = @alignCast(@fieldParentPtr("_task", task_)); - task.result = @call(.auto, func, task.args.*); - task.notif.notify() catch @panic("Unable to notify"); + task.result = @call(.auto, func, task.args); + task.event.set(); } }; var newtask: TaskT = .{ - .notif = Notification.init() catch @panic("Notification.init failed"), - .args = &args, + .args = args, }; - defer newtask.notif.deinit(); - AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&newtask._task)); - newtask.notif.wait() catch @panic("Unable to wait for notification"); - return newtask.result; -} + newtask.event.wait(); -pub fn tick() void { - AsyncThread.current.executor.exec.tick(); + return newtask.result; } pub fn sleep(ms: u64) !void { try aio.sleep(AsyncThread.current.executor, ms); } -pub const Notification = struct { - inner: aio.AsyncNotification, +pub const threading = struct { + const Waiter = struct { + frame: libcoro.Frame, + thread: *const AsyncThread, + next: ?*Waiter = null, + }; - pub fn init() !Notification { + const WaiterQueue = queue_mpsc.Intrusive(Waiter); + + pub const ResetEventSingle = struct { + const State = union(enum) { + unset, + waiting: *Waiter, + set, + + const unset_state: State = .unset; + const set_state: State = .set; + }; + + waiter: std.atomic.Value(*const State) = std.atomic.Value(*const State).init(&State.unset_state), + + pub fn isSet(self: *ResetEventSingle) bool { + return self.waiter.load(&State.set_state, .monotonic) == &State.set_state; + } + + pub fn reset(self: *ResetEventSingle) void { + self.waiter.store(&State.unset_state, .monotonic); + } + + pub fn set(self: *ResetEventSingle) void { + switch (self.waiter.swap(&State.set_state, .monotonic).*) { + .waiting => |waiter| { + waiter.thread.waiters_queue.push(waiter); + waiter.thread.wake(); + }, + else => {}, + } + } + + pub fn wait(self: *ResetEventSingle) void { + var waiter: Waiter = .{ + .frame = libcoro.xframe(), + .thread = AsyncThread.current, + }; + var new_state: State = .{ + .waiting = &waiter, + }; + if (self.waiter.cmpxchgStrong(&State.unset_state, &new_state, .monotonic, .monotonic) == null) { + libcoro.xsuspend(); + } + } + }; +}; + +pub const FrameAllocator = struct { + const Item = [1 * 1024 * 1024]u8; + const FramePool = std.heap.MemoryPool(Item); + + pool: FramePool, + + pub fn init(allocator_: std.mem.Allocator) !FrameAllocator { return .{ - .inner = aio.AsyncNotification.init(AsyncThread.current.executor, try xev.Async.init()), + .pool = FramePool.init(allocator_), }; } - pub fn notify(self: *Notification) !void { - try self.inner.notif.notify(); + pub fn allocator(self: *FrameAllocator) std.mem.Allocator { + return .{ + .ptr = self, + .vtable = &.{ + .alloc = alloc, + .resize = resize, + .free = free, + }, + }; } - pub fn wait(self: *Notification) !void { - try self.inner.wait(); + fn alloc(ctx: *anyopaque, len: usize, ptr_align: u8, ret_addr: usize) ?[*]u8 { + _ = ptr_align; + _ = ret_addr; + stdx.debug.assert(len <= Item.len, "Should always pass a length of less than {d} bytes", .{Item.len}); + const self: *FrameAllocator = @ptrCast(@alignCast(ctx)); + const stack = self.pool.create() catch return null; + return @ptrCast(stack); } - pub fn deinit(self: *Notification) void { - self.inner.notif.deinit(); - self.* = undefined; + fn resize(ctx: *anyopaque, buf: []u8, buf_align: u8, new_len: usize, ret_addr: usize) bool { + _ = ctx; + _ = buf; + _ = buf_align; + _ = ret_addr; + return new_len <= Item.len; + } + + fn free(ctx: *anyopaque, buf: []u8, buf_align: u8, ret_addr: usize) void { + _ = buf_align; + _ = ret_addr; + const self: *FrameAllocator = @ptrCast(@alignCast(ctx)); + const v: *align(8) Item = @ptrCast(@alignCast(buf.ptr)); + self.pool.destroy(v); } }; pub const AsyncThread = struct { - threadlocal var current: AsyncThread = undefined; + threadlocal var current: *const AsyncThread = undefined; executor: *aio.Executor, loop: *xev.Loop, thread_pool: *xev.ThreadPool, + async_notifier: *xev.Async, + waiters_queue: *threading.WaiterQueue, - pub fn main(allocator: std.mem.Allocator, comptime func: anytype, args: anytype) !void { + pub fn wake(self: *const AsyncThread) void { + self.async_notifier.notify() catch {}; + } + + fn waker_cb(q: ?*threading.WaiterQueue, _: *xev.Loop, _: *xev.Completion, _: xev.Async.WaitError!void) xev.CallbackAction { + while (q.?.pop()) |waiter| { + libcoro.xresume(waiter.frame); + } + return .rearm; + } + + pub fn main(allocator: std.mem.Allocator, comptime mainFunc: fn () anyerror!void) !void { var thread_pool = xev.ThreadPool.init(.{}); defer { thread_pool.shutdown(); @@ -127,42 +206,54 @@ pub const AsyncThread = struct { var loop = try xev.Loop.init(.{ .thread_pool = &thread_pool, }); + defer loop.deinit(); var executor = aio.Executor.init(&loop); - AsyncThread.current = .{ - .executor = &executor, - .loop = &loop, - .thread_pool = &thread_pool, - }; + var async_notifier = try xev.Async.init(); + defer async_notifier.deinit(); + + var waiters_queue: threading.WaiterQueue = undefined; + waiters_queue.init(); + + var c: xev.Completion = undefined; + async_notifier.wait(&loop, &c, threading.WaiterQueue, &waiters_queue, &waker_cb); aio.initEnv(.{ .stack_allocator = allocator, - .default_stack_size = 16 * 1024 * 1024, + .default_stack_size = 1 * 1024 * 1024, }); - return try aio.run(&executor, func, args, null); + AsyncThread.current = &.{ + .executor = &executor, + .loop = &loop, + .thread_pool = &thread_pool, + .async_notifier = &async_notifier, + .waiters_queue = &waiters_queue, + }; + + return try aio.run(&executor, mainFunc, .{}, null); } }; -pub fn StdIn() !File { +pub fn getStdIn() !File { return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); } -pub fn StdOut() File { +pub fn getStdOut() File { return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); } -pub fn StdErr() File { +pub fn getStdErr() File { return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); } pub const File = struct { - pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?; - pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?; - pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read); - pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write); + pub const SeekError = stdx.meta.FnSignature(File.seekTo, null).ReturnErrorSet.? || stdx.meta.FnSignature(File.seekBy, null).ReturnErrorSet.?; + pub const GetSeekPosError = SeekError || stdx.meta.FnSignature(File.stat, null).ReturnErrorSet.?; + pub const Reader = std.io.GenericReader(File, stdx.meta.FnSignature(File.read, null).ReturnErrorSet.?, File.read); + pub const Writer = std.io.GenericWriter(File, stdx.meta.FnSignature(File.write, null).ReturnErrorSet.?, File.write); pub const SeekableStream = std.io.SeekableStream( File, SeekError, @@ -187,10 +278,6 @@ pub const File = struct { return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)) }; } - pub fn fromFd(fd: std.fs.File.Handle) !File { - return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.initFd(fd)) }; - } - pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags })); } @@ -201,7 +288,9 @@ pub const File = struct { pub fn read(self: File, buf: []u8) !usize { // NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend. - if (buf.len == 0) return 0; + if (buf.len == 0) { + return 0; + } return self.inner.read(.{ .slice = buf }) catch |err| switch (err) { // NOTE(Corentin): read shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File. @@ -212,7 +301,9 @@ pub const File = struct { pub fn pread(self: File, buf: []u8, offset: u64) !usize { // NOTE(Corentin): Early return is required to avoid error with xev on Linux with io_uring backend. - if (buf.len == 0) return 0; + if (buf.len == 0) { + return 0; + } return self.inner.pread(.{ .slice = buf }, offset) catch |err| switch (err) { // NOTE(Corentin): pread shouldn't return an error on EOF, but a read length of 0 instead. This is to be iso with std.fs.File. @@ -267,53 +358,86 @@ pub const File = struct { }; pub const Socket = struct { + pub fn Listener(comptime T: type) type { + return struct { + const Self = @This(); + + inner: T.Inner, + + pub fn accept(self: *Self) !T { + return .{ .inner = try self.inner.accept() }; + } + + pub fn close(self: *Self) !void { + return self.inner.close(); + } + + pub fn deinit(self: *Self) !void { + self.inner.shutdown(); + } + }; + } + pub const TCP = struct { - pub const Reader = std.io.GenericReader(TCP, FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read); - pub const Writer = std.io.GenericWriter(TCP, FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write); + const Inner = aio.TCP; + + pub const Reader = std.io.GenericReader(TCP, stdx.meta.FnSignature(TCP.read, null).ReturnErrorSet.?, TCP.read); + pub const Writer = std.io.GenericWriter(TCP, stdx.meta.FnSignature(TCP.write, null).ReturnErrorSet.?, TCP.write); inner: aio.TCP, - pub fn init(addr: std.net.Address) !TCP { - return .{ .inner = aio.TCP.init(AsyncThread.current.executor, try xev.TCP.init(addr)) }; + pub fn listen(addr: std.net.Address) !Listener(TCP) { + var self: Listener(TCP) = .{ + .inner = aio.TCP.init(AsyncThread.current.executor, try xev.TCP.init(addr)), + }; + try self.inner.tcp.bind(addr); + try self.inner.tcp.listen(1024); + return self; } pub fn deinit(self: *TCP) void { self.inner.shutdown(); } + pub fn accept(self: *TCP) !TCP { + return .{ .inner = try self.inner.accept() }; + } + pub fn connect(self: *TCP, addr: std.net.Address) !void { return self.inner.connect(addr); } - pub fn read(self: *TCP, buf: []u8) !usize { + pub fn read(self: TCP, buf: []u8) !usize { return self.inner.read(.{ .slice = buf }); } - pub fn write(self: *TCP, buf: []const u8) !usize { + pub fn write(self: TCP, buf: []const u8) !usize { return self.inner.write(.{ .slice = buf }); } - pub fn close(self: *TCP) !void { - defer self.* = undefined; + pub fn close(self: TCP) !void { + // defer self.* = undefined; return self.inner.close(); } - pub fn reader(self: File) Reader { + pub fn reader(self: TCP) Reader { return .{ .context = self }; } - pub fn writer(self: File) Writer { + pub fn writer(self: TCP) Writer { return .{ .context = self }; } }; pub const UDP = struct { - pub const Reader = std.io.GenericReader(UDP, FnSignature(UDP.read, null).ReturnErrorSet.?, UDP.read); + const Inner = aio.TCP; + + pub const Reader = std.io.GenericReader(UDP, stdx.meta.FnSignature(UDP.read, null).ReturnErrorSet.?, UDP.read); pub const WriterContext = struct { file: UDP, addr: std.net.Address, }; - pub const Writer = std.io.GenericWriter(WriterContext, FnSignature(UDP.write, null).ReturnErrorSet.?, struct { + pub const Writer = std.io.GenericWriter(WriterContext, stdx.meta.FnSignature(UDP.write, null).ReturnErrorSet.?, struct { fn callBlocking(self: WriterContext, buf: []const u8) !usize { return self.file.write(self.addr, buf); } @@ -321,8 +445,13 @@ pub const Socket = struct { inner: aio.UDP, - pub fn init(addr: std.net.Address) !UDP { - return .{ .inner = aio.UDP.init(AsyncThread.current.executor, try xev.UDP.init(addr)) }; + pub fn listen(addr: std.net.Address) !Listener(UDP) { + var self: Listener(UDP) = .{ + .inner = aio.UDP.init(AsyncThread.current.executor, try xev.UDP.init(addr)), + }; + try self.inner.udp.bind(addr); + try self.inner.udp.listen(1024); + return self; } pub fn read(self: UDP, buf: []u8) !usize { @@ -370,3 +499,23 @@ pub const Mutex = struct { _ = self.inner.recv(); } }; + +pub fn logFn( + comptime message_level: std.log.Level, + comptime scope: @Type(.EnumLiteral), + comptime format: []const u8, + args: anytype, +) void { + const level_txt = comptime message_level.asText(); + const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; + const stderr = getStdErr().writer(); + var bw = std.io.bufferedWriter(stderr); + const writer = bw.writer(); + + std.debug.lockStdErr(); + defer std.debug.unlockStdErr(); + nosuspend { + writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return; + bw.flush() catch return; + } +} diff --git a/bazel/zig.bzl b/bazel/zig.bzl index ac7da82..7d5fe4b 100644 --- a/bazel/zig.bzl +++ b/bazel/zig.bzl @@ -1,9 +1,18 @@ load("@rules_zig//zig:defs.bzl", "BINARY_KIND", "zig_binary") -def zig_cc_binary(name, args = None, env = None, data = [], deps = [], visibility = None, **kwargs): +def zig_cc_binary( + name, + copts = [], + args = None, + env = None, + data = [], + deps = [], + visibility = None, + **kwargs): zig_binary( name = "{}_lib".format(name), kind = BINARY_KIND.static_lib, + copts = copts + ["-lc", "-fcompiler-rt"], deps = deps + [ "@rules_zig//zig/lib:libc", ], @@ -18,11 +27,20 @@ def zig_cc_binary(name, args = None, env = None, data = [], deps = [], visibilit visibility = visibility, ) -def zig_cc_test(name, env = None, data = [], deps = [], test_runner = None, visibility = None, **kwargs): +def zig_cc_test( + name, + copts = [], + env = None, + data = [], + deps = [], + test_runner = None, + visibility = None, + **kwargs): zig_binary( name = "{}_test_lib".format(name), kind = BINARY_KIND.test_lib, test_runner = test_runner, + copts = copts + ["-lc", "-fcompiler-rt"], deps = deps + [ "@rules_zig//zig/lib:libc", ], diff --git a/stdx/BUILD.bazel b/stdx/BUILD.bazel index 26c190e..56b9683 100644 --- a/stdx/BUILD.bazel +++ b/stdx/BUILD.bazel @@ -4,6 +4,7 @@ zig_library( name = "stdx", srcs = [ "debug.zig", + "io.zig", "math.zig", "meta.zig", "signature.zig", diff --git a/stdx/io.zig b/stdx/io.zig new file mode 100644 index 0000000..390647b --- /dev/null +++ b/stdx/io.zig @@ -0,0 +1,4 @@ +const std = @import("std"); + +pub const BufferedAnyWriter = std.io.BufferedWriter(4096, std.io.AnyWriter); +pub const BufferedAnyReader = std.io.BufferedReader(4096, std.io.AnyReader); diff --git a/stdx/stdx.zig b/stdx/stdx.zig index b8447fe..820e7aa 100644 --- a/stdx/stdx.zig +++ b/stdx/stdx.zig @@ -1,3 +1,4 @@ +pub const debug = @import("debug.zig"); +pub const io = @import("io.zig"); pub const math = @import("math.zig"); pub const meta = @import("meta.zig"); -pub const debug = @import("debug.zig"); diff --git a/third_party/modules/libxev/20241119.0-6afcde9/MODULE.bazel b/third_party/modules/libxev/20241119.0-6afcde9/MODULE.bazel new file mode 100644 index 0000000..c9b0306 --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/MODULE.bazel @@ -0,0 +1,7 @@ +module( + name = "libxev", + version = "20241119.0-6afcde9", + compatibility_level = 1, +) + +bazel_dep(name = "rules_zig", version = "20240904.0-010da15") diff --git a/third_party/modules/libxev/20241119.0-6afcde9/overlay/BUILD.bazel b/third_party/modules/libxev/20241119.0-6afcde9/overlay/BUILD.bazel new file mode 100644 index 0000000..7c0c4ce --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/overlay/BUILD.bazel @@ -0,0 +1,13 @@ +load("@rules_zig//zig:defs.bzl", "zig_library") + +zig_library( + name = "xev", + srcs = glob([ + "src/*.zig", + "src/backend/*.zig", + "src/linux/*.zig", + "src/watcher/*.zig", + ]), + main = "main2.zig", + visibility = ["//visibility:public"], +) diff --git a/third_party/modules/libxev/20241119.0-6afcde9/overlay/MODULE.bazel b/third_party/modules/libxev/20241119.0-6afcde9/overlay/MODULE.bazel new file mode 100644 index 0000000..c9b0306 --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/overlay/MODULE.bazel @@ -0,0 +1,7 @@ +module( + name = "libxev", + version = "20241119.0-6afcde9", + compatibility_level = 1, +) + +bazel_dep(name = "rules_zig", version = "20240904.0-010da15") diff --git a/third_party/modules/libxev/20241119.0-6afcde9/overlay/main2.zig b/third_party/modules/libxev/20241119.0-6afcde9/overlay/main2.zig new file mode 100644 index 0000000..4166925 --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/overlay/main2.zig @@ -0,0 +1,22 @@ +const builtin = @import("builtin"); +const root = @import("root"); + +const main = @import("src/main.zig"); + +pub const ThreadPool = main.ThreadPool; +pub const stream = main.stream; + +pub const Options = struct { + linux_backend: main.Backend = .epoll, +}; + +pub const options: Options = if (@hasDecl(root, "xev_options")) root.xev_options else .{}; + +const default: main.Backend = switch (builtin.os.tag) { + .ios, .macos => .kqueue, + .linux => options.linux_backend, + .wasi => .wasi_poll, + .windows => .iocp, + else => @compileError("Unsupported OS"), +}; +pub usingnamespace default.Api(); diff --git a/third_party/modules/libxev/20241119.0-6afcde9/patches/128.patch b/third_party/modules/libxev/20241119.0-6afcde9/patches/128.patch new file mode 100644 index 0000000..6643419 --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/patches/128.patch @@ -0,0 +1,119 @@ +From 0d1c2f8258072148459d3114b9ccaf43c02e0958 Mon Sep 17 00:00:00 2001 +From: Steeve Morin +Date: Tue, 19 Nov 2024 16:14:14 +0100 +Subject: [PATCH] backend/epoll: implement eventfd wakeup notification + +Tries to mimic what happens in backend/kqueue. + +Closes #4 +--- + src/backend/epoll.zig | 42 ++++++++++++++++++++++++++++++++++++++++++ + 1 file changed, 42 insertions(+) + +diff --git a/src/backend/epoll.zig b/src/backend/epoll.zig +index ae4ec7d..f44d326 100644 +--- a/src/backend/epoll.zig ++++ b/src/backend/epoll.zig +@@ -21,6 +21,12 @@ pub const Loop = struct { + + fd: posix.fd_t, + ++ /// The eventfd that this epoll queue always has a filter for. Writing ++ /// an empty message to this eventfd can be used to wake up the loop ++ /// at any time. Waking up the loop via this eventfd won't trigger any ++ /// particular completion, it just forces tick to cycle. ++ eventfd: xev.Async, ++ + /// The number of active completions. This DOES NOT include completions that + /// are queued in the submissions queue. + active: usize = 0, +@@ -56,8 +62,12 @@ pub const Loop = struct { + } = .{}, + + pub fn init(options: xev.Options) !Loop { ++ var eventfd = try xev.Async.init(); ++ errdefer eventfd.deinit(); ++ + var res: Loop = .{ + .fd = try posix.epoll_create1(std.os.linux.EPOLL.CLOEXEC), ++ .eventfd = eventfd, + .thread_pool = options.thread_pool, + .thread_pool_completions = undefined, + .cached_now = undefined, +@@ -68,6 +78,7 @@ pub const Loop = struct { + + pub fn deinit(self: *Loop) void { + posix.close(self.fd); ++ self.eventfd.deinit(); + } + + /// Run the event loop. See RunMode documentation for details on modes. +@@ -262,9 +273,26 @@ pub const Loop = struct { + // Initialize + if (!self.flags.init) { + self.flags.init = true; ++ + if (self.thread_pool != null) { + self.thread_pool_completions.init(); + } ++ ++ var ev: linux.epoll_event = .{ ++ .events = linux.EPOLL.IN | linux.EPOLL.RDHUP, ++ .data = .{ .ptr = 0 }, ++ }; ++ posix.epoll_ctl( ++ self.fd, ++ linux.EPOLL.CTL_ADD, ++ self.eventfd.fd, ++ &ev, ++ ) catch |err| { ++ // We reset initialization because we can't do anything ++ // safely unless we get this mach port registered! ++ self.flags.init = false; ++ return err; ++ }; + } + + // Submit all the submissions. We copy the submission queue so that +@@ -369,6 +397,10 @@ pub const Loop = struct { + + // Process all our events and invoke their completion handlers + for (events[0..n]) |ev| { ++ // Zero data values are internal events that we do nothing ++ // on such as the eventfd wakeup. ++ if (ev.data.ptr == 0) continue; ++ + const c: *Completion = @ptrFromInt(@as(usize, @intCast(ev.data.ptr))); + + // We get the fd and mark this as in progress we can properly +@@ -415,6 +447,7 @@ pub const Loop = struct { + const pool = self.thread_pool orelse return error.ThreadPoolRequired; + + // Setup our completion state so that thread_perform can do stuff ++ c.task_loop = self; + c.task_completions = &self.thread_pool_completions; + c.task = .{ .callback = Loop.thread_perform }; + +@@ -436,6 +469,14 @@ pub const Loop = struct { + + // Add to our completion queue + c.task_completions.push(c); ++ ++ // Wake up our main loop ++ c.task_loop.wakeup() catch {}; ++ } ++ ++ /// Sends an empty message to this loop's eventfd so that it wakes up. ++ fn wakeup(self: *Loop) !void { ++ try self.eventfd.notify(); + } + + fn start(self: *Loop, completion: *Completion) void { +@@ -800,6 +841,7 @@ pub const Completion = struct { + /// reliable way to get access to the loop and shouldn't be used + /// except internally. + task: ThreadPool.Task = undefined, ++ task_loop: *Loop = undefined, + task_completions: *Loop.TaskCompletionQueue = undefined, + task_result: Result = undefined, + diff --git a/third_party/modules/libxev/20241119.0-6afcde9/source.json b/third_party/modules/libxev/20241119.0-6afcde9/source.json new file mode 100644 index 0000000..de3c241 --- /dev/null +++ b/third_party/modules/libxev/20241119.0-6afcde9/source.json @@ -0,0 +1,14 @@ +{ + "strip_prefix": "libxev-690c76fd792f001c5776716f1e7b04be2cc50b52", + "url": "https://github.com/zml/libxev/archive/690c76fd792f001c5776716f1e7b04be2cc50b52.tar.gz", + "integrity": "sha256-DV66ic8PcRnG3EdimswCleiHo/dDztgebz/1EY5XDXg=", + "overlay": { + "MODULE.bazel": "", + "BUILD.bazel": "", + "main2.zig": "" + }, + "patches": { + "128.patch": "" + }, + "patch_strip": 1 +} diff --git a/zml/BUILD.bazel b/zml/BUILD.bazel index d44cbd0..39a8fd7 100644 --- a/zml/BUILD.bazel +++ b/zml/BUILD.bazel @@ -1,5 +1,5 @@ load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar") -load("@rules_zig//zig:defs.bzl", "zig_library", "zls_completion") +load("@rules_zig//zig:defs.bzl", "zig_library") load("//bazel:zig.bzl", "zig_cc_test") load("//bazel:zig_proto_library.bzl", "zig_proto_library") @@ -40,11 +40,6 @@ zig_library( ], ) -zls_completion( - name = "completion", - deps = [":zml"], -) - zig_proto_library( name = "xla_proto", import_name = "//xla:xla_proto", diff --git a/zml/buffer.zig b/zml/buffer.zig index 8a59aad..1fffbaf 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -82,7 +82,7 @@ pub const Buffer = struct { } for (frames.slice()) |*frame| { - const pjrt_buffer = try frame.await_(); + const pjrt_buffer = try frame.awaitt(); res._shards.appendAssumeCapacity(pjrt_buffer); } return res; diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index d50724f..b2f4b4f 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -137,23 +137,24 @@ pub const Event = opaque { pub fn await_(self: *Event, api: *const Api) !void { defer self.deinit(api); + if (self.isReady(api)) { + return; + } + var ctx = struct { err: ?*pjrt.Error = null, - notif: asynk.Notification, - }{ - .notif = try asynk.Notification.init(), - }; - defer ctx.notif.deinit(); + event: asynk.threading.ResetEventSingle = .{}, + }{}; try self.inner().onReady(api, &(struct { fn call(err: ?*pjrt.Error, user_arg: ?*anyopaque) callconv(.C) void { const ctx_: *@TypeOf(ctx) = @ptrCast(@alignCast(user_arg.?)); ctx_.err = err; - ctx_.notif.notify() catch @panic("Unable to notify"); + ctx_.event.set(); } }.call), &ctx); + ctx.event.wait(); - try ctx.notif.wait(); if (ctx.err) |e| { defer e.deinit(api); return e.getCode(api).toApiError();