zml: eliminate compile-time fields from Bufferized, removing the need to pass undefined to exe.call for inlined arguments. Introduce BufferizedWithArgs in zml.testing for compileAndCall utility.

This commit is contained in:
Tarry Singh 2024-11-28 12:24:39 +00:00
parent 364a222dc1
commit f5ab2c3a55
8 changed files with 285 additions and 102 deletions

View File

@ -231,6 +231,35 @@ pub const BaseExe = struct {
}
}
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
const LocalContext = struct {
index: u32,
platform: Platform,
outputs: []const [*]*pjrt.Buffer,
output_shapes: []Shape,
};
var local_ctx: LocalContext = .{
.index = 0,
.platform = self.platform,
.outputs = self.output_per_device,
.output_shapes = self.result_shapes,
};
meta.visit((struct {
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
const i = ctx.index;
ctx.index += 1;
if (i >= ctx.output_shapes.len) return;
var shards: Buffer.Shards = .{};
for (ctx.outputs) |buff| {
shards.appendAssumeCapacity(buff[i]);
}
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice());
}
}).cb, &local_ctx, result);
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
}
pub fn serialize(self: BaseExe, writer: anytype) !void {
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
var serialize_result = try executable.serialize(self.platform.pjrt_api);
@ -305,7 +334,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
std.debug.assert(total_ready == self.inner.input_buffer_count);
self.inner._unsafeCall();
var result: Bufferized(ReturnT) = undefined;
assignRawBuffers(&result, self.inner.platform, self.inner.output_per_device, self.inner.result_shapes);
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
return result;
}
};
@ -349,36 +378,6 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
return context.index;
}
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape) void {
const LocalContext = struct {
index: u32,
platform: Platform,
buffers: []const [*]*pjrt.Buffer,
buffer_shapes: []Shape,
};
var local_ctx: LocalContext = .{
.index = 0,
.platform = platform,
.buffers = buffers,
.buffer_shapes = buffer_shapes,
};
meta.visit((struct {
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
const i = ctx.index;
ctx.index += 1;
if (i >= ctx.buffer_shapes.len) return;
var shards: Buffer.Shards = .{};
for (ctx.buffers) |buff| {
shards.appendAssumeCapacity(buff[i]);
}
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
}
}).cb, &local_ctx, v);
stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
}
fn prettyFnName(
comptime func: anytype,
allocator: std.mem.Allocator,

View File

@ -1,5 +1,6 @@
const std = @import("std");
const testing = std.testing;
const builtin = @import("builtin");
const stdx = @import("stdx");
const FnParam = stdx.meta.FnParam;
@ -9,6 +10,7 @@ test {
std.testing.refAllDecls(@This());
}
/// Visit a given type `T` and replace all fields containing `From` by fields containing `To`.
pub fn MapType(From: type, To: type) type {
return struct {
pub fn map(T: type) type {
@ -78,6 +80,32 @@ pub fn MapType(From: type, To: type) type {
};
}
test MapType {
const A = struct { a: u32 };
const B = struct { b: u32 };
const A2B = MapType(A, B);
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
const struct_b = A2B.map(StructA){
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
.maybe = null,
.one = .{ .b = 2 },
.other = 43,
};
_ = struct_b;
// TODO(corendos) fixme, union_b should contains Bs not As.
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
const union_b = [_]A2B.map(UnionA){
.{ .some = &[2]A{ .{ .a = 0 }, .{ .a = 1 } } },
.{ .one = .{ .a = 2 } },
.{ .maybe = null },
.{ .other = 43 },
};
_ = union_b;
}
/// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct
/// and calls the callback when it finds a `From` element, and writes it to the `to` struct.
/// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it.
@ -136,10 +164,10 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
@field(from, field.name),
&@field(to, field.name),
);
} else if (field.default_value) |_| {
} else if (field.default_value_ptr) |_| {
@field(to, field.name) = null;
} else {
stdx.debug.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
stdx.debug.compileError("Mapping {} -> {} inside {} failed. Missing field {s} in {}", .{ From, To, FromStruct, field.name, ToStruct });
},
else => @field(to, field.name) = @field(from, field.name),
}
@ -234,6 +262,136 @@ test mapAlloc {
try testing.expectEqual(12, bb.static_slice[1].b);
}
/// Visit a given type `T` and:
/// * replace all fields containing `From` by fields containing `To`
/// * drop all fields not containing any `From`.
/// The returned type will contains only `To` making it easy for the compiler to produce compact layout.
/// Used by `zml.Bufferized` to strip compile time arguments from a model struct.
pub fn MapRestrict(From: type, To: type) type {
return struct {
pub fn map(T: type) type {
switch (T) {
From => return To,
?From => return ?To,
*From => return *To,
*const From => return *const To,
[]From => return []To,
[]const From => return []const To,
else => {},
}
if (!Contains(T, From)) return void;
return switch (@typeInfo(T)) {
.@"struct" => |struct_infos| {
// We know that at least one of the struct field contains a From.
// We map each field individually. Fields without From and comptime fields are removed.
const fields = struct_infos.fields;
var num_fields: usize = 0;
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
for (fields) |field| {
if (!field.is_comptime and Contains(field.type, From)) {
const R = map(field.type);
if (R == field.type) {
struct_fields[num_fields] = field;
} else {
const name = if (struct_infos.is_tuple) struct_infos.fields[num_fields].name else field.name;
struct_fields[num_fields] = .{
.name = name,
.type = R,
.default_value_ptr = null,
.is_comptime = false,
.alignment = @alignOf(R),
};
// Handle the case `field: ?Tensor = null`
// Generic handling of default value is not possible.
if (R == ?To) {
struct_fields[num_fields].default_value_ptr = &@as(R, null);
}
}
num_fields += 1;
}
}
if (num_fields == 0) return void;
return @Type(.{ .@"struct" = .{
.layout = .auto,
.fields = struct_fields[0..num_fields],
.decls = &.{},
.is_tuple = struct_infos.is_tuple,
} });
},
.@"union" => |union_info| {
// We know that at least one of the union field contains a From.
// We map each field individually. Fields without From, are replaced by "void".
const fields = union_info.fields;
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
for (0.., fields) |i, field| {
union_fields[i] = .{
.name = field.name,
.type = map(field.type),
.alignment = 0,
};
}
return @Type(.{ .@"union" = .{
.layout = .auto,
.tag_type = union_info.tag_type,
.fields = union_fields[0..],
.decls = &.{},
} });
},
.array => |arr_info| [arr_info.len]map(arr_info.child),
.pointer => |ptr_info| switch (ptr_info.size) {
.slice => if (ptr_info.is_const)
[]const map(ptr_info.child)
else
[]map(ptr_info.child),
.one => if (ptr_info.is_const)
*const map(ptr_info.child)
else
*map(ptr_info.child),
.many => if (ptr_info.is_const)
[*]const map(ptr_info.child)
else
[*]map(ptr_info.child),
.c => if (ptr_info.is_const)
[*c]map(ptr_info.child)
else
[*c]map(ptr_info.child),
},
.optional => |opt_info| ?map(opt_info.child),
else => T,
};
}
};
}
test MapRestrict {
const A = struct { a: u32 };
const B = struct { b: u32 };
const A2B = MapRestrict(A, B);
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
const struct_b = A2B.map(StructA){
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
.maybe = null,
.one = .{ .b = 2 },
// Note how struct_b doesn't even have a .other field now.
};
_ = struct_b;
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
const union_b = [_]A2B.map(UnionA){
.{ .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } } },
.{ .one = .{ .b = 2 } },
.{ .maybe = null },
// Note how union_b.other is void now.
.{ .other = {} },
};
_ = union_b;
}
/// Recursively visit the given struct and calls the callback for each K found.
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
@ -554,7 +712,7 @@ pub fn Contains(Haystack: type, T: type) bool {
return switch (@typeInfo(Haystack)) {
.@"struct" => |info| {
inline for (info.fields) |field| {
if (Contains(field.type, T))
if (!field.is_comptime and Contains(field.type, T))
return true;
}
return false;

View File

@ -1011,21 +1011,21 @@ test FnCache {
}
};
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
const nn: zml.Bufferized(NN) = .{
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
const nn: zml.testing.BufferizedWithArgs(NN) = .{
.layers = .{
.{
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
.w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }),
.b = try .fromArray(platform, [2]f16{ 0, 0 }),
},
.{
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
.w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }),
.b = try .fromArray(platform, [2]f16{ 10, 10 }),
},
// third layer is different
.{
.w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
.b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
.w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }),
.b = try .fromArray(platform, [3]f16{ -10, -10, -10 }),
},
},
};
@ -1084,13 +1084,13 @@ test "FnCache with mixed integer/tensor" {
}
};
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
const nn: zml.Bufferized(NN) = .{
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
const nn: zml.testing.BufferizedWithArgs(NN) = .{
.layers = .{
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) },
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) },
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }) },
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }) },
// third layer has different shape
.{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) },
.{ .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }) },
},
};
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });

