zml/pjrtx: prefer the built‑in stablehlo version when a plugin reports a newer version, ensuring artifact serialization uses the correct stablehlo version.

This commit is contained in:
Tarry Singh 2023-09-07 17:06:19 +00:00
parent 9505992e00
commit c8c99d7d5a
2 changed files with 24 additions and 2 deletions

View File

@ -1261,6 +1261,23 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo
return state.call(requirement); return state.call(requirement);
} }
pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 {
var buf: [32]u8 = undefined;
var stream = std.io.fixedBufferStream(&buf);
var context = .{ .writer = stream.writer() };
const WriterContext = @TypeOf(context);
_ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct {
pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void {
const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata));
_ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable;
}
}).callback, &context);
return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2;
}
pub fn getCurrentVersion() []const u8 { pub fn getCurrentVersion() []const u8 {
const state = struct { const state = struct {
var buf: [32]u8 = undefined; var buf: [32]u8 = undefined;

View File

@ -98,8 +98,13 @@ pub const Client = opaque {
defer serialized_buffer.deinit(); defer serialized_buffer.deinit();
// spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399 // spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399
var stablehlo_version_buf: [32]u8 = undefined; var requested_stablehlo_version_buf: [32]u8 = undefined;
const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12); const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf);
const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: {
break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion());
} else blk: {
break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12);
};
dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| { dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| {
log.err("failed to serialize to portable artifact: {}", .{err}); log.err("failed to serialize to portable artifact: {}", .{err});