From 5543c8192ff4cbc9653d77c19bcfbd3ab931bfaf Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 4 May 2023 14:44:12 +0000 Subject: [PATCH] Rename async_ to asyncc and add Generic async slugs in async.zig, aio.zig, and module.zig. --- async/async.zig | 30 +++++++++++++++++++----------- zml/aio.zig | 4 ++-- zml/module.zig | 4 ++-- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/async/async.zig b/async/async.zig index 8e8cf74..c743d02 100644 --- a/async/async.zig +++ b/async/async.zig @@ -70,16 +70,20 @@ fn FrameExx(comptime func: anytype, comptime argsT: type) type { }; } -pub fn async_(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { +pub fn asyncc(comptime func: anytype, args: FnSignature(func, null).ArgsT) !FrameEx(func, @TypeOf(args)) { + return asyncGeneric(func, args); +} + +pub fn asyncGeneric(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { const frame = try aio.xasync(func, args, null); return FrameEx(func, @TypeOf(args)).from(frame); } -pub fn call(comptime func: anytype, args: std.meta.ArgsTuple(@TypeOf(func))) @TypeOf(callGeneric(func, args)) { - return callGeneric(func, args); +pub fn callBlocking(comptime func: anytype, args: FnSignature(func, null).ArgsT) @TypeOf(callBlockingGeneric(func, args)) { + return callBlockingGeneric(func, args); } -pub fn callGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { +pub fn callBlockingGeneric(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { const Signature = FnSignature(func, @TypeOf(args)); const TaskT = struct { @@ -224,7 +228,11 @@ pub const File = struct { } pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { - return init(try call(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags })); + return init(try callBlocking(std.fs.Dir.openFile, .{ std.fs.cwd(), path, flags })); + } + + pub fn access(path: []const u8, flags: std.fs.File.OpenFlags) !void { + return try callBlocking(std.fs.Dir.access, .{ std.fs.cwd(), path, flags }); } pub fn read(self: File, buf: []u8) !usize { @@ -274,23 +282,23 @@ pub const File = struct { } pub fn stat(self: File) !std.fs.File.Stat { - return try call(std.fs.File.stat, .{self.asFile()}); + return try callBlocking(std.fs.File.stat, .{self.asFile()}); } pub fn seekBy(self: File, offset: i64) !void { - try call(std.fs.File.seekBy, .{ self.asFile(), offset }); + try callBlocking(std.fs.File.seekBy, .{ self.asFile(), offset }); } pub fn seekTo(self: File, offset: u64) !void { - try call(std.fs.File.seekTo, .{ self.asFile(), offset }); + try callBlocking(std.fs.File.seekTo, .{ self.asFile(), offset }); } pub fn getPos(self: File) !u64 { - return try call(std.fs.File.getPos, .{self.asFile()}); + return try callBlocking(std.fs.File.getPos, .{self.asFile()}); } pub fn getEndPos(self: File) !u64 { - return try call(std.fs.File.getEndPos, .{self.asFile()}); + return try callBlocking(std.fs.File.getEndPos, .{self.asFile()}); } }; @@ -342,7 +350,7 @@ pub const Socket = struct { addr: std.net.Address, }; pub const Writer = std.io.GenericWriter(WriterContext, FnSignature(UDP.write, null).ReturnErrorSet.?, struct { - fn call(self: WriterContext, buf: []const u8) !usize { + fn callBlocking(self: WriterContext, buf: []const u8) !usize { return self.file.write(self.addr, buf); } }.call); diff --git a/zml/aio.zig b/zml/aio.zig index 81631a0..c853f96 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -247,7 +247,7 @@ pub const MemoryMappedFile = struct { pub fn init(file: asynk.File) !MemoryMappedFile { const data_len: usize = (try file.stat()).size; - const data_ = try asynk.call(std.posix.mmap, .{ + const data_ = try asynk.callBlocking(std.posix.mmap, .{ null, data_len, std.posix.PROT.READ, @@ -256,7 +256,7 @@ pub const MemoryMappedFile = struct { 0, }); - try asynk.call(posix.madvise, .{ data_.ptr, @intCast(data_.len), @intCast(c.MADV_SEQUENTIAL) }); + try asynk.callBlocking(posix.madvise, .{ data_.ptr, @intCast(data_.len), @intCast(c.MADV_SEQUENTIAL) }); return .{ .file = file, diff --git a/zml/module.zig b/zml/module.zig index fb5e514..fbc15ce 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -901,7 +901,7 @@ fn compileInternal( var timer = std.time.Timer.start() catch null; const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); + const f = try asynk.callBlockingGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); context._module.getBody().appendOperation(f.mlir_fn); const sharding = context._platform.sharding(); @@ -1170,7 +1170,7 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m // Note: we may need to restore IR downgrade if we need to support old pjrt plugins. module.op().writeBytecode(mlir_bytecode.writer()); - const loaded_executable = try asynk.call(pjrt.Client.compile, .{ + const loaded_executable = try asynk.callBlocking(pjrt.Client.compile, .{ platform.pjrt_client, platform.pjrt_api, .{ .bytecode = mlir_bytecode.items, .bytecode_format = .mlir,