Remove deprecated writer interface APIs from core ZML modules (async, MLIR, PJRT, runtime, fmt, aio, buffer, exe, hostbuffer, meta, mlirx).

This commit is contained in:
Tarry Singh 2025-09-04 14:03:09 +00:00
parent 090d7748d5
commit 3ed9bca5ad
31 changed files with 426 additions and 469 deletions

View File

@ -159,26 +159,21 @@ const Coro = struct {
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn { fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn {
const from_coro: *Coro = @fieldParentPtr("impl", from); const from_coro: *Coro = @fieldParentPtr("impl", from);
const this_coro: *Coro = @fieldParentPtr("impl", this); const this_coro: *Coro = @fieldParentPtr("impl", this);
log(.debug, "coro start {any}", .{this_coro.id}); log(.debug, "coro start {f}", .{this_coro.*});
@call(.auto, this_coro.func, .{}); @call(.auto, this_coro.func, .{});
this_coro.status = .Done; this_coro.status = .Done;
thread_state.switchOut(from_coro); thread_state.switchOut(from_coro);
// Never returns // Never returns
stdx.debug.panic("Cannot resume an already completed coroutine {any}", .{this_coro.id}); stdx.debug.panic("Cannot resume an already completed coroutine {f}", .{this_coro.*});
} }
pub fn getStorage(self: Coro, comptime T: type) *T { pub fn getStorage(self: Coro, comptime T: type) *T {
return @ptrCast(@alignCast(self.storage)); return @ptrCast(@alignCast(self.storage));
} }
pub fn format(self: Coro, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: Coro, writer: *std.Io.Writer) !void {
_ = fmt; try writer.print("Coro{{.id = {any}, .status = {t}}}", .{ self.id, self.status });
_ = options;
try writer.print("Coro{{.id = {any}, .status = {s}}}", .{
self.id,
@tagName(self.status),
});
} }
}; };
@ -292,7 +287,7 @@ const ThreadState = struct {
/// Called from resume /// Called from resume
fn switchIn(self: *ThreadState, target: Frame) void { fn switchIn(self: *ThreadState, target: Frame) void {
log(.debug, "coro resume {any} from {any}", .{ target.id, self.current().id }); log(.debug, "coro resume {f} from {f}", .{ target.id, self.current().id });
// Switch to target, setting this coro as the resumer. // Switch to target, setting this coro as the resumer.
self.switchTo(target, true); self.switchTo(target, true);
@ -307,7 +302,7 @@ const ThreadState = struct {
/// Called from suspend /// Called from suspend
fn switchOut(self: *ThreadState, target: Frame) void { fn switchOut(self: *ThreadState, target: Frame) void {
log(.debug, "coro suspend {any} to {any}", .{ self.current().id, target.id }); log(.debug, "coro suspend {f} to {f}", .{ self.current().id, target.id });
self.switchTo(target, false); self.switchTo(target, false);
} }
@ -384,13 +379,8 @@ const CoroId = struct {
self.invocation += 1; self.invocation += 1;
} }
pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void { pub fn format(self: @This(), writer: *std.Io.Writer) !void {
_ = fmt; try writer.print("CoroId{{.cid={d}, .i={d}}}", .{ self.id.coro, self.invocation });
_ = options;
try writer.print("CoroId{{.cid={d}, .i={d}}}", .{
self.id.coro,
self.invocation,
});
} }
}; };
}; };

View File

@ -750,13 +750,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) { if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
stdx.debug.assert( stdx.debug.assert(
backend_config.isA(mlir.StringAttribute), backend_config.isA(mlir.StringAttribute),
"API version < 4 requires a string as backend_config, got {}", "API version < 4 requires a string as backend_config, got {f}",
.{backend_config}, .{backend_config},
); );
} else { } else {
stdx.debug.assert( stdx.debug.assert(
backend_config.isA(mlir.DictionaryAttribute), backend_config.isA(mlir.DictionaryAttribute),
"API version >= 4 requires a dictionary as backend_config, got {}", "API version >= 4 requires a dictionary as backend_config, got {f}",
.{backend_config}, .{backend_config},
); );
} }
@ -1285,20 +1285,21 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
} }
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 { pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
var buf: [32]u8 = undefined; const Cmp = struct {
v1: []const u8,
v1_is_smaller: bool = undefined,
var stream = std.io.fixedBufferStream(&buf); pub fn smallerCb(smaller_version: c.MlirStringRef, opaque_cmp: ?*anyopaque) callconv(.c) void {
var context = .{ .writer = stream.writer() }; var cmp: *@This() = @ptrCast(@alignCast(opaque_cmp));
const WriterContext = @TypeOf(context); cmp.v1_is_smaller = std.mem.eql(u8, cmp.v1, smaller_version.data[0..smaller_version.length]);
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
} }
}).callback, &context); };
return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2; var cmp_ctx: Cmp = .{ .v1 = version1 };
const cmp_res = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), Cmp.smallerCb, &cmp_ctx);
std.debug.assert(cmp_res.value != 0);
return if (cmp_ctx.v1_is_smaller) version1 else version2;
} }
pub fn getCurrentVersion() []const u8 { pub fn getCurrentVersion() []const u8 {
@ -1308,18 +1309,9 @@ pub fn getCurrentVersion() []const u8 {
var once = std.once(call); var once = std.once(call);
fn call() void { fn call() void {
var stream = std.io.fixedBufferStream(&buf); var writer: std.Io.Writer = .fixed(&buf);
var writer_ = stream.writer(); c.stablehloGetCurrentVersion(printCallbackNoFail, &writer);
const ContextWriter = @TypeOf(writer_); str = writer.buffered();
c.stablehloGetCurrentVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &writer_);
str = buf[0..stream.pos];
} }
}; };
@ -1334,18 +1326,9 @@ pub fn getMinimumVersion() []const u8 {
var once = std.once(call); var once = std.once(call);
fn call() void { fn call() void {
var stream = std.io.fixedBufferStream(&buf); var writer: std.Io.Writer = .fixed(&buf);
var context = .{ .writer = stream.writer() }; c.stablehloGetMinimumVersion(printCallbackNoFail, &writer);
const WriterContext = @TypeOf(context); str = writer.buffered();
c.stablehloGetMinimumVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &context);
str = buf[0..stream.pos];
} }
}; };
@ -1353,14 +1336,25 @@ pub fn getMinimumVersion() []const u8 {
return state.str; return state.str;
} }
pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void { pub fn serializePortableArtifact(
var context = .{ .writer = writer }; bytecode: []const u8,
const WriterContext = @TypeOf(context); target_version: []const u8,
writer: *std.Io.Writer,
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct { ) error{ InvalidMlirBytecodeVersion, WriteFailed }!void {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void { var writer_err: mlir.WriterWithErr = .{ .writer = writer };
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); try mlir.successOr(
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; c.stablehloSerializePortableArtifactFromStringRef(
} mlir.stringRef(bytecode),
}).callback, &context), error.InvalidMlirBytecodeVersion); mlir.stringRef(target_version),
mlir.WriterWithErr.printCallback,
&writer_err,
),
error.InvalidMlirBytecodeVersion,
);
return try writer_err.check();
}
fn printCallbackNoFail(mlir_str: c.MlirStringRef, opaque_writer: ?*anyopaque) callconv(.c) void {
const writer: *std.Io.Writer = @ptrCast(@alignCast(opaque_writer));
writer.writeAll(mlir.fromStringRef(mlir_str)) catch @panic("Failed to write MLIR");
} }

View File

