2023-01-02 14:28:25 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
|
|
|
|
|
const Context = @import("../context.zig").Context;
|
|
|
|
|
const module = @import("../module.zig");
|
|
|
|
|
const mlir = @import("../mlir.zig");
|
|
|
|
|
const dialect = @import("mlir/dialects");
|
|
|
|
|
|
|
|
|
|
const Tensor = @import("../tensor.zig").Tensor;
|
|
|
|
|
const Shape = @import("../shape.zig").Shape;
|
|
|
|
|
const SdpaOpts = @import("../nn.zig").SdpaOpts;
|
|
|
|
|
const DataType = @import("../dtype.zig").DataType;
|
|
|
|
|
const Data = @import("../dtype.zig").Data;
|
|
|
|
|
const CompilationContext = module.CompilationContext;
|
|
|
|
|
|
2023-08-14 14:24:11 +00:00
|
|
|
pub fn canUseCudnnSdpa(q_shape: Shape) bool {
|
2023-01-02 14:28:25 +00:00
|
|
|
const ctx = CompilationContext.current();
|
|
|
|
|
// TODO(Corendos): Check cuda version, cudnn version, device compatibility.
|
2023-10-13 16:08:08 +00:00
|
|
|
if (ctx.target() != .cuda) return false;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
2023-08-14 14:24:11 +00:00
|
|
|
if (q_shape.rank() != 4) return false;
|
|
|
|
|
|
2023-01-02 14:28:25 +00:00
|
|
|
// NOTE(Corentin): In Cudnn fused MHA head_dim is limited to 128.
|
2023-08-14 14:24:11 +00:00
|
|
|
if (q_shape.dim(.hd) > 128) return false;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
// NOTE(Corentin): In Cudnn fused MHA data type is limited to F16 and BF16.
|
2023-08-14 14:24:11 +00:00
|
|
|
if (q_shape.dtype() != .f16 and q_shape.dtype() != .bf16) return false;
|
2023-01-02 14:28:25 +00:00
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn elementTypeFromDataType(dtype: DataType) [:0]const u8 {
|
|
|
|
|
return switch (dtype) {
|
|
|
|
|
.f16 => "F16",
|
|
|
|
|
.bf16 => "BF16",
|
|
|
|
|
else => @panic("Unsupported DataType"),
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|
|
|
|
const ctx = CompilationContext.current();
|
|
|
|
|
const q = q_.transpose(.{ .b, .h, .q, .hd });
|
|
|
|
|
var k = k_.transpose(.{ .b, .h, .k, .hd });
|
|
|
|
|
const v = v_.transpose(.{ .b, .h, .k, .hd });
|
|
|
|
|
|
|
|
|
|
const sqrtHeadDim: f32 = 1.0 / std.math.sqrt(@as(f32, @floatFromInt(q.dim(.hd))));
|
|
|
|
|
const scale: f32 = if (opts.scale) |_| 1.0 else sqrtHeadDim;
|
|
|
|
|
|
|
|
|
|
if (opts.scale) |s| k = k.mul(s);
|
|
|
|
|
|
|
|
|
|
var buffer: [4096]u8 = undefined;
|
|
|
|
|
var fba = std.heap.FixedBufferAllocator.init(&buffer);
|
|
|
|
|
const allocator = fba.allocator();
|
|
|
|
|
|
|
|
|
|
const backend_config =
|
|
|
|
|
std.fmt.allocPrintZ(
|
|
|
|
|
allocator,
|
|
|
|
|
\\{{
|
|
|
|
|
\\ "operation_queue_id":"0",
|
|
|
|
|
\\ "wait_on_operation_queues":[],
|
|
|
|
|
\\ "cudnn_fmha_backend_config": {{
|
|
|
|
|
\\ "algorithm": {{
|
|
|
|
|
\\ "algo_id": "0",
|
|
|
|
|
\\ "math_type": "TENSOR_OP_MATH",
|
|
|
|
|
\\ "tuning_knobs": {{"17": "1", "24": "0"}},
|
|
|
|
|
\\ "is_cudnn_frontend": true,
|
|
|
|
|
\\ "workspace_size": "0",
|
|
|
|
|
\\ }},
|
|
|
|
|
\\ "fmha_scale": {d},
|
|
|
|
|
\\ "dropout_rate": 0,
|
|
|
|
|
\\ "intermediate_tensor_shape": {{
|
|
|
|
|
\\ "element_type": "{s}",
|
|
|
|
|
\\ "dimensions": ["{}", "{}", "{}", "{}"],
|
|
|
|
|
\\ "tuple_shapes": [],
|
|
|
|
|
\\ "layout": {{
|
|
|
|
|
\\ "dim_level_types": [],
|
|
|
|
|
\\ "dim_unique": [],
|
|
|
|
|
\\ "dim_ordered": [],
|
|
|
|
|
\\ "minor_to_major": ["3", "2", "1", "0"],
|
|
|
|
|
\\ "tiles": [],
|
|
|
|
|
\\ "element_size_in_bits": "0",
|
|
|
|
|
\\ "memory_space": "0",
|
|
|
|
|
\\ "index_primitive_type": "PRIMITIVE_TYPE_INVALID",
|
|
|
|
|
\\ "pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
|
|
|
|
|
\\ "dynamic_shape_metadata_prefix_bytes": "0",
|
|
|
|
|
\\ }},
|
|
|
|
|
\\ "is_dynamic_dimension": [false, false, false, false],
|
|
|
|
|
\\ }},
|
|
|
|
|
\\ "seed": 0,
|
|
|
|
|
\\ "is_flash_attention": true,
|
|
|
|
|
\\ "mask_type": "NO_MASK",
|
|
|
|
|
\\ "bmm1_dot_dimension_numbers": {{
|
|
|
|
|
\\ "lhs_contracting_dimensions": ["3"],
|
|
|
|
|
\\ "rhs_contracting_dimensions": ["3"],
|
|
|
|
|
\\ "lhs_batch_dimensions": ["0", "1"],
|
|
|
|
|
\\ "rhs_batch_dimensions": ["0", "1"],
|
|
|
|
|
\\ }},
|
|
|
|
|
\\ "bmm2_dot_dimension_numbers": {{
|
|
|
|
|
\\ "lhs_contracting_dimensions": ["3"],
|
|
|
|
|
\\ "rhs_contracting_dimensions": ["2"],
|
|
|
|
|
\\ "lhs_batch_dimensions": ["0", "1"],
|
|
|
|
|
\\ "rhs_batch_dimensions": ["0", "1"],
|
|
|
|
|
\\ }}
|
|
|
|
|
\\ }}
|
|
|
|
|
\\}}
|
|
|
|
|
,
|
|
|
|
|
.{
|
|
|
|
|
scale,
|
|
|
|
|
elementTypeFromDataType(q.dtype()),
|
|
|
|
|
q.dim(.b),
|
|
|
|
|
q.dim(.h),
|
|
|
|
|
q.dim(.q),
|
|
|
|
|
k.dim(.k),
|
|
|
|
|
},
|
|
|
|
|
) catch unreachable;
|
|
|
|
|
|
|
|
|
|
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));
|
|
|
|
|
|
|
|
|
|
if (opts.attn_mask) |attn_mask| {
|
2023-09-21 11:15:50 +00:00
|
|
|
bias = bias.add(attn_mask.broad(bias.shape()));
|
2023-01-02 14:28:25 +00:00
|
|
|
}
|
|
|
|
|
|
2023-09-21 11:15:50 +00:00
|
|
|
const mlir_ctx = ctx.mlirCtx();
|
|
|
|
|
const loc = mlir_ctx.location(@src());
|
2023-01-02 14:28:25 +00:00
|
|
|
const op = dialect.stablehlo.custom_call(
|
2023-09-21 11:15:50 +00:00
|
|
|
mlir_ctx,
|
2023-01-02 14:28:25 +00:00
|
|
|
&.{ q.value(), k.value(), v.value(), bias.value() },
|
|
|
|
|
.{
|
|
|
|
|
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
|
|
|
|
.backend_config = backend_config,
|
|
|
|
|
.api_version = 2,
|
|
|
|
|
.has_side_effect = false,
|
|
|
|
|
.output_operand_aliases = &.{},
|
|
|
|
|
},
|
|
|
|
|
&.{
|
2023-09-21 11:15:50 +00:00
|
|
|
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
|
|
|
|
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(),
|
2023-01-02 14:28:25 +00:00
|
|
|
},
|
|
|
|
|
loc,
|
|
|
|
|
);
|
|
|
|
|
const result = Tensor._result(q.shape(), op.result(0)).transpose(q_.shape());
|
|
|
|
|
return result;
|
|
|
|
|
}
|