2023-01-02 14:28:25 +00:00
|
|
|
const std = @import("std");
|
2025-06-24 15:39:55 +00:00
|
|
|
const Allocator = std.mem.Allocator;
|
|
|
|
|
|
|
|
|
|
const asynk = @import("async");
|
|
|
|
|
|
|
|
|
|
const MemoryMappedFile = @import("../aio.zig").MemoryMappedFile;
|
2023-01-02 14:28:25 +00:00
|
|
|
const zml = @import("../zml.zig");
|
2023-04-04 17:20:53 +00:00
|
|
|
const HostBuffer = zml.HostBuffer;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
const StringBuilder = std.ArrayListUnmanaged(u8);
|
2023-06-21 14:45:14 +00:00
|
|
|
const log = std.log.scoped(.@"zml/io");
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore {
|
|
|
|
|
var res: zml.aio.BufferStore = .{
|
|
|
|
|
.arena = std.heap.ArenaAllocator.init(allocator),
|
|
|
|
|
};
|
|
|
|
|
errdefer res.arena.deinit();
|
|
|
|
|
const arena = res.arena.allocator();
|
|
|
|
|
|
|
|
|
|
var files = std.ArrayList(MemoryMappedFile).init(arena);
|
|
|
|
|
errdefer files.deinit();
|
|
|
|
|
|
|
|
|
|
if (std.mem.endsWith(u8, path, ".safetensors.index.json")) {
|
|
|
|
|
try loadFromIndex(arena, &res, &files, path);
|
|
|
|
|
} else {
|
|
|
|
|
try loadFile(arena, &res, &files, path);
|
|
|
|
|
}
|
|
|
|
|
res.files = try files.toOwnedSlice();
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn loadFromIndex(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
|
|
|
|
const file = asynk.File.open(path, .{}) catch |err| {
|
|
|
|
|
log.err("Failed to open {s}: {}", .{ path, err });
|
|
|
|
|
return err;
|
|
|
|
|
};
|
|
|
|
|
errdefer file.close() catch unreachable;
|
|
|
|
|
var r = file.reader();
|
|
|
|
|
|
|
|
|
|
const json_data = try allocator.alloc(u8, (try file.stat()).size);
|
|
|
|
|
_ = try r.readAtLeast(json_data, json_data.len);
|
2023-05-23 15:06:59 +00:00
|
|
|
const index = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data, .{ .allocate = .alloc_if_needed });
|
2023-01-02 14:28:25 +00:00
|
|
|
var loaded_files = std.StringHashMap(void).init(allocator);
|
|
|
|
|
|
2023-05-23 15:06:59 +00:00
|
|
|
const weight_map = index.object.get("weight_map").?.object;
|
2023-01-02 14:28:25 +00:00
|
|
|
var it = weight_map.iterator();
|
|
|
|
|
while (it.next()) |entry| {
|
|
|
|
|
const filename = entry.value_ptr.string;
|
|
|
|
|
if (loaded_files.contains(filename)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
log.debug("Loading shard: {s}", .{filename});
|
|
|
|
|
try loaded_files.put(filename, {});
|
|
|
|
|
|
|
|
|
|
const full_filename = try std.fs.path.join(allocator, &.{ std.fs.path.dirname(path).?, filename });
|
|
|
|
|
try loadFile(allocator, store, files, full_filename);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn loadFile(allocator: Allocator, store: *zml.aio.BufferStore, files: *std.ArrayList(MemoryMappedFile), path: []const u8) !void {
|
|
|
|
|
const file = asynk.File.open(path, .{}) catch |err| {
|
|
|
|
|
log.err("Failed to open {s}: {}", .{ path, err });
|
|
|
|
|
return err;
|
|
|
|
|
};
|
|
|
|
|
errdefer file.close() catch unreachable;
|
|
|
|
|
var r = file.reader();
|
|
|
|
|
|
|
|
|
|
const json_header_length: usize = @intCast(try r.readInt(u64, std.builtin.Endian.little));
|
|
|
|
|
const json_data = try allocator.alloc(u8, json_header_length);
|
2023-03-28 16:17:00 +00:00
|
|
|
const n = try r.readAll(json_data);
|
|
|
|
|
if (n != json_header_length) {
|
|
|
|
|
log.err("Failed to read the full {} bytes of json header from file {s}", .{ n, path });
|
|
|
|
|
return error.CorruptedFile;
|
|
|
|
|
}
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-03-28 16:17:00 +00:00
|
|
|
const metadata = try std.json.parseFromSliceLeaky(std.json.Value, allocator, json_data[0..n], .{});
|
2023-01-02 14:28:25 +00:00
|
|
|
var buffer_file = try MemoryMappedFile.init(file);
|
|
|
|
|
errdefer buffer_file.deinit();
|
|
|
|
|
buffer_file.data_offset = 8 + json_header_length;
|
|
|
|
|
|
|
|
|
|
try files.append(buffer_file);
|
2024-07-02 14:19:04 +00:00
|
|
|
errdefer _ = files.pop();
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
var it = metadata.object.iterator();
|
|
|
|
|
while (it.next()) |entry| {
|
|
|
|
|
const key = entry.key_ptr.*;
|
|
|
|
|
const val = entry.value_ptr.*;
|
|
|
|
|
const shape_field = val.object.get("shape").?.array;
|
|
|
|
|
if (shape_field.items.len > zml.Shape.MAX_RANK) {
|
|
|
|
|
// Not an error until someone tries to read the tensor itself.
|
|
|
|
|
log.warn("Can't load tensor {s}, too many dims: {}", .{ key, shape_field.items.len });
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const offset_field = val.object.get("data_offsets").?;
|
|
|
|
|
const start: usize = @intCast(offset_field.array.items[0].integer);
|
|
|
|
|
const end: usize = @intCast(offset_field.array.items[1].integer);
|
|
|
|
|
const dtype = try stringToDtype(val.object.get("dtype").?.string);
|
|
|
|
|
var dims: std.BoundedArray(i64, zml.Shape.MAX_RANK) = .{};
|
|
|
|
|
for (shape_field.items) |d| {
|
|
|
|
|
dims.appendAssumeCapacity(d.integer);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const out_shape = zml.Shape.init(dims.constSlice(), dtype);
|
|
|
|
|
// We aren't storing 'end', so check we can infer it from the tensor shape.
|
|
|
|
|
// This is fine cause safetensor only allow storing contiguous tensors.
|
|
|
|
|
// https://github.com/huggingface/safetensors/blob/main/README.md#format
|
|
|
|
|
// > The byte buffer needs to be entirely indexed, and cannot contain holes. This prevents the creation of polyglot files.
|
|
|
|
|
std.debug.assert(end - start == out_shape.byteSize());
|
|
|
|
|
|
|
|
|
|
const buf = HostBuffer.fromBytes(out_shape, buffer_file.mappedSlice(start, out_shape.byteSize()));
|
|
|
|
|
try store.buffers.put(allocator, try allocator.dupe(u8, key), buf);
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-03-28 16:17:00 +00:00
|
|
|
|
|
|
|
|
fn stringToDtype(safetensor_type: []const u8) !zml.DataType {
|
|
|
|
|
const map = std.StaticStringMap(zml.DataType).initComptime(.{
|
|
|
|
|
.{ "F64", .f64 },
|
|
|
|
|
.{ "F32", .f32 },
|
|
|
|
|
.{ "F16", .f16 },
|
|
|
|
|
.{ "BF16", .bf16 },
|
|
|
|
|
.{ "F8_E4M3", .f8e4m3fn },
|
|
|
|
|
.{ "I64", .i64 },
|
|
|
|
|
.{ "I32", .i32 },
|
|
|
|
|
.{ "I16", .i16 },
|
|
|
|
|
.{ "I8", .i8 },
|
|
|
|
|
.{ "U64", .u64 },
|
|
|
|
|
.{ "U32", .u32 },
|
|
|
|
|
.{ "U16", .u16 },
|
|
|
|
|
.{ "U8", .u8 },
|
|
|
|
|
.{ "BOOL", .bool },
|
|
|
|
|
});
|
|
|
|
|
|
|
|
|
|
return map.get(safetensor_type) orelse {
|
|
|
|
|
log.err("Unsupported safetensor data type: {s}", .{safetensor_type});
|
|
|
|
|
return error.UnsupportedDataType;
|
|
|
|
|
};
|
|
|
|
|
}
|