zml/nn: fix resize implementations (resizeBilinear and resizeBicubic) and expand refAllDecl usage; all tests pass

This commit is contained in:
Tarry Singh 2023-01-27 14:35:11 +00:00
parent 5e1688cbfd
commit 7dcd8b516c
10 changed files with 238 additions and 137 deletions

View File

@ -384,18 +384,6 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
}); });
} }
pub fn reverseMany(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, location: mlir.Location) mlir.Operation {
const result_type = operand.getType();
return mlir.Operation.make(ctx, "stablehlo.reverse", .{
.operands = &.{operand},
.results = &.{result_type},
.attributes = &.{
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
},
.location = location,
});
}
pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation { pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.compare", .{ return mlir.Operation.make(ctx, "stablehlo.compare", .{
.operands = &.{ lhs, rhs }, .operands = &.{ lhs, rhs },

View File

@ -18,6 +18,19 @@ pub const log = std.log.scoped(.zml_aio);
pub const Value = @import("aio/value.zig").Value; pub const Value = @import("aio/value.zig").Value;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(gguf);
// TODO(@cryptodeal)
// std.testing.refAllDecls(nemo);
std.testing.refAllDecls(safetensors);
std.testing.refAllDecls(sentencepiece);
std.testing.refAllDecls(tinyllama);
std.testing.refAllDecls(torch);
// TODO(@cryptodeal)
// std.testing.refAllDecls(yaml);
}
/// Detects the format of the model file (base on filename) and open it. /// Detects the format of the model file (base on filename) and open it.
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore { pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
return if (std.mem.endsWith(u8, model_path, ".safetensors")) return if (std.mem.endsWith(u8, model_path, ".safetensors"))
@ -585,7 +598,3 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
else => {}, else => {},
} }
} }
test {
std.testing.refAllDecls(@This());
}

View File

@ -30,9 +30,9 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void { pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
switch (val) { switch (val) {
.int => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), .int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
.float => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), .float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
.string => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }), .string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
.list => |v| switch (validSlice(v)) { .list => |v| switch (validSlice(v)) {
true => { true => {
if (v.len == 0) return; if (v.len == 0) return;
@ -41,13 +41,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
const values = try allocator.alloc(i64, v.len); const values = try allocator.alloc(i64, v.len);
errdefer allocator.free(values); errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.int; for (v, 0..) |item, i| values[i] = item.int;
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = utils.toVoidSlice(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } });
}, },
.float => { .float => {
const values = try allocator.alloc(f64, v.len); const values = try allocator.alloc(f64, v.len);
errdefer allocator.free(values); errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.float; for (v, 0..) |item, i| values[i] = item.float;
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = utils.toVoidSlice(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } });
}, },
.string => { .string => {
const values = try allocator.alloc([]const u8, v.len); const values = try allocator.alloc([]const u8, v.len);
@ -55,7 +55,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
for (v, 0..) |item, i| { for (v, 0..) |item, i| {
values[i] = try allocator.dupe(u8, item.string); values[i] = try allocator.dupe(u8, item.string);
} }
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = utils.toVoidSlice(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } });
}, },
.list => unreachable, .list => unreachable,
else => {}, else => {},

View File

