zml: eliminate compile-time fields from Bufferized, removing the need to pass undefined to exe.call for inlined arguments. Introduce BufferizedWithArgs in zml.testing for compileAndCall utility.
This commit is contained in:
parent
364a222dc1
commit
f5ab2c3a55
61
zml/exe.zig
61
zml/exe.zig
@ -231,6 +231,35 @@ pub const BaseExe = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn _unsafeAssignResults(self: BaseExe, T: type, result: *T) void {
|
||||||
|
const LocalContext = struct {
|
||||||
|
index: u32,
|
||||||
|
platform: Platform,
|
||||||
|
outputs: []const [*]*pjrt.Buffer,
|
||||||
|
output_shapes: []Shape,
|
||||||
|
};
|
||||||
|
var local_ctx: LocalContext = .{
|
||||||
|
.index = 0,
|
||||||
|
.platform = self.platform,
|
||||||
|
.outputs = self.output_per_device,
|
||||||
|
.output_shapes = self.result_shapes,
|
||||||
|
};
|
||||||
|
meta.visit((struct {
|
||||||
|
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
||||||
|
const i = ctx.index;
|
||||||
|
ctx.index += 1;
|
||||||
|
if (i >= ctx.output_shapes.len) return;
|
||||||
|
|
||||||
|
var shards: Buffer.Shards = .{};
|
||||||
|
for (ctx.outputs) |buff| {
|
||||||
|
shards.appendAssumeCapacity(buff[i]);
|
||||||
|
}
|
||||||
|
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.output_shapes[i], shards.constSlice());
|
||||||
|
}
|
||||||
|
}).cb, &local_ctx, result);
|
||||||
|
stdx.debug.internalAssert(local_ctx.index == self.result_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ self.output_per_device.len, @typeName(T), local_ctx.index });
|
||||||
|
}
|
||||||
|
|
||||||
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
pub fn serialize(self: BaseExe, writer: anytype) !void {
|
||||||
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
var executable = try self.exe.getExecutable(self.platform.pjrt_api);
|
||||||
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
var serialize_result = try executable.serialize(self.platform.pjrt_api);
|
||||||
@ -305,7 +334,7 @@ pub fn Exe(ArgsT: type, ReturnT: type) type {
|
|||||||
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
std.debug.assert(total_ready == self.inner.input_buffer_count);
|
||||||
self.inner._unsafeCall();
|
self.inner._unsafeCall();
|
||||||
var result: Bufferized(ReturnT) = undefined;
|
var result: Bufferized(ReturnT) = undefined;
|
||||||
assignRawBuffers(&result, self.inner.platform, self.inner.output_per_device, self.inner.result_shapes);
|
self.inner._unsafeAssignResults(Bufferized(ReturnT), &result);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -349,36 +378,6 @@ fn fillBuffers(v: anytype, shapes: []const Shape, buffers: []const [*]*pjrt.Buff
|
|||||||
return context.index;
|
return context.index;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
|
||||||
fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape) void {
|
|
||||||
const LocalContext = struct {
|
|
||||||
index: u32,
|
|
||||||
platform: Platform,
|
|
||||||
buffers: []const [*]*pjrt.Buffer,
|
|
||||||
buffer_shapes: []Shape,
|
|
||||||
};
|
|
||||||
var local_ctx: LocalContext = .{
|
|
||||||
.index = 0,
|
|
||||||
.platform = platform,
|
|
||||||
.buffers = buffers,
|
|
||||||
.buffer_shapes = buffer_shapes,
|
|
||||||
};
|
|
||||||
meta.visit((struct {
|
|
||||||
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
|
||||||
const i = ctx.index;
|
|
||||||
ctx.index += 1;
|
|
||||||
if (i >= ctx.buffer_shapes.len) return;
|
|
||||||
|
|
||||||
var shards: Buffer.Shards = .{};
|
|
||||||
for (ctx.buffers) |buff| {
|
|
||||||
shards.appendAssumeCapacity(buff[i]);
|
|
||||||
}
|
|
||||||
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
|
|
||||||
}
|
|
||||||
}).cb, &local_ctx, v);
|
|
||||||
stdx.debug.internalAssert(local_ctx.index == buffer_shapes.len, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prettyFnName(
|
fn prettyFnName(
|
||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
allocator: std.mem.Allocator,
|
allocator: std.mem.Allocator,
|
||||||
|
|||||||
164
zml/meta.zig
164
zml/meta.zig
@ -1,5 +1,6 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const testing = std.testing;
|
const testing = std.testing;
|
||||||
|
const builtin = @import("builtin");
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const FnParam = stdx.meta.FnParam;
|
const FnParam = stdx.meta.FnParam;
|
||||||
@ -9,6 +10,7 @@ test {
|
|||||||
std.testing.refAllDecls(@This());
|
std.testing.refAllDecls(@This());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Visit a given type `T` and replace all fields containing `From` by fields containing `To`.
|
||||||
pub fn MapType(From: type, To: type) type {
|
pub fn MapType(From: type, To: type) type {
|
||||||
return struct {
|
return struct {
|
||||||
pub fn map(T: type) type {
|
pub fn map(T: type) type {
|
||||||
@ -78,6 +80,32 @@ pub fn MapType(From: type, To: type) type {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test MapType {
|
||||||
|
const A = struct { a: u32 };
|
||||||
|
const B = struct { b: u32 };
|
||||||
|
|
||||||
|
const A2B = MapType(A, B);
|
||||||
|
|
||||||
|
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||||
|
const struct_b = A2B.map(StructA){
|
||||||
|
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
|
||||||
|
.maybe = null,
|
||||||
|
.one = .{ .b = 2 },
|
||||||
|
.other = 43,
|
||||||
|
};
|
||||||
|
_ = struct_b;
|
||||||
|
|
||||||
|
// TODO(corendos) fixme, union_b should contains Bs not As.
|
||||||
|
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||||
|
const union_b = [_]A2B.map(UnionA){
|
||||||
|
.{ .some = &[2]A{ .{ .a = 0 }, .{ .a = 1 } } },
|
||||||
|
.{ .one = .{ .a = 2 } },
|
||||||
|
.{ .maybe = null },
|
||||||
|
.{ .other = 43 },
|
||||||
|
};
|
||||||
|
_ = union_b;
|
||||||
|
}
|
||||||
|
|
||||||
/// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct
|
/// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct
|
||||||
/// and calls the callback when it finds a `From` element, and writes it to the `to` struct.
|
/// and calls the callback when it finds a `From` element, and writes it to the `to` struct.
|
||||||
/// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it.
|
/// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it.
|
||||||
@ -136,10 +164,10 @@ pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam
|
|||||||
@field(from, field.name),
|
@field(from, field.name),
|
||||||
&@field(to, field.name),
|
&@field(to, field.name),
|
||||||
);
|
);
|
||||||
} else if (field.default_value) |_| {
|
} else if (field.default_value_ptr) |_| {
|
||||||
@field(to, field.name) = null;
|
@field(to, field.name) = null;
|
||||||
} else {
|
} else {
|
||||||
stdx.debug.compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
|
stdx.debug.compileError("Mapping {} -> {} inside {} failed. Missing field {s} in {}", .{ From, To, FromStruct, field.name, ToStruct });
|
||||||
},
|
},
|
||||||
else => @field(to, field.name) = @field(from, field.name),
|
else => @field(to, field.name) = @field(from, field.name),
|
||||||
}
|
}
|
||||||
@ -234,6 +262,136 @@ test mapAlloc {
|
|||||||
try testing.expectEqual(12, bb.static_slice[1].b);
|
try testing.expectEqual(12, bb.static_slice[1].b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Visit a given type `T` and:
|
||||||
|
/// * replace all fields containing `From` by fields containing `To`
|
||||||
|
/// * drop all fields not containing any `From`.
|
||||||
|
/// The returned type will contains only `To` making it easy for the compiler to produce compact layout.
|
||||||
|
/// Used by `zml.Bufferized` to strip compile time arguments from a model struct.
|
||||||
|
pub fn MapRestrict(From: type, To: type) type {
|
||||||
|
return struct {
|
||||||
|
pub fn map(T: type) type {
|
||||||
|
switch (T) {
|
||||||
|
From => return To,
|
||||||
|
?From => return ?To,
|
||||||
|
*From => return *To,
|
||||||
|
*const From => return *const To,
|
||||||
|
[]From => return []To,
|
||||||
|
[]const From => return []const To,
|
||||||
|
else => {},
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!Contains(T, From)) return void;
|
||||||
|
|
||||||
|
return switch (@typeInfo(T)) {
|
||||||
|
.@"struct" => |struct_infos| {
|
||||||
|
// We know that at least one of the struct field contains a From.
|
||||||
|
// We map each field individually. Fields without From and comptime fields are removed.
|
||||||
|
const fields = struct_infos.fields;
|
||||||
|
var num_fields: usize = 0;
|
||||||
|
|
||||||
|
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
|
||||||
|
for (fields) |field| {
|
||||||
|
if (!field.is_comptime and Contains(field.type, From)) {
|
||||||
|
const R = map(field.type);
|
||||||
|
if (R == field.type) {
|
||||||
|
struct_fields[num_fields] = field;
|
||||||
|
} else {
|
||||||
|
const name = if (struct_infos.is_tuple) struct_infos.fields[num_fields].name else field.name;
|
||||||
|
struct_fields[num_fields] = .{
|
||||||
|
.name = name,
|
||||||
|
.type = R,
|
||||||
|
.default_value_ptr = null,
|
||||||
|
.is_comptime = false,
|
||||||
|
.alignment = @alignOf(R),
|
||||||
|
};
|
||||||
|
// Handle the case `field: ?Tensor = null`
|
||||||
|
// Generic handling of default value is not possible.
|
||||||
|
if (R == ?To) {
|
||||||
|
struct_fields[num_fields].default_value_ptr = &@as(R, null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
num_fields += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (num_fields == 0) return void;
|
||||||
|
return @Type(.{ .@"struct" = .{
|
||||||
|
.layout = .auto,
|
||||||
|
.fields = struct_fields[0..num_fields],
|
||||||
|
.decls = &.{},
|
||||||
|
.is_tuple = struct_infos.is_tuple,
|
||||||
|
} });
|
||||||
|
},
|
||||||
|
.@"union" => |union_info| {
|
||||||
|
// We know that at least one of the union field contains a From.
|
||||||
|
// We map each field individually. Fields without From, are replaced by "void".
|
||||||
|
const fields = union_info.fields;
|
||||||
|
var union_fields: [fields.len]std.builtin.Type.UnionField = undefined;
|
||||||
|
for (0.., fields) |i, field| {
|
||||||
|
union_fields[i] = .{
|
||||||
|
.name = field.name,
|
||||||
|
.type = map(field.type),
|
||||||
|
.alignment = 0,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return @Type(.{ .@"union" = .{
|
||||||
|
.layout = .auto,
|
||||||
|
.tag_type = union_info.tag_type,
|
||||||
|
.fields = union_fields[0..],
|
||||||
|
.decls = &.{},
|
||||||
|
} });
|
||||||
|
},
|
||||||
|
.array => |arr_info| [arr_info.len]map(arr_info.child),
|
||||||
|
.pointer => |ptr_info| switch (ptr_info.size) {
|
||||||
|
.slice => if (ptr_info.is_const)
|
||||||
|
[]const map(ptr_info.child)
|
||||||
|
else
|
||||||
|
[]map(ptr_info.child),
|
||||||
|
.one => if (ptr_info.is_const)
|
||||||
|
*const map(ptr_info.child)
|
||||||
|
else
|
||||||
|
*map(ptr_info.child),
|
||||||
|
.many => if (ptr_info.is_const)
|
||||||
|
[*]const map(ptr_info.child)
|
||||||
|
else
|
||||||
|
[*]map(ptr_info.child),
|
||||||
|
.c => if (ptr_info.is_const)
|
||||||
|
[*c]map(ptr_info.child)
|
||||||
|
else
|
||||||
|
[*c]map(ptr_info.child),
|
||||||
|
},
|
||||||
|
.optional => |opt_info| ?map(opt_info.child),
|
||||||
|
else => T,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
test MapRestrict {
|
||||||
|
const A = struct { a: u32 };
|
||||||
|
const B = struct { b: u32 };
|
||||||
|
|
||||||
|
const A2B = MapRestrict(A, B);
|
||||||
|
|
||||||
|
const StructA = struct { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||||
|
const struct_b = A2B.map(StructA){
|
||||||
|
.some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } },
|
||||||
|
.maybe = null,
|
||||||
|
.one = .{ .b = 2 },
|
||||||
|
// Note how struct_b doesn't even have a .other field now.
|
||||||
|
};
|
||||||
|
_ = struct_b;
|
||||||
|
|
||||||
|
const UnionA = union { some: []const A, one: A, maybe: ?A, other: u32 };
|
||||||
|
const union_b = [_]A2B.map(UnionA){
|
||||||
|
.{ .some = &[2]B{ .{ .b = 0 }, .{ .b = 1 } } },
|
||||||
|
.{ .one = .{ .b = 2 } },
|
||||||
|
.{ .maybe = null },
|
||||||
|
// Note how union_b.other is void now.
|
||||||
|
.{ .other = {} },
|
||||||
|
};
|
||||||
|
_ = union_b;
|
||||||
|
}
|
||||||
|
|
||||||
/// Recursively visit the given struct and calls the callback for each K found.
|
/// Recursively visit the given struct and calls the callback for each K found.
|
||||||
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
|
||||||
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
|
||||||
@ -554,7 +712,7 @@ pub fn Contains(Haystack: type, T: type) bool {
|
|||||||
return switch (@typeInfo(Haystack)) {
|
return switch (@typeInfo(Haystack)) {
|
||||||
.@"struct" => |info| {
|
.@"struct" => |info| {
|
||||||
inline for (info.fields) |field| {
|
inline for (info.fields) |field| {
|
||||||
if (Contains(field.type, T))
|
if (!field.is_comptime and Contains(field.type, T))
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
|
|||||||
@ -1011,21 +1011,21 @@ test FnCache {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
|
||||||
const nn: zml.Bufferized(NN) = .{
|
const nn: zml.testing.BufferizedWithArgs(NN) = .{
|
||||||
.layers = .{
|
.layers = .{
|
||||||
.{
|
.{
|
||||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }),
|
.w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }),
|
||||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 0, 0 }),
|
.b = try .fromArray(platform, [2]f16{ 0, 0 }),
|
||||||
},
|
},
|
||||||
.{
|
.{
|
||||||
.w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }),
|
.w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }),
|
||||||
.b = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ 10, 10 }),
|
.b = try .fromArray(platform, [2]f16{ 10, 10 }),
|
||||||
},
|
},
|
||||||
// third layer is different
|
// third layer is different
|
||||||
.{
|
.{
|
||||||
.w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }),
|
.w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }),
|
||||||
.b = try zml.Buffer.fromSlice(platform, .{3}, &[_]f16{ -10, -10, -10 }),
|
.b = try .fromArray(platform, [3]f16{ -10, -10, -10 }),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -1084,13 +1084,13 @@ test "FnCache with mixed integer/tensor" {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const x = try zml.Buffer.fromSlice(platform, .{2}, &[_]f16{ -1, 1 });
|
const x = try zml.Buffer.fromArray(platform, [2]f16{ -1, 1 });
|
||||||
const nn: zml.Bufferized(NN) = .{
|
const nn: zml.testing.BufferizedWithArgs(NN) = .{
|
||||||
.layers = .{
|
.layers = .{
|
||||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, -1, 0, 1 }) },
|
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, -1 }, .{ 0, 1 } }) },
|
||||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[_]f16{ 1, 2, 1, -1 }) },
|
.{ .w = try .fromArray(platform, [2][2]f16{ .{ 1, 2 }, .{ 1, -1 } }) },
|
||||||
// third layer has different shape
|
// third layer has different shape
|
||||||
.{ .w = try zml.Buffer.fromSlice(platform, .{ 3, 2 }, &[_]f16{ 1, 2, 0, 1, -1, 0 }) },
|
.{ .w = try .fromArray(platform, [3][2]f16{ .{ 1, 2 }, .{ 0, 1 }, .{ -1, 0 } }) },
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
|
const res = try zml.testing.compileAndCall(platform, NN._fwd, .{ nn, x });
|
||||||
|
|||||||
25
zml/nn.zig
25
zml/nn.zig
@ -401,7 +401,7 @@ test "real/img" {
|
|||||||
{
|
{
|
||||||
const mod = try zml.compileFn(std.testing.allocator, Fns.testSplitSeq, .{}, platform);
|
const mod = try zml.compileFn(std.testing.allocator, Fns.testSplitSeq, .{}, platform);
|
||||||
defer mod.deinit();
|
defer mod.deinit();
|
||||||
const ret = mod.call(.{});
|
const ret = mod.call({});
|
||||||
try testing.expectEqual(20, ret.getValue(i32));
|
try testing.expectEqual(20, ret.getValue(i32));
|
||||||
}
|
}
|
||||||
const d_split_interleaved = try zml.testing.compileAndCall(platform, Fns.testSplitInterleaved, .{});
|
const d_split_interleaved = try zml.testing.compileAndCall(platform, Fns.testSplitInterleaved, .{});
|
||||||
@ -1106,11 +1106,11 @@ test sdpaMemEfficient {
|
|||||||
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||||
defer rng_mask.deinit();
|
defer rng_mask.deinit();
|
||||||
|
|
||||||
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
|
||||||
const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd });
|
const q = rng.call({}).withTags(.{ .b, .h, .q, .hd });
|
||||||
const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
|
const k = rng.call({}).withTags(.{ .b, .h, .k, .hd });
|
||||||
const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
|
const v = rng.call({}).withTags(.{ .b, .h, .k, .hd });
|
||||||
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
|
const mask = rng_mask.call({}).withTags(.{ .q, .k });
|
||||||
|
|
||||||
const ref_res = try zml.testing.compileAndCall(
|
const ref_res = try zml.testing.compileAndCall(
|
||||||
platform,
|
platform,
|
||||||
@ -1164,11 +1164,11 @@ test "sdpaMemEfficient transposed" {
|
|||||||
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
|
||||||
defer rng_mask.deinit();
|
defer rng_mask.deinit();
|
||||||
|
|
||||||
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
|
// Note: we pass void here, cause Rng.normal doesn't take any runtime inputs.
|
||||||
const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd });
|
const q = rng.call({}).withTags(.{ .b, .q, .h, .hd });
|
||||||
const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
|
const k = rng.call({}).withTags(.{ .b, .k, .h, .hd });
|
||||||
const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
|
const v = rng.call({}).withTags(.{ .b, .k, .h, .hd });
|
||||||
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
|
const mask = rng_mask.call({}).withTags(.{ .q, .k });
|
||||||
|
|
||||||
const ref_res = try zml.testing.compileAndCall(
|
const ref_res = try zml.testing.compileAndCall(
|
||||||
platform,
|
platform,
|
||||||
@ -1266,7 +1266,7 @@ test sampleTokens {
|
|||||||
const logits, const expected: i32 = logits_expected;
|
const logits, const expected: i32 = logits_expected;
|
||||||
var logits_buff = try zml.Buffer.fromArray(platform, logits);
|
var logits_buff = try zml.Buffer.fromArray(platform, logits);
|
||||||
defer logits_buff.deinit();
|
defer logits_buff.deinit();
|
||||||
var sampled, rng_buff = mod.call(.{ logits_buff, undefined, rng_buff });
|
var sampled, rng_buff = mod.call(.{ logits_buff, rng_buff });
|
||||||
defer sampled.deinit();
|
defer sampled.deinit();
|
||||||
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
|
try zml.testing.expectEqual(expected, try sampled.getValue(i32));
|
||||||
}
|
}
|
||||||
@ -1304,7 +1304,6 @@ pub const DynamicSamplingStrategy = struct {
|
|||||||
opts: Opts,
|
opts: Opts,
|
||||||
) !zml.Bufferized(DynamicSamplingStrategy) {
|
) !zml.Bufferized(DynamicSamplingStrategy) {
|
||||||
return .{
|
return .{
|
||||||
.max_top_k = 0,
|
|
||||||
.top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32),
|
.top_k = try zml.Buffer.scalar(platform, opts.top_k, .i32),
|
||||||
.temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype),
|
.temperature = try zml.Buffer.scalar(platform, opts.temperature, dtype),
|
||||||
.top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype),
|
.top_p = try zml.Buffer.scalar(platform, opts.top_p, dtype),
|
||||||
|
|||||||
16
zml/ops.zig
16
zml/ops.zig
@ -108,13 +108,15 @@ test "simple while" {
|
|||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const init_i = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
const res0, const res1 = try zml.testing.compileAndCall(
|
||||||
const init_sum = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{0});
|
platform,
|
||||||
const counter: zml.Bufferized(CountInts) = .{
|
CountInts._fwd,
|
||||||
.step = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{1}),
|
.{
|
||||||
.end = try zml.Buffer.fromSlice(platform, .{}, &[_]i64{10}),
|
.{ .step = try .scalar(platform, 1, .i64), .end = try .scalar(platform, 10, .i64) },
|
||||||
};
|
try .scalar(platform, 0, .i64),
|
||||||
const res0, const res1 = try zml.testing.compileAndCall(platform, CountInts._fwd, .{ counter, init_i, init_sum });
|
try .scalar(platform, 0, .i64),
|
||||||
|
},
|
||||||
|
);
|
||||||
const last_i = try res0.getValue(i64);
|
const last_i = try res0.getValue(i64);
|
||||||
const sum = try res1.getValue(i64);
|
const sum = try res1.getValue(i64);
|
||||||
|
|
||||||
|
|||||||
@ -532,7 +532,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub const Rng = struct {
|
pub const Rng = struct {
|
||||||
_state: Tensor,
|
_state: Tensor = .{ ._shape = .init(.{2}, .u64), ._id = .{ .buffer_id = 0 } },
|
||||||
algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT,
|
algorithm: dialect.stablehlo.RngAlgorithm.Type = .DEFAULT,
|
||||||
|
|
||||||
pub fn shape() ShapeOf(Rng) {
|
pub fn shape() ShapeOf(Rng) {
|
||||||
@ -542,10 +542,8 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
|
pub fn init(platform: Platform, seed: u128) !Bufferized(Rng) {
|
||||||
const bits: [2]u64 = @bitCast(seed);
|
|
||||||
return .{
|
return .{
|
||||||
._state = try Buffer.fromSlice(platform, Shape.init(.{2}, .u64), &bits),
|
._state = try Buffer.fromBytes(platform, Rng.shape()._state, std.mem.asBytes(&seed)),
|
||||||
.algorithm = undefined,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -643,17 +641,15 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
// Compute stats over a uniform distribution on [-2, 10].
|
// Compute stats over a uniform distribution on [-2, 10].
|
||||||
const rand, const stats = try zml.testing.compileAndCall(
|
const rand, const stats = try zml.testing.compileAndCallWithTensors(
|
||||||
platform,
|
platform,
|
||||||
Stats.uniformStats,
|
Stats.uniformStats,
|
||||||
.{
|
.{ Rng.shape(), zml.Shape.init(.{1024}, .f32), .{ .min = -2, .max = 10 } },
|
||||||
try Rng.init(platform, 1234),
|
.{try Rng.init(platform, 1234)},
|
||||||
Shape.init(.{1024}, .f32),
|
|
||||||
.{ .min = -2, .max = 10 },
|
|
||||||
},
|
|
||||||
);
|
);
|
||||||
|
|
||||||
// Check the Rng state has been modified.
|
// Check the Rng state has been modified.
|
||||||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
try std.testing.expect(try rand._state.getValue(u128) != 1234);
|
||||||
|
|
||||||
// Check the mean and variance are close to theoritical values.
|
// Check the mean and variance are close to theoritical values.
|
||||||
const mean_ = try stats.mean.getValue(f32);
|
const mean_ = try stats.mean.getValue(f32);
|
||||||
@ -746,9 +742,12 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 };
|
const tgt_dist = [_]f32{ 2.0, 1.0, 4.0, 3.0 };
|
||||||
const rand, const stats = try zml.testing.compileAndCall(platform, Stats.gumbelStats, .{
|
const rand, const stats = try zml.testing.compileAndCallWithTensors(
|
||||||
try Rng.init(platform, 1234), try HostBuffer.fromArray(&tgt_dist).toDevice(platform),
|
platform,
|
||||||
});
|
Stats.gumbelStats,
|
||||||
|
.{ Rng.shape(), zml.Shape.init(.{tgt_dist.len}, .f32) },
|
||||||
|
.{ try Rng.init(platform, 1234), try .fromArray(platform, tgt_dist) },
|
||||||
|
);
|
||||||
// Check the Rng state has been modified.
|
// Check the Rng state has been modified.
|
||||||
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
try std.testing.expect(try rand._state.getValue(i128) != 1234);
|
||||||
|
|
||||||
@ -1620,15 +1619,15 @@ pub const Tensor = struct {
|
|||||||
};
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 0, .{ .end = 1 } }, .{ x, 0, .{ .end = 1 } });
|
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 0, .{ .end = 1 } });
|
||||||
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
try testing.expectEqual([5]f32{ 0, 1, 2, 3, 4 }, try res.getValue([5]f32));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), 1, .{ .start = 1, .step = 2 } }, .{ x, 0, .{ .start = 1, .step = 2 } });
|
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, 1, .{ .start = 1, .step = 2 } });
|
||||||
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
try testing.expectEqual([4]f32{ 1, 3, 6, 8 }, try res.getValue([4]f32));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
const res = try zml.testing.compileAndCallWithTensors(platform, Local._slice1dAxis, .{ x.shape(), -1, .{ .start = -2 } }, .{ x, 0, .{ .start = -2 } });
|
const res = try zml.testing.compileAndCall(platform, Local._slice1dAxis, .{ x, -1, .{ .start = -2 } });
|
||||||
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
try testing.expectEqual([4]f32{ 3, 4, 8, 9 }, try res.getValue([4]f32));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3965,13 +3964,12 @@ test "Tensor.maxPool2d" {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
|
||||||
pub fn Bufferized(comptime T: type) type {
|
pub fn Bufferized(comptime T: type) type {
|
||||||
// TODO: we should strip out the non-buffer fields.
|
// TODO: we should strip out the non-buffer fields.
|
||||||
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||||
// Also it will simplify the layout of the Bufferized struct.
|
// Also it will simplify the layout of the Bufferized struct.
|
||||||
// accelerating the calls to execute.
|
// accelerating the calls to execute.
|
||||||
return meta.MapType(Tensor, Buffer).map(T);
|
return meta.MapRestrict(Tensor, Buffer).map(T);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a clone of a type with Tensors replaced by Shapes.
|
/// Return a clone of a type with Tensors replaced by Shapes.
|
||||||
|
|||||||
@ -112,13 +112,30 @@ pub fn expectEqualShapes(expected: zml.Shape, actual: zml.Shape) error{TestExpec
|
|||||||
return error.TestExpectedEqual;
|
return error.TestExpectedEqual;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a mirrored version of T where each Tensor has been replaced by a Buffer.
|
||||||
|
/// This is similar to zml.Bufferized,
|
||||||
|
/// but also keep every other fields that could be used during compilation.
|
||||||
|
///
|
||||||
|
/// see `compileAndCall`.
|
||||||
|
pub fn BufferizedWithArgs(comptime T: type) type {
|
||||||
|
// TODO: we should strip out the non-buffer fields.
|
||||||
|
// Currently it's confusing cause the Bufferized struct contains field that are never read.
|
||||||
|
// Also it will simplify the layout of the Bufferized struct.
|
||||||
|
// accelerating the calls to execute.
|
||||||
|
return meta.MapType(zml.Tensor, zml.Buffer).map(T);
|
||||||
|
}
|
||||||
|
|
||||||
/// Compile a function and immediatly call it with the given buffers.
|
/// Compile a function and immediatly call it with the given buffers.
|
||||||
/// The compiled module is discarded after the call.
|
/// The compiled module is discarded after the call.
|
||||||
/// Useful during testing when a module is typically called only once.
|
/// Useful during testing when a module is typically called only once.
|
||||||
///
|
///
|
||||||
/// Note: `func` needs explicit types on all parameters.
|
/// Note: `func` needs explicit types on all parameters.
|
||||||
/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types.
|
/// To test a function with `anytype` (typically for tagged API), you need to create a specialized version of it with specific types.
|
||||||
pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
pub fn compileAndCall(
|
||||||
|
platform: zml.Platform,
|
||||||
|
func: anytype,
|
||||||
|
buffer_and_args: BufferizedWithArgs(stdx.meta.FnArgs(func)),
|
||||||
|
) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
@ -130,18 +147,30 @@ pub fn compileAndCall(platform: zml.Platform, func: anytype, buffer_args: zml.Bu
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
|
var shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)) = undefined;
|
||||||
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_args, &shape_args);
|
try meta.mapAlloc(Local.bufferToShape, arena.allocator(), {}, buffer_and_args, &shape_args);
|
||||||
|
|
||||||
const mod = try zml.compileFn(allocator, func, shape_args, platform);
|
var mod = try zml.compileFn(allocator, func, shape_args, platform);
|
||||||
defer mod.deinit();
|
defer mod.deinit();
|
||||||
|
|
||||||
return mod.call(buffer_args);
|
// Note: we don't use the type safe API of mod,
|
||||||
|
// cause mod.call expects a `zml.Bufferized` while we have `BufferizedWithArgs`.
|
||||||
|
mod.inner.prepare(buffer_and_args);
|
||||||
|
mod.inner._unsafeCall();
|
||||||
|
|
||||||
|
var result: zml.Bufferized(stdx.meta.FnResult(func)) = undefined;
|
||||||
|
mod.inner._unsafeAssignResults(@TypeOf(result), &result);
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Compile a function and immediatly call it with the given buffers.
|
/// Compile a function and immediatly call it with the given buffers.
|
||||||
/// The compiled module is discarded after the call.
|
/// The compiled module is discarded after the call.
|
||||||
/// Useful during testing when a module is typically called only once.
|
/// Useful during testing when a module is typically called only once.
|
||||||
pub fn compileAndCallWithTensors(platform: zml.Platform, func: anytype, shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)), buffer_args: zml.Bufferized(stdx.meta.FnArgs(func))) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
pub fn compileAndCallWithTensors(
|
||||||
|
platform: zml.Platform,
|
||||||
|
func: anytype,
|
||||||
|
shape_args: zml.ShapeOf(stdx.meta.FnArgs(func)),
|
||||||
|
buffer_args: zml.Bufferized(stdx.meta.FnArgs(func)),
|
||||||
|
) !zml.Bufferized(stdx.meta.FnResult(func)) {
|
||||||
// This simplify test API and also ensure this fn isn't used outside of tests.
|
// This simplify test API and also ensure this fn isn't used outside of tests.
|
||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
var arena = std.heap.ArenaAllocator.init(allocator);
|
var arena = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
|||||||
@ -1,8 +1,8 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
|
|
||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
|
|
||||||
const Tensor = zml.Tensor;
|
const Tensor = zml.Tensor;
|
||||||
|
|
||||||
const log = std.log.scoped(.zml_torch);
|
const log = std.log.scoped(.zml_torch);
|
||||||
@ -141,16 +141,14 @@ test pixelShuffle {
|
|||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
const upscale_factor = 3;
|
const upscale_factor = 3;
|
||||||
var digits: [9 * 4 * 4]i32 = undefined;
|
const shape = zml.Shape.init(.{ .b = 1, .c = 9, .h = 4, .w = 4 }, .i32);
|
||||||
for (&digits, 0..) |*d, i| d.* = @intCast(i);
|
const input = input: {
|
||||||
// TODO should we have tags in buffers ?
|
var digits: [9 * 4 * 4]i32 = undefined;
|
||||||
const input = try zml.Buffer.fromSlice(platform, .{ 1, 9, 4, 4 }, &digits);
|
for (&digits, 0..) |*d, i| d.* = @intCast(i);
|
||||||
const output = try zml.testing.compileAndCallWithTensors(
|
break :input try zml.Buffer.fromSlice(platform, shape, &digits);
|
||||||
platform,
|
};
|
||||||
pixelShuffle,
|
|
||||||
.{ zml.Shape.init(.{ .batch_size = 1, .c = 9, .h = 4, .w = 4 }, .i32), upscale_factor },
|
const output = try zml.testing.compileAndCall(platform, pixelShuffle, .{ input, upscale_factor });
|
||||||
.{ input, upscale_factor },
|
|
||||||
);
|
|
||||||
|
|
||||||
const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{
|
const exp = zml.HostBuffer.fromArray(&[1][1][12][12]i32{.{.{
|
||||||
.{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },
|
.{ 0, 16, 32, 1, 17, 33, 2, 18, 34, 3, 19, 35 },
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user