From c8c99d7d5a862031ef87e0a38970385de3ebca9a Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Thu, 7 Sep 2023 17:06:19 +0000 Subject: [PATCH] =?UTF-8?q?zml/pjrtx:=20prefer=20the=20built=E2=80=91in=20?= =?UTF-8?q?stablehlo=20version=20when=20a=20plugin=20reports=20a=20newer?= =?UTF-8?q?=20version,=20ensuring=20artifact=20serialization=20uses=20the?= =?UTF-8?q?=20correct=20stablehlo=20version.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mlir/dialects/stablehlo.zig | 17 +++++++++++++++++ zml/pjrtx.zig | 9 +++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 24bd551..6d2c740 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -1261,6 +1261,23 @@ pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehlo return state.call(requirement); } +pub fn stablehloGetSmallerVersion(version1: []const u8, version2: []const u8) []const u8 { + var buf: [32]u8 = undefined; + + var stream = std.io.fixedBufferStream(&buf); + var context = .{ .writer = stream.writer() }; + const WriterContext = @TypeOf(context); + + _ = c.stablehloGetSmallerVersion(mlir.stringRef(version1), mlir.stringRef(version2), (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); + + return if (std.mem.eql(u8, buf[0..stream.pos], version1)) version1 else version2; +} + pub fn getCurrentVersion() []const u8 { const state = struct { var buf: [32]u8 = undefined; diff --git a/zml/pjrtx.zig b/zml/pjrtx.zig index 5f6aaed..9d43b9b 100644 --- a/zml/pjrtx.zig +++ b/zml/pjrtx.zig @@ -98,8 +98,13 @@ pub const Client = opaque { defer serialized_buffer.deinit(); // spec ref: https://github.com/openxla/xla/blob/39967ad6782a861ca029ab8d1a2b25f7e0c3902b/xla/pjrt/pjrt_c_api_client.cc#L399 - var stablehlo_version_buf: [32]u8 = undefined; - const stablehlo_version = api.stablehloCurrentVersion(&stablehlo_version_buf) orelse dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12); + var requested_stablehlo_version_buf: [32]u8 = undefined; + const requested_stablehlo_version = api.stablehloCurrentVersion(&requested_stablehlo_version_buf); + const stablehlo_version = if (requested_stablehlo_version) |requested_version| blk: { + break :blk dialects.stablehlo.stablehloGetSmallerVersion(requested_version, dialects.stablehlo.getCurrentVersion()); + } else blk: { + break :blk dialects.stablehlo.stablehloVersionFromCompatibilityRequirement(c.WEEK_12); + }; dialects.stablehlo.serializePortableArtifact(bytecode.items, stablehlo_version, serialized_buffer.writer()) catch |err| { log.err("failed to serialize to portable artifact: {}", .{err});