zml: reduce memory usage of sdpaMemEfficient by using zml.ops.while instead of zml.ops.for, avoiding concatenation of intermediate results.

This commit is contained in:
Tarry Singh 2023-08-14 14:24:11 +00:00
parent 022baf782b
commit 0709b1b32f
8 changed files with 294 additions and 457 deletions

View File

@ -1,86 +0,0 @@
const std = @import("std");
pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type {
const params = @typeInfo(funcT).Fn.params;
if (params.len == 0) {
return @TypeOf(.{});
}
if (@typeInfo(funcT).Fn.is_generic == false) {
return std.meta.ArgsTuple(funcT);
}
const args = std.meta.fields(argsT orelse @compileError("generic function requires an explicit ArgsTuple"));
var tuple_fields: [params.len]std.builtin.Type.StructField = undefined;
inline for (params, args, 0..) |param, arg, i| {
if (param.type == null) {
tuple_fields[i] = arg;
continue;
}
const T = param.type.?;
var num_buf: [32]u8 = undefined;
tuple_fields[i] = .{
.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
},
.type = T,
.default_value = null,
.is_comptime = false,
.alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0,
};
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &tuple_fields,
},
});
}
pub fn TupleRange(comptime T: type, comptime start: usize, comptime end: usize) type {
const fields = std.meta.fields(T);
var new_fields: [end - start]std.builtin.Type.StructField = undefined;
inline for (start..end, 0..) |i, j| {
var new_field = fields[i];
var num_buf: [32]u8 = undefined;
new_field.name = blk: {
const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{});
num_buf[s] = 0;
break :blk num_buf[0..s :0];
};
new_fields[j] = new_field;
}
return @Type(.{
.Struct = .{
.is_tuple = true,
.layout = .auto,
.decls = &.{},
.fields = &new_fields,
},
});
}
pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type {
return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT));
}
pub fn FnSignatureX(comptime func: anytype, comptime argsT: type) type {
return struct {
pub const FuncT = @TypeOf(func);
pub const ArgsT = argsT;
pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined)));
pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.payload,
else => ReturnT,
};
pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) {
.ErrorUnion => |u| u.error_set,
else => null,
};
};
}

View File

