fix
Before we where using `module.op().writeBytecode(writer)` to compute the hash of a model but it crashes on some inputs, notably for unused variables. So I used the text representation of the mlir.
This commit is contained in:
parent
acc492454f
commit
83b5e1ec48
@ -813,8 +813,11 @@ fn computeModuleHash(platform: Platform, module: mlir.Module) u64 {
|
|||||||
const writer = hasher_writer.writer();
|
const writer = hasher_writer.writer();
|
||||||
|
|
||||||
// Hash the canonicalized IR, without debug information that can change across builds.
|
// Hash the canonicalized IR, without debug information that can change across builds.
|
||||||
module.op().writeBytecode(writer);
|
module.op().print(writer, .{ .debug_info = false });
|
||||||
//module.op().print(writer, .{ .debug_info = false });
|
// Note: before we where using module.op().writeBytecode(writer),
|
||||||
|
// but it crashes on some inputs, notably for unused variables.
|
||||||
|
// So we use the text representation of the mlir.
|
||||||
|
// See https://github.com/zml/zml/issues/97.
|
||||||
// Writes can't fail because we are writing to a hasher.
|
// Writes can't fail because we are writing to a hasher.
|
||||||
writer.writeAll(platform.pjrt_client.getPlatformName(platform.pjrt_api)) catch unreachable;
|
writer.writeAll(platform.pjrt_client.getPlatformName(platform.pjrt_api)) catch unreachable;
|
||||||
const api_version = platform.pjrt_api.version();
|
const api_version = platform.pjrt_api.version();
|
||||||
|
|||||||
@ -4040,3 +4040,19 @@ test transposeIsJustAReshape {
|
|||||||
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 }));
|
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 }));
|
||||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 }));
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 }));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test "unused tensor" {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const Local = struct {
|
||||||
|
pub fn forward(x: Tensor) Tensor {
|
||||||
|
const y = x.addConstant(1);
|
||||||
|
_ = y;
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const mod = try zml.compileFn(std.testing.allocator, Local.forward, .{Shape.init(.{10}, .f32)}, platform);
|
||||||
|
defer mod.deinit();
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user