@ -4,15 +4,17 @@ const testing = std.testing;
const meta = @import("meta.zig"); const meta = @import("meta.zig");
const pjrt = @import("pjrt"); const pjrt = @import("pjrt");
const pjrtx = @import("pjrtx.zig"); const pjrtx = @import("pjrtx.zig");
const platform = @import("platform.zig");
const Context = @import("context.zig").Context; const Context = @import("context.zig").Context;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const Target = @import("platform.zig").Target; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
const Platform = @import("platform.zig").Platform;
const Shape = @import("shape.zig").Shape;
test {
std.testing.refAllDecls(Buffer);
}
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator. /// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
/// ///
@ -22,54 +24,74 @@ const Target = @import("platform.zig").Target;
pub const Buffer = struct { pub const Buffer = struct {
_shape: Shape, _shape: Shape,
_shards: Shape = undefined, _shards: Shape = undefined,
_platform: platform.Platform, _platform: Platform,
_data: *pjrtx.Buffer, _data: *pjrtx.Buffer,
/// Copies the content of the given buffer from host memory to the accelerator memory. /// Copies the content of the given buffer from host memory to the accelerator memory.
pub fn from(platform_: platform.Platform, buf: HostBuffer) !Buffer { pub fn from(platform: Platform, buf: HostBuffer) !Buffer {
const pjrt_buffer = try platform_.pjrt_client.bufferFromHostBuffer(platform_.pjrt_api, .{ const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
.data = buf.data, .data = buf.data,
.buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), .buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
.dims = buf.shape().dims(), .dims = buf.shape().dims(),
.byte_strides = null, .byte_strides = buf.strides(),
.device = platform_.getDevices()[0], .device = platform.getDevices()[0],
.host_buffer_semantics = .ImmutableUntilTransferCompletes, .host_buffer_semantics = .ImmutableUntilTransferCompletes,
}); });
return .{ return .{
._platform = platform_, ._platform = platform,
._shape = buf.shape(), ._shape = buf.shape(),
._data = pjrt_buffer, ._data = pjrt_buffer,
}; };
} }
/// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`. /// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`.
pub fn fromPjrtBuffer(platform_: platform.Platform, pjrt_buffer: *pjrtx.Buffer) Buffer { pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer {
return .{ return .{
._platform = platform_, ._platform = platform,
._shape = _shapeFromPjrtBuffer(platform_, pjrt_buffer), ._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer),
._data = pjrt_buffer, ._data = pjrt_buffer,
}; };
} }
/// Copies the given Zig slice to the accelerator memory and /// Copies the given Zig slice to the accelerator memory and
/// return a Buffer with the given dimensions. /// return a Buffer with the given dimensions.
pub fn fromSlice(platform_: platform.Platform, dimz: anytype, s: anytype) !Buffer { pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
const sh = Shape.init(dimz, DataType.fromSliceElementType(s)); const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
return from(platform_, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s))); return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
} }
/// Copies the given Zig array to the accelerator memory and /// Copies the given Zig array to the accelerator memory and
/// return a Buffer using the array shape. /// return a Buffer using the array shape.
pub fn fromArray(platform_: platform.Platform, arr: anytype) !Buffer { pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
const host_buffer = HostBuffer.fromArray(&arr); const host_buffer = HostBuffer.fromArray(&arr);
return try host_buffer.toDevice(platform_); return try from(platform, host_buffer);
} }
/// Creates a Buffer with a single element. /// Creates a Buffer with a single element.
pub fn scalar(platform_: platform.Platform, val: anytype, dtype_: DataType) !Buffer { pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
const x = dtype_.constant(val); const x = dtype_.constant(val);
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice()); const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
return try host_buffer.toDevice(platform_); return try from(platform, host_buffer);
}
/// Creates a Buffer with a single element repeated manytime.
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
const x = shape_.dtype().constant(val);
const host_buffer: HostBuffer = .{
._shape = shape_,
._strides = [1]i64{0} ** Shape.MAX_RANK,
.data = x.constSlice(),
};
return try from(platform, host_buffer);
}
test constant {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42);
const y = try x.getValue([4 * 3 * 2]u16);
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
} }
/// Creates a Buffer as a view of memory visible from the device, /// Creates a Buffer as a view of memory visible from the device,
@ -79,7 +101,7 @@ pub const Buffer = struct {
/// Be careful though, as it requires a specific alignment. /// Be careful though, as it requires a specific alignment.
/// Also note that it might not work on all platforms, /// Also note that it might not work on all platforms,
/// could lead to crashes and is considerably slower. /// could lead to crashes and is considerably slower.
pub fn asViewOf(platform_: platform.Platform, buf: HostBuffer) !Buffer { pub fn asViewOf(platform: Platform, buf: HostBuffer) !Buffer {
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: { const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
var res: [Shape.MAX_RANK]i64 = undefined; var res: [Shape.MAX_RANK]i64 = undefined;
for (0..Shape.MAX_RANK) |i| { for (0..Shape.MAX_RANK) |i| {
@ -88,11 +110,11 @@ pub const Buffer = struct {
break :blk res; break :blk res;
}; };
const pjrt_buffer = try platform_.pjrt_client.createViewOfDeviceBuffer(platform_.pjrt_api, .{ const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
.data = buf.data, .data = buf.data,
.element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()), .element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
.dims = buf.shape().dims(), .dims = buf.shape().dims(),
.device = platform_.getDevices()[0], .device = platform.getDevices()[0],
.layout = .{ .layout = .{
.Tiled = .{ .Tiled = .{
.minor_to_major = minor_to_major[Shape.MAX_RANK - buf.shape().rank() ..], .minor_to_major = minor_to_major[Shape.MAX_RANK - buf.shape().rank() ..],
@ -103,7 +125,7 @@ pub const Buffer = struct {
}); });
return .{ return .{
._platform = platform_, ._platform = platform,
._shape = buf.shape(), ._shape = buf.shape(),
._data = pjrt_buffer, ._data = pjrt_buffer,
}; };
@ -177,11 +199,11 @@ pub const Buffer = struct {
) !void { ) !void {
_ = fmt; _ = fmt;
_ = options; _ = options;
try writer.print("Tensor({_})", .{self._shape}); try writer.print("Buffer({_})", .{self._shape});
} }
fn _shapeFromPjrtBuffer(platform_: platform.Platform, buf: *pjrtx.Buffer) Shape { fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape {
const dt: DataType = switch (buf.getElementType(platform_.pjrt_api)) { const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) {
// Please keep the list exhaustive and in the same order than in DataType. // Please keep the list exhaustive and in the same order than in DataType.
.PRED => .bool, .PRED => .bool,
.F8E4M3B11FNUZ => .f8e4m3b11fnuz, .F8E4M3B11FNUZ => .f8e4m3b11fnuz,
@ -208,14 +230,6 @@ pub const Buffer = struct {
.INVALID => @panic("Can't handle INVALID Pjrt buffers."), .INVALID => @panic("Can't handle INVALID Pjrt buffers."),
}; };
return Shape.init(buf.getDimensions(platform_.pjrt_api), dt); return Shape.init(buf.getDimensions(platform.pjrt_api), dt);
} }
pub const From = meta.MapType(Tensor, Buffer).map;
}; };
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized(comptime T: type) type {
const M = meta.MapType(Tensor, Buffer);
return M.map(T);
}

View File

@ -1,15 +1,14 @@
const std = @import("std"); const std = @import("std");
const meta = @import("meta.zig");
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Shape = @import("shape.zig").Shape;
const Tensor = @import("tensor.zig").Tensor;
const Data = @import("dtype.zig").Data; const Data = @import("dtype.zig").Data;
const DataType = @import("dtype.zig").DataType; const DataType = @import("dtype.zig").DataType;
const Platform = @import("platform.zig").Platform; const Platform = @import("platform.zig").Platform;
const meta = @import("meta.zig"); const Shape = @import("shape.zig").Shape;
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(HostBuffer);
} }
/// Represents a tensor with associated data allocated by user code. /// Represents a tensor with associated data allocated by user code.
@ -99,10 +98,16 @@ pub const HostBuffer = struct {
}; };
} }
pub const ArangeArgs = struct {
start: i64 = 0,
end: i64,
step: i64 = 1,
};
/// Allocates a HostBuffer with the given shape. /// Allocates a HostBuffer with the given shape.
/// The memory is initialized with increasing numbers. /// The memory is initialized with increasing numbers.
/// The caller owns the memory, and need to call `deinit()`. /// The caller owns the memory, and need to call `deinit()`.
pub fn arange(allocator: std.mem.Allocator, args: Tensor.ArangeArgs, dt: DataType) !HostBuffer { pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end }); meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
@ -137,12 +142,6 @@ pub const HostBuffer = struct {
} }
} }
/// Embeds a tensor with concrete values into an Mlir program.
/// The content is copied, so the HostBuffer can be safely `deinit`.
pub fn toStaticTensor(self: HostBuffer) Tensor {
return Tensor.staticTensor(self.shape(), self.data);
}
/// Copies this HostBuffer to the given accelerator. /// Copies this HostBuffer to the given accelerator.
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer { pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
return try Buffer.from(platform_, self); return try Buffer.from(platform_, self);
@ -164,8 +163,9 @@ pub const HostBuffer = struct {
return self._shape.dtype(); return self._shape.dtype();
} }
pub fn strides(self: HostBuffer) ?[]const i64 { pub fn strides(self: *const HostBuffer) ?[]const i64 {
return self._strides; // Pass strides per pointer otherwise we return a pointer to this stack frame.
return if (self._strides) |*strd| strd[0..self.rank()] else null;
} }
pub fn data(self: HostBuffer) []const u8 { pub fn data(self: HostBuffer) []const u8 {

View File

@ -23,7 +23,7 @@ const Tensor = @import("tensor.zig").Tensor;
const ShapeOf = @import("tensor.zig").ShapeOf; const ShapeOf = @import("tensor.zig").ShapeOf;
const Shape = @import("shape.zig").Shape; const Shape = @import("shape.zig").Shape;
const Buffer = @import("buffer.zig").Buffer; const Buffer = @import("buffer.zig").Buffer;
const Bufferized = @import("buffer.zig").Bufferized; const Bufferized = @import("tensor.zig").Bufferized;
const Tracer = @import("tools/tracer.zig").Tracer; const Tracer = @import("tools/tracer.zig").Tracer;
const log = std.log.scoped(.zml_module); const log = std.log.scoped(.zml_module);

View File

@ -16,6 +16,11 @@ const log = std.log.scoped(.zml_tensor);
const cuda = @import("nn/cuda.zig"); const cuda = @import("nn/cuda.zig");
test {
_ = cuda;
std.testing.refAllDecls(@This());
}
pub const Linear = struct { pub const Linear = struct {
weight: Tensor, weight: Tensor,
bias: ?Tensor = null, bias: ?Tensor = null,
@ -302,7 +307,7 @@ test "real/img" {
} }
test "rope" { test "rope" {
const platofrm = zml.testing.env(); const platform = zml.testing.env();
const TestRope = struct { const TestRope = struct {
fn forward(x: Tensor, opts: RopeOpts) Tensor { fn forward(x: Tensor, opts: RopeOpts) Tensor {
@ -326,9 +331,9 @@ test "rope" {
// x is made such as the interleaved and sequential reps are the same. // x is made such as the interleaved and sequential reps are the same.
// So the two implementations should give the same results. // So the two implementations should give the same results.
const x = try zml.Buffer.fromSlice(platofrm, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5); const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
const res1 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } }); const res1 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } });
const res2 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } }); const res2 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } });
try zml.testing.expectClose(res1, res2, 1e-4); try zml.testing.expectClose(res1, res2, 1e-4);
} }
@ -516,70 +521,141 @@ test nearest {
} }
} }
pub const ResizeOpts = struct { original_len: ?Tensor = null }; pub const ResizeOpts = struct {
/// scalar tensor containing the original dimension of the image.
/// It can be different from the image shape,
/// if the image has been padded.
/// This allows to compile one module that handle different input image sizes.
original_len: ?Tensor = null,
pub fn resizeBilinear(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor { /// Internal precision to do the interpolation.
/// Result will always use the same dtype than the original.
/// If not set, will use the image dtype, unless it's an integer type, in which case f32 will be used.
precision: ?zml.DataType = null,
};
pub fn resizeBilinear(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor {
const new_size, const tags_ = Shape.parseStruct(u63, resized_axes);
var out = image; var out = image;
for (axes, dims) |a, d| { for (new_size.constSlice(), tags_.constSlice()) |d, t| {
const ax = image.shape().axis(t);
const child_opt: ResizeOpts = .{ const child_opt: ResizeOpts = .{
.original_len = if (opt.original_len) |o| o.choose1d(0, a) else null, .original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null,
}; };
out = resizeLinear1d(out, a, d, child_opt); out = resizeLinear1d(out, ax, d, child_opt);
} }
return out; return out;
} }
test resizeBilinear {
const platform = zml.testing.env();
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } },
.{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } },
.{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } },
}) |testcase| {
const x_shape, const resizing, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const y = resizeBilinear(x, resizing, .{});
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
try std.testing.expect(y.value().owner().verify());
}
}
pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor { pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor {
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32); const res_shape = image.shape().set(axis, new_len);
const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
const left = scaled.floor(); const left = scaled.floor();
const right = left.addConstant(1); const right = left.addConstant(1);
const values = image.gatherSlices1d(axis, 2, left.convert(.i32), .{ .indices_are_sorted = true }); // TODO: check that two gather isn't too bad perf wise.
const left_val, const right_val = helpers.mapTensors( // Normally we should use gatherSlices to collect the values 2 by 2,
Tensor.squeeze, // but gatherSlices messes up with the order of axes.
values.convert(.f32).chunkExact(2, axis + 1), const left_val = image.gatherValues(axis, left.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
.{@as(i64, @intCast(axis + 1))}, const right_val = image.gatherValues(axis, right.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
);
const left_weight = right.sub(scaled).broadcast(left_val.shape(), &.{axis});
const right_weight = scaled.sub(left).broadcast(left_val.shape(), &.{axis});
return left_val.mul(left_weight).add(right_val.mul(right_weight)).convert(image.dtype()); const left_weight = right.sub(scaled).broadcast(res_shape, &.{axis});
const right_weight = scaled.sub(left).broadcast(res_shape, &.{axis});
const res = left_val.mul(left_weight).add(right_val.mul(right_weight));
return res.convert(image.dtype()).withTags(image.shape().tags());
} }
/// Bicubic interpolation of the given image. /// Bicubic interpolation of the given image.
/// Warning as of May 2024 the cpu backend don't optimize this very well /// Warning as of May 2024 the cpu backend don't optimize this very well
/// and is not able to merge the weighting with the gather, /// and is not able to merge the weighting with the gather,
/// leading to 20x slow down compared to STB implementation. /// leading to 20x slow down compared to STB implementation.
pub fn resizeBicubic(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor { pub fn resizeBicubic(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor {
const new_size, const tags_ = Shape.parseStruct(u63, resized_axes);
var out = image; var out = image;
for (axes, dims) |a, d| { for (new_size.constSlice(), tags_.constSlice()) |d, t| {
const ax = image.shape().axis(t);
const child_opt: ResizeOpts = .{ const child_opt: ResizeOpts = .{
.original_len = if (opt.original_len) |o| o.choose1d(0, a) else null, .original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null,
}; };
out = resizeCubic1d(out, a, d, child_opt); out = resizeCubic1d(out, ax, d, child_opt);
} }
return out; return out;
} }
test resizeBicubic {
const platform = zml.testing.env();
// Only test shapes
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
defer comp.deinit();
comp.activate();
defer comp.deactivate();
inline for (.{
.{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } },
.{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } },
.{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } },
}) |testcase| {
const x_shape, const resizing, const res_shape = testcase;
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
const y = resizeBicubic(x, resizing, .{});
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
try std.testing.expect(y.value().owner().verify());
}
}
pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor { pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor {
// Extract neighboring pixels from the image. // Extract neighboring pixels from the image.
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32); const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len)); const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio);
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
const t = scaled.sub(scaled.floor()); const t = scaled.sub(scaled.floor());
const pos = Tensor.stack(&.{ const pos = Tensor.stack(&.{
Tensor.scalar(1, .f32).broadcast(t.shape(), &.{}), Tensor.constant(t.shape(), dtype.one()),
t, t,
t.mul(t), t.mul(t),
t.pow(Tensor.scalar(3, .f32)), t.pow(Tensor.scalar(3, dtype)),
}, -1, .features); }, .last, ._interpolated);
std.debug.assert(pos.dim(0) == new_len); std.debug.assert(pos.dim(0) == new_len);
std.debug.assert(pos.dim(1) == 4); std.debug.assert(pos.dim(1) == 4);
const context = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32)); const neighbors = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32));
const values = image.gatherSlices1d(axis, 4, context, .{ .indices_are_sorted = true });
const values = image.renameAxis(axis, ._neighbors).gatherSlices(
Shape.init(.{ ._neighbors = 4 }, image.dtype()),
neighbors.appendAxes(.{.coord}),
.{ .indices_are_sorted = true },
).convert(dtype);
const weights_: [4][4]f32 = .{ const weights_: [4][4]f32 = .{
.{ 0, 1, 0, 0 }, .{ 0, 1, 0, 0 },
@ -587,22 +663,24 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
.{ 1, -2.5, 2, -0.5 }, .{ 1, -2.5, 2, -0.5 },
.{ -0.5, 1.5, -1.5, 0.5 }, .{ -0.5, 1.5, -1.5, 0.5 },
}; };
const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)); const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors });
// actually do the interpolation. // actually do the interpolation.
// Note: ideally this matmul should be inlined with the gather, but that's currently not the case. // Note: ideally this matmul should be inlined with the gather, but that's currently not the case.
var res = values.convert(.f32).dotGeneral(weights, &.{.{ axis + 1, 1 }}, &.{}); // TODO: not being able to use dot here is a bit annoying.
res = pos.dotGeneral(res, &.{.{ 1, image.rank() }}, &.{.{ 0, axis }}); var res = values.dotGeneral(weights, &.{.{ values.axis(._neighbors), weights.axis(._neighbors) }}, &.{});
res = pos.dotGeneral(res, &.{.{ pos.axis(._interpolated), res.axis(._interpolated) }}, &.{.{ 0, 0 }});
// the current axis is outputted in first position because it's a batching dim, put it back in place. // the current axis is outputted in first position because it's a batching dim, put it back in place.
// if (axis != 0) if (axis != 0) {
// res = res.transpose(Shape.range(image.rank()).swap(0, axis).dims()); res = res.swapAxes(0, axis);
}
// verify the shape // verify the shape
const res_shape = image.shape().set(axis, new_len); const res_shape = image.shape().set(axis, new_len);
// log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res }); // log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res });
std.debug.assert(std.mem.eql(i64, res_shape.dims(), res.dims())); std.debug.assert(std.mem.eql(i64, res_shape.dims(), res.dims()));
return res.convert(image.dtype()); return res.convert(image.dtype()).withTags(image.shape());
} }
/// Return causal attention masks for the given shape. /// Return causal attention masks for the given shape.
@ -927,7 +1005,3 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens }); // log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
return .{ next_tokens, next_rng }; return .{ next_tokens, next_rng };
} }
test {
_ = cuda;
}