View File

@ -401,7 +401,7 @@ test "real/img" {
{
const mod = try zml.compileFn(std.testing.allocator, Fns.testSplitSeq, .{}, platform);
defer mod.deinit();
const ret = mod.call(.{});
const ret = mod.call({});
try testing.expectEqual(20, ret.getValue(i32));
}
const d_split_interleaved = try zml.testing.compileAndCall(platform, Fns.testSplitInterleaved, .{});
@ -1106,11 +1106,11 @@ test sdpaMemEfficient {
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd });
const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng.call({}).withTags(.{ .b, .h, .q, .hd });
const k = rng.call({}).withTags(.{ .b, .h, .k, .hd });
const v = rng.call({}).withTags(.{ .b, .h, .k, .hd });
const mask = rng_mask.call({}).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(
platform,
@ -1164,11 +1164,11 @@ test "sdpaMemEfficient transposed" {
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd });
const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
const q = rng.call({}).withTags(.{ .b, .q, .h, .hd });
const k = rng.call({}).withTags(.{ .b, .k, .h, .hd });
const v = rng.call({}).withTags(.{ .b, .k, .h, .hd });
const mask = rng_mask.call({}).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(
platform,
@ -1266,7 +1266,7 @@ test sampleTokens {
const logits, const expected: i32 = logits_expected;
var logits_buff = try zml.Buffer.fromArray(platform, logits);
defer logits_buff.deinit();
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
var sampled, rng_buff = mod.call(.{ logits_buff, rng_buff });
defer sampled.deinit();
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
}
@ -1304,7 +1304,6 @@ pub const DynamicSamplingStrategy = struct {
opts: Opts,
) !zml.Bufferized(DynamicSamplingStrategy) {
return .{
.max_top_k = 0,
.top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32),
.temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype),
.top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype),

