Various minor fixes: rewrite tinyllama tokenizer newline token, prevent HostBuffer.isContiguous false trigger on 1‑dim axes, improve HostBuffer.slice1d error messages, simplify module.zig output to show .mlir file path, correct setFlags handling of comptime int/float, make tokenizer.zig return <oob> for out‑of‑range detokenization, and speed up Buffer.constant creation up to 2.5 GB/s on CUDA.
This commit is contained in:
parent
3970df5b48
commit
c109b12e1b
@ -152,6 +152,7 @@ pub fn loadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8, v
|
|||||||
}
|
}
|
||||||
tokenizer.vocab_size = i;
|
tokenizer.vocab_size = i;
|
||||||
}
|
}
|
||||||
|
try tokenizer.rewriteByteFallbackTokens();
|
||||||
return tokenizer;
|
return tokenizer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -124,12 +124,50 @@ pub const Buffer = struct {
|
|||||||
|
|
||||||
/// Creates a Buffer with a single element repeated manytime.
|
/// Creates a Buffer with a single element repeated manytime.
|
||||||
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
|
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
|
||||||
|
var start = try std.time.Timer.start();
|
||||||
|
defer {
|
||||||
|
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
|
||||||
|
if (duration_ms > 100) {
|
||||||
|
const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024);
|
||||||
|
log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert val to the requested dtype.
|
||||||
const x = shape_.dtype().constant(val);
|
const x = shape_.dtype().constant(val);
|
||||||
const host_buffer: HostBuffer = .{
|
const byte_size = shape_.dtype().sizeOf();
|
||||||
._shape = shape_,
|
const max_bytes = 1024;
|
||||||
._strides = [1]i64{0} ** Shape.MAX_RANK,
|
|
||||||
.data = x.constSlice(),
|
// Naive version for scalars and buffers with long last axis.
|
||||||
};
|
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
|
||||||
|
const host_buffer: HostBuffer = .{
|
||||||
|
._shape = shape_,
|
||||||
|
._strides = [1]i64{0} ** Shape.MAX_RANK,
|
||||||
|
.data = x.constSlice(),
|
||||||
|
};
|
||||||
|
return try from(platform, host_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
// To speed up copies, duplicate the scalar value into a vector,
|
||||||
|
// so that PJRT can copy row by row.
|
||||||
|
// Because this is respecting the shape, it won't work if the last axis is too big.
|
||||||
|
// If this becomes an issue, we should create a new intermediary Buffer by splitting last axis into { n, max_bytes }
|
||||||
|
// so that the trick works, and then reshape it
|
||||||
|
// We could also handle sharded constant directly in this function to avoid having to create too big arrays.
|
||||||
|
var bytes: [max_bytes]u8 align(64) = undefined;
|
||||||
|
var strides = [1]i64{0} ** Shape.MAX_RANK;
|
||||||
|
strides[shape_.rank() - 1] = byte_size;
|
||||||
|
|
||||||
|
switch (byte_size) {
|
||||||
|
inline 1, 2, 4, 8, 16 => |b| {
|
||||||
|
const Int = std.meta.Int(.unsigned, b * 8);
|
||||||
|
const x_as_int: Int = @bitCast(x.constSlice()[0..b].*);
|
||||||
|
const bytes_as_int: [*]Int = @ptrCast(&bytes);
|
||||||
|
@memset(bytes_as_int[0..@intCast(shape_.dim(-1))], x_as_int);
|
||||||
|
},
|
||||||
|
else => unreachable,
|
||||||
|
}
|
||||||
|
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes };
|
||||||
return try from(platform, host_buffer);
|
return try from(platform, host_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -195,9 +195,12 @@ pub const HostBuffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn isContiguous(self: HostBuffer) bool {
|
pub fn isContiguous(self: HostBuffer) bool {
|
||||||
const strd = self._strides orelse return true;
|
const _strides = self._strides orelse return true;
|
||||||
const cont_strides = self._shape.computeStrides();
|
const cont_strides = self._shape.computeStrides();
|
||||||
return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice());
|
for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
|
||||||
|
if (d != 1 and stride != cont_stride) return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
|
||||||
@ -219,9 +222,9 @@ pub const HostBuffer = struct {
|
|||||||
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
const start: i64 = if (s.start < 0) s.start + d else s.start;
|
||||||
var end = s.end orelse d;
|
var end = s.end orelse d;
|
||||||
if (end < 0) end += d;
|
if (end < 0) end += d;
|
||||||
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start });
|
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
|
||||||
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end });
|
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
|
||||||
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end });
|
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
|
||||||
|
|
||||||
// If strides weren't set it means original buffer is contiguous.
|
// If strides weren't set it means original buffer is contiguous.
|
||||||
// But it won't be anymore after slicing. The strides don't change though.
|
// But it won't be anymore after slicing. The strides don't change though.
|
||||||
@ -230,7 +233,8 @@ pub const HostBuffer = struct {
|
|||||||
return .{
|
return .{
|
||||||
._shape = self.shape().set(ax, end - start),
|
._shape = self.shape().set(ax, end - start),
|
||||||
.data = self.data[offset..],
|
.data = self.data[offset..],
|
||||||
._strides = _strides,
|
// When axis is 0, we stay contiguous.
|
||||||
|
._strides = if (ax == 0) self._strides else _strides,
|
||||||
._memory = .unmanaged,
|
._memory = .unmanaged,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -243,10 +243,8 @@ pub const CompilationContext = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| {
|
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| {
|
||||||
log.err(
|
log.err("pjrt-{s} failed to compile: {}", .{ @tagName(self._platform.target), err });
|
||||||
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}",
|
if (module_dir) |dir| log.err("mlir can be found at {s}/module.mlir", .{dir});
|
||||||
.{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err },
|
|
||||||
);
|
|
||||||
return err;
|
return err;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -947,8 +945,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
|
||||||
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
|
||||||
.Bool => .{ .value = .{ .bool_field = value } },
|
.Bool => .{ .value = .{ .bool_field = value } },
|
||||||
.Int => .{ .value = .{ .int_field = value } },
|
.ComptimeInt, .Int => .{ .value = .{ .int_field = value } },
|
||||||
.Float => .{ .value = .{ .double_field = value } },
|
.ComptimeFloat, .Float => .{ .value = .{ .double_field = value } },
|
||||||
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
else => .{ .value = .{ .string_field = .{ .Const = value } } },
|
||||||
};
|
};
|
||||||
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });
|
||||||
|
|||||||
@ -275,7 +275,7 @@ pub const Tokenizer = struct {
|
|||||||
else if (id == self.special_tokens.unk)
|
else if (id == self.special_tokens.unk)
|
||||||
"<unk>"
|
"<unk>"
|
||||||
else if (id > self.tokens.len)
|
else if (id > self.tokens.len)
|
||||||
std.debug.panic("Unexpected token id: {d}, vocab_size: {d}", .{ id, self.vocab_size })
|
"<oob>" // this means we received an invalid id, but we didn't want to panic.
|
||||||
else
|
else
|
||||||
self.tokens[id];
|
self.tokens[id];
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user