diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 03a8011..c20dd85 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -384,18 +384,6 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, }); } -pub fn reverseMany(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, location: mlir.Location) mlir.Operation { - const result_type = operand.getType(); - return mlir.Operation.make(ctx, "stablehlo.reverse", .{ - .operands = &.{operand}, - .results = &.{result_type}, - .attributes = &.{ - .{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? }, - }, - .location = location, - }); -} - pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation { return mlir.Operation.make(ctx, "stablehlo.compare", .{ .operands = &.{ lhs, rhs }, diff --git a/zml/aio.zig b/zml/aio.zig index 46c6efd..ce1c47b 100644 --- a/zml/aio.zig +++ b/zml/aio.zig @@ -18,6 +18,19 @@ pub const log = std.log.scoped(.zml_aio); pub const Value = @import("aio/value.zig").Value; const HostBuffer = @import("hostbuffer.zig").HostBuffer; +test { + std.testing.refAllDecls(@This()); + std.testing.refAllDecls(gguf); + // TODO(@cryptodeal) + // std.testing.refAllDecls(nemo); + std.testing.refAllDecls(safetensors); + std.testing.refAllDecls(sentencepiece); + std.testing.refAllDecls(tinyllama); + std.testing.refAllDecls(torch); + // TODO(@cryptodeal) + // std.testing.refAllDecls(yaml); +} + /// Detects the format of the model file (base on filename) and open it. pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore { return if (std.mem.endsWith(u8, model_path, ".safetensors")) @@ -585,7 +598,3 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi else => {}, } } - -test { - std.testing.refAllDecls(@This()); -} diff --git a/zml/aio/yaml.zig b/zml/aio/yaml.zig index 741a402..8200144 100644 --- a/zml/aio/yaml.zig +++ b/zml/aio/yaml.zig @@ -30,9 +30,9 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore { pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void { switch (val) { - .int => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), - .float => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), - .string => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }), + .int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), + .float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), + .string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }), .list => |v| switch (validSlice(v)) { true => { if (v.len == 0) return; @@ -41,13 +41,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str const values = try allocator.alloc(i64, v.len); errdefer allocator.free(values); for (v, 0..) |item, i| values[i] = item.int; - try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = utils.toVoidSlice(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } }); }, .float => { const values = try allocator.alloc(f64, v.len); errdefer allocator.free(values); for (v, 0..) |item, i| values[i] = item.float; - try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = utils.toVoidSlice(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } }); }, .string => { const values = try allocator.alloc([]const u8, v.len); @@ -55,7 +55,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str for (v, 0..) |item, i| { values[i] = try allocator.dupe(u8, item.string); } - try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = utils.toVoidSlice(values) } }); + try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } }); }, .list => unreachable, else => {}, diff --git a/zml/buffer.zig b/zml/buffer.zig index e38a602..eddf976 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -4,15 +4,17 @@ const testing = std.testing; const meta = @import("meta.zig"); const pjrt = @import("pjrt"); const pjrtx = @import("pjrtx.zig"); -const platform = @import("platform.zig"); const Context = @import("context.zig").Context; -const HostBuffer = @import("hostbuffer.zig").HostBuffer; -const Shape = @import("shape.zig").Shape; -const Tensor = @import("tensor.zig").Tensor; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; -const Target = @import("platform.zig").Target; +const HostBuffer = @import("hostbuffer.zig").HostBuffer; +const Platform = @import("platform.zig").Platform; +const Shape = @import("shape.zig").Shape; + +test { + std.testing.refAllDecls(Buffer); +} /// Buffer is a multi-dimension array, whose memory is allocated on an accelerator. /// @@ -22,54 +24,74 @@ const Target = @import("platform.zig").Target; pub const Buffer = struct { _shape: Shape, _shards: Shape = undefined, - _platform: platform.Platform, + _platform: Platform, _data: *pjrtx.Buffer, /// Copies the content of the given buffer from host memory to the accelerator memory. - pub fn from(platform_: platform.Platform, buf: HostBuffer) !Buffer { - const pjrt_buffer = try platform_.pjrt_client.bufferFromHostBuffer(platform_.pjrt_api, .{ + pub fn from(platform: Platform, buf: HostBuffer) !Buffer { + const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{ .data = buf.data, .buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), .dims = buf.shape().dims(), - .byte_strides = null, - .device = platform_.getDevices()[0], + .byte_strides = buf.strides(), + .device = platform.getDevices()[0], .host_buffer_semantics = .ImmutableUntilTransferCompletes, }); return .{ - ._platform = platform_, + ._platform = platform, ._shape = buf.shape(), ._data = pjrt_buffer, }; } /// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`. - pub fn fromPjrtBuffer(platform_: platform.Platform, pjrt_buffer: *pjrtx.Buffer) Buffer { + pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer { return .{ - ._platform = platform_, - ._shape = _shapeFromPjrtBuffer(platform_, pjrt_buffer), + ._platform = platform, + ._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer), ._data = pjrt_buffer, }; } /// Copies the given Zig slice to the accelerator memory and /// return a Buffer with the given dimensions. - pub fn fromSlice(platform_: platform.Platform, dimz: anytype, s: anytype) !Buffer { + pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer { const sh = Shape.init(dimz, DataType.fromSliceElementType(s)); - return from(platform_, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); + return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); } /// Copies the given Zig array to the accelerator memory and /// return a Buffer using the array shape. - pub fn fromArray(platform_: platform.Platform, arr: anytype) !Buffer { + pub fn fromArray(platform: Platform, arr: anytype) !Buffer { const host_buffer = HostBuffer.fromArray(&arr); - return try host_buffer.toDevice(platform_); + return try from(platform, host_buffer); } /// Creates a Buffer with a single element. - pub fn scalar(platform_: platform.Platform, val: anytype, dtype_: DataType) !Buffer { + pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer { const x = dtype_.constant(val); const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice()); - return try host_buffer.toDevice(platform_); + return try from(platform, host_buffer); + } + + /// Creates a Buffer with a single element repeated manytime. + pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer { + const x = shape_.dtype().constant(val); + const host_buffer: HostBuffer = .{ + ._shape = shape_, + ._strides = [1]i64{0} ** Shape.MAX_RANK, + .data = x.constSlice(), + }; + return try from(platform, host_buffer); + } + + test constant { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42); + const y = try x.getValue([4 * 3 * 2]u16); + try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y); } /// Creates a Buffer as a view of memory visible from the device, @@ -79,7 +101,7 @@ pub const Buffer = struct { /// Be careful though, as it requires a specific alignment. /// Also note that it might not work on all platforms, /// could lead to crashes and is considerably slower. - pub fn asViewOf(platform_: platform.Platform, buf: HostBuffer) !Buffer { + pub fn asViewOf(platform: Platform, buf: HostBuffer) !Buffer { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { var res: [Shape.MAX_RANK]i64 = undefined; for (0..Shape.MAX_RANK) |i| { @@ -88,11 +110,11 @@ pub const Buffer = struct { break :blk res; }; - const pjrt_buffer = try platform_.pjrt_client.createViewOfDeviceBuffer(platform_.pjrt_api, .{ + const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{ .data = buf.data, .element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), .dims = buf.shape().dims(), - .device = platform_.getDevices()[0], + .device = platform.getDevices()[0], .layout = .{ .Tiled = .{ .minor_to_major = minor_to_major[Shape.MAX_RANK - buf.shape().rank() ..], @@ -103,7 +125,7 @@ pub const Buffer = struct { }); return .{ - ._platform = platform_, + ._platform = platform, ._shape = buf.shape(), ._data = pjrt_buffer, }; @@ -177,11 +199,11 @@ pub const Buffer = struct { ) !void { _ = fmt; _ = options; - try writer.print("Tensor({_})", .{self._shape}); + try writer.print("Buffer({_})", .{self._shape}); } - fn _shapeFromPjrtBuffer(platform_: platform.Platform, buf: *pjrtx.Buffer) Shape { - const dt: DataType = switch (buf.getElementType(platform_.pjrt_api)) { + fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape { + const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) { // Please keep the list exhaustive and in the same order than in DataType. .PRED => .bool, .F8E4M3B11FNUZ => .f8e4m3b11fnuz, @@ -208,14 +230,6 @@ pub const Buffer = struct { .INVALID => @panic("Can't handle INVALID Pjrt buffers."), }; - return Shape.init(buf.getDimensions(platform_.pjrt_api), dt); + return Shape.init(buf.getDimensions(platform.pjrt_api), dt); } - - pub const From = meta.MapType(Tensor, Buffer).map; }; - -/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. -pub fn Bufferized(comptime T: type) type { - const M = meta.MapType(Tensor, Buffer); - return M.map(T); -} diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index d658bcd..fa16886 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -1,15 +1,14 @@ const std = @import("std"); +const meta = @import("meta.zig"); const Buffer = @import("buffer.zig").Buffer; -const Shape = @import("shape.zig").Shape; -const Tensor = @import("tensor.zig").Tensor; const Data = @import("dtype.zig").Data; const DataType = @import("dtype.zig").DataType; const Platform = @import("platform.zig").Platform; -const meta = @import("meta.zig"); +const Shape = @import("shape.zig").Shape; test { - std.testing.refAllDecls(@This()); + std.testing.refAllDecls(HostBuffer); } /// Represents a tensor with associated data allocated by user code. @@ -99,10 +98,16 @@ pub const HostBuffer = struct { }; } + pub const ArangeArgs = struct { + start: i64 = 0, + end: i64, + step: i64 = 1, + }; + /// Allocates a HostBuffer with the given shape. /// The memory is initialized with increasing numbers. /// The caller owns the memory, and need to call `deinit()`. - pub fn arange(allocator: std.mem.Allocator, args: Tensor.ArangeArgs, dt: DataType) !HostBuffer { + pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer { meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); @@ -137,12 +142,6 @@ pub const HostBuffer = struct { } } - /// Embeds a tensor with concrete values into an Mlir program. - /// The content is copied, so the HostBuffer can be safely `deinit`. - pub fn toStaticTensor(self: HostBuffer) Tensor { - return Tensor.staticTensor(self.shape(), self.data); - } - /// Copies this HostBuffer to the given accelerator. pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer { return try Buffer.from(platform_, self); @@ -164,8 +163,9 @@ pub const HostBuffer = struct { return self._shape.dtype(); } - pub fn strides(self: HostBuffer) ?[]const i64 { - return self._strides; + pub fn strides(self: *const HostBuffer) ?[]const i64 { + // Pass strides per pointer otherwise we return a pointer to this stack frame. + return if (self._strides) |*strd| strd[0..self.rank()] else null; } pub fn data(self: HostBuffer) []const u8 { diff --git a/zml/module.zig b/zml/module.zig index d4e8143..b227ecd 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -23,7 +23,7 @@ const Tensor = @import("tensor.zig").Tensor; const ShapeOf = @import("tensor.zig").ShapeOf; const Shape = @import("shape.zig").Shape; const Buffer = @import("buffer.zig").Buffer; -const Bufferized = @import("buffer.zig").Bufferized; +const Bufferized = @import("tensor.zig").Bufferized; const Tracer = @import("tools/tracer.zig").Tracer; const log = std.log.scoped(.zml_module); diff --git a/zml/nn.zig b/zml/nn.zig index e418e94..a7c0cc3 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -16,6 +16,11 @@ const log = std.log.scoped(.zml_tensor); const cuda = @import("nn/cuda.zig"); +test { + _ = cuda; + std.testing.refAllDecls(@This()); +} + pub const Linear = struct { weight: Tensor, bias: ?Tensor = null, @@ -302,7 +307,7 @@ test "real/img" { } test "rope" { - const platofrm = zml.testing.env(); + const platform = zml.testing.env(); const TestRope = struct { fn forward(x: Tensor, opts: RopeOpts) Tensor { @@ -326,9 +331,9 @@ test "rope" { // x is made such as the interleaved and sequential reps are the same. // So the two implementations should give the same results. - const x = try zml.Buffer.fromSlice(platofrm, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); - const res1 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } }); - const res2 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } }); + const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); + const res1 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } }); + const res2 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } }); try zml.testing.expectClose(res1, res2, 1e-4); } @@ -516,70 +521,141 @@ test nearest { } } -pub const ResizeOpts = struct { original_len: ?Tensor = null }; +pub const ResizeOpts = struct { + /// scalar tensor containing the original dimension of the image. + /// It can be different from the image shape, + /// if the image has been padded. + /// This allows to compile one module that handle different input image sizes. + original_len: ?Tensor = null, -pub fn resizeBilinear(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor { + /// Internal precision to do the interpolation. + /// Result will always use the same dtype than the original. + /// If not set, will use the image dtype, unless it's an integer type, in which case f32 will be used. + precision: ?zml.DataType = null, +}; + +pub fn resizeBilinear(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor { + const new_size, const tags_ = Shape.parseStruct(u63, resized_axes); var out = image; - for (axes, dims) |a, d| { + for (new_size.constSlice(), tags_.constSlice()) |d, t| { + const ax = image.shape().axis(t); const child_opt: ResizeOpts = .{ - .original_len = if (opt.original_len) |o| o.choose1d(0, a) else null, + .original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null, }; - out = resizeLinear1d(out, a, d, child_opt); + out = resizeLinear1d(out, ax, d, child_opt); } return out; } +test resizeBilinear { + const platform = zml.testing.env(); + + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } }, + .{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } }, + .{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } }, + }) |testcase| { + const x_shape, const resizing, const res_shape = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const y = resizeBilinear(x, resizing, .{}); + try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); + try std.testing.expect(y.value().owner().verify()); + } +} + pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor { - const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32); - const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len)); - const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio); + const res_shape = image.shape().set(axis, new_len); + + const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); + const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); + const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); + const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const left = scaled.floor(); const right = left.addConstant(1); - const values = image.gatherSlices1d(axis, 2, left.convert(.i32), .{ .indices_are_sorted = true }); - const left_val, const right_val = helpers.mapTensors( - Tensor.squeeze, - values.convert(.f32).chunkExact(2, axis + 1), - .{@as(i64, @intCast(axis + 1))}, - ); - const left_weight = right.sub(scaled).broadcast(left_val.shape(), &.{axis}); - const right_weight = scaled.sub(left).broadcast(left_val.shape(), &.{axis}); + // TODO: check that two gather isn't too bad perf wise. + // Normally we should use gatherSlices to collect the values 2 by 2, + // but gatherSlices messes up with the order of axes. + const left_val = image.gatherValues(axis, left.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype); + const right_val = image.gatherValues(axis, right.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype); - return left_val.mul(left_weight).add(right_val.mul(right_weight)).convert(image.dtype()); + const left_weight = right.sub(scaled).broadcast(res_shape, &.{axis}); + const right_weight = scaled.sub(left).broadcast(res_shape, &.{axis}); + + const res = left_val.mul(left_weight).add(right_val.mul(right_weight)); + return res.convert(image.dtype()).withTags(image.shape().tags()); } /// Bicubic interpolation of the given image. /// Warning as of May 2024 the cpu backend don't optimize this very well /// and is not able to merge the weighting with the gather, /// leading to 20x slow down compared to STB implementation. -pub fn resizeBicubic(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor { +pub fn resizeBicubic(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor { + const new_size, const tags_ = Shape.parseStruct(u63, resized_axes); var out = image; - for (axes, dims) |a, d| { + for (new_size.constSlice(), tags_.constSlice()) |d, t| { + const ax = image.shape().axis(t); const child_opt: ResizeOpts = .{ - .original_len = if (opt.original_len) |o| o.choose1d(0, a) else null, + .original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null, }; - out = resizeCubic1d(out, a, d, child_opt); + out = resizeCubic1d(out, ax, d, child_opt); } return out; } +test resizeBicubic { + const platform = zml.testing.env(); + + // Only test shapes + var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform); + defer comp.deinit(); + comp.activate(); + defer comp.deactivate(); + + inline for (.{ + .{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } }, + .{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } }, + .{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } }, + }) |testcase| { + const x_shape, const resizing, const res_shape = testcase; + const x = Tensor.constant(x_shape, .{ .f16 = 0 }); + const y = resizeBicubic(x, resizing, .{}); + try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape()); + try std.testing.expect(y.value().owner().verify()); + } +} + pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor { // Extract neighboring pixels from the image. - const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32); - const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len)); - const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio); + const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype(); + const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype); + + const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len)); + const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio); const t = scaled.sub(scaled.floor()); const pos = Tensor.stack(&.{ - Tensor.scalar(1, .f32).broadcast(t.shape(), &.{}), + Tensor.constant(t.shape(), dtype.one()), t, t.mul(t), - t.pow(Tensor.scalar(3, .f32)), - }, -1, .features); + t.pow(Tensor.scalar(3, dtype)), + }, .last, ._interpolated); + std.debug.assert(pos.dim(0) == new_len); std.debug.assert(pos.dim(1) == 4); - const context = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32)); - const values = image.gatherSlices1d(axis, 4, context, .{ .indices_are_sorted = true }); + const neighbors = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32)); + + const values = image.renameAxis(axis, ._neighbors).gatherSlices( + Shape.init(.{ ._neighbors = 4 }, image.dtype()), + neighbors.appendAxes(.{.coord}), + .{ .indices_are_sorted = true }, + ).convert(dtype); const weights_: [4][4]f32 = .{ .{ 0, 1, 0, 0 }, @@ -587,22 +663,24 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten .{ 1, -2.5, 2, -0.5 }, .{ -0.5, 1.5, -1.5, 0.5 }, }; - const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)); + const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors }); // actually do the interpolation. // Note: ideally this matmul should be inlined with the gather, but that's currently not the case. - var res = values.convert(.f32).dotGeneral(weights, &.{.{ axis + 1, 1 }}, &.{}); - res = pos.dotGeneral(res, &.{.{ 1, image.rank() }}, &.{.{ 0, axis }}); + // TODO: not being able to use dot here is a bit annoying. + var res = values.dotGeneral(weights, &.{.{ values.axis(._neighbors), weights.axis(._neighbors) }}, &.{}); + res = pos.dotGeneral(res, &.{.{ pos.axis(._interpolated), res.axis(._interpolated) }}, &.{.{ 0, 0 }}); // the current axis is outputted in first position because it's a batching dim, put it back in place. - // if (axis != 0) - // res = res.transpose(Shape.range(image.rank()).swap(0, axis).dims()); + if (axis != 0) { + res = res.swapAxes(0, axis); + } // verify the shape const res_shape = image.shape().set(axis, new_len); // log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res }); std.debug.assert(std.mem.eql(i64, res_shape.dims(), res.dims())); - return res.convert(image.dtype()); + return res.convert(image.dtype()).withTags(image.shape()); } /// Return causal attention masks for the given shape. @@ -927,7 +1005,3 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); return .{ next_tokens, next_rng }; } - -test { - _ = cuda; -} diff --git a/zml/shape.zig b/zml/shape.zig index 842e1ed..172327e 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -9,7 +9,7 @@ const EnumLiteral = @TypeOf(.enum_literal); const log = std.log.scoped(.shape); test { - std.testing.refAllDecls(@This()); + std.testing.refAllDecls(Shape); } /// Represent the shape of a tensor. diff --git a/zml/tensor.zig b/zml/tensor.zig index 1be5a0e..d2e6032 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -23,12 +23,12 @@ const dialect = struct { const stablehlo = @import("mlir/dialects").stablehlo; }; -test { - std.testing.refAllDecls(@This()); -} - const scoped_log = std.log.scoped(.zml_tensor); +test { + std.testing.refAllDecls(Tensor); +} + /// Represents an abstract Tensor object, which can be the input, /// output, weight or activations of a neural network. /// Tensor are abstract in the sense they only represent a computation, @@ -166,6 +166,12 @@ pub const Tensor = struct { return res; } + pub fn renameAxis(self: Tensor, ax: i8, name: EnumLiteral) Tensor { + var res = self; + res._shape._tags.set(self.axis(ax), @tagName(name).ptr); + return res; + } + /// Returns the mlir.Value associated with the Tensor. /// /// This will fail if used outside of a compilation context. @@ -462,7 +468,7 @@ pub const Tensor = struct { }; } - pub fn init(platform: Platform, seed: u128) !Buffer.From(Rng) { + pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) { const bits: [2]u64 = @bitCast(seed); return .{ ._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits), @@ -1071,7 +1077,12 @@ pub const Tensor = struct { /// In this version batching dimensions need to be explicitly specified. /// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes. /// Where "result axes" are non-contracting, non-batching axes of each input tensor. - pub fn dotGeneral(lhs: Tensor, rhs: Tensor, contracting_axes: []const [2]i8, batching_axes: []const [2]i8) Tensor { + pub fn dotGeneral( + lhs: Tensor, + rhs: Tensor, + contracting_axes: []const [2]i8, + batching_axes: []const [2]i8, + ) Tensor { meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); const Axes = std.BoundedArray(i64, MAX_RANK); @@ -1273,6 +1284,17 @@ pub const Tensor = struct { return _result(res_shape, op.result(0)); } + pub fn swapAxes(self: Tensor, a: anytype, b: anytype) Tensor { + if (self.axis(a) == self.axis(b)) return self; + var perm: Shape.AxesArray = .{}; + for (0..self.rank()) |i| { + perm.appendAssumeCapacity(@intCast(i)); + } + perm.set(self.axis(a), self.axis(b)); + perm.set(self.axis(b), self.axis(a)); + return self.transpose(perm.constSlice()); + } + /// Returns a Tensor with the given axis unflattened. /// /// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3) @@ -1577,11 +1599,7 @@ pub const Tensor = struct { return _result(self._shape, expm1_op.result(0)); } - pub const ArangeArgs = struct { - start: i64 = 0, - end: i64, - step: i64 = 1, - }; + pub const ArangeArgs = HostBuffer.ArangeArgs; /// Returns a Tensor containing evenly spaced values within a given interval. pub fn arange(args: ArangeArgs, dt: DataType) Tensor { @@ -1871,14 +1889,6 @@ pub const Tensor = struct { return _result(self._shape, reverse_op.result(0)); } - /// Returns a Tensor with the given axes reversed. - pub fn reverseMany(self: Tensor, axes_: []const i64) Tensor { - const actual_axes = self._shape.axes(axes_).constSlice(); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({d})", .{actual_axes}); - const reverse_op = dialect.stablehlo.reverseMany(self.getContext().mlirCtx(), self.value(), @ptrCast(actual_axes), loc); - return _result(self._shape, reverse_op.result(0)); - } - pub const GatherOpts = struct { indices_are_sorted: bool = false }; /// For each coordinate in `indices`, @@ -2448,7 +2458,7 @@ pub const Tensor = struct { return result; } - pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: i64) [n_chunks]Tensor { + pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor { const a = self.axis(axis_); const length = self.dim(a); _ = @divExact(length, n_chunks); @@ -3484,3 +3494,9 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va } return res; } + +/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer. +pub fn Bufferized(comptime T: type) type { + const M = meta.MapType(Tensor, Buffer); + return M.map(T); +} diff --git a/zml/zml.zig b/zml/zml.zig index dcf5ff9..484e414 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -4,7 +4,7 @@ //! pub const Buffer = @import("buffer.zig").Buffer; -pub const Bufferized = @import("buffer.zig").Bufferized; +pub const Bufferized = @import("tensor.zig").Bufferized; pub const CompilationOptions = @import("platform.zig").CompilationOptions; pub const Context = @import("context.zig").Context; pub const Data = @import("dtype.zig").Data;