zml/pjrtx: prefer the built‑in stablehlo version when a plugin reports a newer version, ensuring artifact serialization uses the correct stablehlo version.
This commit is contained in:
parent
9505992e00
commit
c8c99d7d5a
@ -1261,6 +1261,23 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
|
|||||||
return state.call(requirement);
|
return state.call(requirement);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
|
||||||
|
var buf: [32]u8 = undefined;
|
||||||
|
|
||||||
|
var stream = std.io.fixedBufferStream(&buf);
|
||||||
|
var context = .{ .writer = stream.writer() };
|
||||||
|
const WriterContext = @TypeOf(context);
|
||||||
|
|
||||||
|
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
|
||||||
|
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
|
||||||
|
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
|
||||||
|
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
|
||||||
|
}
|
||||||
|
}).callback, &context);
|
||||||
|
|
||||||
|
return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn getCurrentVersion() []const u8 {
|
pub fn getCurrentVersion() []const u8 {
|
||||||
const state = struct {
|
const state = struct {
|
||||||
var buf: [32]u8 = undefined;
|
var buf: [32]u8 = undefined;
|
||||||
|
|||||||
@ -98,8 +98,13 @@ pub const Client = opaque {
|
|||||||
defer serialized_buffer.deinit();
|
defer serialized_buffer.deinit();
|
||||||
|
|
||||||
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
|
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
|
||||||
var stablehlo_version_buf: [32]u8 = undefined;
|
var requested_stablehlo_version_buf: [32]u8 = undefined;
|
||||||
const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf);
|
||||||
|
const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: {
|
||||||
|
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
|
||||||
|
} else blk: {
|
||||||
|
break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
|
||||||
|
};
|
||||||
|
|
||||||
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
|
||||||
log.err("failed to serialize to portable artifact: {}", .{err});
|
log.err("failed to serialize to portable artifact: {}", .{err});
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user