Rename zml.aio.Value to zml.aio.Metadata, simplify its type variants, and update torch pickle/eval APIs accordingly.

This commit is contained in:
Tarry Singh 2023-04-07 16:45:58 +00:00
parent aea23c720e
commit 0189b71070
13 changed files with 1307 additions and 769 deletions

View File

@ -14,7 +14,6 @@ pub const torch = @import("aio/torch.zig");
pub const yaml = @import("aio/yaml.zig"); pub const yaml = @import("aio/yaml.zig");
pub const log = std.log.scoped(.zml_aio); pub const log = std.log.scoped(.zml_aio);
pub const Value = @import("aio/value.zig").Value;
const HostBuffer = @import("hostbuffer.zig").HostBuffer; const HostBuffer = @import("hostbuffer.zig").HostBuffer;
test { test {
@ -56,7 +55,7 @@ pub fn detectFormatAndLoadTokenizer(allocator: std.mem.Allocator, tokenizer_path
else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama")) else if (std.mem.endsWith(u8, tokenizer_path, ".tinyllama"))
try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000) try zml.aio.tinyllama.loadTokenizer(allocator, tokenizer_path, 32000)
else { else {
zml.log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path}); log.err("Failed to recognized tokenizer format of: {s}", .{tokenizer_path});
return error.FormatNotRecognized; return error.FormatNotRecognized;
}; };
} }
@ -87,8 +86,8 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
try prefix_builder.push(allocator, prefix); try prefix_builder.push(allocator, prefix);
defer prefix_builder.deinit(allocator); defer prefix_builder.deinit(allocator);
var unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count())); const unique_id = zml.Tensor.reserveIdRange(@intCast(store.buffers.count()));
const ok = _populateStruct(allocator, &prefix_builder, &unique_id, store, &model, true) catch |err| { const ok = _populateStruct(allocator, &prefix_builder, unique_id, store, &model, true) catch |err| {
std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) }); std.debug.panic("Can't populate model of type {s}: {s}", .{ @typeName(type), @errorName(err) });
}; };
if (!ok) return error.TensorNotFound; if (!ok) return error.TensorNotFound;
@ -98,12 +97,12 @@ pub fn populateModelWithPrefix(comptime Model: type, allocator: std.mem.Allocato
/// A struct containing all the buffers and metadata found in a model file. /// A struct containing all the buffers and metadata found in a model file.
pub const BufferStore = struct { pub const BufferStore = struct {
pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer); pub const Buffers = std.StringArrayHashMapUnmanaged(HostBuffer);
pub const Metadata = std.StringArrayHashMapUnmanaged(Value); pub const Metadatas = std.StringArrayHashMapUnmanaged(Metadata);
arena: std.heap.ArenaAllocator, arena: std.heap.ArenaAllocator,
files: []MemoryMappedFile = &.{}, files: []MemoryMappedFile = &.{},
buffers: Buffers = .{}, buffers: Buffers = .{},
_metadata: Metadata = .{}, _metadata: Metadatas = .{},
pub fn deinit(self: BufferStore) void { pub fn deinit(self: BufferStore) void {
for (self.files) |*file| file.deinit(); for (self.files) |*file| file.deinit();
@ -135,7 +134,7 @@ pub const BufferStore = struct {
return if (maybe_max_index) |index| index + 1 else 0; return if (maybe_max_index) |index| index + 1 else 0;
} }
pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Value)) ?std.meta.FieldType(Value, tag) { pub fn metadata(self: BufferStore, key: []const u8, comptime tag: std.meta.FieldEnum(Metadata)) ?std.meta.FieldType(Metadata, tag) {
const wrapped_value = self._metadata.get(key) orelse return null; const wrapped_value = self._metadata.get(key) orelse return null;
if (wrapped_value != tag) { if (wrapped_value != tag) {
@ -145,14 +144,86 @@ pub const BufferStore = struct {
return @field(wrapped_value, @tagName(tag)); return @field(wrapped_value, @tagName(tag));
} }
pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Value.Slice.ItemType) ?[]const Value.Slice.toZigType(tag) { pub fn metadataSlice(self: BufferStore, key: []const u8, comptime tag: Metadata.ItemType) ?[]const tag.toZigType() {
const wrapped_value = self._metadata.get(key) orelse return null; const wrapped_value = self._metadata.get(key) orelse return null;
const true_tag = std.meta.stringToEnum(std.meta.FieldEnum(Metadata), @tagName(tag)).?;
if (wrapped_value != .array or wrapped_value.array.item_type != tag) { if (wrapped_value == true_tag) {
return null; return @field(wrapped_value, "array_" ++ @tagName(tag));
}
return null;
}
};
pub const Metadata = union(enum) {
null: void,
int: i64,
float: f64,
bool: bool,
string: []const u8,
array_bool: []const bool,
array_int: []const i64,
array_float: []const f64,
array_string: []const []const u8,
pub const ItemType = enum {
int,
float,
bool,
string,
pub fn toZigType(comptime kind: ItemType) type {
return switch (kind) {
.int => i64,
.float => f64,
.bool => bool,
.string => []const u8,
};
}
};
pub fn wrap(x: anytype) Metadata {
return switch (@TypeOf(x)) {
inline u8, i8, u16, i16, u32, i32, u64, i64 => .{ .int = @intCast(x) },
inline f16, f32, f64 => .{ .float = @floatCast(x) },
bool => .{ .bool = x },
[]const u8 => .{ .string = x },
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(x))),
};
}
pub fn copySlice(allocator: std.mem.Allocator, any_slice: anytype) !Metadata {
return switch (@TypeOf(any_slice[0])) {
inline u8, i8, u16, i16, u32, i32, u64, i64 => {
const res = try allocator.alloc(i64, any_slice.len);
for (res, any_slice) |*r, val| r.* = @intCast(val);
return .{ .array_int = res };
},
inline f16, f32, f64 => {
const res = try allocator.alloc(f64, any_slice.len);
for (res, any_slice) |*r, val| r.* = @floatCast(val);
return .{ .array_float = res };
},
bool => .{ .array_bool = try allocator.dupe(bool, any_slice) },
[]const u8 => .{ .array_string = try allocator.dupe([]const u8, @alignCast(any_slice)) },
else => @panic("Unsupported type for zml.aio.Value: " ++ @typeName(@TypeOf(any_slice))),
};
}
pub fn format(
self: Metadata,
comptime fmt: []const u8,
options: std.fmt.FormatOptions,
writer: anytype,
) !void {
_ = fmt;
_ = options;
switch (self) {
.null => _ = try writer.write("null"),
inline .bool, .array_bool => |b| try writer.print("{any}", .{b}),
inline else => |v| try writer.print("{d}", .{v}),
} }
const T = Value.Slice.toZigType(tag);
return @alignCast(std.mem.bytesAsSlice(T, wrapped_value.array.data));
} }
}; };
@ -244,7 +315,7 @@ const PrefixBuilder = struct {
fn _populateStruct( fn _populateStruct(
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
prefix_builder: *PrefixBuilder, prefix_builder: *PrefixBuilder,
unique_id: *u64, unique_id: u64,
buffer_store: BufferStore, buffer_store: BufferStore,
obj: anytype, obj: anytype,
required: bool, required: bool,
@ -260,17 +331,17 @@ fn _populateStruct(
const prefix = prefix_builder.data.items; const prefix = prefix_builder.data.items;
if (T == zml.Tensor) { if (T == zml.Tensor) {
return if (buffer_store.get(prefix)) |buffer| { return if (buffer_store.buffers.getIndex(prefix)) |entry_idx| {
const buffer = buffer_store.get(prefix).?;
obj.* = zml.Tensor{ obj.* = zml.Tensor{
._shape = buffer.shape(), ._shape = buffer.shape(),
._id = .{ .buffer_id = unique_id.* }, ._id = .{ .buffer_id = unique_id + entry_idx },
._donation = .input_buffer, ._donation = .input_buffer,
}; };
unique_id.* += 1;
return true; return true;
} else { } else {
if (required) { if (required) {
std.log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() }); log.err("Tensor not found: {s} ({d})", .{ prefix, buffer_store.buffers.count() });
} }
return false; return false;
}; };
@ -290,7 +361,7 @@ fn _populateStruct(
defer prefix_builder.pop(); defer prefix_builder.pop();
const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required); const found = try _populateStruct(allocator, prefix_builder, unique_id, buffer_store, value, required);
if (!found) { if (!found) {
std.log.err("Not able to load {s} as {s}", .{ prefix, @typeName(ptr_info.child) }); log.err("Not able to load {s} as {s}", .{ prefix_builder.data.items, @typeName(ptr_info.child) });
return false; return false;
} }
} }
@ -299,7 +370,7 @@ fn _populateStruct(
} }
return true; return true;
} else { } else {
std.log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) }); log.err("{s} - {s}: {s} type not supported", .{ @src().fn_name, prefix, @typeName(T) });
return false; return false;
} }
}, },
@ -346,7 +417,7 @@ fn _populateStruct(
}, },
.Void => true, .Void => true,
else => if (required) { else => if (required) {
std.log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) }); log.err("{s}: {s} type not supported", .{ prefix, @typeName(T) });
return error.UnsupportedMetadataType; return error.UnsupportedMetadataType;
} else return false, } else return false,
}; };
@ -431,7 +502,7 @@ pub fn loadBuffers(
zml.meta.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model}); zml.meta.assertComptime(@TypeOf(init_args) == void or @TypeOf(init_args) == @TypeOf(.{}), "Model of type {} has no init function, so `loadBuffers` should be call with init_args set to {{}} (void)", .{Model});
} }
return loadModelBuffers(Model, model, buffer_store, allocator, platform); return loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
} }
/// Creates a bufferized version of a Model from the given BufferStore. For details about /// Creates a bufferized version of a Model from the given BufferStore. For details about
@ -449,6 +520,7 @@ pub fn loadModelBuffers(
) !zml.Bufferized(Model) { ) !zml.Bufferized(Model) {
return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, ""); return try loadModelBuffersWithPrefix(Model, model, buffer_store, allocator, platform, "");
} }
/// Creates a bufferized version of a Model from the given BufferStore and the given prefix. /// Creates a bufferized version of a Model from the given BufferStore and the given prefix.
/// For details about bufferization, see the documentation of Bufferized(T). /// For details about bufferization, see the documentation of Bufferized(T).
/// ///

