From fc718ab6492bc529d7d21de69935f925288f205e Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Wed, 22 Feb 2023 15:41:33 +0000 Subject: [PATCH] Add StableHLO bindings for versioning functions, enabling portable serialization of StableHLO. --- mlir/dialects/stablehlo.zig | 40 +++++++++++++++++++++++++++++++++++++ mlir/mlir.zig | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 55e1915..2fa642c 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -1154,3 +1154,43 @@ pub const RngAlgorithm = struct { return std.meta.stringToEnum(Type, value) orelse unreachable; } }; + +pub fn stablehloVersionFromCompatibilityRequirement(requirement: c.MlirStablehloCompatibilityRequirement) []const u8 { + const Context = struct { + str: []const u8 = &.{}, + }; + var context = Context{}; + + c.stablehloVersionFromCompatibilityRequirement(requirement, (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *Context = @ptrCast(@alignCast(userdata)); + inner_ctx.str = mlir.fromStringRef(mlir_str); + } + }).callback, &context); + + return context.str; +} + +pub fn stablehloGetMinimumVersion(writer: anytype) void { + var context = .{ .writer = writer }; + const WriterContext = @TypeOf(context); + + c.stablehloGetMinimumVersion((struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context); +} + +pub fn serializePortableArtifact(module_str: []const u8, target_version: []const u8, writer: anytype) !void { + var context = .{ .writer = writer }; + const WriterContext = @TypeOf(context); + + try mlir.successOr(c.stablehloSerializePortableArtifactFromStringRef(mlir.stringRef(module_str), mlir.stringRef(target_version), (struct { + pub fn callback(mlir_str: c.MlirStringRef, userdata: ?*anyopaque) callconv(.C) void { + const inner_ctx: *WriterContext = @ptrCast(@alignCast(userdata)); + _ = inner_ctx.writer.write(mlir.fromStringRef(mlir_str)) catch unreachable; + } + }).callback, &context), error.InvalidMlirBytecodeVersion); +} diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 50a66ec..6d48e79 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -30,7 +30,7 @@ pub fn registerPasses(comptime passes: []const u8) void { @field(c, "mlirRegister" ++ passes ++ "Passes")(); } -fn successOr(res: c.MlirLogicalResult, err: anytype) !void { +pub fn successOr(res: c.MlirLogicalResult, err: anytype) !void { return if (res.value == 0) err; }