Rename zml.aio.Value to zml.aio.Metadata, simplify its type variants, and update torch pickle/eval APIs accordingly.

This commit is contained in:
Tarry Singh 2023-04-07 16:45:58 +00:00
parent aea23c720e
commit 0189b71070
13 changed files with 1307 additions and 769 deletions

View File

@ -14,7 +14,6 @@ pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig");
pub const log = std.log.scoped(.zml_aio);
pub const Value = @import("aio/value.zig").Value;
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test {
@ -56,7 +55,7 @@ pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path
else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama"))
try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000)
else {
zml.log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
return error.FormatNotRecognized;
};
}
@ -87,8 +86,8 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
try prefix_builder.push(allocator, prefix);
defer prefix_builder.deinit(allocator);
var unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
const ok = _populateStruct(allocator, &prefix_builder, &unique_id, store, &model, true) catch |err| {
const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
};
if (!ok) return error.TensorNotFound;
@ -98,12 +97,12 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
/// A struct containing all the buffers and metadata found in a model file.
pub const BufferStore = struct {
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
pub const Metadata = std.StringArrayHashMapUnmanaged(Value);
pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
arena: std.heap.ArenaAllocator,
files: []MemoryMappedFile = &.{},
buffers: Buffers = .{},
_metadata: Metadata = .{},
_metadata: Metadatas = .{},
pub fn deinit(self: BufferStore) void {
for (self.files) |*file| file.deinit();
@ -135,7 +134,7 @@ pub const BufferStore = struct {
return if (maybe_max_index) |index| index + 1 else 0;
}
pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Value)) ?std.meta.FieldType(Value, tag) {
pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Metadata)) ?std.meta.FieldType(Metadata, tag) {
const wrapped_value = self._metadata.get(key) orelse return null;
if (wrapped_value != tag) {
@ -145,14 +144,86 @@ pub const BufferStore = struct {
return @field(wrapped_value, @tagName(tag));
}
pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Value.Slice.ItemType) ?[]const Value.Slice.toZigType(tag) {
pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Metadata.ItemType) ?[]const tag.toZigType() {
const wrapped_value = self._metadata.get(key) orelse return null;
if (wrapped_value != .array or wrapped_value.array.item_type != tag) {
return null;
const true_tag = std.meta.stringToEnum(std.meta.FieldEnum(Metadata), @tagName(tag)).?;
if (wrapped_value == true_tag) {
return @field(wrapped_value, "array_" ++ @tagName(tag));
}
return null;
}
};
pub const Metadata = union(enum) {
null: void,
int: i64,
float: f64,
bool: bool,
string: []const u8,
array_bool: []const bool,
array_int: []const i64,
array_float: []const f64,
array_string: []const []const u8,
pub const ItemType = enum {
int,
float,
bool,
string,
pub fn toZigType(comptime kind: ItemType) type {
return switch (kind) {
.int => i64,
.float => f64,
.bool => bool,
.string => []const u8,
};
}
};
pub fn wrap(x: anytype) Metadata {
return switch (@TypeOf(x)) {
inline u8, i8, u16, i16, u32, i32, u64, i64 => .{ .int = @intCast(x) },
inline f16, f32, f64 => .{ .float = @floatCast(x) },
bool => .{ .bool = x },
[]const u8 => .{ .string = x },
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(x))),
};
}
pub fn copySlice(allocator: std.mem.Allocator, any_slice: anytype) !Metadata {
return switch (@TypeOf(any_slice[0])) {
inline u8, i8, u16, i16, u32, i32, u64, i64 => {
const res = try allocator.alloc(i64, any_slice.len);
for (res, any_slice) |*r, val| r.* = @intCast(val);
return .{ .array_int = res };
},
inline f16, f32, f64 => {
const res = try allocator.alloc(f64, any_slice.len);
for (res, any_slice) |*r, val| r.* = @floatCast(val);
return .{ .array_float = res };
},
bool => .{ .array_bool = try allocator.dupe(bool, any_slice) },
[]const u8 => .{ .array_string = try allocator.dupe([]const u8, @alignCast(any_slice)) },
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(any_slice))),
};
}
pub fn format(
self: Metadata,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
switch (self) {
.null => _ = try writer.write("null"),
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),
inline else => |v| try writer.print("{d}", .{v}),
}
const T = Value.Slice.toZigType(tag);
return @alignCast(std.mem.bytesAsSlice(T, wrapped_value.array.data));
}
};
@ -244,7 +315,7 @@ const PrefixBuilder = struct {
fn _populateStruct(
allocator: std.mem.Allocator,
prefix_builder: *PrefixBuilder,
unique_id: *u64,
unique_id: u64,
buffer_store: BufferStore,
obj: anytype,
required: bool,
@ -260,17 +331,17 @@ fn _populateStruct(
const prefix = prefix_builder.data.items;
if (T == zml.Tensor) {
return if (buffer_store.get(prefix)) |buffer| {
return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| {
const buffer = buffer_store.get(prefix).?;
obj.* = zml.Tensor{
._shape = buffer.shape(),
._id = .{ .buffer_id = unique_id.* },
._id = .{ .buffer_id = unique_id + entry_idx },
._donation = .input_buffer,
};
unique_id.* += 1;
return true;
} else {
if (required) {
std.log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
}
return false;
};
@ -290,7 +361,7 @@ fn _populateStruct(
defer prefix_builder.pop();
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
if (!found) {
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) });
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
return false;
}
}
@ -299,7 +370,7 @@ fn _populateStruct(
}
return true;
} else {
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
return false;
}
},
@ -346,7 +417,7 @@ fn _populateStruct(
},
.Void => true,
else => if (required) {
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
return error.UnsupportedMetadataType;
} else return false,
};
@ -431,7 +502,7 @@ pub fn loadBuffers(
zml.meta.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
}
return loadModelBuffers(Model, model, buffer_store, allocator, platform);
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
}
/// Creates a bufferized version of a Model from the given BufferStore. For details about
@ -449,6 +520,7 @@ pub fn loadModelBuffers(
) !zml.Bufferized(Model) {
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
}
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T).
///