@ -7,18 +7,20 @@ const stdx = @import("stdx");
const log = std.log.scoped(.mlir); const log = std.log.scoped(.mlir);
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDeclsRecursive(@This());
_ = try Context.init(); _ = try Context.init();
} }
const Error = error{ pub const Error = error{
/// Invalid Mlir was created. /// Invalid Mlir was created.
InvalidMlir, InvalidMlir,
/// Another Mlir error. Check the log for more context. /// Another Mlir error. Check the log for more context.
MlirUnexpected, MlirUnexpected,
/// A resource/executor was not found. /// A resource/executor was not found.
NotFound, NotFound,
/// Bytecode version incompatibility.
InvalidMlirBytecodeVersion,
OutOfMemory, OutOfMemory,
}; };
@ -35,64 +37,61 @@ pub fn registerPasses(comptime passes: []const u8) void {
@field(c, "mlirRegister" ++ passes ++ "Passes")(); @field(c, "mlirRegister" ++ passes ++ "Passes")();
} }
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void { pub fn successOr(res: c.MlirLogicalResult, err: anytype) @TypeOf(err)!void {
return if (res.value == 0) err; return if (res.value == 0) err else {};
} }
/// Alternative to MlirWrapperType
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
pub const Registry = struct { pub const Registry = struct {
_inner: c.MlirDialectRegistry, _inner: c.MlirDialectRegistry,
pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy); pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy);
pub fn init() !Registry { pub fn init() !Registry {
return helpers.init(Registry, c.mlirDialectRegistryCreate(), c.mlirDialectRegistryIsNull) orelse Error.MlirUnexpected; const registry = c.mlirDialectRegistryCreate();
return .{ ._inner = .{ .ptr = registry.ptr orelse return Error.MlirUnexpected } };
} }
}; };
pub const Context = struct { pub const Context = struct {
_inner: c.MlirContext, _inner: c.MlirContext,
const Self = Context;
pub const deinit = helpers.deinit(Context, c.mlirContextDestroy); pub const deinit = helpers.deinit(Context, c.mlirContextDestroy);
pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull); pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull);
pub fn init() !Self { pub fn init() !Context {
return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected; return Context.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected;
} }
pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self { pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Context {
return Self.wrapOr( return Context.wrapOr(
c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled), c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled),
) orelse Error.InvalidMlir; ) orelse Error.InvalidMlir;
} }
pub fn setMultiThreading(self: *Self, enabled: bool) void { pub fn setMultiThreading(self: *Context, enabled: bool) void {
c.mlirContextEnableMultithreading(self._inner, enabled); c.mlirContextEnableMultithreading(self._inner, enabled);
} }
pub fn appendDialectRegistry(self: *Self, registry: Registry) void { pub fn appendDialectRegistry(self: *Context, registry: Registry) void {
c.mlirContextAppendDialectRegistry(self._inner, registry._inner); c.mlirContextAppendDialectRegistry(self._inner, registry._inner);
} }
pub fn loadAllAvailableDialects(self: *Self) void { pub fn loadAllAvailableDialects(self: *Context) void {
c.mlirContextLoadAllAvailableDialects(self._inner); c.mlirContextLoadAllAvailableDialects(self._inner);
} }
pub fn numRegisteredDialects(self: Self) usize { pub fn numRegisteredDialects(self: Context) usize {
return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner)); return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner));
} }
pub fn numLoadedDialects(self: Self) usize { pub fn numLoadedDialects(self: Context) usize {
return @intCast(c.mlirContextGetNumLoadedDialects(self._inner)); return @intCast(c.mlirContextGetNumLoadedDialects(self._inner));
} }
pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool { pub fn isRegisteredOperation(self: Context, op: [:0]const u8) bool {
return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op)); return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op));
} }
pub fn location(self: Self, src: std.builtin.SourceLocation) Location { pub fn location(self: Context, src: std.builtin.SourceLocation) Location {
return Location.fromSrc(self, src); return Location.fromSrc(self, src);
} }
}; };
@ -103,9 +102,7 @@ pub const Module = struct {
pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy); pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy);
pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull); pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull);
const Self = Module; pub fn init(loc: Location) Module {
pub fn init(loc: Location) Self {
return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) }; return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) };
} }
@ -142,29 +139,26 @@ pub const PassManager = struct {
pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy); pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy);
pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull); pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull);
const Self = PassManager; pub fn init(ctx: Context) !PassManager {
return PassManager.wrapOr(
pub fn init(ctx: Context) !Self {
return Self.wrapOr(
c.mlirPassManagerCreate(ctx._inner), c.mlirPassManagerCreate(ctx._inner),
) orelse Error.MlirUnexpected; ) orelse Error.MlirUnexpected;
} }
pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self { pub fn initOnOperation(ctx: Context, op: [:0]const u8) !PassManager {
return Self.wrapOr( return PassManager.wrapOr(
c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)), c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)),
) orelse Error.MlirUnexpected; ) orelse Error.MlirUnexpected;
} }
pub fn asOpPassManager(self: Self) OpPassManager { pub fn asOpPassManager(self: PassManager) OpPassManager {
return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) }; return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) };
} }
pub fn enableIRPrinting(self: *Self) void { // TODO mlirPassManagerEnableIRPrinting
c.mlirPassManagerEnableIRPrinting(self._inner); // pub fn enableIRPrinting(self: *PassManager) void {}
}
pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void { pub fn runOnOp(self: *PassManager, op: Operation) error{InvalidMlir}!void {
if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) { if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) {
return Error.InvalidMlir; return Error.InvalidMlir;
} }
@ -193,21 +187,20 @@ pub const OpPassManager = struct {
pub const Identifier = struct { pub const Identifier = struct {
_inner: c.MlirIdentifier, _inner: c.MlirIdentifier,
const Self = Identifier;
pub fn get(ctx: Context, str_: [:0]const u8) Self { pub fn get(ctx: Context, str_: [:0]const u8) Identifier {
return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) }; return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) };
} }
pub fn context(self: Self) Context { pub fn context(self: Identifier) Context {
return .{ ._inner = c.mlirIdentifierGetContext(self._inner) }; return .{ ._inner = c.mlirIdentifierGetContext(self._inner) };
} }
pub fn str(self: Self) []const u8 { pub fn str(self: Identifier) []const u8 {
return fromStringRef(c.mlirIdentifierStr(self._inner)); return fromStringRef(c.mlirIdentifierStr(self._inner));
} }
pub fn equals(self: Self, other: Self) bool { pub fn equals(self: Identifier, other: Identifier) bool {
return c.mlirIdentifierEqual(self._inner, other._inner); return c.mlirIdentifierEqual(self._inner, other._inner);
} }
}; };
@ -322,6 +315,14 @@ pub const Attribute = struct {
return DictionaryAttribute.init(ctx, attrs).asAttr(); return DictionaryAttribute.init(ctx, attrs).asAttr();
} }
pub fn eqlAny(Attr: type) fn (Attr, Attr) bool {
return struct {
fn eql(a: Attr, b: Attr) bool {
return a.asAttr().eql(b.asAttr());
}
}.eql;
}
}; };
pub const NamedAttribute = extern struct { pub const NamedAttribute = extern struct {
@ -349,15 +350,14 @@ pub const NamedAttribute = extern struct {
pub const StringAttribute = struct { pub const StringAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub const is_a_fn = c.mlirAttributeIsAString; pub const is_a_fn = c.mlirAttributeIsAString;
const Self = StringAttribute; pub const asAttr = Attribute.fromAny(StringAttribute);
pub const asAttr = Attribute.fromAny(Self); pub const eql = Attribute.eqlAny(StringAttribute);
pub const eql = Attribute.eqlAny(Self);
pub fn init(ctx: Context, str: []const u8) Self { pub fn init(ctx: Context, str: []const u8) StringAttribute {
return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) }; return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) };
} }
pub fn value(self: Self) []const u8 { pub fn value(self: StringAttribute) []const u8 {
return fromStringRef(c.mlirStringAttrGetValue(self._inner)); return fromStringRef(c.mlirStringAttrGetValue(self._inner));
} }
}; };
@ -365,15 +365,14 @@ pub const StringAttribute = struct {
pub const BoolAttribute = struct { pub const BoolAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub const is_a_fn = c.mlirAttributeIsABool; pub const is_a_fn = c.mlirAttributeIsABool;
const Self = BoolAttribute; pub const asAttr = Attribute.fromAny(BoolAttribute);
pub const asAttr = Attribute.fromAny(Self); pub const eql = Attribute.eqlAny(BoolAttribute);
pub const eql = Attribute.eqlAny(Self);
pub fn init(ctx: Context, value_: bool) Self { pub fn init(ctx: Context, value_: bool) BoolAttribute {
return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) }; return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) };
} }
pub fn value(self: Self) bool { pub fn value(self: BoolAttribute) bool {
return c.mlirBoolAttrGetValue(self._inner); return c.mlirBoolAttrGetValue(self._inner);
} }
}; };
@ -397,19 +396,18 @@ pub const TypeAttribute = struct {
pub const ArrayAttribute = struct { pub const ArrayAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,
pub const is_a_fn = c.mlirAttributeIsAArray; pub const is_a_fn = c.mlirAttributeIsAArray;
const Self = ArrayAttribute; pub const asAttr = Attribute.fromAny(ArrayAttribute);
pub const asAttr = Attribute.fromAny(Self); pub const eql = Attribute.eqlAny(ArrayAttribute);
pub const eql = Attribute.eqlAny(Self);
pub fn init(ctx: Context, attrs: []const Attribute) Self { pub fn init(ctx: Context, attrs: []const Attribute) ArrayAttribute {
return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) }; return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) };
} }
pub fn size(self: Self) usize { pub fn size(self: ArrayAttribute) usize {
return @intCast(c.mlirArrayAttrGetNumElements(self._inner)); return @intCast(c.mlirArrayAttrGetNumElements(self._inner));
} }
pub fn get(self: Self, index: usize) Attribute { pub fn get(self: ArrayAttribute, index: usize) Attribute {
return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) }; return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) };
} }
}; };
@ -590,13 +588,13 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr { pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr {
const raw_bytes = std.mem.sliceAsBytes(slice); const raw_bytes = std.mem.sliceAsBytes(slice);
const res: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet( const attr: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet(
shaped_type._inner, shaped_type._inner,
@intCast(raw_bytes.len), @intCast(raw_bytes.len),
@ptrCast(raw_bytes.ptr), @ptrCast(raw_bytes.ptr),
) }; ) };
std.debug.assert(Attribute.wrapOr(res._inner) != null); std.debug.assert(attr._inner.ptr != null);
return res; return attr;
} }
pub fn len(self: Attr) usize { pub fn len(self: Attr) usize {
@ -717,7 +715,7 @@ pub const DictionaryAttribute = struct {
} }
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute { pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute {
return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, name)); return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, stringRef(name)));
} }
}; };
@ -728,8 +726,7 @@ pub const Operation = struct {
pub const dump = helpers.dump(Operation, c.mlirOperationDestroy); pub const dump = helpers.dump(Operation, c.mlirOperationDestroy);
pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy); pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy);
pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull); pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull);
pub const eql = helpers.eql(Operation, c.mlirOperationEqual);
pub const eql = Attribute.eqlAny(Self);
pub fn init(state: *OperationState) !Self { pub fn init(state: *OperationState) !Self {
return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir; return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir;
@ -881,52 +878,33 @@ pub const Operation = struct {
return .{ ._inner = c.mlirOperationGetContext(self._inner) }; return .{ ._inner = c.mlirOperationGetContext(self._inner) };
} }
pub fn writeBytecode(self: Self, writer: anytype) void { pub fn writeBytecode(self: Self, writer: *std.Io.Writer) std.Io.Writer.Error!void {
var writer_context = .{ .writer = writer }; var writer_with_err: WriterWithErr = .{ .writer = writer };
const WriterContext = @TypeOf(writer_context);
c.mlirOperationWriteBytecode( c.mlirOperationWriteBytecode(
self._inner, self._inner,
(struct { WriterWithErr.printCallback,
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { &writer_with_err,
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
}
}).callback,
&writer_context,
); );
return writer_with_err.check();
} }
pub fn writeBytecodeWithConfig(self: Self, writer: anytype, config: struct { pub fn writeBytecodeWithConfig(self: Self, writer: *std.Io.Writer, config: struct {
desiredEmitedVersion: ?i64 = null, desiredEmitedVersion: ?i64 = null,
}) !void { }) error{ InvalidMlirBytecodeVersion, WriteFailed }!void {
const cfg = c.mlirBytecodeWriterConfigCreate(); const cfg = c.mlirBytecodeWriterConfigCreate();
defer c.mlirBytecodeWriterConfigDestroy(cfg); defer c.mlirBytecodeWriterConfigDestroy(cfg);
if (config.desiredEmitedVersion) |v| { if (config.desiredEmitedVersion) |v| {
c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v); c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v);
} }
const WriterContext = struct { var writer_with_err: WriterWithErr = .{ .writer = writer };
writer: @TypeOf(writer),
write_error: ?@TypeOf(writer).Error = null,
};
var writer_context: WriterContext = .{ .writer = writer };
try successOr(c.mlirOperationWriteBytecodeWithConfig( try successOr(c.mlirOperationWriteBytecodeWithConfig(
self._inner, self._inner,
cfg, cfg,
(struct { &WriterWithErr.printCallback,
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void { &writer_with_err,
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
inner_writer_context.write_error = err;
};
}
}).callback,
&writer_context,
), error.InvalidMlirBytecodeVersion); ), error.InvalidMlirBytecodeVersion);
return writer_with_err.check();
if (writer_context.write_error) |err| return err;
} }
/// Enable a full dump of the IR. /// Enable a full dump of the IR.
@ -939,26 +917,18 @@ pub const Operation = struct {
op: Operation, op: Operation,
flags: OpPrintingFlags, flags: OpPrintingFlags,
pub fn format(self: @This(), writer: anytype) !void { pub fn format(self: MlirFormatter, writer: *std.Io.Writer) !void {
self.op.print(writer, self.flags); try self.op.print(writer, self.flags);
} }
}; };
pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void { pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) std.Io.Writer.Error!void {
const pflags = flags.create(); const pflags = flags.create();
defer c.mlirOpPrintingFlagsDestroy(pflags); defer c.mlirOpPrintingFlagsDestroy(pflags);
c.mlirOperationPrintWithFlags( var writer_err: WriterWithErr = .{ .writer = writer };
self._inner, c.mlirOperationPrintWithFlags(self._inner, pflags, WriterWithErr.printCallback, &writer_err);
pflags, return writer_err.check();
(struct {
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_));
_writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed");
}
}).callback,
writer,
);
} }
pub fn verify(self: Self) bool { pub fn verify(self: Self) bool {
@ -1065,27 +1035,25 @@ pub const OpPrintingFlags = struct {
pub const OpOperand = struct { pub const OpOperand = struct {
_inner: c.MlirOpOperand, _inner: c.MlirOpOperand,
const Self = OpOperand; pub const wrapOr = helpers.wrapOr(OpOperand, c.mlirOpOperandIsNull);
pub fn owner(self: Self) Operation { pub fn owner(self: OpOperand) Operation {
return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) }; return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) };
} }
pub fn number(self: Self) usize { pub fn number(self: OpOperand) usize {
return @intCast(c.mlirOpOperandGetOperandNumber(self._inner)); return @intCast(c.mlirOpOperandGetOperandNumber(self._inner));
} }
pub fn nextUse(self: Self) ?Self { pub fn nextUse(self: OpOperand) ?OpOperand {
return Self.wrapOr( return wrapOr(c.mlirOpOperandGetNextUse(self._inner));
c.mlirOpOperandGetNextUse(self._inner),
);
} }
}; };
pub const Region = struct { pub const Region = struct {
_inner: c.MlirRegion, _inner: c.MlirRegion,
pub const eql = helpers.eql(Region, c.mlirBlockEqual); pub const eql = helpers.eql(Region, c.mlirRegionEqual);
pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy); pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy);
pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull); pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull);
@ -1121,7 +1089,7 @@ pub const Value = struct {
pub const dump = helpers.dump(Value, c.mlirValueDump); pub const dump = helpers.dump(Value, c.mlirValueDump);
pub const eql = helpers.eql(Value, c.mlirValueEqual); pub const eql = helpers.eql(Value, c.mlirValueEqual);
pub const format = helpers.format(Value, c.mlirValuePrint).format; pub const format = helpers.format(Value, c.mlirValuePrint);
pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull); pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull);
pub fn getType(val: Value) Type { pub fn getType(val: Value) Type {
@ -1183,7 +1151,7 @@ pub const BlockArgument = struct {
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner)); return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
} }
pub fn format(self: BlockArgument, writer: anytype) !void { pub fn format(self: BlockArgument, writer: *std.Io.Writer) !void {
const value = Value{ ._inner = self._inner }; const value = Value{ ._inner = self._inner };
return value.format(writer); return value.format(writer);
} }
@ -1192,7 +1160,7 @@ pub const BlockArgument = struct {
pub const Type = struct { pub const Type = struct {
_inner: c.MlirType, _inner: c.MlirType,
pub const dump = helpers.eql(Type, c.mlirTypeDump); pub const dump = helpers.dump(Type, c.mlirTypeDump);
pub const eql = helpers.eql(Type, c.mlirTypeEqual); pub const eql = helpers.eql(Type, c.mlirTypeEqual);
pub const format = helpers.format(Type, c.mlirTypePrint); pub const format = helpers.format(Type, c.mlirTypePrint);
pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull); pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull);
@ -1230,14 +1198,6 @@ pub const Type = struct {
}.eql; }.eql;
} }
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
return struct {
pub fn format(self: SpecificType, writer: anytype) !void {
return try Type.format(self.asType(), writer);
}
};
}
pub fn index(ctx: Context) Type { pub fn index(ctx: Context) Type {
return IndexType.init(ctx).asType(); return IndexType.init(ctx).asType();
} }
@ -1280,7 +1240,7 @@ pub const IndexType = struct {
pub const asType = Type.fromAny(IndexType); pub const asType = Type.fromAny(IndexType);
pub const eql = Type.eqlAny(IndexType); pub const eql = Type.eqlAny(IndexType);
pub const format = Type.formatAny(IndexType).format; pub const format = helpers.format(IndexType, c.mlirTypePrint);
pub fn init(ctx: Context) IndexType { pub fn init(ctx: Context) IndexType {
return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) }; return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) };
@ -1452,7 +1412,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
pub const asType = Type.fromAny(Complex); pub const asType = Type.fromAny(Complex);
pub const eql = Type.eqlAny(Complex); pub const eql = Type.eqlAny(Complex);
pub const format = Type.formatAny(Complex).format; pub const format = helpers.format(Complex, c.mlirTypePrint);
pub const ComplexTypeType: ComplexTypes = ct; pub const ComplexTypeType: ComplexTypes = ct;
pub const init = if (ct != .unknown) struct { pub const init = if (ct != .unknown) struct {
@ -1468,51 +1428,50 @@ pub const TupleType = struct {
_inner: c.MlirType, _inner: c.MlirType,
pub const is_a_fn = c.mlirTypeIsATuple; pub const is_a_fn = c.mlirTypeIsATuple;
const Self = TupleType; pub fn init(ctx: Context, elements: []const Type) !TupleType {
const tuple_type = c.mlirTupleTypeGet(
pub fn init(ctx: Context, elements: []const Type) !Self {
return Self.wrapOr(c.mlirTupleTypeGet(
ctx._inner, ctx._inner,
@intCast(elements.len), @intCast(elements.len),
@ptrCast(elements.ptr), @ptrCast(elements.ptr),
)) orelse Error.InvalidMlir; );
return .{ ._inner = .{ .ptr = tuple_type.ptr orelse return error.InvalidMlir } };
} }
pub fn getNumTypes(self: Self) usize { pub fn getNumTypes(self: TupleType) usize {
return @intCast(c.mlirTupleTypeGetNumTypes(self._inner)); return @intCast(c.mlirTupleTypeGetNumTypes(self._inner));
} }
pub fn getElementType(self: Self, index: usize) Type { pub fn getElementType(self: TupleType, index: usize) Type {
return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) }; return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) };
} }
pub const asType = Type.fromAny(Self); pub const asType = Type.fromAny(TupleType);
}; };
pub const FunctionType = struct { pub const FunctionType = struct {
_inner: c.MlirType, _inner: c.MlirType,
pub const is_a_fn = c.mlirTypeIsAFunction; pub const is_a_fn = c.mlirTypeIsAFunction;
const Self = FunctionType; pub const asType = Type.fromAny(FunctionType);
pub const asType = Type.fromAny(Self); pub const eql = Type.eqlAny(FunctionType);
pub const eql = Type.eqlAny(IndexType);
pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self { pub fn init(ctx: Context, args: []const Type, results: []const Type) !FunctionType {
const func = Type.wrapOr(c.mlirFunctionTypeGet( const func_type = c.mlirFunctionTypeGet(
ctx._inner, ctx._inner,
@intCast(args.len), @intCast(args.len),
@ptrCast(args.ptr), @ptrCast(args.ptr),
@intCast(results.len), @intCast(results.len),
@ptrCast(results.ptr), @ptrCast(results.ptr),
)) orelse return Error.InvalidMlir; );
return func.as(Self).?; return .{ ._inner = .{ .ptr = func_type.ptr orelse return error.InvalidMlir } };
} }
}; };
pub const RankedTensorType = struct { pub const RankedTensorType = struct {
_inner: c.MlirType, _inner: c.MlirType,
pub const is_a_fn = c.mlirTypeIsARankedTensor; pub const is_a_fn = c.mlirTypeIsARankedTensor;
pub const asType = Type.fromAny(RankedTensorType);
pub const eql = Type.eqlAny(RankedTensorType); pub const eql = Type.eqlAny(RankedTensorType);
pub const format = helpers.format(Type, c.mlirTypePrint); pub const format = helpers.format(RankedTensorType, c.mlirTypePrint);
pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType { pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType {
return .{ ._inner = c.mlirRankedTensorTypeGet( return .{ ._inner = c.mlirRankedTensorTypeGet(
@ -1534,20 +1493,16 @@ pub const RankedTensorType = struct {
pub fn getDimension(self: RankedTensorType, dim: usize) i64 { pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim)); return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim));
} }
pub const asType = Type.fromAny(RankedTensorType);
}; };
pub const Dialect = struct { pub const Dialect = struct {
_inner: c.MlirDialect, _inner: c.MlirDialect,
const Self = Dialect; pub fn getContext(self: Dialect) Context {
pub fn getContext(self: Self) Context {
return .{ ._inner = c.mlirDialectGetContext(self._inner) }; return .{ ._inner = c.mlirDialectGetContext(self._inner) };
} }
pub fn getNamespace(self: Self) []const u8 { pub fn getNamespace(self: Dialect) []const u8 {
return fromStringRef(c.mlirDialectGetNamespace(self._inner)); return fromStringRef(c.mlirDialectGetNamespace(self._inner));
} }
}; };
@ -1579,7 +1534,7 @@ pub const DialectHandle = struct {
pub const Location = struct { pub const Location = struct {
_inner: c.MlirLocation, _inner: c.MlirLocation,
pub const eql = helpers.eql(Type, c.mlirLocationEqual); pub const eql = helpers.eql(Location, c.mlirLocationEqual);
pub const format = helpers.format(Location, c.mlirLocationPrint); pub const format = helpers.format(Location, c.mlirLocationPrint);
pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location { pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location {
@ -1613,17 +1568,17 @@ pub const Location = struct {
) }; ) };
} }
pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location { pub fn named(loc: Location, ctx: Context, loc_name: []const u8) Location {
return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) }; return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) };
} }
pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location { pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location {
var buf: [256]u8 = undefined; var buf: [256]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf); var writer: std.Io.Writer = .fixed(&buf);
std.fmt.format(stream.writer(), fmt, args) catch { writer.print(fmt, args) catch {
buf[256 - 3 ..].* = "...".*; buf[256 - 3 ..].* = "...".*;
}; };
return loc.named(ctx, @ptrCast(stream.getWritten())); return loc.named(ctx, writer.buffered());
} }
pub fn unknown(ctx: Context) Location { pub fn unknown(ctx: Context) Location {
@ -1636,7 +1591,6 @@ pub const Block = struct {
pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull); pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull);
pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy); pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy);
pub const eql = helpers.eql(Block, c.mlirBlockEqual); pub const eql = helpers.eql(Block, c.mlirBlockEqual);
pub fn init(args: []const Type, locs: []const Location) !Block { pub fn init(args: []const Type, locs: []const Location) !Block {
@ -1736,27 +1690,15 @@ pub const helpers = struct {
}.isNull; }.isNull;
} }
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type { pub fn format(
Any: type,
print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void,
) fn (Any, *std.Io.Writer) std.Io.Writer.Error!void {
return struct { return struct {
pub fn format(self: Any, writer: *std.Io.Writer) !void { pub fn format(self: Any, writer: *std.Io.Writer) std.Io.Writer.Error!void {
const WriterWithErr = struct { try callPrintFn(Any, self, print_fn, writer);
writer: *std.Io.Writer,
err: ?std.Io.Writer.Error = null,
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
if (ctx.err) |_| return;
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
ctx.err = err;
return;
};
} }
}; }.format;
var context: WriterWithErr = .{ .writer = writer };
print_fn(self._inner, &WriterWithErr.printCallback, &context);
if (context.err) |err| return err;
}
};
} }
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T { pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
@ -1767,9 +1709,35 @@ pub const helpers = struct {
} }
}.wrapOr; }.wrapOr;
} }
};
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T { pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
if (is_null_fn(inner)) return null;
return .{ ._inner = inner }; pub fn callPrintFn(
T: type,
value: T,
print_fn: fn (@FieldType(T, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void,
writer: *std.Io.Writer,
) std.Io.Writer.Error!void {
var writer_with_err: WriterWithErr = .{ .writer = writer };
print_fn(value._inner, &WriterWithErr.printCallback, &writer_with_err);
return writer_with_err.check();
}
pub const WriterWithErr = struct {
writer: *std.Io.Writer,
err: ?std.Io.Writer.Error = null,
pub fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
var ctx: *WriterWithErr = @ptrCast(@alignCast(opaque_ctx));
if (ctx.err) |_| return;
ctx.writer.writeAll(fromStringRef(mlir_str)) catch |err| {
ctx.err = err;
return;
};
}
pub fn check(self: WriterWithErr) std.Io.Writer.Error!void {
if (self.err) |err| return err;
} }
}; };

