Add StableHLO bindings for versioning functions, enabling portable serialization of StableHLO.
This commit is contained in:
parent
8fa3878fc3
commit
fc718ab649
@ -1154,3 +1154,43 @@ pub const RngAlgorithm = struct {
|
|||||||
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
return std.meta.stringToEnum(Type, value) orelse unreachable;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 {
|
||||||
|
const Context = struct {
|
||||||
|
str: []const u8 = &.{},
|
||||||
|
};
|
||||||
|
var context = Context{};
|
||||||
|
|
||||||
|
c.stablehloVersionFromCompatibilityRequirement(requirement, (struct {
|
||||||
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||||
|
const inner_ctx: *Context = @ptrCast(@alignCast(userdata));
|
||||||
|
inner_ctx.str = mlir.fromStringRef(mlir_str);
|
||||||
|
}
|
||||||
|
}).callback, &context);
|
||||||
|
|
||||||
|
return context.str;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stablehloGetMinimumVersion(writer: anytype) void {
|
||||||
|
var context = .{ .writer = writer };
|
||||||
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
|
c.stablehloGetMinimumVersion((struct {
|
||||||
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||||
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
|
}
|
||||||
|
}).callback, &context);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void {
|
||||||
|
var context = .{ .writer = writer };
|
||||||
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
|
try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct {
|
||||||
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||||
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
|
}
|
||||||
|
}).callback, &context), error.InvalidMlirBytecodeVersion);
|
||||||
|
}
|
||||||
|
|||||||
@ -30,7 +30,7 @@ pub fn registerPasses(comptime passes: []const u8) void {
|
|||||||
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
@field(c, "mlirRegister" ++ passes ++ "Passes")();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void {
|
||||||
return if (res.value == 0) err;
|
return if (res.value == 0) err;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user