zml/tests: re‑enable all Zig tests, fix precision issue by switching to f32, and add refAllDecls to ensure all declarations are tested

This commit is contained in:
Tarry Singh 2023-01-23 16:28:19 +00:00
parent f39b16e13d
commit ebdb8db213
11 changed files with 44 additions and 8 deletions

View File

@ -9,6 +9,10 @@ const RawPickleOp = @import("ops.zig").RawPickleOp;
const Allocator = std.mem.Allocator;
const testing = std.testing;
test {
std.testing.refAllDecls(@This());
}
pub const Decoder = struct {
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{},

View File

@ -4,6 +4,10 @@ const floats = @import("floats.zig");
const C64 = std.math.Complex(f32);
const C128 = std.math.Complex(f64);
test {
std.testing.refAllDecls(@This());
}
pub const DataType = enum(u8) {
bool,
f8e4m3b11fnuz,

View File

@ -1,5 +1,9 @@
const std = @import("std");
test {
std.testing.refAllDecls(@This());
}
fn allBitsOne(v: anytype) bool {
return v == std.math.maxInt(@TypeOf(v));
}

View File

@ -8,6 +8,10 @@ const Tensor = @import("tensor.zig").Tensor;
const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.zml_tensor);
test {
std.testing.refAllDecls(@This());
}
const ShapeError = error{ DimMismatch, NotFound };
const NOT_SET: i64 = 0;
const DIM_MISMATCH: i64 = -1;

View File

@ -8,6 +8,10 @@ const DataType = @import("dtype.zig").DataType;
const Platform = @import("platform.zig").Platform;
const meta = @import("meta.zig");
test {
std.testing.refAllDecls(@This());
}
/// Represents a tensor with associated data allocated by user code.
/// If the memory is `.managed` it needs to be freed with `x.deinit(allocator)`
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).

View File

@ -2,6 +2,10 @@ const std = @import("std");
const testing = std.testing;
test {
std.testing.refAllDecls(@This());
}
/// Computes floating point value division between two integers.
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
return toFloat(T, numerator) / toFloat(T, denominator);

View File

@ -28,6 +28,10 @@ const Tracer = @import("tools/tracer.zig").Tracer;
const log = std.log.scoped(.zml_module);
test {
std.testing.refAllDecls(@This());
}
pub const CompilationContext = struct {
_platform: Platform,

View File

@ -828,14 +828,12 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
};
}
test "sdpaMemEfficient without mask" {
if (true) return error.SkipZigTest;
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable.
@ -863,17 +861,15 @@ test "sdpaMemEfficient without mask" {
}
test "sdpaMemEfficient with mask" {
if (true) return error.SkipZigTest;
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.

View File

@ -21,6 +21,10 @@ const dialect = struct {
const assert = std.debug.assert;
const log = std.log.scoped(.zml_tensor);
test {
std.testing.refAllDecls(@This());
}
/// Generate an MLIR call to the given member function with the given tensors.
pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
// TODO: this should use `self.getContext().callFunc(self, args)`

View File

@ -23,6 +23,10 @@ const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo;
};
test {
std.testing.refAllDecls(@This());
}
const scoped_log = std.log.scoped(.zml_tensor);
/// Represents an abstract Tensor object, which can be the input,
@ -654,7 +658,7 @@ pub const Tensor = struct {
break :blk powers;
};
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16);
const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16);
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
}

View File

@ -8,6 +8,10 @@ const log = std.log.scoped(.zml_tokenizer);
const helpers = @import("helpers.zig");
const meta = @import("meta.zig");
test {
std.testing.refAllDecls(@This());
}
/// Byte Pair Encoding tokenizer generally used for LLM.
pub const Tokenizer = struct {
tokens: [][]const u8,