View File

@ -265,16 +265,8 @@ pub const Buffer = extern struct {
return self._dims[0..self.rank]; return self._dims[0..self.rank];
} }
pub fn format( pub fn format(buffer: Buffer, writer: *std.Io.Writer) !void {
buffer: Buffer, try writer.print("FfiBuffer({any}, .{t})@0x{x}", .{ buffer.dims(), buffer.dtype, @intFromPtr(buffer.data) });
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) });
} }
}; };

View File

@ -1257,14 +1257,7 @@ pub const NamedValue = extern struct {
}) }; }) };
} }
pub fn format( pub fn format(self: NamedValue, writer: *std.Io.Writer) !void {
self: NamedValue,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] }); try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] });
const u = self.inner.unnamed_0; const u = self.inner.unnamed_0;
switch (self.kind()) { switch (self.kind()) {

View File

@ -117,8 +117,8 @@ fn wrapNeffAsCustomCall(allocator: std.mem.Allocator, hlo_code: []const u8, neff
}; };
{ {
var operand_ids: std.ArrayListUnmanaged(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]); var operand_ids: std.ArrayList(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]);
var new_instructions: std.ArrayListUnmanaged(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1])); var new_instructions: std.ArrayList(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]));
for (entry_instructions) |instruction| { for (entry_instructions) |instruction| {
if (std.mem.eql(u8, upb.slice(c.xla_HloInstructionProto_opcode(instruction)) orelse continue, "parameter")) { if (std.mem.eql(u8, upb.slice(c.xla_HloInstructionProto_opcode(instruction)) orelse continue, "parameter")) {
const id = c.xla_HloInstructionProto_id(instruction); const id = c.xla_HloInstructionProto_id(instruction);

View File

@ -9,6 +9,10 @@ fn FmtSlice(T: type) type {
return struct { return struct {
slice: []const T, slice: []const T,
pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
return try formatSliceAny(f.slice, .{}, writer);
}
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void { pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
return switch (@typeInfo(T)) { return switch (@typeInfo(T)) {
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer), .comptime_float, .float => try formatFloatSlice(f.slice, n, writer),

View File

@ -315,14 +315,7 @@ pub const Metadata = union(enum) {
}; };
} }
pub fn format( pub fn format(self: Metadata, writer: *std.Io.Writer) !void {
self: Metadata,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
switch (self) { switch (self) {
.null => _ = try writer.write("null"), .null => _ = try writer.write("null"),
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}), inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),

View File

@ -1,10 +1,10 @@
const async = @import("async");
const std = @import("std"); const std = @import("std");
const async = @import("async");
const zml = @import("../zml.zig"); const zml = @import("../zml.zig");
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayList(u8);
const Allocator = std.mem.Allocator;
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
const file = try std.fs.cwd().openFile(path, .{}); const file = try std.fs.cwd().openFile(path, .{});
defer file.close(); defer file.close();
@ -26,7 +26,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
return res; return res;
} }
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void { pub fn parseMetadata(allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
const metadata = &store._metadata; const metadata = &store._metadata;
const key = prefix.items; const key = prefix.items;
return switch (val) { return switch (val) {

View File

@ -9,7 +9,6 @@ const zml = @import("../zml.zig");
const HostBuffer = zml.HostBuffer; const HostBuffer = zml.HostBuffer;
const json = @import("json.zig"); const json = @import("json.zig");
const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.@"zml/io"); const log = std.log.scoped(.@"zml/io");
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore { pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
@ -17,19 +16,18 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
errdefer res.arena.deinit(); errdefer res.arena.deinit();
const arena = res.arena.allocator(); const arena = res.arena.allocator();
var files = std.array_list.Managed(MemoryMappedFile).init(arena); var files: std.ArrayList(MemoryMappedFile) = .empty;
errdefer files.deinit();
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) { if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
try loadFromIndex(arena, &res, &files, path); try loadFromIndex(arena, &res, &files, path);
} else { } else {
try loadFile(arena, &res, &files, path); try loadFile(arena, &res, &files, path);
} }
res.files = try files.toOwnedSlice(); res.files = try files.toOwnedSlice(allocator);
return res; return res;
} }
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
const file = async.File.open(path, .{}) catch |err| { const file = async.File.open(path, .{}) catch |err| {
log.err("Failed to open {s}: {}", .{ path, err }); log.err("Failed to open {s}: {}", .{ path, err });
return err; return err;
@ -61,11 +59,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
if (index.object.get("__metadata__")) |metadata| { if (index.object.get("__metadata__")) |metadata| {
var prefix_buf: [1024]u8 = undefined; var prefix_buf: [1024]u8 = undefined;
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata); try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), metadata);
} }
} }
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void { fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
const file = async.File.open(path, .{}) catch |err| { const file = async.File.open(path, .{}) catch |err| {
log.err("Failed to open {s}: {}", .{ path, err }); log.err("Failed to open {s}: {}", .{ path, err });
return err; return err;
@ -87,7 +85,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array
errdefer buffer_file.deinit(); errdefer buffer_file.deinit();
buffer_file.data_offset = 8 + json_header_length; buffer_file.data_offset = 8 + json_header_length;
try files.append(buffer_file); try files.append(allocator, buffer_file);
errdefer _ = files.pop(); errdefer _ = files.pop();
var it = metadata.object.iterator(); var it = metadata.object.iterator();
@ -95,7 +93,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array
const key = entry.key_ptr.*; const key = entry.key_ptr.*;
if (std.mem.eql(u8, key, "__metadata__")) { if (std.mem.eql(u8, key, "__metadata__")) {
var prefix_buf: [1024]u8 = undefined; var prefix_buf: [1024]u8 = undefined;
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*); try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), entry.value_ptr.*);
continue; continue;
} }
const val = entry.value_ptr.*; const val = entry.value_ptr.*;

View File

@ -6,7 +6,6 @@ const zml = @import("../zml.zig");
const eval = @import("torch/eval.zig"); const eval = @import("torch/eval.zig");
const File = @import("torch/file.zig").File; const File = @import("torch/file.zig").File;
const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.@"zml/aio"); const log = std.log.scoped(.@"zml/aio");
test { test {

View File

@ -348,7 +348,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
} }
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void { fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
var array_list = std.ArrayListUnmanaged(py.Any).fromOwnedSlice(current.*); var array_list = std.ArrayList(py.Any).fromOwnedSlice(current.*);
try array_list.appendSlice(allocator, values); try array_list.appendSlice(allocator, values);
current.* = array_list.items; current.* = array_list.items;
} }

