Add operator name to source locations and introduce QoL enhancements: remove bias from sdpa, support shape literals in gatherSlices, add Shape.outer, Tensor.all, and infer argMax dtype.
This commit is contained in:
parent
223857251d
commit
acc492454f
30
zml/nn.zig
30
zml/nn.zig
@ -716,7 +716,6 @@ pub fn causalAttnMask(
|
||||
pub const SdpaOpts = struct {
|
||||
attn_mask: ?Tensor = null,
|
||||
scale: ?Tensor = null,
|
||||
bias: ?Tensor = null,
|
||||
allow_cudnn: bool = true,
|
||||
// TODO: put a callback instead of all this field,
|
||||
// so that
|
||||
@ -769,12 +768,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
// log.debug("attn_weights : {}", .{attn_weights});
|
||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||
|
||||
attn_weights = attn_weights.convert(.f32);
|
||||
if (opts.bias) |bias| {
|
||||
attn_weights = attn_weights.add(bias);
|
||||
}
|
||||
attn_weights = attn_weights.softmax(.k).convert(q.dtype());
|
||||
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||
|
||||
var attn = attn_weights.dot(v, .{.k});
|
||||
return attn.transpose(q.shape());
|
||||
@ -983,10 +977,6 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft
|
||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||
|
||||
if (opts.bias) |bias| {
|
||||
attn_weights = attn_weights.add(bias);
|
||||
}
|
||||
|
||||
const partial = partialSoftmax(attn_weights, .k);
|
||||
const attn = partial.values.dot(v, .{.k}).transpose(q.shape());
|
||||
|
||||
@ -1021,7 +1011,7 @@ test sdpaMemEfficient {
|
||||
const ref_res = try zml.testing.compileAndCall(
|
||||
platform,
|
||||
sdpa,
|
||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
|
||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
|
||||
);
|
||||
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
||||
{
|
||||
@ -1033,7 +1023,7 @@ test sdpaMemEfficient {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
||||
.{ .attn_mask = mask, .scale = null },
|
||||
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) },
|
||||
},
|
||||
);
|
||||
@ -1049,7 +1039,7 @@ test sdpaMemEfficient {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
||||
.{ .attn_mask = mask, .scale = null },
|
||||
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) },
|
||||
},
|
||||
);
|
||||
@ -1079,7 +1069,7 @@ test "sdpaMemEfficient transposed" {
|
||||
const ref_res = try zml.testing.compileAndCall(
|
||||
platform,
|
||||
sdpa,
|
||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
|
||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
|
||||
);
|
||||
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
||||
|
||||
@ -1091,7 +1081,7 @@ test "sdpaMemEfficient transposed" {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
||||
.{ .attn_mask = mask, .scale = null },
|
||||
.{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) },
|
||||
},
|
||||
);
|
||||
@ -1107,7 +1097,7 @@ test "sdpaMemEfficient transposed" {
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
||||
.{ .attn_mask = mask, .scale = null },
|
||||
.{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) },
|
||||
},
|
||||
);
|
||||
@ -1127,7 +1117,7 @@ pub const SamplingStrategy = struct {
|
||||
/// Returns an integer tensor with a shape similar to the input, but without the .voc axis.
|
||||
pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
|
||||
if (opts.topk <= 1) {
|
||||
const next_tokens = activations.argMax(.voc, .i32).indices.squeeze(.voc);
|
||||
const next_tokens = activations.argMax(.voc).indices.squeeze(.voc);
|
||||
return .{ next_tokens, rng };
|
||||
}
|
||||
|
||||
@ -1144,7 +1134,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
||||
// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
|
||||
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||
x = x.add(gumbel_noise);
|
||||
const topk_idx = x.argMax(.topk, .i32).indices;
|
||||
const topk_idx = x.argMax(.topk).indices;
|
||||
|
||||
// topk_idx is indices into topk.values ! so in the range [0, topk]
|
||||
// Convert for the original indices from the full [0, voc] range.
|
||||
@ -1234,7 +1224,7 @@ pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: T
|
||||
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||
x = x.add(gumbel_noise);
|
||||
|
||||
const topk_idx = x.argMax(.topk, .i32).indices;
|
||||
const topk_idx = x.argMax(.topk).indices;
|
||||
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
||||
return .{ next_tokens, next_rng };
|
||||
}
|
||||
|
||||
@ -118,9 +118,6 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
||||
if (opts.attn_mask) |attn_mask| {
|
||||
bias = bias.add(attn_mask.broad(bias.shape()));
|
||||
}
|
||||
if (opts.bias) |b| {
|
||||
bias = bias.add(b);
|
||||
}
|
||||
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
const loc = mlir_ctx.location(@src());
|
||||
|
||||
@ -1008,4 +1008,20 @@ pub const Shape = struct {
|
||||
try std.testing.expectEqual(1, s.axis(.b));
|
||||
}
|
||||
}
|
||||
|
||||
pub fn outer(self: Shape, other: Shape) Shape {
|
||||
var res_shape = self;
|
||||
var batching_axes: u8 = 0;
|
||||
for (0..other.rank()) |ax| {
|
||||
if (other.tag(ax) != Shape.TagUnknown) {
|
||||
if (self.hasTag(other.tag(ax))) |batching_ax| {
|
||||
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
||||
batching_axes += 1;
|
||||
}
|
||||
}
|
||||
|
||||
res_shape = res_shape.appendDim(other.dim(ax), other.tag(ax));
|
||||
}
|
||||
return res_shape;
|
||||
}
|
||||
};
|
||||
|
||||
191
zml/tensor.zig
191
zml/tensor.zig
@ -58,9 +58,9 @@ pub const Tensor = struct {
|
||||
options: std.fmt.FormatOptions,
|
||||
writer: anytype,
|
||||
) !void {
|
||||
_ = fmt;
|
||||
_ = options;
|
||||
try writer.print("Tensor({_})", .{self._shape});
|
||||
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||
try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape});
|
||||
}
|
||||
|
||||
/// Returns the shape of a Tensor.
|
||||
@ -277,7 +277,7 @@ pub const Tensor = struct {
|
||||
|
||||
res_shape = res_shape.withDtype(dt);
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "bitCast({})", .{dt});
|
||||
const loc = self.getContext().location(@src(), "bitCast({s})", .{@tagName(dt)});
|
||||
const op = dialect.stablehlo.bitcast_convert(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -317,13 +317,6 @@ pub const Tensor = struct {
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||||
pub fn remainder(self: Tensor, other: Tensor) Tensor {
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
const op = dialect.stablehlo.remainder(self.getContext().mlirCtx(), self.value(), other.value(), loc);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||||
///
|
||||
/// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details.
|
||||
@ -349,17 +342,17 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
|
||||
pub fn shiftLeft(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
||||
return binaryOp(@src(), "shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'.
|
||||
pub fn shiftRightArithmetic(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other);
|
||||
return binaryOp(@src(), "shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'.
|
||||
pub fn shiftRightLogical(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other);
|
||||
return binaryOp(@src(), "shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other);
|
||||
}
|
||||
|
||||
/// Returns the Cholesky decomposition of the input Tensor.
|
||||
@ -369,7 +362,7 @@ pub const Tensor = struct {
|
||||
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
||||
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
|
||||
const loc = self.getContext().location(@src(), "lower={}", .{lower});
|
||||
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
@ -379,7 +372,7 @@ pub const Tensor = struct {
|
||||
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||||
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
||||
const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts });
|
||||
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
@ -492,7 +485,7 @@ pub const Tensor = struct {
|
||||
},
|
||||
};
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
||||
const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts });
|
||||
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
||||
return _result(sh, op.result(0));
|
||||
}
|
||||
@ -522,7 +515,7 @@ pub const Tensor = struct {
|
||||
/// but it is not guaranteed to be deterministic between implementations.
|
||||
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
||||
const ctx = CompilationContext.current();
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "rand.bitGen({})", .{sh});
|
||||
const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh});
|
||||
const op = dialect.stablehlo.rng_bit_generator(
|
||||
ctx.mlirCtx(),
|
||||
self.algorithm,
|
||||
@ -646,12 +639,12 @@ pub const Tensor = struct {
|
||||
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
||||
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
||||
|
||||
const ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
|
||||
const ctx = CompilationContext.current();
|
||||
const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev });
|
||||
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
||||
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
||||
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
||||
const op = dialect.stablehlo.rng(ctx, a.value(), b.value(), res_shape.value(), .NORMAL, loc);
|
||||
const op = dialect.stablehlo.rng(ctx.mlirCtx(), a.value(), b.value(), res_shape.value(), .NORMAL, loc);
|
||||
return _result(sh, op.result(0));
|
||||
}
|
||||
|
||||
@ -692,7 +685,7 @@ pub const Tensor = struct {
|
||||
// Test out the gumbel reparametrization trick
|
||||
var x = target_dist.log().withTags(.{.d}).broad(s);
|
||||
x = x.add(data);
|
||||
const samples = x.argMax(.d, .i32).indices.squeeze(.d);
|
||||
const samples = x.argMax(.d).indices.squeeze(.d);
|
||||
|
||||
// count 0, 1, 2 and 3 in samples:
|
||||
// - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58
|
||||
@ -744,7 +737,7 @@ pub const Tensor = struct {
|
||||
stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
||||
stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
||||
const loc = self.getContext().location(@src(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
||||
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
||||
return _result(self._shape, op.result(0));
|
||||
}
|
||||
@ -867,7 +860,7 @@ pub const Tensor = struct {
|
||||
batch_group_count: i64 = 1,
|
||||
},
|
||||
) Tensor {
|
||||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
||||
const loc = input.getContext().location(@src(), "opts={}", .{opts});
|
||||
return input.convolution(kernel, .{
|
||||
.window_strides = &.{opts.window_strides},
|
||||
.pad_value = opts.padding,
|
||||
@ -912,7 +905,7 @@ pub const Tensor = struct {
|
||||
batch_group_count: i64 = 1,
|
||||
},
|
||||
) Tensor {
|
||||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
||||
const loc = input.getContext().location(@src(), "opts={}", .{opts});
|
||||
return input.convolution(kernel, .{
|
||||
.window_strides = opts.window_strides,
|
||||
.pad_value = opts.padding,
|
||||
@ -935,37 +928,42 @@ pub const Tensor = struct {
|
||||
|
||||
/// Returns a Tensor containing the element-wise addition of the input Tensors.
|
||||
pub fn add(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("add", dialect.stablehlo.add)(self, other);
|
||||
return binaryOp(@src(), "add", dialect.stablehlo.add)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise subtraction of the input Tensors.
|
||||
pub fn sub(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("subtract", dialect.stablehlo.subtract)(self, other);
|
||||
return binaryOp(@src(), "subtract", dialect.stablehlo.subtract)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise multiplication of the input Tensors.
|
||||
pub fn mul(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("mul", dialect.stablehlo.multiply)(self, other);
|
||||
return binaryOp(@src(), "mul", dialect.stablehlo.multiply)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise division of the input Tensors.
|
||||
pub fn div(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("div", dialect.stablehlo.divide)(self, other);
|
||||
return binaryOp(@src(), "div", dialect.stablehlo.divide)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise exponentiation of the input Tensors.
|
||||
pub fn pow(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("pow", dialect.stablehlo.power)(self, other);
|
||||
return binaryOp(@src(), "pow", dialect.stablehlo.power)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise maximum operation of the input Tensors.
|
||||
pub fn maximum(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("maximum", dialect.stablehlo.maximum)(self, other);
|
||||
return binaryOp(@src(), "maximum", dialect.stablehlo.maximum)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise minimum operation of the input Tensors.
|
||||
pub fn minimum(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp("minimum", dialect.stablehlo.minimum)(self, other);
|
||||
return binaryOp(@src(), "minimum", dialect.stablehlo.minimum)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||||
pub fn remainder(self: Tensor, other: Tensor) Tensor {
|
||||
return binaryOp(@src(), "remainder", dialect.stablehlo.remainder)(self, other);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
|
||||
@ -988,9 +986,9 @@ pub const Tensor = struct {
|
||||
/// Returns a Tensor containing the element-wise logical operation of the input Tensors.
|
||||
pub fn logical(self: Tensor, comptime logical_op: LogicalOp, other: Tensor) Tensor {
|
||||
return switch (logical_op) {
|
||||
.OR => binaryOp("or", dialect.stablehlo.or_)(self, other),
|
||||
.XOR => binaryOp("xor", dialect.stablehlo.xor)(self, other),
|
||||
.AND => binaryOp("and", dialect.stablehlo.and_)(self, other),
|
||||
.OR => binaryOp(@src(), "or", dialect.stablehlo.or_)(self, other),
|
||||
.XOR => binaryOp(@src(), "xor", dialect.stablehlo.xor)(self, other),
|
||||
.AND => binaryOp(@src(), "and", dialect.stablehlo.and_)(self, other),
|
||||
};
|
||||
}
|
||||
|
||||
@ -1007,16 +1005,16 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise conversion to another type.
|
||||
pub fn convert(self: Tensor, dt: DataType) Tensor {
|
||||
if (dt == self.dtype()) {
|
||||
pub fn convert(self: Tensor, to: DataType) Tensor {
|
||||
if (to == self.dtype()) {
|
||||
return self;
|
||||
}
|
||||
|
||||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), dt)).as(mlir.Type).?;
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dtype={}", .{dt});
|
||||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?;
|
||||
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
||||
|
||||
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
||||
return _result(self._shape.withDtype(dt), op.result(0));
|
||||
return _result(self._shape.withDtype(to), op.result(0));
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
|
||||
@ -1174,7 +1172,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
const mlir_ctx = lhs.getContext().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src());
|
||||
const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes });
|
||||
const op = dialect.stablehlo.dot_general(
|
||||
mlir_ctx,
|
||||
lhs.value(),
|
||||
@ -1375,7 +1373,7 @@ pub const Tensor = struct {
|
||||
return self.reshape(res_shape);
|
||||
}
|
||||
|
||||
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation });
|
||||
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation });
|
||||
const op = dialect.stablehlo.transpose(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -1408,7 +1406,7 @@ pub const Tensor = struct {
|
||||
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
||||
const new_shape = self._shape.set(a, n).insert(a + 1, .{ ._ = new_dim });
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, n={}", .{ axis_, n });
|
||||
const loc = self.getContext().location(@src(), "axis={}, n={}", .{ axis_, n });
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -1425,7 +1423,7 @@ pub const Tensor = struct {
|
||||
pub fn splitAxis(self: Tensor, ax: anytype, split_shape: anytype) Tensor {
|
||||
const new_shape = self._shape.splitAxis(ax, split_shape);
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "splitAxis({}, {any})", .{ ax, split_shape });
|
||||
const loc = self.getContext().location(@src(), "splitAxis({}, {any})", .{ ax, split_shape });
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -1463,8 +1461,7 @@ pub const Tensor = struct {
|
||||
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||||
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
|
||||
|
||||
const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ });
|
||||
const reshaped_val = dialect.stablehlo.reshape(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -1582,8 +1579,9 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
const res_shape = tensors[0]._shape.set(a, concatenated_dim);
|
||||
const loc = tensors[0].getContext().mlirCtx().location(@src()).namedFmt(tensors[0].getContext().mlirCtx(), "axis={}", .{axis_});
|
||||
const op = dialect.stablehlo.concatenate(tensors[0].getContext().mlirCtx(), buffer[0..tensors.len], a, loc);
|
||||
const ctx = tensors[0].getContext();
|
||||
const loc = ctx.location(@src(), "axis={}", .{axis_});
|
||||
const op = dialect.stablehlo.concatenate(ctx.mlirCtx(), buffer[0..tensors.len], a, loc);
|
||||
// log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape });
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
@ -1601,7 +1599,7 @@ pub const Tensor = struct {
|
||||
const res_shape = shape0.insertTag(axis_, 1, tag);
|
||||
|
||||
for (tensors[1..]) |tensor| {
|
||||
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
|
||||
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape });
|
||||
}
|
||||
|
||||
var reshaped: [32]Tensor = undefined;
|
||||
@ -1748,7 +1746,7 @@ pub const Tensor = struct {
|
||||
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||
|
||||
const ctx = CompilationContext.current();
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
|
||||
const loc = ctx.location(@src(), "arange({}, dtype={})", .{ args, dt });
|
||||
|
||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||
const sh = Shape.init(.{n_steps}, dt);
|
||||
@ -1775,9 +1773,10 @@ pub const Tensor = struct {
|
||||
const a = sh.axis(axis_);
|
||||
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||
const res_shape = sh.withDtype(dt);
|
||||
const mlir_ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a });
|
||||
const ctx = CompilationContext.current();
|
||||
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
|
||||
|
||||
const mlir_ctx = ctx.mlirCtx();
|
||||
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
||||
return _result(res_shape, op.result(0));
|
||||
}
|
||||
@ -1795,7 +1794,7 @@ pub const Tensor = struct {
|
||||
stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
||||
|
||||
const ctx = CompilationContext.current();
|
||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
|
||||
const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt });
|
||||
|
||||
const sh = Shape.init(.{args.steps}, dt);
|
||||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||||
@ -1838,7 +1837,7 @@ pub const Tensor = struct {
|
||||
const sh = Shape.init(dimz, val.dtype());
|
||||
const singleton_sh = Shape.init(.{}, val.dtype());
|
||||
const ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
||||
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
|
||||
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||||
|
||||
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
|
||||
@ -1871,22 +1870,8 @@ pub const Tensor = struct {
|
||||
return self.mul(other);
|
||||
}
|
||||
|
||||
const other_shape = other.shape();
|
||||
var res_shape = self.shape();
|
||||
var batching_axes: u8 = 0;
|
||||
for (0..other.rank()) |ax| {
|
||||
if (other_shape.tag(ax) != Shape.TagUnknown) {
|
||||
if (self.shape().hasTag(other_shape.tag(ax))) |batching_ax| {
|
||||
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
||||
batching_axes += 1;
|
||||
}
|
||||
}
|
||||
|
||||
res_shape = res_shape.appendDim(other_shape.dim(ax), other_shape.tag(ax));
|
||||
}
|
||||
const left = self.broad(res_shape);
|
||||
const right = other.broad(res_shape);
|
||||
return left.mul(right);
|
||||
const res_shape = self.shape().outer(other.shape());
|
||||
return self.broad(res_shape).mul(other.broad(res_shape));
|
||||
}
|
||||
|
||||
/// Given a tensor and a shape of the same rank,
|
||||
@ -1904,9 +1889,10 @@ pub const Tensor = struct {
|
||||
const d = self.dim(self_ax);
|
||||
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
|
||||
}
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({}, {any}, axes={d})", .{ self, res_shape, axes_ });
|
||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc);
|
||||
const ctx = self.getContext();
|
||||
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?;
|
||||
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
||||
|
||||
return _result(res_shape, broadcast_op.result(0));
|
||||
}
|
||||
@ -1959,7 +1945,7 @@ pub const Tensor = struct {
|
||||
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
||||
const output_shape = self._shape.reshape(output_shape_);
|
||||
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reshape({any})", .{output_shape});
|
||||
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
|
||||
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
||||
return _result(output_shape, reshape_value.result(0));
|
||||
}
|
||||
@ -2050,7 +2036,7 @@ pub const Tensor = struct {
|
||||
pub fn reverse(self: Tensor, axes_: anytype) Tensor {
|
||||
const actual_axes = self._shape.axes(axes_);
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({any})", .{axes_});
|
||||
const loc = self.getContext().location(@src(), "reverse({any})", .{axes_});
|
||||
const reverse_op = dialect.stablehlo.reverse(self.getContext().mlirCtx(), self.value(), toI64(actual_axes.constSlice()), loc);
|
||||
return _result(self._shape, reverse_op.result(0));
|
||||
}
|
||||
@ -2257,7 +2243,8 @@ pub const Tensor = struct {
|
||||
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
|
||||
/// In our example the contiguous dimension .d is not sliced
|
||||
/// and gatherSlices can copy data by group of C'*D elements.
|
||||
pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32);
|
||||
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
|
||||
|
||||
const tagged_api = slice_shape.isFullyTagged();
|
||||
@ -2307,7 +2294,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
}
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src());
|
||||
const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices });
|
||||
const gather_op = dialect.stablehlo.gather(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -2331,6 +2318,12 @@ pub const Tensor = struct {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Local = struct {
|
||||
pub fn _gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
||||
return self.gatherSlices(slice_shape, indices, opts);
|
||||
}
|
||||
};
|
||||
|
||||
{
|
||||
// Only test shapes
|
||||
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||
@ -2367,7 +2360,7 @@ pub const Tensor = struct {
|
||||
|
||||
const mod = try zml.compileFn(
|
||||
std.testing.allocator,
|
||||
gatherSlices,
|
||||
Local._gatherSlices,
|
||||
.{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } },
|
||||
platform,
|
||||
);
|
||||
@ -2383,7 +2376,7 @@ pub const Tensor = struct {
|
||||
const start_indices = (try zml.Buffer.fromArray(platform, [2][2]i32{ .{ 2, 1 }, .{ 0, 3 } })).withTags(.{ .n, ._ });
|
||||
defer start_indices.deinit();
|
||||
|
||||
const result = try zml.testing.compileAndCall(platform, gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
||||
const result = try zml.testing.compileAndCall(platform, Local._gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
||||
|
||||
const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{
|
||||
.{
|
||||
@ -2730,15 +2723,14 @@ pub const Tensor = struct {
|
||||
/// Stable argmax:
|
||||
/// * bubbles up Nan
|
||||
/// * in case of equality the smallest index matching the maximum
|
||||
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
|
||||
stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
|
||||
|
||||
pub fn argMax(x: Tensor, axis_: anytype) ArgMaxRes {
|
||||
const a = x.axis(axis_);
|
||||
const dt: DataType = if (x.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||
|
||||
return ops.reduce(
|
||||
ArgMaxRes.cmp,
|
||||
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x.shape(), &.{a}) },
|
||||
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) },
|
||||
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, dt).broadcast(x.shape(), &.{a}) },
|
||||
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, dt) },
|
||||
&.{a},
|
||||
);
|
||||
}
|
||||
@ -2749,7 +2741,7 @@ pub const Tensor = struct {
|
||||
const allocator = std.testing.allocator;
|
||||
const ArgMaxTest = struct {
|
||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||
return x.argMax(1, .i32);
|
||||
return x.argMax(1);
|
||||
}
|
||||
};
|
||||
|
||||
@ -3097,7 +3089,7 @@ pub const Tensor = struct {
|
||||
|
||||
const a = self.axis(axis_);
|
||||
const new_shape = self._shape.set(a, slice_.len);
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
||||
const loc = self.getContext().location(@src(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
||||
|
||||
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
||||
start_indices[a] = slice_.start.value();
|
||||
@ -3397,7 +3389,7 @@ pub const Tensor = struct {
|
||||
|
||||
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
||||
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
|
||||
const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)});
|
||||
const op = dialect.stablehlo.compare(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -3568,17 +3560,31 @@ pub const Tensor = struct {
|
||||
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||||
pub fn any(self: Tensor, axis_: anytype) Tensor {
|
||||
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
||||
const red = ops.reduce(
|
||||
return ops.reduce(
|
||||
struct {
|
||||
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||||
return res.logical(.OR, x);
|
||||
}
|
||||
}.acc,
|
||||
pred,
|
||||
Tensor.scalar(0, pred.dtype()),
|
||||
Tensor.scalar(false, .bool),
|
||||
&.{self.axis(axis_)},
|
||||
);
|
||||
}
|
||||
|
||||
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||||
pub fn all(self: Tensor, axis_: anytype) Tensor {
|
||||
const pred = if (self.dtype() == .bool) self else self.cmp(.NE, Tensor.scalar(0, self.dtype()));
|
||||
return ops.reduce(
|
||||
struct {
|
||||
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||||
return res.logical(.AND, x);
|
||||
}
|
||||
}.acc,
|
||||
pred,
|
||||
Tensor.scalar(true, .bool),
|
||||
&.{self.axis(axis_)},
|
||||
);
|
||||
return red;
|
||||
}
|
||||
|
||||
/// Given a set of N vectors of lengths A, B, C, D,
|
||||
@ -3701,6 +3707,7 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
fn binaryOp(
|
||||
src: std.builtin.SourceLocation,
|
||||
op_name: []const u8,
|
||||
op_fn: fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation,
|
||||
) fn (Tensor, Tensor) Tensor {
|
||||
@ -3718,9 +3725,9 @@ pub const Tensor = struct {
|
||||
|
||||
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
||||
|
||||
const mlirCtx = self.getContext().mlirCtx();
|
||||
const location = mlirCtx.location(@src());
|
||||
const ret = @call(.auto, op_fn, .{ mlirCtx, self.value(), other.value(), location });
|
||||
const ctx = self.getContext();
|
||||
const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other });
|
||||
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
|
||||
return _result(self._shape, ret.result(0));
|
||||
}
|
||||
}.binaryOpHelper;
|
||||
|
||||
@ -25,6 +25,7 @@ pub const nn = @import("nn.zig");
|
||||
pub const module = @import("module.zig");
|
||||
pub const meta = @import("meta.zig");
|
||||
pub const platform = @import("platform.zig");
|
||||
pub const pjrt = @import("pjrtx.zig");
|
||||
pub const testing = @import("testing.zig");
|
||||
pub const torch = @import("torch.zig");
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user