zml/tests: re‑enable all Zig tests, fix precision issue by switching to f32, and add refAllDecls to ensure all declarations are tested
This commit is contained in:
parent
f39b16e13d
commit
ebdb8db213
@ -9,6 +9,10 @@ const RawPickleOp = @import("ops.zig").RawPickleOp;
|
|||||||
const Allocator = std.mem.Allocator;
|
const Allocator = std.mem.Allocator;
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
pub const Decoder = struct {
|
pub const Decoder = struct {
|
||||||
buffer_file: zml.aio.MemoryMappedFile,
|
buffer_file: zml.aio.MemoryMappedFile,
|
||||||
file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{},
|
file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{},
|
||||||
|
|||||||
@ -4,6 +4,10 @@ const floats = @import("floats.zig");
|
|||||||
const C64 = std.math.Complex(f32);
|
const C64 = std.math.Complex(f32);
|
||||||
const C128 = std.math.Complex(f64);
|
const C128 = std.math.Complex(f64);
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
pub const DataType = enum(u8) {
|
pub const DataType = enum(u8) {
|
||||||
bool,
|
bool,
|
||||||
f8e4m3b11fnuz,
|
f8e4m3b11fnuz,
|
||||||
|
|||||||
@ -1,5 +1,9 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
fn allBitsOne(v: anytype) bool {
|
fn allBitsOne(v: anytype) bool {
|
||||||
return v == std.math.maxInt(@TypeOf(v));
|
return v == std.math.maxInt(@TypeOf(v));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,10 @@ const Tensor = @import("tensor.zig").Tensor;
|
|||||||
const EnumLiteral = @TypeOf(.enum_literal);
|
const EnumLiteral = @TypeOf(.enum_literal);
|
||||||
const log = std.log.scoped(.zml_tensor);
|
const log = std.log.scoped(.zml_tensor);
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
const ShapeError = error{ DimMismatch, NotFound };
|
const ShapeError = error{ DimMismatch, NotFound };
|
||||||
const NOT_SET: i64 = 0;
|
const NOT_SET: i64 = 0;
|
||||||
const DIM_MISMATCH: i64 = -1;
|
const DIM_MISMATCH: i64 = -1;
|
||||||
|
|||||||
@ -8,6 +8,10 @@ const DataType = @import("dtype.zig").DataType;
|
|||||||
const Platform = @import("platform.zig").Platform;
|
const Platform = @import("platform.zig").Platform;
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents a tensor with associated data allocated by user code.
|
/// Represents a tensor with associated data allocated by user code.
|
||||||
/// If the memory is `.managed` it needs to be freed with `x.deinit(allocator)`
|
/// If the memory is `.managed` it needs to be freed with `x.deinit(allocator)`
|
||||||
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
/// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere).
|
||||||
|
|||||||
@ -2,6 +2,10 @@ const std = @import("std");
|
|||||||
|
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
/// Computes floating point value division between two integers.
|
/// Computes floating point value division between two integers.
|
||||||
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
|
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
|
||||||
return toFloat(T, numerator) / toFloat(T, denominator);
|
return toFloat(T, numerator) / toFloat(T, denominator);
|
||||||
|
|||||||
@ -28,6 +28,10 @@ const Tracer = @import("tools/tracer.zig").Tracer;
|
|||||||
|
|
||||||
const log = std.log.scoped(.zml_module);
|
const log = std.log.scoped(.zml_module);
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
pub const CompilationContext = struct {
|
pub const CompilationContext = struct {
|
||||||
_platform: Platform,
|
_platform: Platform,
|
||||||
|
|
||||||
|
|||||||
10
zml/nn.zig
10
zml/nn.zig
@ -828,14 +828,12 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
test "sdpaMemEfficient without mask" {
|
test "sdpaMemEfficient without mask" {
|
||||||
if (true) return error.SkipZigTest;
|
|
||||||
|
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
|
|
||||||
// Note we use small input vectors to have the tests run reasonably fast,
|
// Note we use small input vectors to have the tests run reasonably fast,
|
||||||
// but don't expect speed ups with this small sizes.
|
// but don't expect speed ups with this small sizes.
|
||||||
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
|
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||||
defer rng.deinit();
|
defer rng.deinit();
|
||||||
|
|
||||||
// Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable.
|
// Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable.
|
||||||
@ -863,17 +861,15 @@ test "sdpaMemEfficient without mask" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test "sdpaMemEfficient with mask" {
|
test "sdpaMemEfficient with mask" {
|
||||||
if (true) return error.SkipZigTest;
|
|
||||||
|
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
|
|
||||||
// Note we use small input vectors to have the tests run reasonably fast,
|
// Note we use small input vectors to have the tests run reasonably fast,
|
||||||
// but don't expect speed ups with this small sizes.
|
// but don't expect speed ups with this small sizes.
|
||||||
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
|
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||||
defer rng.deinit();
|
defer rng.deinit();
|
||||||
|
|
||||||
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform);
|
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||||
defer rng_mask.deinit();
|
defer rng_mask.deinit();
|
||||||
|
|
||||||
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
||||||
|
|||||||
@ -21,6 +21,10 @@ const dialect = struct {
|
|||||||
const assert = std.debug.assert;
|
const assert = std.debug.assert;
|
||||||
const log = std.log.scoped(.zml_tensor);
|
const log = std.log.scoped(.zml_tensor);
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate an MLIR call to the given member function with the given tensors.
|
/// Generate an MLIR call to the given member function with the given tensors.
|
||||||
pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
|
pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) {
|
||||||
// TODO: this should use `self.getContext().callFunc(self, args)`
|
// TODO: this should use `self.getContext().callFunc(self, args)`
|
||||||
|
|||||||
@ -23,6 +23,10 @@ const dialect = struct {
|
|||||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
const scoped_log = std.log.scoped(.zml_tensor);
|
const scoped_log = std.log.scoped(.zml_tensor);
|
||||||
|
|
||||||
/// Represents an abstract Tensor object, which can be the input,
|
/// Represents an abstract Tensor object, which can be the input,
|
||||||
@ -654,7 +658,7 @@ pub const Tensor = struct {
|
|||||||
break :blk powers;
|
break :blk powers;
|
||||||
};
|
};
|
||||||
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
|
const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d});
|
||||||
const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16);
|
const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16);
|
||||||
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
|
const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n));
|
||||||
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
|
return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } };
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,6 +8,10 @@ const log = std.log.scoped(.zml_tokenizer);
|
|||||||
const helpers = @import("helpers.zig");
|
const helpers = @import("helpers.zig");
|
||||||
const meta = @import("meta.zig");
|
const meta = @import("meta.zig");
|
||||||
|
|
||||||
|
test {
|
||||||
|
std.testing.refAllDecls(@This());
|
||||||
|
}
|
||||||
|
|
||||||
/// Byte Pair Encoding tokenizer generally used for LLM.
|
/// Byte Pair Encoding tokenizer generally used for LLM.
|
||||||
pub const Tokenizer = struct {
|
pub const Tokenizer = struct {
|
||||||
tokens: [][]const u8,
|
tokens: [][]const u8,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user