diff --git a/zml/dtype.zig b/zml/dtype.zig index cfc13cf..ce030ac 100644 --- a/zml/dtype.zig +++ b/zml/dtype.zig @@ -10,6 +10,7 @@ test { pub const DataType = enum(u8) { bool, + // Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda. f8e4m3b11fnuz, f8e4m3fn, f8e4m3fnuz, diff --git a/zml/floats.zig b/zml/floats.zig index ef09fa0..230ae97 100644 --- a/zml/floats.zig +++ b/zml/floats.zig @@ -1,3 +1,4 @@ +///! Conversion utilities between different Floating point formats. const std = @import("std"); test { @@ -9,12 +10,15 @@ fn allBitsOne(v: anytype) bool { } fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) type { - return packed struct(std.meta.Int(.unsigned, @intCast(sign_bits + exponent_bits + mantissa_bits))) { + const bit_size = sign_bits + exponent_bits + mantissa_bits; + if (bit_size % 8 != 0) @compileError("FloatType should have a number of bits divisible by 8"); + + return packed struct(std.meta.Int(.unsigned, bit_size)) { const Self = @This(); - mantissa: std.meta.Int(.unsigned, @intCast(mantissa_bits)), - exponent: std.meta.Int(.unsigned, @intCast(exponent_bits)), - sign: std.meta.Int(.unsigned, @intCast(sign_bits)), + mantissa: std.meta.Int(.unsigned, mantissa_bits), + exponent: std.meta.Int(.unsigned, exponent_bits), + sign: std.meta.Int(.unsigned, sign_bits), pub fn zero() Self { return .{ @@ -35,34 +39,51 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) /// Lossy conversion from f32, similar to @floatCast pub fn fromF32(f: f32) Self { const vf32: Float32 = @bitCast(f); - const precision_loss = @bitSizeOf(@TypeOf(vf32.mantissa)) - mantissa_bits; + const exp_bias = comptime Self.expBias(); + const exponent = @as(u16, vf32.exponent) + exp_bias -| Float32.expBias(); + const overflow = exponent > std.math.maxInt(std.meta.Int(.unsigned, exponent_bits)); + if (overflow) { + return if (@hasDecl(Self, "inf")) { + return if (vf32.sign == 0) Self.inf() else Self.minusInf(); + } else Self.nan(); + } return .{ .sign = vf32.sign, - .exponent = @intCast(vf32.exponent), - .mantissa = shr(vf32.mantissa, precision_loss), + .exponent = @intCast(exponent), + .mantissa = truncMantissa(vf32.mantissa), }; } /// Lossless conversion to f32. pub fn toF32(self: Self) f32 { var vf32: Float32 = undefined; - const precision_loss = @bitSizeOf(@TypeOf(vf32.mantissa)) - mantissa_bits; + if (@hasDecl(Self, "isInf") and self.isInf()) { + return if (self.sign == 0) std.math.inf(f32) else -std.math.inf(f32); + } vf32 = .{ .sign = self.sign, - .exponent = self.exponent, - .mantissa = @shlExact(@as(@TypeOf(vf32.mantissa), self.mantissa), precision_loss), + .exponent = if (self.exponent == 0) 0 else @intCast(@as(i16, self.exponent) + Float32.expBias() - Self.expBias()), + .mantissa = self.f32Mantissa(), }; return @bitCast(vf32); } - fn truncMantissa(T: type, x: anytype) T { - const off = @bitSizeOf(@TypeOf(x)) - @bitSizeOf(T); + fn truncMantissa(x: anytype) std.meta.FieldType(Self, .mantissa) { + @setRuntimeSafety(false); + const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits; return @intCast(x >> off); } - fn shr(x: anytype, comptime off: u8) std.meta.Int(.unsigned, @bitSizeOf(@TypeOf(x)) - off) { - // @setRuntimeSafety(false); - return @intCast(x >> off); + fn f32Mantissa(self: Self) std.meta.FieldType(Float32, .mantissa) { + @setRuntimeSafety(false); + const f32_mantissa_bits = @bitSizeOf(std.meta.FieldType(Float32, .mantissa)); + + const Res = std.meta.FieldType(Float32, .mantissa); + return @shlExact(@as(Res, self.mantissa), f32_mantissa_bits - mantissa_bits); + } + + fn expBias() u8 { + return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1)); } pub fn format( @@ -72,7 +93,11 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) writer: anytype, ) !void { _ = options; - try writer.print("{" ++ fmt ++ "}", .{self.toF32()}); + if (fmt.len == 1 and fmt[0] == '_') { + try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ self.sign, self.exponent, self.mantissa }); + } else { + try writer.print("{" ++ fmt ++ "}", .{self.toF32()}); + } } pub usingnamespace innerT; @@ -124,6 +149,22 @@ pub const Float8E4M3FNUZ = FloatType(1, 4, 3, struct { } }); +test "Float8E4" { + const test_case_e4: TestCase = .{ + .lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 }, + .lossy = &[_]f32{3.02344107628}, + }; + + inline for (.{ + Float8E4M3B11FNUZ, + Float8E4M3FN, + Float8E4M3FNUZ, + }) |Float8T| { + try testCustomFloat(Float8T, test_case_e4); + try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32()); + } +} + pub const Float8E5M2 = FloatType(1, 5, 2, struct { pub fn nan() Float8E5M2 { return .{ @@ -172,6 +213,16 @@ pub const Float8E5M2FNUZ = FloatType(1, 5, 2, struct { } }); +test "Float8E5" { + const test_case_e5: TestCase = .{ + .lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 128.0, -128 }, + .lossy = &[_]f32{3.02344107628}, + }; + inline for (.{ Float8E5M2, Float8E5M2FNUZ }) |Float8T| { + try testCustomFloat(Float8T, test_case_e5); + } +} + pub const BFloat16 = FloatType(1, 8, 7, struct { pub fn nan() BFloat16 { return .{ @@ -215,18 +266,10 @@ test BFloat16 { try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff }); try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32))); - const lossless = [_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) }; - for (&lossless) |v| { - try std.testing.expectEqual(v, BFloat16.fromF32(v).toF32()); - } - const lossy = [_]f32{3.02344107628}; - for (&lossy) |x| { - const y = BFloat16.fromF32(x).toF32(); - if (!std.math.approxEqRel(f32, x, y, 1e-2)) { - std.log.err("expected ~{d}, got {d}", .{ x, y }); - return error.TestUnexpectedResult; - } - } + try testCustomFloat(BFloat16, .{ + .lossless = &[_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) }, + .lossy = &[_]f32{3.02344107628}, + }); } pub fn floatCast(T: type, x: anytype) T { @@ -235,3 +278,25 @@ pub fn floatCast(T: type, x: anytype) T { else => @floatCast(x.toF32()), }; } + +const TestCase = struct { + lossless: []const f32, + lossy: []const f32, + tolerance: f32 = 1e-2, +}; + +fn testCustomFloat(FloatT: type, test_case: TestCase) !void { + for (test_case.lossless) |x| { + try std.testing.expectEqual(x, FloatT.fromF32(x).toF32()); + } + for (test_case.lossy) |x| { + try expectApproxEqRel(f32, x, FloatT.fromF32(x).toF32(), test_case.tolerance); + } +} + +fn expectApproxEqRel(FloatT: type, x: FloatT, y: FloatT, tolerance: FloatT) !void { + if (!std.math.approxEqRel(f32, x, y, tolerance)) { + std.log.err("expected ~{d}, got {d}", .{ x, y }); + return error.TestUnexpectedResult; + } +} diff --git a/zml/mlir.zig b/zml/mlir.zig index 1b0c0a5..38cb104 100644 --- a/zml/mlir.zig +++ b/zml/mlir.zig @@ -18,7 +18,7 @@ pub const ext = struct { return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?; } - pub fn denseElementAttrType(dt: dtype.DataType) mlir.DenseElementsAttributeTypes { + pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes { return switch (dt) { .bool => .bool, .i8 => .i8, @@ -33,7 +33,7 @@ pub const ext = struct { .f16 => .f16, .f32 => .f32, .f64 => .f64, - inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)), + else => null, }; } diff --git a/zml/nn.zig b/zml/nn.zig index 86f4bc2..5b33a50 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -775,7 +775,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { var attn_weights = q.dot(k, .{.hd}); // log.debug("attn_weights : {}", .{attn_weights}); // log.debug("attn_mask : {?}", .{attn_mask}); - if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape())); + if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); attn_weights = attn_weights.convert(.f32); if (opts.bias) |bias| { @@ -988,7 +988,7 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft var attn_weights = q.dot(k, .{.hd}); // log.debug("attn_weights : {}", .{attn_weights}); // log.debug("attn_mask : {?}", .{attn_mask}); - if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape())); + if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); if (opts.bias) |bias| { attn_weights = attn_weights.add(bias); diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index f00ca76..456b52d 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -116,16 +116,16 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0)); if (opts.attn_mask) |attn_mask| { - const mask = attn_mask.withTags(.{ .q, .k }).broad(bias.shape()); - bias = bias.add(mask); + bias = bias.add(attn_mask.broad(bias.shape())); } if (opts.bias) |b| { bias = bias.add(b); } - const loc = ctx.mlirCtx().location(@src()); + const mlir_ctx = ctx.mlirCtx(); + const loc = mlir_ctx.location(@src()); const op = dialect.stablehlo.custom_call( - ctx.mlirCtx(), + mlir_ctx, &.{ q.value(), k.value(), v.value(), bias.value() }, .{ .call_target_name = "__cudnn$fmhaScaleBiasSoftmax", @@ -135,8 +135,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { .output_operand_aliases = &.{}, }, &.{ - mlir.ext.mlirType(ctx.mlirCtx(), q.shape()), - mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(ctx.mlirCtx()).as(mlir.Type).?).as(mlir.Type).?, + mlir.ext.mlirType(mlir_ctx, q.shape()), + mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(), }, loc, ); diff --git a/zml/tensor.zig b/zml/tensor.zig index 27ff65d..584dd12 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -205,6 +205,7 @@ pub const Tensor = struct { pub fn value(self: Tensor) mlir.Value { return self.getContext().getValueAndDonation(self)[0]; } + /// Tell PJRT compiler that memory should be reuse between the two tensors. /// The compiler is already aggressively reusing tensors for intermediate results, /// but this API allows to reuse buffer between input and output arguments @@ -1806,13 +1807,20 @@ pub const Tensor = struct { const singleton_sh = Shape.init(.{}, val.dtype()); const ctx = CompilationContext.current().mlirCtx(); const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val }); - const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); - const elem_type = mlir.ext.denseElementAttrType(val.dtype()); - var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc); + const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); + + var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type| + dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc) + else blk: { + // Not all dtype can be serialized in the IR. If that's not possible, use f32. + const val_f32 = val.as(f32); + break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc); + }; + if (sh.rank() > 0) { constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc); } - return _result(sh, constant_op.result(0)); + return _result(sh, constant_op.result(0)).convert(val.dtype()); } /// Embeds a buffer with concrete values into an Mlir program. @@ -1820,7 +1828,7 @@ pub const Tensor = struct { const ctx = CompilationContext.current().mlirCtx(); const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape()); const loc = ctx.location(@src()); - const elem_type = mlir.ext.denseElementAttrType(val.dtype()); + const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()}); const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc); return _result(val.shape(), constant_op.result(0)); } @@ -1849,9 +1857,13 @@ pub const Tensor = struct { /// To avoid use favorise `.broad(shape)` when working with tagged tensors. pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor { const res_shape = output_shape.withDtype(self.dtype()); - + stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ }); + for (0.., axes_) |self_ax, other_ax| { + const d = self.dim(self_ax); + stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); + } const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?; - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({any}, axes={d})", .{ res_shape, axes_ }); + const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({}, {any}, axes={d})", .{ self, res_shape, axes_ }); const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc); return _result(res_shape, broadcast_op.result(0)); @@ -2425,7 +2437,7 @@ pub const Tensor = struct { break :blk ax; }; - if (indices.count() == 1) { + if (indices.count() == 1 and !single_coord) { return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{})); } @@ -3131,6 +3143,7 @@ pub const Tensor = struct { } /// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime. + /// Note this is the untagged api, if you have tags, you should use dynamicUpdateSlice directly. pub fn dynamicUpdateSlice1d(self: Tensor, update: Tensor, axis_: i64, offset: Tensor) Tensor { const placeholder = Tensor.scalar(0, .i32); var start_indices = [_]Tensor{placeholder} ** MAX_RANK;