View File

@ -36,24 +36,24 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
log.err("GGUF File: Tokens not found", .{}); log.err("GGUF File: Tokens not found", .{});
return error.TokensNotFound; return error.TokensNotFound;
}; };
const scores = self.metadataSlice("tokenizer.ggml.scores", .float32) orelse { const scores = self.metadataSlice("tokenizer.ggml.scores", .float) orelse {
log.err("GGUF File: Scores not found", .{}); log.err("GGUF File: Scores not found", .{});
return error.ScoresNotFound; return error.ScoresNotFound;
}; };
assert(tokens.len == scores.len); assert(tokens.len == scores.len);
const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama"; const tokenizer_type = self.metadata("tokenizer.ggml.model", .string) orelse "llama";
const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece; const tokenizer_impl: zml.tokenizer.KnownImplementation = if (std.mem.eql(u8, tokenizer_type, "gpt2")) .gpt2 else .sentencepiece;
const bos = self.metadata("tokenizer.ggml.bos_token_id", .uint32); const bos = self.metadata("tokenizer.ggml.bos_token_id", .int);
const eos = self.metadata("tokenizer.ggml.eos_token_id", .uint32); const eos = self.metadata("tokenizer.ggml.eos_token_id", .int);
const unk = self.metadata("tokenizer.ggml.unknown_token_id", .uint32); const unk = self.metadata("tokenizer.ggml.unknown_token_id", .int);
const pad = self.metadata("tokenizer.ggml.padding_token_id", .uint32); const pad = self.metadata("tokenizer.ggml.padding_token_id", .int);
const NOT_FOUND = std.math.maxInt(u32); const NOT_FOUND = std.math.maxInt(u32);
const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{ const special_tokens: zml.tokenizer.Tokenizer.SpecialTokens = .{
.bos = bos.?, .bos = @intCast(bos.?),
.eos = eos.?, .eos = @intCast(eos.?),
.unk = unk orelse NOT_FOUND, .unk = @intCast(unk orelse NOT_FOUND),
.pad = pad orelse NOT_FOUND, .pad = @intCast(pad orelse NOT_FOUND),
}; };
const gguf_normalizer = if (tokenizer_impl == .gpt2) const gguf_normalizer = if (tokenizer_impl == .gpt2)
@ -85,10 +85,10 @@ pub fn getGgufTokenizer(self: zml.aio.BufferStore, allocator: std.mem.Allocator)
for (tokens, 0..tokens.len) |t, i| { for (tokens, 0..tokens.len) |t, i| {
if (tokenizer_impl == .gpt2) { if (tokenizer_impl == .gpt2) {
decoded.clearRetainingCapacity(); decoded.clearRetainingCapacity();
try tokenizer.addToken(scores[i], try gpt2_unicode.?.decode(&decoded, t)); try tokenizer.addToken(@floatCast(scores[i]), try gpt2_unicode.?.decode(&decoded, t));
// log.debug("token: {s} -> {s}", .{t, decoded.items}); // log.debug("token: {s} -> {s}", .{t, decoded.items});
} else { } else {
try tokenizer.addToken(scores[i], t); try tokenizer.addToken(@floatCast(scores[i]), t);
} }
} }
@ -112,7 +112,19 @@ fn loadMetadata(allocator: Allocator, store: *zml.aio.BufferStore, file: *core.G
log.warn("Found duplicated metadata key: {s}", .{entry.name}); log.warn("Found duplicated metadata key: {s}", .{entry.name});
continue; continue;
} }
res.value_ptr.* = entry.val.asLoaderValue(); res.value_ptr.* = switch (entry.val) {
.array => |arr| switch (arr.child) {
inline .uint8, .int8, .uint16, .int16, .uint32, .int32, .float32, .bool, .string, .uint64, .int64, .float64 => |tag| blk: {
const T = std.meta.FieldType(core.GgufValue, tag);
break :blk try zml.aio.Metadata.copySlice(allocator, std.mem.bytesAsSlice(T, arr.data));
},
else => blk: {
log.warn("ignoring array metadata", .{});
break :blk .null;
},
},
inline else => |v| zml.aio.Metadata.wrap(v),
};
} else |err| switch (err) { } else |err| switch (err) {
error.EndOfMetadata => {}, error.EndOfMetadata => {},
else => return err, else => return err,

View File

@ -176,20 +176,20 @@ pub const GgufValueType = enum(u32) {
} }
}; };
pub const ValueType = enum { pub const ValueType = enum(u8) {
uint8, uint8 = 0,
int8, int8 = 1,
uint16, uint16 = 2,
int16, int16 = 3,
uint32, uint32 = 4,
int32, int32 = 5,
float32, float32 = 6,
uint64, bool = 7,
int64, string = 8,
float64, array = 9,
boolval, uint64 = 10,
string, int64 = 11,
array, float64 = 12,
}; };
// Union of possible values. // Union of possible values.
@ -201,47 +201,20 @@ pub const GgufValue = union(ValueType) {
uint32: u32, uint32: u32,
int32: i32, int32: i32,
float32: f32, float32: f32,
bool: bool,
string: []const u8,
array: Array,
uint64: u64, uint64: u64,
int64: i64, int64: i64,
float64: f64, float64: f64,
boolval: bool,
string: []const u8,
array: Array,
pub const Array = struct { pub const Array = struct {
// Any value type is valid, including arrays. // Any value type is valid, including arrays.
child: GgufValueType, child: ValueType,
// Number of elements, not bytes // Number of elements, not bytes
len: usize, len: usize,
data: []u8, data: []u8,
}; };
pub fn asLoaderValue(self: GgufValue) zml.aio.Value {
return switch (self) {
.array => |v| .{
.array = .{
.item_type = switch (v.child) {
.bool => .boolval,
.uint8 => .uint8,
.int8 => .int8,
.uint16 => .uint16,
.int16 => .int16,
.uint32 => .uint32,
.int32 => .int32,
.float32 => .float32,
.uint64 => .uint64,
.int64 => .int64,
.float64 => .float64,
.string => .string,
// TODO: .array => .array,
else => unreachable,
},
.data = v.data,
},
},
inline else => |v, tag| @unionInit(zml.aio.Value, @tagName(tag), v),
};
}
}; };
// Header // Header
@ -403,6 +376,9 @@ pub const GgufFile = struct {
fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array { fn readArrayHeader(self: *GgufFile, allocator: std.mem.Allocator) !GgufValue.Array {
const child = try self.readValueType(); const child = try self.readValueType();
if (@intFromEnum(child) > @intFromEnum(ValueType.float64)) {
return error.UnsupportedGgufType;
}
const len: usize = try self.readInt(u64); const len: usize = try self.readInt(u64);
const data = switch (child) { const data = switch (child) {
// Since strings have variable lenghts, we need to read them one by one // Since strings have variable lenghts, we need to read them one by one
@ -414,7 +390,7 @@ pub const GgufFile = struct {
else => try self.readAlloc(allocator, len * child.sizeOf()), else => try self.readAlloc(allocator, len * child.sizeOf()),
}; };
return .{ return .{
.child = child, .child = @enumFromInt(@intFromEnum(child)),
.len = len, .len = len,
.data = data, .data = data,
}; };
@ -429,7 +405,7 @@ pub const GgufFile = struct {
.uint32 => .{ .uint32 = try self.readInt(u32) }, .uint32 => .{ .uint32 = try self.readInt(u32) },
.int32 => .{ .int32 = try self.readInt(i32) }, .int32 => .{ .int32 = try self.readInt(i32) },
.float32 => .{ .float32 = @bitCast(try self.readInt(u32)) }, .float32 => .{ .float32 = @bitCast(try self.readInt(u32)) },
.bool => .{ .boolval = try self.readInt(u8) != 0 }, .bool => .{ .bool = try self.readInt(u8) != 0 },
.string => .{ .string = try self.readString(allocator) }, .string => .{ .string = try self.readString(allocator) },
.array => .{ .array = try self.readArrayHeader(allocator) }, .array => .{ .array = try self.readArrayHeader(allocator) },
.uint64 => .{ .uint64 = try self.readInt(u64) }, .uint64 => .{ .uint64 = try self.readInt(u64) },

View File

@ -30,44 +30,40 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, prefix:
const metadata = &store._metadata; const metadata = &store._metadata;
const key = prefix.items; const key = prefix.items;
return switch (val) { return switch (val) {
.null => try metadata.put(allocator, try allocator.dupe(u8, key), .{ .null = {} }), .null => try metadata.put(allocator, try allocator.dupe(u8, key), .null),
.bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .boolval = v }), .bool => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .bool = v }),
.integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int64 = v }), .integer => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .int = v }),
.float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float64 = v }), .float => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .float = v }),
.number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }), .number_string, .string => |v| try metadata.put(allocator, try allocator.dupe(u8, key), .{ .string = try allocator.dupe(u8, v) }),
.array => |v| { .array => |v| {
if (v.items.len == 0) return; if (v.items.len == 0) return;
return if (validSlice(v)) |item_type| { return if (validSlice(v)) |item_type| {
const data, const dtype: zml.aio.Value.Slice.ItemType = switch (item_type) { const data: zml.aio.Metadata = switch (item_type) {
.bool => blk: { .bool => blk: {
const values = try allocator.alloc(bool, v.items.len); const values = try allocator.alloc(bool, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.bool; for (v.items, 0..) |item, i| values[i] = item.bool;
break :blk .{ std.mem.sliceAsBytes(values), .boolval }; break :blk .{ .array_bool = values };
}, },
.integer => blk: { .integer => blk: {
const values = try allocator.alloc(i64, v.items.len); const values = try allocator.alloc(i64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.integer; for (v.items, 0..) |item, i| values[i] = item.integer;
break :blk .{ std.mem.sliceAsBytes(values), .int64 }; break :blk .{ .array_int = values };
}, },
.float => blk: { .float => blk: {
const values = try allocator.alloc(f64, v.items.len); const values = try allocator.alloc(f64, v.items.len);
for (v.items, 0..) |item, i| values[i] = item.float; for (v.items, 0..) |item, i| values[i] = item.float;
break :blk .{ std.mem.sliceAsBytes(values), .float64 }; break :blk .{ .array_float = values };
}, },
inline .string, .number_string => |tag| blk: { inline .string, .number_string => |tag| blk: {
const values = try allocator.alloc([]const u8, v.items.len); const values = try allocator.alloc([]const u8, v.items.len);
for (v.items, 0..) |item, i| { for (v.items, 0..) |item, i| {
values[i] = @field(item, @tagName(tag)); values[i] = @field(item, @tagName(tag));
} }
break :blk .{ std.mem.sliceAsBytes(values), .string }; break :blk .{ .array_string = values };
}, },
.null, .array, .object => unreachable, .null, .array, .object => unreachable,
}; };
try metadata.put( try metadata.put(allocator, try allocator.dupe(u8, key), data);
allocator,
try allocator.dupe(u8, key),
.{ .array = .{ .item_type = dtype, .data = data } },
);
} else { } else {
for (v.items, 0..) |item, i| { for (v.items, 0..) |item, i| {
var new_prefix = prefix; var new_prefix = prefix;

View File

@ -39,10 +39,10 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const start = try mapped_file.file.getPos(); const start = try mapped_file.file.getPos();
var tmp: zml.aio.torch.PickleData = .{ var tmp: zml.aio.torch.PickleData = .{
.data = try parser.Parser.fromTarFile(arena, mapped_file, file), .data = try parser.Parser.fromTarFile(arena, mapped_file, file),
.memo = undefined,
.stack = undefined, .stack = undefined,
}; };
tmp.stack, tmp.memo = try eval.evaluate(arena, tmp.data.ops, true); tmp.stack = try eval.evaluate(arena, tmp.data.ops, true);
try tmp.parseModel(arena, &res); try tmp.parseModel(arena, &res);
// Since we directly manipulate the file handle pointer, // Since we directly manipulate the file handle pointer,
// reset to the end of file so iterator does not error // reset to the end of file so iterator does not error

View File

@ -91,17 +91,17 @@ pub fn open(allocator: std.mem.Allocator, model_path: []const u8) !zml.aio.Buffe
{ {
try res._metadata.ensureUnusedCapacity(arena, 11); try res._metadata.ensureUnusedCapacity(arena, 11);
res._metadata.putAssumeCapacityNoClobber("dim", .{ .int64 = c.dim }); res._metadata.putAssumeCapacityNoClobber("dim", .{ .int = c.dim });
res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int64 = c.hidden_dim }); res._metadata.putAssumeCapacityNoClobber("hidden_dim", .{ .int = c.hidden_dim });
res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int64 = c.n_layers }); res._metadata.putAssumeCapacityNoClobber("n_layers", .{ .int = c.n_layers });
res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int64 = c.n_heads }); res._metadata.putAssumeCapacityNoClobber("num_heads", .{ .int = c.n_heads });
res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int64 = c.n_kv_heads }); res._metadata.putAssumeCapacityNoClobber("num_kv_heads", .{ .int = c.n_kv_heads });
res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int64 = c.vocab.size }); res._metadata.putAssumeCapacityNoClobber("vocab_size", .{ .int = c.vocab.size });
res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .boolval = c.vocab.has_lm_head }); res._metadata.putAssumeCapacityNoClobber("has_lm_head", .{ .bool = c.vocab.has_lm_head });
res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int64 = c.seq_len }); res._metadata.putAssumeCapacityNoClobber("max_seq_len", .{ .int = c.seq_len });
res._metadata.putAssumeCapacityNoClobber("rope_impl", .{ .string = "interleaved" }); res._metadata.putAssumeCapacityNoClobber("rope_impl", .{ .string = "interleaved" });
res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float64 = 10_000 }); res._metadata.putAssumeCapacityNoClobber("rope_freq_base", .{ .float = 10_000 });
res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float64 = 1e-6 }); res._metadata.putAssumeCapacityNoClobber("rms_norm_eps", .{ .float = 1e-6 });
} }
return res; return res;