View File

@ -13,7 +13,7 @@ const py = @import("py.zig");
const log = std.log.scoped(.@"zml/aio"); const log = std.log.scoped(.@"zml/aio");
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead // TODO(cryptodeal): use zml.aio.PrefixBuilder instead
const StringBuilder = std.ArrayListUnmanaged(u8); const StringBuilder = std.ArrayList(u8);
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(@This());
@ -191,7 +191,7 @@ pub const File = struct {
.boolval => bool, .boolval => bool,
else => unreachable, else => unreachable,
}; };
var values: std.ArrayListUnmanaged(ItemType) = .{}; var values: std.ArrayList(ItemType) = .{};
try values.append(allocator, val0); try values.append(allocator, val0);
for (seq.values[1..], 1..) |val, i| { for (seq.values[1..], 1..) |val, i| {
if (std.meta.activeTag(val) != tag) valid_slice = false; if (std.meta.activeTag(val) != tag) valid_slice = false;

View File

@ -768,7 +768,7 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
// It's not very efficient to interleave the results with the data copied from the stream, // It's not very efficient to interleave the results with the data copied from the stream,
// because growth event in the results ArrayList will lead to fragmentation. // because growth event in the results ArrayList will lead to fragmentation.
// Trying to mitigate that by using a generous default size. // Trying to mitigate that by using a generous default size.
var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512); var results: std.ArrayList(Op) = try .initCapacity(arena, 512);
errdefer results.deinit(arena); errdefer results.deinit(arena);
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512); var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);

