diff --git a/async/meta.zig b/async/meta.zig deleted file mode 100644 index b0f3e56..0000000 --- a/async/meta.zig +++ /dev/null @@ -1,86 +0,0 @@ -const std = @import("std"); - -pub fn ArgsTuple(comptime funcT: anytype, comptime argsT: ?type) type { - const params = @typeInfo(funcT).Fn.params; - if (params.len == 0) { - return @TypeOf(.{}); - } - - if (@typeInfo(funcT).Fn.is_generic == false) { - return std.meta.ArgsTuple(funcT); - } - - const args = std.meta.fields(argsT orelse @compileError("generic function requires an explicit ArgsTuple")); - var tuple_fields: [params.len]std.builtin.Type.StructField = undefined; - inline for (params, args, 0..) |param, arg, i| { - if (param.type == null) { - tuple_fields[i] = arg; - continue; - } - const T = param.type.?; - var num_buf: [32]u8 = undefined; - tuple_fields[i] = .{ - .name = blk: { - const s = std.fmt.formatIntBuf(&num_buf, i, 10, .lower, .{}); - num_buf[s] = 0; - break :blk num_buf[0..s :0]; - }, - .type = T, - .default_value = null, - .is_comptime = false, - .alignment = if (@sizeOf(T) > 0) @alignOf(T) else 0, - }; - } - - return @Type(.{ - .Struct = .{ - .is_tuple = true, - .layout = .auto, - .decls = &.{}, - .fields = &tuple_fields, - }, - }); -} - -pub fn TupleRange(comptime T: type, comptime start: usize, comptime end: usize) type { - const fields = std.meta.fields(T); - var new_fields: [end - start]std.builtin.Type.StructField = undefined; - inline for (start..end, 0..) |i, j| { - var new_field = fields[i]; - var num_buf: [32]u8 = undefined; - new_field.name = blk: { - const s = std.fmt.formatIntBuf(&num_buf, j, 10, .lower, .{}); - num_buf[s] = 0; - break :blk num_buf[0..s :0]; - }; - new_fields[j] = new_field; - } - return @Type(.{ - .Struct = .{ - .is_tuple = true, - .layout = .auto, - .decls = &.{}, - .fields = &new_fields, - }, - }); -} - -pub fn FnSignature(comptime func: anytype, comptime argsT: ?type) type { - return FnSignatureX(func, ArgsTuple(@TypeOf(func), argsT)); -} - -pub fn FnSignatureX(comptime func: anytype, comptime argsT: type) type { - return struct { - pub const FuncT = @TypeOf(func); - pub const ArgsT = argsT; - pub const ReturnT = @TypeOf(@call(.auto, func, @as(ArgsT, undefined))); - pub const ReturnPayloadT = switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.payload, - else => ReturnT, - }; - pub const ReturnErrorSet: ?type = switch (@typeInfo(ReturnT)) { - .ErrorUnion => |u| u.error_set, - else => null, - }; - }; -} diff --git a/async/threaded.zig b/async/threaded.zig deleted file mode 100644 index 5858de6..0000000 --- a/async/threaded.zig +++ /dev/null @@ -1,238 +0,0 @@ -const std = @import("std"); -const xev = @import("xev"); - -const FnSignature = @import("meta.zig").FnSignature; -const NormalizedTuple = @import("meta.zig").NormalizedTuple; - -pub fn Frame(comptime func: anytype) type { - const Signature = FnSignature(func, null); - return FrameExx(func, Signature); -} - -pub fn FrameEx(comptime func: anytype, comptime argsT: type) type { - const Signature = FnSignature(func, argsT); - return FrameExx(func, Signature); -} - -pub fn FrameExx(comptime func: anytype, comptime Signature: type) type { - return struct { - const Self = @This(); - const Signature_ = Signature; - const Task = struct { - _task: xev.ThreadPool.Task = .{ .callback = &Self.run }, - event: std.Thread.ResetEvent = .{}, - args: Signature.ArgsT, - result: Signature.ReturnT = undefined, - }; - - _task: *Task, - - fn run(task_: *xev.ThreadPool.Task) void { - const task: *Task = @alignCast(@fieldParentPtr("_task", task_)); - task.result = @call(.auto, func, task.args); - task.event.set(); - } - - pub const await_ = wait; - pub fn wait(self: *Self) Signature.ReturnT { - defer { - AsyncThread.current.mutex.lock(); - AsyncThread.current.allocator.destroy(self._task); - AsyncThread.current.mutex.unlock(); - } - self._task.event.wait(); - return self._task.result; - } - }; -} - -pub fn asyncc(comptime func: anytype, args: anytype) !FrameEx(func, @TypeOf(args)) { - const FrameT = FrameEx(func, @TypeOf(args)); - - AsyncThread.current.mutex.lock(); - defer AsyncThread.current.mutex.unlock(); - - const task = try AsyncThread.current.allocator.create(FrameT.Task); - task.* = .{ - .args = args, - }; - - AsyncThread.current.thread_pool.schedule(xev.ThreadPool.Batch.from(&task._task)); - return .{ ._task = task }; -} - -pub inline fn callBlocking(comptime func: anytype, args: anytype) FnSignature(func, @TypeOf(args)).ReturnT { - return @call(.auto, func, args); -} - -pub inline fn sleep(ms: u64) !void { - std.time.sleep(ms * std.time.ns_per_ms); -} - -pub const AsyncThread = struct { - var current: AsyncThread = undefined; - - allocator: std.mem.Allocator, - thread_pool: xev.ThreadPool, - mutex: std.Thread.Mutex, - - pub fn main(allocator_: std.mem.Allocator, comptime mainFunc: anytype) !void { - current = .{ - .allocator = allocator_, - .thread_pool = xev.ThreadPool.init(.{}), - .mutex = .{}, - }; - - defer { - current.thread_pool.shutdown(); - current.thread_pool.deinit(); - } - - return try mainFunc(); - } -}; - -pub const Notification = struct { - inner: std.Thread.ResetEvent, - - pub fn init() !Notification { - return .{ .inner = .{} }; - } - - pub fn notify(self: *Notification) !void { - self.inner.set(); - } - - pub fn wait(self: *Notification) !void { - self.inner.wait(); - } - - pub fn deinit(self: *Notification) void { - self.inner.set(); - self.* = undefined; - } -}; - -pub fn getStdIn() !File { - return File.init(std.io.getStdIn()) catch @panic("Unable to open stdin"); -} - -pub fn getStdOut() File { - return File.init(std.io.getStdOut()) catch @panic("Unable to open stdout"); -} - -pub fn getStdErr() File { - return File.init(std.io.getStdErr()) catch @panic("Unable to open stderr"); -} - -pub const File = struct { - pub const SeekError = FnSignature(File.seekTo, null).ReturnErrorSet.? || FnSignature(File.seekBy, null).ReturnErrorSet.?; - pub const GetSeekPosError = SeekError || FnSignature(File.stat, null).ReturnErrorSet.?; - pub const Reader = std.io.GenericReader(File, FnSignature(File.read, null).ReturnErrorSet.?, File.read); - pub const Writer = std.io.GenericWriter(File, FnSignature(File.write, null).ReturnErrorSet.?, File.write); - pub const SeekableStream = std.io.SeekableStream( - File, - SeekError, - GetSeekPosError, - seekTo, - seekBy, - getPos, - getEndPos, - ); - - inner: std.fs.File, - - fn asFile(self: File) std.fs.File { - return self.inner; - } - - pub fn handle(self: File) std.fs.File.Handle { - return self.inner.handle; - } - - pub fn init(file_: std.fs.File) !File { - return .{ .inner = file_ }; - } - - pub fn open(path: []const u8, flags: std.fs.File.OpenFlags) !File { - return init(try std.fs.cwd().openFile(path, flags)); - } - - pub fn access(path: []const u8, flags: std.fs.File.OpenFlags) !void { - return try std.fs.cwd().access(path, flags); - } - - pub fn read(self: File, buf: []u8) !usize { - return try self.inner.read(buf); - } - - pub fn pread(self: File, buf: []u8, offset: u64) !usize { - return try self.inner.pread(buf, offset); - } - - pub fn write(self: File, buf: []const u8) !usize { - return try self.inner.write(buf); - } - - pub fn pwrite(self: File, buf: []const u8, offset: u64) !usize { - return try self.inner.pwrite(buf, offset); - } - - pub fn close(self: File) !void { - return self.inner.close(); - } - - pub fn reader(self: File) Reader { - return .{ .context = self }; - } - - pub fn seekableStream(file: File) SeekableStream { - return .{ .context = file }; - } - - pub fn writer(self: File) Writer { - return .{ .context = self }; - } - - pub fn stat(self: File) !std.fs.File.Stat { - return try self.inner.stat(); - } - - pub fn seekBy(self: File, offset: i64) !void { - try self.inner.seekBy(offset); - } - - pub fn seekTo(self: File, offset: u64) !void { - try self.inner.seekTo(offset); - } - - pub fn getPos(self: File) !u64 { - return try self.inner.getPos(); - } - - pub fn getEndPos(self: File) !u64 { - return try self.inner.getEndPos(); - } -}; - -pub const Mutex = std.Thread.Mutex; - -pub fn logFn( - comptime message_level: std.log.Level, - comptime scope: @Type(.EnumLiteral), - comptime format: []const u8, - args: anytype, -) void { - const level_txt = comptime message_level.asText(); - const prefix2 = if (scope == .default) ": " else "(" ++ @tagName(scope) ++ "): "; - const stderr = getStdErr().writer(); - var bw = std.io.bufferedWriter(stderr); - const writer = bw.writer(); - - std.debug.lockStdErr(); - defer std.debug.unlockStdErr(); - nosuspend { - writer.print(level_txt ++ prefix2 ++ format ++ "\n", args) catch return; - bw.flush() catch return; - } -} diff --git a/zls.build.json b/zls.build.json index a426832..796644a 100644 --- a/zls.build.json +++ b/zls.build.json @@ -2,7 +2,7 @@ "build_options": [ { "name": "cmd", - "value": "bazel run @zml//zml:completion" + "value": "bazel run @zml//:completion" } ] } diff --git a/zml/module.zig b/zml/module.zig index c8cfc86..08fd89d 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -771,10 +771,13 @@ fn assignResults(op: mlir.Operation, v: anytype, shapes: []Shape) void { var context = LocalContext{ .index = 0, .op = op, .shapes = shapes }; meta.visit((struct { fn cb(inner_ctx: *LocalContext, tensor: *Tensor) void { - tensor.* = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); + var new = Tensor.fromMlirValue(inner_ctx.op.result(inner_ctx.index)); if (inner_ctx.shapes) |sh| { - tensor._shape = sh[inner_ctx.index]; + new._shape = sh[inner_ctx.index]; + } else { + new._shape._tags = tensor._shape._tags; } + tensor.* = new; inner_ctx.index += 1; } }).cb, &context, v); @@ -977,7 +980,7 @@ fn compileInternal( const name_attr = context._module.op().getAttributeByName("sym_name").?.as(mlir.StringAttribute).?; const file_name = std.fmt.allocPrint(arena, "{s}.mlir", .{name_attr.value()}) catch name_attr.value(); if (dir.createFile(file_name, .{ .truncate = true })) |file| { - context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = true }); + context._module.op().print(file.writer(), .{ .debug_info = true, .debug_info_pretty_form = false }); log.info("Wrote MLIR to {s}/{s}", .{ xla_dump_to, file_name }); } else |_| { log.warn("Failed to open {s}", .{file_name}); @@ -1098,7 +1101,7 @@ pub fn compileFn( comptime func: anytype, args: ShapeOf(stdx.meta.FnArgs(func)), platform: Platform, -) !ExeWithWeights(FnWithVoidArg(func)) { +) !FnExe(func) { const name = @typeName(@TypeOf(func)); var context = try CompilationContext.init(allocator, name, platform); defer context.deinit(); @@ -1116,6 +1119,10 @@ pub fn compileFn( return try ExeWithWeights(FnWithVoidArg(func)).initFromModel(allocator, raw_module, void_model); } +pub fn FnExe(comptime func: anytype) type { + return ExeWithWeights(FnWithVoidArg(func)); +} + fn FnWithVoidArg(comptime func: anytype) type { const fn_info = @typeInfo(@TypeOf(func)).Fn; const void_param = std.builtin.Type.Fn.Param{ .is_generic = false, .is_noalias = false, .type = void }; @@ -1163,7 +1170,10 @@ fn loadPjrtExecutable(arena: std.mem.Allocator, platform: Platform, module_hash: fn storePjrtExecutable(arena: std.mem.Allocator, platform: Platform, loaded_executable: *pjrt.LoadedExecutable, module_hash: u64, compilation_cache_location: []const u8) !void { const resolved_path = try std.fs.cwd().realpathAlloc(arena, compilation_cache_location); - const compilation_cache_dir = try std.fs.openDirAbsolute(resolved_path, .{}); + const compilation_cache_dir = std.fs.openDirAbsolute(resolved_path, .{}) catch blk: { + try std.fs.makeDirAbsolute(resolved_path); + break :blk try std.fs.openDirAbsolute(resolved_path, .{}); + }; const loaded_executable_file = try compilation_cache_dir.createFile(try std.fmt.allocPrint(arena, "{x}", .{module_hash}), .{}); defer loaded_executable_file.close(); diff --git a/zml/nn.zig b/zml/nn.zig index d66ec58..86f4bc2 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -752,7 +752,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); - if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.dim(.hd), q.dtype())) { + if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape())) { return cuda.sdpa(q, k, v, opts); } @@ -769,8 +769,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args); }; const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); - const scale_logit = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); - k = k.mul(scale_logit.convert(k.dtype())); + const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); + k = k.mul(head_scaling.convert(k.dtype())); var attn_weights = q.dot(k, .{.hd}); // log.debug("attn_weights : {}", .{attn_weights}); @@ -787,164 +787,227 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { return attn.transpose(q.shape()); } -pub const MemEfficientOps = struct { - scale: ?f32 = null, - query_chunk_size: u32, - key_chunk_size: u32, - opts: SdpaOpts = .{}, -}; +pub const SdpaChunks = struct { q_chunk_size: u32, k_chunk_size: u32 }; -pub fn sdpaMemEfficient(q_: Tensor, k_: Tensor, v_: Tensor, opts: MemEfficientOps) Tensor { - const q = q_.withTags(.{ .b, .hq, .sq, .hd }); - const k = k_.withTags(.{ .b, .hk, .sk, .hd }); - const v = v_.withTags(.{ .b, .hk, .sk, .hd }); - var sdpa_opts = opts.opts; - if (sdpa_opts.attn_mask) |*attn_mask| attn_mask.* = attn_mask.withTags(.{ .sq, .sk }); +pub fn sdpaMemEfficient( + q: Tensor, + k: Tensor, + v: Tensor, + sdpa_opts: SdpaOpts, + chunking: SdpaChunks, +) Tensor { + const sdpa_mem_efficient: SdpaMemEfficient = .{ + .q = q, + .k = k, + .v = v, + .sdpa_opts = sdpa_opts, + .chunking = .{ + .q_chunk_size = @intCast(@min(q.dim(.q), chunking.q_chunk_size)), + .k_chunk_size = @intCast(@min(k.dim(.k), chunking.k_chunk_size)), + }, + }; - const sdpa_mem_efficient: SdpaMemEfficient = .{ .q = q, .k = k, .v = v, .opt = .{ - .query_chunk_size = @intCast(@min(q.dim(.sq), opts.query_chunk_size)), - .key_chunk_size = @intCast(@min(k.dim(.sk), opts.key_chunk_size)), - .scale = opts.scale, - .opts = sdpa_opts, - } }; - - // TODO(Corentin): Maybe `withTags` could take a Shape to copy from. - var result = sdpa_mem_efficient.forward(); - result._shape = q_.shape(); - return result; + return sdpa_mem_efficient.forward(); } const SdpaMemEfficient = struct { q: Tensor, k: Tensor, v: Tensor, - opt: MemEfficientOps, + sdpa_opts: SdpaOpts, + chunking: SdpaChunks, fn forward(self: SdpaMemEfficient) Tensor { - const n_q_chunks = @divExact(self.q.dim(.sq), self.opt.query_chunk_size); - const res = ops.for_(SdpaMemEfficient.nextQueriesChunk, self, .{ .nq = n_q_chunks }); - // TODO: should "for_" operate on an axis ? - // res: (nq, b, nh, qlen / nq, dim) -> (b, nh, qlen, dim) - return res.transpose(.{ 1, 2, 0, 3, 4 }).flatten(2); - // return res.transpose(.{ .b, .hq, .nq, .sq, .hd }).merge(.{ .nq, .sq }, .sq); + stdx.debug.assert(@mod(self.q.dim(.q), self.chunking.q_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.q, self.chunking }); + stdx.debug.assert(@mod(self.k.dim(.k), self.chunking.k_chunk_size) == 0, "sdpaMemEfficient expects the chunk_size to exactly divise the seq_len, got: sdpaMemEfficient({}, {})", .{ self.k, self.chunking }); + const n_q_chunks: u32 = @intCast(@divExact(self.q.dim(.q), self.chunking.q_chunk_size)); + + const ctx = zml.module.CompilationContext.current(); + const q_chunks = ctx._allocator.alloc(zml.Tensor, n_q_chunks) catch unreachable; + defer ctx._allocator.free(q_chunks); + for (0..n_q_chunks) |i| { + const idx: u32 = @intCast(i); + const q_slice: zml.Tensor.DynSlice = .{ + .start = Tensor.scalar(idx * self.chunking.q_chunk_size, .i32), + .len = self.chunking.q_chunk_size, + }; + const q_chunk = self.q.dynamicSlice(.{ .q = q_slice }); + const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null; + + var chunk: SdpaMemEfficient = self; + chunk.q = q_chunk; + chunk.sdpa_opts.attn_mask = attn_chunk; + q_chunks[i] = chunk.scanKeyVal(); + } + + const res = zml.Tensor.concatenate(q_chunks, .q); + return res.transpose(self.q.shape()); } fn nextQueriesChunk(self: SdpaMemEfficient, idx: Tensor) Tensor { - const offset = idx.scale(self.opt.query_chunk_size); - const q_chunk = self.q.dynamicSlice(.{ .sq = .{ .start = offset, .len = self.opt.query_chunk_size } }); - const attn_chunk = if (self.opt.opts.attn_mask) |attn_mask| attn_mask.dynamicSlice1d(0, self.opt.query_chunk_size, offset) else null; + const q_slice: zml.Tensor.DynSlice = .{ + .start = idx.scale(self.chunking.q_chunk_size), + .len = self.chunking.q_chunk_size, + }; + const q_chunk = self.q.dynamicSlice(.{ .q = q_slice }); + const attn_chunk = if (self.sdpa_opts.attn_mask) |attn_mask| attn_mask.dynamicSlice(.{ .q = q_slice }) else null; var chunk: SdpaMemEfficient = self; chunk.q = q_chunk; - chunk.opt.opts.attn_mask = attn_chunk; + chunk.sdpa_opts.attn_mask = attn_chunk; return chunk.scanKeyVal(); } fn scanKeyVal(self: SdpaMemEfficient) Tensor { - const n_chunks = @divExact(self.k.dim(.sk), self.opt.key_chunk_size); - const res = ops.for_(SdpaMemEfficient.nextKeyValChunk, self, .{ .k_chunk = n_chunks }); - const global_max = res.max_value.max(.k_chunk).broad(res.max_value.shape()); - const max_diffs = res.max_value.sub(global_max).exp(); - const attn = res.attn.mul(max_diffs.broad(res.attn.shape())).sum(.k_chunk).squeeze(.k_chunk); - const exp_sum = res.exp_sum.mul(max_diffs.convert(.f32)).sum(.k_chunk).squeeze(.k_chunk).convert(attn.dtype()); - return attn.div(exp_sum.broad(self.q.shape())); + const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size); + return if (n_chunks <= 4) { + // Unrolled version + var partial_softmax: ?PartialSoftmax = null; + for (0..@intCast(n_chunks)) |idx| { + const next = self.nextKeyValChunk(Tensor.scalar(idx, .i32)); + partial_softmax = if (partial_softmax) |prev| prev.merge(next) else next; + } + return partial_softmax.?.finalize(); + } else { + // stablehlo.while version + const partial_softmax, _ = zml.ops.while_(hasNextKeyValChunk, nextKeyValChunkMerge, self, .{ PartialSoftmax.zeros(self.q.shape(), .f32), Tensor.scalar(0, .i32) }); + return partial_softmax.finalize(); + }; } - fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialAttn { - const offset = idx.scale(self.opt.key_chunk_size); - const k_chunk = self.k.dynamicSlice(.{ .sk = .{ .start = offset, .len = self.opt.key_chunk_size } }); - const v_chunk = self.v.dynamicSlice(.{ .sk = .{ .start = offset, .len = self.opt.key_chunk_size } }); - const attn_chunk = if (self.opt.opts.attn_mask) |mask| mask.dynamicSlice1d(1, self.opt.key_chunk_size, offset) else null; + fn nextKeyValChunkMerge(self: SdpaMemEfficient, prev: PartialSoftmax, idx: Tensor) struct { PartialSoftmax, Tensor } { + const next = self.nextKeyValChunk(idx); + return .{ prev.merge(next), idx.addConstant(1) }; + } + + fn nextKeyValChunk(self: SdpaMemEfficient, idx: Tensor) PartialSoftmax { + const k_slice: zml.Tensor.DynSlice = .{ + .start = idx.scale(self.chunking.k_chunk_size), + .len = self.chunking.k_chunk_size, + }; + + const k_chunk = self.k.dynamicSlice(.{ .k = k_slice }); + const v_chunk = self.v.dynamicSlice(.{ .k = k_slice }); + const attn_chunk = if (self.sdpa_opts.attn_mask) |mask| mask.dynamicSlice(.{ .k = k_slice }) else null; return sdpaChunk(self.q, k_chunk, v_chunk, .{ .attn_mask = attn_chunk }); } + + pub fn hasNextKeyValChunk(self: SdpaMemEfficient, _: PartialSoftmax, idx: Tensor) zml.Tensor { + const n_chunks = @divExact(self.k.dim(.k), self.chunking.k_chunk_size); + return idx.cmp(.LT, Tensor.scalar(n_chunks, idx.dtype())); + } }; -pub const PartialAttn = struct { - attn: Tensor, +pub const PartialSoftmax = struct { + values: Tensor, exp_sum: Tensor, max_value: Tensor, + + pub fn zeros(q_shape: Shape, exp_sum_precision: DataType) PartialSoftmax { + return .{ + .values = Tensor.constant(q_shape, q_shape.dtype().zero()), + .exp_sum = Tensor.constant(q_shape.setDim(.hd, 1), exp_sum_precision.zero()), + .max_value = Tensor.constant(q_shape.setDim(.hd, 1), q_shape.dtype().minValue()), + }; + } + + pub fn merge(self: PartialSoftmax, other: PartialSoftmax) PartialSoftmax { + // Rescale self and other using the new global_max. + const global_max = self.max_value.maximum(other.max_value); + const new_self = self.rescale(global_max); + const new_other = other.rescale(global_max); + + // Now that self and other are using the same scale, we can just add them: + return .{ + .max_value = global_max, + .values = new_self.values.add(new_other.values), + .exp_sum = new_self.exp_sum.add(new_other.exp_sum), + }; + } + + /// Update max_value and rescale attn and exp_sum accordingly. + pub fn rescale(self: PartialSoftmax, max_value: Tensor) PartialSoftmax { + const max_diff_exp = self.max_value.sub(max_value).exp(); + const sum_dtype = self.exp_sum.dtype(); + return .{ + .max_value = max_value, + .values = self.values.mul(max_diff_exp.broad(self.values.shape())), + .exp_sum = self.exp_sum.mul(max_diff_exp.convert(sum_dtype)), + }; + } + + /// Divides the intermediary results by the exp_sum to get the proper attention values. + pub fn finalize(self: PartialSoftmax) Tensor { + return self.values.div(self.exp_sum.broad(self.values.shape()).convert(self.values.dtype())); + } }; /// Compute softmax over a chunk. /// Returns intermediary results to allow aggregating later. -pub fn partialSoftmax(self: Tensor, axis: anytype) PartialAttn { +pub fn partialSoftmax(self: Tensor, axis: anytype) PartialSoftmax { const a = self.axis(axis); const max_val = self.max(a); const out = self.sub(max_val.broad(self.shape())).exp(); return .{ - .attn = out, - .exp_sum = out.convert(.f32).sum(a).squeeze(a), - .max_value = max_val.squeeze(a), + .values = out, + .exp_sum = out.convert(.f32).sum(a), + .max_value = max_val, }; } /// Compute sdpa on a chunk, and computes a partial softmax. /// q: (B, H, Sq, H_dim) ⊙ k: (B, H, Sk, H_dim) -> qk: (B, H, Sq, Sk) -fn sdpaChunk(q: Tensor, k: Tensor, v: Tensor, opts: SdpaOpts) PartialAttn { - // const bs, const num_head, const sk, const h_dim = q.dims[0..4]; - // TODO: rewrite using modern ZML +pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoftmax { + // this is a dupe of sdpa, but return the PartialSoftmax instead of true Attn. + // Consider implementing sdpa from sdpaChunk. + var q, var k, var v = .{ q_, k_, v_ }; - // If we have more query heads (hq) than key heads (hk), repeat keys. - const k_rep, const v_rep = if (q.dim(.hq) != k.dim(.hk)) blk: { - const num_rep: u63 = @intCast(@divExact(q.dim(.hq), k.dim(.hk))); - break :blk .{ k.repeat1d(0, num_rep).rename(.{ .hk = .hq }), v.repeat1d(0, num_rep).rename(.{ .hk = .hq }) }; - } else .{ k.rename(.{ .hk = .hq }), v.rename(.{ .hk = .hq }) }; + const err_template = "sdpa(q: {}, k: {}, v: {}, attn: {?}) is invalid ! "; + const err_args = .{ q, k, v, opts.attn_mask }; + stdx.debug.assert(q.shape().hasTags(.{ .h, .q, .hd }), err_template ++ "q is missing tags {{.h, .q, .hd}}", err_args); + stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args); + stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args); - var qk = q.dot(k_rep, .{.hd}); + if (q.dim(.h) != k.dim(.h)) { + stdx.debug.assert(@mod(q.dim(.h), k.dim(.h)) == 0, err_template ++ "Different number of heads for keys and queries, but can't repeat keys.", err_args); + // Note: we don't try to repeat queries. + // Repeating keys is the interesting optimisation cause it reduces KV cache memory usage. + const num_rep: u63 = @intCast(@divExact(q.dim(.h), k.dim(.h))); + k, v = .{ k.repeat1d(.h, num_rep), v.repeat1d(.h, num_rep) }; + } + const attn_mask = if (opts.attn_mask) |m| m else null; - const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(q.dim(.hd)))); - qk = qk.scale(sqrtHeadDim); + const dims = helpers.collectDims(.{ .h, .q, .k, .hd }, &.{ q, k, v, attn_mask }, .strict) catch { + stdx.debug.panic(err_template ++ "Inputs have incompatible shapes.", err_args); + }; + const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(dims.hd))); + const head_scaling = if (opts.scale) |s| s else Tensor.scalar(sqrtHeadDim, k.dtype()); + k = k.mul(head_scaling.convert(k.dtype())); - std.debug.assert(qk.rank() == q.rank()); - if (opts.attn_mask) |mask| { - qk = qk.add(mask.broad(qk.shape())); + var attn_weights = q.dot(k, .{.hd}); + // log.debug("attn_weights : {}", .{attn_weights}); + // log.debug("attn_mask : {?}", .{attn_mask}); + if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape())); + + if (opts.bias) |bias| { + attn_weights = attn_weights.add(bias); } - const partial = partialSoftmax(qk, -1); - const attn = partial.attn.dot(v_rep, .{.sk}); + const partial = partialSoftmax(attn_weights, .k); + const attn = partial.values.dot(v, .{.k}).transpose(q.shape()); return .{ - .attn = attn, - .exp_sum = partial.exp_sum, - .max_value = partial.max_value, + .values = attn, + // The renaming is because the above dot projected values.k into .hd, + // do the same thing on the other tensors. + // This work because dot is a linear operation, and commutes with `PartialSoftmax.finalize` + .exp_sum = partial.exp_sum.rename(.{ .k = .hd }).transpose(attn.shape()), + .max_value = partial.max_value.rename(.{ .k = .hd }).transpose(attn.shape()), }; } -test "sdpaMemEfficient without mask" { - const platform = zml.testing.env(); - const allocator = std.testing.allocator; - - // Note we use small input vectors to have the tests run reasonably fast, - // but don't expect speed ups with this small sizes. - const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 10, 512, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); - defer rng.deinit(); - - // Note: it's fine to pass undefined here, cause the arguments have already been baked into the executable. - const q = rng.call(undefined); - const k = rng.call(undefined); - const v = rng.call(undefined); - - const ref_res = try zml.testing.compileAndCallWithTensors(platform, sdpa, .{ - q.shape().withTags(.{ .b, .h, .q, .hd }), - k.shape().withTags(.{ .b, .h, .k, .hd }), - v.shape().withTags(.{ .b, .h, .k, .hd }), - .{ .attn_mask = null, .scale = null, .bias = null }, - }, .{ q, k, v, undefined }); - try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); - - const opts: zml.ShapeOf(MemEfficientOps) = .{ .query_chunk_size = 256, .key_chunk_size = 128, .opts = .{ .attn_mask = null, .scale = null, .bias = null } }; - const res = try zml.testing.compileAndCallWithTensors( - platform, - sdpaMemEfficient, - .{ q.shape(), k.shape(), v.shape(), opts }, - .{ q, k, v, undefined }, - ); - - try zml.testing.expectClose(ref_res, res, 2e-3); -} - -test "sdpaMemEfficient with mask" { +test sdpaMemEfficient { const platform = zml.testing.env(); const allocator = std.testing.allocator; @@ -957,22 +1020,107 @@ test "sdpaMemEfficient with mask" { defer rng_mask.deinit(); // Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable. - const q = rng.call(undefined); - const k = rng.call(undefined); - const v = rng.call(undefined); - const mask = rng_mask.call(undefined); + const q = rng.call(undefined).withTags(.{ .b, .h, .q, .hd }); + const k = rng.call(undefined).withTags(.{ .b, .h, .k, .hd }); + const v = rng.call(undefined).withTags(.{ .b, .h, .k, .hd }); + const mask = rng_mask.call(undefined).withTags(.{ .q, .k }); - const ref_res = try zml.testing.compileAndCall(platform, sdpa, .{ q.withTags(.{ .b, .h, .q, .hd }), k.withTags(.{ .b, .h, .k, .hd }), v.withTags(.{ .b, .h, .k, .hd }), .{ .attn_mask = mask.withTags(.{ .q, .k }), .scale = null, .bias = null } }); + const ref_res = try zml.testing.compileAndCall( + platform, + sdpa, + .{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } }, + ); + try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); + { + // 4 k_chunks + const res = try zml.testing.compileAndCall( + platform, + sdpaMemEfficient, + .{ + q, + k, + v, + .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) }, + }, + ); + + try zml.testing.expectClose(ref_res, res, 2e-3); + } + { + // 16 k_chunks + const res = try zml.testing.compileAndCall( + platform, + sdpaMemEfficient, + .{ + q, + k, + v, + .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) }, + }, + ); + + try zml.testing.expectClose(ref_res, res, 2e-3); + } +} + +test "sdpaMemEfficient transposed" { + const platform = zml.testing.env(); + const allocator = std.testing.allocator; + + // Note we use small input vectors to have the tests run reasonably fast, + // but don't expect speed ups with this small sizes. + const rng = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 1, 512, 10, 64 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); + defer rng.deinit(); + + const rng_mask = try zml.compileFn(allocator, Tensor.Rng.normal, .{ Shape.init(.{ 512, 512 }, .f32), .{ .mean = 0, .stddev = 1 } }, platform); + defer rng_mask.deinit(); + + // Note: it's fine to pass undefined here, cause the arguments have already been backed into the executable. + const q = rng.call(undefined).withTags(.{ .b, .q, .h, .hd }); + const k = rng.call(undefined).withTags(.{ .b, .k, .h, .hd }); + const v = rng.call(undefined).withTags(.{ .b, .k, .h, .hd }); + const mask = rng_mask.call(undefined).withTags(.{ .q, .k }); + + const ref_res = try zml.testing.compileAndCall( + platform, + sdpa, + .{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } }, + ); try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); - const res = try zml.testing.compileAndCall(platform, sdpaMemEfficient, .{ q, k, v, .{ - .query_chunk_size = 256, - .key_chunk_size = 128, - .scale = null, - .opts = .{ .attn_mask = mask, .scale = null, .bias = null, .allow_cudnn = false }, - } }); + { + const res = try zml.testing.compileAndCall( + platform, + sdpaMemEfficient, + .{ + q, + k, + v, + .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) }, + }, + ); - try zml.testing.expectClose(ref_res, res, 2e-3); + try zml.testing.expectClose(ref_res, res, 1e-3); + } + + { + const res = try zml.testing.compileAndCall( + platform, + sdpaMemEfficient, + .{ + q, + k, + v, + .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) }, + }, + ); + + try zml.testing.expectClose(ref_res, res, 1e-3); + } } /// Options controlling generation. The default values correspond to greedy decoding. diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 0bffc5f..f00ca76 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -12,16 +12,18 @@ const DataType = @import("../dtype.zig").DataType; const Data = @import("../dtype.zig").Data; const CompilationContext = module.CompilationContext; -pub fn canUseCudnnSdpa(head_dim: i64, dtype: DataType) bool { +pub fn canUseCudnnSdpa(q_shape: Shape) bool { const ctx = CompilationContext.current(); // TODO(Corendos): Check cuda version, cudnn version, device compatibility. if (!ctx.targetIs(.cuda)) return false; + if (q_shape.rank() != 4) return false; + // NOTE(Corentin): In Cudnn fused MHA head_dim is limited to 128. - if (head_dim > 128) return false; + if (q_shape.dim(.hd) > 128) return false; // NOTE(Corentin): In Cudnn fused MHA data type is limited to F16 and BF16. - if (dtype != .f16 and dtype != .bf16) return false; + if (q_shape.dtype() != .f16 and q_shape.dtype() != .bf16) return false; return true; } diff --git a/zml/tensor.zig b/zml/tensor.zig index 352cd8b..27ff65d 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -2677,7 +2677,7 @@ pub const Tensor = struct { return ops.reduce( ArgMaxRes.cmp, - .{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x._shape.withDtype(index_dtype), &.{a}) }, + .{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x.shape(), &.{a}) }, .{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) }, &.{a}, ); @@ -3171,7 +3171,7 @@ pub const Tensor = struct { var prev_ax: i8 = -1; for (self._shape.tags(), 0..) |t, self_ax| { if (update._shape.hasTag(t)) |up_ax| { - stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_._shape, self._shape }); + stdx.debug.assert(up_ax == prev_ax + 1, "dynamicUpdateSlice expects 'update_' and input tensor axis to have the same order, got {} and {}. (hint: you need to explicitly transpose 'update_')", .{ update_, self }); update_shape._dims.set(self_ax, update.dim(up_ax)); prev_ax = up_ax; @@ -3182,7 +3182,7 @@ pub const Tensor = struct { update = update.reshape(update_shape); } - stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {} (hint: it's probably an issue on our side)", .{ self.rank(), update.rank() }); + stdx.debug.assert(self.rank() == update.rank(), "dynamicUpdateSlice expects input and computed update tensors to have the same rank, got {} and {}", .{ self, update }); for (self.dims(), update.dims(), 0..) |self_d, up_d, ax| { const t = self._shape.debugTag(ax); diff --git a/zml/testing.zig b/zml/testing.zig index ba7710a..d4a2bef 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -17,6 +17,7 @@ pub fn env() zml.Platform { .{ .cache_location = "/tmp/zml/tests/cache", .xla_dump_to = "/tmp/zml/tests/", + .sharding_enabled = true, } else .{}; @@ -197,7 +198,7 @@ pub fn testLayerOut( log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in }); } - const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform); + const exe = try zml.compileModel(alloc, fwd, layer, input_shapes, platform); const n_out_exp = activations.countLayers(out_name); if (exe.inner.result_buffer_count != n_out_exp) {