@ -1,238 +0,0 @@
const std = @import("std");
const xev = @import("xev");
const FnSignature = @import("meta.zig").FnSignature;
const NormalizedTuple = @import("meta.zig").NormalizedTuple;
pub fn Frame(comptime func: anytype) type {
const Signature = FnSignature(func, null);
return FrameExx(func, Signature);
}
pub fn FrameEx(comptime func: anytype, comptime argsT: type) type {
const Signature = FnSignature(func, argsT);
return FrameExx(func, Signature);
}
pub fn FrameExx(comptime func: anytype, comptime Signature: type) type {
return struct {
const Self = @This();
const Signature_ = Signature;
const Task = struct {
_task: xev.ThreadPool.Task = .{ .callback = &Self.run },
event: std.Thread.ResetEvent = .{},
args: Signature.ArgsT,
result: Signature.ReturnT = undefined,
};
_task: *Task,
fn run(task_: *xev.ThreadPool.Task) void {
const task: *Task = @alignCast(@fieldParentPtr("_task", task_));
task.result = @call(.auto, func, task.args);
task.event.set();
}
pub const await_ = wait;
pub fn wait(self: *Self) Signature.ReturnT {
defer {
AsyncThread.current.mutex.lock();
AsyncThread.current.allocator.destroy(self._task);
AsyncThread.current.mutex.unlock();
}
self._task.event.wait();
return self._task.result;
}
};
}
pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) {
const FrameT = FrameEx(func, @TypeOf(args));
AsyncThread.current.mutex.lock();
defer AsyncThread.current.mutex.unlock();
const task = try AsyncThread.current.allocator.create(FrameT.Task);
task.* = .{
.args = args,
};
AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&task._task));
return .{ ._task = task };
}
pub inline fn callBlocking(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT {
return @call(.auto, func, args);
}
pub inline fn sleep(ms: u64) !void {
std.time.sleep(ms * std.time.ns_per_ms);
}
pub const AsyncThread = struct {
var current: AsyncThread = undefined;
allocator: std.mem.Allocator,
thread_pool: xev.ThreadPool,
mutex: std.Thread.Mutex,
pub fn main(allocator_: std.mem.Allocator, comptime mainFunc: anytype) !void {
current = .{
.allocator = allocator_,
.thread_pool = xev.ThreadPool.init(.{}),
.mutex = .{},
};
defer {
current.thread_pool.shutdown();
current.thread_pool.deinit();
}
return try mainFunc();
}
};
pub const Notification = struct {
inner: std.Thread.ResetEvent,
pub fn init() !Notification {
return .{ .inner = .{} };
}
pub fn notify(self: *Notification) !void {
self.inner.set();
}
pub fn wait(self: *Notification) !void {
self.inner.wait();
}
pub fn deinit(self: *Notification) void {
self.inner.set();
self.* = undefined;
}
};
pub fn getStdIn() !File {
return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin");
}
pub fn getStdOut() File {
return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout");
}
pub fn getStdErr() File {
return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr");
}
pub const File = struct {
pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?;
pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?;
pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read);
pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write);
pub const SeekableStream = std.io.SeekableStream(
File,
SeekError,
GetSeekPosError,
seekTo,
seekBy,
getPos,
getEndPos,
);
inner: std.fs.File,
fn asFile(self: File) std.fs.File {
return self.inner;
}
pub fn handle(self: File) std.fs.File.Handle {
return self.inner.handle;
}
pub fn init(file_: std.fs.File) !File {
return .{ .inner = file_ };
}
pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File {
return init(try std.fs.cwd().openFile(path, flags));
}
pub fn access(path: []const u8, flags: std.fs.File.OpenFlags) !void {
return try std.fs.cwd().access(path, flags);
}
pub fn read(self: File, buf: []u8) !usize {
return try self.inner.read(buf);
}
pub fn pread(self: File, buf: []u8, offset: u64) !usize {
return try self.inner.pread(buf, offset);
}
pub fn write(self: File, buf: []const u8) !usize {
return try self.inner.write(buf);
}
pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize {
return try self.inner.pwrite(buf, offset);
}
pub fn close(self: File) !void {
return self.inner.close();
}
pub fn reader(self: File) Reader {
return .{ .context = self };
}
pub fn seekableStream(file: File) SeekableStream {
return .{ .context = file };
}
pub fn writer(self: File) Writer {
return .{ .context = self };
}
pub fn stat(self: File) !std.fs.File.Stat {
return try self.inner.stat();
}
pub fn seekBy(self: File, offset: i64) !void {
try self.inner.seekBy(offset);
}
pub fn seekTo(self: File, offset: u64) !void {
try self.inner.seekTo(offset);
}
pub fn getPos(self: File) !u64 {
return try self.inner.getPos();
}
pub fn getEndPos(self: File) !u64 {
return try self.inner.getEndPos();
}
};
pub const Mutex = std.Thread.Mutex;
pub fn logFn(
comptime message_level: std.log.Level,
comptime scope: @Type(.EnumLiteral),
comptime format: []const u8,
args: anytype,
) void {
const level_txt = comptime message_level.asText();
const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): ";
const stderr = getStdErr().writer();
var bw = std.io.bufferedWriter(stderr);
const writer = bw.writer();
std.debug.lockStdErr();
defer std.debug.unlockStdErr();
nosuspend {
writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return;
bw.flush() catch return;
}
}

View File

@ -2,7 +2,7 @@
"build_options": [
{
"name": "cmd",
"value": "bazel run @zml//zml:completion"
"value": "bazel run @zml//:completion"
}
]
}

View File