View File

@ -16,6 +16,7 @@ const StringBuilder = std.ArrayListUnmanaged(u8);
const log = std.log.scoped(.zml_io); const log = std.log.scoped(.zml_io);
test { test {
std.testing.refAllDecls(@This());
std.testing.refAllDecls(eval); std.testing.refAllDecls(eval);
std.testing.refAllDecls(value); std.testing.refAllDecls(value);
std.testing.refAllDecls(parser); std.testing.refAllDecls(parser);
@ -35,22 +36,21 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
const tmp_alloc = arena.allocator(); const tmp_alloc = arena.allocator();
const _parser = try parser.Parser.init(tmp_alloc, file); const _parser = try parser.Parser.init(tmp_alloc, file);
const stack, const memo = try eval.evaluate(tmp_alloc, _parser.ops, true); const stack = try eval.evaluate(tmp_alloc, _parser.ops, true);
// But we create the HostBuffer objects inside the result BufferStore arena. // But we create the HostBuffer objects inside the result BufferStore arena.
var res: zml.aio.BufferStore = .{ var res: zml.aio.BufferStore = .{
.arena = std.heap.ArenaAllocator.init(allocator), .arena = std.heap.ArenaAllocator.init(allocator),
}; };
res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file}); res.files = try res.arena.allocator().dupe(zml.aio.MemoryMappedFile, &.{_parser.buffer_file});
var tmp: PickleData = .{ .data = _parser, .memo = memo, .stack = stack }; var tmp: PickleData = .{ .data = _parser, .stack = stack };
try tmp.parseModel(res.arena.allocator(), &res); try tmp.parseModel(res.arena.allocator(), &res);
return res; return res;
} }
// TODO: rename me to PytorchFile // TODO: rename me to PytorchFile
pub const PickleData = struct { pub const PickleData = struct {
stack: eval.PickleStack, stack: []const Value,
memo: eval.PickleMemo,
data: parser.Parser, data: parser.Parser,
fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool { fn basicTypeCheck(object: *const value.Object, module: []const u8, class: []const u8) bool {
@ -63,7 +63,7 @@ pub const PickleData = struct {
} }
pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void { pub fn parseModel(self: *PickleData, allocator: std.mem.Allocator, store: *zml.aio.BufferStore) !void {
for (self.stack.stack) |item| { for (self.stack) |item| {
var prefix_buf: [1024]u8 = undefined; var prefix_buf: [1024]u8 = undefined;
try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item); try self.parseValue(allocator, store, StringBuilder.initBuffer(&prefix_buf), item);
} }
@ -147,7 +147,7 @@ pub const PickleData = struct {
try store._metadata.put( try store._metadata.put(
allocator, allocator,
try allocator.dupe(u8, prefix.items), try allocator.dupe(u8, prefix.items),
.{ .array = .{ .item_type = std.meta.stringToEnum(zml.aio.Value.Slice.ItemType, @tagName(tag)).?, .data = std.mem.sliceAsBytes(try values.toOwnedSlice(allocator)) } }, try zml.aio.Metadata.copySlice(allocator, values.items),
); );
} else { } else {
for (values.items, 0..) |val, i| { for (values.items, 0..) |val, i| {
@ -156,7 +156,13 @@ pub const PickleData = struct {
new_prefix.appendAssumeCapacity('.'); new_prefix.appendAssumeCapacity('.');
} }
new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{}); new_prefix.items.len += std.fmt.formatIntBuf(new_prefix.unusedCapacitySlice(), i, 10, .lower, .{});
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Value, @tagName(tag), val)); const new_tag = switch (tag) {
.int64 => "int",
.float64 => "float",
.boolval => "bool",
else => unreachable, // we are already inside a switch
};
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), @unionInit(zml.aio.Metadata, new_tag, val));
} }
} }
}, },
@ -212,15 +218,17 @@ pub const PickleData = struct {
if (d.found_existing) { if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items}); log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key); allocator.free(key);
} else d.value_ptr.* = .{ .array = .{ .item_type = .uint8, .data = @constCast(val) } }; } else d.value_ptr.* = .{ .string = val };
}, },
inline .float64, .int64, .boolval, .bigint, .string => |val, tag| { inline .float64, .int64, .boolval, .bigint, .string => |val| {
const key = try allocator.dupe(u8, prefix.items); const key = try allocator.dupe(u8, prefix.items);
const d = try store._metadata.getOrPut(allocator, key); const d = try store._metadata.getOrPut(allocator, key);
if (d.found_existing) { if (d.found_existing) {
log.warn("Duplicate key: {s}", .{prefix.items}); log.warn("Duplicate key: {s}", .{prefix.items});
allocator.free(key); allocator.free(key);
} else d.value_ptr.* = @unionInit(zml.aio.Value, @tagName(tag), val); } else {
d.value_ptr.* = zml.aio.Metadata.wrap(val);
}
}, },
else => {}, else => {},
} }
@ -248,7 +256,7 @@ pub const PickleData = struct {
} }
const d = try allocator.alloc(i64, size.len); const d = try allocator.alloc(i64, size.len);
for (d, 0..) |*di, i| di.* = size[i].int64; for (d, 0..) |*di, i| di.* = size[i].int64;
entry.value_ptr.* = .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(d) } }; entry.value_ptr.* = .{ .array_int = d };
return true; return true;
} else if (basicTypeCheck(object, "fractions", "Fraction")) { } else if (basicTypeCheck(object, "fractions", "Fraction")) {
const fraction_str = object.args[0].seq.values[0].string; const fraction_str = object.args[0].seq.values[0].string;
@ -256,12 +264,12 @@ pub const PickleData = struct {
{ {
var new_prefix = prefix; var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".numerator"); new_prefix.appendSliceAssumeCapacity(".numerator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) }); try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[0..split_idx], 10) });
} }
{ {
var new_prefix = prefix; var new_prefix = prefix;
new_prefix.appendSliceAssumeCapacity(".denominator"); new_prefix.appendSliceAssumeCapacity(".denominator");
try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int64 = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) }); try store._metadata.put(allocator, try allocator.dupe(u8, new_prefix.items), .{ .int = try std.fmt.parseInt(i64, fraction_str[split_idx + 1 ..], 10) });
} }
return true; return true;
} }

