Radix/zml/aio/torch/eval.zig

486 lines
21 KiB
Zig

const std = @import("std");
const zml = @import("../../zml.zig");
const meta = zml.meta;
const value = @import("value.zig");
const pickle = @import("pickle.zig");
const BTreeMap = @import("b_tree_map.zig").BTreeMap;
const Build = value.Build;
const Object = value.Object;
const PersId = value.PersId;
const Sequence = value.Sequence;
const SequenceType = value.SequenceType;
const Value = value.Value;
const MAX_DEPTH: usize = 250;
const MAX_PROTOCOL: u8 = 5;
pub const PickleMemo = struct {
allocator: std.mem.Allocator,
map: BTreeMap(u32, Value),
pub fn init(allocator: std.mem.Allocator) PickleMemo {
return .{
.allocator = allocator,
.map = BTreeMap(u32, Value).init(allocator),
};
}
pub fn deinit(self: *PickleMemo) void {
var iterator = self.map.iterator();
defer iterator.deinit();
while (iterator.next()) |entry| {
entry.value_ptr.deinit(self.allocator);
}
self.map.deinit() catch unreachable;
self.* = undefined;
}
pub fn resolve(self: *PickleMemo, allocator: std.mem.Allocator, op: Value, recursive: bool) !Value {
var used_op = op;
while (used_op == .ref) {
var count: usize = 0;
const val = self.map.get(op.ref) orelse {
return error.BadMemoRef;
};
if (!recursive) {
return val.clone(allocator);
}
count += 1;
if (count >= MAX_DEPTH or val != .ref) {
used_op = try val.clone(allocator);
break;
}
used_op = val;
}
if (used_op.containsRef()) {
switch (used_op) {
.app, .object, .global => |v| {
if (v.member.containsRef()) {
v.member = try self.resolve(allocator, v.member, recursive);
}
for (v.args) |*item| {
if (item.containsRef()) {
item.* = try self.resolve(allocator, item.*, recursive);
}
}
},
.build => |v| {
if (v.member.containsRef()) {
v.member = try self.resolve(allocator, v.member, recursive);
}
if (v.args.containsRef()) {
v.args = try self.resolve(allocator, v.args, recursive);
}
},
.pers_id => |v| {
if (v.ref.containsRef()) {
v.ref = try self.resolve(allocator, v.ref, recursive);
}
},
.seq => |*v| {
for (v.values) |*item| {
if (item.containsRef()) {
item.* = try self.resolve(allocator, item.*, recursive);
}
}
},
else => {},
}
}
return used_op;
}
pub fn insert(self: *PickleMemo, mid: u32, val: Value) !void {
_ = try self.map.fetchPut(mid, val);
}
pub fn resolveMut(self: *PickleMemo, op: *Value, recursive: bool) !*Value {
if (op.* != .ref) return op;
var lastmid = op.ref;
var count: usize = 0;
var val = self.map.get(lastmid) orelse {
return error.BadMemoRef;
};
while (val == .ref) {
lastmid = val.ref;
if (!recursive) {
break;
}
count += 1;
if (count >= MAX_DEPTH) {
break;
}
val = self.map.get(lastmid) orelse {
return error.BadMemoRef;
};
}
return (self.map.getPtr(lastmid) orelse {
return error.BadMemoRef;
});
}
const MemoError = std.math.big.int.Managed.ConvertError || std.mem.Allocator.Error || error{BadMemoRef};
pub fn resolveAllRefsIter(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, vals: []Value, fix_values: bool) MemoError![]Value {
if (depth >= MAX_DEPTH) {
return vals;
}
const res = try allocator.alloc(Value, vals.len);
for (vals, 0..) |v, i| {
res[i] = try self.resolveAllRefs(allocator, depth + 1, v, fix_values);
}
return res;
}
pub fn resolveAllRefs(self: *PickleMemo, allocator: std.mem.Allocator, depth: usize, val: Value, fix_values: bool) !Value {
var output: Value = switch (val) {
.ref => try self.resolve(allocator, val, true),
inline .app, .object, .global => |v, tag| @unionInit(Value, @tagName(tag), try Object.init(
allocator,
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
try self.resolveAllRefsIter(allocator, depth + 1, v.args, fix_values),
)),
.build => |v| .{ .build = try Build.init(
allocator,
try self.resolveAllRefs(allocator, depth + 1, v.member, fix_values),
try self.resolveAllRefs(allocator, depth + 1, v.args, fix_values),
) },
.seq => |v| .{ .seq = .{ .type = v.type, .values = try self.resolveAllRefsIter(allocator, depth + 1, v.values, fix_values) } },
.pers_id => |v| .{ .pers_id = try PersId.init(allocator, try self.resolveAllRefs(allocator, depth + 1, v.ref, fix_values)) },
else => try val.clone(allocator),
};
if (fix_values) {
output = try output.coerceFromRaw(allocator);
}
return output;
}
};
pub const InternalStack = struct {
allocator: std.mem.Allocator,
values: std.ArrayList(Value),
pub fn init(allocator: std.mem.Allocator) InternalStack {
return .{
.allocator = allocator,
.values = std.ArrayList(Value).init(allocator),
};
}
pub fn deinit(self: *InternalStack) void {
for (0..self.values.items.len) |i| self.values.items[i].deinit(self.allocator);
self.values.deinit();
self.* = undefined;
}
pub fn pop(self: *InternalStack) !Value {
if (self.values.items.len == 0) {
return error.StackUnderrun;
}
return self.values.pop();
}
pub fn popMark(self: *InternalStack, allocator: ?std.mem.Allocator) ![]Value {
const markidx = try self.findMark();
var postmark: []Value = &[_]Value{};
if (allocator) |a| {
postmark = try a.alloc(Value, self.values.items.len - (markidx + 1));
@memcpy(postmark, self.values.items[markidx + 1 ..]);
}
self.values.shrinkAndFree(markidx);
return postmark;
}
pub fn lastMut(self: *InternalStack) !*Value {
if (self.values.items.len == 0) {
return error.UnexpectedEmptyStack;
}
return &self.values.items[self.values.items.len - 1];
}
pub fn findMark(self: *InternalStack) !usize {
const len = self.values.items.len;
for (0..len) |i| {
const idx = (len - 1) - i;
const val = self.values.items[idx];
if (val == .raw and val.raw == .mark) {
return idx;
}
}
zml.log.warn("pytorch loader: missing mark", .{});
return 0;
}
pub fn toPickleStack(self: *InternalStack) !PickleStack {
return .{ .stack = try self.values.toOwnedSlice(), .allocator = self.allocator };
}
};
pub const PickleStack = struct {
stack: []Value,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, values: []Value) PickleStack {
return .{ .allocator = allocator, .stack = values };
}
pub fn deinit(self: *PickleStack) void {
for (self.stack) |*v| v.deinit(self.allocator);
self.allocator.free(self.stack);
}
};
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
var stack = InternalStack.init(allocator);
defer stack.deinit();
var memo = PickleMemo.init(allocator);
errdefer memo.deinit();
const makeKVList = (struct {
pub fn call(alloc: std.mem.Allocator, items: []const Value) ![]Value {
meta.assert(items.len & 1 == 0, "Bad value for setitems", .{});
var kv_items = try std.ArrayList(Value).initCapacity(alloc, items.len);
errdefer kv_items.deinit();
var idx: usize = 0;
while (idx < items.len) : (idx += 2) {
if (idx + 1 >= items.len) {
return error.MissingValueItem;
}
const kv = try alloc.alloc(Value, 2);
kv[0] = items[idx];
kv[1] = items[idx + 1];
kv_items.appendAssumeCapacity(.{ .seq = .{ .type = .kv_tuple, .values = kv } });
}
return kv_items.toOwnedSlice();
}
}).call;
outer: for (x) |op| {
switch (op) {
.mark => try stack.values.append(.{ .raw = op }),
.stop => break :outer,
.pop => _ = try stack.pop(),
.pop_mark => _ = try stack.popMark(allocator),
.dup => {
if (stack.values.getLastOrNull()) |item| {
try stack.values.append(try item.clone(allocator));
} else {
return error.CannotDupEmptyStack;
}
},
.persid => |v| try stack.values.append(.{ .pers_id = try PersId.init(allocator, .{ .string = try allocator.dupe(u8, v) }) }),
.binpersid => try stack.values.append(.{ .pers_id = try PersId.init(allocator, try stack.pop()) }),
.reduce => try stack.values.append(.{ .global = blk: {
const values = try allocator.alloc(Value, 1);
values[0] = try memo.resolve(allocator, try stack.pop(), true);
break :blk try Object.init(allocator, try memo.resolve(allocator, try stack.pop(), true), values);
} }),
.build => try stack.values.append(blk: {
const args = try memo.resolve(allocator, try stack.pop(), true);
const member = try memo.resolve(allocator, try stack.pop(), true);
break :blk .{ .build = try Build.init(allocator, member, args) };
}),
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.binput, .long_binput => |v| {
try memo.insert(v, try stack.pop());
try stack.values.append(.{ .ref = v });
},
.tuple => try stack.values.append(blk: {
const popped = try stack.popMark(allocator);
break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
}),
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.setitem => {
const v, const k = .{ try stack.pop(), try stack.pop() };
const top = try stack.lastMut();
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
},
else => {
return error.BadStackTopForSetItem;
},
}
},
.setitems => {
const popped = try stack.popMark(allocator);
defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped);
const top = try stack.lastMut();
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
else => {
defer allocator.free(kv_items);
return error.BadStackTopForSetItems;
},
}
},
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 1);
tup_values[0] = try stack.pop();
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple2 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple3 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.append => {
const v = try stack.pop();
const top = try stack.lastMut();
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = v;
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = v;
},
else => {
return error.BadStackTopForAppend;
},
}
},
.appends => {
const postmark = try stack.popMark(allocator);
defer allocator.free(postmark);
const top = try stack.lastMut();
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
return error.BadStackTopForAppends;
},
}
},
.dict => try stack.values.append(blk: {
const popped = try stack.popMark(allocator);
defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped);
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
}),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
}),
.obj => try stack.values.append(blk: {
const markidx = try stack.findMark();
const args = try allocator.alloc(Value, stack.values.items.len - (markidx + 2));
@memcpy(args, stack.values.items[markidx + 2 ..]);
const member = stack.values.items[markidx + 1];
break :blk .{ .object = try Object.init(allocator, member, args) };
}),
.put => |v| {
const mid = try std.fmt.parseInt(u32, v, 10);
try memo.insert(mid, try stack.pop());
try stack.values.append(.{ .ref = mid });
},
.newobj => try stack.values.append(blk: {
const args = try allocator.alloc(Value, 1);
args[0] = try stack.pop();
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
}),
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.additems => {
const postmark = try stack.popMark(allocator);
defer allocator.free(postmark);
const top = try stack.lastMut();
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
return error.BadStackTopForSetItem;
},
}
},
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
.newobj_ex => try stack.values.append(blk: {
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
}),
.stack_global => try stack.values.append(blk: {
const gn, const mn = .{
try memo.resolve(allocator, try stack.pop(), true),
try memo.resolve(allocator, try stack.pop(), true),
};
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
}),
.memoize => {
const item = stack.values.getLastOrNull() orelse {
return error.StackUnderrun;
};
try memo.insert(@intCast(memo.map.count()), try item.clone(allocator));
},
else => try stack.values.append(.{ .raw = try op.clone(allocator) }),
}
}
if (!resolve_refs) {
return .{ try stack.toPickleStack(), memo };
}
return .{
PickleStack.init(allocator, try memo.resolveAllRefsIter(allocator, 0, stack.values.items, true)),
memo,
};
}
// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead
fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_length: usize) ![]T {
if (allocator.resize(old, new_length)) {
return old;
} else {
defer allocator.free(old);
const new = try allocator.alloc(T, new_length);
@memcpy(new[0..old.len], old);
return new;
}
}