@ -771,10 +771,13 @@ fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void {
var context = LocalContext{ .index = 0, .op = op, .shapes = shapes };
meta.visit((struct {
fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void {
tensor.* = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index));
if (inner_ctx.shapes) |sh| {
tensor._shape = sh[inner_ctx.index];
new._shape = sh[inner_ctx.index];
} else {
new._shape._tags = tensor._shape._tags;
}
tensor.* = new;
inner_ctx.index += 1;
}
}).cb, &context, v);
@ -977,7 +980,7 @@ fn compileInternal(
const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?;
const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value();
if (dir.createFile(file_name, .{ .truncate = true })) |file| {
context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = true });
context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false });
log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name });
} else |_| {
log.warn("Failed to open {s}", .{file_name});
@ -1098,7 +1101,7 @@ pub fn compileFn(
comptime func: anytype,
args: ShapeOf(stdx.meta.FnArgs(func)),
platform: Platform,
) !ExeWithWeights(FnWithVoidArg(func)) {
) !FnExe(func) {
const name = @typeName(@TypeOf(func));
var context = try CompilationContext.init(allocator, name, platform);
defer context.deinit();
@ -1116,6 +1119,10 @@ pub fn compileFn(
return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model);
}
pub fn FnExe(comptime func: anytype) type {
return ExeWithWeights(FnWithVoidArg(func));
}
fn FnWithVoidArg(comptime func: anytype) type {
const fn_info = @typeInfo(@TypeOf(func)).Fn;
const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void };
@ -1163,7 +1170,10 @@ fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module_hash:
fn storePjrtExecutable(arena: std.mem.Allocator, platform: Platform, loaded_executable: *pjrt.LoadedExecutable, module_hash: u64, compilation_cache_location: []const u8) !void {
const resolved_path = try std.fs.cwd().realpathAlloc(arena, compilation_cache_location);
const compilation_cache_dir = try std.fs.openDirAbsolute(resolved_path, .{});
const compilation_cache_dir = std.fs.openDirAbsolute(resolved_path, .{}) catch blk: {
try std.fs.makeDirAbsolute(resolved_path);
break :blk try std.fs.openDirAbsolute(resolved_path, .{});
};
const loaded_executable_file = try compilation_cache_dir.createFile(try std.fmt.allocPrint(arena, "{x}", .{module_hash}), .{});
defer loaded_executable_file.close();

View File

@ -752,7 +752,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) {
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape())) {
return cuda.sdpa(q, k, v, opts);
}
@ -769,8 +769,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
};
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
k = k.mul(scale_logit.convert(k.dtype()));
const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
k = k.mul(head_scaling.convert(k.dtype()));
var attn_weights = q.dot(k, .{.hd});
// log.debug("attn_weights : {}", .{attn_weights});
@ -787,164 +787,227 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
return attn.transpose(q.shape());
}
pub const MemEfficientOps = struct {
scale: ?f32 = null,
query_chunk_size: u32,
key_chunk_size: u32,
opts: SdpaOpts = .{},
};
pub const SdpaChunks = struct { q_chunk_size: u32, k_chunk_size: u32 };
pub fn sdpaMemEfficient(q_: Tensor, k_: Tensor, v_: Tensor, opts: MemEfficientOps) Tensor {
const q = q_.withTags(.{ .b, .hq, .sq, .hd });
const k = k_.withTags(.{ .b, .hk, .sk, .hd });
const v = v_.withTags(.{ .b, .hk, .sk, .hd });
var sdpa_opts = opts.opts;
if (sdpa_opts.attn_mask) |*attn_mask| attn_mask.* = attn_mask.withTags(.{ .sq, .sk });
pub fn sdpaMemEfficient(
q: Tensor,
k: Tensor,
v: Tensor,
sdpa_opts: SdpaOpts,
chunking: SdpaChunks,
) Tensor {
const sdpa_mem_efficient: SdpaMemEfficient = .{
.q = q,
.k = k,
.v = v,
.sdpa_opts = sdpa_opts,
.chunking = .{
.q_chunk_size = @intCast(@min(q.dim(.q), chunking.q_chunk_size)),
.k_chunk_size = @intCast(@min(k.dim(.k), chunking.k_chunk_size)),
},
};
const sdpa_mem_efficient: SdpaMemEfficient = .{ .q = q, .k = k, .v = v, .opt = .{
.query_chunk_size = @intCast(@min(q.dim(.sq), opts.query_chunk_size)),
.key_chunk_size = @intCast(@min(k.dim(.sk), opts.key_chunk_size)),
.scale = opts.scale,
.opts = sdpa_opts,
} };
// TODO(Corentin): Maybe `withTags` could take a Shape to copy from.
var result = sdpa_mem_efficient.forward();
result._shape = q_.shape();
return result;
return sdpa_mem_efficient.forward();
}
const SdpaMemEfficient = struct {
q: Tensor,
k: Tensor,
v: Tensor,
opt: MemEfficientOps,
sdpa_opts: SdpaOpts,
chunking: SdpaChunks,
fn forward(self: SdpaMemEfficient) Tensor {
const n_q_chunks = @divExact(self.q.dim(.sq), self.opt.query_chunk_size);
const res = ops.for_(SdpaMemEfficient.nextQueriesChunk, self, .{ .nq = n_q_chunks });
// TODO: should "for_" operate on an axis ?
// res: (nq, b, nh, qlen / nq, dim) -> (b, nh, qlen, dim)
return res.transpose(.{ 1, 2, 0, 3, 4 }).flatten(2);
// return res.transpose(.{ .b, .hq, .nq, .sq, .hd }).merge(.{ .nq, .sq }, .sq);
stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.q, self.chunking });
stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.k, self.chunking });
const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size));
const ctx = zml.module.CompilationContext.current();
const q_chunks = ctx._allocator.alloc(zml.Tensor, n_q_chunks) catch unreachable;
defer ctx._allocator.free(q_chunks);
for (0..n_q_chunks) |i| {
const idx: u32 = @intCast(i);
const q_slice: zml.Tensor.DynSlice = .{
.start = Tensor.scalar(idx * self.chunking.q_chunk_size, .i32),
.len = self.chunking.q_chunk_size,
};
const q_chunk = self.q.dynamicSlice(.{ .q = q_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null;
var chunk: SdpaMemEfficient = self;
chunk.q = q_chunk;
chunk.sdpa_opts.attn_mask = attn_chunk;
q_chunks[i] = chunk.scanKeyVal();
}
const res = zml.Tensor.concatenate(q_chunks, .q);
return res.transpose(self.q.shape());
}
fn nextQueriesChunk(self: SdpaMemEfficient, idx: Tensor) Tensor {
const offset = idx.scale(self.opt.query_chunk_size);
const q_chunk = self.q.dynamicSlice(.{ .sq = .{ .start = offset, .len = self.opt.query_chunk_size } });
const attn_chunk = if (self.opt.opts.attn_mask) |attn_mask| attn_mask.dynamicSlice1d(0, self.opt.query_chunk_size, offset) else null;
const q_slice: zml.Tensor.DynSlice = .{
.start = idx.scale(self.chunking.q_chunk_size),
.len = self.chunking.q_chunk_size,
};
const q_chunk = self.q.dynamicSlice(.{ .q = q_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null;
var chunk: SdpaMemEfficient = self;
chunk.q = q_chunk;
chunk.opt.opts.attn_mask = attn_chunk;
chunk.sdpa_opts.attn_mask = attn_chunk;
return chunk.scanKeyVal();
}
fn scanKeyVal(self: SdpaMemEfficient) Tensor {
const n_chunks = @divExact(self.k.dim(.sk), self.opt.key_chunk_size);
const res = ops.for_(SdpaMemEfficient.nextKeyValChunk, self, .{ .k_chunk = n_chunks });
const global_max = res.max_value.max(.k_chunk).broad(res.max_value.shape());
const max_diffs = res.max_value.sub(global_max).exp();
const attn = res.attn.mul(max_diffs.broad(res.attn.shape())).sum(.k_chunk).squeeze(.k_chunk);
const exp_sum = res.exp_sum.mul(max_diffs.convert(.f32)).sum(.k_chunk).squeeze(.k_chunk).convert(attn.dtype());
return attn.div(exp_sum.broad(self.q.shape()));
const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size);
return if (n_chunks <= 4) {
// Unrolled version
var partial_softmax: ?PartialSoftmax = null;
for (0..@intCast(n_chunks)) |idx| {
const next = self.nextKeyValChunk(Tensor.scalar(idx, .i32));
partial_softmax = if (partial_softmax) |prev| prev.merge(next) else next;
}
return partial_softmax.?.finalize();
} else {
// stablehlo.while version
const partial_softmax, _ = zml.ops.while_(hasNextKeyValChunk, nextKeyValChunkMerge, self, .{ PartialSoftmax.zeros(self.q.shape(), .f32), Tensor.scalar(0, .i32) });
return partial_softmax.finalize();
};
}
fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialAttn {
const offset = idx.scale(self.opt.key_chunk_size);
const k_chunk = self.k.dynamicSlice(.{ .sk = .{ .start = offset, .len = self.opt.key_chunk_size } });
const v_chunk = self.v.dynamicSlice(.{ .sk = .{ .start = offset, .len = self.opt.key_chunk_size } });
const attn_chunk = if (self.opt.opts.attn_mask) |mask| mask.dynamicSlice1d(1, self.opt.key_chunk_size, offset) else null;
fn nextKeyValChunkMerge(self: SdpaMemEfficient, prev: PartialSoftmax, idx: Tensor) struct { PartialSoftmax, Tensor } {
const next = self.nextKeyValChunk(idx);
return .{ prev.merge(next), idx.addConstant(1) };
}
fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialSoftmax {
const k_slice: zml.Tensor.DynSlice = .{
.start = idx.scale(self.chunking.k_chunk_size),
.len = self.chunking.k_chunk_size,
};
const k_chunk = self.k.dynamicSlice(.{ .k = k_slice });
const v_chunk = self.v.dynamicSlice(.{ .k = k_slice });
const attn_chunk = if (self.sdpa_opts.attn_mask) |mask| mask.dynamicSlice(.{ .k = k_slice }) else null;
return sdpaChunk(self.q, k_chunk, v_chunk, .{ .attn_mask = attn_chunk });
}
pub fn hasNextKeyValChunk(self: SdpaMemEfficient, _: PartialSoftmax, idx: Tensor) zml.Tensor {
const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size);
return idx.cmp(.LT, Tensor.scalar(n_chunks, idx.dtype()));
}
};
pub const PartialAttn = struct {
attn: Tensor,
pub const PartialSoftmax = struct {
values: Tensor,
exp_sum: Tensor,
max_value: Tensor,
pub fn zeros(q_shape: Shape, exp_sum_precision: DataType) PartialSoftmax {
return .{
.values = Tensor.constant(q_shape, q_shape.dtype().zero()),
.exp_sum = Tensor.constant(q_shape.setDim(.hd, 1), exp_sum_precision.zero()),
.max_value = Tensor.constant(q_shape.setDim(.hd, 1), q_shape.dtype().minValue()),
};
}
pub fn merge(self: PartialSoftmax, other: PartialSoftmax) PartialSoftmax {
// Rescale self and other using the new global_max.
const global_max = self.max_value.maximum(other.max_value);
const new_self = self.rescale(global_max);
const new_other = other.rescale(global_max);
// Now that self and other are using the same scale, we can just add them:
return .{
.max_value = global_max,
.values = new_self.values.add(new_other.values),
.exp_sum = new_self.exp_sum.add(new_other.exp_sum),
};
}
/// Update max_value and rescale attn and exp_sum accordingly.
pub fn rescale(self: PartialSoftmax, max_value: Tensor) PartialSoftmax {
const max_diff_exp = self.max_value.sub(max_value).exp();
const sum_dtype = self.exp_sum.dtype();
return .{
.max_value = max_value,
.values = self.values.mul(max_diff_exp.broad(self.values.shape())),
.exp_sum = self.exp_sum.mul(max_diff_exp.convert(sum_dtype)),
};
}
/// Divides the intermediary results by the exp_sum to get the proper attention values.
pub fn finalize(self: PartialSoftmax) Tensor {
return self.values.div(self.exp_sum.broad(self.values.shape()).convert(self.values.dtype()));
}
};
/// Compute softmax over a chunk.
/// Returns intermediary results to allow aggregating later.
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialAttn {
pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax {
const a = self.axis(axis);
const max_val = self.max(a);
const out = self.sub(max_val.broad(self.shape())).exp();
return .{
.attn = out,
.exp_sum = out.convert(.f32).sum(a).squeeze(a),
.max_value = max_val.squeeze(a),
.values = out,
.exp_sum = out.convert(.f32).sum(a),
.max_value = max_val,
};
}
/// Compute sdpa on a chunk, and computes a partial softmax.
/// q: (B, H, Sq, H_dim) k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk)
fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn {
// const bs, const num_head, const sk, const h_dim = q.dims[0..4];
// TODO: rewrite using modern ZML
pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoftmax {
// this is a dupe of sdpa, but return the PartialSoftmax instead of true Attn.
// Consider implementing sdpa from sdpaChunk.
var q, var k, var v = .{ q_, k_, v_ };
// If we have more query heads (hq) than key heads (hk), repeat keys.
const k_rep, const v_rep = if (q.dim(.hq) != k.dim(.hk)) blk: {
const num_rep: u63 = @intCast(@divExact(q.dim(.hq), k.dim(.hk)));
break :blk .{ k.repeat1d(0, num_rep).rename(.{ .hk = .hq }), v.repeat1d(0, num_rep).rename(.{ .hk = .hq }) };
} else .{ k.rename(.{ .hk = .hq }), v.rename(.{ .hk = .hq }) };
const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! ";
const err_args = .{ q, k, v, opts.attn_mask };
stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args);
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
var qk = q.dot(k_rep, .{.hd});
if (q.dim(.h) != k.dim(.h)) {
stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args);
// Note: we don't try to repeat queries.
// Repeating keys is the interesting optimisation cause it reduces KV cache memory usage.
const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h)));
k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) };
}
const attn_mask = if (opts.attn_mask) |m| m else null;
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(q.dim(.hd))));
qk = qk.scale(sqrtHeadDim);
const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch {
stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args);
};
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd)));
const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype());
k = k.mul(head_scaling.convert(k.dtype()));
std.debug.assert(qk.rank() == q.rank());
if (opts.attn_mask) |mask| {
qk = qk.add(mask.broad(qk.shape()));
var attn_weights = q.dot(k, .{.hd});
// log.debug("attn_weights : {}", .{attn_weights});
// log.debug("attn_mask : {?}", .{attn_mask});
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape()));
if (opts.bias) |bias| {
attn_weights = attn_weights.add(bias);
}
const partial = partialSoftmax(qk, -1);
const attn = partial.attn.dot(v_rep, .{.sk});
const partial = partialSoftmax(attn_weights, .k);
const attn = partial.values.dot(v, .{.k}).transpose(q.shape());
return .{
.attn = attn,
.exp_sum = partial.exp_sum,
.max_value = partial.max_value,
.values = attn,
// The renaming is because the above dot projected values.k into .hd,
// do the same thing on the other tensors.
// This work because dot is a linear operation, and commutes with `PartialSoftmax.finalize`
.exp_sum = partial.exp_sum.rename(.{ .k = .hd }).transpose(attn.shape()),
.max_value = partial.max_value.rename(.{ .k = .hd }).transpose(attn.shape()),
};
}
test "sdpaMemEfficient without mask" {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable.
const q = rng.call(undefined);
const k = rng.call(undefined);
const v = rng.call(undefined);
const ref_res = try zml.testing.compileAndCallWithTensors(platform, sdpa, .{
q.shape().withTags(.{ .b, .h, .q, .hd }),
k.shape().withTags(.{ .b, .h, .k, .hd }),
v.shape().withTags(.{ .b, .h, .k, .hd }),
.{ .attn_mask = null, .scale = null, .bias = null },
}, .{ q, k, v, undefined });
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
const opts: zml.ShapeOf(MemEfficientOps) = .{ .query_chunk_size = 256, .key_chunk_size = 128, .opts = .{ .attn_mask = null, .scale = null, .bias = null } };
const res = try zml.testing.compileAndCallWithTensors(
platform,
sdpaMemEfficient,
.{ q.shape(), k.shape(), v.shape(), opts },
.{ q, k, v, undefined },
);
try zml.testing.expectClose(ref_res, res, 2e-3);
}
test "sdpaMemEfficient with mask" {
test sdpaMemEfficient {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
@ -957,22 +1020,107 @@ test "sdpaMemEfficient with mask" {
defer rng_mask.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
const q = rng.call(undefined);
const k = rng.call(undefined);
const v = rng.call(undefined);
const mask = rng_mask.call(undefined);
const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd });
const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd });
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(platform, sdpa, .{ q.withTags(.{ .b, .h, .q, .hd }), k.withTags(.{ .b, .h, .k, .hd }), v.withTags(.{ .b, .h, .k, .hd }), .{ .attn_mask = mask.withTags(.{ .q, .k }), .scale = null, .bias = null } });
const ref_res = try zml.testing.compileAndCall(
platform,
sdpa,
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
);
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
{
// 4 k_chunks
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null, .bias = null },
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 2e-3);
}
{
// 16 k_chunks
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null, .bias = null },
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) },
},
);
try zml.testing.expectClose(ref_res, res, 2e-3);
}
}
test "sdpaMemEfficient transposed" {
const platform = zml.testing.env();
const allocator = std.testing.allocator;
// Note we use small input vectors to have the tests run reasonably fast,
// but don't expect speed ups with this small sizes.
const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 512, 10, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng.deinit();
const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform);
defer rng_mask.deinit();
// Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable.
const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd });
const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd });
const mask = rng_mask.call(undefined).withTags(.{ .q, .k });
const ref_res = try zml.testing.compileAndCall(
platform,
sdpa,
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
);
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
const res = try zml.testing.compileAndCall(platform, sdpaMemEfficient, .{ q, k, v, .{
.query_chunk_size = 256,
.key_chunk_size = 128,
.scale = null,
.opts = .{ .attn_mask = mask, .scale = null, .bias = null, .allow_cudnn = false },
} });
{
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null, .bias = null },
.{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 2e-3);
try zml.testing.expectClose(ref_res, res, 1e-3);
}
{
const res = try zml.testing.compileAndCall(
platform,
sdpaMemEfficient,
.{
q,
k,
v,
.{ .attn_mask = mask, .scale = null, .bias = null },
.{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) },
},
);
try zml.testing.expectClose(ref_res, res, 1e-3);
}
}
/// Options controlling generation. The default values correspond to greedy decoding.