View File

@ -1,9 +1,10 @@
const std = @import("std"); const std = @import("std");
const math = std.math; const math = std.math;
const log = std.log.scoped(.@"zml/aio");
const pickle = @import("pickle.zig"); const pickle = @import("pickle.zig");
const log = std.log.scoped(.@"zml/aio");
/// Correspond to a function/constructor call /// Correspond to a function/constructor call
pub const Object = struct { pub const Object = struct {
member: Any, member: Any,
@ -206,12 +207,16 @@ pub const Any = union(Kind) {
self.* = undefined; self.* = undefined;
} }
inline fn writeIndents(indents: usize, writer: anytype) !void { inline fn writeIndents(indents: usize, writer: *std.Io.Writer) !void {
try writer.writeBytesNTimes(" ", indents); // resolve tab = 2 spaces try writer.writeBytesNTimes(" ", indents); // resolve tab = 2 spaces
// try writer.writeByteNTimes('\t'); // try writer.writeByteNTimes('\t');
} }
fn internalFormat(value: Any, indents: usize, writer: anytype) !void { pub fn format(self: Any, writer: *std.Io.Writer) !void {
return internalFormat(self, 0, writer);
}
fn internalFormat(value: Any, indents: usize, writer: *std.Io.Writer) !void {
try writeIndents(indents, writer); try writeIndents(indents, writer);
try writer.writeAll(".{\n"); try writer.writeAll(".{\n");
try writeIndents(indents + 1, writer); try writeIndents(indents + 1, writer);
@ -303,10 +308,6 @@ pub const Any = union(Kind) {
try writer.writeByte('}'); try writer.writeByte('}');
} }
pub fn format(self: Any, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
return internalFormat(self, 0, writer);
}
pub fn clone(self: Any, allocator: std.mem.Allocator) !Any { pub fn clone(self: Any, allocator: std.mem.Allocator) !Any {
return switch (self) { return switch (self) {
inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)), inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)),

View File

@ -384,10 +384,7 @@ pub const Buffer = struct {
return res; return res;
} }
pub fn format( pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
self: Buffer,
writer: anytype,
) !void {
try writer.print("Buffer({f})", .{self._shape}); try writer.print("Buffer({f})", .{self._shape});
} }

View File

@ -95,6 +95,8 @@ pub fn compileFn(
) !FnExe(func) { ) !FnExe(func) {
var pretty_name = try prettyFnName(func, allocator); var pretty_name = try prettyFnName(func, allocator);
defer pretty_name.deinit(allocator); defer pretty_name.deinit(allocator);
log.info("Compiling {s} with {f}", .{ pretty_name.items, stdx.fmt.any(args) });
var context = try CompilationContext.init(allocator, pretty_name.items, platform); var context = try CompilationContext.init(allocator, pretty_name.items, platform);
defer context.deinit(); defer context.deinit();
@ -306,7 +308,7 @@ pub const BaseExe = struct {
} }
} }
pub fn serialize(self: BaseExe, writer: anytype) !void { pub fn serialize(self: BaseExe, writer: *std.Io.Writer) !void {
var executable = try self.exe.getExecutable(self.platform.pjrt_api); var executable = try self.exe.getExecutable(self.platform.pjrt_api);
var serialize_result = try executable.serialize(self.platform.pjrt_api); var serialize_result = try executable.serialize(self.platform.pjrt_api);
defer serialize_result.deinit(); defer serialize_result.deinit();
@ -377,7 +379,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
try self.inner.bind(T, value); try self.inner.bind(T, value);
} }
pub fn serialize(self: Self, writer: anytype) !void { pub fn serialize(self: Self, writer: *std.Io.Writer) !void {
return try self.inner.serialize(writer); return try self.inner.serialize(writer);
} }
@ -437,7 +439,7 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
fn prettyFnName( fn prettyFnName(
comptime func: anytype, comptime func: anytype,
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
) !std.ArrayListUnmanaged(u8) { ) !std.ArrayList(u8) {
const full_noisy_name = @typeName(@TypeOf(func)); const full_noisy_name = @typeName(@TypeOf(func));
const og_len = full_noisy_name.len; const og_len = full_noisy_name.len;
const buffer = try allocator.alloc(u8, og_len); const buffer = try allocator.alloc(u8, og_len);

View File

@ -321,10 +321,7 @@ pub const HostBuffer = struct {
}; };
} }
pub fn format( pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
self: HostBuffer,
writer: anytype,
) !void {
try writer.print("HostBuffer(.{f})", .{self._shape}); try writer.print("HostBuffer(.{f})", .{self._shape});
} }

View File

@ -679,28 +679,38 @@ test zip {
/// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist. /// finds all X in the given object, and write the result of func(X) into an arraylist.
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void { pub fn collectAlloc(
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)}); func: anytype,
const LocalContext = struct {
func_ctx: _CollectCtx(func), func_ctx: _CollectCtx(func),
out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), allocator: std.mem.Allocator,
obj: anytype,
) std.mem.Allocator.Error![]stdx.meta.FnSignature(func, null).ReturnT {
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
const CollectAllocCtx = struct {
func_ctx: _CollectCtx(func),
allocator: std.mem.Allocator,
out: std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT) = .empty,
oom: bool = false, oom: bool = false,
};
var context = LocalContext{ .func_ctx = func_ctx, .out = out }; fn cb(ctx: *@This(), val: *const _CollectArg(func)) void {
visit((struct {
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
if (ctx.oom) return; if (ctx.oom) return;
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*); const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
ctx.out.append(res) catch { ctx.out.append(ctx.allocator, res) catch {
ctx.oom = true; ctx.oom = true;
}; };
} }
}).cb, &context, obj); };
var context = CollectAllocCtx{ .func_ctx = func_ctx, .allocator = allocator };
visit(CollectAllocCtx.cb, &context, obj);
if (context.oom) return error.OutOfMemory; if (context.oom) return error.OutOfMemory;
return context.out.toOwnedSlice(allocator);
} }
/// Given a func(X) -> Y or a func(Ctx, X) -> Y, /// Given a func(X) -> Y or a func(Ctx, X) -> Y,
/// finds all X in the given object, and write the result of func(X) into an arraylist. /// finds all X in the given object, and write the result of func(X) into a given slice.
/// Asserts that the number of X found is equal to the slice len.
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void { pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)}); stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
const LocalContext = struct { const LocalContext = struct {

View File

@ -96,6 +96,6 @@ pub const Type = struct {
} }
} }
std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type}); std.debug.panic("Could not convert mlir.Type to DataType: {f}", .{mlir_type});
} }
}; };

View File

