Various minor fixes: rewrite tinyllama tokenizer newline token, prevent HostBuffer.isContiguous false trigger on 1‑dim axes, improve HostBuffer.slice1d error messages, simplify module.zig output to show .mlir file path, correct setFlags handling of comptime int/float, make tokenizer.zig return <oob> for out‑of‑range detokenization, and speed up Buffer.constant creation up to 2.5 GB/s on CUDA.

This commit is contained in:
Tarry Singh 2024-02-19 12:34:18 +00:00
parent 3970df5b48
commit c109b12e1b
5 changed files with 59 additions and 18 deletions

View File

@ -152,6 +152,7 @@ pub fn loadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8, v
} }
tokenizer.vocab_size = i; tokenizer.vocab_size = i;
} }
try tokenizer.rewriteByteFallbackTokens();
return tokenizer; return tokenizer;
} }

View File

@ -124,7 +124,22 @@ pub const Buffer = struct {
/// Creates a Buffer with a single element repeated manytime. /// Creates a Buffer with a single element repeated manytime.
pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer { pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer {
var start = try std.time.Timer.start();
defer {
const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms);
if (duration_ms > 100) {
const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024);
log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 });
}
}
// Convert val to the requested dtype.
const x = shape_.dtype().constant(val); const x = shape_.dtype().constant(val);
const byte_size = shape_.dtype().sizeOf();
const max_bytes = 1024;
// Naive version for scalars and buffers with long last axis.
if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) {
const host_buffer: HostBuffer = .{ const host_buffer: HostBuffer = .{
._shape = shape_, ._shape = shape_,
._strides = [1]i64{0} ** Shape.MAX_RANK, ._strides = [1]i64{0} ** Shape.MAX_RANK,
@ -133,6 +148,29 @@ pub const Buffer = struct {
return try from(platform, host_buffer); return try from(platform, host_buffer);
} }
// To speed up copies, duplicate the scalar value into a vector,
// so that PJRT can copy row by row.
// Because this is respecting the shape, it won't work if the last axis is too big.
// If this becomes an issue, we should create a new intermediary Buffer by splitting last axis into { n, max_bytes }
// so that the trick works, and then reshape it
// We could also handle sharded constant directly in this function to avoid having to create too big arrays.
var bytes: [max_bytes]u8 align(64) = undefined;
var strides = [1]i64{0} ** Shape.MAX_RANK;
strides[shape_.rank() - 1] = byte_size;
switch (byte_size) {
inline 1, 2, 4, 8, 16 => |b| {
const Int = std.meta.Int(.unsigned, b * 8);
const x_as_int: Int = @bitCast(x.constSlice()[0..b].*);
const bytes_as_int: [*]Int = @ptrCast(&bytes);
@memset(bytes_as_int[0..@intCast(shape_.dim(-1))], x_as_int);
},
else => unreachable,
}
const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes };
return try from(platform, host_buffer);
}
test constant { test constant {
const zml = @import("zml.zig"); const zml = @import("zml.zig");
const platform = zml.testing.env(); const platform = zml.testing.env();

View File

@ -195,9 +195,12 @@ pub const HostBuffer = struct {
} }
pub fn isContiguous(self: HostBuffer) bool { pub fn isContiguous(self: HostBuffer) bool {
const strd = self._strides orelse return true; const _strides = self._strides orelse return true;
const cont_strides = self._shape.computeStrides(); const cont_strides = self._shape.computeStrides();
return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice()); for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| {
if (d != 1 and stride != cont_stride) return false;
}
return true;
} }
pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer {
@ -219,9 +222,9 @@ pub const HostBuffer = struct {
const start: i64 = if (s.start < 0) s.start + d else s.start; const start: i64 = if (s.start < 0) s.start + d else s.start;
var end = s.end orelse d; var end = s.end orelse d;
if (end < 0) end += d; if (end < 0) end += d;
stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s });
stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s });
stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s });
// If strides weren't set it means original buffer is contiguous. // If strides weren't set it means original buffer is contiguous.
// But it won't be anymore after slicing. The strides don't change though. // But it won't be anymore after slicing. The strides don't change though.
@ -230,7 +233,8 @@ pub const HostBuffer = struct {
return .{ return .{
._shape = self.shape().set(ax, end - start), ._shape = self.shape().set(ax, end - start),
.data = self.data[offset..], .data = self.data[offset..],
._strides = _strides, // When axis is 0, we stay contiguous.
._strides = if (ax == 0) self._strides else _strides,
._memory = .unmanaged, ._memory = .unmanaged,
}; };
} }

View File

@ -243,10 +243,8 @@ pub const CompilationContext = struct {
} }
const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| { const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| {
log.err( log.err("pjrt-{s} failed to compile: {}", .{ @tagName(self._platform.target), err });
"pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", if (module_dir) |dir| log.err("mlir can be found at {s}/module.mlir", .{dir});
.{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err },
);
return err; return err;
}; };
@ -947,8 +945,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void { fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void {
const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) { const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) {
.Bool => .{ .value = .{ .bool_field = value } }, .Bool => .{ .value = .{ .bool_field = value } },
.Int => .{ .value = .{ .int_field = value } }, .ComptimeInt, .Int => .{ .value = .{ .int_field = value } },
.Float => .{ .value = .{ .double_field = value } }, .ComptimeFloat, .Float => .{ .value = .{ .double_field = value } },
else => .{ .value = .{ .string_field = .{ .Const = value } } }, else => .{ .value = .{ .string_field = .{ .Const = value } } },
}; };
options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option }); options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option });

View File

@ -275,7 +275,7 @@ pub const Tokenizer = struct {
else if (id == self.special_tokens.unk) else if (id == self.special_tokens.unk)
"<unk>" "<unk>"
else if (id > self.tokens.len) else if (id > self.tokens.len)
std.debug.panic("Unexpected token id: {d}, vocab_size: {d}", .{ id, self.vocab_size }) "<oob>" // this means we received an invalid id, but we didn't want to panic.
else else
self.tokens[id]; self.tokens[id];
} }