View File

@ -12,16 +12,18 @@ const DataType = @import("../dtype.zig").DataType;
const Data = @import("../dtype.zig").Data;
const CompilationContext = module.CompilationContext;
pub fn canUseCudnnSdpa(head_dim: i64, dtype: DataType) bool {
pub fn canUseCudnnSdpa(q_shape: Shape) bool {
const ctx = CompilationContext.current();
// TODO(Corendos): Check cuda version, cudnn version, device compatibility.
if (!ctx.targetIs(.cuda)) return false;
if (q_shape.rank() != 4) return false;
// NOTE(Corentin): In Cudnn fused MHA head_dim is limited to 128.
if (head_dim > 128) return false;
if (q_shape.dim(.hd) > 128) return false;
// NOTE(Corentin): In Cudnn fused MHA data type is limited to F16 and BF16.
if (dtype != .f16 and dtype != .bf16) return false;
if (q_shape.dtype() != .f16 and q_shape.dtype() != .bf16) return false;
return true;
}

View File

@ -2677,7 +2677,7 @@ pub const Tensor = struct {
return ops.reduce(
ArgMaxRes.cmp,
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x._shape.withDtype(index_dtype), &.{a}) },
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x.shape(), &.{a}) },
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) },
&.{a},
);
@ -3171,7 +3171,7 @@ pub const Tensor = struct {
var prev_ax: i8 = -1;
for (self._shape.tags(), 0..) |t, self_ax| {
if (update._shape.hasTag(t)) |up_ax| {
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape });
stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_, self });
update_shape._dims.set(self_ax, update.dim(up_ax));
prev_ax = up_ax;
@ -3182,7 +3182,7 @@ pub const Tensor = struct {
update = update.reshape(update_shape);
}
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() });
stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {}", .{ self, update });
for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| {
const t = self._shape.debugTag(ax);

View File

@ -17,6 +17,7 @@ pub fn env() zml.Platform {
.{
.cache_location = "/tmp/zml/tests/cache",
.xla_dump_to = "/tmp/zml/tests/",
.sharding_enabled = true,
}
else
.{};
@ -197,7 +198,7 @@ pub fn testLayerOut(
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
}
const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform);
const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform);
const n_out_exp = activations.countLayers(out_name);
if (exe.inner.result_buffer_count != n_out_exp) {