@ -187,7 +187,8 @@ pub const CompilationContext = struct {
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| { if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
var write_buf: [4096]u8 = undefined; var write_buf: [4096]u8 = undefined;
var writer = file.writer(&write_buf); var writer = file.writer(&write_buf);
module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false }); try module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
try writer.interface.flush();
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name }); log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
} else |_| { } else |_| {
log.warn("Failed to open {s}", .{mlir_name}); log.warn("Failed to open {s}", .{mlir_name});
@ -341,12 +342,11 @@ pub const CompilationContext = struct {
const locations = try arena.alloc(mlir.Location, tensor_count); const locations = try arena.alloc(mlir.Location, tensor_count);
@memset(locations, mlir.Location.unknown(mlir_ctx)); @memset(locations, mlir.Location.unknown(mlir_ctx));
var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count); const input_shapes = try res_allocator.alloc(Shape, tensor_count);
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable; meta.collectBuf(Tensor.shape, {}, args, input_shapes);
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
const input_types = try arena.alloc(mlir.Type, tensor_count); const input_types = try arena.alloc(mlir.Type, tensor_count);
for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh); for (input_types, input_shapes) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
const og_block_args = self._block_args; const og_block_args = self._block_args;
defer { defer {
@ -399,7 +399,7 @@ pub const CompilationContext = struct {
self.addDonationsAttributes(arg_attrs, fn_res_donations); self.addDonationsAttributes(arg_attrs, fn_res_donations);
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind); self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
if (self._platform.sharding().num_partitions > 1) { if (self._platform.sharding().num_partitions > 1) {
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); self.addShardingAttributes(arg_attrs, res_attrs, input_shapes, fn_res_shapes);
} }
} }
@ -427,7 +427,7 @@ pub const CompilationContext = struct {
return .{ return .{
.mlir_fn = mlir_fn, .mlir_fn = mlir_fn,
.name = opts.name, .name = opts.name,
.args_shapes = input_shapes.items, .args_shapes = input_shapes,
.res_tensors = fn_res, .res_tensors = fn_res,
.res_types = fn_res_types, .res_types = fn_res_types,
.res_shapes = fn_res_shapes, .res_shapes = fn_res_shapes,
@ -506,9 +506,9 @@ pub const CompilationContext = struct {
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } }; var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main }); const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator); var mlir_code: std.Io.Writer.Allocating = .init(std.testing.allocator);
defer mlir_bytecode.deinit(); defer mlir_code.deinit();
try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})}); try f.mlir_fn.print(&mlir_code.writer, .{});
// Check that the `x` input argument gives its buffer to the result tensor. // Check that the `x` input argument gives its buffer to the result tensor.
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`. // `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
@ -518,8 +518,8 @@ pub const CompilationContext = struct {
var buf = template.*; var buf = template.*;
for (0..2) |i| { for (0..2) |i| {
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable; const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| { std.testing.expect(std.mem.indexOf(u8, mlir_code.written(), alias_attr) != null) catch |err| {
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items}); log.warn("Didn't produced the expected IR:\n{s}", .{mlir_code.written()});
return err; return err;
}; };
} }
@ -547,12 +547,14 @@ pub const CompilationContext = struct {
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
const ctx = self.mlirCtx(); const ctx = self.mlirCtx();
const num_partitions = self.numPartitions(); const num_partitions = self.numPartitions();
var sharding_str: stdx.BoundedArray(u8, 128) = .{}; // This is big enough, see test below for examples values
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; var sharding_str_buf: [64]u8 = undefined;
return mlir.Attribute.string(ctx, sharding_str.constSlice()); var writer: std.Io.Writer = .fixed(&sharding_str_buf);
writeShardingRepresentation(shape, num_partitions, &writer) catch unreachable;
return mlir.Attribute.string(ctx, writer.buffered());
} }
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void { fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: *std.Io.Writer) std.Io.Writer.Error!void {
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info))); const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
if (n_sharded == 0 or num_partitions == 1) { if (n_sharded == 0 or num_partitions == 1) {
try writer.writeAll("{replicated}"); try writer.writeAll("{replicated}");
@ -567,26 +569,26 @@ pub const CompilationContext = struct {
} }
test writeShardingRepresentation { test writeShardingRepresentation {
var rule: [64]u8 = undefined; var attr_buf: [64]u8 = undefined;
const x = Shape.init(.{ 16, 8 }, .f32); const x = Shape.init(.{ 16, 8 }, .f32);
// By default tensors are replicated. // By default tensors are replicated.
{ {
var fbs = std.io.fixedBufferStream(&rule); var writer: std.Io.Writer = .fixed(&attr_buf);
try writeShardingRepresentation(x, 4, fbs.writer()); try writeShardingRepresentation(x, 4, &writer);
try std.testing.expectEqualStrings("{replicated}", fbs.getWritten()); try std.testing.expectEqualStrings("{replicated}", writer.buffered());
} }
// Shard along first axis. // Shard along first axis.
{ {
var fbs = std.io.fixedBufferStream(&rule); var writer: std.Io.Writer = .fixed(&attr_buf);
try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer()); try writeShardingRepresentation(x.withSharding(.{0}), 4, &writer);
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten()); try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", writer.buffered());
} }
// Also shard along second axis. // Also shard along second axis.
{ {
var fbs = std.io.fixedBufferStream(&rule); var writer: std.Io.Writer = .fixed(&attr_buf);
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer()); try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, &writer);
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten()); try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", writer.buffered());
} }
} }

View File