View File

@ -159,84 +159,9 @@ pub const PickleMemo = struct {
} }
}; };
pub const InternalStack = struct { pub fn evaluate(arena: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) ![]const Value {
allocator: std.mem.Allocator, var stack = std.ArrayList(Value).init(arena);
values: std.ArrayList(Value), var memo = PickleMemo.init(arena);
pub fn init(allocator: std.mem.Allocator) InternalStack {
return .{
.allocator = allocator,
.values = std.ArrayList(Value).init(allocator),
};
}
pub fn deinit(self: *InternalStack) void {
for (0..self.values.items.len) |i| self.values.items[i].deinit(self.allocator);
self.values.deinit();
self.* = undefined;
}
pub fn pop(self: *InternalStack) !Value {
if (self.values.items.len == 0) {
return error.StackUnderrun;
}
return self.values.pop();
}
pub fn popMark(self: *InternalStack, allocator: ?std.mem.Allocator) ![]Value {
const markidx = try self.findMark();
var postmark: []Value = &[_]Value{};
if (allocator) |a| {
postmark = try a.alloc(Value, self.values.items.len - (markidx + 1));
@memcpy(postmark, self.values.items[markidx + 1 ..]);
}
self.values.shrinkAndFree(markidx);
return postmark;
}
pub fn lastMut(self: *InternalStack) !*Value {
if (self.values.items.len == 0) {
return error.UnexpectedEmptyStack;
}
return &self.values.items[self.values.items.len - 1];
}
pub fn findMark(self: *InternalStack) !usize {
const len = self.values.items.len;
for (0..len) |i| {
const idx = (len - 1) - i;
const val = self.values.items[idx];
if (val == .raw and val.raw == .mark) {
return idx;
}
}
zml.log.warn("pytorch loader: missing mark", .{});
return 0;
}
pub fn toPickleStack(self: *InternalStack) !PickleStack {
return .{ .stack = try self.values.toOwnedSlice(), .allocator = self.allocator };
}
};
pub const PickleStack = struct {
stack: []Value,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, values: []Value) PickleStack {
return .{ .allocator = allocator, .stack = values };
}
pub fn deinit(self: *PickleStack) void {
for (self.stack) |*v| v.deinit(self.allocator);
self.allocator.free(self.stack);
}
};
pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs: bool) !struct { PickleStack, PickleMemo } {
var stack = InternalStack.init(allocator);
defer stack.deinit();
var memo = PickleMemo.init(allocator);
errdefer memo.deinit(); errdefer memo.deinit();
const makeKVList = (struct { const makeKVList = (struct {
@ -258,56 +183,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
} }
}).call; }).call;
outer: for (x) |op| { for (x) |op| {
switch (op) { switch (op) {
.mark => try stack.values.append(.{ .raw = op }), .mark => try stack.append(.{ .raw = op }),
.stop => break :outer, .frame => {},
.pop => _ = try stack.pop(), .stop => break,
.pop_mark => _ = try stack.popMark(allocator), .pop => _ = try pop(&stack),
.dup => { .pop_mark => try popMarkDiscard(&stack),
if (stack.values.getLastOrNull()) |item| { .dup => if (stack.getLastOrNull()) |item|
try stack.values.append(try item.clone(allocator)); try stack.append(try item.clone(arena))
} else { else
return error.CannotDupEmptyStack; return error.CannotDupEmptyStack,
} .persid => |v| try stack.append(.{ .pers_id = try PersId.init(arena, .{ .string = try arena.dupe(u8, v) }) }),
}, .binpersid => try stack.append(.{ .pers_id = try PersId.init(arena, try pop(&stack)) }),
.persid => |v| try stack.values.append(.{ .pers_id = try PersId.init(allocator, .{ .string = try allocator.dupe(u8, v) }) }), .reduce => try stack.append(.{ .global = blk: {
.binpersid => try stack.values.append(.{ .pers_id = try PersId.init(allocator, try stack.pop()) }), const values = try arena.alloc(Value, 1);
.reduce => try stack.values.append(.{ .global = blk: { values[0] = try memo.resolve(arena, try pop(&stack), true);
const values = try allocator.alloc(Value, 1); break :blk try Object.init(arena, try memo.resolve(arena, try pop(&stack), true), values);
values[0] = try memo.resolve(allocator, try stack.pop(), true);
break :blk try Object.init(allocator, try memo.resolve(allocator, try stack.pop(), true), values);
} }), } }),
.build => try stack.values.append(blk: { .build => try stack.append(blk: {
const args = try memo.resolve(allocator, try stack.pop(), true); const args = try memo.resolve(arena, try pop(&stack), true);
const member = try memo.resolve(allocator, try stack.pop(), true); const member = try memo.resolve(arena, try pop(&stack), true);
break :blk .{ .build = try Build.init(allocator, member, args) }; break :blk .{ .build = try Build.init(arena, member, args) };
}), }),
.empty_dict => try stack.values.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }), .empty_dict => try stack.append(.{ .seq = .{ .type = .dict, .values = &[_]Value{} } }),
.get => |v| try stack.values.append(.{ .ref = try std.fmt.parseInt(u32, v, 10) }), .get => |v| try stack.append(.{ .ref = v }),
inline .binget, .long_binget => |v| try stack.values.append(.{ .ref = v }), .empty_list => try stack.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }),
.empty_list => try stack.values.append(.{ .seq = .{ .type = .list, .values = &[_]Value{} } }), .put => |v| {
.binput, .long_binput => |v| { try memo.insert(v, try pop(&stack));
try memo.insert(v, try stack.pop()); try stack.append(.{ .ref = v });
try stack.values.append(.{ .ref = v });
}, },
.tuple => try stack.values.append(blk: { .tuple => try stack.append(blk: {
const popped = try stack.popMark(allocator); const popped = try popMark(&stack, arena);
break :blk .{ .seq = .{ .type = .tuple, .values = popped } }; break :blk .{ .seq = .{ .type = .tuple, .values = popped } };
}), }),
.empty_tuple => try stack.values.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }), .empty_tuple => try stack.append(.{ .seq = .{ .type = .tuple, .values = &[_]Value{} } }),
.setitem => { .setitem => {
const v, const k = .{ try stack.pop(), try stack.pop() }; const v, const k = .{ try pop(&stack), try pop(&stack) };
const top = try stack.lastMut(); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
}, },
.seq => |*tup| { .seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ k, v }) } }; tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ k, v }) } };
}, },
else => { else => {
return error.BadStackTopForSetItem; return error.BadStackTopForSetItem;
@ -315,53 +237,53 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
} }
}, },
.setitems => { .setitems => {
const popped = try stack.popMark(allocator); const popped = try popMark(&stack, arena);
defer allocator.free(popped); defer arena.free(popped);
const kv_items = try makeKVList(allocator, popped); const kv_items = try makeKVList(arena, popped);
const top = try stack.lastMut(); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; obj.args[obj.args.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
}, },
.seq => |*tup| { .seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } }; tup.values[tup.values.len - 1] = .{ .seq = .{ .type = .tuple, .values = kv_items } };
}, },
else => { else => {
defer allocator.free(kv_items); defer arena.free(kv_items);
return error.BadStackTopForSetItems; return error.BadStackTopForSetItems;
}, },
} }
}, },
.proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}), .proto => |proto| meta.assert(proto <= MAX_PROTOCOL, "Unsupported protocol {d}", .{proto}),
.tuple1 => try stack.values.append(blk: { .tuple1 => try stack.append(blk: {
const tup_values = try allocator.alloc(Value, 1); const tup_values = try arena.alloc(Value, 1);
tup_values[0] = try stack.pop(); tup_values[0] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple2 => try stack.values.append(blk: { .tuple2 => try stack.append(blk: {
const tup_values = try allocator.alloc(Value, 2); const tup_values = try arena.alloc(Value, 2);
inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); inline for (0..2) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.tuple3 => try stack.values.append(blk: { .tuple3 => try stack.append(blk: {
const tup_values = try allocator.alloc(Value, 3); const tup_values = try arena.alloc(Value, 3);
inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try stack.pop(); inline for (0..3) |i| tup_values[(tup_values.len - 1) - i] = try pop(&stack);
break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } }; break :blk .{ .seq = .{ .type = .tuple, .values = tup_values } };
}), }),
.append => { .append => {
const v = try stack.pop(); const v = try pop(&stack);
const top = try stack.lastMut(); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
obj.args = try assuredResize(Value, allocator, obj.args, obj.args.len + 1); obj.args = try assuredResize(Value, arena, obj.args, obj.args.len + 1);
obj.args[obj.args.len - 1] = v; obj.args[obj.args.len - 1] = v;
}, },
.seq => |*tup| { .seq => |*tup| {
tup.values = try assuredResize(Value, allocator, tup.values, tup.values.len + 1); tup.values = try assuredResize(Value, arena, tup.values, tup.values.len + 1);
tup.values[tup.values.len - 1] = v; tup.values[tup.values.len - 1] = v;
}, },
else => { else => {
@ -370,19 +292,19 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
} }
}, },
.appends => { .appends => {
const postmark = try stack.popMark(allocator); const postmark = try popMark(&stack, arena);
defer allocator.free(postmark); defer arena.free(postmark);
const top = try stack.lastMut(); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
const obj_len = obj.args.len; const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len); obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark); @memcpy(obj.args[obj_len..], postmark);
}, },
.seq => |*tup| { .seq => |*tup| {
const tup_len = tup.values.len; const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark); @memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
@ -390,49 +312,43 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
}, },
} }
}, },
.dict => try stack.values.append(blk: { .dict => try stack.append(blk: {
const popped = try stack.popMark(allocator); const popped = try popMark(&stack, arena);
defer allocator.free(popped); defer arena.free(popped);
const kv_items = try makeKVList(allocator, popped); const kv_items = try makeKVList(arena, popped);
break :blk .{ .seq = .{ .type = .dict, .values = kv_items } }; break :blk .{ .seq = .{ .type = .dict, .values = kv_items } };
}), }),
.list => try stack.values.append(.{ .seq = .{ .type = .list, .values = try stack.popMark(allocator) } }), .list => try stack.append(.{ .seq = .{ .type = .list, .values = try popMark(&stack, arena) } }),
.inst => |v| try stack.values.append(blk: { .inst => |v| try stack.append(blk: {
const tup_items = try allocator.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } }); const tup_items = try arena.dupe(Value, &.{ .{ .string = v.module }, .{ .string = v.class } });
break :blk .{ .object = try Object.init(allocator, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try stack.popMark(allocator)) }; break :blk .{ .object = try Object.init(arena, .{ .seq = .{ .type = .tuple, .values = tup_items } }, try popMark(&stack, arena)) };
}), }),
.obj => try stack.values.append(blk: { .obj => try stack.append(blk: {
const markidx = try stack.findMark(); const mark = try findMark(&stack);
const args = try allocator.alloc(Value, stack.values.items.len - (markidx + 2)); const args = try arena.dupe(Value, stack.items[mark + 2 ..]);
@memcpy(args, stack.values.items[markidx + 2 ..]); const member = stack.items[mark + 1];
const member = stack.values.items[markidx + 1]; break :blk .{ .object = try Object.init(arena, member, args) };
break :blk .{ .object = try Object.init(allocator, member, args) };
}), }),
.put => |v| { .newobj => try stack.append(blk: {
const mid = try std.fmt.parseInt(u32, v, 10); const args = try arena.alloc(Value, 1);
try memo.insert(mid, try stack.pop()); args[0] = try pop(&stack);
try stack.values.append(.{ .ref = mid }); break :blk .{ .object = try Object.init(arena, try pop(&stack), args) };
},
.newobj => try stack.values.append(blk: {
const args = try allocator.alloc(Value, 1);
args[0] = try stack.pop();
break :blk .{ .object = try Object.init(allocator, try stack.pop(), args) };
}), }),
.empty_set => try stack.values.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }), .empty_set => try stack.append(.{ .seq = .{ .type = .set, .values = &[_]Value{} } }),
.additems => { .additems => {
const postmark = try stack.popMark(allocator); const postmark = try popMark(&stack, arena);
defer allocator.free(postmark); defer arena.free(postmark);
const top = try stack.lastMut(); const top = try lastMut(&stack);
const rtop = try memo.resolveMut(top, true); const rtop = try memo.resolveMut(top, true);
switch (rtop.*) { switch (rtop.*) {
.global => |obj| { .global => |obj| {
const obj_len = obj.args.len; const obj_len = obj.args.len;
obj.args = try assuredResize(Value, allocator, obj.args, obj_len + postmark.len); obj.args = try assuredResize(Value, arena, obj.args, obj_len + postmark.len);
@memcpy(obj.args[obj_len..], postmark); @memcpy(obj.args[obj_len..], postmark);
}, },
.seq => |*tup| { .seq => |*tup| {
const tup_len = tup.values.len; const tup_len = tup.values.len;
tup.values = try assuredResize(Value, allocator, tup.values, tup_len + postmark.len); tup.values = try assuredResize(Value, arena, tup.values, tup_len + postmark.len);
@memcpy(tup.values[tup_len..], postmark); @memcpy(tup.values[tup_len..], postmark);
}, },
else => { else => {
@ -440,36 +356,33 @@ pub fn evaluate(allocator: std.mem.Allocator, x: []const pickle.Op, resolve_refs
}, },
} }
}, },
.frozenset => try stack.values.append(.{ .seq = .{ .type = .frozen_set, .values = try stack.popMark(allocator) } }), .frozenset => try stack.append(.{ .seq = .{ .type = .frozen_set, .values = try popMark(&stack, arena) } }),
.newobj_ex => try stack.values.append(blk: { .newobj_ex => try stack.append(blk: {
const kwargs, const args, const cls = .{ try stack.pop(), try stack.pop(), try stack.pop() }; const kwargs, const args, const cls = .{ try pop(&stack), try pop(&stack), try pop(&stack) };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ args, kwargs }) }; const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ args, kwargs }) };
break :blk .{ .object = try Object.init(allocator, cls, try allocator.dupe(Value, &.{.{ .seq = new_seq }})) }; break :blk .{ .object = try Object.init(arena, cls, try arena.dupe(Value, &.{.{ .seq = new_seq }})) };
}), }),
.stack_global => try stack.values.append(blk: { .stack_global => try stack.append(blk: {
const gn, const mn = .{ const gn, const mn = .{
try memo.resolve(allocator, try stack.pop(), true), try memo.resolve(arena, try pop(&stack), true),
try memo.resolve(allocator, try stack.pop(), true), try memo.resolve(arena, try pop(&stack), true),
}; };
const new_seq: Sequence = .{ .type = .tuple, .values = try allocator.dupe(Value, &.{ gn, mn }) }; const new_seq: Sequence = .{ .type = .tuple, .values = try arena.dupe(Value, &.{ gn, mn }) };
break :blk .{ .object = try Object.init(allocator, .{ .seq = new_seq }, &[_]Value{}) }; break :blk .{ .object = try Object.init(arena, .{ .seq = new_seq }, &[_]Value{}) };
}), }),
.memoize => { .memoize => {
const item = stack.values.getLastOrNull() orelse { const item = stack.getLastOrNull() orelse {
return error.StackUnderrun; return error.StackUnderrun;
}; };
try memo.insert(@intCast(memo.map.count()), try item.clone(allocator)); try memo.insert(@intCast(memo.map.count()), try item.clone(arena));
}, },
else => try stack.values.append(.{ .raw = try op.clone(allocator) }), else => try stack.append(.{ .raw = try op.clone(arena) }),
} }
} }
if (!resolve_refs) { if (resolve_refs) {
return .{ try stack.toPickleStack(), memo }; return try memo.resolveAllRefsIter(arena, 0, stack.items, true);
} }
return .{ return stack.toOwnedSlice();
PickleStack.init(allocator, try memo.resolveAllRefsIter(allocator, 0, stack.values.items, true)),
memo,
};
} }
// TODO: this is a unmanaged array list, minus the optimisation. We should use that instead // TODO: this is a unmanaged array list, minus the optimisation. We should use that instead
@ -483,3 +396,86 @@ fn assuredResize(comptime T: type, allocator: std.mem.Allocator, old: []T, new_l
return new; return new;
} }
} }
test evaluate {
var arena = std.heap.ArenaAllocator.init(std.testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const file = try std.fs.cwd().openFile("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
var buffered_reader = std.io.bufferedReader(file.reader());
const ops = try pickle.parse(allocator, buffered_reader.reader(), 4096);
const vals = try evaluate(allocator, ops, true);
defer allocator.free(vals);
try std.testing.expect(vals.len == 1);
try std.testing.expect(vals[0] == .seq);
try std.testing.expect(vals[0].seq.type == .dict);
const entries = vals[0].seq.values[0].seq.values;
try std.testing.expect(entries.len == 5);
const expected: []const Value = &.{
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "hello" }, .{ .string = "world" } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "int" }, .{ .int64 = 1 } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{ .{ .string = "float" }, .{ .float64 = 3.141592 } }) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
.{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(&[_]Value{
.{ .int64 = 0 },
.{ .int64 = 1 },
.{ .int64 = 2 },
.{ .int64 = 3 },
.{ .int64 = 4 },
}) } },
}) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(&[_]Value{
.{ .string = "tuple" },
.{ .seq = .{
.type = .tuple,
.values = @constCast(&[_]Value{
.{ .string = "a" },
.{ .int64 = 10 },
}),
} },
}) } },
};
try std.testing.expectEqualDeep(expected, entries);
}
pub fn pop(values: *std.ArrayList(Value)) !Value {
if (values.items.len == 0) {
return error.StackUnderrun;
}
return values.pop();
}
fn popMarkDiscard(values: *std.ArrayList(Value)) !void {
const mark = try findMark(values);
values.shrinkRetainingCapacity(mark);
}
fn popMark(values: *std.ArrayList(Value), allocator: std.mem.Allocator) ![]Value {
const mark = try findMark(values);
const popping = values.items[mark + 1 ..];
values.shrinkRetainingCapacity(mark);
return try allocator.dupe(Value, popping);
}
fn lastMut(values: *std.ArrayList(Value)) !*Value {
if (values.items.len == 0) {
return error.UnexpectedEmptyStack;
}
return &values.items[values.items.len - 1];
}
fn findMark(values: *std.ArrayList(Value)) !usize {
const len = values.items.len;
for (0..len) |i| {
const idx = (len - 1) - i;
const val = values.items[idx];
if (val == .raw and val.raw == .mark) {
return idx;
}
}
return error.MarkNotFound;
}

