zml: fix float8 <-> float32 conversions, support for Tensor.constant(.{}, .{ .f8 = 1.0})
Mostly:
* fix float8 <-> float32 conversions
* support for `Tensor.constant(.{}, .{ .f8 = 1.0})`
Misc:
* fix small inconsistencies between different versions of sdpa
* better error message for broadcast
* bazelrc: --config=debug
This commit is contained in:
parent
455bb3877f
commit
b5c4fb7c58
@ -10,6 +10,7 @@ test {
|
|||||||
|
|
||||||
pub const DataType = enum(u8) {
|
pub const DataType = enum(u8) {
|
||||||
bool,
|
bool,
|
||||||
|
// Note: the support of the float8 is a bit spotty, f8e4m3b11fnuz seems to be the most supported one on Cuda.
|
||||||
f8e4m3b11fnuz,
|
f8e4m3b11fnuz,
|
||||||
f8e4m3fn,
|
f8e4m3fn,
|
||||||
f8e4m3fnuz,
|
f8e4m3fnuz,
|
||||||
|
|||||||
121
zml/floats.zig
121
zml/floats.zig
@ -1,3 +1,4 @@
|
|||||||
|
///! Conversion utilities between different Floating point formats.
|
||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
|
|
||||||
test {
|
test {
|
||||||
@ -9,12 +10,15 @@ fn allBitsOne(v: anytype) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) type {
|
fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type) type {
|
||||||
return packed struct(std.meta.Int(.unsigned, @intCast(sign_bits + exponent_bits + mantissa_bits))) {
|
const bit_size = sign_bits + exponent_bits + mantissa_bits;
|
||||||
|
if (bit_size % 8 != 0) @compileError("FloatType should have a number of bits divisible by 8");
|
||||||
|
|
||||||
|
return packed struct(std.meta.Int(.unsigned, bit_size)) {
|
||||||
const Self = @This();
|
const Self = @This();
|
||||||
|
|
||||||
mantissa: std.meta.Int(.unsigned, @intCast(mantissa_bits)),
|
mantissa: std.meta.Int(.unsigned, mantissa_bits),
|
||||||
exponent: std.meta.Int(.unsigned, @intCast(exponent_bits)),
|
exponent: std.meta.Int(.unsigned, exponent_bits),
|
||||||
sign: std.meta.Int(.unsigned, @intCast(sign_bits)),
|
sign: std.meta.Int(.unsigned, sign_bits),
|
||||||
|
|
||||||
pub fn zero() Self {
|
pub fn zero() Self {
|
||||||
return .{
|
return .{
|
||||||
@ -35,34 +39,51 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type)
|
|||||||
/// Lossy conversion from f32, similar to @floatCast
|
/// Lossy conversion from f32, similar to @floatCast
|
||||||
pub fn fromF32(f: f32) Self {
|
pub fn fromF32(f: f32) Self {
|
||||||
const vf32: Float32 = @bitCast(f);
|
const vf32: Float32 = @bitCast(f);
|
||||||
const precision_loss = @bitSizeOf(@TypeOf(vf32.mantissa)) - mantissa_bits;
|
const exp_bias = comptime Self.expBias();
|
||||||
|
const exponent = @as(u16, vf32.exponent) + exp_bias -| Float32.expBias();
|
||||||
|
const overflow = exponent > std.math.maxInt(std.meta.Int(.unsigned, exponent_bits));
|
||||||
|
if (overflow) {
|
||||||
|
return if (@hasDecl(Self, "inf")) {
|
||||||
|
return if (vf32.sign == 0) Self.inf() else Self.minusInf();
|
||||||
|
} else Self.nan();
|
||||||
|
}
|
||||||
return .{
|
return .{
|
||||||
.sign = vf32.sign,
|
.sign = vf32.sign,
|
||||||
.exponent = @intCast(vf32.exponent),
|
.exponent = @intCast(exponent),
|
||||||
.mantissa = shr(vf32.mantissa, precision_loss),
|
.mantissa = truncMantissa(vf32.mantissa),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lossless conversion to f32.
|
/// Lossless conversion to f32.
|
||||||
pub fn toF32(self: Self) f32 {
|
pub fn toF32(self: Self) f32 {
|
||||||
var vf32: Float32 = undefined;
|
var vf32: Float32 = undefined;
|
||||||
const precision_loss = @bitSizeOf(@TypeOf(vf32.mantissa)) - mantissa_bits;
|
if (@hasDecl(Self, "isInf") and self.isInf()) {
|
||||||
|
return if (self.sign == 0) std.math.inf(f32) else -std.math.inf(f32);
|
||||||
|
}
|
||||||
vf32 = .{
|
vf32 = .{
|
||||||
.sign = self.sign,
|
.sign = self.sign,
|
||||||
.exponent = self.exponent,
|
.exponent = if (self.exponent == 0) 0 else @intCast(@as(i16, self.exponent) + Float32.expBias() - Self.expBias()),
|
||||||
.mantissa = @shlExact(@as(@TypeOf(vf32.mantissa), self.mantissa), precision_loss),
|
.mantissa = self.f32Mantissa(),
|
||||||
};
|
};
|
||||||
return @bitCast(vf32);
|
return @bitCast(vf32);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn truncMantissa(T: type, x: anytype) T {
|
fn truncMantissa(x: anytype) std.meta.FieldType(Self, .mantissa) {
|
||||||
const off = @bitSizeOf(@TypeOf(x)) - @bitSizeOf(T);
|
@setRuntimeSafety(false);
|
||||||
|
const off = @bitSizeOf(@TypeOf(x)) - mantissa_bits;
|
||||||
return @intCast(x >> off);
|
return @intCast(x >> off);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn shr(x: anytype, comptime off: u8) std.meta.Int(.unsigned, @bitSizeOf(@TypeOf(x)) - off) {
|
fn f32Mantissa(self: Self) std.meta.FieldType(Float32, .mantissa) {
|
||||||
// @setRuntimeSafety(false);
|
@setRuntimeSafety(false);
|
||||||
return @intCast(x >> off);
|
const f32_mantissa_bits = @bitSizeOf(std.meta.FieldType(Float32, .mantissa));
|
||||||
|
|
||||||
|
const Res = std.meta.FieldType(Float32, .mantissa);
|
||||||
|
return @shlExact(@as(Res, self.mantissa), f32_mantissa_bits - mantissa_bits);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expBias() u8 {
|
||||||
|
return std.math.maxInt(std.meta.Int(.unsigned, exponent_bits - 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn format(
|
pub fn format(
|
||||||
@ -72,7 +93,11 @@ fn FloatType(sign_bits: u1, exponent_bits: u8, mantissa_bits: u8, innerT: type)
|
|||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = options;
|
_ = options;
|
||||||
try writer.print("{" ++ fmt ++ "}", .{self.toF32()});
|
if (fmt.len == 1 and fmt[0] == '_') {
|
||||||
|
try writer.print("{{ .sign={}, .exp={}, .mantissa={} }}", .{ self.sign, self.exponent, self.mantissa });
|
||||||
|
} else {
|
||||||
|
try writer.print("{" ++ fmt ++ "}", .{self.toF32()});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub usingnamespace innerT;
|
pub usingnamespace innerT;
|
||||||
@ -124,6 +149,22 @@ pub const Float8E4M3FNUZ = FloatType(1, 4, 3, struct {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test "Float8E4" {
|
||||||
|
const test_case_e4: TestCase = .{
|
||||||
|
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 64.0, -128 },
|
||||||
|
.lossy = &[_]f32{3.02344107628},
|
||||||
|
};
|
||||||
|
|
||||||
|
inline for (.{
|
||||||
|
Float8E4M3B11FNUZ,
|
||||||
|
Float8E4M3FN,
|
||||||
|
Float8E4M3FNUZ,
|
||||||
|
}) |Float8T| {
|
||||||
|
try testCustomFloat(Float8T, test_case_e4);
|
||||||
|
try std.testing.expectEqual(0.0, Float8T.fromF32(1.0 / 128.0).toF32());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub const Float8E5M2 = FloatType(1, 5, 2, struct {
|
pub const Float8E5M2 = FloatType(1, 5, 2, struct {
|
||||||
pub fn nan() Float8E5M2 {
|
pub fn nan() Float8E5M2 {
|
||||||
return .{
|
return .{
|
||||||
@ -172,6 +213,16 @@ pub const Float8E5M2FNUZ = FloatType(1, 5, 2, struct {
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
test "Float8E5" {
|
||||||
|
const test_case_e5: TestCase = .{
|
||||||
|
.lossless = &[_]f32{ 0, 1.0, -2, 1.0 / 128.0, -128 },
|
||||||
|
.lossy = &[_]f32{3.02344107628},
|
||||||
|
};
|
||||||
|
inline for (.{ Float8E5M2, Float8E5M2FNUZ }) |Float8T| {
|
||||||
|
try testCustomFloat(Float8T, test_case_e5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub const BFloat16 = FloatType(1, 8, 7, struct {
|
pub const BFloat16 = FloatType(1, 8, 7, struct {
|
||||||
pub fn nan() BFloat16 {
|
pub fn nan() BFloat16 {
|
||||||
return .{
|
return .{
|
||||||
@ -215,18 +266,10 @@ test BFloat16 {
|
|||||||
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff });
|
try std.testing.expectEqual(std.mem.toBytes(BFloat16.inf().neg()), [_]u8{ 0x80, 0xff });
|
||||||
try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32)));
|
try std.testing.expectEqual(BFloat16.inf(), BFloat16.fromF32(std.math.inf(f32)));
|
||||||
|
|
||||||
const lossless = [_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) };
|
try testCustomFloat(BFloat16, .{
|
||||||
for (&lossless) |v| {
|
.lossless = &[_]f32{ 0, -2, 1.0 / 128.0, -1e64, std.math.inf(f32) },
|
||||||
try std.testing.expectEqual(v, BFloat16.fromF32(v).toF32());
|
.lossy = &[_]f32{3.02344107628},
|
||||||
}
|
});
|
||||||
const lossy = [_]f32{3.02344107628};
|
|
||||||
for (&lossy) |x| {
|
|
||||||
const y = BFloat16.fromF32(x).toF32();
|
|
||||||
if (!std.math.approxEqRel(f32, x, y, 1e-2)) {
|
|
||||||
std.log.err("expected ~{d}, got {d}", .{ x, y });
|
|
||||||
return error.TestUnexpectedResult;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn floatCast(T: type, x: anytype) T {
|
pub fn floatCast(T: type, x: anytype) T {
|
||||||
@ -235,3 +278,25 @@ pub fn floatCast(T: type, x: anytype) T {
|
|||||||
else => @floatCast(x.toF32()),
|
else => @floatCast(x.toF32()),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const TestCase = struct {
|
||||||
|
lossless: []const f32,
|
||||||
|
lossy: []const f32,
|
||||||
|
tolerance: f32 = 1e-2,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn testCustomFloat(FloatT: type, test_case: TestCase) !void {
|
||||||
|
for (test_case.lossless) |x| {
|
||||||
|
try std.testing.expectEqual(x, FloatT.fromF32(x).toF32());
|
||||||
|
}
|
||||||
|
for (test_case.lossy) |x| {
|
||||||
|
try expectApproxEqRel(f32, x, FloatT.fromF32(x).toF32(), test_case.tolerance);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn expectApproxEqRel(FloatT: type, x: FloatT, y: FloatT, tolerance: FloatT) !void {
|
||||||
|
if (!std.math.approxEqRel(f32, x, y, tolerance)) {
|
||||||
|
std.log.err("expected ~{d}, got {d}", .{ x, y });
|
||||||
|
return error.TestUnexpectedResult;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -18,7 +18,7 @@ pub const ext = struct {
|
|||||||
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?;
|
return mlir.RankedTensorType.init(sh.dims(), mlir.ext.Type.fromDType(ctx, sh.dtype())).as(mlir.Type).?;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn denseElementAttrType(dt: dtype.DataType) mlir.DenseElementsAttributeTypes {
|
pub fn denseElementAttrType(dt: dtype.DataType) ?mlir.DenseElementsAttributeTypes {
|
||||||
return switch (dt) {
|
return switch (dt) {
|
||||||
.bool => .bool,
|
.bool => .bool,
|
||||||
.i8 => .i8,
|
.i8 => .i8,
|
||||||
@ -33,7 +33,7 @@ pub const ext = struct {
|
|||||||
.f16 => .f16,
|
.f16 => .f16,
|
||||||
.f32 => .f32,
|
.f32 => .f32,
|
||||||
.f64 => .f64,
|
.f64 => .f64,
|
||||||
inline else => |tag| @panic("Unsupported data type: " ++ @tagName(tag)),
|
else => null,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -775,7 +775,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
var attn_weights = q.dot(k, .{.hd});
|
var attn_weights = q.dot(k, .{.hd});
|
||||||
// log.debug("attn_weights : {}", .{attn_weights});
|
// log.debug("attn_weights : {}", .{attn_weights});
|
||||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
|
|
||||||
attn_weights = attn_weights.convert(.f32);
|
attn_weights = attn_weights.convert(.f32);
|
||||||
if (opts.bias) |bias| {
|
if (opts.bias) |bias| {
|
||||||
@ -988,7 +988,7 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft
|
|||||||
var attn_weights = q.dot(k, .{.hd});
|
var attn_weights = q.dot(k, .{.hd});
|
||||||
// log.debug("attn_weights : {}", .{attn_weights});
|
// log.debug("attn_weights : {}", .{attn_weights});
|
||||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broadcastLeft(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
|
|
||||||
if (opts.bias) |bias| {
|
if (opts.bias) |bias| {
|
||||||
attn_weights = attn_weights.add(bias);
|
attn_weights = attn_weights.add(bias);
|
||||||
|
|||||||
@ -116,16 +116,16 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));
|
var bias = Tensor.constant(Shape.init(.{ .b = q.dim(.b), .h = q.dim(.h), .q = q.dim(.q), .k = k.dim(.k) }, q.dtype()), Data.init(q.dtype(), 0));
|
||||||
|
|
||||||
if (opts.attn_mask) |attn_mask| {
|
if (opts.attn_mask) |attn_mask| {
|
||||||
const mask = attn_mask.withTags(.{ .q, .k }).broad(bias.shape());
|
bias = bias.add(attn_mask.broad(bias.shape()));
|
||||||
bias = bias.add(mask);
|
|
||||||
}
|
}
|
||||||
if (opts.bias) |b| {
|
if (opts.bias) |b| {
|
||||||
bias = bias.add(b);
|
bias = bias.add(b);
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = ctx.mlirCtx().location(@src());
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
|
const loc = mlir_ctx.location(@src());
|
||||||
const op = dialect.stablehlo.custom_call(
|
const op = dialect.stablehlo.custom_call(
|
||||||
ctx.mlirCtx(),
|
mlir_ctx,
|
||||||
&.{ q.value(), k.value(), v.value(), bias.value() },
|
&.{ q.value(), k.value(), v.value(), bias.value() },
|
||||||
.{
|
.{
|
||||||
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
.call_target_name = "__cudnn$fmhaScaleBiasSoftmax",
|
||||||
@ -135,8 +135,8 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
.output_operand_aliases = &.{},
|
.output_operand_aliases = &.{},
|
||||||
},
|
},
|
||||||
&.{
|
&.{
|
||||||
mlir.ext.mlirType(ctx.mlirCtx(), q.shape()),
|
mlir.ext.mlirType(mlir_ctx, q.shape()),
|
||||||
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(ctx.mlirCtx()).as(mlir.Type).?).as(mlir.Type).?,
|
mlir.RankedTensorType.init(&.{0}, mlir.IntegerType(.u8).init(mlir_ctx).as(mlir.Type).?).asType(),
|
||||||
},
|
},
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|||||||
@ -205,6 +205,7 @@ pub const Tensor = struct {
|
|||||||
pub fn value(self: Tensor) mlir.Value {
|
pub fn value(self: Tensor) mlir.Value {
|
||||||
return self.getContext().getValueAndDonation(self)[0];
|
return self.getContext().getValueAndDonation(self)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Tell PJRT compiler that memory should be reuse between the two tensors.
|
/// Tell PJRT compiler that memory should be reuse between the two tensors.
|
||||||
/// The compiler is already aggressively reusing tensors for intermediate results,
|
/// The compiler is already aggressively reusing tensors for intermediate results,
|
||||||
/// but this API allows to reuse buffer between input and output arguments
|
/// but this API allows to reuse buffer between input and output arguments
|
||||||
@ -1806,13 +1807,20 @@ pub const Tensor = struct {
|
|||||||
const singleton_sh = Shape.init(.{}, val.dtype());
|
const singleton_sh = Shape.init(.{}, val.dtype());
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype());
|
|
||||||
var constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.constSlice(), loc);
|
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
|
||||||
|
dialect.stablehlo.constant(ctx, res_type, elem_type, val.constSlice(), loc)
|
||||||
|
else blk: {
|
||||||
|
// Not all dtype can be serialized in the IR. If that's not possible, use f32.
|
||||||
|
const val_f32 = val.as(f32);
|
||||||
|
break :blk dialect.stablehlo.constant(ctx, res_type, .f32, std.mem.asBytes(&val_f32), loc);
|
||||||
|
};
|
||||||
|
|
||||||
if (sh.rank() > 0) {
|
if (sh.rank() > 0) {
|
||||||
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
|
constant_op = dialect.stablehlo.broadcast_in_dim(ctx, constant_op.result(0), &.{}, mlir.ext.RankedTensorType.fromShape(ctx, sh).as(mlir.Type).?, loc);
|
||||||
}
|
}
|
||||||
return _result(sh, constant_op.result(0));
|
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Embeds a buffer with concrete values into an Mlir program.
|
/// Embeds a buffer with concrete values into an Mlir program.
|
||||||
@ -1820,7 +1828,7 @@ pub const Tensor = struct {
|
|||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
const result_type = mlir.ext.RankedTensorType.fromShape(ctx, val.shape());
|
||||||
const loc = ctx.location(@src());
|
const loc = ctx.location(@src());
|
||||||
const elem_type = mlir.ext.denseElementAttrType(val.dtype());
|
const elem_type = mlir.ext.denseElementAttrType(val.dtype()) orelse std.debug.panic("constantTensor expects a dtype that can be serialized to MLIR, like f32 or i32, got {}", .{val.shape()});
|
||||||
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc);
|
const constant_op = dialect.stablehlo.constant(ctx, result_type, elem_type, val.data, loc);
|
||||||
return _result(val.shape(), constant_op.result(0));
|
return _result(val.shape(), constant_op.result(0));
|
||||||
}
|
}
|
||||||
@ -1849,9 +1857,13 @@ pub const Tensor = struct {
|
|||||||
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
|
/// To avoid use favorise `.broad(shape)` when working with tagged tensors.
|
||||||
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
|
pub fn broadcast(self: Tensor, output_shape: Shape, axes_: []const i64) Tensor {
|
||||||
const res_shape = output_shape.withDtype(self.dtype());
|
const res_shape = output_shape.withDtype(self.dtype());
|
||||||
|
stdx.debug.assert(axes_.len == self.rank(), "broadcast expects axes_ to map all axes from self to axes of the output shape, got broadcast({}, {}, {d})", .{ self, output_shape, axes_ });
|
||||||
|
for (0.., axes_) |self_ax, other_ax| {
|
||||||
|
const d = self.dim(self_ax);
|
||||||
|
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
|
||||||
|
}
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({any}, axes={d})", .{ res_shape, axes_ });
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({}, {any}, axes={d})", .{ self, res_shape, axes_ });
|
||||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc);
|
const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc);
|
||||||
|
|
||||||
return _result(res_shape, broadcast_op.result(0));
|
return _result(res_shape, broadcast_op.result(0));
|
||||||
@ -2425,7 +2437,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
break :blk ax;
|
break :blk ax;
|
||||||
};
|
};
|
||||||
if (indices.count() == 1) {
|
if (indices.count() == 1 and !single_coord) {
|
||||||
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
|
return self.dynamicUpdateSlice1d(updates, coord_axes_.get(0), indices.reshape(.{}));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3131,6 +3143,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime.
|
/// Updates a slice of the input Tensor along a specific axis using the given 'update' Tensor, with a start offset known at runtime.
|
||||||
|
/// Note this is the untagged api, if you have tags, you should use dynamicUpdateSlice directly.
|
||||||
pub fn dynamicUpdateSlice1d(self: Tensor, update: Tensor, axis_: i64, offset: Tensor) Tensor {
|
pub fn dynamicUpdateSlice1d(self: Tensor, update: Tensor, axis_: i64, offset: Tensor) Tensor {
|
||||||
const placeholder = Tensor.scalar(0, .i32);
|
const placeholder = Tensor.scalar(0, .i32);
|
||||||
var start_indices = [_]Tensor{placeholder} ** MAX_RANK;
|
var start_indices = [_]Tensor{placeholder} ** MAX_RANK;
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user