zml/nn: fix resize implementations (resizeBilinear and resizeBicubic) and expand refAllDecl usage; all tests pass
This commit is contained in:
parent
5e1688cbfd
commit
7dcd8b516c
@ -384,18 +384,6 @@ pub fn reverse(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn reverseMany(ctx: mlir.Context, operand: mlir.Value, dimensions: []const i64, location: mlir.Location) mlir.Operation {
|
||||
const result_type = operand.getType();
|
||||
return mlir.Operation.make(ctx, "stablehlo.reverse", .{
|
||||
.operands = &.{operand},
|
||||
.results = &.{result_type},
|
||||
.attributes = &.{
|
||||
.{ "dimensions", mlir.DenseArrayAttribute(.i64).init(ctx, dimensions).as(mlir.Attribute).? },
|
||||
},
|
||||
.location = location,
|
||||
});
|
||||
}
|
||||
|
||||
pub fn compare(ctx: mlir.Context, lhs: mlir.Value, rhs: mlir.Value, comparison_direction: ComparisonDirection, compare_type: CompareType, location: mlir.Location) mlir.Operation {
|
||||
return mlir.Operation.make(ctx, "stablehlo.compare", .{
|
||||
.operands = &.{ lhs, rhs },
|
||||
|
||||
17
zml/aio.zig
17
zml/aio.zig
@ -18,6 +18,19 @@ pub const log = std.log.scoped(.zml_aio);
|
||||
pub const Value = @import("aio/value.zig").Value;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(gguf);
|
||||
// TODO(@cryptodeal)
|
||||
// std.testing.refAllDecls(nemo);
|
||||
std.testing.refAllDecls(safetensors);
|
||||
std.testing.refAllDecls(sentencepiece);
|
||||
std.testing.refAllDecls(tinyllama);
|
||||
std.testing.refAllDecls(torch);
|
||||
// TODO(@cryptodeal)
|
||||
// std.testing.refAllDecls(yaml);
|
||||
}
|
||||
|
||||
/// Detects the format of the model file (base on filename) and open it.
|
||||
pub fn detectFormatAndOpen(allocator: std.mem.Allocator, model_path: []const u8) !BufferStore {
|
||||
return if (std.mem.endsWith(u8, model_path, ".safetensors"))
|
||||
@ -585,7 +598,3 @@ fn visitStructAndLoadBuffer(allocator: std.mem.Allocator, prefix_builder: *Prefi
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
@ -30,9 +30,9 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
|
||||
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
|
||||
switch (val) {
|
||||
.int => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
|
||||
.float => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
|
||||
.string => |v| try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
|
||||
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
|
||||
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
|
||||
.string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
|
||||
.list => |v| switch (validSlice(v)) {
|
||||
true => {
|
||||
if (v.len == 0) return;
|
||||
@ -41,13 +41,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
|
||||
const values = try allocator.alloc(i64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.int;
|
||||
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = utils.toVoidSlice(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } });
|
||||
},
|
||||
.float => {
|
||||
const values = try allocator.alloc(f64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.float;
|
||||
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = utils.toVoidSlice(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } });
|
||||
},
|
||||
.string => {
|
||||
const values = try allocator.alloc([]const u8, v.len);
|
||||
@ -55,7 +55,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
|
||||
for (v, 0..) |item, i| {
|
||||
values[i] = try allocator.dupe(u8, item.string);
|
||||
}
|
||||
try store.metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = utils.toVoidSlice(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } });
|
||||
},
|
||||
.list => unreachable,
|
||||
else => {},
|
||||
|
||||
@ -4,15 +4,17 @@ const testing = std.testing;
|
||||
const meta = @import("meta.zig");
|
||||
const pjrt = @import("pjrt");
|
||||
const pjrtx = @import("pjrtx.zig");
|
||||
const platform = @import("platform.zig");
|
||||
|
||||
const Context = @import("context.zig").Context;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const Target = @import("platform.zig").Target;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(Buffer);
|
||||
}
|
||||
|
||||
/// Buffer is a multi-dimension array, whose memory is allocated on an accelerator.
|
||||
///
|
||||
@ -22,54 +24,74 @@ const Target = @import("platform.zig").Target;
|
||||
pub const Buffer = struct {
|
||||
_shape: Shape,
|
||||
_shards: Shape = undefined,
|
||||
_platform: platform.Platform,
|
||||
_platform: Platform,
|
||||
_data: *pjrtx.Buffer,
|
||||
|
||||
/// Copies the content of the given buffer from host memory to the accelerator memory.
|
||||
pub fn from(platform_: platform.Platform, buf: HostBuffer) !Buffer {
|
||||
const pjrt_buffer = try platform_.pjrt_client.bufferFromHostBuffer(platform_.pjrt_api, .{
|
||||
pub fn from(platform: Platform, buf: HostBuffer) !Buffer {
|
||||
const pjrt_buffer = try platform.pjrt_client.bufferFromHostBuffer(platform.pjrt_api, .{
|
||||
.data = buf.data,
|
||||
.buffer_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
|
||||
.dims = buf.shape().dims(),
|
||||
.byte_strides = null,
|
||||
.device = platform_.getDevices()[0],
|
||||
.byte_strides = buf.strides(),
|
||||
.device = platform.getDevices()[0],
|
||||
.host_buffer_semantics = .ImmutableUntilTransferCompletes,
|
||||
});
|
||||
return .{
|
||||
._platform = platform_,
|
||||
._platform = platform,
|
||||
._shape = buf.shape(),
|
||||
._data = pjrt_buffer,
|
||||
};
|
||||
}
|
||||
|
||||
/// Wraps a pre-exisiting `pjrt.Buffer` into a `zml.Buffer`.
|
||||
pub fn fromPjrtBuffer(platform_: platform.Platform, pjrt_buffer: *pjrtx.Buffer) Buffer {
|
||||
pub fn fromPjrtBuffer(platform: Platform, pjrt_buffer: *pjrtx.Buffer) Buffer {
|
||||
return .{
|
||||
._platform = platform_,
|
||||
._shape = _shapeFromPjrtBuffer(platform_, pjrt_buffer),
|
||||
._platform = platform,
|
||||
._shape = _shapeFromPjrtBuffer(platform, pjrt_buffer),
|
||||
._data = pjrt_buffer,
|
||||
};
|
||||
}
|
||||
|
||||
/// Copies the given Zig slice to the accelerator memory and
|
||||
/// return a Buffer with the given dimensions.
|
||||
pub fn fromSlice(platform_: platform.Platform, dimz: anytype, s: anytype) !Buffer {
|
||||
pub fn fromSlice(platform: Platform, dimz: anytype, s: anytype) !Buffer {
|
||||
const sh = Shape.init(dimz, DataType.fromSliceElementType(s));
|
||||
return from(platform_, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
|
||||
return from(platform, HostBuffer.fromBytes(sh, std.mem.sliceAsBytes(s)));
|
||||
}
|
||||
|
||||
/// Copies the given Zig array to the accelerator memory and
|
||||
/// return a Buffer using the array shape.
|
||||
pub fn fromArray(platform_: platform.Platform, arr: anytype) !Buffer {
|
||||
pub fn fromArray(platform: Platform, arr: anytype) !Buffer {
|
||||
const host_buffer = HostBuffer.fromArray(&arr);
|
||||
return try host_buffer.toDevice(platform_);
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
|
||||
/// Creates a Buffer with a single element.
|
||||
pub fn scalar(platform_: platform.Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||
pub fn scalar(platform: Platform, val: anytype, dtype_: DataType) !Buffer {
|
||||
const x = dtype_.constant(val);
|
||||
const host_buffer = HostBuffer.fromBytes(Shape.init(.{}, dtype_), x.constSlice());
|
||||
return try host_buffer.toDevice(platform_);
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
|
||||
/// Creates a Buffer with a single element repeated manytime.
|
||||
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
|
||||
const x = shape_.dtype().constant(val);
|
||||
const host_buffer: HostBuffer = .{
|
||||
._shape = shape_,
|
||||
._strides = [1]i64{0} ** Shape.MAX_RANK,
|
||||
.data = x.constSlice(),
|
||||
};
|
||||
return try from(platform, host_buffer);
|
||||
}
|
||||
|
||||
test constant {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const x = try constant(platform, Shape.init(.{ 4, 3, 2 }, .u16), 42);
|
||||
const y = try x.getValue([4 * 3 * 2]u16);
|
||||
try std.testing.expectEqual([_]u16{42} ** (4 * 3 * 2), y);
|
||||
}
|
||||
|
||||
/// Creates a Buffer as a view of memory visible from the device,
|
||||
@ -79,7 +101,7 @@ pub const Buffer = struct {
|
||||
/// Be careful though, as it requires a specific alignment.
|
||||
/// Also note that it might not work on all platforms,
|
||||
/// could lead to crashes and is considerably slower.
|
||||
pub fn asViewOf(platform_: platform.Platform, buf: HostBuffer) !Buffer {
|
||||
pub fn asViewOf(platform: Platform, buf: HostBuffer) !Buffer {
|
||||
const minor_to_major: [Shape.MAX_RANK]i64 = comptime blk: {
|
||||
var res: [Shape.MAX_RANK]i64 = undefined;
|
||||
for (0..Shape.MAX_RANK) |i| {
|
||||
@ -88,11 +110,11 @@ pub const Buffer = struct {
|
||||
break :blk res;
|
||||
};
|
||||
|
||||
const pjrt_buffer = try platform_.pjrt_client.createViewOfDeviceBuffer(platform_.pjrt_api, .{
|
||||
const pjrt_buffer = try platform.pjrt_client.createViewOfDeviceBuffer(platform.pjrt_api, .{
|
||||
.data = buf.data,
|
||||
.element_type = pjrtx.Buffer.BufferTypeFromDType(buf.shape().dtype()),
|
||||
.dims = buf.shape().dims(),
|
||||
.device = platform_.getDevices()[0],
|
||||
.device = platform.getDevices()[0],
|
||||
.layout = .{
|
||||
.Tiled = .{
|
||||
.minor_to_major = minor_to_major[Shape.MAX_RANK - buf.shape().rank() ..],
|
||||
@ -103,7 +125,7 @@ pub const Buffer = struct {
|
||||
});
|
||||
|
||||
return .{
|
||||
._platform = platform_,
|
||||
._platform = platform,
|
||||
._shape = buf.shape(),
|
||||
._data = pjrt_buffer,
|
||||
};
|
||||
@ -177,11 +199,11 @@ pub const Buffer = struct {
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
try writer.print("Tensor({_})", .{self._shape});
|
||||
try writer.print("Buffer({_})", .{self._shape});
|
||||
}
|
||||
|
||||
fn _shapeFromPjrtBuffer(platform_: platform.Platform, buf: *pjrtx.Buffer) Shape {
|
||||
const dt: DataType = switch (buf.getElementType(platform_.pjrt_api)) {
|
||||
fn _shapeFromPjrtBuffer(platform: Platform, buf: *pjrtx.Buffer) Shape {
|
||||
const dt: DataType = switch (buf.getElementType(platform.pjrt_api)) {
|
||||
// Please keep the list exhaustive and in the same order than in DataType.
|
||||
.PRED => .bool,
|
||||
.F8E4M3B11FNUZ => .f8e4m3b11fnuz,
|
||||
@ -208,14 +230,6 @@ pub const Buffer = struct {
|
||||
.INVALID => @panic("Can't handle INVALID Pjrt buffers."),
|
||||
};
|
||||
|
||||
return Shape.init(buf.getDimensions(platform_.pjrt_api), dt);
|
||||
return Shape.init(buf.getDimensions(platform.pjrt_api), dt);
|
||||
}
|
||||
|
||||
pub const From = meta.MapType(Tensor, Buffer).map;
|
||||
};
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
const M = meta.MapType(Tensor, Buffer);
|
||||
return M.map(T);
|
||||
}
|
||||
|
||||
@ -1,15 +1,14 @@
|
||||
const std = @import("std");
|
||||
|
||||
const meta = @import("meta.zig");
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Tensor = @import("tensor.zig").Tensor;
|
||||
const Data = @import("dtype.zig").Data;
|
||||
const DataType = @import("dtype.zig").DataType;
|
||||
const Platform = @import("platform.zig").Platform;
|
||||
const meta = @import("meta.zig");
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(HostBuffer);
|
||||
}
|
||||
|
||||
/// Represents a tensor with associated data allocated by user code.
|
||||
@ -99,10 +98,16 @@ pub const HostBuffer = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub const ArangeArgs = struct {
|
||||
start: i64 = 0,
|
||||
end: i64,
|
||||
step: i64 = 1,
|
||||
};
|
||||
|
||||
/// Allocates a HostBuffer with the given shape.
|
||||
/// The memory is initialized with increasing numbers.
|
||||
/// The caller owns the memory, and need to call `deinit()`.
|
||||
pub fn arange(allocator: std.mem.Allocator, args: Tensor.ArangeArgs, dt: DataType) !HostBuffer {
|
||||
pub fn arange(allocator: std.mem.Allocator, args: ArangeArgs, dt: DataType) !HostBuffer {
|
||||
meta.assert(args.start < args.end, "arange expects 'args.start' to be less than 'args.end', got {} and {}", .{ args.start, args.end });
|
||||
meta.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||
|
||||
@ -137,12 +142,6 @@ pub const HostBuffer = struct {
|
||||
}
|
||||
}
|
||||
|
||||
/// Embeds a tensor with concrete values into an Mlir program.
|
||||
/// The content is copied, so the HostBuffer can be safely `deinit`.
|
||||
pub fn toStaticTensor(self: HostBuffer) Tensor {
|
||||
return Tensor.staticTensor(self.shape(), self.data);
|
||||
}
|
||||
|
||||
/// Copies this HostBuffer to the given accelerator.
|
||||
pub fn toDevice(self: HostBuffer, platform_: Platform) !Buffer {
|
||||
return try Buffer.from(platform_, self);
|
||||
@ -164,8 +163,9 @@ pub const HostBuffer = struct {
|
||||
return self._shape.dtype();
|
||||
}
|
||||
|
||||
pub fn strides(self: HostBuffer) ?[]const i64 {
|
||||
return self._strides;
|
||||
pub fn strides(self: *const HostBuffer) ?[]const i64 {
|
||||
// Pass strides per pointer otherwise we return a pointer to this stack frame.
|
||||
return if (self._strides) |*strd| strd[0..self.rank()] else null;
|
||||
}
|
||||
|
||||
pub fn data(self: HostBuffer) []const u8 {
|
||||
|
||||
@ -23,7 +23,7 @@ const Tensor = @import("tensor.zig").Tensor;
|
||||
const ShapeOf = @import("tensor.zig").ShapeOf;
|
||||
const Shape = @import("shape.zig").Shape;
|
||||
const Buffer = @import("buffer.zig").Buffer;
|
||||
const Bufferized = @import("buffer.zig").Bufferized;
|
||||
const Bufferized = @import("tensor.zig").Bufferized;
|
||||
const Tracer = @import("tools/tracer.zig").Tracer;
|
||||
|
||||
const log = std.log.scoped(.zml_module);
|
||||
|
||||
160
zml/nn.zig
160
zml/nn.zig
@ -16,6 +16,11 @@ const log = std.log.scoped(.zml_tensor);
|
||||
|
||||
const cuda = @import("nn/cuda.zig");
|
||||
|
||||
test {
|
||||
_ = cuda;
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
pub const Linear = struct {
|
||||
weight: Tensor,
|
||||
bias: ?Tensor = null,
|
||||
@ -302,7 +307,7 @@ test "real/img" {
|
||||
}
|
||||
|
||||
test "rope" {
|
||||
const platofrm = zml.testing.env();
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const TestRope = struct {
|
||||
fn forward(x: Tensor, opts: RopeOpts) Tensor {
|
||||
@ -326,9 +331,9 @@ test "rope" {
|
||||
|
||||
// x is made such as the interleaved and sequential reps are the same.
|
||||
// So the two implementations should give the same results.
|
||||
const x = try zml.Buffer.fromSlice(platofrm, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
||||
const res1 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } });
|
||||
const res2 = try zml.testing.compileAndCall(platofrm, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } });
|
||||
const x = try zml.Buffer.fromSlice(platform, .{ 1, 5, 4 }, &[_]f32{ 1.0, 0.1, -1.0, -0.5 } ** 5);
|
||||
const res1 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .interleaved } });
|
||||
const res2 = try zml.testing.compileAndCall(platform, TestRope.forward, .{ x, RopeOpts{ .impl = .sequential } });
|
||||
|
||||
try zml.testing.expectClose(res1, res2, 1e-4);
|
||||
}
|
||||
@ -516,70 +521,141 @@ test nearest {
|
||||
}
|
||||
}
|
||||
|
||||
pub const ResizeOpts = struct { original_len: ?Tensor = null };
|
||||
pub const ResizeOpts = struct {
|
||||
/// scalar tensor containing the original dimension of the image.
|
||||
/// It can be different from the image shape,
|
||||
/// if the image has been padded.
|
||||
/// This allows to compile one module that handle different input image sizes.
|
||||
original_len: ?Tensor = null,
|
||||
|
||||
pub fn resizeBilinear(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor {
|
||||
/// Internal precision to do the interpolation.
|
||||
/// Result will always use the same dtype than the original.
|
||||
/// If not set, will use the image dtype, unless it's an integer type, in which case f32 will be used.
|
||||
precision: ?zml.DataType = null,
|
||||
};
|
||||
|
||||
pub fn resizeBilinear(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor {
|
||||
const new_size, const tags_ = Shape.parseStruct(u63, resized_axes);
|
||||
var out = image;
|
||||
for (axes, dims) |a, d| {
|
||||
for (new_size.constSlice(), tags_.constSlice()) |d, t| {
|
||||
const ax = image.shape().axis(t);
|
||||
const child_opt: ResizeOpts = .{
|
||||
.original_len = if (opt.original_len) |o| o.choose1d(0, a) else null,
|
||||
.original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null,
|
||||
};
|
||||
out = resizeLinear1d(out, a, d, child_opt);
|
||||
out = resizeLinear1d(out, ax, d, child_opt);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
test resizeBilinear {
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } },
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } },
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const resizing, const res_shape = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const y = resizeBilinear(x, resizing, .{});
|
||||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||||
try std.testing.expect(y.value().owner().verify());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resizeLinear1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor {
|
||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32);
|
||||
const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len));
|
||||
const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio);
|
||||
const res_shape = image.shape().set(axis, new_len);
|
||||
|
||||
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
||||
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
|
||||
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
||||
const left = scaled.floor();
|
||||
const right = left.addConstant(1);
|
||||
|
||||
const values = image.gatherSlices1d(axis, 2, left.convert(.i32), .{ .indices_are_sorted = true });
|
||||
const left_val, const right_val = helpers.mapTensors(
|
||||
Tensor.squeeze,
|
||||
values.convert(.f32).chunkExact(2, axis + 1),
|
||||
.{@as(i64, @intCast(axis + 1))},
|
||||
);
|
||||
const left_weight = right.sub(scaled).broadcast(left_val.shape(), &.{axis});
|
||||
const right_weight = scaled.sub(left).broadcast(left_val.shape(), &.{axis});
|
||||
// TODO: check that two gather isn't too bad perf wise.
|
||||
// Normally we should use gatherSlices to collect the values 2 by 2,
|
||||
// but gatherSlices messes up with the order of axes.
|
||||
const left_val = image.gatherValues(axis, left.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
|
||||
const right_val = image.gatherValues(axis, right.convert(.i32), .{ .indices_are_sorted = true }).convert(dtype);
|
||||
|
||||
return left_val.mul(left_weight).add(right_val.mul(right_weight)).convert(image.dtype());
|
||||
const left_weight = right.sub(scaled).broadcast(res_shape, &.{axis});
|
||||
const right_weight = scaled.sub(left).broadcast(res_shape, &.{axis});
|
||||
|
||||
const res = left_val.mul(left_weight).add(right_val.mul(right_weight));
|
||||
return res.convert(image.dtype()).withTags(image.shape().tags());
|
||||
}
|
||||
|
||||
/// Bicubic interpolation of the given image.
|
||||
/// Warning as of May 2024 the cpu backend don't optimize this very well
|
||||
/// and is not able to merge the weighting with the gather,
|
||||
/// leading to 20x slow down compared to STB implementation.
|
||||
pub fn resizeBicubic(image: Tensor, axes: []const i8, dims: []const u63, opt: ResizeOpts) Tensor {
|
||||
pub fn resizeBicubic(image: Tensor, resized_axes: anytype, opt: ResizeOpts) Tensor {
|
||||
const new_size, const tags_ = Shape.parseStruct(u63, resized_axes);
|
||||
var out = image;
|
||||
for (axes, dims) |a, d| {
|
||||
for (new_size.constSlice(), tags_.constSlice()) |d, t| {
|
||||
const ax = image.shape().axis(t);
|
||||
const child_opt: ResizeOpts = .{
|
||||
.original_len = if (opt.original_len) |o| o.choose1d(0, a) else null,
|
||||
.original_len = if (opt.original_len) |o| o.choose1d(0, ax) else null,
|
||||
};
|
||||
out = resizeCubic1d(out, a, d, child_opt);
|
||||
out = resizeCubic1d(out, ax, d, child_opt);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
test resizeBicubic {
|
||||
const platform = zml.testing.env();
|
||||
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
||||
defer comp.deinit();
|
||||
comp.activate();
|
||||
defer comp.deactivate();
|
||||
|
||||
inline for (.{
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .a = 20 }, .{ .a = 20, .b = 10 } },
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .b = 5 }, .{ .a = 10, .b = 5 } },
|
||||
.{ .{ .a = 10, .b = 10 }, .{ .a = 20, .b = 5 }, .{ .a = 20, .b = 5 } },
|
||||
}) |testcase| {
|
||||
const x_shape, const resizing, const res_shape = testcase;
|
||||
const x = Tensor.constant(x_shape, .{ .f16 = 0 });
|
||||
const y = resizeBicubic(x, resizing, .{});
|
||||
try zml.testing.expectEqualShapes(Shape.init(res_shape, .f16), y.shape());
|
||||
try std.testing.expect(y.value().owner().verify());
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Tensor {
|
||||
// Extract neighboring pixels from the image.
|
||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), .f32);
|
||||
const ratio = og_len.convert(.f32).scale(meta.divFloat(f32, 1, new_len));
|
||||
const scaled = Tensor.arange(.{ .end = new_len }, .f32).mul(ratio);
|
||||
const dtype = opt.precision orelse if (image.dtype().class() == .integer) .f32 else image.dtype();
|
||||
const og_len = opt.original_len orelse Tensor.scalar(image.dim(axis), dtype);
|
||||
|
||||
const ratio = og_len.convert(dtype).scale(meta.divFloat(f32, 1, new_len));
|
||||
const scaled = Tensor.arange(.{ .end = new_len }, dtype).mul(ratio);
|
||||
const t = scaled.sub(scaled.floor());
|
||||
const pos = Tensor.stack(&.{
|
||||
Tensor.scalar(1, .f32).broadcast(t.shape(), &.{}),
|
||||
Tensor.constant(t.shape(), dtype.one()),
|
||||
t,
|
||||
t.mul(t),
|
||||
t.pow(Tensor.scalar(3, .f32)),
|
||||
}, -1, .features);
|
||||
t.pow(Tensor.scalar(3, dtype)),
|
||||
}, .last, ._interpolated);
|
||||
|
||||
std.debug.assert(pos.dim(0) == new_len);
|
||||
std.debug.assert(pos.dim(1) == 4);
|
||||
|
||||
const context = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32));
|
||||
const values = image.gatherSlices1d(axis, 4, context, .{ .indices_are_sorted = true });
|
||||
const neighbors = scaled.floor().addConstant(-1).convert(.i32).maximum(Tensor.scalar(0, .i32));
|
||||
|
||||
const values = image.renameAxis(axis, ._neighbors).gatherSlices(
|
||||
Shape.init(.{ ._neighbors = 4 }, image.dtype()),
|
||||
neighbors.appendAxes(.{.coord}),
|
||||
.{ .indices_are_sorted = true },
|
||||
).convert(dtype);
|
||||
|
||||
const weights_: [4][4]f32 = .{
|
||||
.{ 0, 1, 0, 0 },
|
||||
@ -587,22 +663,24 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
|
||||
.{ 1, -2.5, 2, -0.5 },
|
||||
.{ -0.5, 1.5, -1.5, 0.5 },
|
||||
};
|
||||
const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_));
|
||||
const weights = zml.Tensor.constantTensor(zml.HostBuffer.fromArray(&weights_)).convert(dtype).withTags(.{ ._interpolated, ._neighbors });
|
||||
|
||||
// actually do the interpolation.
|
||||
// Note: ideally this matmul should be inlined with the gather, but that's currently not the case.
|
||||
var res = values.convert(.f32).dotGeneral(weights, &.{.{ axis + 1, 1 }}, &.{});
|
||||
res = pos.dotGeneral(res, &.{.{ 1, image.rank() }}, &.{.{ 0, axis }});
|
||||
// TODO: not being able to use dot here is a bit annoying.
|
||||
var res = values.dotGeneral(weights, &.{.{ values.axis(._neighbors), weights.axis(._neighbors) }}, &.{});
|
||||
res = pos.dotGeneral(res, &.{.{ pos.axis(._interpolated), res.axis(._interpolated) }}, &.{.{ 0, 0 }});
|
||||
|
||||
// the current axis is outputted in first position because it's a batching dim, put it back in place.
|
||||
// if (axis != 0)
|
||||
// res = res.transpose(Shape.range(image.rank()).swap(0, axis).dims());
|
||||
if (axis != 0) {
|
||||
res = res.swapAxes(0, axis);
|
||||
}
|
||||
|
||||
// verify the shape
|
||||
const res_shape = image.shape().set(axis, new_len);
|
||||
// log.debug("resizeCubic1d: ({}, {}, {}, {}) -> {}", .{ image, axis, new_len, opt, res });
|
||||
std.debug.assert(std.mem.eql(i64, res_shape.dims(), res.dims()));
|
||||
return res.convert(image.dtype());
|
||||
return res.convert(image.dtype()).withTags(image.shape());
|
||||
}
|
||||
|
||||
/// Return causal attention masks for the given shape.
|
||||
@ -927,7 +1005,3 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
||||
// log.debug("sampleTokens({}) -> {} -> {} -> {}", .{ activations, topk.indices, topk_idx, next_tokens });
|
||||
return .{ next_tokens, next_rng };
|
||||
}
|
||||
|
||||
test {
|
||||
_ = cuda;
|
||||
}
|
||||
|
||||
@ -9,7 +9,7 @@ const EnumLiteral = @TypeOf(.enum_literal);
|
||||
const log = std.log.scoped(.shape);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(Shape);
|
||||
}
|
||||
|
||||
/// Represent the shape of a tensor.
|
||||
|
||||
@ -23,12 +23,12 @@ const dialect = struct {
|
||||
const stablehlo = @import("mlir/dialects").stablehlo;
|
||||
};
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
const scoped_log = std.log.scoped(.zml_tensor);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(Tensor);
|
||||
}
|
||||
|
||||
/// Represents an abstract Tensor object, which can be the input,
|
||||
/// output, weight or activations of a neural network.
|
||||
/// Tensor are abstract in the sense they only represent a computation,
|
||||
@ -166,6 +166,12 @@ pub const Tensor = struct {
|
||||
return res;
|
||||
}
|
||||
|
||||
pub fn renameAxis(self: Tensor, ax: i8, name: EnumLiteral) Tensor {
|
||||
var res = self;
|
||||
res._shape._tags.set(self.axis(ax), @tagName(name).ptr);
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns the mlir.Value associated with the Tensor.
|
||||
///
|
||||
/// This will fail if used outside of a compilation context.
|
||||
@ -462,7 +468,7 @@ pub const Tensor = struct {
|
||||
};
|
||||
}
|
||||
|
||||
pub fn init(platform: Platform, seed: u128) !Buffer.From(Rng) {
|
||||
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
|
||||
const bits: [2]u64 = @bitCast(seed);
|
||||
return .{
|
||||
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
|
||||
@ -1071,7 +1077,12 @@ pub const Tensor = struct {
|
||||
/// In this version batching dimensions need to be explicitly specified.
|
||||
/// The result shape is made of (batching_axes ++ lhs_result_axes ++ rhs_result_axes.
|
||||
/// Where "result axes" are non-contracting, non-batching axes of each input tensor.
|
||||
pub fn dotGeneral(lhs: Tensor, rhs: Tensor, contracting_axes: []const [2]i8, batching_axes: []const [2]i8) Tensor {
|
||||
pub fn dotGeneral(
|
||||
lhs: Tensor,
|
||||
rhs: Tensor,
|
||||
contracting_axes: []const [2]i8,
|
||||
batching_axes: []const [2]i8,
|
||||
) Tensor {
|
||||
meta.assert(lhs.dtype() == rhs.dtype(), "dotGeneral expects tensors to be of the same type, got {} and {}", .{ lhs.dtype(), rhs.dtype() });
|
||||
|
||||
const Axes = std.BoundedArray(i64, MAX_RANK);
|
||||
@ -1273,6 +1284,17 @@ pub const Tensor = struct {
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
|
||||
pub fn swapAxes(self: Tensor, a: anytype, b: anytype) Tensor {
|
||||
if (self.axis(a) == self.axis(b)) return self;
|
||||
var perm: Shape.AxesArray = .{};
|
||||
for (0..self.rank()) |i| {
|
||||
perm.appendAssumeCapacity(@intCast(i));
|
||||
}
|
||||
perm.set(self.axis(a), self.axis(b));
|
||||
perm.set(self.axis(b), self.axis(a));
|
||||
return self.transpose(perm.constSlice());
|
||||
}
|
||||
|
||||
/// Returns a Tensor with the given axis unflattened.
|
||||
///
|
||||
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
||||
@ -1577,11 +1599,7 @@ pub const Tensor = struct {
|
||||
return _result(self._shape, expm1_op.result(0));
|
||||
}
|
||||
|
||||
pub const ArangeArgs = struct {
|
||||
start: i64 = 0,
|
||||
end: i64,
|
||||
step: i64 = 1,
|
||||
};
|
||||
pub const ArangeArgs = HostBuffer.ArangeArgs;
|
||||
|
||||
/// Returns a Tensor containing evenly spaced values within a given interval.
|
||||
pub fn arange(args: ArangeArgs, dt: DataType) Tensor {
|
||||
@ -1871,14 +1889,6 @@ pub const Tensor = struct {
|
||||
return _result(self._shape, reverse_op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor with the given axes reversed.
|
||||
pub fn reverseMany(self: Tensor, axes_: []const i64) Tensor {
|
||||
const actual_axes = self._shape.axes(axes_).constSlice();
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({d})", .{actual_axes});
|
||||
const reverse_op = dialect.stablehlo.reverseMany(self.getContext().mlirCtx(), self.value(), @ptrCast(actual_axes), loc);
|
||||
return _result(self._shape, reverse_op.result(0));
|
||||
}
|
||||
|
||||
pub const GatherOpts = struct { indices_are_sorted: bool = false };
|
||||
|
||||
/// For each coordinate in `indices`,
|
||||
@ -2448,7 +2458,7 @@ pub const Tensor = struct {
|
||||
return result;
|
||||
}
|
||||
|
||||
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: i64) [n_chunks]Tensor {
|
||||
pub fn chunkExact(self: Tensor, n_chunks: comptime_int, axis_: anytype) [n_chunks]Tensor {
|
||||
const a = self.axis(axis_);
|
||||
const length = self.dim(a);
|
||||
_ = @divExact(length, n_chunks);
|
||||
@ -3484,3 +3494,9 @@ fn _collectAxes(T: type, bounded_array: std.BoundedArray(T, Tensor.MAX_RANK), va
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
const M = meta.MapType(Tensor, Buffer);
|
||||
return M.map(T);
|
||||
}
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
//!
|
||||
|
||||
pub const Buffer = @import("buffer.zig").Buffer;
|
||||
pub const Bufferized = @import("buffer.zig").Bufferized;
|
||||
pub const Bufferized = @import("tensor.zig").Bufferized;
|
||||
pub const CompilationOptions = @import("platform.zig").CompilationOptions;
|
||||
pub const Context = @import("context.zig").Context;
|
||||
pub const Data = @import("dtype.zig").Data;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user