View File

@ -36,24 +36,24 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
log.err("GGUF File: Tokens not found", .{});
return error.TokensNotFound;
};
const scores = self.metadataSlice("tokenizer.ggml.scores", .float32) orelse {
const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse {
log.err("GGUF File: Scores not found", .{});
return error.ScoresNotFound;
};
assert(tokens.len == scores.len);
const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama";
const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece;
const bos = self.metadata("tokenizer.ggml.bos_token_id", .uint32);
const eos = self.metadata("tokenizer.ggml.eos_token_id", .uint32);
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .uint32);
const pad = self.metadata("tokenizer.ggml.padding_token_id", .uint32);
const bos = self.metadata("tokenizer.ggml.bos_token_id", .int);
const eos = self.metadata("tokenizer.ggml.eos_token_id", .int);
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int);
const pad = self.metadata("tokenizer.ggml.padding_token_id", .int);
const NOT_FOUND = std.math.maxInt(u32);
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
.bos = bos.?,
.eos = eos.?,
.unk = unk orelse NOT_FOUND,
.pad = pad orelse NOT_FOUND,
.bos = @intCast(bos.?),
.eos = @intCast(eos.?),
.unk = @intCast(unk orelse NOT_FOUND),
.pad = @intCast(pad orelse NOT_FOUND),
};
const gguf_normalizer = if (tokenizer_impl == .gpt2)
@ -85,10 +85,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
for (tokens, 0..tokens.len) |t, i| {
if (tokenizer_impl == .gpt2) {
decoded.clearRetainingCapacity();
try tokenizer.addToken(scores[i], try gpt2_unicode.?.decode(&decoded, t));
try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t));
// log.debug("token: {s} -> {s}", .{t, decoded.items});
} else {
try tokenizer.addToken(scores[i], t);
try tokenizer.addToken(@floatCast(scores[i]), t);
}
}
@ -112,7 +112,19 @@ fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.G
log.warn("Found duplicated metadata key: {s}", .{entry.name});
continue;
}
res.value_ptr.* = entry.val.asLoaderValue();
res.value_ptr.* = switch (entry.val) {
.array => |arr| switch (arr.child) {
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
const T = std.meta.FieldType(core.GgufValue, tag);
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
},
else => blk: {
log.warn("ignoring array metadata", .{});
break :blk .null;
},
},
inline else => |v| zml.aio.Metadata.wrap(v),
};
} else |err| switch (err) {
error.EndOfMetadata => {},
else => return err,

View File

@ -176,20 +176,20 @@ pub const GgufValueType = enum(u32) {
}
};
pub const ValueType = enum {
uint8,
int8,
uint16,
int16,
uint32,
int32,
float32,
uint64,
int64,
float64,
boolval,
string,
array,
pub const ValueType = enum(u8) {
uint8 = 0,
int8 = 1,
uint16 = 2,
int16 = 3,
uint32 = 4,
int32 = 5,
float32 = 6,
bool = 7,
string = 8,
array = 9,
uint64 = 10,
int64 = 11,
float64 = 12,
};
// Union of possible values.
@ -201,47 +201,20 @@ pub const GgufValue = union(ValueType) {
uint32: u32,
int32: i32,
float32: f32,
bool: bool,
string: []const u8,
array: Array,
uint64: u64,
int64: i64,
float64: f64,
boolval: bool,
string: []const u8,
array: Array,
pub const Array = struct {
// Any value type is valid, including arrays.
child: GgufValueType,
child: ValueType,
// Number of elements, not bytes
len: usize,
data: []u8,
};
pub fn asLoaderValue(self: GgufValue) zml.aio.Value {
return switch (self) {
.array => |v| .{
.array = .{
.item_type = switch (v.child) {
.bool => .boolval,
.uint8 => .uint8,
.int8 => .int8,
.uint16 => .uint16,
.int16 => .int16,
.uint32 => .uint32,
.int32 => .int32,
.float32 => .float32,
.uint64 => .uint64,
.int64 => .int64,
.float64 => .float64,
.string => .string,
// TODO: .array => .array,
else => unreachable,
},
.data = v.data,
},
},
inline else => |v, tag| @unionInit(zml.aio.Value, @tagName(tag), v),
};
}
};
// Header
@ -403,6 +376,9 @@ pub const GgufFile = struct {
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
const child = try self.readValueType();
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
return error.UnsupportedGgufType;
}
const len: usize = try self.readInt(u64);
const data = switch (child) {
// Since strings have variable lenghts, we need to read them one by one
@ -414,7 +390,7 @@ pub const GgufFile = struct {
else => try self.readAlloc(allocator, len * child.sizeOf()),
};
return .{
.child = child,
.child = @enumFromInt(@intFromEnum(child)),
.len = len,
.data = data,
};
@ -429,7 +405,7 @@ pub const GgufFile = struct {
.uint32 => .{ .uint32 = try self.readInt(u32) },
.int32 => .{ .int32 = try self.readInt(i32) },
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
.bool => .{ .boolval = try self.readInt(u8) != 0 },
.bool => .{ .bool = try self.readInt(u8) != 0 },
.string => .{ .string = try self.readString(allocator) },
.array => .{ .array = try self.readArrayHeader(allocator) },
.uint64 => .{ .uint64 = try self.readInt(u64) },

View File

@ -30,44 +30,40 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix:
const metadata = &store._metadata;
const key = prefix.items;
return switch (val) {
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }),
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .null),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
.array => |v| {
if (v.items.len == 0) return;
return if (validSlice(v)) |item_type| {
const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) {
const data: zml.aio.Metadata = switch (item_type) {
.bool => blk: {
const values = try allocator.alloc(bool, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.bool;
break :blk .{ std.mem.sliceAsBytes(values), .boolval };
break :blk .{ .array_bool = values };
},
.integer => blk: {
const values = try allocator.alloc(i64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.integer;
break :blk .{ std.mem.sliceAsBytes(values), .int64 };
break :blk .{ .array_int = values };
},
.float => blk: {
const values = try allocator.alloc(f64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.float;
break :blk .{ std.mem.sliceAsBytes(values), .float64 };
break :blk .{ .array_float = values };
},
inline .string, .number_string => |tag| blk: {
const values = try allocator.alloc([]const u8, v.items.len);
for (v.items, 0..) |item, i| {
values[i] = @field(item, @tagName(tag));
}
break :blk .{ std.mem.sliceAsBytes(values), .string };
break :blk .{ .array_string = values };
},
.null, .array, .object => unreachable,
};
try metadata.put(
allocator,
try allocator.dupe(u8, key),
.{ .array = .{ .item_type = dtype, .data = data } },
);
try metadata.put(allocator, try allocator.dupe(u8, key), data);
} else {
for (v.items, 0..) |item, i| {
var new_prefix = prefix;

View File

@ -39,10 +39,10 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const start = try mapped_file.file.getPos();
var tmp: zml.aio.torch.PickleData = .{
.data = try parser.Parser.fromTarFile(arena, mapped_file, file),
.memo = undefined,
.stack = undefined,
};
tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true);
tmp.stack = try eval.evaluate(arena, tmp.data.ops, true);
try tmp.parseModel(arena, &res);
// Since we directly manipulate the file handle pointer,
// reset to the end of file so iterator does not error

View File

@ -91,17 +91,17 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe
{
try res._metadata.ensureUnusedCapacity(arena, 11);
res._metadata.putAssumeCapacityNoClobber("dim", .{ .int64 = c.dim });
res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int64 = c.hidden_dim });
res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int64 = c.n_layers });
res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int64 = c.n_heads });
res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int64 = c.n_kv_heads });
res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int64 = c.vocab.size });
res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .boolval = c.vocab.has_lm_head });
res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int64 = c.seq_len });
res._metadata.putAssumeCapacityNoClobber("dim", .{ .int = c.dim });
res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int = c.hidden_dim });
res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int = c.n_layers });
res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int = c.n_heads });
res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int = c.n_kv_heads });
res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int = c.vocab.size });
res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .bool = c.vocab.has_lm_head });
res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int = c.seq_len });
res._metadata.putAssumeCapacityNoClobber("rope_impl", .{ .string = "interleaved" });
res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float64 = 10_000 });
res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float64 = 1e-6 });
res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float = 10_000 });
res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float = 1e-6 });
}
return res;

View File

@ -16,6 +16,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.zml_io);
test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(eval);
std.testing.refAllDecls(value);
std.testing.refAllDecls(parser);
@ -35,22 +36,21 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const tmp_alloc = arena.allocator();
const _parser = try parser.Parser.init(tmp_alloc, file);
const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true);
const stack = try eval.evaluate(tmp_alloc, _parser.ops, true);
// But we create the HostBuffer objects inside the result BufferStore arena.
var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator),
};
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file});
var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack };
var tmp: PickleData = .{ .data = _parser, .stack = stack };
try tmp.parseModel(res.arena.allocator(), &res);
return res;
}
// TODO: rename me to PytorchFile
pub const PickleData = struct {
stack: eval.PickleStack,
memo: eval.PickleMemo,
stack: []const Value,
data: parser.Parser,
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
@ -63,7 +63,7 @@ pub const PickleData = struct {
}
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
for (self.stack.stack) |item| {
for (self.stack) |item| {
var prefix_buf: [1024]u8 = undefined;
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
}
@ -147,7 +147,7 @@ pub const PickleData = struct {
try store._metadata.put(
allocator,
try allocator.dupe(u8, prefix.items),
.{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } },
try zml.aio.Metadata.copySlice(allocator, values.items),
);
} else {
for (values.items, 0..) |val, i| {
@ -156,7 +156,13 @@ pub const PickleData = struct {
new_prefix.appendAssumeCapacity('.');
}
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val));
const new_tag = switch (tag) {
.int64 => "int",
.float64 => "float",
.boolval => "bool",
else => unreachable, // we are already inside a switch
};
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val));
}
}
},
@ -212,15 +218,17 @@ pub const PickleData = struct {
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } };
} else d.value_ptr.* = .{ .string = val };
},
inline .float64, .int64, .boolval, .bigint, .string => |val, tag| {
inline .float64, .int64, .boolval, .bigint, .string => |val| {
const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key);
} else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val);
} else {
d.value_ptr.* = zml.aio.Metadata.wrap(val);
}
},
else => {},
}
@ -248,7 +256,7 @@ pub const PickleData = struct {
}
const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
entry.value_ptr.* = .{ .array_int = d };
return true;
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq.values[0].string;
@ -256,12 +264,12 @@ pub const PickleData = struct {
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
}
{
var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
}
return true;
}

View File

@ -159,84 +159,9 @@ pub const PickleMemo = struct {
}
};
pub const InternalStack = struct {
allocator: std.mem.Allocator,
values: std.ArrayList(Value),
pub fn init(allocator: std.mem.Allocator) InternalStack {
return .{
.allocator = allocator,
.values = std.ArrayList(Value).init(allocator),
};
}
pub fn deinit(self: *InternalStack) void {
for (0..self.values.items.len) |i| self.values.items[i].deinit(self.allocator);
self.values.deinit();
self.* = undefined;
}
pub fn pop(self: *InternalStack) !Value {
if (self.values.items.len == 0) {
return error.StackUnderrun;
}
return self.values.pop();
}
pub fn popMark(self: *InternalStack, allocator: ?std.mem.Allocator) ![]Value {
const markidx = try self.findMark();
var postmark: []Value = &[_]Value{};
if (allocator) |a| {
postmark = try a.alloc(Value, self.values.items.len - (markidx + 1));
@memcpy(postmark, self.values.items[markidx + 1 ..]);
}
self.values.shrinkAndFree(markidx);
return postmark;
}
pub fn lastMut(self: *InternalStack) !*Value {
if (self.values.items.len == 0) {
return error.UnexpectedEmptyStack;
}
return &self.values.items[self.values.items.len - 1];
}
pub fn findMark(self: *InternalStack) !usize {
const len = self.values.items.len;
for (0..len) |i| {
const idx = (len - 1) - i;
const val = self.values.items[idx];
if (val == .raw and val.raw == .mark) {
return idx;
}
}
zml.log.warn("pytorch loader: missing mark", .{});
return 0;
}
pub fn toPickleStack(self: *InternalStack) !PickleStack {
return .{ .stack = try self.values.toOwnedSlice(), .allocator = self.allocator };
}
};
pub const PickleStack = struct {
stack: []Value,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, values: []Value) PickleStack {
return .{ .allocator = allocator, .stack = values };
}
pub fn deinit(self: *PickleStack) void {
for (self.stack) |*v| v.deinit(self.allocator);
self.allocator.free(self.stack);
}
};
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
var stack = InternalStack.init(allocator);
defer stack.deinit();
var memo = PickleMemo.init(allocator);
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value {
var stack = std.ArrayList(Value).init(arena);
var memo = PickleMemo.init(arena);
errdefer memo.deinit();
const makeKVList = (struct {
@ -258,56 +183,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
}
}).call;
outer: for (x) |op| {
for (x) |op| {
switch (op) {
.mark => try stack.values.append(.{ .raw = op }),
.stop => break :outer,
.pop => _ = try stack.pop(),
.pop_mark => _ = try stack.popMark(allocator),
.dup => {
if (stack.values.getLastOrNull()) |item| {
try stack.values.append(try item.clone(allocator));
} else {
return error.CannotDupEmptyStack;
}
},
.persid => |v| try stack.values.append(.{ .pers_id = try PersId.init(allocator, .{ .string = try allocator.dupe(u8, v) }) }),
.binpersid => try stack.values.append(.{ .pers_id = try PersId.init(allocator, try stack.pop()) }),
.reduce => try stack.values.append(.{ .global = blk: {
const values = try allocator.alloc(Value, 1);
values[0] = try memo.resolve(allocator, try stack.pop(), true);
break :blk try Object.init(allocator, try memo.resolve(allocator, try stack.pop(), true), values);
.mark => try stack.append(.{ .raw = op }),
.frame => {},
.stop => break,
.pop => _ = try pop(&stack),
.pop_mark => try popMarkDiscard(&stack),
.dup => if (stack.getLastOrNull()) |item|
try stack.append(try item.clone(arena))
else
return error.CannotDupEmptyStack,
.persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
.binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }),
.reduce => try stack.append(.{ .global = blk: {
const values = try arena.alloc(Value, 1);
values[0] = try memo.resolve(arena, try pop(&stack), true);
break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values);
} }),
.build => try stack.values.append(blk: {
const args = try memo.resolve(allocator, try stack.pop(), true);
const member = try memo.resolve(allocator, try stack.pop(), true);
break :blk .{ .build = try Build.init(allocator, member, args) };
.build => try stack.append(blk: {
const args = try memo.resolve(arena, try pop(&stack), true);
const member = try memo.resolve(arena, try pop(&stack), true);
break :blk .{ .build = try Build.init(arena, member, args) };
}),
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.binput, .long_binput => |v| {
try memo.insert(v, try stack.pop());
try stack.values.append(.{ .ref = v });
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.append(.{ .ref = v }),
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.put => |v| {
try memo.insert(v, try pop(&stack));
try stack.append(.{ .ref = v });
},
.tuple => try stack.values.append(blk: {
const popped = try stack.popMark(allocator);
.tuple => try stack.append(blk: {
const popped = try popMark(&stack, arena);
break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
}),
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.setitem => {
const v, const k = .{ try stack.pop(), try stack.pop() };
const top = try stack.lastMut();
const v, const k = .{ try pop(&stack), try pop(&stack) };
const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
},
else => {
return error.BadStackTopForSetItem;
@ -315,53 +237,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
}
},
.setitems => {
const popped = try stack.popMark(allocator);
defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped);
const top = try stack.lastMut();
const popped = try popMark(&stack, arena);
defer arena.free(popped);
const kv_items = try makeKVList(arena, popped);
const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
},
else => {
defer allocator.free(kv_items);
defer arena.free(kv_items);
return error.BadStackTopForSetItems;
},
}
},
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 1);
tup_values[0] = try stack.pop();
.tuple1 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 1);
tup_values[0] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple2 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
.tuple2 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.tuple3 => try stack.values.append(blk: {
const tup_values = try allocator.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
.tuple3 => try stack.append(blk: {
const tup_values = try arena.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}),
.append => {
const v = try stack.pop();
const top = try stack.lastMut();
const v = try pop(&stack);
const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = v;
},
.seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = v;
},
else => {
@ -370,19 +292,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
}
},
.appends => {
const postmark = try stack.popMark(allocator);
defer allocator.free(postmark);
const top = try stack.lastMut();
const postmark = try popMark(&stack, arena);
defer arena.free(postmark);
const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
@ -390,49 +312,43 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
},
}
},
.dict => try stack.values.append(blk: {
const popped = try stack.popMark(allocator);
defer allocator.free(popped);
const kv_items = try makeKVList(allocator, popped);
.dict => try stack.append(blk: {
const popped = try popMark(&stack, arena);
defer arena.free(popped);
const kv_items = try makeKVList(arena, popped);
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
}),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
.inst => |v| try stack.values.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
.list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }),
.inst => |v| try stack.append(blk: {
const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) };
}),
.obj => try stack.values.append(blk: {
const markidx = try stack.findMark();
const args = try allocator.alloc(Value, stack.values.items.len - (markidx + 2));
@memcpy(args, stack.values.items[markidx + 2 ..]);
const member = stack.values.items[markidx + 1];
break :blk .{ .object = try Object.init(allocator, member, args) };
.obj => try stack.append(blk: {
const mark = try findMark(&stack);
const args = try arena.dupe(Value, stack.items[mark + 2 ..]);
const member = stack.items[mark + 1];
break :blk .{ .object = try Object.init(arena, member, args) };
}),
.put => |v| {
const mid = try std.fmt.parseInt(u32, v, 10);
try memo.insert(mid, try stack.pop());
try stack.values.append(.{ .ref = mid });
},
.newobj => try stack.values.append(blk: {
const args = try allocator.alloc(Value, 1);
args[0] = try stack.pop();
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
.newobj => try stack.append(blk: {
const args = try arena.alloc(Value, 1);
args[0] = try pop(&stack);
break :blk .{ .object = try Object.init(arena, try pop(&stack), args) };
}),
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.additems => {
const postmark = try stack.popMark(allocator);
defer allocator.free(postmark);
const top = try stack.lastMut();
const postmark = try popMark(&stack, arena);
defer arena.free(postmark);
const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true);
switch (rtop.*) {
.global => |obj| {
const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark);
},
.seq => |*tup| {
const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark);
},
else => {
@ -440,36 +356,33 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
},
}
},
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
.newobj_ex => try stack.values.append(blk: {
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
.frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }),
.newobj_ex => try stack.append(blk: {
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) };
}),
.stack_global => try stack.values.append(blk: {
.stack_global => try stack.append(blk: {
const gn, const mn = .{
try memo.resolve(allocator, try stack.pop(), true),
try memo.resolve(allocator, try stack.pop(), true),
try memo.resolve(arena, try pop(&stack), true),
try memo.resolve(arena, try pop(&stack), true),
};
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) };
}),
.memoize => {
const item = stack.values.getLastOrNull() orelse {
const item = stack.getLastOrNull() orelse {
return error.StackUnderrun;
};
try memo.insert(@intCast(memo.map.count()), try item.clone(allocator));
try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
},
else => try stack.values.append(.{ .raw = try op.clone(allocator) }),
else => try stack.append(.{ .raw = try op.clone(arena) }),
}
}
if (!resolve_refs) {
return .{ try stack.toPickleStack(), memo };
if (resolve_refs) {
return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
}
return .{
PickleStack.init(allocator, try memo.resolveAllRefsIter(allocator, 0, stack.values.items, true)),
memo,
};
return stack.toOwnedSlice();
}
// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead
@ -483,3 +396,86 @@ fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_l
return new;
}
}
test evaluate {
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
const vals = try evaluate(allocator, ops, true);
defer allocator.free(vals);
try std.testing.expect(vals.len == 1);
try std.testing.expect(vals[0] == .seq);
try std.testing.expect(vals[0].seq.type == .dict);
const entries = vals[0].seq.values[0].seq.values;
try std.testing.expect(entries.len == 5);
const expected: []const Value = &.{
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
.{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{
.{ .int64 = 0 },
.{ .int64 = 1 },
.{ .int64 = 2 },
.{ .int64 = 3 },
.{ .int64 = 4 },
}) } },
}) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
.{ .string = "tuple" },
.{ .seq = .{
.type = .tuple,
.values = @constCast(&[_]Value{
.{ .string = "a" },
.{ .int64 = 10 },
}),
} },
}) } },
};
try std.testing.expectEqualDeep(expected, entries);
}
pub fn pop(values: *std.ArrayList(Value)) !Value {
if (values.items.len == 0) {
return error.StackUnderrun;
}
return values.pop();
}
fn popMarkDiscard(values: *std.ArrayList(Value)) !void {
const mark = try findMark(values);
values.shrinkRetainingCapacity(mark);
}
fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value {
const mark = try findMark(values);
const popping = values.items[mark + 1 ..];
values.shrinkRetainingCapacity(mark);
return try allocator.dupe(Value, popping);
}
fn lastMut(values: *std.ArrayList(Value)) !*Value {
if (values.items.len == 0) {
return error.UnexpectedEmptyStack;
}
return &values.items[values.items.len - 1];
}
fn findMark(values: *std.ArrayList(Value)) !usize {
const len = values.items.len;
for (0..len) |i| {
const idx = (len - 1) - i;
const val = values.items[idx];
if (val == .raw and val.raw == .mark) {
return idx;
}
}
return error.MarkNotFound;
}

View File

@ -17,7 +17,7 @@ pub const Parser = struct {
buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null,
ops: []pickle.Op,
ops: []const pickle.Op,
is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{},
@ -65,7 +65,7 @@ pub const Parser = struct {
};
if (!self.is_zip_file) {
const reader = tar_stream.reader();
self.ops = try parse(allocator, reader, try tar_stream.getEndPos());
self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream());
}
@ -82,7 +82,7 @@ pub const Parser = struct {
};
if (!self.is_zip_file) {
const reader = self.buffer_file.file.reader();
self.ops = try parse(allocator, reader, try reader.context.getEndPos());
self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos());
} else {
self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream());
}
@ -94,7 +94,7 @@ pub const Parser = struct {
self.* = undefined;
}
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op {
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op {
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| {
@ -152,7 +152,7 @@ pub const Parser = struct {
switch (entry.compression_method) {
.store => {
return parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
},
.deflate => {
// TODO(cryptodeal): handle decompress
@ -166,196 +166,6 @@ pub const Parser = struct {
std.log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound;
}
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op {
var results = std.ArrayList(pickle.Op).init(allocator);
errdefer results.deinit();
outer: while (true) {
const b = try reader.readByte();
switch (@as(pickle.OpCode, @enumFromInt(b))) {
.mark => try results.append(.{ .mark = {} }),
.stop => {
try results.append(.{ .stop = {} });
break :outer;
},
.pop => try results.append(.{ .pop = {} }),
.pop_mark => try results.append(.{ .pop_mark = {} }),
.dup => try results.append(.{ .dup = {} }),
.float => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .float = buf });
},
.int => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .int = buf });
},
.binint => try results.append(.{ .binint = try reader.readInt(i32, .little) }),
.binint1 => try results.append(.{ .binint1 = try reader.readByte() }),
.long => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .long = buf });
},
.binint2 => try results.append(.{ .binint2 = try reader.readInt(u16, .little) }),
.none => try results.append(.{ .none = {} }),
.persid => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .persid = buf });
},
.binpersid => try results.append(.{ .binpersid = {} }),
.reduce => try results.append(.{ .reduce = {} }),
.string => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .string = buf });
},
.binstring => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binstring = buf });
},
.short_binstring => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .short_binstring = buf });
},
.unicode => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .unicode = buf });
},
.binunicode => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode = buf });
},
.append => try results.append(.{ .append = {} }),
.build => try results.append(.{ .build = {} }),
.global, .inst => {
const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(module);
const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(class);
try results.append(.{ .global = .{ .module = module, .class = class } });
},
.dict => try results.append(.{ .dict = {} }),
.empty_dict => try results.append(.{ .empty_dict = {} }),
.appends => try results.append(.{ .appends = {} }),
.get => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .get = buf });
},
.binget => try results.append(.{ .binget = try reader.readByte() }),
.long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }),
.list => try results.append(.{ .list = {} }),
.empty_list => try results.append(.{ .empty_list = {} }),
.obj => try results.append(.{ .obj = {} }),
.put => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .put = buf });
},
.binput => {
try results.append(.{ .binput = try reader.readByte() });
},
.long_binput => {
try results.append(.{ .long_binput = try reader.readInt(u32, .little) });
},
.setitem => try results.append(.{ .setitem = {} }),
.tuple => try results.append(.{ .tuple = {} }),
.empty_tuple => try results.append(.{ .empty_tuple = {} }),
.setitems => try results.append(.{ .setitems = {} }),
.binfloat => try results.append(.{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }),
.proto => try results.append(.{ .proto = try reader.readByte() }),
.newobj => try results.append(.{ .newobj = {} }),
.ext1 => try results.append(.{ .ext1 = try reader.readByte() }),
.ext2 => try results.append(.{ .ext2 = try reader.readInt(i16, .little) }),
.ext4 => try results.append(.{ .ext4 = try reader.readInt(i32, .little) }),
.tuple1 => try results.append(.{ .tuple1 = {} }),
.tuple2 => try results.append(.{ .tuple2 = {} }),
.tuple3 => try results.append(.{ .tuple3 = {} }),
.newtrue => try results.append(.{ .newtrue = {} }),
.newfalse => try results.append(.{ .newfalse = {} }),
.long1 => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .long1 = buf });
},
.long4 => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .long4 = buf });
},
.binbytes => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binbytes = buf });
},
.binbytes8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binbytes8 = buf });
},
.short_binbytes => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .short_binbytes = buf });
},
.binunicode8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode8 = buf });
},
.short_binunicode => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode8 = buf });
},
.empty_set => try results.append(.{ .empty_set = {} }),
.additems => try results.append(.{ .additems = {} }),
.frozenset => try results.append(.{ .frozenset = {} }),
.newobj_ex => try results.append(.{ .newobj_ex = {} }),
.stack_global => try results.append(.{ .stack_global = {} }),
.memoize => try results.append(.{ .memoize = {} }),
.frame => try results.append(.{ .frame = try reader.readInt(u64, .little) }),
.bytearray8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .bytearray8 = buf });
},
.next_buffer => try results.append(.{ .next_buffer = {} }),
.readonly_buffer => try results.append(.{ .readonly_buffer = {} }),
_ => {},
}
}
return results.toOwnedSlice();
}
};
const TarStream = struct {
@ -404,54 +214,6 @@ const TarStream = struct {
}
};
test "Read pickle (simple)" {
const Value = @import("value.zig").Value;
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const eval = @import("eval.zig");
const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
var data = try Parser.init(allocator, file);
defer data.deinit();
var vals, var memo = try eval.evaluate(allocator, data.ops, true);
defer vals.deinit();
defer memo.deinit();
try testing.expect(vals.stack.len == 2);
// skip first value (frame)
try testing.expect(vals.stack[1] == .seq);
try testing.expect(vals.stack[1].seq.type == .dict);
const entries = vals.stack[1].seq.values[0].seq.values;
try testing.expect(entries.len == 5);
const expected: []const Value = &.{
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
.{ .int64 = 0 },
.{ .int64 = 1 },
.{ .int64 = 2 },
.{ .int64 = 3 },
.{ .int64 = 4 },
})) } },
})) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "tuple" },
.{ .seq = .{
.type = .tuple,
.values = @constCast(@as([]const Value, &.{
.{ .string = "a" },
.{ .int64 = 10 },
})),
} },
})) } },
};
try std.testing.expectEqualDeep(expected, entries);
}
test "Read pickle (zipped)" {
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();

File diff suppressed because it is too large Load Diff

View File

@ -109,7 +109,7 @@ pub const ValueType = enum {
none,
};
/// A processed value.
/// A pickle operator that has been interpreted.
pub const Value = union(ValueType) {
/// Types that we can't handle or just had to give up on processing.
raw: pickle.Op,
@ -283,10 +283,10 @@ pub const Value = union(ValueType) {
return switch (self) {
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
.seq => |seq| blk: {
const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) };
for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator);
break :blk .{ .seq = new_val };
.seq => |seq| {
const values = try allocator.alloc(Value, seq.values.len);
for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator);
return .{ .seq = .{ .type = seq.type, .values = values } };
},
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
.bigint => |v| .{ .bigint = try v.clone() },
@ -342,8 +342,9 @@ pub const Value = union(ValueType) {
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
return switch (self) {
.raw => |raw_val| switch (raw_val) {
.binint, .binint1, .binint2 => |val| .{ .int64 = val },
.long1, .long4 => |b| if (b.len != 0) {
.binint => |val| .{ .int64 = val },
.long => |b| if (b.len != 0) {
// TODO: handle trailing 'L'
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
var mutable = bint.toMutable();
mutable.readTwosComplement(b, b.len, .little, .signed);
@ -355,28 +356,17 @@ pub const Value = union(ValueType) {
} else return .{ .bigint = bint };
} else .{ .raw_num = raw_val },
.binfloat => |val| .{ .float64 = val },
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s },
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b },
.unicode => |s| .{ .string = s },
.bytes => |b| .{ .bytes = b },
// This isn't how Pickle actually works but we just try to UTF8 decode the
// string and if it fails, we make it a bytes value instead. If anyone
// actually cares they can just fix values themselves or recover the raw bytes
// from the UTF8 string (it's guaranteed to be reversible, as far as I know).
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
.newtrue => .{ .boolval = true },
.newfalse => .{ .boolval = false },
.string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
.bool => |b| .{ .boolval = b },
.none => .{ .none = {} },
inline .int,
.float,
.long,
=> |v, tag| {
if (tag == .int and std.mem.eql(u8, v, "01")) {
return .{ .boolval = true };
} else if (tag == .int and std.mem.eql(u8, v, "00")) {
return .{ .boolval = false };
} else {
return .{ .raw_num = raw_val };
}
},
// TODO .int should be handled like .long
.int, .float => .{ .raw_num = raw_val },
else => self,
},
.app, .object, .global => |v| blk: {

View File

@ -25,8 +25,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
switch (val) {
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int = v }),
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float = v }),
.string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
.list => |v| switch (validSlice(v)) {
true => {
@ -36,13 +36,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
const values = try allocator.alloc(i64, v.len);
errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.int;
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } });
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_int = values });
},
.float => {
const values = try allocator.alloc(f64, v.len);
errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.float;
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } });
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_float = values });
},
.string => {
const values = try allocator.alloc([]const u8, v.len);
@ -50,7 +50,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
for (v, 0..) |item, i| {
values[i] = try allocator.dupe(u8, item.string);
}
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } });
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_string = values });
},
.list => unreachable,
else => {},

View File

@ -145,6 +145,8 @@ pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool
}
pub fn DeclEnum(comptime T: type) type {
const field_infos = std.meta.declarations(T);
if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T});
return std.meta.DeclEnum(UnwrapPtr(T));
}