Radix/zml/meta.zig

658 lines
24 KiB
Zig
Raw Normal View History

const std = @import("std");
const testing = std.testing;
/// Computes floating point value division between two integers.
pub fn divFloat(T: type, numerator: anytype, denominator: anytype) T {
return toFloat(T, numerator) / toFloat(T, denominator);
}
fn toFloat(T: type, x: anytype) T {
return switch (@typeInfo(@TypeOf(x))) {
.Float => @floatCast(x),
else => @floatFromInt(x),
};
}
pub fn guard(check: bool, src: std.builtin.SourceLocation) void {
assert(check, "Invalid inputs {s}@{s}:{d}", .{ src.file, src.fn_name, src.line });
}
pub inline fn internalAssert(check: bool, comptime msg: []const u8, args: anytype) void {
assert(check, "ZML internal error: " ++ msg, args);
}
pub fn assert(check: bool, comptime msg: []const u8, args: anytype) void {
if (!check) panic(msg, args);
}
pub fn panic(comptime msg: []const u8, args: anytype) noreturn {
std.log.err(msg, args);
@panic(msg);
}
pub fn compileLog(comptime msg: []const u8, comptime args: anytype) void {
@compileLog(std.fmt.comptimePrint(msg, args));
}
pub fn compileError(comptime msg: []const u8, comptime args: anytype) noreturn {
@compileError(std.fmt.comptimePrint(msg, args));
}
pub fn assertComptime(comptime check: bool, comptime msg: []const u8, comptime args: anytype) void {
if (check == false) {
compileError(msg, args);
}
}
pub fn isStruct(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => true,
else => false,
};
}
pub fn isTuple(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| info.is_tuple,
else => false,
};
}
pub fn isStructOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (field.type != Elem) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isStructOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Struct => |info| blk: {
inline for (info.fields) |field| {
if (f(field.type) == false) {
break :blk false;
}
}
break :blk true;
},
else => false,
};
}
pub fn isTupleOf(comptime T: type, comptime Elem: type) bool {
return isTuple(T) and isStructOf(T, Elem);
}
pub fn isTupleOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return isTuple(T) and isStructOfAny(T, f);
}
pub fn isSliceOf(comptime T: type, comptime Elem: type) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child == Elem,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child == Elem,
else => false,
},
else => false,
},
else => false,
};
}
pub fn asSlice(comptime T: type) type {
const err_msg = "Type " ++ @typeName(T) ++ " can't be interpreted as a slice";
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.Slice => info.child,
.One => switch (@typeInfo(info.child)) {
// As Zig, convert pointer to Array as a slice.
.Array => |arr_info| arr_info.child,
else => @compileError(err_msg),
},
else => @compileError(err_msg),
},
else => @compileError(err_msg),
};
}
pub fn isInteger(comptime T: type) bool {
return switch (@typeInfo(T)) {
.Int, .ComptimeInt => true,
else => false,
};
}
pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool {
return switch (@typeInfo(T)) {
.Pointer => |info| info.size == .Slice and f(info.child),
else => false,
};
}
pub fn DeclEnum(comptime T: type) type {
return std.meta.DeclEnum(UnwrapPtr(T));
}
pub fn UnwrapPtr(comptime T: type) type {
return switch (@typeInfo(T)) {
.Pointer => |info| switch (info.size) {
.One => info.child,
else => T,
},
else => T,
};
}
pub fn FnParam(func: anytype, n: comptime_int) type {
return @typeInfo(@TypeOf(func)).Fn.params[n].type orelse @compileError("anytype not supported in callbacks");
}
pub fn FnParams(func: anytype) type {
return std.meta.ArgsTuple(@TypeOf(func));
}
pub fn FnResult(func: anytype) type {
return @typeInfo(@TypeOf(func)).Fn.return_type.?;
}
pub fn FnResultPayload(func: anytype) type {
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
const payload_type = switch (@typeInfo(return_type)) {
.ErrorUnion => |u| u.payload,
else => return_type,
};
return payload_type;
}
pub fn FnResultErrorSet(func: anytype) ?type {
const return_type = @typeInfo(@TypeOf(func)).Fn.return_type.?;
const error_set = switch (@typeInfo(return_type)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
return error_set;
}
pub fn Signature(comptime func: anytype, comptime argsT: ?type) type {
return struct {
pub const FuncT = if (@TypeOf(func) == type) func else @TypeOf(func);
pub const ArgsT = blk: {
if (@typeInfo(FuncT).Fn.params.len == 0) {
break :blk @TypeOf(.{});
}
break :blk argsT orelse std.meta.ArgsTuple(FuncT);
};
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload,
else => ReturnT,
};
};
pub const ReturnErrorSet: ?type = blk: {
break :blk switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
};
};
}
pub fn MapType(From: type, To: type) type {
return struct {
pub fn map(T: type) type {
switch (T) {
To => return To,
?To => return ?To,
From => return To,
*From => return *To,
?From => return ?To,
else => {},
}
return switch (@typeInfo(T)) {
.Struct => |struct_infos| {
const fields = struct_infos.fields;
var same: bool = true;
var struct_fields: [fields.len]std.builtin.Type.StructField = undefined;
for (struct_fields[0..], fields) |*struct_field, field| {
if (!field.is_comptime) {
const R = map(field.type);
if (R == field.type) {
struct_field.* = field;
} else {
struct_field.* = .{
.name = field.name,
.type = R,
.default_value = null,
.is_comptime = field.is_comptime,
.alignment = @alignOf(R),
};
same = false;
// Handle the case `field: ?Tensor = null`
// Generic handling of default value is complicated,
// it would require to call the callback at comptime.
if (R == ?To) {
struct_field.default_value = &@as(R, null);
}
}
} else {
struct_field.* = field;
}
}
if (same) return T;
return @Type(.{ .Struct = .{
.layout = .auto,
.fields = struct_fields[0..],
.decls = &.{},
.is_tuple = struct_infos.is_tuple,
} });
},
.Array => |arr_info| [arr_info.len]map(arr_info.child),
.Pointer => |ptr_info| switch (ptr_info.size) {
.Slice => if (ptr_info.is_const)
[]const map(ptr_info.child)
else
[]map(ptr_info.child),
.One => *map(ptr_info.child),
else => T,
},
.Optional => |opt_info| ?map(opt_info.child),
else => T,
};
}
};
}
/// Given a callback: `fn(Ctx, From) To`, recursively visits the given `from` struct
/// and calls the callback when it finds a `From` element, and writes it to the `to` struct.
/// The `to` parameter must be passed with mutable pointer, and tensor data need to be mutable if callback needs it.
/// `mapAlloc` tries as much as possible to respect the conversions made by Zig itself.
/// For example it can convert from a comptime array to a runtime slice.
/// `mapAlloc` can allocate new slices to write the result if the result struct requires it.
/// The caller is owning said allocations, using an `ArenaAllocator` might help tracking them.
// TODO: handle tuple to slice conversion
pub fn mapAlloc(comptime cb: anytype, allocator: std.mem.Allocator, ctx: FnParam(cb, 0), from: anytype, to: anytype) !void {
// const Ctx = FnParam(cb, 0);
const From = FnParam(cb, 1);
const FromStruct = @TypeOf(from);
const type_info_to_ptr = @typeInfo(@TypeOf(to));
if (type_info_to_ptr != .Pointer) {
@compileError("convertType is expecting a mutable `to` argument but received: " ++ @typeName(@TypeOf(to)));
}
const ToStruct = type_info_to_ptr.Pointer.child;
const type_info_to = @typeInfo(ToStruct);
if (FromStruct == From) {
// Special case for converting from shape to tensor:
// If the target type is Shape, skip tensor conversion.
// A general `to.* = from` assignment causes a Zig error in this scenario.
// (see below)
if (ToStruct == @import("shape.zig").Shape and FromStruct == ToStruct) { // FromStruct) {
to.* = from;
} else {
to.* = @call(.auto, cb, .{ ctx, from });
}
return;
}
// This is generally due to a user error, but let this fn compile,
// and the user will have a Zig error.
if (FromStruct == ToStruct) {
to.* = from;
return;
}
// Don't go into Shape objects because of the weird tag.
// TODO: we could not error on pointers to basic types like u8
if (FromStruct == @import("shape.zig").Shape) {
to.* = from;
return;
}
switch (type_info_to) {
.Struct => |info| inline for (info.fields) |field| {
// if (field.is_comptime) continue;
const field_type_info = @typeInfo(field.type);
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
switch (field_type_info) {
// .Pointer => try convertType(From, To, allocator, @field(from, field.name), @field(to, field.name), Ctx, ctx, cb),
.Array, .Optional, .Union, .Struct, .Pointer => if (@hasField(FromStruct, field.name)) {
try mapAlloc(
cb,
allocator,
ctx,
@field(from, field.name),
&@field(to, field.name),
);
} else if (field.default_value) |_| {
@field(to, field.name) = null;
} else {
compileError("Mapping {} to {} failed. Missing field {s}", .{ FromStruct, ToStruct, field.name });
},
else => @field(to, field.name) = @field(from, field.name),
}
},
.Array => for (from, to) |f, *t| {
try mapAlloc(cb, allocator, ctx, f, t);
},
.Pointer => |ptr_info| switch (ptr_info.size) {
.One => switch (type_info_to_ptr.Pointer.size) {
// pointer to array -> slice promotion
.Slice => {
to.* = try allocator.alloc(type_info_to_ptr.Pointer.child, from.len);
for (from, to.*) |f, *t| {
try mapAlloc(cb, allocator, ctx, f, t);
}
},
else => try mapAlloc(cb, allocator, ctx, from.*, to.*),
},
.Slice => {
const items = try allocator.alloc(@typeInfo(ToStruct).Pointer.child, from.len);
for (from, items) |f, *t| {
try mapAlloc(cb, allocator, ctx, f, t);
}
to.* = items;
},
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
},
.Optional => if (from) |f| {
to.* = @as(@typeInfo(type_info_to_ptr.Pointer.child).Optional.child, undefined);
try mapAlloc(cb, allocator, ctx, f, &(to.*.?));
} else {
to.* = null;
},
.Int, .Float => to.* = from,
else => @compileError("zml.meta.mapAlloc doesn't support: " ++ @typeName(FromStruct)),
}
}
test mapAlloc {
const B = struct { b: u8 };
const A = struct {
a: u8,
pub fn convert(_: void, a: @This()) B {
return .{ .b = a.a };
}
};
const AA = struct {
field: A,
array: [2]A,
slice: []const A,
other: u8,
// We want to allow conversion from comptime to runtime, because Zig type inference works like this.
comptime static_val: u8 = 8,
comptime static_slice: [2]A = .{ .{ .a = 11 }, .{ .a = 12 } },
};
const BB = struct {
field: B,
array: [2]B,
slice: []const B,
other: u8,
static_val: u8,
static_slice: []B,
};
const aa: AA = .{
.field = .{ .a = 4 },
.array = .{ .{ .a = 5 }, .{ .a = 6 } },
.other = 7,
.slice = &.{ .{ .a = 9 }, .{ .a = 10 } },
};
var bb: BB = undefined;
try mapAlloc(A.convert, testing.allocator, {}, aa, &bb);
defer testing.allocator.free(bb.slice);
defer testing.allocator.free(bb.static_slice);
try testing.expectEqual(4, bb.field.b);
try testing.expectEqual(5, bb.array[0].b);
try testing.expectEqual(6, bb.array[1].b);
try testing.expectEqual(7, bb.other);
try testing.expectEqual(8, bb.static_val);
try testing.expectEqual(9, bb.slice[0].b);
try testing.expectEqual(10, bb.slice[1].b);
try testing.expectEqual(11, bb.static_slice[0].b);
try testing.expectEqual(12, bb.static_slice[1].b);
}
/// Recursively visit the given struct and calls the callback for each K found.
/// The `v` parameter must me a pointer, and tensor data need to be mutable if callbacks needs it.
pub fn visit(comptime cb: anytype, ctx: FnParam(cb, 0), v: anytype) void {
const T = @TypeOf(v);
const type_info_v = @typeInfo(T);
const K = switch (@typeInfo(FnParam(cb, 1))) {
.Pointer => |info| info.child,
else => @compileError("zml.meta.visit is expecting a pointer value as second parameter in callback to use but found " ++ @typeName(FnParam(cb, 1))),
};
if (type_info_v != .Pointer) {
const Callback = @TypeOf(cb);
@compileError("zml.meta.visit is expecting a pointer input to go with following callback signature: " ++ @typeName(Callback) ++ " but received: " ++ @typeName(T));
}
const ptr_info = type_info_v.Pointer;
// This is important, because with trivial types like void,
// Zig sometimes decide to call `visit` at comptime, but can't do
// the pointer wrangling logic at comptime.
// So we detect early this case and return.
if (@sizeOf(ptr_info.child) == 0) return;
switch (ptr_info.size) {
// If we have a single pointer, two cases:
// * It's a pointer to K, in which case we call the callback.
// * It's a pointer to something else, in which case, we explore and recurse if needed.
.One => if (ptr_info.child == K) {
cb(ctx, v);
} else if (ptr_info.child == ?K) {
if (v.*) |*val| cb(ctx, val);
} else switch (@typeInfo(ptr_info.child)) {
.Struct => |s| inline for (s.fields) |field_info| {
if (field_info.is_comptime) continue;
const field_type_info = @typeInfo(field_info.type);
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
switch (field_type_info) {
.Pointer => visit(cb, ctx, @field(v, field_info.name)),
.Array, .Optional, .Union, .Struct => visit(cb, ctx, &@field(v, field_info.name)),
else => {},
}
},
.Array => |_| for (v) |*elem| visit(cb, ctx, elem),
.Optional => if (v.* != null) visit(cb, ctx, &v.*.?),
.Union => switch (v.*) {
inline else => |*v_field| visit(cb, ctx, v_field),
},
else => {},
},
// If we have a slice, two cases also:
// * It's a slice of K, in which case we call the callback for each element of the slice.
// * It's a slice to something else, in which case, for each element we explore and recurse if needed.
.Slice => {
for (v) |*v_elem| {
if (ptr_info.child == K) {
cb(ctx, v_elem);
} else switch (@typeInfo(ptr_info.child)) {
.Struct => |s| inline for (s.fields) |field_info| {
const field_type_info = @typeInfo(field_info.type);
// If the field is already a pointer, we recurse with it directly, otherwise, we recurse with a pointer to the field.
if (field_type_info == .Pointer) {
visit(cb, ctx, @field(v_elem, field_info.name));
} else {
visit(cb, ctx, &@field(v_elem, field_info.name));
}
},
.Array => |_| for (v) |*elem| visit(cb, ctx, elem),
.Optional => if (v.* != null) visit(cb, ctx, &v.*.?),
.Union => switch (v_elem.*) {
inline else => |*v_field| visit(cb, ctx, v_field),
},
else => {},
}
}
},
else => @compileError("Only single pointer and slice are supported. Received " ++ @typeName(T)),
}
}
test visit {
const Attr = struct { data: usize };
const OtherAttr = struct { other: []const u8 };
const NestedAttr = struct { nested: Attr };
const NestedAttrOptional = struct { nested: ?Attr };
const SimpleStruct = struct { prop: Attr };
const MultipleTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: ?Attr };
const NestedTypesStruct = struct { prop1: Attr, prop2: OtherAttr, prop3: NestedAttr, prop4: NestedAttrOptional };
const LocalContext = struct {
result: usize,
};
{
var context: LocalContext = .{ .result = 0 };
const container: SimpleStruct = .{ .prop = .{ .data = 1 } };
visit((struct {
fn cb(ctx: *LocalContext, attr: *const Attr) void {
ctx.result += attr.data;
}
}).cb, &context, &container);
try std.testing.expectEqual(1, context.result);
}
{
var context: LocalContext = .{ .result = 0 };
var container: SimpleStruct = .{ .prop = .{ .data = 1 } };
visit((struct {
fn cb(ctx: *LocalContext, attr: *Attr) void {
ctx.result += attr.data;
}
}).cb, &context, &container);
try std.testing.expectEqual(1, context.result);
}
{
var context: LocalContext = .{ .result = 0 };
var container: MultipleTypesStruct = .{ .prop1 = .{ .data = 1 }, .prop2 = .{ .other = "hello" }, .prop3 = null };
visit((struct {
fn cb(ctx: *LocalContext, attr: *Attr) void {
ctx.result += attr.data;
}
}).cb, &context, &container);
try std.testing.expectEqual(1, context.result);
}
{
var context: LocalContext = .{ .result = 0 };
const container: MultipleTypesStruct = .{ .prop1 = .{ .data = 1 }, .prop2 = .{ .other = "hello" }, .prop3 = .{ .data = 2 } };
visit((struct {
fn cb(ctx: *LocalContext, attr: *const Attr) void {
ctx.result += attr.data;
}
}).cb, &context, &container);
try std.testing.expectEqual(3, context.result);
}
{
var context: LocalContext = .{ .result = 0 };
const container: NestedTypesStruct = .{
.prop1 = .{ .data = 1 },
.prop2 = .{ .other = "hello" },
.prop3 = .{ .nested = .{ .data = 2 } },
.prop4 = .{ .nested = .{ .data = 3 } },
};
visit((struct {
fn cb(ctx: *LocalContext, attr: *const Attr) void {
ctx.result += attr.data;
}
}).cb, &context, &container);
try std.testing.expectEqual(6, context.result);
}
}
/// Given a `fn([]const T, Args) T` and a slice of values,
/// will combine all values in one value.
/// Only T elements of values will be looked at.
/// This only works for simple types, in particular `zip` doesn't follow pointers.
/// Which means that zip only allocate temp memory, and nothing need to be freed after the call.
pub fn zip(func: anytype, allocator: std.mem.Allocator, values: anytype, args: anytype) error{OutOfMemory}!asSlice(@TypeOf(values)) {
const sliceT = @typeInfo(FnParam(func, 0));
assertComptime(sliceT == .Pointer and sliceT.Pointer.size == .Slice and sliceT.Pointer.child == FnResult(func), "zip requires a `fn([]const T, Args) T`, received: {}", .{@TypeOf(func)});
const T = sliceT.Pointer.child;
const V = asSlice(@TypeOf(values));
if (V == T) {
return @call(.auto, func, .{values} ++ args);
}
// const fn_args
return switch (@typeInfo(V)) {
.Pointer => @compileError("zip only accept by value arguments. Received: " ++ @typeName(V)),
.Struct => |struct_info| {
var out: V = values[0];
inline for (struct_info.fields) |f| {
if (f.is_comptime) continue;
if (@typeInfo(f.type) == .Pointer) {
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
}
var fields = try allocator.alloc(f.type, values.len);
defer allocator.free(fields);
for (values, 0..) |val, i| {
fields[i] = @field(val, f.name);
}
@field(out, f.name) = try zip(func, allocator, fields, args);
}
return out;
},
.Array => |arr_info| {
if (@typeInfo(arr_info.child) == .Pointer) {
@compileError("zip doesn't follow pointers and don't accept struct containing them. Received: " ++ @typeName(V));
}
var out: V = undefined;
var slice = try allocator.alloc(arr_info.child, values.len);
defer allocator.free(slice);
for (&out, 0..) |*o, j| {
for (values, 0..) |val, i| {
slice[i] = val[j];
}
o.* = try zip(func, allocator, slice, args);
}
return out;
},
.Union, .Optional => @compileError("zip doesn't yet support " ++ @typeName(V)),
else => values[0],
};
}
test zip {
const A = struct { a: u8, b: [2]u8 };
const a0: A = .{ .a = 1, .b = .{ 2, 3 } };
const a1: A = .{ .a = 4, .b = .{ 5, 6 } };
const Sum = struct {
pub fn call(x: []const u8) u8 {
var res: u8 = 0;
for (x) |xx| res += xx;
return res;
}
};
const a_sum: A = try zip(Sum.call, testing.allocator, &[_]A{ a0, a1 }, .{});
try testing.expectEqual(A{ .a = 5, .b = .{ 7, 9 } }, a_sum);
}