zml/pjrtx: prefer the built‑in stablehlo version when a plugin reports a newer version, ensuring artifact serialization uses the correct stablehlo version.
This commit is contained in:
parent
9505992e00
commit
c8c99d7d5a
@ -1261,6 +1261,23 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
|
||||
return state.call(requirement);
|
||||
}
|
||||
|
||||
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
|
||||
var buf: [32]u8 = undefined;
|
||||
|
||||
var stream = std.io.fixedBufferStream(&buf);
|
||||
var context = .{ .writer = stream.writer() };
|
||||
const WriterContext = @TypeOf(context);
|
||||
|
||||
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
|
||||
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||
}
|
||||
}).callback, &context);
|
||||
|
||||
return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2;
|
||||
}
|
||||
|
||||
pub fn getCurrentVersion() []const u8 {
|
||||
const state = struct {
|
||||
var buf: [32]u8 = undefined;
|
||||
|
||||
@ -98,8 +98,13 @@ pub const Client = opaque {
|
||||
defer serialized_buffer.deinit();
|
||||
|
||||
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
|
||||
var stablehlo_version_buf: [32]u8 = undefined;
|
||||
const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
||||
var requested_stablehlo_version_buf: [32]u8 = undefined;
|
||||
const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf);
|
||||
const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: {
|
||||
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
|
||||
} else blk: {
|
||||
break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
||||
};
|
||||
|
||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
||||
log.err("failed to serialize to portable artifact: {}", .{err});
|
||||
|
||||
Loading…
Reference in New Issue
Block a user