Add StableHLO bindings for versioning functions, enabling portable serialization of StableHLO.
This commit is contained in:
parent
8fa3878fc3
commit
fc718ab649
@ -1154,3 +1154,43 @@ pub const RngAlgorithm = struct {
|
||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 {
|
||||
const Context = struct {
|
||||
str: []const u8 = &.{},
|
||||
};
|
||||
var context = Context{};
|
||||
|
||||
c.stablehloVersionFromCompatibilityRequirement(requirement, (struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *Context = @ptrCast(@alignCast(userdata));
|
||||
inner_ctx.str = mlir.fromStringRef(mlir_str);
|
||||
}
|
||||
}).callback, &context);
|
||||
|
||||
return context.str;
|
||||
}
|
||||
|
||||
pub fn stablehloGetMinimumVersion(writer: anytype) void {
|
||||
var context = .{ .writer = writer };
|
||||
const WriterContext = @TypeOf(context);
|
||||
|
||||
c.stablehloGetMinimumVersion((struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||
}
|
||||
}).callback, &context);
|
||||
}
|
||||
|
||||
pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void {
|
||||
var context = .{ .writer = writer };
|
||||
const WriterContext = @TypeOf(context);
|
||||
|
||||
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||
}
|
||||
}).callback, &context), error.InvalidMlirBytecodeVersion);
|
||||
}
|
||||
|
||||
@ -30,7 +30,7 @@ pub fn registerPasses(comptime passes: []const u8) void {
|
||||
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
||||
}
|
||||
|
||||
fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
||||
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
||||
return if (res.value == 0) err;
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user