@ -443,29 +443,23 @@ pub fn if_(
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {}); const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {}); const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); check: {
defer true_shapes.deinit(); const arena = ctx.allocator();
var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator()); const true_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &true_branch_res) catch break :check;
defer false_shapes.deinit(); defer arena.free(true_shapes);
var failed_to_collect = false; const false_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &false_branch_res) catch break :check;
meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch { defer arena.free(false_shapes);
failed_to_collect = true;
}; stdx.debug.assert(true_shapes.len == false_shapes.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) });
meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch { for (true_shapes, false_shapes) |true_shape, false_shape| {
failed_to_collect = true; stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) });
};
if (!failed_to_collect) {
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
} }
} }
const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0);
const loc = ctx.mlirCtx().location(@src()); const loc = ctx.mlirCtx().location(@src());
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{ const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
.operands = &.{scalar_pred.value()}, .operands = &.{pred.asScalar().value()},
.result_type_inference = true, .result_type_inference = true,
.blocks = &.{ true_branch_block, false_branch_block }, .blocks = &.{ true_branch_block, false_branch_block },
// We can't verify right away, cause the weights captured by the if haven't been added yet. // We can't verify right away, cause the weights captured by the if haven't been added yet.
@ -958,18 +952,20 @@ pub fn scatter(
const UpdateS = BlockSign(update_fn); const UpdateS = BlockSign(update_fn);
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar }); const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM"); const arena = ctx.allocator();
defer input_values.deinit(); const input_values = arena.alloc(mlir.Value, n_inputs) catch @panic("OOM");
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable; defer arena.free(input_values);
var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM"); meta.collectBuf(CompilationContext.getValue, ctx, &inputs, input_values);
defer updates_values.deinit();
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable; const updates_values = arena.alloc(mlir.Value, n_updates) catch @panic("OOM");
defer arena.free(updates_values);
meta.collectBuf(CompilationContext.getValue, ctx, &updates, updates_values);
const op = dialect.stablehlo.scatter( const op = dialect.stablehlo.scatter(
mlir_ctx, mlir_ctx,
input_values.items, input_values,
&.{indices.value()}, &.{indices.value()},
updates_values.items, updates_values,
update_block, update_block,
.{ .{
.update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(), .update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(),

View File

@ -75,14 +75,17 @@ pub const Client = opaque {
} }
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable { fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
var bytecode: std.array_list.Managed(u8) = .init(allocator); var bytecode: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096);
defer bytecode.deinit(); defer bytecode.deinit();
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| { module.op().writeBytecodeWithConfig(&bytecode.writer, .{ .desiredEmitedVersion = 1 }) catch |err| {
log.err("failed to write module bytecode: {}", .{err}); log.err("failed to write module bytecode: {}", .{err});
return err; return switch (err) {
std.Io.Writer.Error.WriteFailed => error.OutOfMemory,
else => |e| e,
};
}; };
var serialized_buffer: std.array_list.Managed(u8) = .init(allocator); var serialized_buffer: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096);
defer serialized_buffer.deinit(); defer serialized_buffer.deinit();
const stablehlo_version = blk: { const stablehlo_version = blk: {
@ -92,13 +95,16 @@ pub const Client = opaque {
break :blk dialects.stablehlo.getMinimumVersion(); break :blk dialects.stablehlo.getMinimumVersion();
}; };
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| { dialects.stablehlo.serializePortableArtifact(bytecode.written(), stablehlo_version, &serialized_buffer.writer) catch |err| {
log.err("failed to serialize to portable artifact: {}", .{err}); log.err("failed to serialize to portable artifact: {}", .{err});
return err; return switch (err) {
std.Io.Writer.Error.WriteFailed => error.OutOfMemory,
else => |e| e,
};
}; };
return @ptrCast(try self.inner().compile(api, .{ return @ptrCast(try self.inner().compile(api, .{
.bytecode = serialized_buffer.items, .bytecode = serialized_buffer.written(),
.bytecode_format = .mlir, .bytecode_format = .mlir,
.compile_options_pb = compile_options_pb, .compile_options_pb = compile_options_pb,
})); }));

View File

@ -95,7 +95,7 @@ const _CreateOptions = struct {
pub const Cpu = struct { pub const Cpu = struct {
device_count: u32, device_count: u32,
fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { fn writeNamedValues(self: Cpu, values: *std.ArrayList(pjrt.NamedValue)) void {
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count))); values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
} }
}; };
@ -124,7 +124,7 @@ const _CreateOptions = struct {
}; };
}; };
fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void { fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
switch (self.allocator) { switch (self.allocator) {
.platform => { .platform => {
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform")); values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
@ -145,7 +145,7 @@ const _CreateOptions = struct {
}; };
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue { pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out); var values = std.ArrayList(pjrt.NamedValue).fromOwnedSlice(out);
values.shrinkRetainingCapacity(0); values.shrinkRetainingCapacity(0);
switch (target) { switch (target) {
.cpu => self.cpu.writeNamedValues(&values), .cpu => self.cpu.writeNamedValues(&values),

View File

@ -386,14 +386,8 @@ pub const Shape = struct {
/// Format the shape. /// Format the shape.
/// Default format: "Shape({.a=10, .b=20}, dtype=.f32)" /// Default format: "Shape({.a=10, .b=20}, dtype=.f32)"
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32" /// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
pub fn format( pub fn format(self: Shape, writer: *std.Io.Writer) !void {
self: Shape, _ = try writer.writeByte('{');
writer: anytype,
) !void {
// TODO: impl alternative format
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
const bare_fmt = true;
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
var need_comma = false; var need_comma = false;
for (self.dims(), 0..) |d, i| { for (self.dims(), 0..) |d, i| {
@ -411,7 +405,7 @@ pub const Shape = struct {
} }
if (need_comma) try writer.writeByte(','); if (need_comma) try writer.writeByte(',');
_ = try writer.write(@tagName(self.dtype())); _ = try writer.write(@tagName(self.dtype()));
_ = try writer.write(if (bare_fmt) "}" else "})"); _ = try writer.writeByte('}');
} }
/// Broadcasts a Tensor to the given shape, extending dimensions if needed. /// Broadcasts a Tensor to the given shape, extending dimensions if needed.

View File

@ -52,10 +52,7 @@ pub const Tensor = struct {
return CompilationContext.current(); return CompilationContext.current();
} }
pub fn format( pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
self: Tensor,
writer: anytype,
) !void {
// TODO(0.15.0) handle format // TODO(0.15.0) handle format
// const bare_fmt = fmt.len == 1 and fmt[0] == '_'; // const bare_fmt = fmt.len == 1 and fmt[0] == '_';
const bare_fmt = false; const bare_fmt = false;
@ -1146,7 +1143,7 @@ pub const Tensor = struct {
contracting_axes: []const [2]i8, contracting_axes: []const [2]i8,
batching_axes: []const [2]i8, batching_axes: []const [2]i8,
) Tensor { ) Tensor {
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {f} and {f}", .{ lhs, rhs });
const Axes = stdx.BoundedArray(i64, MAX_RANK); const Axes = stdx.BoundedArray(i64, MAX_RANK);
@ -1156,7 +1153,7 @@ pub const Tensor = struct {
var rhs_batching_axes: Axes = .{}; var rhs_batching_axes: Axes = .{};
for (batching_axes) |b_axes| { for (batching_axes) |b_axes| {
const l, const r = b_axes; const l, const r = b_axes;
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs }); stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {d} and {d} in {f} and {f}", .{ l, r, lhs, rhs });
var t = lhs._shape.tag(l); var t = lhs._shape.tag(l);
if (t == Shape.TagUnknown) t = rhs._shape.tag(r); if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
res_shape = res_shape.appendDim(lhs._shape.dim(l), t); res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
@ -1521,14 +1518,7 @@ pub const Tensor = struct {
const to_the_end = std.math.maxInt(i64); const to_the_end = std.math.maxInt(i64);
pub fn format( pub fn format(self: Slice, writer: *std.Io.Writer) !void {
self: Slice,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
if (self.singleton) { if (self.singleton) {
try writer.print("[{}]", .{self.start}); try writer.print("[{}]", .{self.start});
} else if (self.end == to_the_end and self.step == 1) { } else if (self.end == to_the_end and self.step == 1) {
@ -2043,7 +2033,7 @@ pub const Tensor = struct {
/// Converts the given 1 element Tensor into a 0-rank Tensor. /// Converts the given 1 element Tensor into a 0-rank Tensor.
pub fn asScalar(self: Tensor) Tensor { pub fn asScalar(self: Tensor) Tensor {
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self}); stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self});
return self.reshape(.{}); return if (self.rank() == 0) self else self.reshape(.{});
} }
pub const Pad = struct { pub const Pad = struct {
@ -2582,7 +2572,7 @@ pub const Tensor = struct {
/// that requires host<->device synchronization. /// that requires host<->device synchronization.
/// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR. /// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR.
pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor { pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor {
scoped_log.debug("scatterSlices({}, {any}, {})", .{ self, indices, updates }); scoped_log.debug("scatterSlices({f}, {any}, {f})", .{ self, indices, updates });
const UpdateType = @TypeOf(ScatterOpts.increment); const UpdateType = @TypeOf(ScatterOpts.increment);
@ -3087,7 +3077,7 @@ pub const Tensor = struct {
const tail_chunk_size: i64 = @rem(d, chunk_size); const tail_chunk_size: i64 = @rem(d, chunk_size);
const allocator = self.getContext().allocator(); const allocator = self.getContext().allocator();
var chunks = std.ArrayListUnmanaged(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM"); var chunks = std.ArrayList(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM");
for (0..n_chunks) |i| { for (0..n_chunks) |i| {
const start: i64 = @as(i64, @intCast(i)) * chunk_size; const start: i64 = @as(i64, @intCast(i)) * chunk_size;
chunks.appendAssumeCapacity( chunks.appendAssumeCapacity(

View File

@ -229,7 +229,7 @@ pub fn testLayerOut(
const FetchCtx = struct { const FetchCtx = struct {
store: zml.aio.BufferStore, store: zml.aio.BufferStore,
index: u32, index: u32,
prefix: std.ArrayListUnmanaged(u8), prefix: std.ArrayList(u8),
platform: zml.Platform, platform: zml.Platform,
fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer { fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer {

View File

@ -1,4 +1,4 @@
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library") load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_test")
zig_library( zig_library(
name = "tokenizer", name = "tokenizer",
@ -25,3 +25,14 @@ zig_binary(
"//stdx", "//stdx",
], ],
) )
zig_test(
name = "test",
main = "homemade.zig",
visibility = ["//visibility:public"],
deps = [
"//async",
"//ffi:zig",
"//stdx",
],
)

View File

@ -110,6 +110,7 @@ pub const HFTokenizer = opaque {
} }
pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 { pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 {
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token)); const id = c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
return if (id == std.math.maxInt(u32)) null else id;
} }
}; };

View File

@ -189,10 +189,9 @@ pub const Tokenizer = struct {
// Step by step visualization of the progress. // Step by step visualization of the progress.
if (options.debug) { if (options.debug) {
var _debug_buf: [256]u8 = undefined; var _debug_buf: [256]u8 = undefined;
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf); var debug_progress: std.Io.Writer = .fixed(&_debug_buf);
var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator()); self.decodeWithOpts(tok_buff[0..num_tokens], &debug_progress, .{ .sep = "|" }) catch {};
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {}; log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.buffered() });
log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
} }
var best_score: f32 = -1e10; var best_score: f32 = -1e10;
var best_token: u32 = 0; var best_token: u32 = 0;
@ -311,22 +310,19 @@ pub const Tokenizer = struct {
/// Converts the given slice of tokens back into bytes. /// Converts the given slice of tokens back into bytes.
/// Note that if the tokenizer allows sub-unicode bytes, it's possible /// Note that if the tokenizer allows sub-unicode bytes, it's possible
/// the output is not valid utf8. /// the output is not valid utf8.
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 { pub fn decode(self: *const Tokenizer, input: []const u32, writer: *std.Io.Writer) std.Io.Writer.Error!void {
var output = std.array_list.Managed(u8).init(allocator); try self.decodeWithOpts(input, writer, .{});
errdefer output.deinit();
try self.decodeWithOpts(&output, input, .{});
return output.toOwnedSlice();
} }
pub fn decodeWithOpts( pub fn decodeWithOpts(
self: *const Tokenizer, self: *const Tokenizer,
output: *std.array_list.Managed(u8),
input: []const u32, input: []const u32,
output: *std.Io.Writer,
opts: struct { sep: []const u8 = "" }, opts: struct { sep: []const u8 = "" },
) error{OutOfMemory}!void { ) std.Io.Writer.Error!void {
const escaped = if (self.normalizer) |n| n.escapedSpace() else null; const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
// Flag used to indicate if the first dummy whitespace has been consumed. // Flag used to indicate if the first dummy whitespace has been consumed.
var first_token: bool = true;
for (input) |id| { for (input) |id| {
// Retrieve the slice corresponding to the id. // Retrieve the slice corresponding to the id.
var piece = self.lookupPiece(id); var piece = self.lookupPiece(id);
@ -337,12 +333,13 @@ pub const Tokenizer = struct {
while (std.mem.startsWith(u8, piece, escspc)) { while (std.mem.startsWith(u8, piece, escspc)) {
piece = piece[escspc.len..]; piece = piece[escspc.len..];
// don't output a space at beginning of text. // don't output a space at beginning of text.
if (output.items.len > 0) try output.append(' '); if (!first_token) try output.writeByte(' ');
} }
} }
try output.appendSlice(piece); try output.writeAll(piece);
if (opts.sep.len > 0) try output.appendSlice(opts.sep); first_token = false;
try output.writeAll(opts.sep);
} }
} }
@ -472,10 +469,13 @@ pub const Decoder = struct {
} }
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 { pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
// TODO: revisite tokenizer api to use std.Io.Writer.
self.reset(); self.reset();
const res = try self.inner.decode(self.arena.allocator(), ids); // Reuse the same warmup than in init, to maximize contiguous writes.
self.current_string = res; var arena_writer = std.Io.Writer.Allocating.initCapacity(self.arena.allocator(), 4096) catch unreachable;
return res; try self.inner.decode(ids, &arena_writer.writer);
self.current_string = arena_writer.written();
return self.current_string.?;
} }
pub fn string(self: *const Decoder) []const u8 { pub fn string(self: *const Decoder) []const u8 {
@ -623,9 +623,9 @@ pub const Normalizer = struct {
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null; return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
} }
fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void { fn addSlice(allocator: std.mem.Allocator, data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
try normalized.appendSlice(data); try normalized.appendSlice(allocator, data);
for (data) |_| try normalized_to_origin.append(consumed); for (data) |_| try normalized_to_origin.append(allocator, consumed);
} }
pub const Result = struct { pub const Result = struct {
@ -673,13 +673,13 @@ pub const Normalizer = struct {
// Pre-allocate outputs // Pre-allocate outputs
const space = self.escapedSpace() orelse " "; const space = self.escapedSpace() orelse " ";
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len; const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len); var normalized: std.ArrayList(u8) = try .initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
errdefer normalized.deinit(); errdefer normalized.deinit(allocator);
var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity); var normalized_to_origin: std.ArrayList(usize) = try .initCapacity(allocator, normalized.capacity);
errdefer normalized_to_origin.deinit(); errdefer normalized_to_origin.deinit(allocator);
// If the spec asks for it, add a whitespace at the beginning. // If the spec asks for it, add a whitespace at the beginning.
if (self.flags.add_dummy_prefix) try addSlice(space, consumed, &normalized, &normalized_to_origin); if (self.flags.add_dummy_prefix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin);
var is_prev_space: bool = true; var is_prev_space: bool = true;
var is_prev_word: bool = false; var is_prev_word: bool = false;
@ -706,23 +706,23 @@ pub const Normalizer = struct {
var byte = slice[0]; var byte = slice[0];
if (self.escapedSpace() != null and byte == ' ') { if (self.escapedSpace() != null and byte == ' ') {
// replace the space token by the special token // replace the space token by the special token
try addSlice(space, origin, &normalized, &normalized_to_origin); try addSlice(allocator, space, origin, &normalized, &normalized_to_origin);
is_prev_word = false; is_prev_word = false;
break :ascii; break :ascii;
} else if (self.flags.split_on_punct_ascii) { } else if (self.flags.split_on_punct_ascii) {
if (is_prev_word and isPunct(slice)) { if (is_prev_word and isPunct(slice)) {
// Insert a space, but continue handling the rest // Insert a space, but continue handling the rest
try addSlice(space, origin, &normalized, &normalized_to_origin); try addSlice(allocator, space, origin, &normalized, &normalized_to_origin);
} }
} }
if (self.flags.lower_case_ascii) { if (self.flags.lower_case_ascii) {
byte = std.ascii.toLower(byte); byte = std.ascii.toLower(byte);
} }
try normalized.append(byte); try normalized.append(allocator, byte);
try normalized_to_origin.append(origin); try normalized_to_origin.append(allocator, origin);
} else { } else {
// we can safely copy to the output. // we can safely copy to the output.
try addSlice(slice, origin, &normalized, &normalized_to_origin); try addSlice(allocator, slice, origin, &normalized, &normalized_to_origin);
} }
is_prev_word = !is_prev_space and !isPunct(slice); is_prev_word = !is_prev_space and !isPunct(slice);
} }
@ -732,20 +732,20 @@ pub const Normalizer = struct {
while (std.mem.endsWith(u8, normalized.items, space)) { while (std.mem.endsWith(u8, normalized.items, space)) {
const length = normalized.items.len - space.len; const length = normalized.items.len - space.len;
consumed = normalized_to_origin.items[length]; consumed = normalized_to_origin.items[length];
try normalized.resize(length); try normalized.resize(allocator, length);
try normalized_to_origin.resize(length); try normalized_to_origin.resize(allocator, length);
} }
} }
try normalized_to_origin.append(consumed); try normalized_to_origin.append(allocator, consumed);
std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1); std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1);
if (self.flags.add_dummy_suffix) try addSlice(space, consumed, &normalized, &normalized_to_origin); if (self.flags.add_dummy_suffix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin);
return .{ return .{
.normalized = try normalized.toOwnedSlice(), .normalized_to_origin = try normalized_to_origin.toOwnedSlice(allocator),
.normalized_to_origin = try normalized_to_origin.toOwnedSlice(), .normalized = try normalized.toOwnedSlice(allocator),
}; };
} }
@ -1006,8 +1006,7 @@ pub const Gpt2TextDecoder = struct {
/// Transform bytes representing text under the gpt2 encoding, /// Transform bytes representing text under the gpt2 encoding,
/// and write to the `unicode` buffer utf-8 bytes. /// and write to the `unicode` buffer utf-8 bytes.
pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 { pub fn decode(self: Gpt2TextDecoder, bytes: []const u8, writer: *std.Io.Writer) (error{InvalidInput} || std.Io.Writer.Error)!void {
const start = unicode.items.len;
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes }; var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
while (it.nextCodepointSlice()) |codepoint| { while (it.nextCodepointSlice()) |codepoint| {
const code: Code = switch (codepoint.len) { const code: Code = switch (codepoint.len) {
@ -1016,9 +1015,8 @@ pub const Gpt2TextDecoder = struct {
else => return error.InvalidInput, else => return error.InvalidInput,
}; };
const byte = self.code_to_byte.get(code) orelse return error.InvalidInput; const byte = self.code_to_byte.get(code) orelse return error.InvalidInput;
try unicode.append(byte); try writer.writeByte(byte);
} }
return unicode.items[start..];
} }
inline fn isPrintableByte(c: u8) bool { inline fn isPrintableByte(c: u8) bool {
@ -1030,15 +1028,33 @@ test Gpt2TextDecoder {
var decoder = try Gpt2TextDecoder.init(testing.allocator); var decoder = try Gpt2TextDecoder.init(testing.allocator);
defer decoder.deinit(); defer decoder.deinit();
var out = std.array_list.Managed(u8).init(testing.allocator); var buf: [128]u8 = undefined;
defer out.deinit();
// Ascii is not changed. // Ascii is not changed.
try testing.expectEqualStrings("getTitle", try decoder.decode(&out, "getTitle")); {
var out: std.Io.Writer = .fixed(&buf);
try decoder.decode("getTitle", &out);
try testing.expectEqualStrings("getTitle", out.buffered());
}
{
var out: std.Io.Writer = .fixed(&buf);
try decoder.decode("Ċ", &out);
try testing.expectEqualStrings("\n", out.buffered());
}
// Leading space are represented with 'Ġ' // Leading space are represented with 'Ġ'
try testing.expectEqualStrings(" UINavigationController", try decoder.decode(&out, "ĠUINavigationController")); {
var out: std.Io.Writer = .fixed(&buf);
try decoder.decode("ĠUINavigationController", &out);
try testing.expectEqualStrings(" UINavigationController", out.buffered());
}
// Russian is wild // Russian is wild
try testing.expectEqualStrings(" работ", try decoder.decode(&out, "ĠÑĢабоÑĤ")); {
var out: std.Io.Writer = .fixed(&buf);
try decoder.decode("ĠÑĢабоÑĤ", &out);
try testing.expectEqualStrings(" работ", out.buffered());
}
} }
/// Open a json file in HF format and load the vocab from it. /// Open a json file in HF format and load the vocab from it.
@ -1077,12 +1093,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
// Buffer containing all concatenated tokens. // Buffer containing all concatenated tokens.
// Reserve a big chunk, to avoid grow event, but release over-allocated memory. // Reserve a big chunk, to avoid grow event, but release over-allocated memory.
var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len); var all_tokens: std.Io.Writer.Allocating = try .initCapacity(tokenizer.arena_state.allocator(), file_content.len);
const original_alloc = all_tokens.items.ptr; const original_alloc = all_tokens.writer.buffer.ptr;
// A re-alloc event here means we have invalidated all slices inside the tokenizer. // A re-alloc event here means we have invalidated all slices inside the tokenizer.
// If this is too annoying we could switch to a custom type instead of slices. // If this is too annoying we could switch to a custom type instead of slices.
defer { defer {
std.debug.assert(all_tokens.items.ptr == original_alloc); std.debug.assert(all_tokens.writer.buffer.ptr == original_alloc);
} }
// gpt2 based tokenizer got a special way of encoding unicode. // gpt2 based tokenizer got a special way of encoding unicode.
@ -1094,16 +1110,18 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
defer gpt2_decoder.deinit(); defer gpt2_decoder.deinit();
var it = vocab.iterator(); var it = vocab.iterator();
while (it.next()) |kv| { while (it.next()) |kv| {
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| { const n = all_tokens.writer.buffered().len;
gpt2_decoder.decode(kv.key_ptr.*, &all_tokens.writer) catch |err| {
switch (err) { switch (err) {
error.InvalidInput => { error.InvalidInput => {
is_gpt2_vocab = false; is_gpt2_vocab = false;
break; break;
}, },
else => return err, error.WriteFailed => return error.OutOfMemory,
} }
}; };
const idx: u32 = @intCast(kv.value_ptr.*.integer); const idx: u32 = @intCast(kv.value_ptr.*.integer);
const token = all_tokens.written()[n..];
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token); tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
} }
@ -1126,15 +1144,16 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
if (token_obj != .object) return error.InvalidFormat; if (token_obj != .object) return error.InvalidFormat;
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat; const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat); const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
const token = try if (is_gpt2_vocab) const n = all_tokens.written().len;
gpt2_decoder.decode(&all_tokens, v) try if (is_gpt2_vocab)
gpt2_decoder.decode(v, &all_tokens.writer)
else else
dup(&all_tokens, v); all_tokens.writer.writeAll(v);
const token = all_tokens.written()[n..];
tokenizer.addOwnedTokenByIndex(id, 0, token); tokenizer.addOwnedTokenByIndex(id, 0, token);
} }
// We won't add more tokens here, let release. // We won't add more tokens here, let release.
all_tokens.shrinkAndFree(all_tokens.items.len); _ = try all_tokens.toOwnedSlice();
var unk = tokenizer.lookup("<unk>"); var unk = tokenizer.lookup("<unk>");
if (objectGet(model, .integer, "unk_token")) |unk_tok| { if (objectGet(model, .integer, "unk_token")) |unk_tok| {
@ -1165,10 +1184,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
} }
/// Returns a copy of the given string, stored inside the given ArrayList. /// Returns a copy of the given string, stored inside the given ArrayList.
fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 { fn dup(allocating: *std.Io.Writer.Allocating, str: []const u8) std.Io.Writer.Error![]const u8 {
const n = buffer.items.len; const n = allocating.written().len;
try buffer.appendSlice(str); try allocating.writer.writeAll(str);
return buffer.items[n..]; return allocating.written()[n..];
} }
/// Returns the given entry in a json object only if it has the right type. /// Returns the given entry in a json object only if it has the right type.

View File

@ -48,11 +48,11 @@ pub fn asyncMain() !void {
} }
if (args.expected.len > 0) { if (args.expected.len > 0) {
var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len); var expected: std.ArrayList(u32) = try .initCapacity(allocator, args.prompt.len);
var it = std.mem.splitSequence(u8, args.expected, ","); var it = std.mem.splitSequence(u8, args.expected, ",");
while (it.next()) |int_token| { while (it.next()) |int_token| {
const tok = try std.fmt.parseInt(u32, int_token, 10); const tok = try std.fmt.parseInt(u32, int_token, 10);
try expected.append(tok); try expected.append(allocator, tok);
} }
if (!std.mem.eql(u32, expected.items, prompt_tok)) { if (!std.mem.eql(u32, expected.items, prompt_tok)) {
log.err("Doesn't match expected: {any}", .{expected.items}); log.err("Doesn't match expected: {any}", .{expected.items});