Add StableHLO bindings for versioning functions, enabling portable serialization of StableHLO.

This commit is contained in:
Tarry Singh 2023-02-22 15:41:33 +00:00
parent 8fa3878fc3
commit fc718ab649
2 changed files with 41 additions and 1 deletions

View File

@ -1154,3 +1154,43 @@ pub const RngAlgorithm = struct {
return std.meta.stringToEnum(Type, value) orelse unreachable;
}
};
pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 {
const Context = struct {
str: []const u8 = &.{},
};
var context = Context{};
c.stablehloVersionFromCompatibilityRequirement(requirement, (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const inner_ctx: *Context = @ptrCast(@alignCast(userdata));
inner_ctx.str = mlir.fromStringRef(mlir_str);
}
}).callback, &context);
return context.str;
}
pub fn stablehloGetMinimumVersion(writer: anytype) void {
var context = .{ .writer = writer };
const WriterContext = @TypeOf(context);
c.stablehloGetMinimumVersion((struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &context);
}
pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void {
var context = .{ .writer = writer };
const WriterContext = @TypeOf(context);
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &context), error.InvalidMlirBytecodeVersion);
}

View File

@ -30,7 +30,7 @@ pub fn registerPasses(comptime passes: []const u8) void {
@field(c, "mlirRegister" ++ passes ++ "Passes")();
}
fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
return if (res.value == 0) err;
}