diff --git a/zml/aio/tinyllama.zig b/zml/aio/tinyllama.zig index c39f3e2..e3fe3f9 100644 --- a/zml/aio/tinyllama.zig +++ b/zml/aio/tinyllama.zig @@ -152,6 +152,7 @@ pub fn loadTokenizer(allocator: std.mem.Allocator, tokenizer_path: []const u8, v } tokenizer.vocab_size = i; } + try tokenizer.rewriteByteFallbackTokens(); return tokenizer; } diff --git a/zml/buffer.zig b/zml/buffer.zig index 1fffbaf..a9fba67 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -124,12 +124,50 @@ pub const Buffer = struct { /// Creates a Buffer with a single element repeated manytime. pub fn constant(platform: Platform, shape_: Shape, val: anytype) !Buffer { + var start = try std.time.Timer.start(); + defer { + const duration_ms = stdx.math.divFloat(f32, start.read(), std.time.ns_per_ms); + if (duration_ms > 100) { + const size_gb = stdx.math.divFloat(f32, shape_.byteSize(), 1024 * 1024 * 1024); + log.info("Wrote constant({_}) to device ({d:.2}Gb) in {d:.0}ms: {d:.2}Gb/s", .{ shape_, size_gb, duration_ms, size_gb / duration_ms * 1000 }); + } + } + + // Convert val to the requested dtype. const x = shape_.dtype().constant(val); - const host_buffer: HostBuffer = .{ - ._shape = shape_, - ._strides = [1]i64{0} ** Shape.MAX_RANK, - .data = x.constSlice(), - }; + const byte_size = shape_.dtype().sizeOf(); + const max_bytes = 1024; + + // Naive version for scalars and buffers with long last axis. + if (shape_.rank() < 1 or byte_size * shape_.dim(-1) > max_bytes) { + const host_buffer: HostBuffer = .{ + ._shape = shape_, + ._strides = [1]i64{0} ** Shape.MAX_RANK, + .data = x.constSlice(), + }; + return try from(platform, host_buffer); + } + + // To speed up copies, duplicate the scalar value into a vector, + // so that PJRT can copy row by row. + // Because this is respecting the shape, it won't work if the last axis is too big. + // If this becomes an issue, we should create a new intermediary Buffer by splitting last axis into { n, max_bytes } + // so that the trick works, and then reshape it + // We could also handle sharded constant directly in this function to avoid having to create too big arrays. + var bytes: [max_bytes]u8 align(64) = undefined; + var strides = [1]i64{0} ** Shape.MAX_RANK; + strides[shape_.rank() - 1] = byte_size; + + switch (byte_size) { + inline 1, 2, 4, 8, 16 => |b| { + const Int = std.meta.Int(.unsigned, b * 8); + const x_as_int: Int = @bitCast(x.constSlice()[0..b].*); + const bytes_as_int: [*]Int = @ptrCast(&bytes); + @memset(bytes_as_int[0..@intCast(shape_.dim(-1))], x_as_int); + }, + else => unreachable, + } + const host_buffer: HostBuffer = .{ ._shape = shape_, ._strides = strides, .data = &bytes }; return try from(platform, host_buffer); } diff --git a/zml/hostbuffer.zig b/zml/hostbuffer.zig index 34531b9..adaa02e 100644 --- a/zml/hostbuffer.zig +++ b/zml/hostbuffer.zig @@ -195,9 +195,12 @@ pub const HostBuffer = struct { } pub fn isContiguous(self: HostBuffer) bool { - const strd = self._strides orelse return true; + const _strides = self._strides orelse return true; const cont_strides = self._shape.computeStrides(); - return std.mem.eql(i64, strd[0..self.rank()], cont_strides.constSlice()); + for (self._shape.dims(), _strides[0..self.rank()], cont_strides.constSlice()) |d, stride, cont_stride| { + if (d != 1 and stride != cont_stride) return false; + } + return true; } pub fn reshape(self: HostBuffer, shape_: anytype) HostBuffer { @@ -219,9 +222,9 @@ pub const HostBuffer = struct { const start: i64 = if (s.start < 0) s.start + d else s.start; var end = s.end orelse d; if (end < 0) end += d; - stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, start }); - stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, end }); - stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({})", .{ self, ax, start, end }); + stdx.debug.assert(start >= 0 and start < d, "slice1d({}, {}) expects the slice start to be between 0 and {} got: {}", .{ self, ax, d, s }); + stdx.debug.assert(end >= 1 and end <= d, "slice1d({}, {}) expects the slice end to be between 1 and {} got: {}", .{ self, ax, d, s }); + stdx.debug.assert(start < end, "slice1d({}, {}) expects the slice start ({}) to be smaller than the end ({}), got: {}", .{ self, ax, start, end, s }); // If strides weren't set it means original buffer is contiguous. // But it won't be anymore after slicing. The strides don't change though. @@ -230,7 +233,8 @@ pub const HostBuffer = struct { return .{ ._shape = self.shape().set(ax, end - start), .data = self.data[offset..], - ._strides = _strides, + // When axis is 0, we stay contiguous. + ._strides = if (ax == 0) self._strides else _strides, ._memory = .unmanaged, }; } diff --git a/zml/module.zig b/zml/module.zig index 58d4c8b..86d7908 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -243,10 +243,8 @@ pub const CompilationContext = struct { } const loaded_executable = compileModuleToPjrtExecutable(arena, self._platform, module, module_dir) catch |err| { - log.err( - "pjrt-{s} failed to compile following valid MLIR:\n{}\n{}", - .{ @tagName(self._platform.target), module.op().mlirFormatter(.{}), err }, - ); + log.err("pjrt-{s} failed to compile: {}", .{ @tagName(self._platform.target), err }); + if (module_dir) |dir| log.err("mlir can be found at {s}/module.mlir", .{dir}); return err; }; @@ -947,8 +945,8 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m fn setFlag(options: *xla_pb.CompileOptionsProto, comptime flag: [:0]const u8, value: anytype) void { const option: xla_pb.OptionOverrideProto = switch (@typeInfo(@TypeOf(value))) { .Bool => .{ .value = .{ .bool_field = value } }, - .Int => .{ .value = .{ .int_field = value } }, - .Float => .{ .value = .{ .double_field = value } }, + .ComptimeInt, .Int => .{ .value = .{ .int_field = value } }, + .ComptimeFloat, .Float => .{ .value = .{ .double_field = value } }, else => .{ .value = .{ .string_field = .{ .Const = value } } }, }; options.env_option_overrides.appendAssumeCapacity(.{ .key = .{ .Const = flag }, .value = option }); diff --git a/zml/tokenizer.zig b/zml/tokenizer.zig index dbe2046..e62ff2a 100644 --- a/zml/tokenizer.zig +++ b/zml/tokenizer.zig @@ -275,7 +275,7 @@ pub const Tokenizer = struct { else if (id == self.special_tokens.unk) "" else if (id > self.tokens.len) - std.debug.panic("Unexpected token id: {d}, vocab_size: {d}", .{ id, self.vocab_size }) + "" // this means we received an invalid id, but we didn't want to panic. else self.tokens[id]; }