zml: Introduce arena allocator in CompilationContext. Expose arena allocator to replace existing allocator, enabling safe allocation for ops without misusing std.BoundedArray. Includes breaking changes to chunkAllowTrailing and split. Upgrade axis_ types to anytype for tag handling and add TODOs for upcoming Tensor API.
This commit is contained in:
parent
57bf667c90
commit
6e4fef8844
@ -150,11 +150,17 @@ pub fn FnParam(comptime func: anytype, comptime n: comptime_int) type {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn FnArgs(comptime func: anytype) type {
|
pub fn FnArgs(comptime func: anytype) type {
|
||||||
|
debug.assertComptime(!@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgs expects non generic function, got: {}", .{@TypeOf(func)});
|
||||||
return FnSignature(func, null).ArgsT;
|
return FnSignature(func, null).ArgsT;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn FnArgsWithHint(comptime func: anytype, ArgsT: type) type {
|
||||||
|
debug.assertComptime(@typeInfo(@TypeOf(func)).Fn.is_generic, "FnArgsWithHint expects a generic function, got: {}", .{@TypeOf(func)});
|
||||||
|
return FnSignature(func, ArgsT).ArgsT;
|
||||||
|
}
|
||||||
|
|
||||||
pub fn FnResult(comptime func: anytype) type {
|
pub fn FnResult(comptime func: anytype) type {
|
||||||
return FnSignature(func, null).ReturnT;
|
return @typeInfo(@TypeOf(func)).Fn.return_type orelse @compileError("anytype is not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn Head(Tuple: type) type {
|
pub fn Head(Tuple: type) type {
|
||||||
|
|||||||
@ -12,7 +12,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
|||||||
return std.meta.ArgsTuple(funcT);
|
return std.meta.ArgsTuple(funcT);
|
||||||
}
|
}
|
||||||
|
|
||||||
const args = std.meta.fields(ArgsT orelse compileError("generic function requires an explicit ArgsTuple", .{}));
|
const args = std.meta.fields(ArgsT orelse @compileError("generic function requires an explicit ArgsTuple"));
|
||||||
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
|
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
|
||||||
if (params.len != args.len) {
|
if (params.len != args.len) {
|
||||||
compileError("function {} expected {} args, got {}", .{ funcT, params.len, args.len });
|
compileError("function {} expected {} args, got {}", .{ funcT, params.len, args.len });
|
||||||
@ -23,7 +23,7 @@ pub fn ArgsTuple(comptime funcT: anytype, comptime ArgsT: ?type) type {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const T = param.type.?;
|
const T = param.type.?;
|
||||||
var num_buf: [32]u8 = undefined;
|
var num_buf: [8]u8 = undefined;
|
||||||
tuple_fields[i] = .{
|
tuple_fields[i] = .{
|
||||||
.name = blk: {
|
.name = blk: {
|
||||||
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
|
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
|
||||||
|
|||||||
@ -88,7 +88,7 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
|
|||||||
try prefix_builder.push(allocator, prefix);
|
try prefix_builder.push(allocator, prefix);
|
||||||
defer prefix_builder.deinit(allocator);
|
defer prefix_builder.deinit(allocator);
|
||||||
|
|
||||||
const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
|
const unique_id = zml.Tensor._reserveIdRange(@intCast(store.buffers.count()));
|
||||||
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
|
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
|
||||||
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
|
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
|
||||||
};
|
};
|
||||||
|
|||||||
146
zml/module.zig
146
zml/module.zig
@ -86,18 +86,17 @@ pub const CompilationContext = struct {
|
|||||||
_platform: Platform,
|
_platform: Platform,
|
||||||
_name: []const u8,
|
_name: []const u8,
|
||||||
|
|
||||||
|
_arena: std.heap.ArenaAllocator,
|
||||||
_mlir_ctx: mlir.Context,
|
_mlir_ctx: mlir.Context,
|
||||||
_mlir_registry: mlir.Registry,
|
_mlir_registry: mlir.Registry,
|
||||||
_mlir_canonicalizer: mlir.PassManager,
|
_mlir_canonicalizer: mlir.PassManager,
|
||||||
|
|
||||||
_module: mlir.Module,
|
_module: mlir.Module,
|
||||||
|
|
||||||
_blocks: std.BoundedArray(Block, 64),
|
_blocks: std.BoundedArray(Block, 64) = .{},
|
||||||
_fn_cache: FnCache,
|
_fn_cache: FnCache = .{},
|
||||||
// TODO: make this an arena, that way it's fine if ops allocate inside it.
|
|
||||||
_allocator: std.mem.Allocator,
|
|
||||||
|
|
||||||
_buffer_to_arg: TensorToBlockArg = .{},
|
_block_args: TensorToBlockArg = .{},
|
||||||
_unique_id: u64 = 10000,
|
_unique_id: u64 = 10000,
|
||||||
_tracer: Tracer,
|
_tracer: Tracer,
|
||||||
|
|
||||||
@ -107,7 +106,7 @@ pub const CompilationContext = struct {
|
|||||||
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
const TensorToBlockArg = std.AutoHashMapUnmanaged(Tensor._Id, struct { mlir.Value, Tensor._Donation });
|
||||||
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
const AttributeList = std.BoundedArray(mlir.NamedAttribute, 3);
|
||||||
|
|
||||||
pub fn init(allocator: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
|
pub fn init(allocator_: std.mem.Allocator, name: []const u8, platform: Platform) !CompilationContext {
|
||||||
const mlir_registry = mlir.Registry.init() catch unreachable;
|
const mlir_registry = mlir.Registry.init() catch unreachable;
|
||||||
inline for (.{ "func", "stablehlo" }) |d| {
|
inline for (.{ "func", "stablehlo" }) |d| {
|
||||||
mlir.DialectHandle.fromString(d).insertDialect(mlir_registry);
|
mlir.DialectHandle.fromString(d).insertDialect(mlir_registry);
|
||||||
@ -127,6 +126,10 @@ pub const CompilationContext = struct {
|
|||||||
try opm.addPipeline("canonicalize");
|
try opm.addPipeline("canonicalize");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var arena = std.heap.ArenaAllocator.init(allocator_);
|
||||||
|
_ = try arena.allocator().alloc(u8, std.mem.page_size);
|
||||||
|
_ = arena.reset(.retain_capacity);
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
._platform = platform,
|
._platform = platform,
|
||||||
._name = name,
|
._name = name,
|
||||||
@ -135,17 +138,21 @@ pub const CompilationContext = struct {
|
|||||||
._mlir_canonicalizer = canonicalizer,
|
._mlir_canonicalizer = canonicalizer,
|
||||||
._module = module,
|
._module = module,
|
||||||
._blocks = .{},
|
._blocks = .{},
|
||||||
._fn_cache = FnCache.init(allocator),
|
._fn_cache = .{},
|
||||||
._allocator = allocator,
|
._arena = arena,
|
||||||
._tracer = Tracer.init("ai.zml.compilation"),
|
._tracer = Tracer.init("ai.zml.compilation"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: *CompilationContext) void {
|
pub fn deinit(self: *CompilationContext) void {
|
||||||
self._fn_cache.deinit();
|
// No need to deinit self._fn_cache cause it uses our arena
|
||||||
self._mlir_ctx.deinit();
|
self._mlir_ctx.deinit();
|
||||||
self._mlir_registry.deinit();
|
self._mlir_registry.deinit();
|
||||||
self._buffer_to_arg.deinit(self._allocator);
|
self._arena.deinit();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn allocator(self: *CompilationContext) std.mem.Allocator {
|
||||||
|
return self._arena.allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn activate(self: *CompilationContext) void {
|
pub fn activate(self: *CompilationContext) void {
|
||||||
@ -173,22 +180,20 @@ pub const CompilationContext = struct {
|
|||||||
/// Compiles the given function with the given arguments.
|
/// Compiles the given function with the given arguments.
|
||||||
/// This is the untyped API and is not meant to be use directly.
|
/// This is the untyped API and is not meant to be use directly.
|
||||||
///
|
///
|
||||||
/// * allocator is used to allocate the
|
/// * allocator is used to allocate the result Exe
|
||||||
/// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors.
|
/// * args can contain a mix of tensors and shapes, allowing to pass a "model struct" containig tensors.
|
||||||
pub fn compileInternal(
|
pub fn compileInternal(
|
||||||
self: *CompilationContext,
|
self: *CompilationContext,
|
||||||
allocator: std.mem.Allocator,
|
allocator_: std.mem.Allocator,
|
||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
args: anytype,
|
args: anytype,
|
||||||
) !BaseExe {
|
) !BaseExe {
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
const arena = self.allocator();
|
||||||
defer arena_state.deinit();
|
|
||||||
const arena = arena_state.allocator();
|
|
||||||
|
|
||||||
var timer = std.time.Timer.start() catch null;
|
var timer = std.time.Timer.start() catch null;
|
||||||
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
const tensor_args = try self.tensorFromShapes(stdx.meta.FnArgs(func), arena, args);
|
||||||
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||||
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, arena, func, &tensor_args, .{ .name = "main", .kind = .main } });
|
const f = try asynk.callBlocking(CompilationContext.emitMlir, .{ self, func, &tensor_args, .{ .name = "main", .kind = .main } });
|
||||||
const module = self._module;
|
const module = self._module;
|
||||||
module.getBody().appendOperation(f.mlir_fn);
|
module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
@ -251,7 +256,7 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return BaseExe.init(
|
return BaseExe.init(
|
||||||
allocator,
|
allocator_,
|
||||||
self._platform,
|
self._platform,
|
||||||
loaded_executable,
|
loaded_executable,
|
||||||
.{
|
.{
|
||||||
@ -335,7 +340,6 @@ pub const CompilationContext = struct {
|
|||||||
/// tensors with unique tensor ids.
|
/// tensors with unique tensor ids.
|
||||||
pub fn emitMlir(
|
pub fn emitMlir(
|
||||||
self: *CompilationContext,
|
self: *CompilationContext,
|
||||||
allocator: std.mem.Allocator,
|
|
||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
args: *const stdx.meta.FnArgs(func),
|
args: *const stdx.meta.FnArgs(func),
|
||||||
opts: struct {
|
opts: struct {
|
||||||
@ -346,9 +350,10 @@ pub const CompilationContext = struct {
|
|||||||
const frame = self._tracer.frameStart("emitMlir.emit");
|
const frame = self._tracer.frameStart("emitMlir.emit");
|
||||||
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
|
errdefer self._tracer.frameEnd(frame, "emitMlir.emit");
|
||||||
|
|
||||||
|
const res_allocator = self.allocator();
|
||||||
// Note: only temp allocations are done in the arena,
|
// Note: only temp allocations are done in the arena,
|
||||||
// the other allocations are managed by the caller.
|
// the other allocations are in the context allocator.
|
||||||
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
const arena = arena_state.allocator();
|
const arena = arena_state.allocator();
|
||||||
|
|
||||||
@ -367,19 +372,28 @@ pub const CompilationContext = struct {
|
|||||||
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
const input_types = try arena.alloc(mlir.Type, tensor_count);
|
||||||
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
for (input_types, input_shapes.items) |*t, sh| t.* = mlir.ext.mlirType(mlir_ctx, sh);
|
||||||
|
|
||||||
|
const og_block_args = self._block_args;
|
||||||
|
defer {
|
||||||
|
self._block_args.deinit(self.allocator());
|
||||||
|
self._block_args = og_block_args;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset the buffer -> assignement
|
||||||
|
self._block_args = .{};
|
||||||
|
|
||||||
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
|
// Note: this isn't stricly necessary. We call `countTensor` on `fn_res`.
|
||||||
// But it forces user to have simpler function.
|
// But it forces user to have simpler function.
|
||||||
const ReturnT = stdx.meta.FnResult(func);
|
const ReturnT = stdx.meta.FnResult(func);
|
||||||
const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
|
const out_tensor_count = comptime ops.staticCountTensors(ReturnT) orelse @compileError("Can't use " ++ @typeName(ReturnT) ++ " in an MLIR function, because it has a variable number of tensors");
|
||||||
// Those are returned to caller so we don't put them in the arena.
|
// Those are returned to caller so we don't put them in the arena.
|
||||||
const fn_res_types = try allocator.alloc(mlir.Type, out_tensor_count);
|
const fn_res_types = try res_allocator.alloc(mlir.Type, out_tensor_count);
|
||||||
const fn_res_shapes = try allocator.alloc(Shape, out_tensor_count);
|
const fn_res_shapes = try res_allocator.alloc(Shape, out_tensor_count);
|
||||||
const fn_res_donations = try allocator.alloc(Tensor._Donation, out_tensor_count);
|
const fn_res_donations = try res_allocator.alloc(Tensor._Donation, out_tensor_count);
|
||||||
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
var fn_body = self.openBlock(.hermetic, input_types, locations) catch unreachable;
|
||||||
{
|
{
|
||||||
defer self.closeBlock(fn_body);
|
defer self.closeBlock(fn_body);
|
||||||
|
|
||||||
try self._buffer_to_arg.ensureUnusedCapacity(self._allocator, @intCast(tensor_count));
|
try self._block_args.ensureUnusedCapacity(self.allocator(), @intCast(tensor_count));
|
||||||
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
|
const assigned_args_count = self.mapBlockArguments(args, fn_body.block(), 0);
|
||||||
std.debug.assert(assigned_args_count == tensor_count);
|
std.debug.assert(assigned_args_count == tensor_count);
|
||||||
|
|
||||||
@ -474,7 +488,6 @@ pub const CompilationContext = struct {
|
|||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
defer arena.deinit();
|
defer arena.deinit();
|
||||||
const allocator = arena.allocator();
|
|
||||||
|
|
||||||
const s = Shape.init(.{8}, .f16);
|
const s = Shape.init(.{8}, .f16);
|
||||||
|
|
||||||
@ -497,13 +510,14 @@ pub const CompilationContext = struct {
|
|||||||
.bias = zml.Tensor{ ._shape = s, ._id = .{ .buffer_id = 0 } },
|
.bias = zml.Tensor{ ._shape = s, ._id = .{ .buffer_id = 0 } },
|
||||||
};
|
};
|
||||||
|
|
||||||
var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
var tensor_args = .{ model, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1234 } }, Tensor{ ._shape = s, ._id = .{ .buffer_id = 1235 } } };
|
||||||
const f = try comp.emitMlir(allocator, Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
const f = try comp.emitMlir(Local.forward, &tensor_args, .{ .name = "test.emitMlir.Local.forward", .kind = .main });
|
||||||
|
|
||||||
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
|
var mlir_bytecode = std.ArrayList(u8).init(std.testing.allocator);
|
||||||
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
defer mlir_bytecode.deinit();
|
||||||
|
try mlir_bytecode.writer().print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
||||||
|
|
||||||
// Check that the `x` input argument gives its buffer to the result tensor.
|
// Check that the `x` input argument gives its buffer to the result tensor.
|
||||||
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
// `%arg0` is the bias of the model, `%arg1` is `x`, `%arg2` is `y`.
|
||||||
@ -598,8 +612,8 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn finalizeAttributeList(allocator: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute {
|
fn finalizeAttributeList(allocator_: std.mem.Allocator, mlir_ctx: mlir.Context, attributes: []AttributeList) ![]mlir.Attribute {
|
||||||
const res = try allocator.alloc(mlir.Attribute, attributes.len);
|
const res = try allocator_.alloc(mlir.Attribute, attributes.len);
|
||||||
for (res, attributes) |*r, attr| {
|
for (res, attributes) |*r, attr| {
|
||||||
r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr();
|
r.* = mlir.DictionaryAttribute.init(mlir_ctx, attr.constSlice()).asAttr();
|
||||||
}
|
}
|
||||||
@ -617,7 +631,7 @@ pub const CompilationContext = struct {
|
|||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
args: stdx.meta.FnArgs(func),
|
args: stdx.meta.FnArgs(func),
|
||||||
) stdx.meta.FnResult(func) {
|
) stdx.meta.FnResult(func) {
|
||||||
var arena_state = std.heap.ArenaAllocator.init(self._allocator);
|
var arena_state = std.heap.ArenaAllocator.init(self._arena.child_allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
const arena = arena_state.allocator();
|
const arena = arena_state.allocator();
|
||||||
|
|
||||||
@ -625,21 +639,13 @@ pub const CompilationContext = struct {
|
|||||||
// the result of this will also have the correct tags of the result shapes
|
// the result of this will also have the correct tags of the result shapes
|
||||||
const args_hash = hashArgs(args);
|
const args_hash = hashArgs(args);
|
||||||
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash };
|
const key: FnCache.Key = .{ .fn_ptr = &func, .input_hash = args_hash };
|
||||||
|
|
||||||
const function = self._fn_cache.getEntry(key) orelse b: {
|
const function = self._fn_cache.getEntry(key) orelse b: {
|
||||||
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
const full_name: [:0]const u8 = if (std.mem.eql(u8, "main", func_name))
|
||||||
arena.dupeZ(u8, func_name) catch unreachable
|
arena.dupeZ(u8, func_name) catch unreachable
|
||||||
else
|
else
|
||||||
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable;
|
std.fmt.allocPrintZ(arena, "{s}_{x}", .{ func_name, key.input_hash }) catch unreachable;
|
||||||
|
|
||||||
const og_buffer_to_arg = self._buffer_to_arg;
|
|
||||||
defer {
|
|
||||||
self._buffer_to_arg.deinit(self._allocator);
|
|
||||||
self._buffer_to_arg = og_buffer_to_arg;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Reset the buffer -> assignement
|
|
||||||
self._buffer_to_arg = .{};
|
|
||||||
|
|
||||||
var arg_id: u16 = 0;
|
var arg_id: u16 = 0;
|
||||||
var tensor_args: @TypeOf(args) = args;
|
var tensor_args: @TypeOf(args) = args;
|
||||||
meta.mapAlloc(struct {
|
meta.mapAlloc(struct {
|
||||||
@ -648,14 +654,14 @@ pub const CompilationContext = struct {
|
|||||||
arg_id_.* += 1;
|
arg_id_.* += 1;
|
||||||
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
|
return Tensor{ ._shape = x._shape, ._id = .{ .arg_id = a }, ._donation = .{ .arg = a } };
|
||||||
}
|
}
|
||||||
}.cb, self._allocator, &arg_id, args, &tensor_args) catch @panic("OutOfMemory");
|
}.cb, arena, &arg_id, args, &tensor_args) catch @panic("OutOfMemory");
|
||||||
|
|
||||||
const f = self.emitMlir(arena, func, &tensor_args, .{
|
const f = self.emitMlir(func, &tensor_args, .{
|
||||||
.name = full_name,
|
.name = full_name,
|
||||||
}) catch @panic("OOM");
|
}) catch @panic("OOM");
|
||||||
self._module.getBody().appendOperation(f.mlir_fn);
|
self._module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
break :b self._fn_cache.addEntry(key, f) catch unreachable;
|
break :b self._fn_cache.addEntry(self.allocator(), key, f) catch unreachable;
|
||||||
};
|
};
|
||||||
|
|
||||||
const loc = self.mlirCtx().location(@src());
|
const loc = self.mlirCtx().location(@src());
|
||||||
@ -701,7 +707,7 @@ pub const CompilationContext = struct {
|
|||||||
///
|
///
|
||||||
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
|
/// This is done so that we have a mapping between the arguments of the kernel associated with a module and the actual Tensors
|
||||||
/// stored in the Module.
|
/// stored in the Module.
|
||||||
/// Caller need to allocate required memory in self._buffer_to_arg.
|
/// Caller need to allocate required memory in self._block_args.
|
||||||
pub fn mapBlockArguments(self: *CompilationContext, v: anytype, block: mlir.Block, start: usize) usize {
|
pub fn mapBlockArguments(self: *CompilationContext, v: anytype, block: mlir.Block, start: usize) usize {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
index: usize,
|
index: usize,
|
||||||
@ -714,7 +720,7 @@ pub const CompilationContext = struct {
|
|||||||
const arg_value = ctx.block.argument(ctx.index);
|
const arg_value = ctx.block.argument(ctx.index);
|
||||||
// log.debug("mapping {} to arg {}", .{ tensor._id, ctx.index });
|
// log.debug("mapping {} to arg {}", .{ tensor._id, ctx.index });
|
||||||
|
|
||||||
const res = ctx.self._buffer_to_arg.getOrPutAssumeCapacity(tensor._id);
|
const res = ctx.self._block_args.getOrPutAssumeCapacity(tensor._id);
|
||||||
if (res.found_existing) {
|
if (res.found_existing) {
|
||||||
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
|
stdx.debug.panic("Failed compilation because received two tensors arguments with the same ID: {} and {} at index {} ({}).", .{ res.value_ptr.*[0], tensor, ctx.index, tensor._id });
|
||||||
} else {
|
} else {
|
||||||
@ -728,7 +734,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
/// Create tensor from the given shapes.
|
/// Create tensor from the given shapes.
|
||||||
/// Each created tensor will receive a unique id, local to this CompilationContext.
|
/// Each created tensor will receive a unique id, local to this CompilationContext.
|
||||||
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator: std.mem.Allocator, args_shapes: anytype) !ArgsT {
|
pub fn tensorFromShapes(self: *CompilationContext, ArgsT: type, allocator_: std.mem.Allocator, args_shapes: anytype) !ArgsT {
|
||||||
const Local = struct {
|
const Local = struct {
|
||||||
fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor {
|
fn tensorFromShape(arg_id: *u64, shape: Shape) Tensor {
|
||||||
defer arg_id.* += 1;
|
defer arg_id.* += 1;
|
||||||
@ -740,7 +746,7 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
var tensor_args: ArgsT = undefined;
|
var tensor_args: ArgsT = undefined;
|
||||||
try meta.mapAlloc(Local.tensorFromShape, allocator, &self._unique_id, args_shapes, &tensor_args);
|
try meta.mapAlloc(Local.tensorFromShape, allocator_, &self._unique_id, args_shapes, &tensor_args);
|
||||||
return tensor_args;
|
return tensor_args;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -771,7 +777,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
pub fn getValueAndDonation(self: *const CompilationContext, tensor: Tensor) struct { mlir.Value, Tensor._Donation } {
|
pub fn getValueAndDonation(self: *const CompilationContext, tensor: Tensor) struct { mlir.Value, Tensor._Donation } {
|
||||||
return switch (tensor._id) {
|
return switch (tensor._id) {
|
||||||
.buffer_id, .arg_id => if (self._buffer_to_arg.get(tensor._id)) |res|
|
.buffer_id, .arg_id => if (self._block_args.get(tensor._id)) |res|
|
||||||
.{ res[0], res[1] }
|
.{ res[0], res[1] }
|
||||||
else {
|
else {
|
||||||
log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id });
|
log.err("Found unknown tensor id {}({})", .{ tensor, tensor._id });
|
||||||
@ -1021,44 +1027,28 @@ pub fn xxHash64Writer(hasher: *std.hash.XxHash64) XxHash64Writer {
|
|||||||
pub const FnCache = struct {
|
pub const FnCache = struct {
|
||||||
pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 };
|
pub const Key = struct { fn_ptr: *const anyopaque, input_hash: u64 };
|
||||||
|
|
||||||
// TODO: merge arenas
|
cache: std.AutoHashMapUnmanaged(Key, MlirFn) = .{},
|
||||||
cache: std.AutoHashMapUnmanaged(Key, MlirFn),
|
|
||||||
// Arena for the cache entries
|
|
||||||
cache_arena: std.heap.ArenaAllocator,
|
|
||||||
// Arena for the cache data (name, res_type)
|
|
||||||
cache_data_arena: std.heap.ArenaAllocator,
|
|
||||||
|
|
||||||
pub fn init(allocator: std.mem.Allocator) FnCache {
|
pub fn deinit(self: FnCache, allocator: std.mem.Allocator) void {
|
||||||
return .{
|
self.cache.deinit(allocator);
|
||||||
.cache = .{},
|
|
||||||
.cache_arena = std.heap.ArenaAllocator.init(allocator),
|
|
||||||
.cache_data_arena = std.heap.ArenaAllocator.init(allocator),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: FnCache) void {
|
|
||||||
self.cache_arena.deinit();
|
|
||||||
self.cache_data_arena.deinit();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn {
|
pub fn getEntry(self: *const FnCache, key: Key) ?MlirFn {
|
||||||
return self.cache.get(key);
|
return self.cache.get(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn addEntry(self: *FnCache, key: Key, value: MlirFn) !MlirFn {
|
pub fn addEntry(self: *FnCache, allocator: std.mem.Allocator, key: Key, value: MlirFn) !MlirFn {
|
||||||
var cache_data_allocator = self.cache_data_arena.allocator();
|
const res_types_copy = try allocator.dupe(mlir.Type, value.res_types);
|
||||||
|
errdefer allocator.free(res_types_copy);
|
||||||
|
|
||||||
const res_types_copy = try cache_data_allocator.dupe(mlir.Type, value.res_types);
|
const res_shapes_copy = try allocator.dupe(Shape, value.res_shapes);
|
||||||
errdefer cache_data_allocator.free(res_types_copy);
|
errdefer allocator.free(res_shapes_copy);
|
||||||
|
|
||||||
const res_shapes_copy = try cache_data_allocator.dupe(Shape, value.res_shapes);
|
const res_donations_copy = try allocator.dupe(Tensor._Donation, value.res_donations);
|
||||||
errdefer cache_data_allocator.free(res_shapes_copy);
|
errdefer allocator.free(res_donations_copy);
|
||||||
|
|
||||||
const res_donations_copy = try cache_data_allocator.dupe(Tensor._Donation, value.res_donations);
|
const name_copy = try allocator.dupeZ(u8, value.name);
|
||||||
errdefer cache_data_allocator.free(res_donations_copy);
|
errdefer allocator.free(name_copy);
|
||||||
|
|
||||||
const name_copy = try cache_data_allocator.dupeZ(u8, value.name);
|
|
||||||
errdefer cache_data_allocator.free(name_copy);
|
|
||||||
|
|
||||||
const owned_value: MlirFn = .{
|
const owned_value: MlirFn = .{
|
||||||
.name = name_copy,
|
.name = name_copy,
|
||||||
@ -1069,7 +1059,7 @@ pub const FnCache = struct {
|
|||||||
.res_donations = res_donations_copy,
|
.res_donations = res_donations_copy,
|
||||||
};
|
};
|
||||||
|
|
||||||
try self.cache.putNoClobber(self.cache_arena.allocator(), key, owned_value);
|
try self.cache.putNoClobber(allocator, key, owned_value);
|
||||||
return owned_value;
|
return owned_value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
@ -816,8 +816,8 @@ const SdpaMemEfficient = struct {
|
|||||||
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
|
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
|
||||||
|
|
||||||
const ctx = zml.module.CompilationContext.current();
|
const ctx = zml.module.CompilationContext.current();
|
||||||
const q_chunks = ctx._allocator.alloc(zml.Tensor, n_q_chunks) catch unreachable;
|
const q_chunks = ctx.allocator().alloc(zml.Tensor, n_q_chunks) catch unreachable;
|
||||||
defer ctx._allocator.free(q_chunks);
|
defer ctx.allocator().free(q_chunks);
|
||||||
for (0..n_q_chunks) |i| {
|
for (0..n_q_chunks) |i| {
|
||||||
const idx: u32 = @intCast(i);
|
const idx: u32 = @intCast(i);
|
||||||
const q_slice: zml.Tensor.DynSlice = .{
|
const q_slice: zml.Tensor.DynSlice = .{
|
||||||
|
|||||||
@ -316,7 +316,7 @@ pub fn for_(comptime func: anytype, blk_ctx: BlockSign(func).BlkCtx, num_steps_:
|
|||||||
// but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop.
|
// but because of https://github.com/zml/zml/issues/97 we also reuse it to start the while_ loop.
|
||||||
const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) });
|
const first_step = @call(.auto, func, .{ blk_ctx, Tensor.scalar(0, .i32) });
|
||||||
log.debug("for_ first_step: {}", .{first_step});
|
log.debug("for_ first_step: {}", .{first_step});
|
||||||
const allocator = CompilationContext.current()._allocator;
|
const allocator = CompilationContext.current().allocator();
|
||||||
// Optimize for small num reps
|
// Optimize for small num reps
|
||||||
if (num_steps == 1) {
|
if (num_steps == 1) {
|
||||||
var res = first_step;
|
var res = first_step;
|
||||||
|
|||||||
@ -126,7 +126,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns the dimension of axis 'axis_'.
|
/// Returns the dimension of axis 'axis_'.
|
||||||
///
|
///
|
||||||
/// 'axis_' can be a signed integer or a tag.
|
/// 'axis_' can be an integer or a tag.
|
||||||
pub fn dim(self: Tensor, axis_: anytype) i64 {
|
pub fn dim(self: Tensor, axis_: anytype) i64 {
|
||||||
return self._shape.dim(axis_);
|
return self._shape.dim(axis_);
|
||||||
}
|
}
|
||||||
@ -138,11 +138,18 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns the index of axis 'axis_'.
|
/// Returns the index of axis 'axis_'.
|
||||||
///
|
///
|
||||||
/// 'axis_' can be a signed integer or a tag.
|
/// 'axis_' can be an integer or a tag.
|
||||||
pub fn axis(self: Tensor, axis_: anytype) u3 {
|
pub fn axis(self: Tensor, axis_: anytype) u3 {
|
||||||
return self._shape.axis(axis_);
|
return self._shape.axis(axis_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the indices of each of the given axes.
|
||||||
|
///
|
||||||
|
/// 'axis_' can be an integer or a tag.
|
||||||
|
pub fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) {
|
||||||
|
return self._shape.axes(axes_);
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a Tensor tagged with the tags in 'tagz'.
|
/// Returns a Tensor tagged with the tags in 'tagz'.
|
||||||
pub fn withTags(self: Tensor, tagz: anytype) Tensor {
|
pub fn withTags(self: Tensor, tagz: anytype) Tensor {
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -227,13 +234,13 @@ pub const Tensor = struct {
|
|||||||
var _global_tensor_counter: u64 = 0;
|
var _global_tensor_counter: u64 = 0;
|
||||||
|
|
||||||
/// Internal use
|
/// Internal use
|
||||||
pub fn reserveIdRange(len: u32) u64 {
|
pub fn _reserveIdRange(len: u32) u64 {
|
||||||
return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst);
|
return @atomicRmw(u64, &_global_tensor_counter, .Add, len, .seq_cst);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Internal use
|
/// Internal use
|
||||||
pub fn setUniqueId(self: *Tensor) void {
|
pub fn setUniqueId(self: *Tensor) void {
|
||||||
self._id = .{ .buffer_id = reserveIdRange(1) };
|
self._id = .{ .buffer_id = _reserveIdRange(1) };
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
|
/// Returns a Tensor containing the absolute value of each element of the input Tensor.
|
||||||
@ -357,7 +364,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns the Cholesky decomposition of the input Tensor.
|
/// Returns the Cholesky decomposition of the input Tensor.
|
||||||
///
|
///
|
||||||
/// 'lower' controls the form of the outut Tensor. The output will be lower-triangular if 'lower' is true
|
/// 'lower' controls the form of the output Tensor. The output will be lower-triangular if 'lower' is true
|
||||||
/// and upper-triangular otherwise.
|
/// and upper-triangular otherwise.
|
||||||
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
||||||
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
||||||
@ -1070,7 +1077,7 @@ pub const Tensor = struct {
|
|||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
|
|
||||||
comp.activate();
|
comp.activate();
|
||||||
@ -1390,6 +1397,7 @@ pub const Tensor = struct {
|
|||||||
///
|
///
|
||||||
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
/// unflatten((d0, d1, axis_m, d3), 2, n) -> (d0, d1, n, d2_m, d3)
|
||||||
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
|
pub fn unflatten(self: Tensor, axis_: i8, n: i64) Tensor {
|
||||||
|
// TODO: move to torch.zig, this equivalent to `spitAxis`
|
||||||
stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
|
stdx.debug.assert(self.rank() < Tensor.MAX_RANK, "unflatten expects input tensor rank to be less than {}, got {}", .{ Tensor.MAX_RANK, self.rank() });
|
||||||
|
|
||||||
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
|
const a = if (axis_ >= 0) self.axis(axis_) else self.axis(axis_) + 1;
|
||||||
@ -1445,6 +1453,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Flattens the given axis and the next one, into one new axis.
|
/// Flattens the given axis and the next one, into one new axis.
|
||||||
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
|
pub fn flatten(self: Tensor, axis_: anytype) Tensor {
|
||||||
|
// TODO: move to torch.zig, this is equivalent to merge
|
||||||
const old_shape = self._shape;
|
const old_shape = self._shape;
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||||||
@ -1463,6 +1472,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn flattenAll(self: Tensor) Tensor {
|
pub inline fn flattenAll(self: Tensor) Tensor {
|
||||||
|
// TODO: rename to just flatten, once flatten is moved to torch
|
||||||
return self.reshape(.{self.count()});
|
return self.reshape(.{self.count()});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1547,7 +1557,8 @@ pub const Tensor = struct {
|
|||||||
return if (idx < 0) self.dim(axis_) + idx else idx;
|
return if (idx < 0) self.dim(axis_) + idx else idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn choose1d(self: Tensor, axis_: i64, i: i64) Tensor {
|
pub fn choose1d(self: Tensor, axis_: anytype, i: i64) Tensor {
|
||||||
|
// TODO: this use case could be handled directly by slice if we added a .single field
|
||||||
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_);
|
return self.slice1d(axis_, .{ .start = i, .end = i + 1 }).squeeze(axis_);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1651,6 +1662,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Repeats in line each value along the given axes.
|
/// Repeats in line each value along the given axes.
|
||||||
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
|
pub fn stutter(self: Tensor, n_reps: []const u63) Tensor {
|
||||||
|
// TODO: this should support the tagged syntax: x.repeat(.{ .a = 3, .b = 2});
|
||||||
stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
stdx.debug.assert(n_reps.len == self.rank(), "stutter expects tensor rank and 'n_reps' length to be equal, got {} and {}", .{ self.rank(), n_reps.len });
|
||||||
|
|
||||||
var res = self;
|
var res = self;
|
||||||
@ -2159,7 +2171,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
comp.activate();
|
comp.activate();
|
||||||
defer comp.deactivate();
|
defer comp.deactivate();
|
||||||
@ -2295,7 +2307,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
comp.activate();
|
comp.activate();
|
||||||
defer comp.deactivate();
|
defer comp.deactivate();
|
||||||
@ -2519,7 +2531,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
{
|
{
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
comp.activate();
|
comp.activate();
|
||||||
defer comp.deactivate();
|
defer comp.deactivate();
|
||||||
@ -2759,8 +2771,10 @@ pub const Tensor = struct {
|
|||||||
return .{ .values = res[0], .indices = res[1] };
|
return .{ .values = res[0], .indices = res[1] };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub const ArgSortOpts = struct { descending: bool = false };
|
||||||
|
|
||||||
/// Returns a Tensor containing the indices corresponding to the sorted values over the given axis.
|
/// Returns a Tensor containing the indices corresponding to the sorted values over the given axis.
|
||||||
pub fn argsort(self: Tensor, axis_: i64, opts: struct { descending: bool = false }) Tensor {
|
pub fn argsort(self: Tensor, axis_: anytype, opts: ArgSortOpts) Tensor {
|
||||||
return self.sort(axis_, .{ .descending = opts.descending }).indices;
|
return self.sort(axis_, .{ .descending = opts.descending }).indices;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2768,13 +2782,19 @@ pub const Tensor = struct {
|
|||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const Local = struct {
|
||||||
|
pub fn _argsort(x: Tensor, axis_: u3, opts: ArgSortOpts) Tensor {
|
||||||
|
return x.argsort(axis_, opts);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
var arena_state = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||||
defer arena_state.deinit();
|
defer arena_state.deinit();
|
||||||
const allocator = arena_state.allocator();
|
const allocator = arena_state.allocator();
|
||||||
// 2D Tensor - dim = 1, ascending
|
// 2D Tensor - dim = 1, ascending
|
||||||
{
|
{
|
||||||
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
const x = try zml.Buffer.fromSlice(platform, .{ 2, 5 }, &[_]f32{ -0.9264, 0.7156, 1.0202, 0.3992, 1.2349, 1.0003, -0.1932, 1.3935, 0.7316, 0.0851 });
|
||||||
const res = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{} });
|
const res = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{} });
|
||||||
const res_cpu = try res.toHostAlloc(allocator);
|
const res_cpu = try res.toHostAlloc(allocator);
|
||||||
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
try testing.expectEqualSlices(i32, &.{ 0, 3, 1, 2, 4, 1, 4, 3, 0, 2 }, res_cpu.items(i32));
|
||||||
}
|
}
|
||||||
@ -2787,7 +2807,7 @@ pub const Tensor = struct {
|
|||||||
0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008,
|
0.6626, -0.3040, -0.8726, -1.4805, -1.6943, 1.1055, -2.0078, -0.5288, 0.8813, 0.8008,
|
||||||
2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918,
|
2.0527, 1.1230, 0.5430, 0.2494, -0.9434, 0.7876, 0.1818, 0.9258, -2.4902, 1.5918,
|
||||||
});
|
});
|
||||||
const res_dev = try zml.testing.compileAndCall(platform, Tensor.argsort, .{ x, 1, .{ .descending = true } });
|
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 1, .{ .descending = true } });
|
||||||
const res = try res_dev.toHostAlloc(allocator);
|
const res = try res_dev.toHostAlloc(allocator);
|
||||||
try testing.expectEqualSlices(i32, &.{
|
try testing.expectEqualSlices(i32, &.{
|
||||||
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
4, 1, 1, 2, 0, 2, 0, 0, 3, 4,
|
||||||
@ -2809,7 +2829,7 @@ pub const Tensor = struct {
|
|||||||
64, 86, 62, 88,
|
64, 86, 62, 88,
|
||||||
57, 21, 19, 12,
|
57, 21, 19, 12,
|
||||||
});
|
});
|
||||||
const res_dev = try zml.testing.compileAndCallWithTensors(platform, Tensor.argsort, .{ x.shape(), 3, .{} }, .{ x, 0, .{} });
|
const res_dev = try zml.testing.compileAndCall(platform, Local._argsort, .{ x, 3, .{} });
|
||||||
const res = try res_dev.toHostAlloc(allocator);
|
const res = try res_dev.toHostAlloc(allocator);
|
||||||
try testing.expectEqualSlices(i32, &.{
|
try testing.expectEqualSlices(i32, &.{
|
||||||
2, 1, 3, 0,
|
2, 1, 3, 0,
|
||||||
@ -2921,10 +2941,6 @@ pub const Tensor = struct {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub inline fn axes(self: Tensor, axes_: anytype) std.BoundedArray(u3, Tensor.MAX_RANK) {
|
|
||||||
return self._shape.axes(axes_);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Chunk a given tensor into exactly n parts of equal shape.
|
/// Chunk a given tensor into exactly n parts of equal shape.
|
||||||
/// `self.dim(axis_)` must be divisible by n_chunks.
|
/// `self.dim(axis_)` must be divisible by n_chunks.
|
||||||
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
|
pub fn chunkExact(self: Tensor, axis_: anytype, n_chunks: comptime_int) [n_chunks]Tensor {
|
||||||
@ -2944,7 +2960,7 @@ pub const Tensor = struct {
|
|||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
comp.activate();
|
comp.activate();
|
||||||
defer comp.deactivate();
|
defer comp.deactivate();
|
||||||
@ -2959,7 +2975,7 @@ pub const Tensor = struct {
|
|||||||
const chunks = x.chunkExact(ax, n_chunks);
|
const chunks = x.chunkExact(ax, n_chunks);
|
||||||
|
|
||||||
const res_shape = Shape.init(res, .f16);
|
const res_shape = Shape.init(res, .f16);
|
||||||
for (&chunks) |chk| {
|
for (chunks) |chk| {
|
||||||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2971,13 +2987,14 @@ pub const Tensor = struct {
|
|||||||
self: Tensor,
|
self: Tensor,
|
||||||
axis_: i64,
|
axis_: i64,
|
||||||
n_chunks: comptime_int,
|
n_chunks: comptime_int,
|
||||||
) std.BoundedArray(Tensor, n_chunks + 1) {
|
) []Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const d = self.dim(a);
|
const d = self.dim(a);
|
||||||
const chunk_size: i64 = @divFloor(d, n_chunks);
|
const chunk_size: i64 = @divFloor(d, n_chunks);
|
||||||
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
const tail_chunk_size: i64 = @rem(d, chunk_size);
|
||||||
|
|
||||||
var chunks: std.BoundedArray(Tensor, n_chunks + 1) = .{};
|
const allocator = self.getContext().allocator();
|
||||||
|
var chunks = std.ArrayListUnmanaged(Tensor).initCapacity(allocator, n_chunks + 1) catch @panic("OOM");
|
||||||
for (0..n_chunks) |i| {
|
for (0..n_chunks) |i| {
|
||||||
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
const start: i64 = @as(i64, @intCast(i)) * chunk_size;
|
||||||
chunks.appendAssumeCapacity(
|
chunks.appendAssumeCapacity(
|
||||||
@ -2988,7 +3005,7 @@ pub const Tensor = struct {
|
|||||||
const start: i64 = n_chunks * chunk_size;
|
const start: i64 = n_chunks * chunk_size;
|
||||||
chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
|
chunks.appendAssumeCapacity(self.slice1d(a, .{ .start = start }));
|
||||||
}
|
}
|
||||||
return chunks;
|
return chunks.items;
|
||||||
}
|
}
|
||||||
|
|
||||||
test chunkAllowTrailing {
|
test chunkAllowTrailing {
|
||||||
@ -2996,7 +3013,7 @@ pub const Tensor = struct {
|
|||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.heap.page_allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
comp.activate();
|
comp.activate();
|
||||||
defer comp.deactivate();
|
defer comp.deactivate();
|
||||||
@ -3012,35 +3029,34 @@ pub const Tensor = struct {
|
|||||||
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
|
const chunks = x.chunkAllowTrailing(x.axis(ax), n_chunks);
|
||||||
|
|
||||||
const res_shape = Shape.init(res, .f16);
|
const res_shape = Shape.init(res, .f16);
|
||||||
for (chunks.constSlice()[0..n_chunks]) |chk| {
|
for (chunks[0..n_chunks]) |chk| {
|
||||||
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
try zml.testing.expectEqualShapes(res_shape, chk.shape());
|
||||||
}
|
}
|
||||||
const trailing_shape = Shape.init(trailing, .f16);
|
const trailing_shape = Shape.init(trailing, .f16);
|
||||||
if (trailing_shape.rank() > 0) {
|
if (trailing_shape.rank() > 0) {
|
||||||
try std.testing.expectEqual(n_chunks + 1, chunks.len);
|
try std.testing.expectEqual(n_chunks + 1, chunks.len);
|
||||||
try zml.testing.expectEqualShapes(trailing_shape, chunks.get(n_chunks).shape());
|
try zml.testing.expectEqualShapes(trailing_shape, chunks[n_chunks].shape());
|
||||||
} else {
|
} else {
|
||||||
try std.testing.expectEqual(n_chunks, chunks.len);
|
try std.testing.expectEqual(n_chunks, chunks.len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn split(self: Tensor, allocator: std.mem.Allocator, split_size_or_sections: []const i64, axis_: i64) ![]Tensor {
|
pub fn split(self: Tensor, axis_: anytype, split_sizes: []const i64) []Tensor {
|
||||||
stdx.debug.assert(split_size_or_sections.len > 0, "split expects 'split_size_or_sections' length to be positive, got {}", .{split_size_or_sections.len});
|
stdx.debug.assert(split_sizes.len > 0, "split expects at least one 'split_sizes', got 0", .{});
|
||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const length = self.dim(a);
|
const d = self.dim(a);
|
||||||
if (split_size_or_sections.len != 1) {
|
|
||||||
var split_sum: i64 = 0;
|
var split_sum: i64 = 0;
|
||||||
for (split_size_or_sections) |n| split_sum += n;
|
for (split_sizes) |n| split_sum += n;
|
||||||
stdx.debug.assert(split_sum == length, "split expects sum of 'split_size_or_sections' values and axis dimension to be equal, got {} and {}", .{ split_sum, length });
|
stdx.debug.assert(split_sum == d, "split expects sum of 'split_sizes' values and axis dimension to be equal, got {} and {}", .{ split_sum, d });
|
||||||
}
|
|
||||||
|
|
||||||
const res = try allocator.alloc(Tensor, split_size_or_sections.len);
|
const allocator = self.getContext().allocator();
|
||||||
|
const res = allocator.alloc(Tensor, split_sizes.len) catch @panic("OOM");
|
||||||
errdefer allocator.dealloc(res);
|
errdefer allocator.dealloc(res);
|
||||||
|
|
||||||
var start: i64 = 0;
|
var start: i64 = 0;
|
||||||
for (split_size_or_sections, 0..) |n, i| {
|
for (split_sizes, 0..) |n, i| {
|
||||||
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
res[i] = self.slice1d(a, .{ .start = start, .end = start + n });
|
||||||
start += n;
|
start += n;
|
||||||
}
|
}
|
||||||
@ -3528,7 +3544,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||||||
pub fn any(self: Tensor, axis_: i64) Tensor {
|
pub fn any(self: Tensor, axis_: anytype) Tensor {
|
||||||
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
||||||
const red = ops.reduce(
|
const red = ops.reduce(
|
||||||
struct {
|
struct {
|
||||||
|
|||||||
@ -127,6 +127,9 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
|
|||||||
/// Compile a function and immediatly call it with the given buffers.
|
/// Compile a function and immediatly call it with the given buffers.
|
||||||
/// The compiled module is discarded after the call.
|
/// The compiled module is discarded after the call.
|
||||||
/// Useful during testing when a module is typically called only once.
|
/// Useful during testing when a module is typically called only once.
|
||||||
|
///
|
||||||
|
/// Note: `func` needs explicit types on all parameters.
|
||||||
|
/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types.
|
||||||
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user