View File

@ -108,13 +108,15 @@ test "simple while" {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
const counter: zml.Bufferized(CountInts) = .{
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
};
const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum });
const res0, const res1 = try zml.testing.compileAndCall(
platform,
CountInts._fwd,
.{
.{ .step = try .scalar(platform, 1, .i64), .end = try .scalar(platform, 10, .i64) },
try .scalar(platform, 0, .i64),
try .scalar(platform, 0, .i64),
},
);
const last_i = try res0.getValue(i64);
const sum = try res1.getValue(i64);

View File

@ -532,7 +532,7 @@ pub const Tensor = struct {
}
pub const Rng = struct {
_state: Tensor,
_state: Tensor = .{ ._shape = .init(.{2}, .u64), ._id = .{ .buffer_id = 0 } },
algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT,
pub fn shape() ShapeOf(Rng) {
@ -542,10 +542,8 @@ pub const Tensor = struct {
}
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
const bits: [2]u64 = @bitCast(seed);
return .{
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
.algorithm = undefined,
._state = try Buffer.fromBytes(platform, Rng.shape()._state, std.mem.asBytes(&seed)),
};
}
@ -643,17 +641,15 @@ pub const Tensor = struct {
const platform = zml.testing.env();
// Compute stats over a uniform distribution on [-2, 10].
const rand, const stats = try zml.testing.compileAndCall(
const rand, const stats = try zml.testing.compileAndCallWithTensors(
platform,
Stats.uniformStats,
.{
try Rng.init(platform, 1234),
Shape.init(.{1024}, .f32),
.{ .min = -2, .max = 10 },
},
.{ Rng.shape(), zml.Shape.init(.{1024}, .f32), .{ .min = -2, .max = 10 } },
.{try Rng.init(platform, 1234)},
);
// Check the Rng state has been modified.
try std.testing.expect(try rand._state.getValue(i128) != 1234);
try std.testing.expect(try rand._state.getValue(u128) != 1234);
// Check the mean and variance are close to theoritical values.
const mean_ = try stats.mean.getValue(f32);
@ -746,9 +742,12 @@ pub const Tensor = struct {
const platform = zml.testing.env();
const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 };
const rand, const stats = try zml.testing.compileAndCall(platform, Stats.gumbelStats, .{
try Rng.init(platform, 1234), try HostBuffer.fromArray(&tgt_dist).toDevice(platform),
});
const rand, const stats = try zml.testing.compileAndCallWithTensors(
platform,
Stats.gumbelStats,
.{ Rng.shape(), zml.Shape.init(.{tgt_dist.len}, .f32) },
.{ try Rng.init(platform, 1234), try .fromArray(platform, tgt_dist) },
);
// Check the Rng state has been modified.
try std.testing.expect(try rand._state.getValue(i128) != 1234);
@ -1620,15 +1619,15 @@ pub const Tensor = struct {
};
{
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
}
{
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
}
{
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
}
}
@ -3965,13 +3964,12 @@ test "Tensor.maxPool2d" {
);
}
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
pub fn Bufferized(comptime T: type) type {
// TODO: we should strip out the non-buffer fields.
// Currently it's confusing cause the Bufferized struct contains field that are never read.
// Also it will simplify the layout of the Bufferized struct.
// accelerating the calls to execute.
return meta.MapType(Tensor, Buffer).map(T);
return meta.MapRestrict(Tensor, Buffer).map(T);
}
/// Return a clone of a type with Tensors replaced by Shapes.

View File

@ -112,13 +112,30 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
return error.TestExpectedEqual;
}
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
/// This is similar to zml.Bufferized,
/// but also keep every other fields that could be used during compilation.
///
/// see `compileAndCall`.
pub fn BufferizedWithArgs(comptime T: type) type {
// TODO: we should strip out the non-buffer fields.
// Currently it's confusing cause the Bufferized struct contains field that are never read.
// Also it will simplify the layout of the Bufferized struct.
// accelerating the calls to execute.
return meta.MapType(zml.Tensor, zml.Buffer).map(T);
}
/// Compile a function and immediatly call it with the given buffers.
/// The compiled module is discarded after the call.
/// Useful during testing when a module is typically called only once.
///
/// Note: `func` needs explicit types on all parameters.
/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types.
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
pub fn compileAndCall(
platform: zml.Platform,
func: anytype,
buffer_and_args: BufferizedWithArgs(stdx.meta.FnArgs(func)),
) !zml.Bufferized(stdx.meta.FnResult(func)) {
// This simplify test API and also ensure this fn isn't used outside of tests.
const allocator = std.testing.allocator;
var arena = std.heap.ArenaAllocator.init(allocator);
@ -130,18 +147,30 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
}
};
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args);
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_and_args, &shape_args);
const mod = try zml.compileFn(allocator, func, shape_args, platform);
var mod = try zml.compileFn(allocator, func, shape_args, platform);
defer mod.deinit();
return mod.call(buffer_args);
// Note: we don't use the type safe API of mod,
// cause mod.call expects a `zml.Bufferized` while we have `BufferizedWithArgs`.
mod.inner.prepare(buffer_and_args);
mod.inner._unsafeCall();
var result: zml.Bufferized(stdx.meta.FnResult(func)) = undefined;
mod.inner._unsafeAssignResults(@TypeOf(result), &result);
return result;
}
/// Compile a function and immediatly call it with the given buffers.
/// The compiled module is discarded after the call.
/// Useful during testing when a module is typically called only once.
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
pub fn compileAndCallWithTensors(
platform: zml.Platform,
func: anytype,
shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)),
buffer_args: zml.Bufferized(stdx.meta.FnArgs(func)),
) !zml.Bufferized(stdx.meta.FnResult(func)) {
// This simplify test API and also ensure this fn isn't used outside of tests.
const allocator = std.testing.allocator;
var arena = std.heap.ArenaAllocator.init(allocator);

View File

@ -1,8 +1,8 @@
const std = @import("std");
const stdx = @import("stdx");
const zml = @import("zml.zig");
const Tensor = zml.Tensor;
const log = std.log.scoped(.zml_torch);
@ -141,16 +141,14 @@ test pixelShuffle {
const platform = zml.testing.env();
const upscale_factor = 3;
const shape = zml.Shape.init(.{ .b = 1, .c = 9, .h = 4, .w = 4 }, .i32);
const input = input: {
var digits: [9 * 4 * 4]i32 = undefined;
for (&digits, 0..) |*d, i| d.* = @intCast(i);
// TODO should we have tags in buffers ?
const input = try zml.Buffer.fromSlice(platform, .{ 1, 9, 4, 4 }, &digits);
const output = try zml.testing.compileAndCallWithTensors(
platform,
pixelShuffle,
.{ zml.Shape.init(.{ .batch_size = 1, .c = 9, .h = 4, .w = 4 }, .i32), upscale_factor },
.{ input, upscale_factor },
);
break :input try zml.Buffer.fromSlice(platform, shape, &digits);
};
const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor });
const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{
.{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },