Rename zml.aio.Value to zml.aio.Metadata, simplify its type variants, and update torch pickle/eval APIs accordingly.
This commit is contained in:
parent
aea23c720e
commit
0189b71070
116
zml/aio.zig
116
zml/aio.zig
@ -14,7 +14,6 @@ pub const torch = @import("aio/torch.zig");
|
||||
pub const yaml = @import("aio/yaml.zig");
|
||||
|
||||
pub const log = std.log.scoped(.zml_aio);
|
||||
pub const Value = @import("aio/value.zig").Value;
|
||||
const HostBuffer = @import("hostbuffer.zig").HostBuffer;
|
||||
|
||||
test {
|
||||
@ -56,7 +55,7 @@ pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path
|
||||
else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama"))
|
||||
try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000)
|
||||
else {
|
||||
zml.log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
|
||||
log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
|
||||
return error.FormatNotRecognized;
|
||||
};
|
||||
}
|
||||
@ -87,8 +86,8 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
|
||||
try prefix_builder.push(allocator, prefix);
|
||||
defer prefix_builder.deinit(allocator);
|
||||
|
||||
var unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
|
||||
const ok = _populateStruct(allocator, &prefix_builder, &unique_id, store, &model, true) catch |err| {
|
||||
const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
|
||||
const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
|
||||
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
|
||||
};
|
||||
if (!ok) return error.TensorNotFound;
|
||||
@ -98,12 +97,12 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
|
||||
/// A struct containing all the buffers and metadata found in a model file.
|
||||
pub const BufferStore = struct {
|
||||
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
|
||||
pub const Metadata = std.StringArrayHashMapUnmanaged(Value);
|
||||
pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
|
||||
|
||||
arena: std.heap.ArenaAllocator,
|
||||
files: []MemoryMappedFile = &.{},
|
||||
buffers: Buffers = .{},
|
||||
_metadata: Metadata = .{},
|
||||
_metadata: Metadatas = .{},
|
||||
|
||||
pub fn deinit(self: BufferStore) void {
|
||||
for (self.files) |*file| file.deinit();
|
||||
@ -135,7 +134,7 @@ pub const BufferStore = struct {
|
||||
return if (maybe_max_index) |index| index + 1 else 0;
|
||||
}
|
||||
|
||||
pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Value)) ?std.meta.FieldType(Value, tag) {
|
||||
pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Metadata)) ?std.meta.FieldType(Metadata, tag) {
|
||||
const wrapped_value = self._metadata.get(key) orelse return null;
|
||||
|
||||
if (wrapped_value != tag) {
|
||||
@ -145,14 +144,86 @@ pub const BufferStore = struct {
|
||||
return @field(wrapped_value, @tagName(tag));
|
||||
}
|
||||
|
||||
pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Value.Slice.ItemType) ?[]const Value.Slice.toZigType(tag) {
|
||||
pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Metadata.ItemType) ?[]const tag.toZigType() {
|
||||
const wrapped_value = self._metadata.get(key) orelse return null;
|
||||
|
||||
if (wrapped_value != .array or wrapped_value.array.item_type != tag) {
|
||||
return null;
|
||||
const true_tag = std.meta.stringToEnum(std.meta.FieldEnum(Metadata), @tagName(tag)).?;
|
||||
if (wrapped_value == true_tag) {
|
||||
return @field(wrapped_value, "array_" ++ @tagName(tag));
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
pub const Metadata = union(enum) {
|
||||
null: void,
|
||||
int: i64,
|
||||
float: f64,
|
||||
bool: bool,
|
||||
string: []const u8,
|
||||
|
||||
array_bool: []const bool,
|
||||
array_int: []const i64,
|
||||
array_float: []const f64,
|
||||
array_string: []const []const u8,
|
||||
|
||||
pub const ItemType = enum {
|
||||
int,
|
||||
float,
|
||||
bool,
|
||||
string,
|
||||
|
||||
pub fn toZigType(comptime kind: ItemType) type {
|
||||
return switch (kind) {
|
||||
.int => i64,
|
||||
.float => f64,
|
||||
.bool => bool,
|
||||
.string => []const u8,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
pub fn wrap(x: anytype) Metadata {
|
||||
return switch (@TypeOf(x)) {
|
||||
inline u8, i8, u16, i16, u32, i32, u64, i64 => .{ .int = @intCast(x) },
|
||||
inline f16, f32, f64 => .{ .float = @floatCast(x) },
|
||||
bool => .{ .bool = x },
|
||||
[]const u8 => .{ .string = x },
|
||||
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(x))),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn copySlice(allocator: std.mem.Allocator, any_slice: anytype) !Metadata {
|
||||
return switch (@TypeOf(any_slice[0])) {
|
||||
inline u8, i8, u16, i16, u32, i32, u64, i64 => {
|
||||
const res = try allocator.alloc(i64, any_slice.len);
|
||||
for (res, any_slice) |*r, val| r.* = @intCast(val);
|
||||
return .{ .array_int = res };
|
||||
},
|
||||
inline f16, f32, f64 => {
|
||||
const res = try allocator.alloc(f64, any_slice.len);
|
||||
for (res, any_slice) |*r, val| r.* = @floatCast(val);
|
||||
return .{ .array_float = res };
|
||||
},
|
||||
bool => .{ .array_bool = try allocator.dupe(bool, any_slice) },
|
||||
[]const u8 => .{ .array_string = try allocator.dupe([]const u8, @alignCast(any_slice)) },
|
||||
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(any_slice))),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn format(
|
||||
self: Metadata,
|
||||
comptime fmt: []const u8,
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
switch (self) {
|
||||
.null => _ = try writer.write("null"),
|
||||
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),
|
||||
inline else => |v| try writer.print("{d}", .{v}),
|
||||
}
|
||||
const T = Value.Slice.toZigType(tag);
|
||||
return @alignCast(std.mem.bytesAsSlice(T, wrapped_value.array.data));
|
||||
}
|
||||
};
|
||||
|
||||
@ -244,7 +315,7 @@ const PrefixBuilder = struct {
|
||||
fn _populateStruct(
|
||||
allocator: std.mem.Allocator,
|
||||
prefix_builder: *PrefixBuilder,
|
||||
unique_id: *u64,
|
||||
unique_id: u64,
|
||||
buffer_store: BufferStore,
|
||||
obj: anytype,
|
||||
required: bool,
|
||||
@ -260,17 +331,17 @@ fn _populateStruct(
|
||||
|
||||
const prefix = prefix_builder.data.items;
|
||||
if (T == zml.Tensor) {
|
||||
return if (buffer_store.get(prefix)) |buffer| {
|
||||
return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| {
|
||||
const buffer = buffer_store.get(prefix).?;
|
||||
obj.* = zml.Tensor{
|
||||
._shape = buffer.shape(),
|
||||
._id = .{ .buffer_id = unique_id.* },
|
||||
._id = .{ .buffer_id = unique_id + entry_idx },
|
||||
._donation = .input_buffer,
|
||||
};
|
||||
unique_id.* += 1;
|
||||
return true;
|
||||
} else {
|
||||
if (required) {
|
||||
std.log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
|
||||
log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
|
||||
}
|
||||
return false;
|
||||
};
|
||||
@ -290,7 +361,7 @@ fn _populateStruct(
|
||||
defer prefix_builder.pop();
|
||||
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
|
||||
if (!found) {
|
||||
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) });
|
||||
log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@ -299,7 +370,7 @@ fn _populateStruct(
|
||||
}
|
||||
return true;
|
||||
} else {
|
||||
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
|
||||
log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
|
||||
return false;
|
||||
}
|
||||
},
|
||||
@ -346,7 +417,7 @@ fn _populateStruct(
|
||||
},
|
||||
.Void => true,
|
||||
else => if (required) {
|
||||
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
||||
log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
|
||||
return error.UnsupportedMetadataType;
|
||||
} else return false,
|
||||
};
|
||||
@ -431,7 +502,7 @@ pub fn loadBuffers(
|
||||
zml.meta.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
|
||||
}
|
||||
|
||||
return loadModelBuffers(Model, model, buffer_store, allocator, platform);
|
||||
return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||
}
|
||||
|
||||
/// Creates a bufferized version of a Model from the given BufferStore. For details about
|
||||
@ -449,6 +520,7 @@ pub fn loadModelBuffers(
|
||||
) !zml.Bufferized(Model) {
|
||||
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
|
||||
}
|
||||
|
||||
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
|
||||
/// For details about bufferization, see the documentation of Bufferized(T).
|
||||
///
|
||||
|
||||
@ -36,24 +36,24 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
|
||||
log.err("GGUF File: Tokens not found", .{});
|
||||
return error.TokensNotFound;
|
||||
};
|
||||
const scores = self.metadataSlice("tokenizer.ggml.scores", .float32) orelse {
|
||||
const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse {
|
||||
log.err("GGUF File: Scores not found", .{});
|
||||
return error.ScoresNotFound;
|
||||
};
|
||||
assert(tokens.len == scores.len);
|
||||
const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama";
|
||||
const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece;
|
||||
const bos = self.metadata("tokenizer.ggml.bos_token_id", .uint32);
|
||||
const eos = self.metadata("tokenizer.ggml.eos_token_id", .uint32);
|
||||
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .uint32);
|
||||
const pad = self.metadata("tokenizer.ggml.padding_token_id", .uint32);
|
||||
const bos = self.metadata("tokenizer.ggml.bos_token_id", .int);
|
||||
const eos = self.metadata("tokenizer.ggml.eos_token_id", .int);
|
||||
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int);
|
||||
const pad = self.metadata("tokenizer.ggml.padding_token_id", .int);
|
||||
|
||||
const NOT_FOUND = std.math.maxInt(u32);
|
||||
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
|
||||
.bos = bos.?,
|
||||
.eos = eos.?,
|
||||
.unk = unk orelse NOT_FOUND,
|
||||
.pad = pad orelse NOT_FOUND,
|
||||
.bos = @intCast(bos.?),
|
||||
.eos = @intCast(eos.?),
|
||||
.unk = @intCast(unk orelse NOT_FOUND),
|
||||
.pad = @intCast(pad orelse NOT_FOUND),
|
||||
};
|
||||
|
||||
const gguf_normalizer = if (tokenizer_impl == .gpt2)
|
||||
@ -85,10 +85,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
|
||||
for (tokens, 0..tokens.len) |t, i| {
|
||||
if (tokenizer_impl == .gpt2) {
|
||||
decoded.clearRetainingCapacity();
|
||||
try tokenizer.addToken(scores[i], try gpt2_unicode.?.decode(&decoded, t));
|
||||
try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t));
|
||||
// log.debug("token: {s} -> {s}", .{t, decoded.items});
|
||||
} else {
|
||||
try tokenizer.addToken(scores[i], t);
|
||||
try tokenizer.addToken(@floatCast(scores[i]), t);
|
||||
}
|
||||
}
|
||||
|
||||
@ -112,7 +112,19 @@ fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.G
|
||||
log.warn("Found duplicated metadata key: {s}", .{entry.name});
|
||||
continue;
|
||||
}
|
||||
res.value_ptr.* = entry.val.asLoaderValue();
|
||||
res.value_ptr.* = switch (entry.val) {
|
||||
.array => |arr| switch (arr.child) {
|
||||
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
|
||||
const T = std.meta.FieldType(core.GgufValue, tag);
|
||||
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
|
||||
},
|
||||
else => blk: {
|
||||
log.warn("ignoring array metadata", .{});
|
||||
break :blk .null;
|
||||
},
|
||||
},
|
||||
inline else => |v| zml.aio.Metadata.wrap(v),
|
||||
};
|
||||
} else |err| switch (err) {
|
||||
error.EndOfMetadata => {},
|
||||
else => return err,
|
||||
|
||||
@ -176,20 +176,20 @@ pub const GgufValueType = enum(u32) {
|
||||
}
|
||||
};
|
||||
|
||||
pub const ValueType = enum {
|
||||
uint8,
|
||||
int8,
|
||||
uint16,
|
||||
int16,
|
||||
uint32,
|
||||
int32,
|
||||
float32,
|
||||
uint64,
|
||||
int64,
|
||||
float64,
|
||||
boolval,
|
||||
string,
|
||||
array,
|
||||
pub const ValueType = enum(u8) {
|
||||
uint8 = 0,
|
||||
int8 = 1,
|
||||
uint16 = 2,
|
||||
int16 = 3,
|
||||
uint32 = 4,
|
||||
int32 = 5,
|
||||
float32 = 6,
|
||||
bool = 7,
|
||||
string = 8,
|
||||
array = 9,
|
||||
uint64 = 10,
|
||||
int64 = 11,
|
||||
float64 = 12,
|
||||
};
|
||||
|
||||
// Union of possible values.
|
||||
@ -201,47 +201,20 @@ pub const GgufValue = union(ValueType) {
|
||||
uint32: u32,
|
||||
int32: i32,
|
||||
float32: f32,
|
||||
bool: bool,
|
||||
string: []const u8,
|
||||
array: Array,
|
||||
uint64: u64,
|
||||
int64: i64,
|
||||
float64: f64,
|
||||
boolval: bool,
|
||||
string: []const u8,
|
||||
array: Array,
|
||||
|
||||
pub const Array = struct {
|
||||
// Any value type is valid, including arrays.
|
||||
child: GgufValueType,
|
||||
child: ValueType,
|
||||
// Number of elements, not bytes
|
||||
len: usize,
|
||||
data: []u8,
|
||||
};
|
||||
|
||||
pub fn asLoaderValue(self: GgufValue) zml.aio.Value {
|
||||
return switch (self) {
|
||||
.array => |v| .{
|
||||
.array = .{
|
||||
.item_type = switch (v.child) {
|
||||
.bool => .boolval,
|
||||
.uint8 => .uint8,
|
||||
.int8 => .int8,
|
||||
.uint16 => .uint16,
|
||||
.int16 => .int16,
|
||||
.uint32 => .uint32,
|
||||
.int32 => .int32,
|
||||
.float32 => .float32,
|
||||
.uint64 => .uint64,
|
||||
.int64 => .int64,
|
||||
.float64 => .float64,
|
||||
.string => .string,
|
||||
// TODO: .array => .array,
|
||||
else => unreachable,
|
||||
},
|
||||
.data = v.data,
|
||||
},
|
||||
},
|
||||
inline else => |v, tag| @unionInit(zml.aio.Value, @tagName(tag), v),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Header
|
||||
@ -403,6 +376,9 @@ pub const GgufFile = struct {
|
||||
|
||||
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
|
||||
const child = try self.readValueType();
|
||||
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
|
||||
return error.UnsupportedGgufType;
|
||||
}
|
||||
const len: usize = try self.readInt(u64);
|
||||
const data = switch (child) {
|
||||
// Since strings have variable lenghts, we need to read them one by one
|
||||
@ -414,7 +390,7 @@ pub const GgufFile = struct {
|
||||
else => try self.readAlloc(allocator, len * child.sizeOf()),
|
||||
};
|
||||
return .{
|
||||
.child = child,
|
||||
.child = @enumFromInt(@intFromEnum(child)),
|
||||
.len = len,
|
||||
.data = data,
|
||||
};
|
||||
@ -429,7 +405,7 @@ pub const GgufFile = struct {
|
||||
.uint32 => .{ .uint32 = try self.readInt(u32) },
|
||||
.int32 => .{ .int32 = try self.readInt(i32) },
|
||||
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
|
||||
.bool => .{ .boolval = try self.readInt(u8) != 0 },
|
||||
.bool => .{ .bool = try self.readInt(u8) != 0 },
|
||||
.string => .{ .string = try self.readString(allocator) },
|
||||
.array => .{ .array = try self.readArrayHeader(allocator) },
|
||||
.uint64 => .{ .uint64 = try self.readInt(u64) },
|
||||
|
||||
@ -30,44 +30,40 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix:
|
||||
const metadata = &store._metadata;
|
||||
const key = prefix.items;
|
||||
return switch (val) {
|
||||
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }),
|
||||
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }),
|
||||
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }),
|
||||
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }),
|
||||
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .null),
|
||||
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }),
|
||||
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }),
|
||||
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }),
|
||||
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
|
||||
.array => |v| {
|
||||
if (v.items.len == 0) return;
|
||||
return if (validSlice(v)) |item_type| {
|
||||
const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) {
|
||||
const data: zml.aio.Metadata = switch (item_type) {
|
||||
.bool => blk: {
|
||||
const values = try allocator.alloc(bool, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.bool;
|
||||
break :blk .{ std.mem.sliceAsBytes(values), .boolval };
|
||||
break :blk .{ .array_bool = values };
|
||||
},
|
||||
.integer => blk: {
|
||||
const values = try allocator.alloc(i64, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.integer;
|
||||
break :blk .{ std.mem.sliceAsBytes(values), .int64 };
|
||||
break :blk .{ .array_int = values };
|
||||
},
|
||||
.float => blk: {
|
||||
const values = try allocator.alloc(f64, v.items.len);
|
||||
for (v.items, 0..) |item, i| values[i] = item.float;
|
||||
break :blk .{ std.mem.sliceAsBytes(values), .float64 };
|
||||
break :blk .{ .array_float = values };
|
||||
},
|
||||
inline .string, .number_string => |tag| blk: {
|
||||
const values = try allocator.alloc([]const u8, v.items.len);
|
||||
for (v.items, 0..) |item, i| {
|
||||
values[i] = @field(item, @tagName(tag));
|
||||
}
|
||||
break :blk .{ std.mem.sliceAsBytes(values), .string };
|
||||
break :blk .{ .array_string = values };
|
||||
},
|
||||
.null, .array, .object => unreachable,
|
||||
};
|
||||
try metadata.put(
|
||||
allocator,
|
||||
try allocator.dupe(u8, key),
|
||||
.{ .array = .{ .item_type = dtype, .data = data } },
|
||||
);
|
||||
try metadata.put(allocator, try allocator.dupe(u8, key), data);
|
||||
} else {
|
||||
for (v.items, 0..) |item, i| {
|
||||
var new_prefix = prefix;
|
||||
|
||||
@ -39,10 +39,10 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
||||
const start = try mapped_file.file.getPos();
|
||||
var tmp: zml.aio.torch.PickleData = .{
|
||||
.data = try parser.Parser.fromTarFile(arena, mapped_file, file),
|
||||
.memo = undefined,
|
||||
.stack = undefined,
|
||||
};
|
||||
tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true);
|
||||
tmp.stack = try eval.evaluate(arena, tmp.data.ops, true);
|
||||
|
||||
try tmp.parseModel(arena, &res);
|
||||
// Since we directly manipulate the file handle pointer,
|
||||
// reset to the end of file so iterator does not error
|
||||
|
||||
@ -91,17 +91,17 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe
|
||||
|
||||
{
|
||||
try res._metadata.ensureUnusedCapacity(arena, 11);
|
||||
res._metadata.putAssumeCapacityNoClobber("dim", .{ .int64 = c.dim });
|
||||
res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int64 = c.hidden_dim });
|
||||
res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int64 = c.n_layers });
|
||||
res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int64 = c.n_heads });
|
||||
res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int64 = c.n_kv_heads });
|
||||
res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int64 = c.vocab.size });
|
||||
res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .boolval = c.vocab.has_lm_head });
|
||||
res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int64 = c.seq_len });
|
||||
res._metadata.putAssumeCapacityNoClobber("dim", .{ .int = c.dim });
|
||||
res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int = c.hidden_dim });
|
||||
res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int = c.n_layers });
|
||||
res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int = c.n_heads });
|
||||
res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int = c.n_kv_heads });
|
||||
res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int = c.vocab.size });
|
||||
res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .bool = c.vocab.has_lm_head });
|
||||
res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int = c.seq_len });
|
||||
res._metadata.putAssumeCapacityNoClobber("rope_impl", .{ .string = "interleaved" });
|
||||
res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float64 = 10_000 });
|
||||
res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float64 = 1e-6 });
|
||||
res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float = 10_000 });
|
||||
res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float = 1e-6 });
|
||||
}
|
||||
|
||||
return res;
|
||||
|
||||
@ -16,6 +16,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
|
||||
const log = std.log.scoped(.zml_io);
|
||||
|
||||
test {
|
||||
std.testing.refAllDecls(@This());
|
||||
std.testing.refAllDecls(eval);
|
||||
std.testing.refAllDecls(value);
|
||||
std.testing.refAllDecls(parser);
|
||||
@ -35,22 +36,21 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
||||
const tmp_alloc = arena.allocator();
|
||||
|
||||
const _parser = try parser.Parser.init(tmp_alloc, file);
|
||||
const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true);
|
||||
const stack = try eval.evaluate(tmp_alloc, _parser.ops, true);
|
||||
|
||||
// But we create the HostBuffer objects inside the result BufferStore arena.
|
||||
var res: zml.aio.BufferStore = .{
|
||||
.arena = std.heap.ArenaAllocator.init(allocator),
|
||||
};
|
||||
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file});
|
||||
var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack };
|
||||
var tmp: PickleData = .{ .data = _parser, .stack = stack };
|
||||
try tmp.parseModel(res.arena.allocator(), &res);
|
||||
return res;
|
||||
}
|
||||
|
||||
// TODO: rename me to PytorchFile
|
||||
pub const PickleData = struct {
|
||||
stack: eval.PickleStack,
|
||||
memo: eval.PickleMemo,
|
||||
stack: []const Value,
|
||||
data: parser.Parser,
|
||||
|
||||
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
|
||||
@ -63,7 +63,7 @@ pub const PickleData = struct {
|
||||
}
|
||||
|
||||
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
|
||||
for (self.stack.stack) |item| {
|
||||
for (self.stack) |item| {
|
||||
var prefix_buf: [1024]u8 = undefined;
|
||||
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
|
||||
}
|
||||
@ -147,7 +147,7 @@ pub const PickleData = struct {
|
||||
try store._metadata.put(
|
||||
allocator,
|
||||
try allocator.dupe(u8, prefix.items),
|
||||
.{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } },
|
||||
try zml.aio.Metadata.copySlice(allocator, values.items),
|
||||
);
|
||||
} else {
|
||||
for (values.items, 0..) |val, i| {
|
||||
@ -156,7 +156,13 @@ pub const PickleData = struct {
|
||||
new_prefix.appendAssumeCapacity('.');
|
||||
}
|
||||
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val));
|
||||
const new_tag = switch (tag) {
|
||||
.int64 => "int",
|
||||
.float64 => "float",
|
||||
.boolval => "bool",
|
||||
else => unreachable, // we are already inside a switch
|
||||
};
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val));
|
||||
}
|
||||
}
|
||||
},
|
||||
@ -212,15 +218,17 @@ pub const PickleData = struct {
|
||||
if (d.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } };
|
||||
} else d.value_ptr.* = .{ .string = val };
|
||||
},
|
||||
inline .float64, .int64, .boolval, .bigint, .string => |val, tag| {
|
||||
inline .float64, .int64, .boolval, .bigint, .string => |val| {
|
||||
const key = try allocator.dupe(u8, prefix.items);
|
||||
const d = try store._metadata.getOrPut(allocator, key);
|
||||
if (d.found_existing) {
|
||||
log.warn("Duplicate key: {s}", .{prefix.items});
|
||||
allocator.free(key);
|
||||
} else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val);
|
||||
} else {
|
||||
d.value_ptr.* = zml.aio.Metadata.wrap(val);
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
@ -248,7 +256,7 @@ pub const PickleData = struct {
|
||||
}
|
||||
const d = try allocator.alloc(i64, size.len);
|
||||
for (d, 0..) |*di, i| di.* = size[i].int64;
|
||||
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } };
|
||||
entry.value_ptr.* = .{ .array_int = d };
|
||||
return true;
|
||||
} else if (basicTypeCheck(object, "fractions", "Fraction")) {
|
||||
const fraction_str = object.args[0].seq.values[0].string;
|
||||
@ -256,12 +264,12 @@ pub const PickleData = struct {
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".numerator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
|
||||
}
|
||||
{
|
||||
var new_prefix = prefix;
|
||||
new_prefix.appendSliceAssumeCapacity(".denominator");
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -159,84 +159,9 @@ pub const PickleMemo = struct {
|
||||
}
|
||||
};
|
||||
|
||||
pub const InternalStack = struct {
|
||||
allocator: std.mem.Allocator,
|
||||
values: std.ArrayList(Value),
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator) InternalStack {
|
||||
return .{
|
||||
.allocator = allocator,
|
||||
.values = std.ArrayList(Value).init(allocator),
|
||||
};
|
||||
}
|
||||
|
||||
pub fn deinit(self: *InternalStack) void {
|
||||
for (0..self.values.items.len) |i| self.values.items[i].deinit(self.allocator);
|
||||
self.values.deinit();
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
pub fn pop(self: *InternalStack) !Value {
|
||||
if (self.values.items.len == 0) {
|
||||
return error.StackUnderrun;
|
||||
}
|
||||
return self.values.pop();
|
||||
}
|
||||
|
||||
pub fn popMark(self: *InternalStack, allocator: ?std.mem.Allocator) ![]Value {
|
||||
const markidx = try self.findMark();
|
||||
var postmark: []Value = &[_]Value{};
|
||||
if (allocator) |a| {
|
||||
postmark = try a.alloc(Value, self.values.items.len - (markidx + 1));
|
||||
@memcpy(postmark, self.values.items[markidx + 1 ..]);
|
||||
}
|
||||
self.values.shrinkAndFree(markidx);
|
||||
return postmark;
|
||||
}
|
||||
|
||||
pub fn lastMut(self: *InternalStack) !*Value {
|
||||
if (self.values.items.len == 0) {
|
||||
return error.UnexpectedEmptyStack;
|
||||
}
|
||||
return &self.values.items[self.values.items.len - 1];
|
||||
}
|
||||
|
||||
pub fn findMark(self: *InternalStack) !usize {
|
||||
const len = self.values.items.len;
|
||||
for (0..len) |i| {
|
||||
const idx = (len - 1) - i;
|
||||
const val = self.values.items[idx];
|
||||
if (val == .raw and val.raw == .mark) {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
zml.log.warn("pytorch loader: missing mark", .{});
|
||||
return 0;
|
||||
}
|
||||
|
||||
pub fn toPickleStack(self: *InternalStack) !PickleStack {
|
||||
return .{ .stack = try self.values.toOwnedSlice(), .allocator = self.allocator };
|
||||
}
|
||||
};
|
||||
|
||||
pub const PickleStack = struct {
|
||||
stack: []Value,
|
||||
allocator: std.mem.Allocator,
|
||||
|
||||
pub fn init(allocator: std.mem.Allocator, values: []Value) PickleStack {
|
||||
return .{ .allocator = allocator, .stack = values };
|
||||
}
|
||||
|
||||
pub fn deinit(self: *PickleStack) void {
|
||||
for (self.stack) |*v| v.deinit(self.allocator);
|
||||
self.allocator.free(self.stack);
|
||||
}
|
||||
};
|
||||
|
||||
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
|
||||
var stack = InternalStack.init(allocator);
|
||||
defer stack.deinit();
|
||||
var memo = PickleMemo.init(allocator);
|
||||
pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value {
|
||||
var stack = std.ArrayList(Value).init(arena);
|
||||
var memo = PickleMemo.init(arena);
|
||||
errdefer memo.deinit();
|
||||
|
||||
const makeKVList = (struct {
|
||||
@ -258,56 +183,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
|
||||
}
|
||||
}).call;
|
||||
|
||||
outer: for (x) |op| {
|
||||
for (x) |op| {
|
||||
switch (op) {
|
||||
.mark => try stack.values.append(.{ .raw = op }),
|
||||
.stop => break :outer,
|
||||
.pop => _ = try stack.pop(),
|
||||
.pop_mark => _ = try stack.popMark(allocator),
|
||||
.dup => {
|
||||
if (stack.values.getLastOrNull()) |item| {
|
||||
try stack.values.append(try item.clone(allocator));
|
||||
} else {
|
||||
return error.CannotDupEmptyStack;
|
||||
}
|
||||
},
|
||||
.persid => |v| try stack.values.append(.{ .pers_id = try PersId.init(allocator, .{ .string = try allocator.dupe(u8, v) }) }),
|
||||
.binpersid => try stack.values.append(.{ .pers_id = try PersId.init(allocator, try stack.pop()) }),
|
||||
.reduce => try stack.values.append(.{ .global = blk: {
|
||||
const values = try allocator.alloc(Value, 1);
|
||||
values[0] = try memo.resolve(allocator, try stack.pop(), true);
|
||||
break :blk try Object.init(allocator, try memo.resolve(allocator, try stack.pop(), true), values);
|
||||
.mark => try stack.append(.{ .raw = op }),
|
||||
.frame => {},
|
||||
.stop => break,
|
||||
.pop => _ = try pop(&stack),
|
||||
.pop_mark => try popMarkDiscard(&stack),
|
||||
.dup => if (stack.getLastOrNull()) |item|
|
||||
try stack.append(try item.clone(arena))
|
||||
else
|
||||
return error.CannotDupEmptyStack,
|
||||
.persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
|
||||
.binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }),
|
||||
.reduce => try stack.append(.{ .global = blk: {
|
||||
const values = try arena.alloc(Value, 1);
|
||||
values[0] = try memo.resolve(arena, try pop(&stack), true);
|
||||
break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values);
|
||||
} }),
|
||||
.build => try stack.values.append(blk: {
|
||||
const args = try memo.resolve(allocator, try stack.pop(), true);
|
||||
const member = try memo.resolve(allocator, try stack.pop(), true);
|
||||
break :blk .{ .build = try Build.init(allocator, member, args) };
|
||||
.build => try stack.append(blk: {
|
||||
const args = try memo.resolve(arena, try pop(&stack), true);
|
||||
const member = try memo.resolve(arena, try pop(&stack), true);
|
||||
break :blk .{ .build = try Build.init(arena, member, args) };
|
||||
}),
|
||||
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
|
||||
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }),
|
||||
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }),
|
||||
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
|
||||
.binput, .long_binput => |v| {
|
||||
try memo.insert(v, try stack.pop());
|
||||
try stack.values.append(.{ .ref = v });
|
||||
.empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
|
||||
.get => |v| try stack.append(.{ .ref = v }),
|
||||
.empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
|
||||
.put => |v| {
|
||||
try memo.insert(v, try pop(&stack));
|
||||
try stack.append(.{ .ref = v });
|
||||
},
|
||||
.tuple => try stack.values.append(blk: {
|
||||
const popped = try stack.popMark(allocator);
|
||||
.tuple => try stack.append(blk: {
|
||||
const popped = try popMark(&stack, arena);
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
|
||||
}),
|
||||
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
|
||||
.empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
|
||||
.setitem => {
|
||||
const v, const k = .{ try stack.pop(), try stack.pop() };
|
||||
const top = try stack.lastMut();
|
||||
const v, const k = .{ try pop(&stack), try pop(&stack) };
|
||||
const top = try lastMut(&stack);
|
||||
const rtop = try memo.resolveMut(top, true);
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
|
||||
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } };
|
||||
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
|
||||
},
|
||||
else => {
|
||||
return error.BadStackTopForSetItem;
|
||||
@ -315,53 +237,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
|
||||
}
|
||||
},
|
||||
.setitems => {
|
||||
const popped = try stack.popMark(allocator);
|
||||
defer allocator.free(popped);
|
||||
const kv_items = try makeKVList(allocator, popped);
|
||||
const top = try stack.lastMut();
|
||||
const popped = try popMark(&stack, arena);
|
||||
defer arena.free(popped);
|
||||
const kv_items = try makeKVList(arena, popped);
|
||||
const top = try lastMut(&stack);
|
||||
const rtop = try memo.resolveMut(top, true);
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
|
||||
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
|
||||
},
|
||||
else => {
|
||||
defer allocator.free(kv_items);
|
||||
defer arena.free(kv_items);
|
||||
return error.BadStackTopForSetItems;
|
||||
},
|
||||
}
|
||||
},
|
||||
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
|
||||
.tuple1 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 1);
|
||||
tup_values[0] = try stack.pop();
|
||||
.tuple1 => try stack.append(blk: {
|
||||
const tup_values = try arena.alloc(Value, 1);
|
||||
tup_values[0] = try pop(&stack);
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.tuple2 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 2);
|
||||
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
|
||||
.tuple2 => try stack.append(blk: {
|
||||
const tup_values = try arena.alloc(Value, 2);
|
||||
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.tuple3 => try stack.values.append(blk: {
|
||||
const tup_values = try allocator.alloc(Value, 3);
|
||||
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop();
|
||||
.tuple3 => try stack.append(blk: {
|
||||
const tup_values = try arena.alloc(Value, 3);
|
||||
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
|
||||
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
|
||||
}),
|
||||
.append => {
|
||||
const v = try stack.pop();
|
||||
const top = try stack.lastMut();
|
||||
const v = try pop(&stack);
|
||||
const top = try lastMut(&stack);
|
||||
const rtop = try memo.resolveMut(top, true);
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1);
|
||||
obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
|
||||
obj.args[obj.args.len - 1] = v;
|
||||
},
|
||||
.seq => |*tup| {
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1);
|
||||
tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
|
||||
tup.values[tup.values.len - 1] = v;
|
||||
},
|
||||
else => {
|
||||
@ -370,19 +292,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
|
||||
}
|
||||
},
|
||||
.appends => {
|
||||
const postmark = try stack.popMark(allocator);
|
||||
defer allocator.free(postmark);
|
||||
const top = try stack.lastMut();
|
||||
const postmark = try popMark(&stack, arena);
|
||||
defer arena.free(postmark);
|
||||
const top = try lastMut(&stack);
|
||||
const rtop = try memo.resolveMut(top, true);
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
const obj_len = obj.args.len;
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
|
||||
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
|
||||
@memcpy(obj.args[obj_len..], postmark);
|
||||
},
|
||||
.seq => |*tup| {
|
||||
const tup_len = tup.values.len;
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
|
||||
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
|
||||
@memcpy(tup.values[tup_len..], postmark);
|
||||
},
|
||||
else => {
|
||||
@ -390,49 +312,43 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
|
||||
},
|
||||
}
|
||||
},
|
||||
.dict => try stack.values.append(blk: {
|
||||
const popped = try stack.popMark(allocator);
|
||||
defer allocator.free(popped);
|
||||
const kv_items = try makeKVList(allocator, popped);
|
||||
.dict => try stack.append(blk: {
|
||||
const popped = try popMark(&stack, arena);
|
||||
defer arena.free(popped);
|
||||
const kv_items = try makeKVList(arena, popped);
|
||||
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
|
||||
}),
|
||||
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }),
|
||||
.inst => |v| try stack.values.append(blk: {
|
||||
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
|
||||
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) };
|
||||
.list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }),
|
||||
.inst => |v| try stack.append(blk: {
|
||||
const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
|
||||
break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) };
|
||||
}),
|
||||
.obj => try stack.values.append(blk: {
|
||||
const markidx = try stack.findMark();
|
||||
const args = try allocator.alloc(Value, stack.values.items.len - (markidx + 2));
|
||||
@memcpy(args, stack.values.items[markidx + 2 ..]);
|
||||
const member = stack.values.items[markidx + 1];
|
||||
break :blk .{ .object = try Object.init(allocator, member, args) };
|
||||
.obj => try stack.append(blk: {
|
||||
const mark = try findMark(&stack);
|
||||
const args = try arena.dupe(Value, stack.items[mark + 2 ..]);
|
||||
const member = stack.items[mark + 1];
|
||||
break :blk .{ .object = try Object.init(arena, member, args) };
|
||||
}),
|
||||
.put => |v| {
|
||||
const mid = try std.fmt.parseInt(u32, v, 10);
|
||||
try memo.insert(mid, try stack.pop());
|
||||
try stack.values.append(.{ .ref = mid });
|
||||
},
|
||||
.newobj => try stack.values.append(blk: {
|
||||
const args = try allocator.alloc(Value, 1);
|
||||
args[0] = try stack.pop();
|
||||
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
|
||||
.newobj => try stack.append(blk: {
|
||||
const args = try arena.alloc(Value, 1);
|
||||
args[0] = try pop(&stack);
|
||||
break :blk .{ .object = try Object.init(arena, try pop(&stack), args) };
|
||||
}),
|
||||
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
|
||||
.empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
|
||||
.additems => {
|
||||
const postmark = try stack.popMark(allocator);
|
||||
defer allocator.free(postmark);
|
||||
const top = try stack.lastMut();
|
||||
const postmark = try popMark(&stack, arena);
|
||||
defer arena.free(postmark);
|
||||
const top = try lastMut(&stack);
|
||||
const rtop = try memo.resolveMut(top, true);
|
||||
switch (rtop.*) {
|
||||
.global => |obj| {
|
||||
const obj_len = obj.args.len;
|
||||
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len);
|
||||
obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
|
||||
@memcpy(obj.args[obj_len..], postmark);
|
||||
},
|
||||
.seq => |*tup| {
|
||||
const tup_len = tup.values.len;
|
||||
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len);
|
||||
tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
|
||||
@memcpy(tup.values[tup_len..], postmark);
|
||||
},
|
||||
else => {
|
||||
@ -440,36 +356,33 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
|
||||
},
|
||||
}
|
||||
},
|
||||
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }),
|
||||
.newobj_ex => try stack.values.append(blk: {
|
||||
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() };
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) };
|
||||
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) };
|
||||
.frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }),
|
||||
.newobj_ex => try stack.append(blk: {
|
||||
const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) };
|
||||
break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) };
|
||||
}),
|
||||
.stack_global => try stack.values.append(blk: {
|
||||
.stack_global => try stack.append(blk: {
|
||||
const gn, const mn = .{
|
||||
try memo.resolve(allocator, try stack.pop(), true),
|
||||
try memo.resolve(allocator, try stack.pop(), true),
|
||||
try memo.resolve(arena, try pop(&stack), true),
|
||||
try memo.resolve(arena, try pop(&stack), true),
|
||||
};
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) };
|
||||
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) };
|
||||
const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) };
|
||||
break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) };
|
||||
}),
|
||||
.memoize => {
|
||||
const item = stack.values.getLastOrNull() orelse {
|
||||
const item = stack.getLastOrNull() orelse {
|
||||
return error.StackUnderrun;
|
||||
};
|
||||
try memo.insert(@intCast(memo.map.count()), try item.clone(allocator));
|
||||
try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
|
||||
},
|
||||
else => try stack.values.append(.{ .raw = try op.clone(allocator) }),
|
||||
else => try stack.append(.{ .raw = try op.clone(arena) }),
|
||||
}
|
||||
}
|
||||
if (!resolve_refs) {
|
||||
return .{ try stack.toPickleStack(), memo };
|
||||
if (resolve_refs) {
|
||||
return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
|
||||
}
|
||||
return .{
|
||||
PickleStack.init(allocator, try memo.resolveAllRefsIter(allocator, 0, stack.values.items, true)),
|
||||
memo,
|
||||
};
|
||||
return stack.toOwnedSlice();
|
||||
}
|
||||
|
||||
// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead
|
||||
@ -483,3 +396,86 @@ fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_l
|
||||
return new;
|
||||
}
|
||||
}
|
||||
|
||||
test evaluate {
|
||||
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
|
||||
defer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
|
||||
var buffered_reader = std.io.bufferedReader(file.reader());
|
||||
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
|
||||
|
||||
const vals = try evaluate(allocator, ops, true);
|
||||
defer allocator.free(vals);
|
||||
|
||||
try std.testing.expect(vals.len == 1);
|
||||
try std.testing.expect(vals[0] == .seq);
|
||||
try std.testing.expect(vals[0].seq.type == .dict);
|
||||
const entries = vals[0].seq.values[0].seq.values;
|
||||
try std.testing.expect(entries.len == 5);
|
||||
const expected: []const Value = &.{
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
|
||||
.{ .string = "list" },
|
||||
.{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{
|
||||
.{ .int64 = 0 },
|
||||
.{ .int64 = 1 },
|
||||
.{ .int64 = 2 },
|
||||
.{ .int64 = 3 },
|
||||
.{ .int64 = 4 },
|
||||
}) } },
|
||||
}) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
|
||||
.{ .string = "tuple" },
|
||||
.{ .seq = .{
|
||||
.type = .tuple,
|
||||
.values = @constCast(&[_]Value{
|
||||
.{ .string = "a" },
|
||||
.{ .int64 = 10 },
|
||||
}),
|
||||
} },
|
||||
}) } },
|
||||
};
|
||||
|
||||
try std.testing.expectEqualDeep(expected, entries);
|
||||
}
|
||||
|
||||
pub fn pop(values: *std.ArrayList(Value)) !Value {
|
||||
if (values.items.len == 0) {
|
||||
return error.StackUnderrun;
|
||||
}
|
||||
return values.pop();
|
||||
}
|
||||
|
||||
fn popMarkDiscard(values: *std.ArrayList(Value)) !void {
|
||||
const mark = try findMark(values);
|
||||
values.shrinkRetainingCapacity(mark);
|
||||
}
|
||||
|
||||
fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value {
|
||||
const mark = try findMark(values);
|
||||
const popping = values.items[mark + 1 ..];
|
||||
values.shrinkRetainingCapacity(mark);
|
||||
return try allocator.dupe(Value, popping);
|
||||
}
|
||||
|
||||
fn lastMut(values: *std.ArrayList(Value)) !*Value {
|
||||
if (values.items.len == 0) {
|
||||
return error.UnexpectedEmptyStack;
|
||||
}
|
||||
return &values.items[values.items.len - 1];
|
||||
}
|
||||
|
||||
fn findMark(values: *std.ArrayList(Value)) !usize {
|
||||
const len = values.items.len;
|
||||
for (0..len) |i| {
|
||||
const idx = (len - 1) - i;
|
||||
const val = values.items[idx];
|
||||
if (val == .raw and val.raw == .mark) {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
return error.MarkNotFound;
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ pub const Parser = struct {
|
||||
buffer_file: zml.aio.MemoryMappedFile,
|
||||
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
|
||||
tar_file: ?TarStream = null,
|
||||
ops: []pickle.Op,
|
||||
ops: []const pickle.Op,
|
||||
is_zip_file: bool,
|
||||
zip_prefix: []const u8 = &[_]u8{},
|
||||
|
||||
@ -65,7 +65,7 @@ pub const Parser = struct {
|
||||
};
|
||||
if (!self.is_zip_file) {
|
||||
const reader = tar_stream.reader();
|
||||
self.ops = try parse(allocator, reader, try tar_stream.getEndPos());
|
||||
self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos());
|
||||
} else {
|
||||
self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream());
|
||||
}
|
||||
@ -82,7 +82,7 @@ pub const Parser = struct {
|
||||
};
|
||||
if (!self.is_zip_file) {
|
||||
const reader = self.buffer_file.file.reader();
|
||||
self.ops = try parse(allocator, reader, try reader.context.getEndPos());
|
||||
self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos());
|
||||
} else {
|
||||
self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream());
|
||||
}
|
||||
@ -94,7 +94,7 @@ pub const Parser = struct {
|
||||
self.* = undefined;
|
||||
}
|
||||
|
||||
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op {
|
||||
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op {
|
||||
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
|
||||
var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
|
||||
while (try iter.next()) |entry| {
|
||||
@ -152,7 +152,7 @@ pub const Parser = struct {
|
||||
|
||||
switch (entry.compression_method) {
|
||||
.store => {
|
||||
return parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
|
||||
return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
|
||||
},
|
||||
.deflate => {
|
||||
// TODO(cryptodeal): handle decompress
|
||||
@ -166,196 +166,6 @@ pub const Parser = struct {
|
||||
std.log.err("Could not find file ending in `data.pkl` in archive", .{});
|
||||
return error.PickleNotFound;
|
||||
}
|
||||
|
||||
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op {
|
||||
var results = std.ArrayList(pickle.Op).init(allocator);
|
||||
errdefer results.deinit();
|
||||
outer: while (true) {
|
||||
const b = try reader.readByte();
|
||||
switch (@as(pickle.OpCode, @enumFromInt(b))) {
|
||||
.mark => try results.append(.{ .mark = {} }),
|
||||
.stop => {
|
||||
try results.append(.{ .stop = {} });
|
||||
break :outer;
|
||||
},
|
||||
.pop => try results.append(.{ .pop = {} }),
|
||||
.pop_mark => try results.append(.{ .pop_mark = {} }),
|
||||
.dup => try results.append(.{ .dup = {} }),
|
||||
.float => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .float = buf });
|
||||
},
|
||||
.int => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .int = buf });
|
||||
},
|
||||
.binint => try results.append(.{ .binint = try reader.readInt(i32, .little) }),
|
||||
.binint1 => try results.append(.{ .binint1 = try reader.readByte() }),
|
||||
.long => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .long = buf });
|
||||
},
|
||||
.binint2 => try results.append(.{ .binint2 = try reader.readInt(u16, .little) }),
|
||||
.none => try results.append(.{ .none = {} }),
|
||||
.persid => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .persid = buf });
|
||||
},
|
||||
.binpersid => try results.append(.{ .binpersid = {} }),
|
||||
.reduce => try results.append(.{ .reduce = {} }),
|
||||
.string => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .string = buf });
|
||||
},
|
||||
.binstring => {
|
||||
const str_len = try reader.readInt(u32, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binstring = buf });
|
||||
},
|
||||
.short_binstring => {
|
||||
const str_len = try reader.readByte();
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .short_binstring = buf });
|
||||
},
|
||||
.unicode => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .unicode = buf });
|
||||
},
|
||||
.binunicode => {
|
||||
const str_len = try reader.readInt(u32, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binunicode = buf });
|
||||
},
|
||||
.append => try results.append(.{ .append = {} }),
|
||||
.build => try results.append(.{ .build = {} }),
|
||||
.global, .inst => {
|
||||
const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(module);
|
||||
const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(class);
|
||||
try results.append(.{ .global = .{ .module = module, .class = class } });
|
||||
},
|
||||
.dict => try results.append(.{ .dict = {} }),
|
||||
.empty_dict => try results.append(.{ .empty_dict = {} }),
|
||||
.appends => try results.append(.{ .appends = {} }),
|
||||
.get => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .get = buf });
|
||||
},
|
||||
.binget => try results.append(.{ .binget = try reader.readByte() }),
|
||||
.long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }),
|
||||
.list => try results.append(.{ .list = {} }),
|
||||
.empty_list => try results.append(.{ .empty_list = {} }),
|
||||
.obj => try results.append(.{ .obj = {} }),
|
||||
.put => {
|
||||
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
|
||||
errdefer allocator.free(buf);
|
||||
try results.append(.{ .put = buf });
|
||||
},
|
||||
.binput => {
|
||||
try results.append(.{ .binput = try reader.readByte() });
|
||||
},
|
||||
.long_binput => {
|
||||
try results.append(.{ .long_binput = try reader.readInt(u32, .little) });
|
||||
},
|
||||
.setitem => try results.append(.{ .setitem = {} }),
|
||||
.tuple => try results.append(.{ .tuple = {} }),
|
||||
.empty_tuple => try results.append(.{ .empty_tuple = {} }),
|
||||
.setitems => try results.append(.{ .setitems = {} }),
|
||||
.binfloat => try results.append(.{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }),
|
||||
.proto => try results.append(.{ .proto = try reader.readByte() }),
|
||||
.newobj => try results.append(.{ .newobj = {} }),
|
||||
.ext1 => try results.append(.{ .ext1 = try reader.readByte() }),
|
||||
.ext2 => try results.append(.{ .ext2 = try reader.readInt(i16, .little) }),
|
||||
.ext4 => try results.append(.{ .ext4 = try reader.readInt(i32, .little) }),
|
||||
.tuple1 => try results.append(.{ .tuple1 = {} }),
|
||||
.tuple2 => try results.append(.{ .tuple2 = {} }),
|
||||
.tuple3 => try results.append(.{ .tuple3 = {} }),
|
||||
.newtrue => try results.append(.{ .newtrue = {} }),
|
||||
.newfalse => try results.append(.{ .newfalse = {} }),
|
||||
.long1 => {
|
||||
const str_len = try reader.readByte();
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .long1 = buf });
|
||||
},
|
||||
.long4 => {
|
||||
const str_len = try reader.readInt(u32, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .long4 = buf });
|
||||
},
|
||||
.binbytes => {
|
||||
const str_len = try reader.readInt(u32, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binbytes = buf });
|
||||
},
|
||||
.binbytes8 => {
|
||||
const str_len = try reader.readInt(u64, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binbytes8 = buf });
|
||||
},
|
||||
.short_binbytes => {
|
||||
const str_len = try reader.readByte();
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .short_binbytes = buf });
|
||||
},
|
||||
.binunicode8 => {
|
||||
const str_len = try reader.readInt(u64, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binunicode8 = buf });
|
||||
},
|
||||
.short_binunicode => {
|
||||
const str_len = try reader.readByte();
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .binunicode8 = buf });
|
||||
},
|
||||
.empty_set => try results.append(.{ .empty_set = {} }),
|
||||
.additems => try results.append(.{ .additems = {} }),
|
||||
.frozenset => try results.append(.{ .frozenset = {} }),
|
||||
.newobj_ex => try results.append(.{ .newobj_ex = {} }),
|
||||
.stack_global => try results.append(.{ .stack_global = {} }),
|
||||
.memoize => try results.append(.{ .memoize = {} }),
|
||||
.frame => try results.append(.{ .frame = try reader.readInt(u64, .little) }),
|
||||
.bytearray8 => {
|
||||
const str_len = try reader.readInt(u64, .little);
|
||||
const buf = try allocator.alloc(u8, str_len);
|
||||
errdefer allocator.free(buf);
|
||||
_ = try reader.read(buf);
|
||||
try results.append(.{ .bytearray8 = buf });
|
||||
},
|
||||
.next_buffer => try results.append(.{ .next_buffer = {} }),
|
||||
.readonly_buffer => try results.append(.{ .readonly_buffer = {} }),
|
||||
_ => {},
|
||||
}
|
||||
}
|
||||
return results.toOwnedSlice();
|
||||
}
|
||||
};
|
||||
|
||||
const TarStream = struct {
|
||||
@ -404,54 +214,6 @@ const TarStream = struct {
|
||||
}
|
||||
};
|
||||
|
||||
test "Read pickle (simple)" {
|
||||
const Value = @import("value.zig").Value;
|
||||
var arena = std.heap.ArenaAllocator.init(testing.allocator);
|
||||
defer arena.deinit();
|
||||
const allocator = arena.allocator();
|
||||
const eval = @import("eval.zig");
|
||||
const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
|
||||
var data = try Parser.init(allocator, file);
|
||||
defer data.deinit();
|
||||
var vals, var memo = try eval.evaluate(allocator, data.ops, true);
|
||||
defer vals.deinit();
|
||||
defer memo.deinit();
|
||||
|
||||
try testing.expect(vals.stack.len == 2);
|
||||
// skip first value (frame)
|
||||
try testing.expect(vals.stack[1] == .seq);
|
||||
try testing.expect(vals.stack[1].seq.type == .dict);
|
||||
const entries = vals.stack[1].seq.values[0].seq.values;
|
||||
try testing.expect(entries.len == 5);
|
||||
const expected: []const Value = &.{
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "list" },
|
||||
.{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .int64 = 0 },
|
||||
.{ .int64 = 1 },
|
||||
.{ .int64 = 2 },
|
||||
.{ .int64 = 3 },
|
||||
.{ .int64 = 4 },
|
||||
})) } },
|
||||
})) } },
|
||||
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "tuple" },
|
||||
.{ .seq = .{
|
||||
.type = .tuple,
|
||||
.values = @constCast(@as([]const Value, &.{
|
||||
.{ .string = "a" },
|
||||
.{ .int64 = 10 },
|
||||
})),
|
||||
} },
|
||||
})) } },
|
||||
};
|
||||
|
||||
try std.testing.expectEqualDeep(expected, entries);
|
||||
}
|
||||
|
||||
test "Read pickle (zipped)" {
|
||||
var arena = std.heap.ArenaAllocator.init(testing.allocator);
|
||||
defer arena.deinit();
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -109,7 +109,7 @@ pub const ValueType = enum {
|
||||
none,
|
||||
};
|
||||
|
||||
/// A processed value.
|
||||
/// A pickle operator that has been interpreted.
|
||||
pub const Value = union(ValueType) {
|
||||
/// Types that we can't handle or just had to give up on processing.
|
||||
raw: pickle.Op,
|
||||
@ -283,10 +283,10 @@ pub const Value = union(ValueType) {
|
||||
return switch (self) {
|
||||
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
|
||||
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
|
||||
.seq => |seq| blk: {
|
||||
const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) };
|
||||
for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator);
|
||||
break :blk .{ .seq = new_val };
|
||||
.seq => |seq| {
|
||||
const values = try allocator.alloc(Value, seq.values.len);
|
||||
for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator);
|
||||
return .{ .seq = .{ .type = seq.type, .values = values } };
|
||||
},
|
||||
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
|
||||
.bigint => |v| .{ .bigint = try v.clone() },
|
||||
@ -342,8 +342,9 @@ pub const Value = union(ValueType) {
|
||||
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
|
||||
return switch (self) {
|
||||
.raw => |raw_val| switch (raw_val) {
|
||||
.binint, .binint1, .binint2 => |val| .{ .int64 = val },
|
||||
.long1, .long4 => |b| if (b.len != 0) {
|
||||
.binint => |val| .{ .int64 = val },
|
||||
.long => |b| if (b.len != 0) {
|
||||
// TODO: handle trailing 'L'
|
||||
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
|
||||
var mutable = bint.toMutable();
|
||||
mutable.readTwosComplement(b, b.len, .little, .signed);
|
||||
@ -355,28 +356,17 @@ pub const Value = union(ValueType) {
|
||||
} else return .{ .bigint = bint };
|
||||
} else .{ .raw_num = raw_val },
|
||||
.binfloat => |val| .{ .float64 = val },
|
||||
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s },
|
||||
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b },
|
||||
.unicode => |s| .{ .string = s },
|
||||
.bytes => |b| .{ .bytes = b },
|
||||
// This isn't how Pickle actually works but we just try to UTF8 decode the
|
||||
// string and if it fails, we make it a bytes value instead. If anyone
|
||||
// actually cares they can just fix values themselves or recover the raw bytes
|
||||
// from the UTF8 string (it's guaranteed to be reversible, as far as I know).
|
||||
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
|
||||
.newtrue => .{ .boolval = true },
|
||||
.newfalse => .{ .boolval = false },
|
||||
.string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
|
||||
.bool => |b| .{ .boolval = b },
|
||||
.none => .{ .none = {} },
|
||||
inline .int,
|
||||
.float,
|
||||
.long,
|
||||
=> |v, tag| {
|
||||
if (tag == .int and std.mem.eql(u8, v, "01")) {
|
||||
return .{ .boolval = true };
|
||||
} else if (tag == .int and std.mem.eql(u8, v, "00")) {
|
||||
return .{ .boolval = false };
|
||||
} else {
|
||||
return .{ .raw_num = raw_val };
|
||||
}
|
||||
},
|
||||
// TODO .int should be handled like .long
|
||||
.int, .float => .{ .raw_num = raw_val },
|
||||
else => self,
|
||||
},
|
||||
.app, .object, .global => |v| blk: {
|
||||
|
||||
@ -25,8 +25,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
|
||||
|
||||
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
|
||||
switch (val) {
|
||||
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }),
|
||||
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }),
|
||||
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int = v }),
|
||||
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float = v }),
|
||||
.string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
|
||||
.list => |v| switch (validSlice(v)) {
|
||||
true => {
|
||||
@ -36,13 +36,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
|
||||
const values = try allocator.alloc(i64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.int;
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_int = values });
|
||||
},
|
||||
.float => {
|
||||
const values = try allocator.alloc(f64, v.len);
|
||||
errdefer allocator.free(values);
|
||||
for (v, 0..) |item, i| values[i] = item.float;
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_float = values });
|
||||
},
|
||||
.string => {
|
||||
const values = try allocator.alloc([]const u8, v.len);
|
||||
@ -50,7 +50,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
|
||||
for (v, 0..) |item, i| {
|
||||
values[i] = try allocator.dupe(u8, item.string);
|
||||
}
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } });
|
||||
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_string = values });
|
||||
},
|
||||
.list => unreachable,
|
||||
else => {},
|
||||
|
||||
@ -145,6 +145,8 @@ pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool
|
||||
}
|
||||
|
||||
pub fn DeclEnum(comptime T: type) type {
|
||||
const field_infos = std.meta.declarations(T);
|
||||
if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T});
|
||||
return std.meta.DeclEnum(UnwrapPtr(T));
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user