Remove deprecated writer interface APIs from core ZML modules (async, MLIR, PJRT, runtime, fmt, aio, buffer, exe, hostbuffer, meta, mlirx).
This commit is contained in:
parent
090d7748d5
commit
3ed9bca5ad
@ -159,26 +159,21 @@ const Coro = struct {
|
|||||||
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn {
|
fn runcoro(from: *base.Coro, this: *base.Coro) callconv(.c) noreturn {
|
||||||
const from_coro: *Coro = @fieldParentPtr("impl", from);
|
const from_coro: *Coro = @fieldParentPtr("impl", from);
|
||||||
const this_coro: *Coro = @fieldParentPtr("impl", this);
|
const this_coro: *Coro = @fieldParentPtr("impl", this);
|
||||||
log(.debug, "coro start {any}", .{this_coro.id});
|
log(.debug, "coro start {f}", .{this_coro.*});
|
||||||
@call(.auto, this_coro.func, .{});
|
@call(.auto, this_coro.func, .{});
|
||||||
this_coro.status = .Done;
|
this_coro.status = .Done;
|
||||||
thread_state.switchOut(from_coro);
|
thread_state.switchOut(from_coro);
|
||||||
|
|
||||||
// Never returns
|
// Never returns
|
||||||
stdx.debug.panic("Cannot resume an already completed coroutine {any}", .{this_coro.id});
|
stdx.debug.panic("Cannot resume an already completed coroutine {f}", .{this_coro.*});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getStorage(self: Coro, comptime T: type) *T {
|
pub fn getStorage(self: Coro, comptime T: type) *T {
|
||||||
return @ptrCast(@alignCast(self.storage));
|
return @ptrCast(@alignCast(self.storage));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: Coro, comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: Coro, writer: *std.Io.Writer) !void {
|
||||||
_ = fmt;
|
try writer.print("Coro{{.id = {any}, .status = {t}}}", .{ self.id, self.status });
|
||||||
_ = options;
|
|
||||||
try writer.print("Coro{{.id = {any}, .status = {s}}}", .{
|
|
||||||
self.id,
|
|
||||||
@tagName(self.status),
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -292,7 +287,7 @@ const ThreadState = struct {
|
|||||||
|
|
||||||
/// Called from resume
|
/// Called from resume
|
||||||
fn switchIn(self: *ThreadState, target: Frame) void {
|
fn switchIn(self: *ThreadState, target: Frame) void {
|
||||||
log(.debug, "coro resume {any} from {any}", .{ target.id, self.current().id });
|
log(.debug, "coro resume {f} from {f}", .{ target.id, self.current().id });
|
||||||
|
|
||||||
// Switch to target, setting this coro as the resumer.
|
// Switch to target, setting this coro as the resumer.
|
||||||
self.switchTo(target, true);
|
self.switchTo(target, true);
|
||||||
@ -307,7 +302,7 @@ const ThreadState = struct {
|
|||||||
|
|
||||||
/// Called from suspend
|
/// Called from suspend
|
||||||
fn switchOut(self: *ThreadState, target: Frame) void {
|
fn switchOut(self: *ThreadState, target: Frame) void {
|
||||||
log(.debug, "coro suspend {any} to {any}", .{ self.current().id, target.id });
|
log(.debug, "coro suspend {f} to {f}", .{ self.current().id, target.id });
|
||||||
self.switchTo(target, false);
|
self.switchTo(target, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -384,13 +379,8 @@ const CoroId = struct {
|
|||||||
self.invocation += 1;
|
self.invocation += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: @This(), comptime fmt: []const u8, options: std.fmt.FormatOptions, writer: anytype) !void {
|
pub fn format(self: @This(), writer: *std.Io.Writer) !void {
|
||||||
_ = fmt;
|
try writer.print("CoroId{{.cid={d}, .i={d}}}", .{ self.id.coro, self.invocation });
|
||||||
_ = options;
|
|
||||||
try writer.print("CoroId{{.cid={d}, .i={d}}}", .{
|
|
||||||
self.id.coro,
|
|
||||||
self.invocation,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|||||||
@ -750,13 +750,13 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
|
if (@intFromEnum(opts.api_version) < @intFromEnum(CustomCallOpts.ApiVersion.typed_ffi)) {
|
||||||
stdx.debug.assert(
|
stdx.debug.assert(
|
||||||
backend_config.isA(mlir.StringAttribute),
|
backend_config.isA(mlir.StringAttribute),
|
||||||
"API version < 4 requires a string as backend_config, got {}",
|
"API version < 4 requires a string as backend_config, got {f}",
|
||||||
.{backend_config},
|
.{backend_config},
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
stdx.debug.assert(
|
stdx.debug.assert(
|
||||||
backend_config.isA(mlir.DictionaryAttribute),
|
backend_config.isA(mlir.DictionaryAttribute),
|
||||||
"API version >= 4 requires a dictionary as backend_config, got {}",
|
"API version >= 4 requires a dictionary as backend_config, got {f}",
|
||||||
.{backend_config},
|
.{backend_config},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -1285,20 +1285,21 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
|
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
|
||||||
var buf: [32]u8 = undefined;
|
const Cmp = struct {
|
||||||
|
v1: []const u8,
|
||||||
|
v1_is_smaller: bool = undefined,
|
||||||
|
|
||||||
var stream = std.io.fixedBufferStream(&buf);
|
pub fn smallerCb(smaller_version: c.MlirStringRef, opaque_cmp: ?*anyopaque) callconv(.c) void {
|
||||||
var context = .{ .writer = stream.writer() };
|
var cmp: *@This() = @ptrCast(@alignCast(opaque_cmp));
|
||||||
const WriterContext = @TypeOf(context);
|
cmp.v1_is_smaller = std.mem.eql(u8, cmp.v1, smaller_version.data[0..smaller_version.length]);
|
||||||
|
|
||||||
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
|
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
|
||||||
}
|
}
|
||||||
}).callback, &context);
|
};
|
||||||
|
|
||||||
return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2;
|
var cmp_ctx: Cmp = .{ .v1 = version1 };
|
||||||
|
const cmp_res = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), Cmp.smallerCb, &cmp_ctx);
|
||||||
|
|
||||||
|
std.debug.assert(cmp_res.value != 0);
|
||||||
|
return if (cmp_ctx.v1_is_smaller) version1 else version2;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getCurrentVersion() []const u8 {
|
pub fn getCurrentVersion() []const u8 {
|
||||||
@ -1308,18 +1309,9 @@ pub fn getCurrentVersion() []const u8 {
|
|||||||
var once = std.once(call);
|
var once = std.once(call);
|
||||||
|
|
||||||
fn call() void {
|
fn call() void {
|
||||||
var stream = std.io.fixedBufferStream(&buf);
|
var writer: std.Io.Writer = .fixed(&buf);
|
||||||
var writer_ = stream.writer();
|
c.stablehloGetCurrentVersion(printCallbackNoFail, &writer);
|
||||||
const ContextWriter = @TypeOf(writer_);
|
str = writer.buffered();
|
||||||
|
|
||||||
c.stablehloGetCurrentVersion((struct {
|
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
|
||||||
const writer: *ContextWriter = @ptrCast(@alignCast(userdata));
|
|
||||||
_ = writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
|
||||||
}
|
|
||||||
}).callback, &writer_);
|
|
||||||
|
|
||||||
str = buf[0..stream.pos];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1334,18 +1326,9 @@ pub fn getMinimumVersion() []const u8 {
|
|||||||
var once = std.once(call);
|
var once = std.once(call);
|
||||||
|
|
||||||
fn call() void {
|
fn call() void {
|
||||||
var stream = std.io.fixedBufferStream(&buf);
|
var writer: std.Io.Writer = .fixed(&buf);
|
||||||
var context = .{ .writer = stream.writer() };
|
c.stablehloGetMinimumVersion(printCallbackNoFail, &writer);
|
||||||
const WriterContext = @TypeOf(context);
|
str = writer.buffered();
|
||||||
|
|
||||||
c.stablehloGetMinimumVersion((struct {
|
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
|
||||||
}
|
|
||||||
}).callback, &context);
|
|
||||||
|
|
||||||
str = buf[0..stream.pos];
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -1353,14 +1336,25 @@ pub fn getMinimumVersion() []const u8 {
|
|||||||
return state.str;
|
return state.str;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serializePortableArtifact(bytecode: []const u8, target_version: []const u8, writer: anytype) !void {
|
pub fn serializePortableArtifact(
|
||||||
var context = .{ .writer = writer };
|
bytecode: []const u8,
|
||||||
const WriterContext = @TypeOf(context);
|
target_version: []const u8,
|
||||||
|
writer: *std.Io.Writer,
|
||||||
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(bytecode), mlir.stringRef(target_version), (struct {
|
) error{ InvalidMlirBytecodeVersion, WriteFailed }!void {
|
||||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.c) void {
|
var writer_err: mlir.WriterWithErr = .{ .writer = writer };
|
||||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
try mlir.successOr(
|
||||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
c.stablehloSerializePortableArtifactFromStringRef(
|
||||||
}
|
mlir.stringRef(bytecode),
|
||||||
}).callback, &context), error.InvalidMlirBytecodeVersion);
|
mlir.stringRef(target_version),
|
||||||
|
mlir.WriterWithErr.printCallback,
|
||||||
|
&writer_err,
|
||||||
|
),
|
||||||
|
error.InvalidMlirBytecodeVersion,
|
||||||
|
);
|
||||||
|
return try writer_err.check();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn printCallbackNoFail(mlir_str: c.MlirStringRef, opaque_writer: ?*anyopaque) callconv(.c) void {
|
||||||
|
const writer: *std.Io.Writer = @ptrCast(@alignCast(opaque_writer));
|
||||||
|
writer.writeAll(mlir.fromStringRef(mlir_str)) catch @panic("Failed to write MLIR");
|
||||||
}
|
}
|
||||||
|
|||||||
320
mlir/mlir.zig
320
mlir/mlir.zig
@ -7,18 +7,20 @@ const stdx = @import("stdx");
|
|||||||
const log = std.log.scoped(.mlir);
|
const log = std.log.scoped(.mlir);
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDeclsRecursive(@This());
|
||||||
|
|
||||||
_ = try Context.init();
|
_ = try Context.init();
|
||||||
}
|
}
|
||||||
|
|
||||||
const Error = error{
|
pub const Error = error{
|
||||||
/// Invalid Mlir was created.
|
/// Invalid Mlir was created.
|
||||||
InvalidMlir,
|
InvalidMlir,
|
||||||
/// Another Mlir error. Check the log for more context.
|
/// Another Mlir error. Check the log for more context.
|
||||||
MlirUnexpected,
|
MlirUnexpected,
|
||||||
/// A resource/executor was not found.
|
/// A resource/executor was not found.
|
||||||
NotFound,
|
NotFound,
|
||||||
|
/// Bytecode version incompatibility.
|
||||||
|
InvalidMlirBytecodeVersion,
|
||||||
OutOfMemory,
|
OutOfMemory,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -35,64 +37,61 @@ pub fn registerPasses(comptime passes: []const u8) void {
|
|||||||
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
pub fn successOr(res: c.MlirLogicalResult, err: anytype) @TypeOf(err)!void {
|
||||||
return if (res.value == 0) err;
|
return if (res.value == 0) err else {};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Alternative to MlirWrapperType
|
|
||||||
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
|
|
||||||
|
|
||||||
pub const Registry = struct {
|
pub const Registry = struct {
|
||||||
_inner: c.MlirDialectRegistry,
|
_inner: c.MlirDialectRegistry,
|
||||||
|
|
||||||
pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy);
|
pub const deinit = helpers.deinit(Registry, c.mlirDialectRegistryDestroy);
|
||||||
|
|
||||||
pub fn init() !Registry {
|
pub fn init() !Registry {
|
||||||
return helpers.init(Registry, c.mlirDialectRegistryCreate(), c.mlirDialectRegistryIsNull) orelse Error.MlirUnexpected;
|
const registry = c.mlirDialectRegistryCreate();
|
||||||
|
return .{ ._inner = .{ .ptr = registry.ptr orelse return Error.MlirUnexpected } };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Context = struct {
|
pub const Context = struct {
|
||||||
_inner: c.MlirContext,
|
_inner: c.MlirContext,
|
||||||
const Self = Context;
|
|
||||||
pub const deinit = helpers.deinit(Context, c.mlirContextDestroy);
|
pub const deinit = helpers.deinit(Context, c.mlirContextDestroy);
|
||||||
pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull);
|
pub const wrapOr = helpers.wrapOr(Context, c.mlirContextIsNull);
|
||||||
|
|
||||||
pub fn init() !Self {
|
pub fn init() !Context {
|
||||||
return Self.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected;
|
return Context.wrapOr(c.mlirContextCreate()) orelse Error.MlirUnexpected;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Self {
|
pub fn initWithRegistry(registry: Registry, threadingEnabled: bool) !Context {
|
||||||
return Self.wrapOr(
|
return Context.wrapOr(
|
||||||
c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled),
|
c.mlirContextCreateWithRegistry(registry._inner, threadingEnabled),
|
||||||
) orelse Error.InvalidMlir;
|
) orelse Error.InvalidMlir;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn setMultiThreading(self: *Self, enabled: bool) void {
|
pub fn setMultiThreading(self: *Context, enabled: bool) void {
|
||||||
c.mlirContextEnableMultithreading(self._inner, enabled);
|
c.mlirContextEnableMultithreading(self._inner, enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn appendDialectRegistry(self: *Self, registry: Registry) void {
|
pub fn appendDialectRegistry(self: *Context, registry: Registry) void {
|
||||||
c.mlirContextAppendDialectRegistry(self._inner, registry._inner);
|
c.mlirContextAppendDialectRegistry(self._inner, registry._inner);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn loadAllAvailableDialects(self: *Self) void {
|
pub fn loadAllAvailableDialects(self: *Context) void {
|
||||||
c.mlirContextLoadAllAvailableDialects(self._inner);
|
c.mlirContextLoadAllAvailableDialects(self._inner);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn numRegisteredDialects(self: Self) usize {
|
pub fn numRegisteredDialects(self: Context) usize {
|
||||||
return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner));
|
return @intCast(c.mlirContextGetNumRegisteredDialects(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn numLoadedDialects(self: Self) usize {
|
pub fn numLoadedDialects(self: Context) usize {
|
||||||
return @intCast(c.mlirContextGetNumLoadedDialects(self._inner));
|
return @intCast(c.mlirContextGetNumLoadedDialects(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn isRegisteredOperation(self: Self, op: [:0]const u8) bool {
|
pub fn isRegisteredOperation(self: Context, op: [:0]const u8) bool {
|
||||||
return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op));
|
return c.mlirContextIsRegisteredOperation(self._inner, stringRef(op));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn location(self: Self, src: std.builtin.SourceLocation) Location {
|
pub fn location(self: Context, src: std.builtin.SourceLocation) Location {
|
||||||
return Location.fromSrc(self, src);
|
return Location.fromSrc(self, src);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -103,9 +102,7 @@ pub const Module = struct {
|
|||||||
pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy);
|
pub const deinit = helpers.deinit(Module, c.mlirModuleDestroy);
|
||||||
pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull);
|
pub const wrapOr = helpers.wrapOr(Module, c.mlirModuleIsNull);
|
||||||
|
|
||||||
const Self = Module;
|
pub fn init(loc: Location) Module {
|
||||||
|
|
||||||
pub fn init(loc: Location) Self {
|
|
||||||
return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) };
|
return .{ ._inner = c.mlirModuleCreateEmpty(loc._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -142,29 +139,26 @@ pub const PassManager = struct {
|
|||||||
pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy);
|
pub const deinit = helpers.deinit(PassManager, c.mlirPassManagerDestroy);
|
||||||
pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull);
|
pub const wrapOr = helpers.wrapOr(PassManager, c.mlirPassManagerIsNull);
|
||||||
|
|
||||||
const Self = PassManager;
|
pub fn init(ctx: Context) !PassManager {
|
||||||
|
return PassManager.wrapOr(
|
||||||
pub fn init(ctx: Context) !Self {
|
|
||||||
return Self.wrapOr(
|
|
||||||
c.mlirPassManagerCreate(ctx._inner),
|
c.mlirPassManagerCreate(ctx._inner),
|
||||||
) orelse Error.MlirUnexpected;
|
) orelse Error.MlirUnexpected;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn initOnOperation(ctx: Context, op: [:0]const u8) !Self {
|
pub fn initOnOperation(ctx: Context, op: [:0]const u8) !PassManager {
|
||||||
return Self.wrapOr(
|
return PassManager.wrapOr(
|
||||||
c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)),
|
c.mlirPassManagerCreateOnOperation(ctx._inner, stringRef(op)),
|
||||||
) orelse Error.MlirUnexpected;
|
) orelse Error.MlirUnexpected;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn asOpPassManager(self: Self) OpPassManager {
|
pub fn asOpPassManager(self: PassManager) OpPassManager {
|
||||||
return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) };
|
return .{ ._inner = c.mlirPassManagerGetAsOpPassManager(self._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn enableIRPrinting(self: *Self) void {
|
// TODO mlirPassManagerEnableIRPrinting
|
||||||
c.mlirPassManagerEnableIRPrinting(self._inner);
|
// pub fn enableIRPrinting(self: *PassManager) void {}
|
||||||
}
|
|
||||||
|
|
||||||
pub fn runOnOp(self: *Self, op: Operation) error{InvalidMlir}!void {
|
pub fn runOnOp(self: *PassManager, op: Operation) error{InvalidMlir}!void {
|
||||||
if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) {
|
if (c.mlirPassManagerRunOnOp(self._inner, op._inner).value == 0) {
|
||||||
return Error.InvalidMlir;
|
return Error.InvalidMlir;
|
||||||
}
|
}
|
||||||
@ -193,21 +187,20 @@ pub const OpPassManager = struct {
|
|||||||
|
|
||||||
pub const Identifier = struct {
|
pub const Identifier = struct {
|
||||||
_inner: c.MlirIdentifier,
|
_inner: c.MlirIdentifier,
|
||||||
const Self = Identifier;
|
|
||||||
|
|
||||||
pub fn get(ctx: Context, str_: [:0]const u8) Self {
|
pub fn get(ctx: Context, str_: [:0]const u8) Identifier {
|
||||||
return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) };
|
return .{ ._inner = c.mlirIdentifierGet(ctx._inner, stringRef(str_)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn context(self: Self) Context {
|
pub fn context(self: Identifier) Context {
|
||||||
return .{ ._inner = c.mlirIdentifierGetContext(self._inner) };
|
return .{ ._inner = c.mlirIdentifierGetContext(self._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn str(self: Self) []const u8 {
|
pub fn str(self: Identifier) []const u8 {
|
||||||
return fromStringRef(c.mlirIdentifierStr(self._inner));
|
return fromStringRef(c.mlirIdentifierStr(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn equals(self: Self, other: Self) bool {
|
pub fn equals(self: Identifier, other: Identifier) bool {
|
||||||
return c.mlirIdentifierEqual(self._inner, other._inner);
|
return c.mlirIdentifierEqual(self._inner, other._inner);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -322,6 +315,14 @@ pub const Attribute = struct {
|
|||||||
|
|
||||||
return DictionaryAttribute.init(ctx, attrs).asAttr();
|
return DictionaryAttribute.init(ctx, attrs).asAttr();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn eqlAny(Attr: type) fn (Attr, Attr) bool {
|
||||||
|
return struct {
|
||||||
|
fn eql(a: Attr, b: Attr) bool {
|
||||||
|
return a.asAttr().eql(b.asAttr());
|
||||||
|
}
|
||||||
|
}.eql;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const NamedAttribute = extern struct {
|
pub const NamedAttribute = extern struct {
|
||||||
@ -349,15 +350,14 @@ pub const NamedAttribute = extern struct {
|
|||||||
pub const StringAttribute = struct {
|
pub const StringAttribute = struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
pub const is_a_fn = c.mlirAttributeIsAString;
|
pub const is_a_fn = c.mlirAttributeIsAString;
|
||||||
const Self = StringAttribute;
|
pub const asAttr = Attribute.fromAny(StringAttribute);
|
||||||
pub const asAttr = Attribute.fromAny(Self);
|
pub const eql = Attribute.eqlAny(StringAttribute);
|
||||||
pub const eql = Attribute.eqlAny(Self);
|
|
||||||
|
|
||||||
pub fn init(ctx: Context, str: []const u8) Self {
|
pub fn init(ctx: Context, str: []const u8) StringAttribute {
|
||||||
return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) };
|
return .{ ._inner = c.mlirStringAttrGet(ctx._inner, stringRef(str)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn value(self: Self) []const u8 {
|
pub fn value(self: StringAttribute) []const u8 {
|
||||||
return fromStringRef(c.mlirStringAttrGetValue(self._inner));
|
return fromStringRef(c.mlirStringAttrGetValue(self._inner));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -365,15 +365,14 @@ pub const StringAttribute = struct {
|
|||||||
pub const BoolAttribute = struct {
|
pub const BoolAttribute = struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
pub const is_a_fn = c.mlirAttributeIsABool;
|
pub const is_a_fn = c.mlirAttributeIsABool;
|
||||||
const Self = BoolAttribute;
|
pub const asAttr = Attribute.fromAny(BoolAttribute);
|
||||||
pub const asAttr = Attribute.fromAny(Self);
|
pub const eql = Attribute.eqlAny(BoolAttribute);
|
||||||
pub const eql = Attribute.eqlAny(Self);
|
|
||||||
|
|
||||||
pub fn init(ctx: Context, value_: bool) Self {
|
pub fn init(ctx: Context, value_: bool) BoolAttribute {
|
||||||
return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) };
|
return .{ ._inner = c.mlirBoolAttrGet(ctx._inner, if (value_) 1 else 0) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn value(self: Self) bool {
|
pub fn value(self: BoolAttribute) bool {
|
||||||
return c.mlirBoolAttrGetValue(self._inner);
|
return c.mlirBoolAttrGetValue(self._inner);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -397,19 +396,18 @@ pub const TypeAttribute = struct {
|
|||||||
pub const ArrayAttribute = struct {
|
pub const ArrayAttribute = struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
pub const is_a_fn = c.mlirAttributeIsAArray;
|
pub const is_a_fn = c.mlirAttributeIsAArray;
|
||||||
const Self = ArrayAttribute;
|
pub const asAttr = Attribute.fromAny(ArrayAttribute);
|
||||||
pub const asAttr = Attribute.fromAny(Self);
|
pub const eql = Attribute.eqlAny(ArrayAttribute);
|
||||||
pub const eql = Attribute.eqlAny(Self);
|
|
||||||
|
|
||||||
pub fn init(ctx: Context, attrs: []const Attribute) Self {
|
pub fn init(ctx: Context, attrs: []const Attribute) ArrayAttribute {
|
||||||
return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) };
|
return .{ ._inner = c.mlirArrayAttrGet(ctx._inner, @intCast(attrs.len), @ptrCast(attrs.ptr)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn size(self: Self) usize {
|
pub fn size(self: ArrayAttribute) usize {
|
||||||
return @intCast(c.mlirArrayAttrGetNumElements(self._inner));
|
return @intCast(c.mlirArrayAttrGetNumElements(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(self: Self, index: usize) Attribute {
|
pub fn get(self: ArrayAttribute, index: usize) Attribute {
|
||||||
return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) };
|
return .{ ._inner = c.mlirArrayAttrGetElement(self._inner, @intCast(index)) };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -590,13 +588,13 @@ pub fn DenseElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
|||||||
|
|
||||||
pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr {
|
pub fn init(shaped_type: Type, slice: []const dt.ZigType()) Attr {
|
||||||
const raw_bytes = std.mem.sliceAsBytes(slice);
|
const raw_bytes = std.mem.sliceAsBytes(slice);
|
||||||
const res: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet(
|
const attr: Attr = .{ ._inner = c.mlirDenseElementsAttrRawBufferGet(
|
||||||
shaped_type._inner,
|
shaped_type._inner,
|
||||||
@intCast(raw_bytes.len),
|
@intCast(raw_bytes.len),
|
||||||
@ptrCast(raw_bytes.ptr),
|
@ptrCast(raw_bytes.ptr),
|
||||||
) };
|
) };
|
||||||
std.debug.assert(Attribute.wrapOr(res._inner) != null);
|
std.debug.assert(attr._inner.ptr != null);
|
||||||
return res;
|
return attr;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn len(self: Attr) usize {
|
pub fn len(self: Attr) usize {
|
||||||
@ -717,7 +715,7 @@ pub const DictionaryAttribute = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute {
|
pub fn getByName(self: DictionaryAttribute, name: [:0]const u8) ?Attribute {
|
||||||
return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, name));
|
return Attribute.wrapOr(c.mlirDictionaryAttrGetElementByName(self._inner, stringRef(name)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -728,8 +726,7 @@ pub const Operation = struct {
|
|||||||
pub const dump = helpers.dump(Operation, c.mlirOperationDestroy);
|
pub const dump = helpers.dump(Operation, c.mlirOperationDestroy);
|
||||||
pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy);
|
pub const deinit = helpers.deinit(Operation, c.mlirOperationDestroy);
|
||||||
pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull);
|
pub const wrapOr = helpers.wrapOr(Operation, c.mlirOperationIsNull);
|
||||||
|
pub const eql = helpers.eql(Operation, c.mlirOperationEqual);
|
||||||
pub const eql = Attribute.eqlAny(Self);
|
|
||||||
|
|
||||||
pub fn init(state: *OperationState) !Self {
|
pub fn init(state: *OperationState) !Self {
|
||||||
return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir;
|
return Self.wrapOr(c.mlirOperationCreate(&state._inner)) orelse Error.InvalidMlir;
|
||||||
@ -881,52 +878,33 @@ pub const Operation = struct {
|
|||||||
return .{ ._inner = c.mlirOperationGetContext(self._inner) };
|
return .{ ._inner = c.mlirOperationGetContext(self._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writeBytecode(self: Self, writer: anytype) void {
|
pub fn writeBytecode(self: Self, writer: *std.Io.Writer) std.Io.Writer.Error!void {
|
||||||
var writer_context = .{ .writer = writer };
|
var writer_with_err: WriterWithErr = .{ .writer = writer };
|
||||||
const WriterContext = @TypeOf(writer_context);
|
|
||||||
|
|
||||||
c.mlirOperationWriteBytecode(
|
c.mlirOperationWriteBytecode(
|
||||||
self._inner,
|
self._inner,
|
||||||
(struct {
|
WriterWithErr.printCallback,
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
&writer_with_err,
|
||||||
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
||||||
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch unreachable;
|
|
||||||
}
|
|
||||||
}).callback,
|
|
||||||
&writer_context,
|
|
||||||
);
|
);
|
||||||
|
return writer_with_err.check();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn writeBytecodeWithConfig(self: Self, writer: anytype, config: struct {
|
pub fn writeBytecodeWithConfig(self: Self, writer: *std.Io.Writer, config: struct {
|
||||||
desiredEmitedVersion: ?i64 = null,
|
desiredEmitedVersion: ?i64 = null,
|
||||||
}) !void {
|
}) error{ InvalidMlirBytecodeVersion, WriteFailed }!void {
|
||||||
const cfg = c.mlirBytecodeWriterConfigCreate();
|
const cfg = c.mlirBytecodeWriterConfigCreate();
|
||||||
defer c.mlirBytecodeWriterConfigDestroy(cfg);
|
defer c.mlirBytecodeWriterConfigDestroy(cfg);
|
||||||
if (config.desiredEmitedVersion) |v| {
|
if (config.desiredEmitedVersion) |v| {
|
||||||
c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v);
|
c.mlirBytecodeWriterConfigDesiredEmitVersion(cfg, v);
|
||||||
}
|
}
|
||||||
|
|
||||||
const WriterContext = struct {
|
var writer_with_err: WriterWithErr = .{ .writer = writer };
|
||||||
writer: @TypeOf(writer),
|
|
||||||
write_error: ?@TypeOf(writer).Error = null,
|
|
||||||
};
|
|
||||||
var writer_context: WriterContext = .{ .writer = writer };
|
|
||||||
|
|
||||||
try successOr(c.mlirOperationWriteBytecodeWithConfig(
|
try successOr(c.mlirOperationWriteBytecodeWithConfig(
|
||||||
self._inner,
|
self._inner,
|
||||||
cfg,
|
cfg,
|
||||||
(struct {
|
&WriterWithErr.printCallback,
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
&writer_with_err,
|
||||||
const inner_writer_context: *WriterContext = @ptrCast(@alignCast(ctx_));
|
|
||||||
_ = inner_writer_context.writer.write(str.data[0..str.length]) catch |err| {
|
|
||||||
inner_writer_context.write_error = err;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}).callback,
|
|
||||||
&writer_context,
|
|
||||||
), error.InvalidMlirBytecodeVersion);
|
), error.InvalidMlirBytecodeVersion);
|
||||||
|
return writer_with_err.check();
|
||||||
if (writer_context.write_error) |err| return err;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Enable a full dump of the IR.
|
/// Enable a full dump of the IR.
|
||||||
@ -939,26 +917,18 @@ pub const Operation = struct {
|
|||||||
op: Operation,
|
op: Operation,
|
||||||
flags: OpPrintingFlags,
|
flags: OpPrintingFlags,
|
||||||
|
|
||||||
pub fn format(self: @This(), writer: anytype) !void {
|
pub fn format(self: MlirFormatter, writer: *std.Io.Writer) !void {
|
||||||
self.op.print(writer, self.flags);
|
try self.op.print(writer, self.flags);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) void {
|
pub fn print(self: Self, writer: *std.Io.Writer, flags: OpPrintingFlags) std.Io.Writer.Error!void {
|
||||||
const pflags = flags.create();
|
const pflags = flags.create();
|
||||||
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
defer c.mlirOpPrintingFlagsDestroy(pflags);
|
||||||
|
|
||||||
c.mlirOperationPrintWithFlags(
|
var writer_err: WriterWithErr = .{ .writer = writer };
|
||||||
self._inner,
|
c.mlirOperationPrintWithFlags(self._inner, pflags, WriterWithErr.printCallback, &writer_err);
|
||||||
pflags,
|
return writer_err.check();
|
||||||
(struct {
|
|
||||||
pub fn callback(str: c.MlirStringRef, ctx_: ?*anyopaque) callconv(.c) void {
|
|
||||||
const _writer: *std.Io.Writer = @ptrCast(@alignCast(ctx_));
|
|
||||||
_writer.writeAll(str.data[0..str.length]) catch @panic("Mlir print failed");
|
|
||||||
}
|
|
||||||
}).callback,
|
|
||||||
writer,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify(self: Self) bool {
|
pub fn verify(self: Self) bool {
|
||||||
@ -1065,27 +1035,25 @@ pub const OpPrintingFlags = struct {
|
|||||||
|
|
||||||
pub const OpOperand = struct {
|
pub const OpOperand = struct {
|
||||||
_inner: c.MlirOpOperand,
|
_inner: c.MlirOpOperand,
|
||||||
const Self = OpOperand;
|
pub const wrapOr = helpers.wrapOr(OpOperand, c.mlirOpOperandIsNull);
|
||||||
|
|
||||||
pub fn owner(self: Self) Operation {
|
pub fn owner(self: OpOperand) Operation {
|
||||||
return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) };
|
return .{ ._inner = c.mlirOpOperandGetOwner(self._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn number(self: Self) usize {
|
pub fn number(self: OpOperand) usize {
|
||||||
return @intCast(c.mlirOpOperandGetOperandNumber(self._inner));
|
return @intCast(c.mlirOpOperandGetOperandNumber(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn nextUse(self: Self) ?Self {
|
pub fn nextUse(self: OpOperand) ?OpOperand {
|
||||||
return Self.wrapOr(
|
return wrapOr(c.mlirOpOperandGetNextUse(self._inner));
|
||||||
c.mlirOpOperandGetNextUse(self._inner),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Region = struct {
|
pub const Region = struct {
|
||||||
_inner: c.MlirRegion,
|
_inner: c.MlirRegion,
|
||||||
|
|
||||||
pub const eql = helpers.eql(Region, c.mlirBlockEqual);
|
pub const eql = helpers.eql(Region, c.mlirRegionEqual);
|
||||||
pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy);
|
pub const deinit = helpers.deinit(Region, c.mlirRegionDestroy);
|
||||||
pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull);
|
pub const wrapOr = helpers.wrapOr(Region, c.mlirRegionIsNull);
|
||||||
|
|
||||||
@ -1121,7 +1089,7 @@ pub const Value = struct {
|
|||||||
|
|
||||||
pub const dump = helpers.dump(Value, c.mlirValueDump);
|
pub const dump = helpers.dump(Value, c.mlirValueDump);
|
||||||
pub const eql = helpers.eql(Value, c.mlirValueEqual);
|
pub const eql = helpers.eql(Value, c.mlirValueEqual);
|
||||||
pub const format = helpers.format(Value, c.mlirValuePrint).format;
|
pub const format = helpers.format(Value, c.mlirValuePrint);
|
||||||
pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull);
|
pub const wrapOr = helpers.wrapOr(Value, c.mlirValueIsNull);
|
||||||
|
|
||||||
pub fn getType(val: Value) Type {
|
pub fn getType(val: Value) Type {
|
||||||
@ -1183,7 +1151,7 @@ pub const BlockArgument = struct {
|
|||||||
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
return @bitCast(c.mlirBlockArgumentGetArgNumber(arg._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: BlockArgument, writer: anytype) !void {
|
pub fn format(self: BlockArgument, writer: *std.Io.Writer) !void {
|
||||||
const value = Value{ ._inner = self._inner };
|
const value = Value{ ._inner = self._inner };
|
||||||
return value.format(writer);
|
return value.format(writer);
|
||||||
}
|
}
|
||||||
@ -1192,7 +1160,7 @@ pub const BlockArgument = struct {
|
|||||||
pub const Type = struct {
|
pub const Type = struct {
|
||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
|
|
||||||
pub const dump = helpers.eql(Type, c.mlirTypeDump);
|
pub const dump = helpers.dump(Type, c.mlirTypeDump);
|
||||||
pub const eql = helpers.eql(Type, c.mlirTypeEqual);
|
pub const eql = helpers.eql(Type, c.mlirTypeEqual);
|
||||||
pub const format = helpers.format(Type, c.mlirTypePrint);
|
pub const format = helpers.format(Type, c.mlirTypePrint);
|
||||||
pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull);
|
pub const wrapOr = helpers.wrapOr(Type, c.mlirTypeIsNull);
|
||||||
@ -1230,14 +1198,6 @@ pub const Type = struct {
|
|||||||
}.eql;
|
}.eql;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatAny(SpecificType: type) fn (SpecificType, SpecificType) type {
|
|
||||||
return struct {
|
|
||||||
pub fn format(self: SpecificType, writer: anytype) !void {
|
|
||||||
return try Type.format(self.asType(), writer);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn index(ctx: Context) Type {
|
pub fn index(ctx: Context) Type {
|
||||||
return IndexType.init(ctx).asType();
|
return IndexType.init(ctx).asType();
|
||||||
}
|
}
|
||||||
@ -1280,7 +1240,7 @@ pub const IndexType = struct {
|
|||||||
|
|
||||||
pub const asType = Type.fromAny(IndexType);
|
pub const asType = Type.fromAny(IndexType);
|
||||||
pub const eql = Type.eqlAny(IndexType);
|
pub const eql = Type.eqlAny(IndexType);
|
||||||
pub const format = Type.formatAny(IndexType).format;
|
pub const format = helpers.format(IndexType, c.mlirTypePrint);
|
||||||
|
|
||||||
pub fn init(ctx: Context) IndexType {
|
pub fn init(ctx: Context) IndexType {
|
||||||
return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) };
|
return .{ ._inner = c.mlirIndexTypeGet(ctx._inner) };
|
||||||
@ -1452,7 +1412,7 @@ pub fn ComplexType(comptime ct: ComplexTypes) type {
|
|||||||
|
|
||||||
pub const asType = Type.fromAny(Complex);
|
pub const asType = Type.fromAny(Complex);
|
||||||
pub const eql = Type.eqlAny(Complex);
|
pub const eql = Type.eqlAny(Complex);
|
||||||
pub const format = Type.formatAny(Complex).format;
|
pub const format = helpers.format(Complex, c.mlirTypePrint);
|
||||||
pub const ComplexTypeType: ComplexTypes = ct;
|
pub const ComplexTypeType: ComplexTypes = ct;
|
||||||
|
|
||||||
pub const init = if (ct != .unknown) struct {
|
pub const init = if (ct != .unknown) struct {
|
||||||
@ -1468,51 +1428,50 @@ pub const TupleType = struct {
|
|||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
pub const is_a_fn = c.mlirTypeIsATuple;
|
pub const is_a_fn = c.mlirTypeIsATuple;
|
||||||
|
|
||||||
const Self = TupleType;
|
pub fn init(ctx: Context, elements: []const Type) !TupleType {
|
||||||
|
const tuple_type = c.mlirTupleTypeGet(
|
||||||
pub fn init(ctx: Context, elements: []const Type) !Self {
|
|
||||||
return Self.wrapOr(c.mlirTupleTypeGet(
|
|
||||||
ctx._inner,
|
ctx._inner,
|
||||||
@intCast(elements.len),
|
@intCast(elements.len),
|
||||||
@ptrCast(elements.ptr),
|
@ptrCast(elements.ptr),
|
||||||
)) orelse Error.InvalidMlir;
|
);
|
||||||
|
return .{ ._inner = .{ .ptr = tuple_type.ptr orelse return error.InvalidMlir } };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getNumTypes(self: Self) usize {
|
pub fn getNumTypes(self: TupleType) usize {
|
||||||
return @intCast(c.mlirTupleTypeGetNumTypes(self._inner));
|
return @intCast(c.mlirTupleTypeGetNumTypes(self._inner));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getElementType(self: Self, index: usize) Type {
|
pub fn getElementType(self: TupleType, index: usize) Type {
|
||||||
return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) };
|
return .{ ._inner = c.mlirTupleTypeGetType(self._inner, @intCast(index)) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const asType = Type.fromAny(Self);
|
pub const asType = Type.fromAny(TupleType);
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const FunctionType = struct {
|
pub const FunctionType = struct {
|
||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
pub const is_a_fn = c.mlirTypeIsAFunction;
|
pub const is_a_fn = c.mlirTypeIsAFunction;
|
||||||
const Self = FunctionType;
|
pub const asType = Type.fromAny(FunctionType);
|
||||||
pub const asType = Type.fromAny(Self);
|
pub const eql = Type.eqlAny(FunctionType);
|
||||||
pub const eql = Type.eqlAny(IndexType);
|
|
||||||
|
|
||||||
pub fn init(ctx: Context, args: []const Type, results: []const Type) !Self {
|
pub fn init(ctx: Context, args: []const Type, results: []const Type) !FunctionType {
|
||||||
const func = Type.wrapOr(c.mlirFunctionTypeGet(
|
const func_type = c.mlirFunctionTypeGet(
|
||||||
ctx._inner,
|
ctx._inner,
|
||||||
@intCast(args.len),
|
@intCast(args.len),
|
||||||
@ptrCast(args.ptr),
|
@ptrCast(args.ptr),
|
||||||
@intCast(results.len),
|
@intCast(results.len),
|
||||||
@ptrCast(results.ptr),
|
@ptrCast(results.ptr),
|
||||||
)) orelse return Error.InvalidMlir;
|
);
|
||||||
return func.as(Self).?;
|
return .{ ._inner = .{ .ptr = func_type.ptr orelse return error.InvalidMlir } };
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const RankedTensorType = struct {
|
pub const RankedTensorType = struct {
|
||||||
_inner: c.MlirType,
|
_inner: c.MlirType,
|
||||||
pub const is_a_fn = c.mlirTypeIsARankedTensor;
|
pub const is_a_fn = c.mlirTypeIsARankedTensor;
|
||||||
|
pub const asType = Type.fromAny(RankedTensorType);
|
||||||
pub const eql = Type.eqlAny(RankedTensorType);
|
pub const eql = Type.eqlAny(RankedTensorType);
|
||||||
pub const format = helpers.format(Type, c.mlirTypePrint);
|
pub const format = helpers.format(RankedTensorType, c.mlirTypePrint);
|
||||||
|
|
||||||
pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType {
|
pub fn init(dimensions: []const i64, elemType: Type) RankedTensorType {
|
||||||
return .{ ._inner = c.mlirRankedTensorTypeGet(
|
return .{ ._inner = c.mlirRankedTensorTypeGet(
|
||||||
@ -1534,20 +1493,16 @@ pub const RankedTensorType = struct {
|
|||||||
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
pub fn getDimension(self: RankedTensorType, dim: usize) i64 {
|
||||||
return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim));
|
return c.mlirShapedTypeGetDimSize(self._inner, @intCast(dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const asType = Type.fromAny(RankedTensorType);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub const Dialect = struct {
|
pub const Dialect = struct {
|
||||||
_inner: c.MlirDialect,
|
_inner: c.MlirDialect,
|
||||||
|
|
||||||
const Self = Dialect;
|
pub fn getContext(self: Dialect) Context {
|
||||||
|
|
||||||
pub fn getContext(self: Self) Context {
|
|
||||||
return .{ ._inner = c.mlirDialectGetContext(self._inner) };
|
return .{ ._inner = c.mlirDialectGetContext(self._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getNamespace(self: Self) []const u8 {
|
pub fn getNamespace(self: Dialect) []const u8 {
|
||||||
return fromStringRef(c.mlirDialectGetNamespace(self._inner));
|
return fromStringRef(c.mlirDialectGetNamespace(self._inner));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1579,7 +1534,7 @@ pub const DialectHandle = struct {
|
|||||||
pub const Location = struct {
|
pub const Location = struct {
|
||||||
_inner: c.MlirLocation,
|
_inner: c.MlirLocation,
|
||||||
|
|
||||||
pub const eql = helpers.eql(Type, c.mlirLocationEqual);
|
pub const eql = helpers.eql(Location, c.mlirLocationEqual);
|
||||||
pub const format = helpers.format(Location, c.mlirLocationPrint);
|
pub const format = helpers.format(Location, c.mlirLocationPrint);
|
||||||
|
|
||||||
pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location {
|
pub fn fromSrc(ctx: Context, src: std.builtin.SourceLocation) Location {
|
||||||
@ -1613,17 +1568,17 @@ pub const Location = struct {
|
|||||||
) };
|
) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn named(loc: Location, ctx: Context, loc_name: [:0]const u8) Location {
|
pub fn named(loc: Location, ctx: Context, loc_name: []const u8) Location {
|
||||||
return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) };
|
return .{ ._inner = c.mlirLocationNameGet(ctx._inner, stringRef(loc_name), loc._inner) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location {
|
pub fn namedFmt(loc: Location, ctx: Context, comptime fmt: [:0]const u8, args: anytype) Location {
|
||||||
var buf: [256]u8 = undefined;
|
var buf: [256]u8 = undefined;
|
||||||
var stream = std.io.fixedBufferStream(&buf);
|
var writer: std.Io.Writer = .fixed(&buf);
|
||||||
std.fmt.format(stream.writer(), fmt, args) catch {
|
writer.print(fmt, args) catch {
|
||||||
buf[256 - 3 ..].* = "...".*;
|
buf[256 - 3 ..].* = "...".*;
|
||||||
};
|
};
|
||||||
return loc.named(ctx, @ptrCast(stream.getWritten()));
|
return loc.named(ctx, writer.buffered());
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn unknown(ctx: Context) Location {
|
pub fn unknown(ctx: Context) Location {
|
||||||
@ -1636,7 +1591,6 @@ pub const Block = struct {
|
|||||||
|
|
||||||
pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull);
|
pub const wrapOr = helpers.wrapOr(Block, c.mlirBlockIsNull);
|
||||||
pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy);
|
pub const deinit = helpers.deinit(Block, c.mlirBlockDestroy);
|
||||||
|
|
||||||
pub const eql = helpers.eql(Block, c.mlirBlockEqual);
|
pub const eql = helpers.eql(Block, c.mlirBlockEqual);
|
||||||
|
|
||||||
pub fn init(args: []const Type, locs: []const Location) !Block {
|
pub fn init(args: []const Type, locs: []const Location) !Block {
|
||||||
@ -1736,27 +1690,15 @@ pub const helpers = struct {
|
|||||||
}.isNull;
|
}.isNull;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(Any: type, print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void) type {
|
pub fn format(
|
||||||
|
Any: type,
|
||||||
|
print_fn: fn (@FieldType(Any, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void,
|
||||||
|
) fn (Any, *std.Io.Writer) std.Io.Writer.Error!void {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn format(self: Any, writer: *std.Io.Writer) !void {
|
pub fn format(self: Any, writer: *std.Io.Writer) std.Io.Writer.Error!void {
|
||||||
const WriterWithErr = struct {
|
try callPrintFn(Any, self, print_fn, writer);
|
||||||
writer: *std.Io.Writer,
|
|
||||||
err: ?std.Io.Writer.Error = null,
|
|
||||||
fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
|
||||||
var ctx: *@This() = @ptrCast(@alignCast(opaque_ctx));
|
|
||||||
if (ctx.err) |_| return;
|
|
||||||
_ = ctx.writer.write(mlir_str.data[0..mlir_str.length]) catch |err| {
|
|
||||||
ctx.err = err;
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
var context: WriterWithErr = .{ .writer = writer };
|
|
||||||
print_fn(self._inner, &WriterWithErr.printCallback, &context);
|
|
||||||
if (context.err) |err| return err;
|
|
||||||
}
|
}
|
||||||
};
|
}.format;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
|
pub fn wrapOr(T: type, is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) fn (@FieldType(T, "_inner")) ?T {
|
||||||
@ -1767,9 +1709,35 @@ pub const helpers = struct {
|
|||||||
}
|
}
|
||||||
}.wrapOr;
|
}.wrapOr;
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
pub fn init(T: type, inner: @FieldType(T, "_inner"), is_null_fn: fn (@FieldType(T, "_inner")) callconv(.c) bool) ?T {
|
pub const MlirStrCallback = fn (c.MlirStringRef, ?*anyopaque) callconv(.c) void;
|
||||||
if (is_null_fn(inner)) return null;
|
|
||||||
return .{ ._inner = inner };
|
pub fn callPrintFn(
|
||||||
|
T: type,
|
||||||
|
value: T,
|
||||||
|
print_fn: fn (@FieldType(T, "_inner"), ?*const MlirStrCallback, ?*anyopaque) callconv(.c) void,
|
||||||
|
writer: *std.Io.Writer,
|
||||||
|
) std.Io.Writer.Error!void {
|
||||||
|
var writer_with_err: WriterWithErr = .{ .writer = writer };
|
||||||
|
print_fn(value._inner, &WriterWithErr.printCallback, &writer_with_err);
|
||||||
|
return writer_with_err.check();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const WriterWithErr = struct {
|
||||||
|
writer: *std.Io.Writer,
|
||||||
|
err: ?std.Io.Writer.Error = null,
|
||||||
|
|
||||||
|
pub fn printCallback(mlir_str: c.MlirStringRef, opaque_ctx: ?*anyopaque) callconv(.c) void {
|
||||||
|
var ctx: *WriterWithErr = @ptrCast(@alignCast(opaque_ctx));
|
||||||
|
if (ctx.err) |_| return;
|
||||||
|
ctx.writer.writeAll(fromStringRef(mlir_str)) catch |err| {
|
||||||
|
ctx.err = err;
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check(self: WriterWithErr) std.Io.Writer.Error!void {
|
||||||
|
if (self.err) |err| return err;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
12
pjrt/ffi.zig
12
pjrt/ffi.zig
@ -265,16 +265,8 @@ pub const Buffer = extern struct {
|
|||||||
return self._dims[0..self.rank];
|
return self._dims[0..self.rank];
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(buffer: Buffer, writer: *std.Io.Writer) !void {
|
||||||
buffer: Buffer,
|
try writer.print("FfiBuffer({any}, .{t})@0x{x}", .{ buffer.dims(), buffer.dtype, @intFromPtr(buffer.data) });
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
|
|
||||||
try writer.print("FfiBuffer({d}, .{s})@0x{x}", .{ buffer.dims(), @tagName(buffer.dtype), @intFromPtr(buffer.data) });
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -1257,14 +1257,7 @@ pub const NamedValue = extern struct {
|
|||||||
}) };
|
}) };
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: NamedValue, writer: *std.Io.Writer) !void {
|
||||||
self: NamedValue,
|
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] });
|
try writer.print("{s}{{ .name = {s},", .{ @typeName(NamedValue), self.inner.name[0..self.inner.name_size] });
|
||||||
const u = self.inner.unnamed_0;
|
const u = self.inner.unnamed_0;
|
||||||
switch (self.kind()) {
|
switch (self.kind()) {
|
||||||
|
|||||||
@ -117,8 +117,8 @@ fn wrapNeffAsCustomCall(allocator: std.mem.Allocator, hlo_code: []const u8, neff
|
|||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
var operand_ids: std.ArrayListUnmanaged(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]);
|
var operand_ids: std.ArrayList(i64) = .initBuffer(c.xla_HloInstructionProto_resize_operand_ids(fused_root, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]);
|
||||||
var new_instructions: std.ArrayListUnmanaged(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]));
|
var new_instructions: std.ArrayList(*const c.xla_HloInstructionProto) = .initBuffer(@ptrCast(c.xla_HloComputationProto_resize_instructions(entry, parameters_len + 1, upb_arena)[0 .. parameters_len + 1]));
|
||||||
for (entry_instructions) |instruction| {
|
for (entry_instructions) |instruction| {
|
||||||
if (std.mem.eql(u8, upb.slice(c.xla_HloInstructionProto_opcode(instruction)) orelse continue, "parameter")) {
|
if (std.mem.eql(u8, upb.slice(c.xla_HloInstructionProto_opcode(instruction)) orelse continue, "parameter")) {
|
||||||
const id = c.xla_HloInstructionProto_id(instruction);
|
const id = c.xla_HloInstructionProto_id(instruction);
|
||||||
|
|||||||
@ -9,6 +9,10 @@ fn FmtSlice(T: type) type {
|
|||||||
return struct {
|
return struct {
|
||||||
slice: []const T,
|
slice: []const T,
|
||||||
|
|
||||||
|
pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
|
||||||
|
return try formatSliceAny(f.slice, .{}, writer);
|
||||||
|
}
|
||||||
|
|
||||||
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
|
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
|
||||||
|
|||||||
@ -315,14 +315,7 @@ pub const Metadata = union(enum) {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: Metadata, writer: *std.Io.Writer) !void {
|
||||||
self: Metadata,
|
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
switch (self) {
|
switch (self) {
|
||||||
.null => _ = try writer.write("null"),
|
.null => _ = try writer.write("null"),
|
||||||
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),
|
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
const async = @import("async");
|
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
|
const async = @import("async");
|
||||||
|
|
||||||
const zml = @import("../zml.zig");
|
const zml = @import("../zml.zig");
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayList(u8);
|
||||||
const Allocator = std.mem.Allocator;
|
|
||||||
|
|
||||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
const file = try std.fs.cwd().openFile(path, .{});
|
const file = try std.fs.cwd().openFile(path, .{});
|
||||||
defer file.close();
|
defer file.close();
|
||||||
@ -26,7 +26,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
|
pub fn parseMetadata(allocator: std.mem.Allocator, store: *zml.aio.BufferStore, prefix: StringBuilder, val: std.json.Value) !void {
|
||||||
const metadata = &store._metadata;
|
const metadata = &store._metadata;
|
||||||
const key = prefix.items;
|
const key = prefix.items;
|
||||||
return switch (val) {
|
return switch (val) {
|
||||||
|
|||||||
@ -9,7 +9,6 @@ const zml = @import("../zml.zig");
|
|||||||
const HostBuffer = zml.HostBuffer;
|
const HostBuffer = zml.HostBuffer;
|
||||||
const json = @import("json.zig");
|
const json = @import("json.zig");
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
|
||||||
const log = std.log.scoped(.@"zml/io");
|
const log = std.log.scoped(.@"zml/io");
|
||||||
|
|
||||||
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||||
@ -17,19 +16,18 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
errdefer res.arena.deinit();
|
errdefer res.arena.deinit();
|
||||||
const arena = res.arena.allocator();
|
const arena = res.arena.allocator();
|
||||||
|
|
||||||
var files = std.array_list.Managed(MemoryMappedFile).init(arena);
|
var files: std.ArrayList(MemoryMappedFile) = .empty;
|
||||||
errdefer files.deinit();
|
|
||||||
|
|
||||||
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
|
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
|
||||||
try loadFromIndex(arena, &res, &files, path);
|
try loadFromIndex(arena, &res, &files, path);
|
||||||
} else {
|
} else {
|
||||||
try loadFile(arena, &res, &files, path);
|
try loadFile(arena, &res, &files, path);
|
||||||
}
|
}
|
||||||
res.files = try files.toOwnedSlice();
|
res.files = try files.toOwnedSlice(allocator);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
|
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
||||||
const file = async.File.open(path, .{}) catch |err| {
|
const file = async.File.open(path, .{}) catch |err| {
|
||||||
log.err("Failed to open {s}: {}", .{ path, err });
|
log.err("Failed to open {s}: {}", .{ path, err });
|
||||||
return err;
|
return err;
|
||||||
@ -61,11 +59,11 @@ fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.
|
|||||||
|
|
||||||
if (index.object.get("__metadata__")) |metadata| {
|
if (index.object.get("__metadata__")) |metadata| {
|
||||||
var prefix_buf: [1024]u8 = undefined;
|
var prefix_buf: [1024]u8 = undefined;
|
||||||
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), metadata);
|
try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), metadata);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array_list.Managed(MemoryMappedFile), path: []const u8) !void {
|
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
||||||
const file = async.File.open(path, .{}) catch |err| {
|
const file = async.File.open(path, .{}) catch |err| {
|
||||||
log.err("Failed to open {s}: {}", .{ path, err });
|
log.err("Failed to open {s}: {}", .{ path, err });
|
||||||
return err;
|
return err;
|
||||||
@ -87,7 +85,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array
|
|||||||
errdefer buffer_file.deinit();
|
errdefer buffer_file.deinit();
|
||||||
buffer_file.data_offset = 8 + json_header_length;
|
buffer_file.data_offset = 8 + json_header_length;
|
||||||
|
|
||||||
try files.append(buffer_file);
|
try files.append(allocator, buffer_file);
|
||||||
errdefer _ = files.pop();
|
errdefer _ = files.pop();
|
||||||
|
|
||||||
var it = metadata.object.iterator();
|
var it = metadata.object.iterator();
|
||||||
@ -95,7 +93,7 @@ fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.array
|
|||||||
const key = entry.key_ptr.*;
|
const key = entry.key_ptr.*;
|
||||||
if (std.mem.eql(u8, key, "__metadata__")) {
|
if (std.mem.eql(u8, key, "__metadata__")) {
|
||||||
var prefix_buf: [1024]u8 = undefined;
|
var prefix_buf: [1024]u8 = undefined;
|
||||||
try json.parseMetadata(allocator, store, StringBuilder.initBuffer(&prefix_buf), entry.value_ptr.*);
|
try json.parseMetadata(allocator, store, .initBuffer(&prefix_buf), entry.value_ptr.*);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const val = entry.value_ptr.*;
|
const val = entry.value_ptr.*;
|
||||||
|
|||||||
@ -6,7 +6,6 @@ const zml = @import("../zml.zig");
|
|||||||
const eval = @import("torch/eval.zig");
|
const eval = @import("torch/eval.zig");
|
||||||
const File = @import("torch/file.zig").File;
|
const File = @import("torch/file.zig").File;
|
||||||
|
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
|
||||||
const log = std.log.scoped(.@"zml/aio");
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
|
|||||||
@ -348,7 +348,7 @@ pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bo
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
|
fn append(allocator: std.mem.Allocator, current: *[]py.Any, values: []const py.Any) !void {
|
||||||
var array_list = std.ArrayListUnmanaged(py.Any).fromOwnedSlice(current.*);
|
var array_list = std.ArrayList(py.Any).fromOwnedSlice(current.*);
|
||||||
try array_list.appendSlice(allocator, values);
|
try array_list.appendSlice(allocator, values);
|
||||||
current.* = array_list.items;
|
current.* = array_list.items;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -13,7 +13,7 @@ const py = @import("py.zig");
|
|||||||
const log = std.log.scoped(.@"zml/aio");
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
// TODO(cryptodeal): use zml.aio.PrefixBuilder instead
|
||||||
const StringBuilder = std.ArrayListUnmanaged(u8);
|
const StringBuilder = std.ArrayList(u8);
|
||||||
|
|
||||||
test {
|
test {
|
||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
@ -191,7 +191,7 @@ pub const File = struct {
|
|||||||
.boolval => bool,
|
.boolval => bool,
|
||||||
else => unreachable,
|
else => unreachable,
|
||||||
};
|
};
|
||||||
var values: std.ArrayListUnmanaged(ItemType) = .{};
|
var values: std.ArrayList(ItemType) = .{};
|
||||||
try values.append(allocator, val0);
|
try values.append(allocator, val0);
|
||||||
for (seq.values[1..], 1..) |val, i| {
|
for (seq.values[1..], 1..) |val, i| {
|
||||||
if (std.meta.activeTag(val) != tag) valid_slice = false;
|
if (std.meta.activeTag(val) != tag) valid_slice = false;
|
||||||
|
|||||||
@ -768,7 +768,7 @@ pub fn parse(arena: std.mem.Allocator, reader: *std.Io.Reader) ![]const Op {
|
|||||||
// It's not very efficient to interleave the results with the data copied from the stream,
|
// It's not very efficient to interleave the results with the data copied from the stream,
|
||||||
// because growth event in the results ArrayList will lead to fragmentation.
|
// because growth event in the results ArrayList will lead to fragmentation.
|
||||||
// Trying to mitigate that by using a generous default size.
|
// Trying to mitigate that by using a generous default size.
|
||||||
var results: std.ArrayListUnmanaged(Op) = try .initCapacity(arena, 512);
|
var results: std.ArrayList(Op) = try .initCapacity(arena, 512);
|
||||||
errdefer results.deinit(arena);
|
errdefer results.deinit(arena);
|
||||||
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);
|
var alloc_writer = try std.Io.Writer.Allocating.initCapacity(arena, 512);
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const math = std.math;
|
const math = std.math;
|
||||||
const log = std.log.scoped(.@"zml/aio");
|
|
||||||
|
|
||||||
const pickle = @import("pickle.zig");
|
const pickle = @import("pickle.zig");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.@"zml/aio");
|
||||||
|
|
||||||
/// Correspond to a function/constructor call
|
/// Correspond to a function/constructor call
|
||||||
pub const Object = struct {
|
pub const Object = struct {
|
||||||
member: Any,
|
member: Any,
|
||||||
@ -206,12 +207,16 @@ pub const Any = union(Kind) {
|
|||||||
self.* = undefined;
|
self.* = undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn writeIndents(indents: usize, writer: anytype) !void {
|
inline fn writeIndents(indents: usize, writer: *std.Io.Writer) !void {
|
||||||
try writer.writeBytesNTimes(" ", indents); // resolve tab = 2 spaces
|
try writer.writeBytesNTimes(" ", indents); // resolve tab = 2 spaces
|
||||||
// try writer.writeByteNTimes('\t');
|
// try writer.writeByteNTimes('\t');
|
||||||
}
|
}
|
||||||
|
|
||||||
fn internalFormat(value: Any, indents: usize, writer: anytype) !void {
|
pub fn format(self: Any, writer: *std.Io.Writer) !void {
|
||||||
|
return internalFormat(self, 0, writer);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn internalFormat(value: Any, indents: usize, writer: *std.Io.Writer) !void {
|
||||||
try writeIndents(indents, writer);
|
try writeIndents(indents, writer);
|
||||||
try writer.writeAll(".{\n");
|
try writer.writeAll(".{\n");
|
||||||
try writeIndents(indents + 1, writer);
|
try writeIndents(indents + 1, writer);
|
||||||
@ -303,10 +308,6 @@ pub const Any = union(Kind) {
|
|||||||
try writer.writeByte('}');
|
try writer.writeByte('}');
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(self: Any, comptime _: []const u8, _: std.fmt.FormatOptions, writer: anytype) !void {
|
|
||||||
return internalFormat(self, 0, writer);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn clone(self: Any, allocator: std.mem.Allocator) !Any {
|
pub fn clone(self: Any, allocator: std.mem.Allocator) !Any {
|
||||||
return switch (self) {
|
return switch (self) {
|
||||||
inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)),
|
inline .raw, .raw_num => |v, tag| @unionInit(Any, @tagName(tag), try v.clone(allocator)),
|
||||||
|
|||||||
@ -384,10 +384,7 @@ pub const Buffer = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: Buffer, writer: *std.Io.Writer) !void {
|
||||||
self: Buffer,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
try writer.print("Buffer({f})", .{self._shape});
|
try writer.print("Buffer({f})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -95,6 +95,8 @@ pub fn compileFn(
|
|||||||
) !FnExe(func) {
|
) !FnExe(func) {
|
||||||
var pretty_name = try prettyFnName(func, allocator);
|
var pretty_name = try prettyFnName(func, allocator);
|
||||||
defer pretty_name.deinit(allocator);
|
defer pretty_name.deinit(allocator);
|
||||||
|
log.info("Compiling {s} with {f}", .{ pretty_name.items, stdx.fmt.any(args) });
|
||||||
|
|
||||||
var context = try CompilationContext.init(allocator, pretty_name.items, platform);
|
var context = try CompilationContext.init(allocator, pretty_name.items, platform);
|
||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
|
|
||||||
@ -306,7 +308,7 @@ pub const BaseExe = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
pub fn serialize(self: BaseExe, writer: *std.Io.Writer) !void {
|
||||||
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
||||||
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||||
defer serialize_result.deinit();
|
defer serialize_result.deinit();
|
||||||
@ -377,7 +379,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|||||||
try self.inner.bind(T, value);
|
try self.inner.bind(T, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn serialize(self: Self, writer: anytype) !void {
|
pub fn serialize(self: Self, writer: *std.Io.Writer) !void {
|
||||||
return try self.inner.serialize(writer);
|
return try self.inner.serialize(writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -437,7 +439,7 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
|
|||||||
fn prettyFnName(
|
fn prettyFnName(
|
||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
) !std.ArrayListUnmanaged(u8) {
|
) !std.ArrayList(u8) {
|
||||||
const full_noisy_name = @typeName(@TypeOf(func));
|
const full_noisy_name = @typeName(@TypeOf(func));
|
||||||
const og_len = full_noisy_name.len;
|
const og_len = full_noisy_name.len;
|
||||||
const buffer = try allocator.alloc(u8, og_len);
|
const buffer = try allocator.alloc(u8, og_len);
|
||||||
|
|||||||
@ -321,10 +321,7 @@ pub const HostBuffer = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: HostBuffer, writer: *std.Io.Writer) !void {
|
||||||
self: HostBuffer,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
try writer.print("HostBuffer(.{f})", .{self._shape});
|
try writer.print("HostBuffer(.{f})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
30
zml/meta.zig
30
zml/meta.zig
@ -679,28 +679,38 @@ test zip {
|
|||||||
|
|
||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
||||||
pub fn collect(func: anytype, func_ctx: _CollectCtx(func), out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT), obj: anytype) error{OutOfMemory}!void {
|
pub fn collectAlloc(
|
||||||
|
func: anytype,
|
||||||
|
func_ctx: _CollectCtx(func),
|
||||||
|
allocator: std.mem.Allocator,
|
||||||
|
obj: anytype,
|
||||||
|
) std.mem.Allocator.Error![]stdx.meta.FnSignature(func, null).ReturnT {
|
||||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collect expects a func with two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
|
||||||
|
const CollectAllocCtx = struct {
|
||||||
func_ctx: _CollectCtx(func),
|
func_ctx: _CollectCtx(func),
|
||||||
out: *std.array_list.Managed(stdx.meta.FnSignature(func, null).ReturnT),
|
allocator: std.mem.Allocator,
|
||||||
|
out: std.ArrayList(stdx.meta.FnSignature(func, null).ReturnT) = .empty,
|
||||||
oom: bool = false,
|
oom: bool = false,
|
||||||
};
|
|
||||||
var context = LocalContext{ .func_ctx = func_ctx, .out = out };
|
fn cb(ctx: *@This(), val: *const _CollectArg(func)) void {
|
||||||
visit((struct {
|
|
||||||
fn cb(ctx: *LocalContext, val: *const _CollectArg(func)) void {
|
|
||||||
if (ctx.oom) return;
|
if (ctx.oom) return;
|
||||||
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
|
const res = if (_CollectCtx(func) == void) func(val.*) else func(ctx.func_ctx, val.*);
|
||||||
ctx.out.append(res) catch {
|
ctx.out.append(ctx.allocator, res) catch {
|
||||||
ctx.oom = true;
|
ctx.oom = true;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}).cb, &context, obj);
|
};
|
||||||
|
var context = CollectAllocCtx{ .func_ctx = func_ctx, .allocator = allocator };
|
||||||
|
visit(CollectAllocCtx.cb, &context, obj);
|
||||||
if (context.oom) return error.OutOfMemory;
|
if (context.oom) return error.OutOfMemory;
|
||||||
|
|
||||||
|
return context.out.toOwnedSlice(allocator);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
/// Given a func(X) -> Y or a func(Ctx, X) -> Y,
|
||||||
/// finds all X in the given object, and write the result of func(X) into an arraylist.
|
/// finds all X in the given object, and write the result of func(X) into a given slice.
|
||||||
|
/// Asserts that the number of X found is equal to the slice len.
|
||||||
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
|
pub fn collectBuf(func: anytype, func_ctx: _CollectCtx(func), obj: anytype, out: []stdx.meta.FnResult(func)) void {
|
||||||
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
|
stdx.debug.assertComptime(@typeInfo(@TypeOf(func)).@"fn".params.len <= 2, "zml.meta.collectBuf expects a func with one or two arguments, got: {}", .{@TypeOf(func)});
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
|
|||||||
@ -96,6 +96,6 @@ pub const Type = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std.debug.panic("Could not convert mlir.Type to DataType: {}", .{mlir_type});
|
std.debug.panic("Could not convert mlir.Type to DataType: {f}", .{mlir_type});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -187,7 +187,8 @@ pub const CompilationContext = struct {
|
|||||||
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
|
if (cache_dir.createFile(mlir_name, .{ .truncate = true })) |file| {
|
||||||
var write_buf: [4096]u8 = undefined;
|
var write_buf: [4096]u8 = undefined;
|
||||||
var writer = file.writer(&write_buf);
|
var writer = file.writer(&write_buf);
|
||||||
module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
|
try module.op().print(&writer.interface, .{ .debug_info = true, .debug_info_pretty_form = false });
|
||||||
|
try writer.interface.flush();
|
||||||
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
|
log.info("Wrote MLIR to {s}/{s}", .{ module_dir.?, mlir_name });
|
||||||
} else |_| {
|
} else |_| {
|
||||||
log.warn("Failed to open {s}", .{mlir_name});
|
log.warn("Failed to open {s}", .{mlir_name});
|
||||||
@ -341,12 +342,11 @@ pub const CompilationContext = struct {
|
|||||||
const locations = try arena.alloc(mlir.Location, tensor_count);
|
const locations = try arena.alloc(mlir.Location, tensor_count);
|
||||||
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
@memset(locations, mlir.Location.unknown(mlir_ctx));
|
||||||
|
|
||||||
var input_shapes: std.array_list.Managed(Shape) = try .initCapacity(res_allocator, tensor_count);
|
const input_shapes = try res_allocator.alloc(Shape, tensor_count);
|
||||||
meta.collect(Tensor.shape, {}, &input_shapes, args) catch unreachable;
|
meta.collectBuf(Tensor.shape, {}, args, input_shapes);
|
||||||
stdx.debug.internalAssert(input_shapes.items.len == tensor_count, "args have changed ?", .{});
|
|
||||||
|
|
||||||
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
||||||
for (input_types, input_shapes.items) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
|
for (input_types, input_shapes) |*t, sh| t.* = mlirx.tensorType(mlir_ctx, sh);
|
||||||
|
|
||||||
const og_block_args = self._block_args;
|
const og_block_args = self._block_args;
|
||||||
defer {
|
defer {
|
||||||
@ -399,7 +399,7 @@ pub const CompilationContext = struct {
|
|||||||
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
||||||
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
|
self.addOutputMemoryKindAttributes(res_attrs, fn_res_output_memory_kind);
|
||||||
if (self._platform.sharding().num_partitions > 1) {
|
if (self._platform.sharding().num_partitions > 1) {
|
||||||
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
|
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes, fn_res_shapes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -427,7 +427,7 @@ pub const CompilationContext = struct {
|
|||||||
return .{
|
return .{
|
||||||
.mlir_fn = mlir_fn,
|
.mlir_fn = mlir_fn,
|
||||||
.name = opts.name,
|
.name = opts.name,
|
||||||
.args_shapes = input_shapes.items,
|
.args_shapes = input_shapes,
|
||||||
.res_tensors = fn_res,
|
.res_tensors = fn_res,
|
||||||
.res_types = fn_res_types,
|
.res_types = fn_res_types,
|
||||||
.res_shapes = fn_res_shapes,
|
.res_shapes = fn_res_shapes,
|
||||||
@ -506,9 +506,9 @@ pub const CompilationContext = struct {
|
|||||||
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
||||||
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
const f = try comp.emitMlir(Local._fwd, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
||||||
|
|
||||||
var mlir_bytecode = std.array_list.Managed(u8).init(std.testing.allocator);
|
var mlir_code: std.Io.Writer.Allocating = .init(std.testing.allocator);
|
||||||
defer mlir_bytecode.deinit();
|
defer mlir_code.deinit();
|
||||||
try mlir_bytecode.writer().print("{f}", .{f.mlir_fn.mlirFormatter(.{})});
|
try f.mlir_fn.print(&mlir_code.writer, .{});
|
||||||
|
|
||||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||||
@ -518,8 +518,8 @@ pub const CompilationContext = struct {
|
|||||||
var buf = template.*;
|
var buf = template.*;
|
||||||
for (0..2) |i| {
|
for (0..2) |i| {
|
||||||
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
|
const alias_attr = std.fmt.bufPrint(&buf, template, .{i}) catch unreachable;
|
||||||
std.testing.expect(std.mem.indexOf(u8, mlir_bytecode.items, alias_attr) != null) catch |err| {
|
std.testing.expect(std.mem.indexOf(u8, mlir_code.written(), alias_attr) != null) catch |err| {
|
||||||
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_bytecode.items});
|
log.warn("Didn't produced the expected IR:\n{s}", .{mlir_code.written()});
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -547,12 +547,14 @@ pub const CompilationContext = struct {
|
|||||||
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.Attribute {
|
||||||
const ctx = self.mlirCtx();
|
const ctx = self.mlirCtx();
|
||||||
const num_partitions = self.numPartitions();
|
const num_partitions = self.numPartitions();
|
||||||
var sharding_str: stdx.BoundedArray(u8, 128) = .{};
|
// This is big enough, see test below for examples values
|
||||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
var sharding_str_buf: [64]u8 = undefined;
|
||||||
return mlir.Attribute.string(ctx, sharding_str.constSlice());
|
var writer: std.Io.Writer = .fixed(&sharding_str_buf);
|
||||||
|
writeShardingRepresentation(shape, num_partitions, &writer) catch unreachable;
|
||||||
|
return mlir.Attribute.string(ctx, writer.buffered());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: anytype) @TypeOf(writer).Error!void {
|
fn writeShardingRepresentation(shape: Shape, num_partitions: u8, writer: *std.Io.Writer) std.Io.Writer.Error!void {
|
||||||
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
|
const n_sharded: u8 = @popCount(@as(u8, @bitCast(shape._sharding_info)));
|
||||||
if (n_sharded == 0 or num_partitions == 1) {
|
if (n_sharded == 0 or num_partitions == 1) {
|
||||||
try writer.writeAll("{replicated}");
|
try writer.writeAll("{replicated}");
|
||||||
@ -567,26 +569,26 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test writeShardingRepresentation {
|
test writeShardingRepresentation {
|
||||||
var rule: [64]u8 = undefined;
|
var attr_buf: [64]u8 = undefined;
|
||||||
const x = Shape.init(.{ 16, 8 }, .f32);
|
const x = Shape.init(.{ 16, 8 }, .f32);
|
||||||
|
|
||||||
// By default tensors are replicated.
|
// By default tensors are replicated.
|
||||||
{
|
{
|
||||||
var fbs = std.io.fixedBufferStream(&rule);
|
var writer: std.Io.Writer = .fixed(&attr_buf);
|
||||||
try writeShardingRepresentation(x, 4, fbs.writer());
|
try writeShardingRepresentation(x, 4, &writer);
|
||||||
try std.testing.expectEqualStrings("{replicated}", fbs.getWritten());
|
try std.testing.expectEqualStrings("{replicated}", writer.buffered());
|
||||||
}
|
}
|
||||||
// Shard along first axis.
|
// Shard along first axis.
|
||||||
{
|
{
|
||||||
var fbs = std.io.fixedBufferStream(&rule);
|
var writer: std.Io.Writer = .fixed(&attr_buf);
|
||||||
try writeShardingRepresentation(x.withSharding(.{0}), 4, fbs.writer());
|
try writeShardingRepresentation(x.withSharding(.{0}), 4, &writer);
|
||||||
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", fbs.getWritten());
|
try std.testing.expectEqualStrings("{devices=[4,1]<=[4]}", writer.buffered());
|
||||||
}
|
}
|
||||||
// Also shard along second axis.
|
// Also shard along second axis.
|
||||||
{
|
{
|
||||||
var fbs = std.io.fixedBufferStream(&rule);
|
var writer: std.Io.Writer = .fixed(&attr_buf);
|
||||||
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, fbs.writer());
|
try writeShardingRepresentation(x.withSharding(.{ 0, 1 }), 2, &writer);
|
||||||
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", fbs.getWritten());
|
try std.testing.expectEqualStrings("{devices=[2,2]<=[2]}", writer.buffered());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
46
zml/ops.zig
46
zml/ops.zig
@ -443,29 +443,23 @@ pub fn if_(
|
|||||||
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
const true_branch_block, const true_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &true_branch_fn, blkctx, {});
|
||||||
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
const false_branch_block, const false_branch_res = ctx.makeBlock(.open, TrueBlockSignature, &false_branch_fn, blkctx, {});
|
||||||
|
|
||||||
var true_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
|
check: {
|
||||||
defer true_shapes.deinit();
|
const arena = ctx.allocator();
|
||||||
var false_shapes = std.array_list.Managed(Shape).init(ctx.allocator());
|
const true_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &true_branch_res) catch break :check;
|
||||||
defer false_shapes.deinit();
|
defer arena.free(true_shapes);
|
||||||
|
|
||||||
var failed_to_collect = false;
|
const false_shapes = meta.collectAlloc(Tensor.shape, {}, arena, &false_branch_res) catch break :check;
|
||||||
meta.collect(Tensor.shape, {}, &true_shapes, &true_branch_res) catch {
|
defer arena.free(false_shapes);
|
||||||
failed_to_collect = true;
|
|
||||||
};
|
stdx.debug.assert(true_shapes.len == false_shapes.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) });
|
||||||
meta.collect(Tensor.shape, {}, &false_shapes, &false_branch_res) catch {
|
for (true_shapes, false_shapes) |true_shape, false_shape| {
|
||||||
failed_to_collect = true;
|
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {f}\n -false branch: {f}", .{ stdx.fmt.slice(true_shapes), stdx.fmt.slice(false_shapes) });
|
||||||
};
|
|
||||||
if (!failed_to_collect) {
|
|
||||||
stdx.debug.assert(true_shapes.items.len == false_shapes.items.len, "zml.ops.if_ expects the true and false branch to produce the same number of tensors. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
|
|
||||||
for (true_shapes.items, false_shapes.items) |true_shape, false_shape| {
|
|
||||||
stdx.debug.assert(true_shape.eqlWithTags(false_shape), "zml.ops.if_ expects the true and false branch to produce tensors of the same shape. Got: \n - true branch: {any}\n -false branch: {any}", .{ true_shapes.items, false_shapes.items });
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const scalar_pred = if (pred.rank() == 0) pred else pred.flattenAll().squeeze(0);
|
|
||||||
const loc = ctx.mlirCtx().location(@src());
|
const loc = ctx.mlirCtx().location(@src());
|
||||||
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
|
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.if", .{
|
||||||
.operands = &.{scalar_pred.value()},
|
.operands = &.{pred.asScalar().value()},
|
||||||
.result_type_inference = true,
|
.result_type_inference = true,
|
||||||
.blocks = &.{ true_branch_block, false_branch_block },
|
.blocks = &.{ true_branch_block, false_branch_block },
|
||||||
// We can't verify right away, cause the weights captured by the if haven't been added yet.
|
// We can't verify right away, cause the weights captured by the if haven't been added yet.
|
||||||
@ -958,18 +952,20 @@ pub fn scatter(
|
|||||||
const UpdateS = BlockSign(update_fn);
|
const UpdateS = BlockSign(update_fn);
|
||||||
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
|
const update_block, _ = ctx.makeBlock(.hermetic, UpdateS, update_fn, blkctx, .{ _scalar, _scalar });
|
||||||
|
|
||||||
var input_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_inputs) catch @panic("OOM");
|
const arena = ctx.allocator();
|
||||||
defer input_values.deinit();
|
const input_values = arena.alloc(mlir.Value, n_inputs) catch @panic("OOM");
|
||||||
meta.collect(CompilationContext.getValue, ctx, &input_values, &inputs) catch unreachable;
|
defer arena.free(input_values);
|
||||||
var updates_values = std.array_list.Managed(mlir.Value).initCapacity(ctx.allocator(), n_updates) catch @panic("OOM");
|
meta.collectBuf(CompilationContext.getValue, ctx, &inputs, input_values);
|
||||||
defer updates_values.deinit();
|
|
||||||
meta.collect(CompilationContext.getValue, ctx, &updates_values, &updates) catch unreachable;
|
const updates_values = arena.alloc(mlir.Value, n_updates) catch @panic("OOM");
|
||||||
|
defer arena.free(updates_values);
|
||||||
|
meta.collectBuf(CompilationContext.getValue, ctx, &updates, updates_values);
|
||||||
|
|
||||||
const op = dialect.stablehlo.scatter(
|
const op = dialect.stablehlo.scatter(
|
||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
input_values.items,
|
input_values,
|
||||||
&.{indices.value()},
|
&.{indices.value()},
|
||||||
updates_values.items,
|
updates_values,
|
||||||
update_block,
|
update_block,
|
||||||
.{
|
.{
|
||||||
.update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(),
|
.update_window_dims = _collectAxes(AxisKind, config.up_kind, .update_window).constSlice(),
|
||||||
|
|||||||
@ -75,14 +75,17 @@ pub const Client = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
fn compileSync(self: *const Client, api: *const Api, allocator: std.mem.Allocator, module: mlir.Module, compile_options_pb: []const u8) CompileError!*LoadedExecutable {
|
||||||
var bytecode: std.array_list.Managed(u8) = .init(allocator);
|
var bytecode: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096);
|
||||||
defer bytecode.deinit();
|
defer bytecode.deinit();
|
||||||
module.op().writeBytecodeWithConfig(bytecode.writer(), .{ .desiredEmitedVersion = 1 }) catch |err| {
|
module.op().writeBytecodeWithConfig(&bytecode.writer, .{ .desiredEmitedVersion = 1 }) catch |err| {
|
||||||
log.err("failed to write module bytecode: {}", .{err});
|
log.err("failed to write module bytecode: {}", .{err});
|
||||||
return err;
|
return switch (err) {
|
||||||
|
std.Io.Writer.Error.WriteFailed => error.OutOfMemory,
|
||||||
|
else => |e| e,
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
var serialized_buffer: std.array_list.Managed(u8) = .init(allocator);
|
var serialized_buffer: std.Io.Writer.Allocating = try .initCapacity(allocator, 4096);
|
||||||
defer serialized_buffer.deinit();
|
defer serialized_buffer.deinit();
|
||||||
|
|
||||||
const stablehlo_version = blk: {
|
const stablehlo_version = blk: {
|
||||||
@ -92,13 +95,16 @@ pub const Client = opaque {
|
|||||||
break :blk dialects.stablehlo.getMinimumVersion();
|
break :blk dialects.stablehlo.getMinimumVersion();
|
||||||
};
|
};
|
||||||
|
|
||||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
dialects.stablehlo.serializePortableArtifact(bytecode.written(), stablehlo_version, &serialized_buffer.writer) catch |err| {
|
||||||
log.err("failed to serialize to portable artifact: {}", .{err});
|
log.err("failed to serialize to portable artifact: {}", .{err});
|
||||||
return err;
|
return switch (err) {
|
||||||
|
std.Io.Writer.Error.WriteFailed => error.OutOfMemory,
|
||||||
|
else => |e| e,
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
return @ptrCast(try self.inner().compile(api, .{
|
return @ptrCast(try self.inner().compile(api, .{
|
||||||
.bytecode = serialized_buffer.items,
|
.bytecode = serialized_buffer.written(),
|
||||||
.bytecode_format = .mlir,
|
.bytecode_format = .mlir,
|
||||||
.compile_options_pb = compile_options_pb,
|
.compile_options_pb = compile_options_pb,
|
||||||
}));
|
}));
|
||||||
|
|||||||
@ -95,7 +95,7 @@ const _CreateOptions = struct {
|
|||||||
pub const Cpu = struct {
|
pub const Cpu = struct {
|
||||||
device_count: u32,
|
device_count: u32,
|
||||||
|
|
||||||
fn writeNamedValues(self: Cpu, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
fn writeNamedValues(self: Cpu, values: *std.ArrayList(pjrt.NamedValue)) void {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
|
values.appendAssumeCapacity(pjrt.NamedValue.from("cpu_device_count", @as(i64, self.device_count)));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -124,7 +124,7 @@ const _CreateOptions = struct {
|
|||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
fn writeNamedValues(self: Cuda, values: *std.ArrayListUnmanaged(pjrt.NamedValue)) void {
|
fn writeNamedValues(self: Cuda, values: *std.ArrayList(pjrt.NamedValue)) void {
|
||||||
switch (self.allocator) {
|
switch (self.allocator) {
|
||||||
.platform => {
|
.platform => {
|
||||||
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
values.appendAssumeCapacity(pjrt.NamedValue.fromString("allocator", "platform"));
|
||||||
@ -145,7 +145,7 @@ const _CreateOptions = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
|
pub fn toNamedValues(self: _CreateOptions, target: Target, out: []pjrt.NamedValue) []pjrt.NamedValue {
|
||||||
var values = std.ArrayListUnmanaged(pjrt.NamedValue).fromOwnedSlice(out);
|
var values = std.ArrayList(pjrt.NamedValue).fromOwnedSlice(out);
|
||||||
values.shrinkRetainingCapacity(0);
|
values.shrinkRetainingCapacity(0);
|
||||||
switch (target) {
|
switch (target) {
|
||||||
.cpu => self.cpu.writeNamedValues(&values),
|
.cpu => self.cpu.writeNamedValues(&values),
|
||||||
|
|||||||
@ -386,14 +386,8 @@ pub const Shape = struct {
|
|||||||
/// Format the shape.
|
/// Format the shape.
|
||||||
/// Default format: "Shape({.a=10, .b=20}, dtype=.f32)"
|
/// Default format: "Shape({.a=10, .b=20}, dtype=.f32)"
|
||||||
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
|
/// Bare format {_}: "{.a=10, .b=20}, dtype=.f32"
|
||||||
pub fn format(
|
pub fn format(self: Shape, writer: *std.Io.Writer) !void {
|
||||||
self: Shape,
|
_ = try writer.writeByte('{');
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
// TODO: impl alternative format
|
|
||||||
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
|
||||||
const bare_fmt = true;
|
|
||||||
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
|
||||||
|
|
||||||
var need_comma = false;
|
var need_comma = false;
|
||||||
for (self.dims(), 0..) |d, i| {
|
for (self.dims(), 0..) |d, i| {
|
||||||
@ -411,7 +405,7 @@ pub const Shape = struct {
|
|||||||
}
|
}
|
||||||
if (need_comma) try writer.writeByte(',');
|
if (need_comma) try writer.writeByte(',');
|
||||||
_ = try writer.write(@tagName(self.dtype()));
|
_ = try writer.write(@tagName(self.dtype()));
|
||||||
_ = try writer.write(if (bare_fmt) "}" else "})");
|
_ = try writer.writeByte('}');
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
|
/// Broadcasts a Tensor to the given shape, extending dimensions if needed.
|
||||||
|
|||||||
@ -52,10 +52,7 @@ pub const Tensor = struct {
|
|||||||
return CompilationContext.current();
|
return CompilationContext.current();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: Tensor, writer: *std.Io.Writer) !void {
|
||||||
self: Tensor,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
// TODO(0.15.0) handle format
|
// TODO(0.15.0) handle format
|
||||||
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
// const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||||
const bare_fmt = false;
|
const bare_fmt = false;
|
||||||
@ -1146,7 +1143,7 @@ pub const Tensor = struct {
|
|||||||
contracting_axes: []const [2]i8,
|
contracting_axes: []const [2]i8,
|
||||||
batching_axes: []const [2]i8,
|
batching_axes: []const [2]i8,
|
||||||
) Tensor {
|
) Tensor {
|
||||||
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
stdx.debug.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {f} and {f}", .{ lhs, rhs });
|
||||||
|
|
||||||
const Axes = stdx.BoundedArray(i64, MAX_RANK);
|
const Axes = stdx.BoundedArray(i64, MAX_RANK);
|
||||||
|
|
||||||
@ -1156,7 +1153,7 @@ pub const Tensor = struct {
|
|||||||
var rhs_batching_axes: Axes = .{};
|
var rhs_batching_axes: Axes = .{};
|
||||||
for (batching_axes) |b_axes| {
|
for (batching_axes) |b_axes| {
|
||||||
const l, const r = b_axes;
|
const l, const r = b_axes;
|
||||||
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {} and {} in {f} and {f}", .{ l, r, lhs, rhs });
|
stdx.debug.assert(lhs._shape.dim(l) == rhs._shape.dim(r), "dotGeneral expects batching dimensions to be equal, got {d} and {d} in {f} and {f}", .{ l, r, lhs, rhs });
|
||||||
var t = lhs._shape.tag(l);
|
var t = lhs._shape.tag(l);
|
||||||
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
if (t == Shape.TagUnknown) t = rhs._shape.tag(r);
|
||||||
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
res_shape = res_shape.appendDim(lhs._shape.dim(l), t);
|
||||||
@ -1521,14 +1518,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const to_the_end = std.math.maxInt(i64);
|
const to_the_end = std.math.maxInt(i64);
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(self: Slice, writer: *std.Io.Writer) !void {
|
||||||
self: Slice,
|
|
||||||
comptime fmt: []const u8,
|
|
||||||
options: std.fmt.FormatOptions,
|
|
||||||
writer: anytype,
|
|
||||||
) !void {
|
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
|
||||||
if (self.singleton) {
|
if (self.singleton) {
|
||||||
try writer.print("[{}]", .{self.start});
|
try writer.print("[{}]", .{self.start});
|
||||||
} else if (self.end == to_the_end and self.step == 1) {
|
} else if (self.end == to_the_end and self.step == 1) {
|
||||||
@ -2043,7 +2033,7 @@ pub const Tensor = struct {
|
|||||||
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
||||||
pub fn asScalar(self: Tensor) Tensor {
|
pub fn asScalar(self: Tensor) Tensor {
|
||||||
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self});
|
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {f}", .{self});
|
||||||
return self.reshape(.{});
|
return if (self.rank() == 0) self else self.reshape(.{});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const Pad = struct {
|
pub const Pad = struct {
|
||||||
@ -2582,7 +2572,7 @@ pub const Tensor = struct {
|
|||||||
/// that requires host<->device synchronization.
|
/// that requires host<->device synchronization.
|
||||||
/// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR.
|
/// ZML tries to generate the easiest to optimize IR, and will warn you if it generates known problematic IR.
|
||||||
pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor {
|
pub fn scatterSlices(self: Tensor, indices: anytype, updates: Tensor, opts: ScatterOpts) Tensor {
|
||||||
scoped_log.debug("scatterSlices({}, {any}, {})", .{ self, indices, updates });
|
scoped_log.debug("scatterSlices({f}, {any}, {f})", .{ self, indices, updates });
|
||||||
|
|
||||||
const UpdateType = @TypeOf(ScatterOpts.increment);
|
const UpdateType = @TypeOf(ScatterOpts.increment);
|
||||||
|
|
||||||
@ -3087,7 +3077,7 @@ pub const Tensor = struct {
|
|||||||
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
||||||
|
|
||||||
const allocator = self.getContext().allocator();
|
const allocator = self.getContext().allocator();
|
||||||
var chunks = std.ArrayListUnmanaged(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM");
|
var chunks = std.ArrayList(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM");
|
||||||
for (0..n_chunks) |i| {
|
for (0..n_chunks) |i| {
|
||||||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||||
chunks.appendAssumeCapacity(
|
chunks.appendAssumeCapacity(
|
||||||
|
|||||||
@ -229,7 +229,7 @@ pub fn testLayerOut(
|
|||||||
const FetchCtx = struct {
|
const FetchCtx = struct {
|
||||||
store: zml.aio.BufferStore,
|
store: zml.aio.BufferStore,
|
||||||
index: u32,
|
index: u32,
|
||||||
prefix: std.ArrayListUnmanaged(u8),
|
prefix: std.ArrayList(u8),
|
||||||
platform: zml.Platform,
|
platform: zml.Platform,
|
||||||
|
|
||||||
fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer {
|
fn fetch(ctx: *@This(), x: zml.Tensor) zml.Buffer {
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library")
|
load("@rules_zig//zig:defs.bzl", "zig_binary", "zig_library", "zig_test")
|
||||||
|
|
||||||
zig_library(
|
zig_library(
|
||||||
name = "tokenizer",
|
name = "tokenizer",
|
||||||
@ -25,3 +25,14 @@ zig_binary(
|
|||||||
"//stdx",
|
"//stdx",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
zig_test(
|
||||||
|
name = "test",
|
||||||
|
main = "homemade.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
"//async",
|
||||||
|
"//ffi:zig",
|
||||||
|
"//stdx",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@ -110,6 +110,7 @@ pub const HFTokenizer = opaque {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 {
|
pub fn tokenToId(self: *HFTokenizer, token: []const u8) ?u32 {
|
||||||
return c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
|
const id = c.hftokenizers_token_to_id(@ptrCast(self), ffi.ZigSlice.from(token));
|
||||||
|
return if (id == std.math.maxInt(u32)) null else id;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -189,10 +189,9 @@ pub const Tokenizer = struct {
|
|||||||
// Step by step visualization of the progress.
|
// Step by step visualization of the progress.
|
||||||
if (options.debug) {
|
if (options.debug) {
|
||||||
var _debug_buf: [256]u8 = undefined;
|
var _debug_buf: [256]u8 = undefined;
|
||||||
var _debug_alloc = std.heap.FixedBufferAllocator.init(&_debug_buf);
|
var debug_progress: std.Io.Writer = .fixed(&_debug_buf);
|
||||||
var debug_progress = std.array_list.Managed(u8).init(_debug_alloc.allocator());
|
self.decodeWithOpts(tok_buff[0..num_tokens], &debug_progress, .{ .sep = "|" }) catch {};
|
||||||
self.decodeWithOpts(&debug_progress, tok_buff[0..num_tokens], .{ .sep = "|" }) catch {};
|
log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.buffered() });
|
||||||
log.debug("tokens: {any} -> {s}", .{ tok_buff[0..num_tokens], debug_progress.items });
|
|
||||||
}
|
}
|
||||||
var best_score: f32 = -1e10;
|
var best_score: f32 = -1e10;
|
||||||
var best_token: u32 = 0;
|
var best_token: u32 = 0;
|
||||||
@ -311,22 +310,19 @@ pub const Tokenizer = struct {
|
|||||||
/// Converts the given slice of tokens back into bytes.
|
/// Converts the given slice of tokens back into bytes.
|
||||||
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
/// Note that if the tokenizer allows sub-unicode bytes, it's possible
|
||||||
/// the output is not valid utf8.
|
/// the output is not valid utf8.
|
||||||
pub fn decode(self: *const Tokenizer, allocator: std.mem.Allocator, input: []const u32) error{OutOfMemory}![]u8 {
|
pub fn decode(self: *const Tokenizer, input: []const u32, writer: *std.Io.Writer) std.Io.Writer.Error!void {
|
||||||
var output = std.array_list.Managed(u8).init(allocator);
|
try self.decodeWithOpts(input, writer, .{});
|
||||||
errdefer output.deinit();
|
|
||||||
|
|
||||||
try self.decodeWithOpts(&output, input, .{});
|
|
||||||
return output.toOwnedSlice();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn decodeWithOpts(
|
pub fn decodeWithOpts(
|
||||||
self: *const Tokenizer,
|
self: *const Tokenizer,
|
||||||
output: *std.array_list.Managed(u8),
|
|
||||||
input: []const u32,
|
input: []const u32,
|
||||||
|
output: *std.Io.Writer,
|
||||||
opts: struct { sep: []const u8 = "" },
|
opts: struct { sep: []const u8 = "" },
|
||||||
) error{OutOfMemory}!void {
|
) std.Io.Writer.Error!void {
|
||||||
const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
|
const escaped = if (self.normalizer) |n| n.escapedSpace() else null;
|
||||||
// Flag used to indicate if the first dummy whitespace has been consumed.
|
// Flag used to indicate if the first dummy whitespace has been consumed.
|
||||||
|
var first_token: bool = true;
|
||||||
for (input) |id| {
|
for (input) |id| {
|
||||||
// Retrieve the slice corresponding to the id.
|
// Retrieve the slice corresponding to the id.
|
||||||
var piece = self.lookupPiece(id);
|
var piece = self.lookupPiece(id);
|
||||||
@ -337,12 +333,13 @@ pub const Tokenizer = struct {
|
|||||||
while (std.mem.startsWith(u8, piece, escspc)) {
|
while (std.mem.startsWith(u8, piece, escspc)) {
|
||||||
piece = piece[escspc.len..];
|
piece = piece[escspc.len..];
|
||||||
// don't output a space at beginning of text.
|
// don't output a space at beginning of text.
|
||||||
if (output.items.len > 0) try output.append(' ');
|
if (!first_token) try output.writeByte(' ');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try output.appendSlice(piece);
|
try output.writeAll(piece);
|
||||||
if (opts.sep.len > 0) try output.appendSlice(opts.sep);
|
first_token = false;
|
||||||
|
try output.writeAll(opts.sep);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,10 +469,13 @@ pub const Decoder = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
|
pub fn decode(self: *Decoder, ids: []const u32) ![]const u8 {
|
||||||
|
// TODO: revisite tokenizer api to use std.Io.Writer.
|
||||||
self.reset();
|
self.reset();
|
||||||
const res = try self.inner.decode(self.arena.allocator(), ids);
|
// Reuse the same warmup than in init, to maximize contiguous writes.
|
||||||
self.current_string = res;
|
var arena_writer = std.Io.Writer.Allocating.initCapacity(self.arena.allocator(), 4096) catch unreachable;
|
||||||
return res;
|
try self.inner.decode(ids, &arena_writer.writer);
|
||||||
|
self.current_string = arena_writer.written();
|
||||||
|
return self.current_string.?;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn string(self: *const Decoder) []const u8 {
|
pub fn string(self: *const Decoder) []const u8 {
|
||||||
@ -623,9 +623,9 @@ pub const Normalizer = struct {
|
|||||||
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
return if (self._whitespace.len > 1) self._whitespace.constSlice() else null;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn addSlice(data: []const u8, consumed: usize, normalized: *std.array_list.Managed(u8), normalized_to_origin: *std.array_list.Managed(usize)) !void {
|
fn addSlice(allocator: std.mem.Allocator, data: []const u8, consumed: usize, normalized: *std.ArrayList(u8), normalized_to_origin: *std.ArrayList(usize)) !void {
|
||||||
try normalized.appendSlice(data);
|
try normalized.appendSlice(allocator, data);
|
||||||
for (data) |_| try normalized_to_origin.append(consumed);
|
for (data) |_| try normalized_to_origin.append(allocator, consumed);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const Result = struct {
|
pub const Result = struct {
|
||||||
@ -673,13 +673,13 @@ pub const Normalizer = struct {
|
|||||||
// Pre-allocate outputs
|
// Pre-allocate outputs
|
||||||
const space = self.escapedSpace() orelse " ";
|
const space = self.escapedSpace() orelse " ";
|
||||||
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
const overhead = if (self.flags.split_on_punct_ascii) space.len + 1 else space.len;
|
||||||
var normalized = try std.array_list.Managed(u8).initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
var normalized: std.ArrayList(u8) = try .initCapacity(allocator, trimmed_input.len * overhead + 2 * space.len);
|
||||||
errdefer normalized.deinit();
|
errdefer normalized.deinit(allocator);
|
||||||
var normalized_to_origin = try std.array_list.Managed(usize).initCapacity(allocator, normalized.capacity);
|
var normalized_to_origin: std.ArrayList(usize) = try .initCapacity(allocator, normalized.capacity);
|
||||||
errdefer normalized_to_origin.deinit();
|
errdefer normalized_to_origin.deinit(allocator);
|
||||||
|
|
||||||
// If the spec asks for it, add a whitespace at the beginning.
|
// If the spec asks for it, add a whitespace at the beginning.
|
||||||
if (self.flags.add_dummy_prefix) try addSlice(space, consumed, &normalized, &normalized_to_origin);
|
if (self.flags.add_dummy_prefix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin);
|
||||||
|
|
||||||
var is_prev_space: bool = true;
|
var is_prev_space: bool = true;
|
||||||
var is_prev_word: bool = false;
|
var is_prev_word: bool = false;
|
||||||
@ -706,23 +706,23 @@ pub const Normalizer = struct {
|
|||||||
var byte = slice[0];
|
var byte = slice[0];
|
||||||
if (self.escapedSpace() != null and byte == ' ') {
|
if (self.escapedSpace() != null and byte == ' ') {
|
||||||
// replace the space token by the special token
|
// replace the space token by the special token
|
||||||
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
try addSlice(allocator, space, origin, &normalized, &normalized_to_origin);
|
||||||
is_prev_word = false;
|
is_prev_word = false;
|
||||||
break :ascii;
|
break :ascii;
|
||||||
} else if (self.flags.split_on_punct_ascii) {
|
} else if (self.flags.split_on_punct_ascii) {
|
||||||
if (is_prev_word and isPunct(slice)) {
|
if (is_prev_word and isPunct(slice)) {
|
||||||
// Insert a space, but continue handling the rest
|
// Insert a space, but continue handling the rest
|
||||||
try addSlice(space, origin, &normalized, &normalized_to_origin);
|
try addSlice(allocator, space, origin, &normalized, &normalized_to_origin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (self.flags.lower_case_ascii) {
|
if (self.flags.lower_case_ascii) {
|
||||||
byte = std.ascii.toLower(byte);
|
byte = std.ascii.toLower(byte);
|
||||||
}
|
}
|
||||||
try normalized.append(byte);
|
try normalized.append(allocator, byte);
|
||||||
try normalized_to_origin.append(origin);
|
try normalized_to_origin.append(allocator, origin);
|
||||||
} else {
|
} else {
|
||||||
// we can safely copy to the output.
|
// we can safely copy to the output.
|
||||||
try addSlice(slice, origin, &normalized, &normalized_to_origin);
|
try addSlice(allocator, slice, origin, &normalized, &normalized_to_origin);
|
||||||
}
|
}
|
||||||
is_prev_word = !is_prev_space and !isPunct(slice);
|
is_prev_word = !is_prev_space and !isPunct(slice);
|
||||||
}
|
}
|
||||||
@ -732,20 +732,20 @@ pub const Normalizer = struct {
|
|||||||
while (std.mem.endsWith(u8, normalized.items, space)) {
|
while (std.mem.endsWith(u8, normalized.items, space)) {
|
||||||
const length = normalized.items.len - space.len;
|
const length = normalized.items.len - space.len;
|
||||||
consumed = normalized_to_origin.items[length];
|
consumed = normalized_to_origin.items[length];
|
||||||
try normalized.resize(length);
|
try normalized.resize(allocator, length);
|
||||||
try normalized_to_origin.resize(length);
|
try normalized_to_origin.resize(allocator, length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try normalized_to_origin.append(consumed);
|
try normalized_to_origin.append(allocator, consumed);
|
||||||
|
|
||||||
std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1);
|
std.debug.assert(normalized_to_origin.items.len == normalized.items.len + 1);
|
||||||
|
|
||||||
if (self.flags.add_dummy_suffix) try addSlice(space, consumed, &normalized, &normalized_to_origin);
|
if (self.flags.add_dummy_suffix) try addSlice(allocator, space, consumed, &normalized, &normalized_to_origin);
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
.normalized = try normalized.toOwnedSlice(),
|
.normalized_to_origin = try normalized_to_origin.toOwnedSlice(allocator),
|
||||||
.normalized_to_origin = try normalized_to_origin.toOwnedSlice(),
|
.normalized = try normalized.toOwnedSlice(allocator),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1006,8 +1006,7 @@ pub const Gpt2TextDecoder = struct {
|
|||||||
|
|
||||||
/// Transform bytes representing text under the gpt2 encoding,
|
/// Transform bytes representing text under the gpt2 encoding,
|
||||||
/// and write to the `unicode` buffer utf-8 bytes.
|
/// and write to the `unicode` buffer utf-8 bytes.
|
||||||
pub fn decode(self: Gpt2TextDecoder, unicode: *std.array_list.Managed(u8), bytes: []const u8) ![]const u8 {
|
pub fn decode(self: Gpt2TextDecoder, bytes: []const u8, writer: *std.Io.Writer) (error{InvalidInput} || std.Io.Writer.Error)!void {
|
||||||
const start = unicode.items.len;
|
|
||||||
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
|
var it = std.unicode.Utf8Iterator{ .i = 0, .bytes = bytes };
|
||||||
while (it.nextCodepointSlice()) |codepoint| {
|
while (it.nextCodepointSlice()) |codepoint| {
|
||||||
const code: Code = switch (codepoint.len) {
|
const code: Code = switch (codepoint.len) {
|
||||||
@ -1016,9 +1015,8 @@ pub const Gpt2TextDecoder = struct {
|
|||||||
else => return error.InvalidInput,
|
else => return error.InvalidInput,
|
||||||
};
|
};
|
||||||
const byte = self.code_to_byte.get(code) orelse return error.InvalidInput;
|
const byte = self.code_to_byte.get(code) orelse return error.InvalidInput;
|
||||||
try unicode.append(byte);
|
try writer.writeByte(byte);
|
||||||
}
|
}
|
||||||
return unicode.items[start..];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inline fn isPrintableByte(c: u8) bool {
|
inline fn isPrintableByte(c: u8) bool {
|
||||||
@ -1030,15 +1028,33 @@ test Gpt2TextDecoder {
|
|||||||
var decoder = try Gpt2TextDecoder.init(testing.allocator);
|
var decoder = try Gpt2TextDecoder.init(testing.allocator);
|
||||||
defer decoder.deinit();
|
defer decoder.deinit();
|
||||||
|
|
||||||
var out = std.array_list.Managed(u8).init(testing.allocator);
|
var buf: [128]u8 = undefined;
|
||||||
defer out.deinit();
|
|
||||||
|
|
||||||
// Ascii is not changed.
|
// Ascii is not changed.
|
||||||
try testing.expectEqualStrings("getTitle", try decoder.decode(&out, "getTitle"));
|
{
|
||||||
|
var out: std.Io.Writer = .fixed(&buf);
|
||||||
|
try decoder.decode("getTitle", &out);
|
||||||
|
try testing.expectEqualStrings("getTitle", out.buffered());
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
var out: std.Io.Writer = .fixed(&buf);
|
||||||
|
try decoder.decode("Ċ", &out);
|
||||||
|
try testing.expectEqualStrings("\n", out.buffered());
|
||||||
|
}
|
||||||
|
|
||||||
// Leading space are represented with 'Ġ'
|
// Leading space are represented with 'Ġ'
|
||||||
try testing.expectEqualStrings(" UINavigationController", try decoder.decode(&out, "ĠUINavigationController"));
|
{
|
||||||
|
var out: std.Io.Writer = .fixed(&buf);
|
||||||
|
try decoder.decode("ĠUINavigationController", &out);
|
||||||
|
try testing.expectEqualStrings(" UINavigationController", out.buffered());
|
||||||
|
}
|
||||||
// Russian is wild
|
// Russian is wild
|
||||||
try testing.expectEqualStrings(" работ", try decoder.decode(&out, "ĠÑĢабоÑĤ"));
|
{
|
||||||
|
var out: std.Io.Writer = .fixed(&buf);
|
||||||
|
try decoder.decode("ĠÑĢабоÑĤ", &out);
|
||||||
|
try testing.expectEqualStrings(" работ", out.buffered());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Open a json file in HF format and load the vocab from it.
|
/// Open a json file in HF format and load the vocab from it.
|
||||||
@ -1077,12 +1093,12 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
|
|
||||||
// Buffer containing all concatenated tokens.
|
// Buffer containing all concatenated tokens.
|
||||||
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
// Reserve a big chunk, to avoid grow event, but release over-allocated memory.
|
||||||
var all_tokens = try std.array_list.Managed(u8).initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
var all_tokens: std.Io.Writer.Allocating = try .initCapacity(tokenizer.arena_state.allocator(), file_content.len);
|
||||||
const original_alloc = all_tokens.items.ptr;
|
const original_alloc = all_tokens.writer.buffer.ptr;
|
||||||
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
// A re-alloc event here means we have invalidated all slices inside the tokenizer.
|
||||||
// If this is too annoying we could switch to a custom type instead of slices.
|
// If this is too annoying we could switch to a custom type instead of slices.
|
||||||
defer {
|
defer {
|
||||||
std.debug.assert(all_tokens.items.ptr == original_alloc);
|
std.debug.assert(all_tokens.writer.buffer.ptr == original_alloc);
|
||||||
}
|
}
|
||||||
|
|
||||||
// gpt2 based tokenizer got a special way of encoding unicode.
|
// gpt2 based tokenizer got a special way of encoding unicode.
|
||||||
@ -1094,16 +1110,18 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
defer gpt2_decoder.deinit();
|
defer gpt2_decoder.deinit();
|
||||||
var it = vocab.iterator();
|
var it = vocab.iterator();
|
||||||
while (it.next()) |kv| {
|
while (it.next()) |kv| {
|
||||||
const token = gpt2_decoder.decode(&all_tokens, kv.key_ptr.*) catch |err| {
|
const n = all_tokens.writer.buffered().len;
|
||||||
|
gpt2_decoder.decode(kv.key_ptr.*, &all_tokens.writer) catch |err| {
|
||||||
switch (err) {
|
switch (err) {
|
||||||
error.InvalidInput => {
|
error.InvalidInput => {
|
||||||
is_gpt2_vocab = false;
|
is_gpt2_vocab = false;
|
||||||
break;
|
break;
|
||||||
},
|
},
|
||||||
else => return err,
|
error.WriteFailed => return error.OutOfMemory,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
const idx: u32 = @intCast(kv.value_ptr.*.integer);
|
||||||
|
const token = all_tokens.written()[n..];
|
||||||
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
tokenizer.addOwnedTokenByIndex(idx, @floatFromInt(vocab_size - idx), token);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1126,15 +1144,16 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
if (token_obj != .object) return error.InvalidFormat;
|
if (token_obj != .object) return error.InvalidFormat;
|
||||||
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
|
const v = objectGet(token_obj.object, .string, "content") orelse return error.InvalidFormat;
|
||||||
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
|
const id: u32 = @intCast(objectGet(token_obj.object, .integer, "id") orelse return error.InvalidFormat);
|
||||||
const token = try if (is_gpt2_vocab)
|
const n = all_tokens.written().len;
|
||||||
gpt2_decoder.decode(&all_tokens, v)
|
try if (is_gpt2_vocab)
|
||||||
|
gpt2_decoder.decode(v, &all_tokens.writer)
|
||||||
else
|
else
|
||||||
dup(&all_tokens, v);
|
all_tokens.writer.writeAll(v);
|
||||||
|
const token = all_tokens.written()[n..];
|
||||||
tokenizer.addOwnedTokenByIndex(id, 0, token);
|
tokenizer.addOwnedTokenByIndex(id, 0, token);
|
||||||
}
|
}
|
||||||
// We won't add more tokens here, let release.
|
// We won't add more tokens here, let release.
|
||||||
all_tokens.shrinkAndFree(all_tokens.items.len);
|
_ = try all_tokens.toOwnedSlice();
|
||||||
|
|
||||||
var unk = tokenizer.lookup("<unk>");
|
var unk = tokenizer.lookup("<unk>");
|
||||||
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
|
if (objectGet(model, .integer, "unk_token")) |unk_tok| {
|
||||||
@ -1165,10 +1184,10 @@ pub fn fromHfJson(allocator: std.mem.Allocator, tokenizer_path: []const u8) !Tok
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a copy of the given string, stored inside the given ArrayList.
|
/// Returns a copy of the given string, stored inside the given ArrayList.
|
||||||
fn dup(buffer: *std.array_list.Managed(u8), str: []const u8) ![]const u8 {
|
fn dup(allocating: *std.Io.Writer.Allocating, str: []const u8) std.Io.Writer.Error![]const u8 {
|
||||||
const n = buffer.items.len;
|
const n = allocating.written().len;
|
||||||
try buffer.appendSlice(str);
|
try allocating.writer.writeAll(str);
|
||||||
return buffer.items[n..];
|
return allocating.written()[n..];
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the given entry in a json object only if it has the right type.
|
/// Returns the given entry in a json object only if it has the right type.
|
||||||
|
|||||||
@ -48,11 +48,11 @@ pub fn asyncMain() !void {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (args.expected.len > 0) {
|
if (args.expected.len > 0) {
|
||||||
var expected = try std.array_list.Managed(u32).initCapacity(allocator, args.prompt.len);
|
var expected: std.ArrayList(u32) = try .initCapacity(allocator, args.prompt.len);
|
||||||
var it = std.mem.splitSequence(u8, args.expected, ",");
|
var it = std.mem.splitSequence(u8, args.expected, ",");
|
||||||
while (it.next()) |int_token| {
|
while (it.next()) |int_token| {
|
||||||
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
const tok = try std.fmt.parseInt(u32, int_token, 10);
|
||||||
try expected.append(tok);
|
try expected.append(allocator, tok);
|
||||||
}
|
}
|
||||||
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
if (!std.mem.eql(u32, expected.items, prompt_tok)) {
|
||||||
log.err("Doesn't match expected: {any}", .{expected.items});
|
log.err("Doesn't match expected: {any}", .{expected.items});
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user