View File

@ -17,7 +17,7 @@ pub const Parser = struct {
buffer_file: zml.aio.MemoryMappedFile, buffer_file: zml.aio.MemoryMappedFile,
file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{}, file_map: std.StringArrayHashMapUnmanaged(FileEntry) = .{},
tar_file: ?TarStream = null, tar_file: ?TarStream = null,
ops: []pickle.Op, ops: []const pickle.Op,
is_zip_file: bool, is_zip_file: bool,
zip_prefix: []const u8 = &[_]u8{}, zip_prefix: []const u8 = &[_]u8{},
@ -65,7 +65,7 @@ pub const Parser = struct {
}; };
if (!self.is_zip_file) { if (!self.is_zip_file) {
const reader = tar_stream.reader(); const reader = tar_stream.reader();
self.ops = try parse(allocator, reader, try tar_stream.getEndPos()); self.ops = try pickle.parse(allocator, reader, try tar_stream.getEndPos());
} else { } else {
self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream()); self.ops = try self.parseOps(allocator, self.tar_file.?.seekableStream());
} }
@ -82,7 +82,7 @@ pub const Parser = struct {
}; };
if (!self.is_zip_file) { if (!self.is_zip_file) {
const reader = self.buffer_file.file.reader(); const reader = self.buffer_file.file.reader();
self.ops = try parse(allocator, reader, try reader.context.getEndPos()); self.ops = try pickle.parse(allocator, reader, try reader.context.getEndPos());
} else { } else {
self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream()); self.ops = try self.parseOps(allocator, self.buffer_file.file.seekableStream());
} }
@ -94,7 +94,7 @@ pub const Parser = struct {
self.* = undefined; self.* = undefined;
} }
fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]pickle.Op { fn parseOps(self: *Parser, allocator: Allocator, seekable_stream: anytype) ![]const pickle.Op {
var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream); var iter = try std.zip.Iterator(@TypeOf(seekable_stream)).init(seekable_stream);
var filename_buf: [std.fs.max_path_bytes]u8 = undefined; var filename_buf: [std.fs.max_path_bytes]u8 = undefined;
while (try iter.next()) |entry| { while (try iter.next()) |entry| {
@ -152,7 +152,7 @@ pub const Parser = struct {
switch (entry.compression_method) { switch (entry.compression_method) {
.store => { .store => {
return parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size); return pickle.parse(allocator, seekable_stream.context.reader(), entry.uncompressed_size);
}, },
.deflate => { .deflate => {
// TODO(cryptodeal): handle decompress // TODO(cryptodeal): handle decompress
@ -166,196 +166,6 @@ pub const Parser = struct {
std.log.err("Could not find file ending in `data.pkl` in archive", .{}); std.log.err("Could not find file ending in `data.pkl` in archive", .{});
return error.PickleNotFound; return error.PickleNotFound;
} }
fn parse(allocator: Allocator, reader: anytype, len: usize) ![]pickle.Op {
var results = std.ArrayList(pickle.Op).init(allocator);
errdefer results.deinit();
outer: while (true) {
const b = try reader.readByte();
switch (@as(pickle.OpCode, @enumFromInt(b))) {
.mark => try results.append(.{ .mark = {} }),
.stop => {
try results.append(.{ .stop = {} });
break :outer;
},
.pop => try results.append(.{ .pop = {} }),
.pop_mark => try results.append(.{ .pop_mark = {} }),
.dup => try results.append(.{ .dup = {} }),
.float => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .float = buf });
},
.int => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .int = buf });
},
.binint => try results.append(.{ .binint = try reader.readInt(i32, .little) }),
.binint1 => try results.append(.{ .binint1 = try reader.readByte() }),
.long => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .long = buf });
},
.binint2 => try results.append(.{ .binint2 = try reader.readInt(u16, .little) }),
.none => try results.append(.{ .none = {} }),
.persid => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .persid = buf });
},
.binpersid => try results.append(.{ .binpersid = {} }),
.reduce => try results.append(.{ .reduce = {} }),
.string => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .string = buf });
},
.binstring => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binstring = buf });
},
.short_binstring => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .short_binstring = buf });
},
.unicode => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .unicode = buf });
},
.binunicode => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode = buf });
},
.append => try results.append(.{ .append = {} }),
.build => try results.append(.{ .build = {} }),
.global, .inst => {
const module = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(module);
const class = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(class);
try results.append(.{ .global = .{ .module = module, .class = class } });
},
.dict => try results.append(.{ .dict = {} }),
.empty_dict => try results.append(.{ .empty_dict = {} }),
.appends => try results.append(.{ .appends = {} }),
.get => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .get = buf });
},
.binget => try results.append(.{ .binget = try reader.readByte() }),
.long_binget => try results.append(.{ .long_binget = try reader.readInt(u32, .little) }),
.list => try results.append(.{ .list = {} }),
.empty_list => try results.append(.{ .empty_list = {} }),
.obj => try results.append(.{ .obj = {} }),
.put => {
const buf = try reader.readUntilDelimiterAlloc(allocator, '\n', len);
errdefer allocator.free(buf);
try results.append(.{ .put = buf });
},
.binput => {
try results.append(.{ .binput = try reader.readByte() });
},
.long_binput => {
try results.append(.{ .long_binput = try reader.readInt(u32, .little) });
},
.setitem => try results.append(.{ .setitem = {} }),
.tuple => try results.append(.{ .tuple = {} }),
.empty_tuple => try results.append(.{ .empty_tuple = {} }),
.setitems => try results.append(.{ .setitems = {} }),
.binfloat => try results.append(.{ .binfloat = @bitCast(try reader.readInt(u64, .big)) }),
.proto => try results.append(.{ .proto = try reader.readByte() }),
.newobj => try results.append(.{ .newobj = {} }),
.ext1 => try results.append(.{ .ext1 = try reader.readByte() }),
.ext2 => try results.append(.{ .ext2 = try reader.readInt(i16, .little) }),
.ext4 => try results.append(.{ .ext4 = try reader.readInt(i32, .little) }),
.tuple1 => try results.append(.{ .tuple1 = {} }),
.tuple2 => try results.append(.{ .tuple2 = {} }),
.tuple3 => try results.append(.{ .tuple3 = {} }),
.newtrue => try results.append(.{ .newtrue = {} }),
.newfalse => try results.append(.{ .newfalse = {} }),
.long1 => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .long1 = buf });
},
.long4 => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .long4 = buf });
},
.binbytes => {
const str_len = try reader.readInt(u32, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binbytes = buf });
},
.binbytes8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binbytes8 = buf });
},
.short_binbytes => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .short_binbytes = buf });
},
.binunicode8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode8 = buf });
},
.short_binunicode => {
const str_len = try reader.readByte();
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .binunicode8 = buf });
},
.empty_set => try results.append(.{ .empty_set = {} }),
.additems => try results.append(.{ .additems = {} }),
.frozenset => try results.append(.{ .frozenset = {} }),
.newobj_ex => try results.append(.{ .newobj_ex = {} }),
.stack_global => try results.append(.{ .stack_global = {} }),
.memoize => try results.append(.{ .memoize = {} }),
.frame => try results.append(.{ .frame = try reader.readInt(u64, .little) }),
.bytearray8 => {
const str_len = try reader.readInt(u64, .little);
const buf = try allocator.alloc(u8, str_len);
errdefer allocator.free(buf);
_ = try reader.read(buf);
try results.append(.{ .bytearray8 = buf });
},
.next_buffer => try results.append(.{ .next_buffer = {} }),
.readonly_buffer => try results.append(.{ .readonly_buffer = {} }),
_ => {},
}
}
return results.toOwnedSlice();
}
}; };
const TarStream = struct { const TarStream = struct {
@ -404,54 +214,6 @@ const TarStream = struct {
} }
}; };
test "Read pickle (simple)" {
const Value = @import("value.zig").Value;
var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit();
const allocator = arena.allocator();
const eval = @import("eval.zig");
const file = try asynk.File.open("zml/aio/torch/simple_test.pickle", .{ .mode = .read_only });
var data = try Parser.init(allocator, file);
defer data.deinit();
var vals, var memo = try eval.evaluate(allocator, data.ops, true);
defer vals.deinit();
defer memo.deinit();
try testing.expect(vals.stack.len == 2);
// skip first value (frame)
try testing.expect(vals.stack[1] == .seq);
try testing.expect(vals.stack[1].seq.type == .dict);
const entries = vals.stack[1].seq.values[0].seq.values;
try testing.expect(entries.len == 5);
const expected: []const Value = &.{
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "hello" }, .{ .string = "world" } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "int" }, .{ .int64 = 1 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{ .{ .string = "float" }, .{ .float64 = 3.141592 } })) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "list" },
.{ .seq = .{ .type = .list, .values = @constCast(@as([]const Value, &.{
.{ .int64 = 0 },
.{ .int64 = 1 },
.{ .int64 = 2 },
.{ .int64 = 3 },
.{ .int64 = 4 },
})) } },
})) } },
.{ .seq = .{ .type = .kv_tuple, .values = @constCast(@as([]const Value, &.{
.{ .string = "tuple" },
.{ .seq = .{
.type = .tuple,
.values = @constCast(@as([]const Value, &.{
.{ .string = "a" },
.{ .int64 = 10 },
})),
} },
})) } },
};
try std.testing.expectEqualDeep(expected, entries);
}
test "Read pickle (zipped)" { test "Read pickle (zipped)" {
var arena = std.heap.ArenaAllocator.init(testing.allocator); var arena = std.heap.ArenaAllocator.init(testing.allocator);
defer arena.deinit(); defer arena.deinit();

File diff suppressed because it is too large Load Diff

View File

@ -109,7 +109,7 @@ pub const ValueType = enum {
none, none,
}; };
/// A processed value. /// A pickle operator that has been interpreted.
pub const Value = union(ValueType) { pub const Value = union(ValueType) {
/// Types that we can't handle or just had to give up on processing. /// Types that we can't handle or just had to give up on processing.
raw: pickle.Op, raw: pickle.Op,
@ -283,10 +283,10 @@ pub const Value = union(ValueType) {
return switch (self) { return switch (self) {
inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .raw, .raw_num => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)), inline .app, .object, .global, .build, .pers_id => |v, tag| @unionInit(Value, @tagName(tag), try v.clone(allocator)),
.seq => |seq| blk: { .seq => |seq| {
const new_val: Sequence = .{ .type = seq.type, .values = try allocator.alloc(Value, seq.values.len) }; const values = try allocator.alloc(Value, seq.values.len);
for (seq.values, 0..) |v, i| new_val.values[i] = try v.clone(allocator); for (seq.values, 0..) |v, i| values[i] = try v.clone(allocator);
break :blk .{ .seq = new_val }; return .{ .seq = .{ .type = seq.type, .values = values } };
}, },
inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)), inline .string, .bytes => |v, tag| @unionInit(Value, @tagName(tag), try allocator.dupe(u8, v)),
.bigint => |v| .{ .bigint = try v.clone() }, .bigint => |v| .{ .bigint = try v.clone() },
@ -342,8 +342,9 @@ pub const Value = union(ValueType) {
pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value { pub fn coerceFromRaw(self: Value, allocator: std.mem.Allocator) !Value {
return switch (self) { return switch (self) {
.raw => |raw_val| switch (raw_val) { .raw => |raw_val| switch (raw_val) {
.binint, .binint1, .binint2 => |val| .{ .int64 = val }, .binint => |val| .{ .int64 = val },
.long1, .long4 => |b| if (b.len != 0) { .long => |b| if (b.len != 0) {
// TODO: handle trailing 'L'
var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len)); var bint = try big_int.Managed.initCapacity(allocator, std.math.big.int.calcTwosCompLimbCount(b.len));
var mutable = bint.toMutable(); var mutable = bint.toMutable();
mutable.readTwosComplement(b, b.len, .little, .signed); mutable.readTwosComplement(b, b.len, .little, .signed);
@ -355,28 +356,17 @@ pub const Value = union(ValueType) {
} else return .{ .bigint = bint }; } else return .{ .bigint = bint };
} else .{ .raw_num = raw_val }, } else .{ .raw_num = raw_val },
.binfloat => |val| .{ .float64 = val }, .binfloat => |val| .{ .float64 = val },
.binunicode, .binunicode8, .short_binunicode => |s| .{ .string = s }, .unicode => |s| .{ .string = s },
.binbytes, .binbytes8, .short_binbytes, .bytearray8 => |b| .{ .bytes = b }, .bytes => |b| .{ .bytes = b },
// This isn't how Pickle actually works but we just try to UTF8 decode the // This isn't how Pickle actually works but we just try to UTF8 decode the
// string and if it fails, we make it a bytes value instead. If anyone // string and if it fails, we make it a bytes value instead. If anyone
// actually cares they can just fix values themselves or recover the raw bytes // actually cares they can just fix values themselves or recover the raw bytes
// from the UTF8 string (it's guaranteed to be reversible, as far as I know). // from the UTF8 string (it's guaranteed to be reversible, as far as I know).
.binstring, .short_binstring => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b }, .string => |b| if (std.unicode.utf8ValidateSlice(b)) .{ .string = b } else .{ .bytes = b },
.newtrue => .{ .boolval = true }, .bool => |b| .{ .boolval = b },
.newfalse => .{ .boolval = false },
.none => .{ .none = {} }, .none => .{ .none = {} },
inline .int, // TODO .int should be handled like .long
.float, .int, .float => .{ .raw_num = raw_val },
.long,
=> |v, tag| {
if (tag == .int and std.mem.eql(u8, v, "01")) {
return .{ .boolval = true };
} else if (tag == .int and std.mem.eql(u8, v, "00")) {
return .{ .boolval = false };
} else {
return .{ .raw_num = raw_val };
}
},
else => self, else => self,
}, },
.app, .object, .global => |v| blk: { .app, .object, .global => |v| blk: {

View File

@ -25,8 +25,8 @@ pub fn open(allocator: Allocator, path: []const u8) !zml.aio.BufferStore {
pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void { pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: StringBuilder, val: yaml.Value) !void {
switch (val) { switch (val) {
.int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int64 = v }), .int => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .int = v }),
.float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float64 = v }), .float => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .float = v }),
.string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }), .string => |v| try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .string = v }),
.list => |v| switch (validSlice(v)) { .list => |v| switch (validSlice(v)) {
true => { true => {
@ -36,13 +36,13 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
const values = try allocator.alloc(i64, v.len); const values = try allocator.alloc(i64, v.len);
errdefer allocator.free(values); errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.int; for (v, 0..) |item, i| values[i] = item.int;
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .int64, .data = std.mem.sliceAsBytes(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_int = values });
}, },
.float => { .float => {
const values = try allocator.alloc(f64, v.len); const values = try allocator.alloc(f64, v.len);
errdefer allocator.free(values); errdefer allocator.free(values);
for (v, 0..) |item, i| values[i] = item.float; for (v, 0..) |item, i| values[i] = item.float;
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .float64, .data = std.mem.sliceAsBytes(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_float = values });
}, },
.string => { .string => {
const values = try allocator.alloc([]const u8, v.len); const values = try allocator.alloc([]const u8, v.len);
@ -50,7 +50,7 @@ pub fn parseMetadata(allocator: Allocator, store: *zml.aio.BufferStore, key: Str
for (v, 0..) |item, i| { for (v, 0..) |item, i| {
values[i] = try allocator.dupe(u8, item.string); values[i] = try allocator.dupe(u8, item.string);
} }
try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array = .{ .item_type = .string, .data = std.mem.sliceAsBytes(values) } }); try store._metadata.put(allocator, try allocator.dupe(u8, key.items), .{ .array_string = values });
}, },
.list => unreachable, .list => unreachable,
else => {}, else => {},

View File

@ -145,6 +145,8 @@ pub fn isSliceOfAny(comptime T: type, comptime f: fn (comptime type) bool) bool
} }
pub fn DeclEnum(comptime T: type) type { pub fn DeclEnum(comptime T: type) type {
const field_infos = std.meta.declarations(T);
if (field_infos.len == 0) compileError("Struct {} has no declarations", .{T});
return std.meta.DeclEnum(UnwrapPtr(T)); return std.meta.DeclEnum(UnwrapPtr(T));
} }