From c68ec4bc5c53876e41f01f21fe818c689ef8ec12 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 25 May 2023 16:02:11 +0000 Subject: [PATCH] async: implement default threaded backend using a thread pool. Backend selectable via @zml//async:impl flag (threaded or zigcoro). Provides workaround for environments where io_uring is unavailable. --- async/BUILD.bazel | 46 ++++++- async/meta.zig | 26 ++++ async/threaded.zig | 198 +++++++++++++++++++++++++++++++ async/{async.zig => zigcoro.zig} | 50 ++------ zml/aio.zig | 2 +- zml/module.zig | 5 +- 6 files changed, 279 insertions(+), 48 deletions(-) create mode 100644 async/meta.zig create mode 100644 async/threaded.zig rename async/{async.zig => zigcoro.zig} (87%) diff --git a/async/BUILD.bazel b/async/BUILD.bazel index d7ac4a5..262466f 100644 --- a/async/BUILD.bazel +++ b/async/BUILD.bazel @@ -1,11 +1,51 @@ +load("@bazel_skylib//rules:common_settings.bzl", "string_flag") load("@rules_zig//zig:defs.bzl", "zig_library") +IMPL = [ + "threaded", + "zigcoro", +] + +string_flag( + name = "impl", + build_setting_default = "threaded", + values = IMPL, +) + +[ + config_setting( + name = "impl.{}".format(impl), + flag_values = {":impl": impl}, + ) + for impl in IMPL +] + zig_library( - name = "async", - main = "async.zig", - visibility = ["//visibility:public"], + name = "zigcoro", + srcs = ["meta.zig"], + import_name = "async", + main = "zigcoro.zig", deps = [ + "@libxev//:xev", "@zigcoro//:libcoro", + ], +) + +zig_library( + name = "threaded", + srcs = ["meta.zig"], + import_name = "async", + main = "threaded.zig", + deps = [ "@libxev//:xev", ], ) + +alias( + name = "async", + actual = select({ + ":impl.threaded": ":threaded", + ":impl.zigcoro": ":zigcoro", + }), + visibility = ["//visibility:public"], +) diff --git a/async/meta.zig b/async/meta.zig new file mode 100644 index 0000000..aba974b --- /dev/null +++ b/async/meta.zig @@ -0,0 +1,26 @@ +const std = @import("std"); + +pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { + return struct { + pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); + pub const ArgsT = blk: { + if (@typeInfo(FuncT).Fn.params.len == 0) { + break :blk @TypeOf(.{}); + } + break :blk argsT orelse std.meta.ArgsTuple(FuncT); + }; + pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); + pub const ReturnPayloadT = blk: { + break :blk switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.payload, + else => ReturnT, + }; + }; + pub const ReturnErrorSet: ?type = blk: { + break :blk switch (@typeInfo(ReturnT)) { + .ErrorUnion => |u| u.error_set, + else => null, + }; + }; + }; +} diff --git a/async/threaded.zig b/async/threaded.zig new file mode 100644 index 0000000..9d54392 --- /dev/null +++ b/async/threaded.zig @@ -0,0 +1,198 @@ +const std = @import("std"); +const xev = @import("xev"); + +const FnSignature = @import("meta.zig").FnSignature; + +pub fn Frame(comptime func: anytype) type { + const Signature = FnSignature(func, null); + return FrameEx(func, Signature.ArgsT); +} + +pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { + return struct { + const Self = @This(); + const Signature = FnSignature(func, argsT); + const Task = struct { + _task: xev.ThreadPool.Task = .{ .callback = &Self.run }, + event: std.Thread.ResetEvent = .{}, + args: Signature.ArgsT, + result: Signature.ReturnT = undefined, + }; + + _task: *Task, + + fn run(task_: *xev.ThreadPool.Task) void { + const task: *Task = @alignCast(@fieldParentPtr("_task", task_)); + task.result = @call(.auto, func, task.args); + task.event.set(); + } + + pub fn await_(self: *Self) Signature.ReturnT { + defer { + AsyncThread.current.mutex.lock(); + AsyncThread.current.allocator.destroy(self._task); + AsyncThread.current.mutex.unlock(); + } + self._task.event.wait(); + return self._task.result; + } + }; +} + +pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) { + return asyncGeneric(func, args); +} + +pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { + const FrameT = FrameEx(func, @TypeOf(args)); + + AsyncThread.current.mutex.lock(); + defer AsyncThread.current.mutex.unlock(); + + const task = try AsyncThread.current.allocator.create(FrameT.Task); + task.* = .{ + .args = args, + }; + + AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&task._task)); + return .{ ._task = task }; +} + +pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) { + return callBlockingGeneric(func, args); +} + +pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { + return @call(.auto, func, args); +} + +pub fn sleep(ms: u64) !void { + std.time.sleep(ms * std.time.ns_per_ms); +} + +pub const AsyncThread = struct { + var current: AsyncThread = undefined; + + allocator: std.mem.Allocator, + thread_pool: xev.ThreadPool, + mutex: std.Thread.Mutex, + + pub fn main(allocator_: std.mem.Allocator, comptime func: anytype, args: anytype) !void { + current = .{ + .allocator = allocator_, + .thread_pool = xev.ThreadPool.init(.{}), + .mutex = .{}, + }; + + defer { + current.thread_pool.shutdown(); + current.thread_pool.deinit(); + } + + return @call(.auto, func, args); + } +}; + +pub fn StdIn() !File { + return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); +} + +pub fn StdOut() File { + return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); +} + +pub fn StdErr() File { + return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); +} + +pub const File = struct { + pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?; + pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?; + pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read); + pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write); + pub const SeekableStream = std.io.SeekableStream( + File, + SeekError, + GetSeekPosError, + seekTo, + seekBy, + getPos, + getEndPos, + ); + + inner: std.fs.File, + + fn asFile(self: File) std.fs.File { + return self.inner; + } + + pub fn handle(self: File) std.fs.File.Handle { + return self.inner.handle; + } + + pub fn init(file_: std.fs.File) !File { + return .{ .inner = file_ }; + } + + pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { + return init(try std.fs.cwd().openFile(path, flags)); + } + + pub fn access(path: []const u8, flags: std.fs.File.OpenFlags) !void { + return try std.fs.cwd().access(path, flags); + } + + pub fn read(self: File, buf: []u8) !usize { + return try self.inner.read(buf); + } + + pub fn pread(self: File, buf: []u8, offset: u64) !usize { + return try self.inner.pread(buf, offset); + } + + pub fn write(self: File, buf: []const u8) !usize { + return try self.inner.write(buf); + } + + pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize { + return try self.inner.pwrite(buf, offset); + } + + pub fn close(self: File) !void { + return self.inner.close(); + } + + pub fn reader(self: File) Reader { + return .{ .context = self }; + } + + pub fn seekableStream(file: File) SeekableStream { + return .{ .context = file }; + } + + pub fn writer(self: File) Writer { + return .{ .context = self }; + } + + pub fn stat(self: File) !std.fs.File.Stat { + return try self.inner.stat(); + } + + pub fn seekBy(self: File, offset: i64) !void { + try self.inner.seekBy(offset); + } + + pub fn seekTo(self: File, offset: u64) !void { + try self.inner.seekTo(offset); + } + + pub fn getPos(self: File) !u64 { + return try self.inner.getPos(); + } + + pub fn getEndPos(self: File) !u64 { + return try self.inner.getEndPos(); + } +}; + +pub const Mutex = std.Thread.Mutex; diff --git a/async/async.zig b/async/zigcoro.zig similarity index 87% rename from async/async.zig rename to async/zigcoro.zig index c743d02..cc5bc18 100644 --- a/async/async.zig +++ b/async/zigcoro.zig @@ -3,41 +3,7 @@ const xev = @import("xev"); const libcoro = @import("libcoro"); const aio = libcoro.asyncio; -/// Normalize from a real tuple to a generic tuple. This is needed because -/// real tuples are reifed tuples are not the same. -fn NormalizedTuple(comptime T: type) type { - const ti = @typeInfo(T).Struct; - var types: [ti.fields.len]type = undefined; - inline for (ti.fields, 0..) |field, i| { - types[i] = field.type; - } - return std.meta.Tuple(&types); -} - -pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { - return struct { - pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func); - pub const ArgsT = blk: { - if (@typeInfo(FuncT).Fn.params.len == 0) { - break :blk @TypeOf(.{}); - } - break :blk argsT orelse std.meta.ArgsTuple(FuncT); - }; - pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); - pub const ReturnPayloadT = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.payload, - else => ReturnT, - }; - }; - pub const ReturnErrorSet: ?type = blk: { - break :blk switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.error_set, - else => null, - }; - }; - }; -} +const FnSignature = @import("meta.zig").FnSignature; pub fn Frame(comptime func: anytype) type { const Signature = FnSignature(func, null); @@ -151,9 +117,7 @@ pub const AsyncThread = struct { loop: *xev.Loop, thread_pool: *xev.ThreadPool, - pub fn main(allocator: std.mem.Allocator, comptime func: anytype, args: anytype) !FnSignature(func, NormalizedTuple(@TypeOf(args))).ReturnPayloadT { - const Signature = FnSignature(func, NormalizedTuple(@TypeOf(args))); - + pub fn main(allocator: std.mem.Allocator, comptime func: anytype, args: anytype) !void { var thread_pool = xev.ThreadPool.init(.{}); defer { thread_pool.shutdown(); @@ -178,11 +142,7 @@ pub const AsyncThread = struct { .default_stack_size = 16 * 1024 * 1024, }); - if (Signature.ReturnErrorSet) |_| { - return try aio.run(&executor, func, args, null); - } else { - return aio.run(&executor, func, args, null); - } + return try aio.run(&executor, func, args, null); } }; @@ -219,6 +179,10 @@ pub const File = struct { return .{ .handle = self.inner.file.fd }; } + pub fn handle(self: File) std.fs.File.Handle { + return self.inner.file.fd; + } + pub fn init(file_: std.fs.File) !File { return .{ .inner = aio.File.init(AsyncThread.current.executor, try xev.File.init(file_)) }; } diff --git a/zml/aio.zig b/zml/aio.zig index c853f96..595a6b4 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -252,7 +252,7 @@ pub const MemoryMappedFile = struct { data_len, std.posix.PROT.READ, .{ .TYPE = .PRIVATE }, - file.inner.file.fd, + file.handle(), 0, }); diff --git a/zml/module.zig b/zml/module.zig index 11cd146..3410117 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -49,6 +49,7 @@ pub const CompilationContext = struct { _unique_id: u64 = 10000, _tracer: Tracer, + _previous: ?*CompilationContext = null, threadlocal var _current: ?*CompilationContext = null; const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation }); @@ -95,12 +96,14 @@ pub const CompilationContext = struct { } pub fn activate(self: *CompilationContext) void { + self._previous = _current; _current = self; } pub fn deactivate(self: *CompilationContext) void { std.debug.assert(_current != null and _current.? == self); - _current = null; + _current = self._previous; + self._previous = null; } pub fn current() *CompilationContext {