View File

@ -9,7 +9,7 @@ const EnumLiteral = @TypeOf(.enum_literal);
const log = std.log.scoped(.shape); const log = std.log.scoped(.shape);
test { test {
std.testing.refAllDecls(@This()); std.testing.refAllDecls(Shape);
} }
/// Represent the shape of a tensor. /// Represent the shape of a tensor.

View File

@ -23,12 +23,12 @@ const dialect = struct {
const stablehlo = @import("mlir/dialects").stablehlo; const stablehlo = @import("mlir/dialects").stablehlo;
}; };
test {
std.testing.refAllDecls(@This());
}
const scoped_log = std.log.scoped(.zml_tensor); const scoped_log = std.log.scoped(.zml_tensor);
test {
std.testing.refAllDecls(Tensor);
}
/// Represents an abstract Tensor object, which can be the input, /// Represents an abstract Tensor object, which can be the input,
/// output, weight or activations of a neural network. /// output, weight or activations of a neural network.
/// Tensor are abstract in the sense they only represent a computation, /// Tensor are abstract in the sense they only represent a computation,
@ -166,6 +166,12 @@ pub const Tensor = struct {
return res; return res;
} }
pub fn renameAxis(self: Tensor, ax: i8, name: EnumLiteral) Tensor {
var res = self;
res._shape._tags.set(self.axis(ax), @tagName(name).ptr);
return res;
}
/// Returns the mlir.Value associated with the Tensor. /// Returns the mlir.Value associated with the Tensor.
/// ///
/// This will fail if used outside of a compilation context. /// This will fail if used outside of a compilation context.
@ -462,7 +468,7 @@ pub const Tensor = struct {
}; };
} }
pub fn init(platform: Platform, seed: u128) !Buffer.From(Rng) { pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
const bits: [2]u64 = @bitCast(seed); const bits: [2]u64 = @bitCast(seed);
return .{ return .{
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits), ._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
@ -1071,7 +1077,12 @@ pub const Tensor = struct {
/// In this version batching dimensions need to be explicitly specified. /// In this version batching dimensions need to be explicitly specified.
/// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes. /// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes.
/// Where "result axes" are non-contracting, non-batching axes of each input tensor. /// Where "result axes" are non-contracting, non-batching axes of each input tensor.
pub fn dotGeneral(lhs: Tensor, rhs: Tensor, contracting_axes: []const [2]i8, batching_axes: []const [2]i8) Tensor { pub fn dotGeneral(
lhs: Tensor,
rhs: Tensor,
contracting_axes: []const [2]i8,
batching_axes: []const [2]i8,
) Tensor {
meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() }); meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
const Axes = std.BoundedArray(i64, MAX_RANK); const Axes = std.BoundedArray(i64, MAX_RANK);
@ -1273,6 +1284,17 @@ pub const Tensor = struct {
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));
} }
pub fn swapAxes(self: Tensor, a: anytype, b: anytype) Tensor {
if (self.axis(a) == self.axis(b)) return self;
var perm: Shape.AxesArray = .{};
for (0..self.rank()) |i| {
perm.appendAssumeCapacity(@intCast(i));
}
perm.set(self.axis(a), self.axis(b));
perm.set(self.axis(b), self.axis(a));
return self.transpose(perm.constSlice());
}
/// Returns a Tensor with the given axis unflattened. /// Returns a Tensor with the given axis unflattened.
/// ///
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3) /// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
@ -1577,11 +1599,7 @@ pub const Tensor = struct {
return _result(self._shape, expm1_op.result(0)); return _result(self._shape, expm1_op.result(0));
} }
pub const ArangeArgs = struct { pub const ArangeArgs = HostBuffer.ArangeArgs;
start: i64 = 0,
end: i64,
step: i64 = 1,
};
/// Returns a Tensor containing evenly spaced values within a given interval. /// Returns a Tensor containing evenly spaced values within a given interval.
pub fn arange(args: ArangeArgs, dt: DataType) Tensor { pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
@ -1871,14 +1889,6 @@ pub const Tensor = struct {
return _result(self._shape, reverse_op.result(0)); return _result(self._shape, reverse_op.result(0));
} }
/// Returns a Tensor with the given axes reversed.
pub fn reverseMany(self: Tensor, axes_: []const i64) Tensor {
const actual_axes = self._shape.axes(axes_).constSlice();
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({d})", .{actual_axes});
const reverse_op = dialect.stablehlo.reverseMany(self.getContext().mlirCtx(), self.value(), @ptrCast(actual_axes), loc);
return _result(self._shape, reverse_op.result(0));
}
pub const GatherOpts = struct { indices_are_sorted: bool = false }; pub const GatherOpts = struct { indices_are_sorted: bool = false };
/// For each coordinate in `indices`, /// For each coordinate in `indices`,
@ -2448,7 +2458,7 @@ pub const Tensor = struct {
return result; return result;
} }
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: i64) [n_chunks]Tensor { pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor {
const a = self.axis(axis_); const a = self.axis(axis_);
const length = self.dim(a); const length = self.dim(a);
_ = @divExact(length, n_chunks); _ = @divExact(length, n_chunks);
@ -3484,3 +3494,9 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va
} }
return res; return res;
} }
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized(comptime T: type) type {
const M = meta.MapType(Tensor, Buffer);
return M.map(T);
}

View File

@ -4,7 +4,7 @@
//! //!
pub const Buffer = @import("buffer.zig").Buffer; pub const Buffer = @import("buffer.zig").Buffer;
pub const Bufferized = @import("buffer.zig").Bufferized; pub const Bufferized = @import("tensor.zig").Bufferized;
pub const CompilationOptions = @import("platform.zig").CompilationOptions; pub const CompilationOptions = @import("platform.zig").CompilationOptions;
pub const Context = @import("context.zig").Context; pub const Context = @import("context.zig").Context;
pub const Data = @import("dtype.zig").Data; pub const Data = @import("dtype.zig").Data;