zml: eliminate compile-time fields from Bufferized, removing the need to pass undefined to exe.call for inlined arguments. Introduce BufferizedWithArgs in zml.testing for compileAndCall utility.
This commit is contained in:
parent
364a222dc1
commit
f5ab2c3a55
61
zml/exe.zig
61
zml/exe.zig
@ -231,6 +231,35 @@ pub const BaseExe = struct {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
||||
const LocalContext = struct {
|
||||
index: u32,
|
||||
platform: Platform,
|
||||
outputs: []const [*]*pjrt.Buffer,
|
||||
output_shapes: []Shape,
|
||||
};
|
||||
var local_ctx: LocalContext = .{
|
||||
.index = 0,
|
||||
.platform = self.platform,
|
||||
.outputs = self.output_per_device,
|
||||
.output_shapes = self.result_shapes,
|
||||
};
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
||||
const i = ctx.index;
|
||||
ctx.index += 1;
|
||||
if (i >= ctx.output_shapes.len) return;
|
||||
|
||||
var shards: Buffer.Shards = .{};
|
||||
for (ctx.outputs) |buff| {
|
||||
shards.appendAssumeCapacity(buff[i]);
|
||||
}
|
||||
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice());
|
||||
}
|
||||
}).cb, &local_ctx, result);
|
||||
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
||||
}
|
||||
|
||||
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
||||
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
||||
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||
@ -305,7 +334,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
||||
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
||||
self.inner._unsafeCall();
|
||||
var result: Bufferized(ReturnT) = undefined;
|
||||
assignRawBuffers(&result, self.inner.platform, self.inner.output_per_device, self.inner.result_shapes);
|
||||
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
||||
return result;
|
||||
}
|
||||
};
|
||||
@ -349,36 +378,6 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
|
||||
return context.index;
|
||||
}
|
||||
|
||||
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
||||
fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape) void {
|
||||
const LocalContext = struct {
|
||||
index: u32,
|
||||
platform: Platform,
|
||||
buffers: []const [*]*pjrt.Buffer,
|
||||
buffer_shapes: []Shape,
|
||||
};
|
||||
var local_ctx: LocalContext = .{
|
||||
.index = 0,
|
||||
.platform = platform,
|
||||
.buffers = buffers,
|
||||
.buffer_shapes = buffer_shapes,
|
||||
};
|
||||
meta.visit((struct {
|
||||
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
||||
const i = ctx.index;
|
||||
ctx.index += 1;
|
||||
if (i >= ctx.buffer_shapes.len) return;
|
||||
|
||||
var shards: Buffer.Shards = .{};
|
||||
for (ctx.buffers) |buff| {
|
||||
shards.appendAssumeCapacity(buff[i]);
|
||||
}
|
||||
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
|
||||
}
|
||||
}).cb, &local_ctx, v);
|
||||
stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
||||
}
|
||||
|
||||
fn prettyFnName(
|
||||
comptime func: anytype,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
164
zml/meta.zig
164
zml/meta.zig
@ -1,5 +1,6 @@
|
||||
const std = @import("std");
|
||||
const testing = std.testing;
|
||||
const builtin = @import("builtin");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
const FnParam = stdx.meta.FnParam;
|
||||
@ -9,6 +10,7 @@ test {
|
||||
std.testing.refAllDecls(@This());
|
||||
}
|
||||
|
||||
/// Visit a given type `T` and replace all fields containing `From` by fields containing `To`.
|
||||
pub fn MapType(From: type, To: type) type {
|
||||
return struct {
|
||||
pub fn map(T: type) type {
|
||||
@ -78,6 +80,32 @@ pub fn MapType(From: type, To: type) type {
|
||||
};
|
||||
}
|
||||
|
||||
test MapType {
|
||||
const A = struct { a: u32 };
|
||||
const B = struct { b: u32 };
|
||||
|
||||
const A2B = MapType(A, B);
|
||||
|
||||
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||
const struct_b = A2B.map(StructA){
|
||||
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
|
||||
.maybe = null,
|
||||
.one = .{ .b = 2 },
|
||||
.other = 43,
|
||||
};
|
||||
_ = struct_b;
|
||||
|
||||
// TODO(corendos) fixme, union_b should contains Bs not As.
|
||||
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||
const union_b = [_]A2B.map(UnionA){
|
||||
.{ .some = &[2]A{ .{ .a = 0 }, .{ .a = 1 } } },
|
||||
.{ .one = .{ .a = 2 } },
|
||||
.{ .maybe = null },
|
||||
.{ .other = 43 },
|
||||
};
|
||||
_ = union_b;
|
||||
}
|
||||
|
||||
/// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct
|
||||
/// and calls the callback when it finds a `From` element, and writes it to the `to` struct.
|
||||
/// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it.
|
||||
@ -136,10 +164,10 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
||||
@field(from, field.name),
|
||||
&@field(to, field.name),
|
||||
);
|
||||
} else if (field.default_value) |_| {
|
||||
} else if (field.default_value_ptr) |_| {
|
||||
@field(to, field.name) = null;
|
||||
} else {
|
||||
stdx.debug.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
|
||||
stdx.debug.compileError("Mapping {} -> {} inside {} failed. Missing field {s} in {}", .{ From, To, FromStruct, field.name, ToStruct });
|
||||
},
|
||||
else => @field(to, field.name) = @field(from, field.name),
|
||||
}
|
||||
@ -234,6 +262,136 @@ test mapAlloc {
|
||||
try testing.expectEqual(12, bb.static_slice[1].b);
|
||||
}
|
||||
|
||||
/// Visit a given type `T` and:
|
||||
/// * replace all fields containing `From` by fields containing `To`
|
||||
/// * drop all fields not containing any `From`.
|
||||
/// The returned type will contains only `To` making it easy for the compiler to produce compact layout.
|
||||
/// Used by `zml.Bufferized` to strip compile time arguments from a model struct.
|
||||
pub fn MapRestrict(From: type, To: type) type {
|
||||
return struct {
|
||||
pub fn map(T: type) type {
|
||||
switch (T) {
|
||||
From => return To,
|
||||
?From => return ?To,
|
||||
*From => return *To,
|
||||
*const From => return *const To,
|
||||
[]From => return []To,
|
||||
[]const From => return []const To,
|
||||
else => {},
|
||||
}
|
||||
|
||||
if (!Contains(T, From)) return void;
|
||||
|
||||
return switch (@typeInfo(T)) {
|
||||
.@"struct" => |struct_infos| {
|
||||
// We know that at least one of the struct field contains a From.
|
||||
// We map each field individually. Fields without From and comptime fields are removed.
|
||||
const fields = struct_infos.fields;
|
||||
var num_fields: usize = 0;
|
||||
|
||||
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
|
||||
for (fields) |field| {
|
||||
if (!field.is_comptime and Contains(field.type, From)) {
|
||||
const R = map(field.type);
|
||||
if (R == field.type) {
|
||||
struct_fields[num_fields] = field;
|
||||
} else {
|
||||
const name = if (struct_infos.is_tuple) struct_infos.fields[num_fields].name else field.name;
|
||||
struct_fields[num_fields] = .{
|
||||
.name = name,
|
||||
.type = R,
|
||||
.default_value_ptr = null,
|
||||
.is_comptime = false,
|
||||
.alignment = @alignOf(R),
|
||||
};
|
||||
// Handle the case `field: ?Tensor = null`
|
||||
// Generic handling of default value is not possible.
|
||||
if (R == ?To) {
|
||||
struct_fields[num_fields].default_value_ptr = &@as(R, null);
|
||||
}
|
||||
}
|
||||
num_fields += 1;
|
||||
}
|
||||
}
|
||||
if (num_fields == 0) return void;
|
||||
return @Type(.{ .@"struct" = .{
|
||||
.layout = .auto,
|
||||
.fields = struct_fields[0..num_fields],
|
||||
.decls = &.{},
|
||||
.is_tuple = struct_infos.is_tuple,
|
||||
} });
|
||||
},
|
||||
.@"union" => |union_info| {
|
||||
// We know that at least one of the union field contains a From.
|
||||
// We map each field individually. Fields without From, are replaced by "void".
|
||||
const fields = union_info.fields;
|
||||
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
|
||||
for (0.., fields) |i, field| {
|
||||
union_fields[i] = .{
|
||||
.name = field.name,
|
||||
.type = map(field.type),
|
||||
.alignment = 0,
|
||||
};
|
||||
}
|
||||
return @Type(.{ .@"union" = .{
|
||||
.layout = .auto,
|
||||
.tag_type = union_info.tag_type,
|
||||
.fields = union_fields[0..],
|
||||
.decls = &.{},
|
||||
} });
|
||||
},
|
||||
.array => |arr_info| [arr_info.len]map(arr_info.child),
|
||||
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||
.slice => if (ptr_info.is_const)
|
||||
[]const map(ptr_info.child)
|
||||
else
|
||||
[]map(ptr_info.child),
|
||||
.one => if (ptr_info.is_const)
|
||||
*const map(ptr_info.child)
|
||||
else
|
||||
*map(ptr_info.child),
|
||||
.many => if (ptr_info.is_const)
|
||||
[*]const map(ptr_info.child)
|
||||
else
|
||||
[*]map(ptr_info.child),
|
||||
.c => if (ptr_info.is_const)
|
||||
[*c]map(ptr_info.child)
|
||||
else
|
||||
[*c]map(ptr_info.child),
|
||||
},
|
||||
.optional => |opt_info| ?map(opt_info.child),
|
||||
else => T,
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test MapRestrict {
|
||||
const A = struct { a: u32 };
|
||||
const B = struct { b: u32 };
|
||||
|
||||
const A2B = MapRestrict(A, B);
|
||||
|
||||
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||
const struct_b = A2B.map(StructA){
|
||||
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
|
||||
.maybe = null,
|
||||
.one = .{ .b = 2 },
|
||||
// Note how struct_b doesn't even have a .other field now.
|
||||
};
|
||||
_ = struct_b;
|
||||
|
||||
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||
const union_b = [_]A2B.map(UnionA){
|
||||
.{ .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } } },
|
||||
.{ .one = .{ .b = 2 } },
|
||||
.{ .maybe = null },
|
||||
// Note how union_b.other is void now.
|
||||
.{ .other = {} },
|
||||
};
|
||||
_ = union_b;
|
||||
}
|
||||
|
||||
/// Recursively visit the given struct and calls the callback for each K found.
|
||||
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
||||
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||
@ -554,7 +712,7 @@ pub fn Contains(Haystack: type, T: type) bool {
|
||||
return switch (@typeInfo(Haystack)) {
|
||||
.@"struct" => |info| {
|
||||
inline for (info.fields) |field| {
|
||||
if (Contains(field.type, T))
|
||||
if (!field.is_comptime and Contains(field.type, T))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
||||
@ -1011,21 +1011,21 @@ test FnCache {
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
||||
const nn: zml.Bufferized(NN) = .{
|
||||
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
|
||||
const nn: zml.testing.BufferizedWithArgs(NN) = .{
|
||||
.layers = .{
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
|
||||
.w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }),
|
||||
.b = try .fromArray(platform, [2]f16{ 0, 0 }),
|
||||
},
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
|
||||
.w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }),
|
||||
.b = try .fromArray(platform, [2]f16{ 10, 10 }),
|
||||
},
|
||||
// third layer is different
|
||||
.{
|
||||
.w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
|
||||
.b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
|
||||
.w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }),
|
||||
.b = try .fromArray(platform, [3]f16{ -10, -10, -10 }),
|
||||
},
|
||||
},
|
||||
};
|
||||
@ -1084,13 +1084,13 @@ test "FnCache with mixed integer/tensor" {
|
||||
}
|
||||
};
|
||||
|
||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
||||
const nn: zml.Bufferized(NN) = .{
|
||||
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
|
||||
const nn: zml.testing.BufferizedWithArgs(NN) = .{
|
||||
.layers = .{
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) },
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) },
|
||||
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }) },
|
||||
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }) },
|
||||
// third layer has different shape
|
||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) },
|
||||
.{ .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }) },
|
||||
},
|
||||
};
|
||||
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
|
||||
|
||||
25
zml/nn.zig
25
zml/nn.zig
@ -401,7 +401,7 @@ test "real/img" {
|
||||
{
|
||||
const mod = try zml.compileFn(std.testing.allocator, Fns.testSplitSeq, .{}, platform);
|
||||
defer mod.deinit();
|
||||
const ret = mod.call(.{});
|
||||
const ret = mod.call({});
|
||||
try testing.expectEqual(20, ret.getValue(i32));
|
||||
}
|
||||
const d_split_interleaved = try zml.testing.compileAndCall(platform, Fns.testSplitInterleaved, .{});
|
||||
@ -1106,11 +1106,11 @@ test sdpaMemEfficient {
|
||||
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||
defer rng_mask.deinit();
|
||||
|
||||
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
||||
const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd });
|
||||
const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
|
||||
const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
|
||||
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
|
||||
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
|
||||
const q = rng.call({}).withTags(.{ .b, .h, .q, .hd });
|
||||
const k = rng.call({}).withTags(.{ .b, .h, .k, .hd });
|
||||
const v = rng.call({}).withTags(.{ .b, .h, .k, .hd });
|
||||
const mask = rng_mask.call({}).withTags(.{ .q, .k });
|
||||
|
||||
const ref_res = try zml.testing.compileAndCall(
|
||||
platform,
|
||||
@ -1164,11 +1164,11 @@ test "sdpaMemEfficient transposed" {
|
||||
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||
defer rng_mask.deinit();
|
||||
|
||||
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
||||
const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd });
|
||||
const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
|
||||
const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
|
||||
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
|
||||
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
|
||||
const q = rng.call({}).withTags(.{ .b, .q, .h, .hd });
|
||||
const k = rng.call({}).withTags(.{ .b, .k, .h, .hd });
|
||||
const v = rng.call({}).withTags(.{ .b, .k, .h, .hd });
|
||||
const mask = rng_mask.call({}).withTags(.{ .q, .k });
|
||||
|
||||
const ref_res = try zml.testing.compileAndCall(
|
||||
platform,
|
||||
@ -1266,7 +1266,7 @@ test sampleTokens {
|
||||
const logits, const expected: i32 = logits_expected;
|
||||
var logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||
defer logits_buff.deinit();
|
||||
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
|
||||
var sampled, rng_buff = mod.call(.{ logits_buff, rng_buff });
|
||||
defer sampled.deinit();
|
||||
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
|
||||
}
|
||||
@ -1304,7 +1304,6 @@ pub const DynamicSamplingStrategy = struct {
|
||||
opts: Opts,
|
||||
) !zml.Bufferized(DynamicSamplingStrategy) {
|
||||
return .{
|
||||
.max_top_k = 0,
|
||||
.top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32),
|
||||
.temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype),
|
||||
.top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype),
|
||||
|
||||
16
zml/ops.zig
16
zml/ops.zig
@ -108,13 +108,15 @@ test "simple while" {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
||||
const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
||||
const counter: zml.Bufferized(CountInts) = .{
|
||||
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
|
||||
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
|
||||
};
|
||||
const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum });
|
||||
const res0, const res1 = try zml.testing.compileAndCall(
|
||||
platform,
|
||||
CountInts._fwd,
|
||||
.{
|
||||
.{ .step = try .scalar(platform, 1, .i64), .end = try .scalar(platform, 10, .i64) },
|
||||
try .scalar(platform, 0, .i64),
|
||||
try .scalar(platform, 0, .i64),
|
||||
},
|
||||
);
|
||||
const last_i = try res0.getValue(i64);
|
||||
const sum = try res1.getValue(i64);
|
||||
|
||||
|
||||
@ -532,7 +532,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
pub const Rng = struct {
|
||||
_state: Tensor,
|
||||
_state: Tensor = .{ ._shape = .init(.{2}, .u64), ._id = .{ .buffer_id = 0 } },
|
||||
algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT,
|
||||
|
||||
pub fn shape() ShapeOf(Rng) {
|
||||
@ -542,10 +542,8 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
|
||||
const bits: [2]u64 = @bitCast(seed);
|
||||
return .{
|
||||
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
|
||||
.algorithm = undefined,
|
||||
._state = try Buffer.fromBytes(platform, Rng.shape()._state, std.mem.asBytes(&seed)),
|
||||
};
|
||||
}
|
||||
|
||||
@ -643,17 +641,15 @@ pub const Tensor = struct {
|
||||
|
||||
const platform = zml.testing.env();
|
||||
// Compute stats over a uniform distribution on [-2, 10].
|
||||
const rand, const stats = try zml.testing.compileAndCall(
|
||||
const rand, const stats = try zml.testing.compileAndCallWithTensors(
|
||||
platform,
|
||||
Stats.uniformStats,
|
||||
.{
|
||||
try Rng.init(platform, 1234),
|
||||
Shape.init(.{1024}, .f32),
|
||||
.{ .min = -2, .max = 10 },
|
||||
},
|
||||
.{ Rng.shape(), zml.Shape.init(.{1024}, .f32), .{ .min = -2, .max = 10 } },
|
||||
.{try Rng.init(platform, 1234)},
|
||||
);
|
||||
|
||||
// Check the Rng state has been modified.
|
||||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
||||
try std.testing.expect(try rand._state.getValue(u128) != 1234);
|
||||
|
||||
// Check the mean and variance are close to theoritical values.
|
||||
const mean_ = try stats.mean.getValue(f32);
|
||||
@ -746,9 +742,12 @@ pub const Tensor = struct {
|
||||
|
||||
const platform = zml.testing.env();
|
||||
const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 };
|
||||
const rand, const stats = try zml.testing.compileAndCall(platform, Stats.gumbelStats, .{
|
||||
try Rng.init(platform, 1234), try HostBuffer.fromArray(&tgt_dist).toDevice(platform),
|
||||
});
|
||||
const rand, const stats = try zml.testing.compileAndCallWithTensors(
|
||||
platform,
|
||||
Stats.gumbelStats,
|
||||
.{ Rng.shape(), zml.Shape.init(.{tgt_dist.len}, .f32) },
|
||||
.{ try Rng.init(platform, 1234), try .fromArray(platform, tgt_dist) },
|
||||
);
|
||||
// Check the Rng state has been modified.
|
||||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
||||
|
||||
@ -1620,15 +1619,15 @@ pub const Tensor = struct {
|
||||
};
|
||||
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
|
||||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
|
||||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||
}
|
||||
{
|
||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
|
||||
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
|
||||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||
}
|
||||
}
|
||||
@ -3965,13 +3964,12 @@ test "Tensor.maxPool2d" {
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
pub fn Bufferized(comptime T: type) type {
|
||||
// TODO: we should strip out the non-buffer fields.
|
||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||
// Also it will simplify the layout of the Bufferized struct.
|
||||
// accelerating the calls to execute.
|
||||
return meta.MapType(Tensor, Buffer).map(T);
|
||||
return meta.MapRestrict(Tensor, Buffer).map(T);
|
||||
}
|
||||
|
||||
/// Return a clone of a type with Tensors replaced by Shapes.
|
||||
|
||||
@ -112,13 +112,30 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
|
||||
return error.TestExpectedEqual;
|
||||
}
|
||||
|
||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||
/// This is similar to zml.Bufferized,
|
||||
/// but also keep every other fields that could be used during compilation.
|
||||
///
|
||||
/// see `compileAndCall`.
|
||||
pub fn BufferizedWithArgs(comptime T: type) type {
|
||||
// TODO: we should strip out the non-buffer fields.
|
||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||
// Also it will simplify the layout of the Bufferized struct.
|
||||
// accelerating the calls to execute.
|
||||
return meta.MapType(zml.Tensor, zml.Buffer).map(T);
|
||||
}
|
||||
|
||||
/// Compile a function and immediatly call it with the given buffers.
|
||||
/// The compiled module is discarded after the call.
|
||||
/// Useful during testing when a module is typically called only once.
|
||||
///
|
||||
/// Note: `func` needs explicit types on all parameters.
|
||||
/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types.
|
||||
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||
pub fn compileAndCall(
|
||||
platform: zml.Platform,
|
||||
func: anytype,
|
||||
buffer_and_args: BufferizedWithArgs(stdx.meta.FnArgs(func)),
|
||||
) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||
const allocator = std.testing.allocator;
|
||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||
@ -130,18 +147,30 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
|
||||
}
|
||||
};
|
||||
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
|
||||
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args);
|
||||
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_and_args, &shape_args);
|
||||
|
||||
const mod = try zml.compileFn(allocator, func, shape_args, platform);
|
||||
var mod = try zml.compileFn(allocator, func, shape_args, platform);
|
||||
defer mod.deinit();
|
||||
|
||||
return mod.call(buffer_args);
|
||||
// Note: we don't use the type safe API of mod,
|
||||
// cause mod.call expects a `zml.Bufferized` while we have `BufferizedWithArgs`.
|
||||
mod.inner.prepare(buffer_and_args);
|
||||
mod.inner._unsafeCall();
|
||||
|
||||
var result: zml.Bufferized(stdx.meta.FnResult(func)) = undefined;
|
||||
mod.inner._unsafeAssignResults(@TypeOf(result), &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Compile a function and immediatly call it with the given buffers.
|
||||
/// The compiled module is discarded after the call.
|
||||
/// Useful during testing when a module is typically called only once.
|
||||
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||
pub fn compileAndCallWithTensors(
|
||||
platform: zml.Platform,
|
||||
func: anytype,
|
||||
shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)),
|
||||
buffer_args: zml.Bufferized(stdx.meta.FnArgs(func)),
|
||||
) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||
const allocator = std.testing.allocator;
|
||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
const std = @import("std");
|
||||
|
||||
const stdx = @import("stdx");
|
||||
|
||||
const zml = @import("zml.zig");
|
||||
|
||||
const Tensor = zml.Tensor;
|
||||
|
||||
const log = std.log.scoped(.zml_torch);
|
||||
@ -141,16 +141,14 @@ test pixelShuffle {
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const upscale_factor = 3;
|
||||
var digits: [9 * 4 * 4]i32 = undefined;
|
||||
for (&digits, 0..) |*d, i| d.* = @intCast(i);
|
||||
// TODO should we have tags in buffers ?
|
||||
const input = try zml.Buffer.fromSlice(platform, .{ 1, 9, 4, 4 }, &digits);
|
||||
const output = try zml.testing.compileAndCallWithTensors(
|
||||
platform,
|
||||
pixelShuffle,
|
||||
.{ zml.Shape.init(.{ .batch_size = 1, .c = 9, .h = 4, .w = 4 }, .i32), upscale_factor },
|
||||
.{ input, upscale_factor },
|
||||
);
|
||||
const shape = zml.Shape.init(.{ .b = 1, .c = 9, .h = 4, .w = 4 }, .i32);
|
||||
const input = input: {
|
||||
var digits: [9 * 4 * 4]i32 = undefined;
|
||||
for (&digits, 0..) |*d, i| d.* = @intCast(i);
|
||||
break :input try zml.Buffer.fromSlice(platform, shape, &digits);
|
||||
};
|
||||
|
||||
const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor });
|
||||
|
||||
const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{
|
||||
.{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },
|
||||
|
||||
Loading…
Reference in New Issue
Block a user