diff --git a/zml/aio/torch/parser.zig b/zml/aio/torch/parser.zig index 93bafcb..5075299 100644 --- a/zml/aio/torch/parser.zig +++ b/zml/aio/torch/parser.zig @@ -9,6 +9,10 @@ const RawPickleOp = @import("ops.zig").RawPickleOp; const Allocator = std.mem.Allocator; const testing = std.testing; +test { + std.testing.refAllDecls(@This()); +} + pub const Decoder = struct { buffer_file: zml.aio.MemoryMappedFile, file_map: std.StringArrayHashMapUnmanaged(std.zip.Iterator(asynk.File.SeekableStream).Entry) = .{}, diff --git a/zml/dtype.zig b/zml/dtype.zig index fc1883c..2230463 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -4,6 +4,10 @@ const floats = @import("floats.zig"); const C64 = std.math.Complex(f32); const C128 = std.math.Complex(f64); +test { + std.testing.refAllDecls(@This()); +} + pub const DataType = enum(u8) { bool, f8e4m3b11fnuz, diff --git a/zml/floats.zig b/zml/floats.zig index 2694688..5adb522 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -1,5 +1,9 @@ const std = @import("std"); +test { + std.testing.refAllDecls(@This()); +} + fn allBitsOne(v: anytype) bool { return v == std.math.maxInt(@TypeOf(v)); } diff --git a/zml/helpers.zig b/zml/helpers.zig index f336316..a042e87 100644 --- a/zml/helpers.zig +++ b/zml/helpers.zig @@ -8,6 +8,10 @@ const Tensor = @import("tensor.zig").Tensor; const EnumLiteral = @TypeOf(.enum_literal); const log = std.log.scoped(.zml_tensor); +test { + std.testing.refAllDecls(@This()); +} + const ShapeError = error{ DimMismatch, NotFound }; const NOT_SET: i64 = 0; const DIM_MISMATCH: i64 = -1; diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 5f34733..d658bcd 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -8,6 +8,10 @@ const DataType = @import("dtype.zig").DataType; const Platform = @import("platform.zig").Platform; const meta = @import("meta.zig"); +test { + std.testing.refAllDecls(@This()); +} + /// Represents a tensor with associated data allocated by user code. /// If the memory is `.managed` it needs to be freed with `x.deinit(allocator)` /// If the memory is `.unmanaged` it doesn't need to be freed (eg memory mapped, or tracked elsewhere). diff --git a/zml/meta.zig b/zml/meta.zig index e9883f9..0f593f1 100644 --- a/zml/meta.zig +++ b/zml/meta.zig @@ -2,6 +2,10 @@ const std = @import("std"); const testing = std.testing; +test { + std.testing.refAllDecls(@This()); +} + /// Computes floating point value division between two integers. pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T { return toFloat(T, numerator) / toFloat(T, denominator); diff --git a/zml/module.zig b/zml/module.zig index a6578f5..d4e8143 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -28,6 +28,10 @@ const Tracer = @import("tools/tracer.zig").Tracer; const log = std.log.scoped(.zml_module); +test { + std.testing.refAllDecls(@This()); +} + pub const CompilationContext = struct { _platform: Platform, diff --git a/zml/nn.zig b/zml/nn.zig index 4af9695..e418e94 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -828,14 +828,12 @@ fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn { }; } test "sdpaMemEfficient without mask" { - if (true) return error.SkipZigTest; - const platform = zml.testing.env(); const allocator = std.testing.allocator; // Note we use small input vectors to have the tests run reasonably fast, // but don't expect speed ups with this small sizes. - const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform); + const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); defer rng.deinit(); // Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable. @@ -863,17 +861,15 @@ test "sdpaMemEfficient without mask" { } test "sdpaMemEfficient with mask" { - if (true) return error.SkipZigTest; - const platform = zml.testing.env(); const allocator = std.testing.allocator; // Note we use small input vectors to have the tests run reasonably fast, // but don't expect speed ups with this small sizes. - const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform); + const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); defer rng.deinit(); - const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f16), .{ .mean = 0, .stddev = 1 } }, platform); + const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); defer rng_mask.deinit(); // Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable. diff --git a/zml/ops.zig b/zml/ops.zig index 9a6a1fe..4d7e758 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -21,6 +21,10 @@ const dialect = struct { const assert = std.debug.assert; const log = std.log.scoped(.zml_tensor); +test { + std.testing.refAllDecls(@This()); +} + /// Generate an MLIR call to the given member function with the given tensors. pub fn call(self: anytype, comptime func: meta.DeclEnum(@TypeOf(self)), args: anytype) @TypeOf(@call(.auto, @field(meta.UnwrapPtr(@TypeOf(self)), @tagName(func)), .{self} ++ args)) { // TODO: this should use `self.getContext().callFunc(self, args)` diff --git a/zml/tensor.zig b/zml/tensor.zig index b408aff..1be5a0e 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -23,6 +23,10 @@ const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; +test { + std.testing.refAllDecls(@This()); +} + const scoped_log = std.log.scoped(.zml_tensor); /// Represents an abstract Tensor object, which can be the input, @@ -654,7 +658,7 @@ pub const Tensor = struct { break :blk powers; }; const values = Tensor.constantTensor(HostBuffer.fromArray(&powers)).withTags(.{.d}); - const counts = values.gatherValues(.d, samples, .{}).sum(.d).bitCast(.u16); + const counts = values.gatherValues(.d, samples, .{}).sum(.n).bitCast(.u16); const actual_dist = counts.reshape(target_dist.shape()).convert(target_dist.dtype()).divByConst(s.dim(.n)); return .{ rng, .{ .mean = mean_, .variance = variance, .actual_dist = actual_dist } }; } diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index 83634a6..6a1f394 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -8,6 +8,10 @@ const log = std.log.scoped(.zml_tokenizer); const helpers = @import("helpers.zig"); const meta = @import("meta.zig"); +test { + std.testing.refAllDecls(@This()); +} + /// Byte Pair Encoding tokenizer generally used for LLM. pub const Tokenizer = struct { tokens: [][]const u8,