Radix/zml/quantization.zig

102 lines
4.7 KiB
Zig

const std = @import("std");
const Allocator = std.mem.Allocator;
const zml = @import("zml.zig");
const module = zml.module;
// TODO add tests, use modern zml
pub fn Q4_0(comptime dtype: zml.DataType) type {
return struct {
const Self = @This();
const QuantType = zml.io.gguf.quants.QuantType.q4_0;
quant_buffer: zml.Tensor,
pub fn compile(
allocator: Allocator,
ctx: *zml.Context,
input: zml.Tensor,
shape: zml.Shape,
) !module.CompiledModule(Self.forward) {
std.debug.assert(input.dtype() == .u8);
std.debug.assert(input.rank() == 1);
return module.compile(
allocator,
ctx,
Self.forward,
Self{ .quant_buffer = input },
.{shape},
) catch unreachable;
}
/// Each block is composed of a f16 scale and 32 4-bit values.
const block_stride = 18;
pub fn forward(self: Self, shape: zml.Shape) zml.Tensor {
const input = self.quant_buffer;
const block_count: u63 = @intCast(@divExact(input.dim(0), block_stride));
const scales = extractScales(block_count, input);
const weights = extractWeights(block_count, input);
return scales.broadcast(weights.shape(), &.{0})
.mul(weights)
.convert(dtype)
.reshape(shape);
}
pub fn extractScales(block_count: u63, input: zml.Tensor) zml.Tensor {
// The goal here is to get the first two bytes of every 18-bytes block in the input. For that,
// we generate a list of indices that we will use to gather from the input.
// indices1 is the offsets of the scale bytes, repeated block_count times.
const indices1 = zml.Tensor.arange(.{ .start = 0, .end = 2 }, .i32).repeat1d(0, block_count);
// indices2 is the offsets of the blocks, repeated for each scale byte, repeated block_count times.
const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32)
.reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 2 }, .i32)).reshape(.{2 * block_count});
// indices is the sum of the two, which is the offsets to all the bytes we are interested in.
const indices = indices1.add(indices2);
// We select the values we are interested in with the indices, group them by pair and bitcast them to f16, then convert them to f32.
const scales = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true }).reshape(.{ block_count, 2 }).bitCast(.f16).convert(.f32);
return scales;
}
pub fn extractWeights(block_count: u63, input: zml.Tensor) zml.Tensor {
// The goal here is to get everything but the first two bytes of every 18-bytes block in the input. For that,
// we generate a list of indices that we will use to gather from the input.
// indices1 is the offsets of the data bytes, repeated block_count times.
const indices1 = zml.Tensor.arange(.{ .start = 2, .end = 18 }, .i32).repeat1d(0, block_count);
// indices2 is the offsets of the blocks, repeated for each data byte, repeated block_count times.
const indices2 = zml.Tensor.arange(.{ .start = 0, .end = block_stride * block_count, .step = block_stride }, .i32)
.reshape(.{ block_count, 1 }).broadcastLeft(zml.Shape.init(.{ block_count, 16 }, .i32)).reshape(.{16 * block_count});
// indices is the sum of the two, which is the offsets to all the bytes we are interested in.
const indices = indices1.add(indices2);
// NOTE(Corendos): i4 is not supported by bitcast convert, so we need the following workaround.
// We select the values we are interested in with the indices, these are our quantized_weights.
const quantized_weights = input.gather_(&.{0}, &.{indices}, .{ .indices_are_sorted = true });
const lb_weights = quantized_weights
.logical(.And, zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 0xf)))
.bitCast(.i8);
const hb_weights = quantized_weights
.shiftRightLogical(zml.Tensor.constant(.{16 * block_count}, zml.Data.init(.u8, 4))).bitCast(.i8);
const weights = zml.Tensor.concatenate(
&.{ lb_weights.reshape(.{ block_count, 16 }), hb_weights.reshape(.{ block_count, 16 }) },
1,
)
.sub(zml.Tensor.constant(.{ block_count, 32 }, zml.Data.init(.i8, 8)))
.convert(.f32);
return weights;
}
};
}