diff --git a/async/coro.zig b/async/coro.zig index a4d0345..6704296 100644 --- a/async/coro.zig +++ b/async/coro.zig @@ -159,26 +159,21 @@ const Coro = struct { fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn { const from_coro: *Coro = @fieldParentPtr("impl", from); const this_coro: *Coro = @fieldParentPtr("impl", this); - log(.debug, "coro start {any}", .{this_coro.id}); + log(.debug, "coro start {f}", .{this_coro.*}); @call(.auto, this_coro.func, .{}); this_coro.status = .Done; thread_state.switchOut(from_coro); // Never returns - stdx.debug.panic("Cannot resume an already completed coroutine {any}", .{this_coro.id}); + stdx.debug.panic("Cannot resume an already completed coroutine {f}", .{this_coro.*}); } pub fn getStorage(self: Coro, comptime T: type) *T { return @ptrCast(@alignCast(self.storage)); } - pub fn format(self: Coro, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - _ = fmt; - _ = options; - try writer.print("Coro{{.id = {any}, .status = {s}}}", .{ - self.id, - @tagName(self.status), - }); + pub fn format(self: Coro, writer: *std.Io.Writer) !void { + try writer.print("Coro{{.id = {any}, .status = {t}}}", .{ self.id, self.status }); } }; @@ -292,7 +287,7 @@ const ThreadState = struct { /// Called from resume fn switchIn(self: *ThreadState, target: Frame) void { - log(.debug, "coro resume {any} from {any}", .{ target.id, self.current().id }); + log(.debug, "coro resume {f} from {f}", .{ target.id, self.current().id }); // Switch to target, setting this coro as the resumer. self.switchTo(target, true); @@ -307,7 +302,7 @@ const ThreadState = struct { /// Called from suspend fn switchOut(self: *ThreadState, target: Frame) void { - log(.debug, "coro suspend {any} to {any}", .{ self.current().id, target.id }); + log(.debug, "coro suspend {f} to {f}", .{ self.current().id, target.id }); self.switchTo(target, false); } @@ -384,13 +379,8 @@ const CoroId = struct { self.invocation += 1; } - pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { - _ = fmt; - _ = options; - try writer.print("CoroId{{.cid={d}, .i={d}}}", .{ - self.id.coro, - self.invocation, - }); + pub fn format(self: @This(), writer: *std.Io.Writer) !void { + try writer.print("CoroId{{.cid={d}, .i={d}}}", .{ self.id.coro, self.invocation }); } }; }; diff --git a/mlir/dialects/stablehlo/stablehlo.zig b/mlir/dialects/stablehlo/stablehlo.zig index 0c60fd0..7c7c139 100644 --- a/mlir/dialects/stablehlo/stablehlo.zig +++ b/mlir/dialects/stablehlo/stablehlo.zig @@ -750,13 +750,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { stdx.debug.assert( backend_config.isA(mlir.StringAttribute), - "API version < 4 requires a string as backend_config, got {}", + "API version < 4 requires a string as backend_config, got {f}", .{backend_config}, ); } else { stdx.debug.assert( backend_config.isA(mlir.DictionaryAttribute), - "API version >= 4 requires a dictionary as backend_config, got {}", + "API version >= 4 requires a dictionary as backend_config, got {f}", .{backend_config}, ); } @@ -1285,20 +1285,21 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo } pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 { - var buf: [32]u8 = undefined; + const Cmp = struct { + v1: []const u8, + v1_is_smaller: bool = undefined, - var stream = std.io.fixedBufferStream(&buf); - var context = .{ .writer = stream.writer() }; - const WriterContext = @TypeOf(context); - - _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + pub fn smallerCb(smaller_version: c.MlirStringRef, opaque_cmp: ?*anyopaque) callconv(.c) void { + var cmp: *@This() = @ptrCast(@alignCast(opaque_cmp)); + cmp.v1_is_smaller = std.mem.eql(u8, cmp.v1, smaller_version.data[0..smaller_version.length]); } - }).callback, &context); + }; - return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2; + var cmp_ctx: Cmp = .{ .v1 = version1 }; + const cmp_res = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), Cmp.smallerCb, &cmp_ctx); + + std.debug.assert(cmp_res.value != 0); + return if (cmp_ctx.v1_is_smaller) version1 else version2; } pub fn getCurrentVersion() []const u8 { @@ -1308,18 +1309,9 @@ pub fn getCurrentVersion() []const u8 { var once = std.once(call); fn call() void { - var stream = std.io.fixedBufferStream(&buf); - var writer_ = stream.writer(); - const ContextWriter = @TypeOf(writer_); - - c.stablehloGetCurrentVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { - const writer: *ContextWriter = @ptrCast(@alignCast(userdata)); - _ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &writer_); - - str = buf[0..stream.pos]; + var writer: std.Io.Writer = .fixed(&buf); + c.stablehloGetCurrentVersion(printCallbackNoFail, &writer); + str = writer.buffered(); } }; @@ -1334,18 +1326,9 @@ pub fn getMinimumVersion() []const u8 { var once = std.once(call); fn call() void { - var stream = std.io.fixedBufferStream(&buf); - var context = .{ .writer = stream.writer() }; - const WriterContext = @TypeOf(context); - - c.stablehloGetMinimumVersion((struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context); - - str = buf[0..stream.pos]; + var writer: std.Io.Writer = .fixed(&buf); + c.stablehloGetMinimumVersion(printCallbackNoFail, &writer); + str = writer.buffered(); } }; @@ -1353,14 +1336,25 @@ pub fn getMinimumVersion() []const u8 { return state.str; } -pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void { - var context = .{ .writer = writer }; - const WriterContext = @TypeOf(context); - - try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { - pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { - const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); - _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; - } - }).callback, &context), error.InvalidMlirBytecodeVersion); +pub fn serializePortableArtifact( + bytecode: []const u8, + target_version: []const u8, + writer: *std.Io.Writer, +) error{ InvalidMlirBytecodeVersion, WriteFailed }!void { + var writer_err: mlir.WriterWithErr = .{ .writer = writer }; + try mlir.successOr( + c.stablehloSerializePortableArtifactFromStringRef( + mlir.stringRef(bytecode), + mlir.stringRef(target_version), + mlir.WriterWithErr.printCallback, + &writer_err, + ), + error.InvalidMlirBytecodeVersion, + ); + return try writer_err.check(); +} + +fn printCallbackNoFail(mlir_str: c.MlirStringRef, opaque_writer: ?*anyopaque) callconv(.c) void { + const writer: *std.Io.Writer = @ptrCast(@alignCast(opaque_writer)); + writer.writeAll(mlir.fromStringRef(mlir_str)) catch @panic("Failed to write MLIR"); } diff --git a/mlir/mlir.zig b/mlir/mlir.zig index a672191..15dbdab 100755 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -7,18 +7,20 @@ const stdx = @import("stdx"); const log = std.log.scoped(.mlir); test { - std.testing.refAllDecls(@This()); + std.testing.refAllDeclsRecursive(@This()); _ = try Context.init(); } -const Error = error{ +pub const Error = error{ /// Invalid Mlir was created. InvalidMlir, /// Another Mlir error. Check the log for more context. MlirUnexpected, /// A resource/executor was not found. NotFound, + /// Bytecode version incompatibility. + InvalidMlirBytecodeVersion, OutOfMemory, }; @@ -35,64 +37,61 @@ pub fn registerPasses(comptime passes: []const u8) void { @field(c, "mlirRegister" ++ passes ++ "Passes")(); } -pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void { - return if (res.value == 0) err; +pub fn successOr(res: c.MlirLogicalResult, err: anytype) @TypeOf(err)!void { + return if (res.value == 0) err else {}; } -/// Alternative to MlirWrapperType -pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void; - pub const Registry = struct { _inner: c.MlirDialectRegistry, pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy); pub fn init() !Registry { - return helpers.init(Registry, c.mlirDialectRegistryCreate(), c.mlirDialectRegistryIsNull) orelse Error.MlirUnexpected; + const registry = c.mlirDialectRegistryCreate(); + return .{ ._inner = .{ .ptr = registry.ptr orelse return Error.MlirUnexpected } }; } }; pub const Context = struct { _inner: c.MlirContext, - const Self = Context; pub const deinit = helpers.deinit(Context, c.mlirContextDestroy); pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull); - pub fn init() !Self { - return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected; + pub fn init() !Context { + return Context.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected; } - pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self { - return Self.wrapOr( + pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Context { + return Context.wrapOr( c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled), ) orelse Error.InvalidMlir; } - pub fn setMultiThreading(self: *Self, enabled: bool) void { + pub fn setMultiThreading(self: *Context, enabled: bool) void { c.mlirContextEnableMultithreading(self._inner, enabled); } - pub fn appendDialectRegistry(self: *Self, registry: Registry) void { + pub fn appendDialectRegistry(self: *Context, registry: Registry) void { c.mlirContextAppendDialectRegistry(self._inner, registry._inner); } - pub fn loadAllAvailableDialects(self: *Self) void { + pub fn loadAllAvailableDialects(self: *Context) void { c.mlirContextLoadAllAvailableDialects(self._inner); } - pub fn numRegisteredDialects(self: Self) usize { + pub fn numRegisteredDialects(self: Context) usize { return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner)); } - pub fn numLoadedDialects(self: Self) usize { + pub fn numLoadedDialects(self: Context) usize { return @intCast(c.mlirContextGetNumLoadedDialects(self._inner)); } - pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool { + pub fn isRegisteredOperation(self: Context, op: [:0]const u8) bool { return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op)); } - pub fn location(self: Self, src: std.builtin.SourceLocation) Location { + pub fn location(self: Context, src: std.builtin.SourceLocation) Location { return Location.fromSrc(self, src); } }; @@ -103,9 +102,7 @@ pub const Module = struct { pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy); pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull); - const Self = Module; - - pub fn init(loc: Location) Self { + pub fn init(loc: Location) Module { return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) }; } @@ -142,29 +139,26 @@ pub const PassManager = struct { pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy); pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull); - const Self = PassManager; - - pub fn init(ctx: Context) !Self { - return Self.wrapOr( + pub fn init(ctx: Context) !PassManager { + return PassManager.wrapOr( c.mlirPassManagerCreate(ctx._inner), ) orelse Error.MlirUnexpected; } - pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self { - return Self.wrapOr( + pub fn initOnOperation(ctx: Context, op: [:0]const u8) !PassManager { + return PassManager.wrapOr( c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)), ) orelse Error.MlirUnexpected; } - pub fn asOpPassManager(self: Self) OpPassManager { + pub fn asOpPassManager(self: PassManager) OpPassManager { return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) }; } - pub fn enableIRPrinting(self: *Self) void { - c.mlirPassManagerEnableIRPrinting(self._inner); - } + // TODO mlirPassManagerEnableIRPrinting + // pub fn enableIRPrinting(self: *PassManager) void {} - pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void { + pub fn runOnOp(self: *PassManager, op: Operation) error{InvalidMlir}!void { if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) { return Error.InvalidMlir; } @@ -193,21 +187,20 @@ pub const OpPassManager = struct { pub const Identifier = struct { _inner: c.MlirIdentifier, - const Self = Identifier; - pub fn get(ctx: Context, str_: [:0]const u8) Self { + pub fn get(ctx: Context, str_: [:0]const u8) Identifier { return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) }; } - pub fn context(self: Self) Context { + pub fn context(self: Identifier) Context { return .{ ._inner = c.mlirIdentifierGetContext(self._inner) }; } - pub fn str(self: Self) []const u8 { + pub fn str(self: Identifier) []const u8 { return fromStringRef(c.mlirIdentifierStr(self._inner)); } - pub fn equals(self: Self, other: Self) bool { + pub fn equals(self: Identifier, other: Identifier) bool { return c.mlirIdentifierEqual(self._inner, other._inner); } }; @@ -322,6 +315,14 @@ pub const Attribute = struct { return DictionaryAttribute.init(ctx, attrs).asAttr(); } + + pub fn eqlAny(Attr: type) fn (Attr, Attr) bool { + return struct { + fn eql(a: Attr, b: Attr) bool { + return a.asAttr().eql(b.asAttr()); + } + }.eql; + } }; pub const NamedAttribute = extern struct { @@ -349,15 +350,14 @@ pub const NamedAttribute = extern struct { pub const StringAttribute = struct { _inner: c.MlirAttribute, pub const is_a_fn = c.mlirAttributeIsAString; - const Self = StringAttribute; - pub const asAttr = Attribute.fromAny(Self); - pub const eql = Attribute.eqlAny(Self); + pub const asAttr = Attribute.fromAny(StringAttribute); + pub const eql = Attribute.eqlAny(StringAttribute); - pub fn init(ctx: Context, str: []const u8) Self { + pub fn init(ctx: Context, str: []const u8) StringAttribute { return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) }; } - pub fn value(self: Self) []const u8 { + pub fn value(self: StringAttribute) []const u8 { return fromStringRef(c.mlirStringAttrGetValue(self._inner)); } }; @@ -365,15 +365,14 @@ pub const StringAttribute = struct { pub const BoolAttribute = struct { _inner: c.MlirAttribute, pub const is_a_fn = c.mlirAttributeIsABool; - const Self = BoolAttribute; - pub const asAttr = Attribute.fromAny(Self); - pub const eql = Attribute.eqlAny(Self); + pub const asAttr = Attribute.fromAny(BoolAttribute); + pub const eql = Attribute.eqlAny(BoolAttribute); - pub fn init(ctx: Context, value_: bool) Self { + pub fn init(ctx: Context, value_: bool) BoolAttribute { return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) }; } - pub fn value(self: Self) bool { + pub fn value(self: BoolAttribute) bool { return c.mlirBoolAttrGetValue(self._inner); } }; @@ -397,19 +396,18 @@ pub const TypeAttribute = struct { pub const ArrayAttribute = struct { _inner: c.MlirAttribute, pub const is_a_fn = c.mlirAttributeIsAArray; - const Self = ArrayAttribute; - pub const asAttr = Attribute.fromAny(Self); - pub const eql = Attribute.eqlAny(Self); + pub const asAttr = Attribute.fromAny(ArrayAttribute); + pub const eql = Attribute.eqlAny(ArrayAttribute); - pub fn init(ctx: Context, attrs: []const Attribute) Self { + pub fn init(ctx: Context, attrs: []const Attribute) ArrayAttribute { return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) }; } - pub fn size(self: Self) usize { + pub fn size(self: ArrayAttribute) usize { return @intCast(c.mlirArrayAttrGetNumElements(self._inner)); } - pub fn get(self: Self, index: usize) Attribute { + pub fn get(self: ArrayAttribute, index: usize) Attribute { return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) }; } }; @@ -590,13 +588,13 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr { const raw_bytes = std.mem.sliceAsBytes(slice); - const res: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet( + const attr: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet( shaped_type._inner, @intCast(raw_bytes.len), @ptrCast(raw_bytes.ptr), ) }; - std.debug.assert(Attribute.wrapOr(res._inner) != null); - return res; + std.debug.assert(attr._inner.ptr != null); + return attr; } pub fn len(self: Attr) usize { @@ -717,7 +715,7 @@ pub const DictionaryAttribute = struct { } pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute { - return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, name)); + return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, stringRef(name))); } }; @@ -728,8 +726,7 @@ pub const Operation = struct { pub const dump = helpers.dump(Operation, c.mlirOperationDestroy); pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy); pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull); - - pub const eql = Attribute.eqlAny(Self); + pub const eql = helpers.eql(Operation, c.mlirOperationEqual); pub fn init(state: *OperationState) !Self { return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir; @@ -881,52 +878,33 @@ pub const Operation = struct { return .{ ._inner = c.mlirOperationGetContext(self._inner) }; } - pub fn writeBytecode(self: Self, writer: anytype) void { - var writer_context = .{ .writer = writer }; - const WriterContext = @TypeOf(writer_context); - + pub fn writeBytecode(self: Self, writer: *std.Io.Writer) std.Io.Writer.Error!void { + var writer_with_err: WriterWithErr = .{ .writer = writer }; c.mlirOperationWriteBytecode( self._inner, - (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { - const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); - _ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable; - } - }).callback, - &writer_context, + WriterWithErr.printCallback, + &writer_with_err, ); + return writer_with_err.check(); } - pub fn writeBytecodeWithConfig(self: Self, writer: anytype, config: struct { + pub fn writeBytecodeWithConfig(self: Self, writer: *std.Io.Writer, config: struct { desiredEmitedVersion: ?i64 = null, - }) !void { + }) error{ InvalidMlirBytecodeVersion, WriteFailed }!void { const cfg = c.mlirBytecodeWriterConfigCreate(); defer c.mlirBytecodeWriterConfigDestroy(cfg); if (config.desiredEmitedVersion) |v| { c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v); } - const WriterContext = struct { - writer: @TypeOf(writer), - write_error: ?@TypeOf(writer).Error = null, - }; - var writer_context: WriterContext = .{ .writer = writer }; - + var writer_with_err: WriterWithErr = .{ .writer = writer }; try successOr(c.mlirOperationWriteBytecodeWithConfig( self._inner, cfg, - (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { - const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_)); - _ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| { - inner_writer_context.write_error = err; - }; - } - }).callback, - &writer_context, + &WriterWithErr.printCallback, + &writer_with_err, ), error.InvalidMlirBytecodeVersion); - - if (writer_context.write_error) |err| return err; + return writer_with_err.check(); } /// Enable a full dump of the IR. @@ -939,26 +917,18 @@ pub const Operation = struct { op: Operation, flags: OpPrintingFlags, - pub fn format(self: @This(), writer: anytype) !void { - self.op.print(writer, self.flags); + pub fn format(self: MlirFormatter, writer: *std.Io.Writer) !void { + try self.op.print(writer, self.flags); } }; - pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void { + pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) std.Io.Writer.Error!void { const pflags = flags.create(); defer c.mlirOpPrintingFlagsDestroy(pflags); - c.mlirOperationPrintWithFlags( - self._inner, - pflags, - (struct { - pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { - const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_)); - _writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed"); - } - }).callback, - writer, - ); + var writer_err: WriterWithErr = .{ .writer = writer }; + c.mlirOperationPrintWithFlags(self._inner, pflags, WriterWithErr.printCallback, &writer_err); + return writer_err.check(); } pub fn verify(self: Self) bool { @@ -1065,27 +1035,25 @@ pub const OpPrintingFlags = struct { pub const OpOperand = struct { _inner: c.MlirOpOperand, - const Self = OpOperand; + pub const wrapOr = helpers.wrapOr(OpOperand, c.mlirOpOperandIsNull); - pub fn owner(self: Self) Operation { + pub fn owner(self: OpOperand) Operation { return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) }; } - pub fn number(self: Self) usize { + pub fn number(self: OpOperand) usize { return @intCast(c.mlirOpOperandGetOperandNumber(self._inner)); } - pub fn nextUse(self: Self) ?Self { - return Self.wrapOr( - c.mlirOpOperandGetNextUse(self._inner), - ); + pub fn nextUse(self: OpOperand) ?OpOperand { + return wrapOr(c.mlirOpOperandGetNextUse(self._inner)); } }; pub const Region = struct { _inner: c.MlirRegion, - pub const eql = helpers.eql(Region, c.mlirBlockEqual); + pub const eql = helpers.eql(Region, c.mlirRegionEqual); pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy); pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull); @@ -1121,7 +1089,7 @@ pub const Value = struct { pub const dump = helpers.dump(Value, c.mlirValueDump); pub const eql = helpers.eql(Value, c.mlirValueEqual); - pub const format = helpers.format(Value, c.mlirValuePrint).format; + pub const format = helpers.format(Value, c.mlirValuePrint); pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull); pub fn getType(val: Value) Type { @@ -1183,7 +1151,7 @@ pub const BlockArgument = struct { return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner)); } - pub fn format(self: BlockArgument, writer: anytype) !void { + pub fn format(self: BlockArgument, writer: *std.Io.Writer) !void { const value = Value{ ._inner = self._inner }; return value.format(writer); } @@ -1192,7 +1160,7 @@ pub const BlockArgument = struct { pub const Type = struct { _inner: c.MlirType, - pub const dump = helpers.eql(Type, c.mlirTypeDump); + pub const dump = helpers.dump(Type, c.mlirTypeDump); pub const eql = helpers.eql(Type, c.mlirTypeEqual); pub const format = helpers.format(Type, c.mlirTypePrint); pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull); @@ -1230,14 +1198,6 @@ pub const Type = struct { }.eql; } - pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type { - return struct { - pub fn format(self: SpecificType, writer: anytype) !void { - return try Type.format(self.asType(), writer); - } - }; - } - pub fn index(ctx: Context) Type { return IndexType.init(ctx).asType(); } @@ -1280,7 +1240,7 @@ pub const IndexType = struct { pub const asType = Type.fromAny(IndexType); pub const eql = Type.eqlAny(IndexType); - pub const format = Type.formatAny(IndexType).format; + pub const format = helpers.format(IndexType, c.mlirTypePrint); pub fn init(ctx: Context) IndexType { return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) }; @@ -1452,7 +1412,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type { pub const asType = Type.fromAny(Complex); pub const eql = Type.eqlAny(Complex); - pub const format = Type.formatAny(Complex).format; + pub const format = helpers.format(Complex, c.mlirTypePrint); pub const ComplexTypeType: ComplexTypes = ct; pub const init = if (ct != .unknown) struct { @@ -1468,51 +1428,50 @@ pub const TupleType = struct { _inner: c.MlirType, pub const is_a_fn = c.mlirTypeIsATuple; - const Self = TupleType; - - pub fn init(ctx: Context, elements: []const Type) !Self { - return Self.wrapOr(c.mlirTupleTypeGet( + pub fn init(ctx: Context, elements: []const Type) !TupleType { + const tuple_type = c.mlirTupleTypeGet( ctx._inner, @intCast(elements.len), @ptrCast(elements.ptr), - )) orelse Error.InvalidMlir; + ); + return .{ ._inner = .{ .ptr = tuple_type.ptr orelse return error.InvalidMlir } }; } - pub fn getNumTypes(self: Self) usize { + pub fn getNumTypes(self: TupleType) usize { return @intCast(c.mlirTupleTypeGetNumTypes(self._inner)); } - pub fn getElementType(self: Self, index: usize) Type { + pub fn getElementType(self: TupleType, index: usize) Type { return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) }; } - pub const asType = Type.fromAny(Self); + pub const asType = Type.fromAny(TupleType); }; pub const FunctionType = struct { _inner: c.MlirType, pub const is_a_fn = c.mlirTypeIsAFunction; - const Self = FunctionType; - pub const asType = Type.fromAny(Self); - pub const eql = Type.eqlAny(IndexType); + pub const asType = Type.fromAny(FunctionType); + pub const eql = Type.eqlAny(FunctionType); - pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self { - const func = Type.wrapOr(c.mlirFunctionTypeGet( + pub fn init(ctx: Context, args: []const Type, results: []const Type) !FunctionType { + const func_type = c.mlirFunctionTypeGet( ctx._inner, @intCast(args.len), @ptrCast(args.ptr), @intCast(results.len), @ptrCast(results.ptr), - )) orelse return Error.InvalidMlir; - return func.as(Self).?; + ); + return .{ ._inner = .{ .ptr = func_type.ptr orelse return error.InvalidMlir } }; } }; pub const RankedTensorType = struct { _inner: c.MlirType, pub const is_a_fn = c.mlirTypeIsARankedTensor; + pub const asType = Type.fromAny(RankedTensorType); pub const eql = Type.eqlAny(RankedTensorType); - pub const format = helpers.format(Type, c.mlirTypePrint); + pub const format = helpers.format(RankedTensorType, c.mlirTypePrint); pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType { return .{ ._inner = c.mlirRankedTensorTypeGet( @@ -1534,20 +1493,16 @@ pub const RankedTensorType = struct { pub fn getDimension(self: RankedTensorType, dim: usize) i64 { return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim)); } - - pub const asType = Type.fromAny(RankedTensorType); }; pub const Dialect = struct { _inner: c.MlirDialect, - const Self = Dialect; - - pub fn getContext(self: Self) Context { + pub fn getContext(self: Dialect) Context { return .{ ._inner = c.mlirDialectGetContext(self._inner) }; } - pub fn getNamespace(self: Self) []const u8 { + pub fn getNamespace(self: Dialect) []const u8 { return fromStringRef(c.mlirDialectGetNamespace(self._inner)); } }; @@ -1579,7 +1534,7 @@ pub const DialectHandle = struct { pub const Location = struct { _inner: c.MlirLocation, - pub const eql = helpers.eql(Type, c.mlirLocationEqual); + pub const eql = helpers.eql(Location, c.mlirLocationEqual); pub const format = helpers.format(Location, c.mlirLocationPrint); pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location { @@ -1613,17 +1568,17 @@ pub const Location = struct { ) }; } - pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location { + pub fn named(loc: Location, ctx: Context, loc_name: []const u8) Location { return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) }; } pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location { var buf: [256]u8 = undefined; - var stream = std.io.fixedBufferStream(&buf); - std.fmt.format(stream.writer(), fmt, args) catch { + var writer: std.Io.Writer = .fixed(&buf); + writer.print(fmt, args) catch { buf[256 - 3 ..].* = "...".*; }; - return loc.named(ctx, @ptrCast(stream.getWritten())); + return loc.named(ctx, writer.buffered()); } pub fn unknown(ctx: Context) Location { @@ -1636,7 +1591,6 @@ pub const Block = struct { pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull); pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy); - pub const eql = helpers.eql(Block, c.mlirBlockEqual); pub fn init(args: []const Type, locs: []const Location) !Block { @@ -1736,27 +1690,15 @@ pub const helpers = struct { }.isNull; } - pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type { + pub fn format( + Any: type, + print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void, + ) fn (Any, *std.Io.Writer) std.Io.Writer.Error!void { return struct { - pub fn format(self: Any, writer: *std.Io.Writer) !void { - const WriterWithErr = struct { - writer: *std.Io.Writer, - err: ?std.Io.Writer.Error = null, - fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void { - var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx)); - if (ctx.err) |_| return; - _ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| { - ctx.err = err; - return; - }; - } - }; - - var context: WriterWithErr = .{ .writer = writer }; - print_fn(self._inner, &WriterWithErr.printCallback, &context); - if (context.err) |err| return err; + pub fn format(self: Any, writer: *std.Io.Writer) std.Io.Writer.Error!void { + try callPrintFn(Any, self, print_fn, writer); } - }; + }.format; } pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T { @@ -1767,9 +1709,35 @@ pub const helpers = struct { } }.wrapOr; } +}; - pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T { - if (is_null_fn(inner)) return null; - return .{ ._inner = inner }; +pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void; + +pub fn callPrintFn( + T: type, + value: T, + print_fn: fn (@FieldType(T, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void, + writer: *std.Io.Writer, +) std.Io.Writer.Error!void { + var writer_with_err: WriterWithErr = .{ .writer = writer }; + print_fn(value._inner, &WriterWithErr.printCallback, &writer_with_err); + return writer_with_err.check(); +} + +pub const WriterWithErr = struct { + writer: *std.Io.Writer, + err: ?std.Io.Writer.Error = null, + + pub fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void { + var ctx: *WriterWithErr = @ptrCast(@alignCast(opaque_ctx)); + if (ctx.err) |_| return; + ctx.writer.writeAll(fromStringRef(mlir_str)) catch |err| { + ctx.err = err; + return; + }; + } + + pub fn check(self: WriterWithErr) std.Io.Writer.Error!void { + if (self.err) |err| return err; } }; diff --git a/pjrt/ffi.zig b/pjrt/ffi.zig index 2ca639a..6a49294 100644 --- a/pjrt/ffi.zig +++ b/pjrt/ffi.zig @@ -265,16 +265,8 @@ pub const Buffer = extern struct { return self._dims[0..self.rank]; } - pub fn format( - buffer: Buffer, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; - - try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) }); + pub fn format(buffer: Buffer, writer: *std.Io.Writer) !void { + try writer.print("FfiBuffer({any}, .{t})@0x{x}", .{ buffer.dims(), buffer.dtype, @intFromPtr(buffer.data) }); } }; diff --git a/pjrt/pjrt.zig b/pjrt/pjrt.zig index d4d49dc..25e4551 100644 --- a/pjrt/pjrt.zig +++ b/pjrt/pjrt.zig @@ -1257,14 +1257,7 @@ pub const NamedValue = extern struct { }) }; } - pub fn format( - self: NamedValue, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; + pub fn format(self: NamedValue, writer: *std.Io.Writer) !void { try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] }); const u = self.inner.unnamed_0; switch (self.kind()) { diff --git a/runtimes/neuron/libneuronxla.zig b/runtimes/neuron/libneuronxla.zig index 0010d28..59d8ae6 100644 --- a/runtimes/neuron/libneuronxla.zig +++ b/runtimes/neuron/libneuronxla.zig @@ -117,8 +117,8 @@ fn wrapNeffAsCustomCall(allocator: std.mem.Allocator, hlo_code: []const u8, neff }; { - var operand_ids: std.ArrayListUnmanaged(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]); - var new_instructions: std.ArrayListUnmanaged(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1])); + var operand_ids: std.ArrayList(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]); + var new_instructions: std.ArrayList(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1])); for (entry_instructions) |instruction| { if (std.mem.eql(u8, upb.slice(c.xla_HloInstructionProto_opcode(instruction)) orelse continue, "parameter")) { const id = c.xla_HloInstructionProto_id(instruction); diff --git a/stdx/fmt.zig b/stdx/fmt.zig index 03bf7a3..4e38488 100644 --- a/stdx/fmt.zig +++ b/stdx/fmt.zig @@ -9,6 +9,10 @@ fn FmtSlice(T: type) type { return struct { slice: []const T, + pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void { + return try formatSliceAny(f.slice, .{}, writer); + } + pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { return switch (@typeInfo(T)) { .comptime_float, .float => try formatFloatSlice(f.slice, n, writer), diff --git a/zml/aio.zig b/zml/aio.zig index 38a9914..47cc75f 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -315,14 +315,7 @@ pub const Metadata = union(enum) { }; } - pub fn format( - self: Metadata, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; + pub fn format(self: Metadata, writer: *std.Io.Writer) !void { switch (self) { .null => _ = try writer.write("null"), inline .bool, .array_bool => |b| try writer.print("{any}", .{b}), diff --git a/zml/aio/json.zig b/zml/aio/json.zig index 2c504bb..6f46513 100644 --- a/zml/aio/json.zig +++ b/zml/aio/json.zig @@ -1,10 +1,10 @@ -const async = @import("async"); const std = @import("std"); + +const async = @import("async"); + const zml = @import("../zml.zig"); -const StringBuilder = std.ArrayListUnmanaged(u8); -const Allocator = std.mem.Allocator; - +const StringBuilder = std.ArrayList(u8); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { const file = try std.fs.cwd().openFile(path, .{}); defer file.close(); @@ -26,7 +26,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore return res; } -pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void { +pub fn parseMetadata(allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void { const metadata = &store._metadata; const key = prefix.items; return switch (val) { diff --git a/zml/aio/safetensors.zig b/zml/aio/safetensors.zig index f1cdf29..b49928a 100644 --- a/zml/aio/safetensors.zig +++ b/zml/aio/safetensors.zig @@ -9,7 +9,6 @@ const zml = @import("../zml.zig"); const HostBuffer = zml.HostBuffer; const json = @import("json.zig"); -const StringBuilder = std.ArrayListUnmanaged(u8); const log = std.log.scoped(.@"zml/io"); pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { @@ -17,19 +16,18 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore errdefer res.arena.deinit(); const arena = res.arena.allocator(); - var files = std.array_list.Managed(MemoryMappedFile).init(arena); - errdefer files.deinit(); + var files: std.ArrayList(MemoryMappedFile) = .empty; if (std.mem.endsWith(u8, path, ".safetensors.index.json")) { try loadFromIndex(arena, &res, &files, path); } else { try loadFile(arena, &res, &files, path); } - res.files = try files.toOwnedSlice(); + res.files = try files.toOwnedSlice(allocator); return res; } -fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { +fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { const file = async.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; @@ -61,11 +59,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std. if (index.object.get("__metadata__")) |metadata| { var prefix_buf: [1024]u8 = undefined; - try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata); + try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), metadata); } } -fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { +fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void { const file = async.File.open(path, .{}) catch |err| { log.err("Failed to open {s}: {}", .{ path, err }); return err; @@ -87,7 +85,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array errdefer buffer_file.deinit(); buffer_file.data_offset = 8 + json_header_length; - try files.append(buffer_file); + try files.append(allocator, buffer_file); errdefer _ = files.pop(); var it = metadata.object.iterator(); @@ -95,7 +93,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array const key = entry.key_ptr.*; if (std.mem.eql(u8, key, "__metadata__")) { var prefix_buf: [1024]u8 = undefined; - try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*); + try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), entry.value_ptr.*); continue; } const val = entry.value_ptr.*; diff --git a/zml/aio/torch.zig b/zml/aio/torch.zig index fe48721..7c244f6 100644 --- a/zml/aio/torch.zig +++ b/zml/aio/torch.zig @@ -6,7 +6,6 @@ const zml = @import("../zml.zig"); const eval = @import("torch/eval.zig"); const File = @import("torch/file.zig").File; -const StringBuilder = std.ArrayListUnmanaged(u8); const log = std.log.scoped(.@"zml/aio"); test { diff --git a/zml/aio/torch/eval.zig b/zml/aio/torch/eval.zig index 42565ae..01604e7 100644 --- a/zml/aio/torch/eval.zig +++ b/zml/aio/torch/eval.zig @@ -348,7 +348,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo } fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void { - var array_list = std.ArrayListUnmanaged(py.Any).fromOwnedSlice(current.*); + var array_list = std.ArrayList(py.Any).fromOwnedSlice(current.*); try array_list.appendSlice(allocator, values); current.* = array_list.items; } diff --git a/zml/aio/torch/file.zig b/zml/aio/torch/file.zig index 1afafdd..c1588b4 100644 --- a/zml/aio/torch/file.zig +++ b/zml/aio/torch/file.zig @@ -13,7 +13,7 @@ const py = @import("py.zig"); const log = std.log.scoped(.@"zml/aio"); // TODO(cryptodeal): use zml.aio.PrefixBuilder instead -const StringBuilder = std.ArrayListUnmanaged(u8); +const StringBuilder = std.ArrayList(u8); test { std.testing.refAllDecls(@This()); @@ -191,7 +191,7 @@ pub const File = struct { .boolval => bool, else => unreachable, }; - var values: std.ArrayListUnmanaged(ItemType) = .{}; + var values: std.ArrayList(ItemType) = .{}; try values.append(allocator, val0); for (seq.values[1..], 1..) |val, i| { if (std.meta.activeTag(val) != tag) valid_slice = false; diff --git a/zml/aio/torch/pickle.zig b/zml/aio/torch/pickle.zig index 9174eb1..5724ab0 100644 --- a/zml/aio/torch/pickle.zig +++ b/zml/aio/torch/pickle.zig @@ -768,7 +768,7 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op { // It's not very efficient to interleave the results with the data copied from the stream, // because growth event in the results ArrayList will lead to fragmentation. // Trying to mitigate that by using a generous default size. - var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512); + var results: std.ArrayList(Op) = try .initCapacity(arena, 512); errdefer results.deinit(arena); var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512); diff --git a/zml/aio/torch/py.zig b/zml/aio/torch/py.zig index c523f28..539c28a 100644 --- a/zml/aio/torch/py.zig +++ b/zml/aio/torch/py.zig @@ -1,9 +1,10 @@ const std = @import("std"); const math = std.math; -const log = std.log.scoped(.@"zml/aio"); const pickle = @import("pickle.zig"); +const log = std.log.scoped(.@"zml/aio"); + /// Correspond to a function/constructor call pub const Object = struct { member: Any, @@ -206,12 +207,16 @@ pub const Any = union(Kind) { self.* = undefined; } - inline fn writeIndents(indents: usize, writer: anytype) !void { + inline fn writeIndents(indents: usize, writer: *std.Io.Writer) !void { try writer.writeBytesNTimes(" ", indents); // resolve tab = 2 spaces // try writer.writeByteNTimes('\t'); } - fn internalFormat(value: Any, indents: usize, writer: anytype) !void { + pub fn format(self: Any, writer: *std.Io.Writer) !void { + return internalFormat(self, 0, writer); + } + + fn internalFormat(value: Any, indents: usize, writer: *std.Io.Writer) !void { try writeIndents(indents, writer); try writer.writeAll(".{\n"); try writeIndents(indents + 1, writer); @@ -303,10 +308,6 @@ pub const Any = union(Kind) { try writer.writeByte('}'); } - pub fn format(self: Any, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void { - return internalFormat(self, 0, writer); - } - pub fn clone(self: Any, allocator: std.mem.Allocator) !Any { return switch (self) { inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)), diff --git a/zml/buffer.zig b/zml/buffer.zig index 6dcc42f..36a55d8 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -384,10 +384,7 @@ pub const Buffer = struct { return res; } - pub fn format( - self: Buffer, - writer: anytype, - ) !void { + pub fn format(self: Buffer, writer: *std.Io.Writer) !void { try writer.print("Buffer({f})", .{self._shape}); } diff --git a/zml/exe.zig b/zml/exe.zig index 9e3d35e..37c9826 100644 --- a/zml/exe.zig +++ b/zml/exe.zig @@ -95,6 +95,8 @@ pub fn compileFn( ) !FnExe(func) { var pretty_name = try prettyFnName(func, allocator); defer pretty_name.deinit(allocator); + log.info("Compiling {s} with {f}", .{ pretty_name.items, stdx.fmt.any(args) }); + var context = try CompilationContext.init(allocator, pretty_name.items, platform); defer context.deinit(); @@ -306,7 +308,7 @@ pub const BaseExe = struct { } } - pub fn serialize(self: BaseExe, writer: anytype) !void { + pub fn serialize(self: BaseExe, writer: *std.Io.Writer) !void { var executable = try self.exe.getExecutable(self.platform.pjrt_api); var serialize_result = try executable.serialize(self.platform.pjrt_api); defer serialize_result.deinit(); @@ -377,7 +379,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type { try self.inner.bind(T, value); } - pub fn serialize(self: Self, writer: anytype) !void { + pub fn serialize(self: Self, writer: *std.Io.Writer) !void { return try self.inner.serialize(writer); } @@ -437,7 +439,7 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff fn prettyFnName( comptime func: anytype, allocator: std.mem.Allocator, -) !std.ArrayListUnmanaged(u8) { +) !std.ArrayList(u8) { const full_noisy_name = @typeName(@TypeOf(func)); const og_len = full_noisy_name.len; const buffer = try allocator.alloc(u8, og_len); diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 848fee8..88419b7 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -321,10 +321,7 @@ pub const HostBuffer = struct { }; } - pub fn format( - self: HostBuffer, - writer: anytype, - ) !void { + pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void { try writer.print("HostBuffer(.{f})", .{self._shape}); } diff --git a/zml/meta.zig b/zml/meta.zig index 855318a..891a27e 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -679,28 +679,38 @@ test zip { /// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// finds all X in the given object, and write the result of func(X) into an arraylist. -pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { +pub fn collectAlloc( + func: anytype, + func_ctx: _CollectCtx(func), + allocator: std.mem.Allocator, + obj: anytype, +) std.mem.Allocator.Error![]stdx.meta.FnSignature(func, null).ReturnT { stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); - const LocalContext = struct { + + const CollectAllocCtx = struct { func_ctx: _CollectCtx(func), - out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), + allocator: std.mem.Allocator, + out: std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT) = .empty, oom: bool = false, - }; - var context = LocalContext{ .func_ctx = func_ctx, .out = out }; - visit((struct { - fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void { + + fn cb(ctx: *@This(), val: *const _CollectArg(func)) void { if (ctx.oom) return; const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*); - ctx.out.append(res) catch { + ctx.out.append(ctx.allocator, res) catch { ctx.oom = true; }; } - }).cb, &context, obj); + }; + var context = CollectAllocCtx{ .func_ctx = func_ctx, .allocator = allocator }; + visit(CollectAllocCtx.cb, &context, obj); if (context.oom) return error.OutOfMemory; + + return context.out.toOwnedSlice(allocator); } /// Given a func(X) -> Y or a func(Ctx, X) -> Y, -/// finds all X in the given object, and write the result of func(X) into an arraylist. +/// finds all X in the given object, and write the result of func(X) into a given slice. +/// Asserts that the number of X found is equal to the slice len. pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void { stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)}); const LocalContext = struct { diff --git a/zml/mlirx.zig b/zml/mlirx.zig index 28f4038..f50a2ef 100644 --- a/zml/mlirx.zig +++ b/zml/mlirx.zig @@ -96,6 +96,6 @@ pub const Type = struct { } } - std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); + std.debug.panic("Could not convert mlir.Type to DataType: {f}", .{mlir_type}); } }; diff --git a/zml/module.zig b/zml/module.zig index 294d76b..5c27607 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -187,7 +187,8 @@ pub const CompilationContext = struct { if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| { var write_buf: [4096]u8 = undefined; var writer = file.writer(&write_buf); - module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false }); + try module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false }); + try writer.interface.flush(); log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name }); } else |_| { log.warn("Failed to open {s}", .{mlir_name}); @@ -341,12 +342,11 @@ pub const CompilationContext = struct { const locations = try arena.alloc(mlir.Location, tensor_count); @memset(locations, mlir.Location.unknown(mlir_ctx)); - var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count); - meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; - stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{}); + const input_shapes = try res_allocator.alloc(Shape, tensor_count); + meta.collectBuf(Tensor.shape, {}, args, input_shapes); const input_types = try arena.alloc(mlir.Type, tensor_count); - for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh); + for (input_types, input_shapes) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh); const og_block_args = self._block_args; defer { @@ -399,7 +399,7 @@ pub const CompilationContext = struct { self.addDonationsAttributes(arg_attrs, fn_res_donations); self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind); if (self._platform.sharding().num_partitions > 1) { - self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); + self.addShardingAttributes(arg_attrs, res_attrs, input_shapes, fn_res_shapes); } } @@ -427,7 +427,7 @@ pub const CompilationContext = struct { return .{ .mlir_fn = mlir_fn, .name = opts.name, - .args_shapes = input_shapes.items, + .args_shapes = input_shapes, .res_tensors = fn_res, .res_types = fn_res_types, .res_shapes = fn_res_shapes, @@ -506,9 +506,9 @@ pub const CompilationContext = struct { var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); - var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator); - defer mlir_bytecode.deinit(); - try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})}); + var mlir_code: std.Io.Writer.Allocating = .init(std.testing.allocator); + defer mlir_code.deinit(); + try f.mlir_fn.print(&mlir_code.writer, .{}); // Check that the `x` input argument gives its buffer to the result tensor. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. @@ -518,8 +518,8 @@ pub const CompilationContext = struct { var buf = template.*; for (0..2) |i| { const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable; - std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| { - log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); + std.testing.expect(std.mem.indexOf(u8, mlir_code.written(), alias_attr) != null) catch |err| { + log.warn("Didn't produced the expected IR:\n{s}", .{mlir_code.written()}); return err; }; } @@ -547,12 +547,14 @@ pub const CompilationContext = struct { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { const ctx = self.mlirCtx(); const num_partitions = self.numPartitions(); - var sharding_str: stdx.BoundedArray(u8, 128) = .{}; - writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; - return mlir.Attribute.string(ctx, sharding_str.constSlice()); + // This is big enough, see test below for examples values + var sharding_str_buf: [64]u8 = undefined; + var writer: std.Io.Writer = .fixed(&sharding_str_buf); + writeShardingRepresentation(shape, num_partitions, &writer) catch unreachable; + return mlir.Attribute.string(ctx, writer.buffered()); } - fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void { + fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: *std.Io.Writer) std.Io.Writer.Error!void { const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info))); if (n_sharded == 0 or num_partitions == 1) { try writer.writeAll("{replicated}"); @@ -567,26 +569,26 @@ pub const CompilationContext = struct { } test writeShardingRepresentation { - var rule: [64]u8 = undefined; + var attr_buf: [64]u8 = undefined; const x = Shape.init(.{ 16, 8 }, .f32); // By default tensors are replicated. { - var fbs = std.io.fixedBufferStream(&rule); - try writeShardingRepresentation(x, 4, fbs.writer()); - try std.testing.expectEqualStrings("{replicated}", fbs.getWritten()); + var writer: std.Io.Writer = .fixed(&attr_buf); + try writeShardingRepresentation(x, 4, &writer); + try std.testing.expectEqualStrings("{replicated}", writer.buffered()); } // Shard along first axis. { - var fbs = std.io.fixedBufferStream(&rule); - try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer()); - try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten()); + var writer: std.Io.Writer = .fixed(&attr_buf); + try writeShardingRepresentation(x.withSharding(.{0}), 4, &writer); + try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", writer.buffered()); } // Also shard along second axis. { - var fbs = std.io.fixedBufferStream(&rule); - try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer()); - try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten()); + var writer: std.Io.Writer = .fixed(&attr_buf); + try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, &writer); + try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", writer.buffered()); } } diff --git a/zml/ops.zig b/zml/ops.zig index ee12508..edbcd4a 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -443,29 +443,23 @@ pub fn if_( const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); - var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); - defer true_shapes.deinit(); - var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); - defer false_shapes.deinit(); + check: { + const arena = ctx.allocator(); + const true_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &true_branch_res) catch break :check; + defer arena.free(true_shapes); - var failed_to_collect = false; - meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch { - failed_to_collect = true; - }; - meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch { - failed_to_collect = true; - }; - if (!failed_to_collect) { - stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items }); - for (true_shapes.items, false_shapes.items) |true_shape, false_shape| { - stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items }); + const false_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &false_branch_res) catch break :check; + defer arena.free(false_shapes); + + stdx.debug.assert(true_shapes.len == false_shapes.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) }); + for (true_shapes, false_shapes) |true_shape, false_shape| { + stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) }); } } - const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0); const loc = ctx.mlirCtx().location(@src()); const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ - .operands = &.{scalar_pred.value()}, + .operands = &.{pred.asScalar().value()}, .result_type_inference = true, .blocks = &.{ true_branch_block, false_branch_block }, // We can't verify right away, cause the weights captured by the if haven't been added yet. @@ -958,18 +952,20 @@ pub fn scatter( const UpdateS = BlockSign(update_fn); const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar }); - var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); - defer input_values.deinit(); - meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable; - var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); - defer updates_values.deinit(); - meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable; + const arena = ctx.allocator(); + const input_values = arena.alloc(mlir.Value, n_inputs) catch @panic("OOM"); + defer arena.free(input_values); + meta.collectBuf(CompilationContext.getValue, ctx, &inputs, input_values); + + const updates_values = arena.alloc(mlir.Value, n_updates) catch @panic("OOM"); + defer arena.free(updates_values); + meta.collectBuf(CompilationContext.getValue, ctx, &updates, updates_values); const op = dialect.stablehlo.scatter( mlir_ctx, - input_values.items, + input_values, &.{indices.value()}, - updates_values.items, + updates_values, update_block, .{ .update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(), diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index f3c8c34..5932619 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -75,14 +75,17 @@ pub const Client = opaque { } fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { - var bytecode: std.array_list.Managed(u8) = .init(allocator); + var bytecode: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096); defer bytecode.deinit(); - module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| { + module.op().writeBytecodeWithConfig(&bytecode.writer, .{ .desiredEmitedVersion = 1 }) catch |err| { log.err("failed to write module bytecode: {}", .{err}); - return err; + return switch (err) { + std.Io.Writer.Error.WriteFailed => error.OutOfMemory, + else => |e| e, + }; }; - var serialized_buffer: std.array_list.Managed(u8) = .init(allocator); + var serialized_buffer: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096); defer serialized_buffer.deinit(); const stablehlo_version = blk: { @@ -92,13 +95,16 @@ pub const Client = opaque { break :blk dialects.stablehlo.getMinimumVersion(); }; - dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| { + dialects.stablehlo.serializePortableArtifact(bytecode.written(), stablehlo_version, &serialized_buffer.writer) catch |err| { log.err("failed to serialize to portable artifact: {}", .{err}); - return err; + return switch (err) { + std.Io.Writer.Error.WriteFailed => error.OutOfMemory, + else => |e| e, + }; }; return @ptrCast(try self.inner().compile(api, .{ - .bytecode = serialized_buffer.items, + .bytecode = serialized_buffer.written(), .bytecode_format = .mlir, .compile_options_pb = compile_options_pb, })); diff --git a/zml/platform.zig b/zml/platform.zig index 6e13d7f..a6f4e14 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -95,7 +95,7 @@ const _CreateOptions = struct { pub const Cpu = struct { device_count: u32, - fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { + fn writeNamedValues(self: Cpu, values: *std.ArrayList(pjrt.NamedValue)) void { values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count))); } }; @@ -124,7 +124,7 @@ const _CreateOptions = struct { }; }; - fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { + fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void { switch (self.allocator) { .platform => { values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); @@ -145,7 +145,7 @@ const _CreateOptions = struct { }; pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue { - var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out); + var values = std.ArrayList(pjrt.NamedValue).fromOwnedSlice(out); values.shrinkRetainingCapacity(0); switch (target) { .cpu => self.cpu.writeNamedValues(&values), diff --git a/zml/shape.zig b/zml/shape.zig index bd5f2a2..a9daffa 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -386,14 +386,8 @@ pub const Shape = struct { /// Format the shape. /// Default format: "Shape({.a=10, .b=20}, dtype=.f32)" /// Bare format {_}: "{.a=10, .b=20}, dtype=.f32" - pub fn format( - self: Shape, - writer: anytype, - ) !void { - // TODO: impl alternative format - // const bare_fmt = fmt.len == 1 and fmt[0] == '_'; - const bare_fmt = true; - _ = try writer.write(if (bare_fmt) "{" else "Shape({"); + pub fn format(self: Shape, writer: *std.Io.Writer) !void { + _ = try writer.writeByte('{'); var need_comma = false; for (self.dims(), 0..) |d, i| { @@ -411,7 +405,7 @@ pub const Shape = struct { } if (need_comma) try writer.writeByte(','); _ = try writer.write(@tagName(self.dtype())); - _ = try writer.write(if (bare_fmt) "}" else "})"); + _ = try writer.writeByte('}'); } /// Broadcasts a Tensor to the given shape, extending dimensions if needed. diff --git a/zml/tensor.zig b/zml/tensor.zig index 9d55ebf..50b8c56 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -52,10 +52,7 @@ pub const Tensor = struct { return CompilationContext.current(); } - pub fn format( - self: Tensor, - writer: anytype, - ) !void { + pub fn format(self: Tensor, writer: *std.Io.Writer) !void { // TODO(0.15.0) handle format // const bare_fmt = fmt.len == 1 and fmt[0] == '_'; const bare_fmt = false; @@ -1146,7 +1143,7 @@ pub const Tensor = struct { contracting_axes: []const [2]i8, batching_axes: []const [2]i8, ) Tensor { - stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); + stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {f} and {f}", .{ lhs, rhs }); const Axes = stdx.BoundedArray(i64, MAX_RANK); @@ -1156,7 +1153,7 @@ pub const Tensor = struct { var rhs_batching_axes: Axes = .{}; for (batching_axes) |b_axes| { const l, const r = b_axes; - stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs }); + stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {d} and {d} in {f} and {f}", .{ l, r, lhs, rhs }); var t = lhs._shape.tag(l); if (t == Shape.TagUnknown) t = rhs._shape.tag(r); res_shape = res_shape.appendDim(lhs._shape.dim(l), t); @@ -1521,14 +1518,7 @@ pub const Tensor = struct { const to_the_end = std.math.maxInt(i64); - pub fn format( - self: Slice, - comptime fmt: []const u8, - options: std.fmt.FormatOptions, - writer: anytype, - ) !void { - _ = fmt; - _ = options; + pub fn format(self: Slice, writer: *std.Io.Writer) !void { if (self.singleton) { try writer.print("[{}]", .{self.start}); } else if (self.end == to_the_end and self.step == 1) { @@ -2043,7 +2033,7 @@ pub const Tensor = struct { /// Converts the given 1 element Tensor into a 0-rank Tensor. pub fn asScalar(self: Tensor) Tensor { stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self}); - return self.reshape(.{}); + return if (self.rank() == 0) self else self.reshape(.{}); } pub const Pad = struct { @@ -2582,7 +2572,7 @@ pub const Tensor = struct { /// that requires host<->device synchronization. /// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR. pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor { - scoped_log.debug("scatterSlices({}, {any}, {})", .{ self, indices, updates }); + scoped_log.debug("scatterSlices({f}, {any}, {f})", .{ self, indices, updates }); const UpdateType = @TypeOf(ScatterOpts.increment); @@ -3087,7 +3077,7 @@ pub const Tensor = struct { const tail_chunk_size: i64 = @rem(d, chunk_size); const allocator = self.getContext().allocator(); - var chunks = std.ArrayListUnmanaged(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM"); + var chunks = std.ArrayList(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM"); for (0..n_chunks) |i| { const start: i64 = @as(i64, @intCast(i)) * chunk_size; chunks.appendAssumeCapacity( diff --git a/zml/testing.zig b/zml/testing.zig index 59778ea..c9ad9ef 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -229,7 +229,7 @@ pub fn testLayerOut( const FetchCtx = struct { store: zml.aio.BufferStore, index: u32, - prefix: std.ArrayListUnmanaged(u8), + prefix: std.ArrayList(u8), platform: zml.Platform, fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer { diff --git a/zml/tokenizer/BUILD.bazel b/zml/tokenizer/BUILD.bazel index 9b2f9c7..b7af9e2 100644 --- a/zml/tokenizer/BUILD.bazel +++ b/zml/tokenizer/BUILD.bazel @@ -1,4 +1,4 @@ -load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library") +load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_test") zig_library( name = "tokenizer", @@ -25,3 +25,14 @@ zig_binary( "//stdx", ], ) + +zig_test( + name = "test", + main = "homemade.zig", + visibility = ["//visibility:public"], + deps = [ + "//async", + "//ffi:zig", + "//stdx", + ], +) diff --git a/zml/tokenizer/hftokenizers/hftokenizers.zig b/zml/tokenizer/hftokenizers/hftokenizers.zig index 7b36dbe..01b9c99 100644 --- a/zml/tokenizer/hftokenizers/hftokenizers.zig +++ b/zml/tokenizer/hftokenizers/hftokenizers.zig @@ -110,6 +110,7 @@ pub const HFTokenizer = opaque { } pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 { - return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token)); + const id = c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token)); + return if (id == std.math.maxInt(u32)) null else id; } }; diff --git a/zml/tokenizer/homemade.zig b/zml/tokenizer/homemade.zig index 9b5898d..93383b0 100644 --- a/zml/tokenizer/homemade.zig +++ b/zml/tokenizer/homemade.zig @@ -189,10 +189,9 @@ pub const Tokenizer = struct { // Step by step visualization of the progress. if (options.debug) { var _debug_buf: [256]u8 = undefined; - var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); - var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator()); - self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; - log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items }); + var debug_progress: std.Io.Writer = .fixed(&_debug_buf); + self.decodeWithOpts(tok_buff[0..num_tokens], &debug_progress, .{ .sep = "|" }) catch {}; + log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.buffered() }); } var best_score: f32 = -1e10; var best_token: u32 = 0; @@ -311,22 +310,19 @@ pub const Tokenizer = struct { /// Converts the given slice of tokens back into bytes. /// Note that if the tokenizer allows sub-unicode bytes, it's possible /// the output is not valid utf8. - pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { - var output = std.array_list.Managed(u8).init(allocator); - errdefer output.deinit(); - - try self.decodeWithOpts(&output, input, .{}); - return output.toOwnedSlice(); + pub fn decode(self: *const Tokenizer, input: []const u32, writer: *std.Io.Writer) std.Io.Writer.Error!void { + try self.decodeWithOpts(input, writer, .{}); } pub fn decodeWithOpts( self: *const Tokenizer, - output: *std.array_list.Managed(u8), input: []const u32, + output: *std.Io.Writer, opts: struct { sep: []const u8 = "" }, - ) error{OutOfMemory}!void { + ) std.Io.Writer.Error!void { const escaped = if (self.normalizer) |n| n.escapedSpace() else null; // Flag used to indicate if the first dummy whitespace has been consumed. + var first_token: bool = true; for (input) |id| { // Retrieve the slice corresponding to the id. var piece = self.lookupPiece(id); @@ -337,12 +333,13 @@ pub const Tokenizer = struct { while (std.mem.startsWith(u8, piece, escspc)) { piece = piece[escspc.len..]; // don't output a space at beginning of text. - if (output.items.len > 0) try output.append(' '); + if (!first_token) try output.writeByte(' '); } } - try output.appendSlice(piece); - if (opts.sep.len > 0) try output.appendSlice(opts.sep); + try output.writeAll(piece); + first_token = false; + try output.writeAll(opts.sep); } } @@ -472,10 +469,13 @@ pub const Decoder = struct { } pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 { + // TODO: revisite tokenizer api to use std.Io.Writer. self.reset(); - const res = try self.inner.decode(self.arena.allocator(), ids); - self.current_string = res; - return res; + // Reuse the same warmup than in init, to maximize contiguous writes. + var arena_writer = std.Io.Writer.Allocating.initCapacity(self.arena.allocator(), 4096) catch unreachable; + try self.inner.decode(ids, &arena_writer.writer); + self.current_string = arena_writer.written(); + return self.current_string.?; } pub fn string(self: *const Decoder) []const u8 { @@ -623,9 +623,9 @@ pub const Normalizer = struct { return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; } - fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void { - try normalized.appendSlice(data); - for (data) |_| try normalized_to_origin.append(consumed); + fn addSlice(allocator: std.mem.Allocator, data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void { + try normalized.appendSlice(allocator, data); + for (data) |_| try normalized_to_origin.append(allocator, consumed); } pub const Result = struct { @@ -673,13 +673,13 @@ pub const Normalizer = struct { // Pre-allocate outputs const space = self.escapedSpace() orelse " "; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; - var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); - errdefer normalized.deinit(); - var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity); - errdefer normalized_to_origin.deinit(); + var normalized: std.ArrayList(u8) = try .initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); + errdefer normalized.deinit(allocator); + var normalized_to_origin: std.ArrayList(usize) = try .initCapacity(allocator, normalized.capacity); + errdefer normalized_to_origin.deinit(allocator); // If the spec asks for it, add a whitespace at the beginning. - if (self.flags.add_dummy_prefix) try addSlice(space, consumed, &normalized, &normalized_to_origin); + if (self.flags.add_dummy_prefix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin); var is_prev_space: bool = true; var is_prev_word: bool = false; @@ -706,23 +706,23 @@ pub const Normalizer = struct { var byte = slice[0]; if (self.escapedSpace() != null and byte == ' ') { // replace the space token by the special token - try addSlice(space, origin, &normalized, &normalized_to_origin); + try addSlice(allocator, space, origin, &normalized, &normalized_to_origin); is_prev_word = false; break :ascii; } else if (self.flags.split_on_punct_ascii) { if (is_prev_word and isPunct(slice)) { // Insert a space, but continue handling the rest - try addSlice(space, origin, &normalized, &normalized_to_origin); + try addSlice(allocator, space, origin, &normalized, &normalized_to_origin); } } if (self.flags.lower_case_ascii) { byte = std.ascii.toLower(byte); } - try normalized.append(byte); - try normalized_to_origin.append(origin); + try normalized.append(allocator, byte); + try normalized_to_origin.append(allocator, origin); } else { // we can safely copy to the output. - try addSlice(slice, origin, &normalized, &normalized_to_origin); + try addSlice(allocator, slice, origin, &normalized, &normalized_to_origin); } is_prev_word = !is_prev_space and !isPunct(slice); } @@ -732,20 +732,20 @@ pub const Normalizer = struct { while (std.mem.endsWith(u8, normalized.items, space)) { const length = normalized.items.len - space.len; consumed = normalized_to_origin.items[length]; - try normalized.resize(length); - try normalized_to_origin.resize(length); + try normalized.resize(allocator, length); + try normalized_to_origin.resize(allocator, length); } } - try normalized_to_origin.append(consumed); + try normalized_to_origin.append(allocator, consumed); std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1); - if (self.flags.add_dummy_suffix) try addSlice(space, consumed, &normalized, &normalized_to_origin); + if (self.flags.add_dummy_suffix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin); return .{ - .normalized = try normalized.toOwnedSlice(), - .normalized_to_origin = try normalized_to_origin.toOwnedSlice(), + .normalized_to_origin = try normalized_to_origin.toOwnedSlice(allocator), + .normalized = try normalized.toOwnedSlice(allocator), }; } @@ -1006,8 +1006,7 @@ pub const Gpt2TextDecoder = struct { /// Transform bytes representing text under the gpt2 encoding, /// and write to the `unicode` buffer utf-8 bytes. - pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 { - const start = unicode.items.len; + pub fn decode(self: Gpt2TextDecoder, bytes: []const u8, writer: *std.Io.Writer) (error{InvalidInput} || std.Io.Writer.Error)!void { var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes }; while (it.nextCodepointSlice()) |codepoint| { const code: Code = switch (codepoint.len) { @@ -1016,9 +1015,8 @@ pub const Gpt2TextDecoder = struct { else => return error.InvalidInput, }; const byte = self.code_to_byte.get(code) orelse return error.InvalidInput; - try unicode.append(byte); + try writer.writeByte(byte); } - return unicode.items[start..]; } inline fn isPrintableByte(c: u8) bool { @@ -1030,15 +1028,33 @@ test Gpt2TextDecoder { var decoder = try Gpt2TextDecoder.init(testing.allocator); defer decoder.deinit(); - var out = std.array_list.Managed(u8).init(testing.allocator); - defer out.deinit(); + var buf: [128]u8 = undefined; // Ascii is not changed. - try testing.expectEqualStrings("getTitle", try decoder.decode(&out, "getTitle")); + { + var out: std.Io.Writer = .fixed(&buf); + try decoder.decode("getTitle", &out); + try testing.expectEqualStrings("getTitle", out.buffered()); + } + + { + var out: std.Io.Writer = .fixed(&buf); + try decoder.decode("Ċ", &out); + try testing.expectEqualStrings("\n", out.buffered()); + } + // Leading space are represented with 'Ġ' - try testing.expectEqualStrings(" UINavigationController", try decoder.decode(&out, "ĠUINavigationController")); + { + var out: std.Io.Writer = .fixed(&buf); + try decoder.decode("ĠUINavigationController", &out); + try testing.expectEqualStrings(" UINavigationController", out.buffered()); + } // Russian is wild - try testing.expectEqualStrings(" работ", try decoder.decode(&out, "ĠÑĢабоÑĤ")); + { + var out: std.Io.Writer = .fixed(&buf); + try decoder.decode("ĠÑĢабоÑĤ", &out); + try testing.expectEqualStrings(" работ", out.buffered()); + } } /// Open a json file in HF format and load the vocab from it. @@ -1077,12 +1093,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok // Buffer containing all concatenated tokens. // Reserve a big chunk, to avoid grow event, but release over-allocated memory. - var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); - const original_alloc = all_tokens.items.ptr; + var all_tokens: std.Io.Writer.Allocating = try .initCapacity(tokenizer.arena_state.allocator(), file_content.len); + const original_alloc = all_tokens.writer.buffer.ptr; // A re-alloc event here means we have invalidated all slices inside the tokenizer. // If this is too annoying we could switch to a custom type instead of slices. defer { - std.debug.assert(all_tokens.items.ptr == original_alloc); + std.debug.assert(all_tokens.writer.buffer.ptr == original_alloc); } // gpt2 based tokenizer got a special way of encoding unicode. @@ -1094,16 +1110,18 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok defer gpt2_decoder.deinit(); var it = vocab.iterator(); while (it.next()) |kv| { - const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { + const n = all_tokens.writer.buffered().len; + gpt2_decoder.decode(kv.key_ptr.*, &all_tokens.writer) catch |err| { switch (err) { error.InvalidInput => { is_gpt2_vocab = false; break; }, - else => return err, + error.WriteFailed => return error.OutOfMemory, } }; const idx: u32 = @intCast(kv.value_ptr.*.integer); + const token = all_tokens.written()[n..]; tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); } @@ -1126,15 +1144,16 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok if (token_obj != .object) return error.InvalidFormat; const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat; const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat); - const token = try if (is_gpt2_vocab) - gpt2_decoder.decode(&all_tokens, v) + const n = all_tokens.written().len; + try if (is_gpt2_vocab) + gpt2_decoder.decode(v, &all_tokens.writer) else - dup(&all_tokens, v); - + all_tokens.writer.writeAll(v); + const token = all_tokens.written()[n..]; tokenizer.addOwnedTokenByIndex(id, 0, token); } // We won't add more tokens here, let release. - all_tokens.shrinkAndFree(all_tokens.items.len); + _ = try all_tokens.toOwnedSlice(); var unk = tokenizer.lookup(""); if (objectGet(model, .integer, "unk_token")) |unk_tok| { @@ -1165,10 +1184,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok } /// Returns a copy of the given string, stored inside the given ArrayList. -fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 { - const n = buffer.items.len; - try buffer.appendSlice(str); - return buffer.items[n..]; +fn dup(allocating: *std.Io.Writer.Allocating, str: []const u8) std.Io.Writer.Error![]const u8 { + const n = allocating.written().len; + try allocating.writer.writeAll(str); + return allocating.written()[n..]; } /// Returns the given entry in a json object only if it has the right type. diff --git a/zml/tokenizer/main.zig b/zml/tokenizer/main.zig index fc56838..5d78cb8 100644 --- a/zml/tokenizer/main.zig +++ b/zml/tokenizer/main.zig @@ -48,11 +48,11 @@ pub fn asyncMain() !void { } if (args.expected.len > 0) { - var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len); + var expected: std.ArrayList(u32) = try .initCapacity(allocator, args.prompt.len); var it = std.mem.splitSequence(u8, args.expected, ","); while (it.next()) |int_token| { const tok = try std.fmt.parseInt(u32, int_token, 10); - try expected.append(tok); + try expected.append(allocator, tok); } if (!std.mem.eql(u32, expected.items, prompt_tok)) { log.err("Doesn't match expected: {any}", .{expected.items});