Add operator name to source locations and introduce QoL enhancements: remove bias from sdpa, support shape literals in gatherSlices, add Shape.outer, Tensor.all, and infer argMax dtype.
This commit is contained in:
parent
223857251d
commit
acc492454f
30
zml/nn.zig
30
zml/nn.zig
@ -716,7 +716,6 @@ pub fn causalAttnMask(
|
|||||||
pub const SdpaOpts = struct {
|
pub const SdpaOpts = struct {
|
||||||
attn_mask: ?Tensor = null,
|
attn_mask: ?Tensor = null,
|
||||||
scale: ?Tensor = null,
|
scale: ?Tensor = null,
|
||||||
bias: ?Tensor = null,
|
|
||||||
allow_cudnn: bool = true,
|
allow_cudnn: bool = true,
|
||||||
// TODO: put a callback instead of all this field,
|
// TODO: put a callback instead of all this field,
|
||||||
// so that
|
// so that
|
||||||
@ -769,12 +768,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
// log.debug("attn_weights : {}", .{attn_weights});
|
// log.debug("attn_weights : {}", .{attn_weights});
|
||||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
|
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||||
attn_weights = attn_weights.convert(.f32);
|
|
||||||
if (opts.bias) |bias| {
|
|
||||||
attn_weights = attn_weights.add(bias);
|
|
||||||
}
|
|
||||||
attn_weights = attn_weights.softmax(.k).convert(q.dtype());
|
|
||||||
|
|
||||||
var attn = attn_weights.dot(v, .{.k});
|
var attn = attn_weights.dot(v, .{.k});
|
||||||
return attn.transpose(q.shape());
|
return attn.transpose(q.shape());
|
||||||
@ -983,10 +977,6 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft
|
|||||||
// log.debug("attn_mask : {?}", .{attn_mask});
|
// log.debug("attn_mask : {?}", .{attn_mask});
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
|
|
||||||
if (opts.bias) |bias| {
|
|
||||||
attn_weights = attn_weights.add(bias);
|
|
||||||
}
|
|
||||||
|
|
||||||
const partial = partialSoftmax(attn_weights, .k);
|
const partial = partialSoftmax(attn_weights, .k);
|
||||||
const attn = partial.values.dot(v, .{.k}).transpose(q.shape());
|
const attn = partial.values.dot(v, .{.k}).transpose(q.shape());
|
||||||
|
|
||||||
@ -1021,7 +1011,7 @@ test sdpaMemEfficient {
|
|||||||
const ref_res = try zml.testing.compileAndCall(
|
const ref_res = try zml.testing.compileAndCall(
|
||||||
platform,
|
platform,
|
||||||
sdpa,
|
sdpa,
|
||||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
|
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
|
||||||
);
|
);
|
||||||
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
||||||
{
|
{
|
||||||
@ -1033,7 +1023,7 @@ test sdpaMemEfficient {
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
.{ .attn_mask = mask, .scale = null },
|
||||||
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) },
|
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) },
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -1049,7 +1039,7 @@ test sdpaMemEfficient {
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
.{ .attn_mask = mask, .scale = null },
|
||||||
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) },
|
.{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) },
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -1079,7 +1069,7 @@ test "sdpaMemEfficient transposed" {
|
|||||||
const ref_res = try zml.testing.compileAndCall(
|
const ref_res = try zml.testing.compileAndCall(
|
||||||
platform,
|
platform,
|
||||||
sdpa,
|
sdpa,
|
||||||
.{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } },
|
.{ q, k, v, .{ .attn_mask = mask, .scale = null } },
|
||||||
);
|
);
|
||||||
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims());
|
||||||
|
|
||||||
@ -1091,7 +1081,7 @@ test "sdpaMemEfficient transposed" {
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
.{ .attn_mask = mask, .scale = null },
|
||||||
.{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) },
|
.{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) },
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -1107,7 +1097,7 @@ test "sdpaMemEfficient transposed" {
|
|||||||
q,
|
q,
|
||||||
k,
|
k,
|
||||||
v,
|
v,
|
||||||
.{ .attn_mask = mask, .scale = null, .bias = null },
|
.{ .attn_mask = mask, .scale = null },
|
||||||
.{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) },
|
.{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) },
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
@ -1127,7 +1117,7 @@ pub const SamplingStrategy = struct {
|
|||||||
/// Returns an integer tensor with a shape similar to the input, but without the .voc axis.
|
/// Returns an integer tensor with a shape similar to the input, but without the .voc axis.
|
||||||
pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
|
pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } {
|
||||||
if (opts.topk <= 1) {
|
if (opts.topk <= 1) {
|
||||||
const next_tokens = activations.argMax(.voc, .i32).indices.squeeze(.voc);
|
const next_tokens = activations.argMax(.voc).indices.squeeze(.voc);
|
||||||
return .{ next_tokens, rng };
|
return .{ next_tokens, rng };
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1144,7 +1134,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng
|
|||||||
// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
|
// https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks
|
||||||
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||||
x = x.add(gumbel_noise);
|
x = x.add(gumbel_noise);
|
||||||
const topk_idx = x.argMax(.topk, .i32).indices;
|
const topk_idx = x.argMax(.topk).indices;
|
||||||
|
|
||||||
// topk_idx is indices into topk.values ! so in the range [0, topk]
|
// topk_idx is indices into topk.values ! so in the range [0, topk]
|
||||||
// Convert for the original indices from the full [0, voc] range.
|
// Convert for the original indices from the full [0, voc] range.
|
||||||
@ -1234,7 +1224,7 @@ pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: T
|
|||||||
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
const next_rng, const gumbel_noise = rng.gumbel(x.shape());
|
||||||
x = x.add(gumbel_noise);
|
x = x.add(gumbel_noise);
|
||||||
|
|
||||||
const topk_idx = x.argMax(.topk, .i32).indices;
|
const topk_idx = x.argMax(.topk).indices;
|
||||||
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{});
|
||||||
return .{ next_tokens, next_rng };
|
return .{ next_tokens, next_rng };
|
||||||
}
|
}
|
||||||
|
|||||||
@ -118,9 +118,6 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
if (opts.attn_mask) |attn_mask| {
|
if (opts.attn_mask) |attn_mask| {
|
||||||
bias = bias.add(attn_mask.broad(bias.shape()));
|
bias = bias.add(attn_mask.broad(bias.shape()));
|
||||||
}
|
}
|
||||||
if (opts.bias) |b| {
|
|
||||||
bias = bias.add(b);
|
|
||||||
}
|
|
||||||
|
|
||||||
const mlir_ctx = ctx.mlirCtx();
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
const loc = mlir_ctx.location(@src());
|
const loc = mlir_ctx.location(@src());
|
||||||
|
|||||||
@ -1008,4 +1008,20 @@ pub const Shape = struct {
|
|||||||
try std.testing.expectEqual(1, s.axis(.b));
|
try std.testing.expectEqual(1, s.axis(.b));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn outer(self: Shape, other: Shape) Shape {
|
||||||
|
var res_shape = self;
|
||||||
|
var batching_axes: u8 = 0;
|
||||||
|
for (0..other.rank()) |ax| {
|
||||||
|
if (other.tag(ax) != Shape.TagUnknown) {
|
||||||
|
if (self.hasTag(other.tag(ax))) |batching_ax| {
|
||||||
|
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
||||||
|
batching_axes += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res_shape = res_shape.appendDim(other.dim(ax), other.tag(ax));
|
||||||
|
}
|
||||||
|
return res_shape;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|||||||
191
zml/tensor.zig
191
zml/tensor.zig
@ -58,9 +58,9 @@ pub const Tensor = struct {
|
|||||||
options: std.fmt.FormatOptions,
|
options: std.fmt.FormatOptions,
|
||||||
writer: anytype,
|
writer: anytype,
|
||||||
) !void {
|
) !void {
|
||||||
_ = fmt;
|
|
||||||
_ = options;
|
_ = options;
|
||||||
try writer.print("Tensor({_})", .{self._shape});
|
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||||
|
try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape});
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the shape of a Tensor.
|
/// Returns the shape of a Tensor.
|
||||||
@ -277,7 +277,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
res_shape = res_shape.withDtype(dt);
|
res_shape = res_shape.withDtype(dt);
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "bitCast({})", .{dt});
|
const loc = self.getContext().location(@src(), "bitCast({s})", .{@tagName(dt)});
|
||||||
const op = dialect.stablehlo.bitcast_convert(
|
const op = dialect.stablehlo.bitcast_convert(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -317,13 +317,6 @@ pub const Tensor = struct {
|
|||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
|
||||||
pub fn remainder(self: Tensor, other: Tensor) Tensor {
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
|
||||||
const op = dialect.stablehlo.remainder(self.getContext().mlirCtx(), self.value(), other.value(), loc);
|
|
||||||
return _result(self._shape, op.result(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||||||
///
|
///
|
||||||
/// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details.
|
/// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details.
|
||||||
@ -349,17 +342,17 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
|
/// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'.
|
||||||
pub fn shiftLeft(self: Tensor, other: Tensor) Tensor {
|
pub fn shiftLeft(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
return binaryOp(@src(), "shiftLeft", dialect.stablehlo.shift_left)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'.
|
/// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'.
|
||||||
pub fn shiftRightArithmetic(self: Tensor, other: Tensor) Tensor {
|
pub fn shiftRightArithmetic(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other);
|
return binaryOp(@src(), "shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'.
|
/// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'.
|
||||||
pub fn shiftRightLogical(self: Tensor, other: Tensor) Tensor {
|
pub fn shiftRightLogical(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other);
|
return binaryOp(@src(), "shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the Cholesky decomposition of the input Tensor.
|
/// Returns the Cholesky decomposition of the input Tensor.
|
||||||
@ -369,7 +362,7 @@ pub const Tensor = struct {
|
|||||||
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
pub fn cholesky(self: Tensor, lower: bool) Tensor {
|
||||||
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower});
|
const loc = self.getContext().location(@src(), "lower={}", .{lower});
|
||||||
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc);
|
||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -379,7 +372,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() });
|
||||||
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() });
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts });
|
||||||
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts);
|
||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -492,7 +485,7 @@ pub const Tensor = struct {
|
|||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts});
|
const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts });
|
||||||
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts);
|
||||||
return _result(sh, op.result(0));
|
return _result(sh, op.result(0));
|
||||||
}
|
}
|
||||||
@ -522,7 +515,7 @@ pub const Tensor = struct {
|
|||||||
/// but it is not guaranteed to be deterministic between implementations.
|
/// but it is not guaranteed to be deterministic between implementations.
|
||||||
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } {
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "rand.bitGen({})", .{sh});
|
const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh});
|
||||||
const op = dialect.stablehlo.rng_bit_generator(
|
const op = dialect.stablehlo.rng_bit_generator(
|
||||||
ctx.mlirCtx(),
|
ctx.mlirCtx(),
|
||||||
self.algorithm,
|
self.algorithm,
|
||||||
@ -646,12 +639,12 @@ pub const Tensor = struct {
|
|||||||
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor {
|
||||||
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()});
|
||||||
|
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts });
|
const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev });
|
||||||
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean));
|
||||||
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev));
|
||||||
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims()));
|
||||||
const op = dialect.stablehlo.rng(ctx, a.value(), b.value(), res_shape.value(), .NORMAL, loc);
|
const op = dialect.stablehlo.rng(ctx.mlirCtx(), a.value(), b.value(), res_shape.value(), .NORMAL, loc);
|
||||||
return _result(sh, op.result(0));
|
return _result(sh, op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -692,7 +685,7 @@ pub const Tensor = struct {
|
|||||||
// Test out the gumbel reparametrization trick
|
// Test out the gumbel reparametrization trick
|
||||||
var x = target_dist.log().withTags(.{.d}).broad(s);
|
var x = target_dist.log().withTags(.{.d}).broad(s);
|
||||||
x = x.add(data);
|
x = x.add(data);
|
||||||
const samples = x.argMax(.d, .i32).indices.squeeze(.d);
|
const samples = x.argMax(.d).indices.squeeze(.d);
|
||||||
|
|
||||||
// count 0, 1, 2 and 3 in samples:
|
// count 0, 1, 2 and 3 in samples:
|
||||||
// - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58
|
// - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58
|
||||||
@ -744,7 +737,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits});
|
||||||
stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits});
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
const loc = self.getContext().location(@src(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits });
|
||||||
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc);
|
||||||
return _result(self._shape, op.result(0));
|
return _result(self._shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -867,7 +860,7 @@ pub const Tensor = struct {
|
|||||||
batch_group_count: i64 = 1,
|
batch_group_count: i64 = 1,
|
||||||
},
|
},
|
||||||
) Tensor {
|
) Tensor {
|
||||||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
const loc = input.getContext().location(@src(), "opts={}", .{opts});
|
||||||
return input.convolution(kernel, .{
|
return input.convolution(kernel, .{
|
||||||
.window_strides = &.{opts.window_strides},
|
.window_strides = &.{opts.window_strides},
|
||||||
.pad_value = opts.padding,
|
.pad_value = opts.padding,
|
||||||
@ -912,7 +905,7 @@ pub const Tensor = struct {
|
|||||||
batch_group_count: i64 = 1,
|
batch_group_count: i64 = 1,
|
||||||
},
|
},
|
||||||
) Tensor {
|
) Tensor {
|
||||||
const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts});
|
const loc = input.getContext().location(@src(), "opts={}", .{opts});
|
||||||
return input.convolution(kernel, .{
|
return input.convolution(kernel, .{
|
||||||
.window_strides = opts.window_strides,
|
.window_strides = opts.window_strides,
|
||||||
.pad_value = opts.padding,
|
.pad_value = opts.padding,
|
||||||
@ -935,37 +928,42 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor containing the element-wise addition of the input Tensors.
|
/// Returns a Tensor containing the element-wise addition of the input Tensors.
|
||||||
pub fn add(self: Tensor, other: Tensor) Tensor {
|
pub fn add(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("add", dialect.stablehlo.add)(self, other);
|
return binaryOp(@src(), "add", dialect.stablehlo.add)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise subtraction of the input Tensors.
|
/// Returns a Tensor containing the element-wise subtraction of the input Tensors.
|
||||||
pub fn sub(self: Tensor, other: Tensor) Tensor {
|
pub fn sub(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("subtract", dialect.stablehlo.subtract)(self, other);
|
return binaryOp(@src(), "subtract", dialect.stablehlo.subtract)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise multiplication of the input Tensors.
|
/// Returns a Tensor containing the element-wise multiplication of the input Tensors.
|
||||||
pub fn mul(self: Tensor, other: Tensor) Tensor {
|
pub fn mul(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("mul", dialect.stablehlo.multiply)(self, other);
|
return binaryOp(@src(), "mul", dialect.stablehlo.multiply)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise division of the input Tensors.
|
/// Returns a Tensor containing the element-wise division of the input Tensors.
|
||||||
pub fn div(self: Tensor, other: Tensor) Tensor {
|
pub fn div(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("div", dialect.stablehlo.divide)(self, other);
|
return binaryOp(@src(), "div", dialect.stablehlo.divide)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise exponentiation of the input Tensors.
|
/// Returns a Tensor containing the element-wise exponentiation of the input Tensors.
|
||||||
pub fn pow(self: Tensor, other: Tensor) Tensor {
|
pub fn pow(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("pow", dialect.stablehlo.power)(self, other);
|
return binaryOp(@src(), "pow", dialect.stablehlo.power)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise maximum operation of the input Tensors.
|
/// Returns a Tensor containing the element-wise maximum operation of the input Tensors.
|
||||||
pub fn maximum(self: Tensor, other: Tensor) Tensor {
|
pub fn maximum(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("maximum", dialect.stablehlo.maximum)(self, other);
|
return binaryOp(@src(), "maximum", dialect.stablehlo.maximum)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise minimum operation of the input Tensors.
|
/// Returns a Tensor containing the element-wise minimum operation of the input Tensors.
|
||||||
pub fn minimum(self: Tensor, other: Tensor) Tensor {
|
pub fn minimum(self: Tensor, other: Tensor) Tensor {
|
||||||
return binaryOp("minimum", dialect.stablehlo.minimum)(self, other);
|
return binaryOp(@src(), "minimum", dialect.stablehlo.minimum)(self, other);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'.
|
||||||
|
pub fn remainder(self: Tensor, other: Tensor) Tensor {
|
||||||
|
return binaryOp(@src(), "remainder", dialect.stablehlo.remainder)(self, other);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
|
/// Returns a Tensor containing the element-wise addition of the input Tensor with a constant.
|
||||||
@ -988,9 +986,9 @@ pub const Tensor = struct {
|
|||||||
/// Returns a Tensor containing the element-wise logical operation of the input Tensors.
|
/// Returns a Tensor containing the element-wise logical operation of the input Tensors.
|
||||||
pub fn logical(self: Tensor, comptime logical_op: LogicalOp, other: Tensor) Tensor {
|
pub fn logical(self: Tensor, comptime logical_op: LogicalOp, other: Tensor) Tensor {
|
||||||
return switch (logical_op) {
|
return switch (logical_op) {
|
||||||
.OR => binaryOp("or", dialect.stablehlo.or_)(self, other),
|
.OR => binaryOp(@src(), "or", dialect.stablehlo.or_)(self, other),
|
||||||
.XOR => binaryOp("xor", dialect.stablehlo.xor)(self, other),
|
.XOR => binaryOp(@src(), "xor", dialect.stablehlo.xor)(self, other),
|
||||||
.AND => binaryOp("and", dialect.stablehlo.and_)(self, other),
|
.AND => binaryOp(@src(), "and", dialect.stablehlo.and_)(self, other),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1007,16 +1005,16 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise conversion to another type.
|
/// Returns a Tensor containing the element-wise conversion to another type.
|
||||||
pub fn convert(self: Tensor, dt: DataType) Tensor {
|
pub fn convert(self: Tensor, to: DataType) Tensor {
|
||||||
if (dt == self.dtype()) {
|
if (to == self.dtype()) {
|
||||||
return self;
|
return self;
|
||||||
}
|
}
|
||||||
|
|
||||||
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), dt)).as(mlir.Type).?;
|
const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?;
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dtype={}", .{dt});
|
const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) });
|
||||||
|
|
||||||
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc);
|
||||||
return _result(self._shape.withDtype(dt), op.result(0));
|
return _result(self._shape.withDtype(to), op.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
|
/// Returns a Tensor containing the element-wise rounding operation of the input Tensor.
|
||||||
@ -1174,7 +1172,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const mlir_ctx = lhs.getContext().mlirCtx();
|
const mlir_ctx = lhs.getContext().mlirCtx();
|
||||||
const loc = mlir_ctx.location(@src());
|
const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes });
|
||||||
const op = dialect.stablehlo.dot_general(
|
const op = dialect.stablehlo.dot_general(
|
||||||
mlir_ctx,
|
mlir_ctx,
|
||||||
lhs.value(),
|
lhs.value(),
|
||||||
@ -1375,7 +1373,7 @@ pub const Tensor = struct {
|
|||||||
return self.reshape(res_shape);
|
return self.reshape(res_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation });
|
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation });
|
||||||
const op = dialect.stablehlo.transpose(
|
const op = dialect.stablehlo.transpose(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1408,7 +1406,7 @@ pub const Tensor = struct {
|
|||||||
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n });
|
||||||
const new_shape = self._shape.set(a, n).insert(a + 1, .{ ._ = new_dim });
|
const new_shape = self._shape.set(a, n).insert(a + 1, .{ ._ = new_dim });
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, n={}", .{ axis_, n });
|
const loc = self.getContext().location(@src(), "axis={}, n={}", .{ axis_, n });
|
||||||
const reshaped_val = dialect.stablehlo.reshape(
|
const reshaped_val = dialect.stablehlo.reshape(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1425,7 +1423,7 @@ pub const Tensor = struct {
|
|||||||
pub fn splitAxis(self: Tensor, ax: anytype, split_shape: anytype) Tensor {
|
pub fn splitAxis(self: Tensor, ax: anytype, split_shape: anytype) Tensor {
|
||||||
const new_shape = self._shape.splitAxis(ax, split_shape);
|
const new_shape = self._shape.splitAxis(ax, split_shape);
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "splitAxis({}, {any})", .{ ax, split_shape });
|
const loc = self.getContext().location(@src(), "splitAxis({}, {any})", .{ ax, split_shape });
|
||||||
const reshaped_val = dialect.stablehlo.reshape(
|
const reshaped_val = dialect.stablehlo.reshape(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1463,8 +1461,7 @@ pub const Tensor = struct {
|
|||||||
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
// stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis });
|
||||||
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1));
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_});
|
const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ });
|
||||||
|
|
||||||
const reshaped_val = dialect.stablehlo.reshape(
|
const reshaped_val = dialect.stablehlo.reshape(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1582,8 +1579,9 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const res_shape = tensors[0]._shape.set(a, concatenated_dim);
|
const res_shape = tensors[0]._shape.set(a, concatenated_dim);
|
||||||
const loc = tensors[0].getContext().mlirCtx().location(@src()).namedFmt(tensors[0].getContext().mlirCtx(), "axis={}", .{axis_});
|
const ctx = tensors[0].getContext();
|
||||||
const op = dialect.stablehlo.concatenate(tensors[0].getContext().mlirCtx(), buffer[0..tensors.len], a, loc);
|
const loc = ctx.location(@src(), "axis={}", .{axis_});
|
||||||
|
const op = dialect.stablehlo.concatenate(ctx.mlirCtx(), buffer[0..tensors.len], a, loc);
|
||||||
// log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape });
|
// log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape });
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -1601,7 +1599,7 @@ pub const Tensor = struct {
|
|||||||
const res_shape = shape0.insertTag(axis_, 1, tag);
|
const res_shape = shape0.insertTag(axis_, 1, tag);
|
||||||
|
|
||||||
for (tensors[1..]) |tensor| {
|
for (tensors[1..]) |tensor| {
|
||||||
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 });
|
stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape });
|
||||||
}
|
}
|
||||||
|
|
||||||
var reshaped: [32]Tensor = undefined;
|
var reshaped: [32]Tensor = undefined;
|
||||||
@ -1748,7 +1746,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt });
|
const loc = ctx.location(@src(), "arange({}, dtype={})", .{ args, dt });
|
||||||
|
|
||||||
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable;
|
||||||
const sh = Shape.init(.{n_steps}, dt);
|
const sh = Shape.init(.{n_steps}, dt);
|
||||||
@ -1775,9 +1773,10 @@ pub const Tensor = struct {
|
|||||||
const a = sh.axis(axis_);
|
const a = sh.axis(axis_);
|
||||||
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||||
const res_shape = sh.withDtype(dt);
|
const res_shape = sh.withDtype(dt);
|
||||||
const mlir_ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current();
|
||||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a });
|
const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a });
|
||||||
|
|
||||||
|
const mlir_ctx = ctx.mlirCtx();
|
||||||
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
}
|
}
|
||||||
@ -1795,7 +1794,7 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt});
|
||||||
|
|
||||||
const ctx = CompilationContext.current();
|
const ctx = CompilationContext.current();
|
||||||
const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt });
|
const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt });
|
||||||
|
|
||||||
const sh = Shape.init(.{args.steps}, dt);
|
const sh = Shape.init(.{args.steps}, dt);
|
||||||
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc);
|
||||||
@ -1838,7 +1837,7 @@ pub const Tensor = struct {
|
|||||||
const sh = Shape.init(dimz, val.dtype());
|
const sh = Shape.init(dimz, val.dtype());
|
||||||
const singleton_sh = Shape.init(.{}, val.dtype());
|
const singleton_sh = Shape.init(.{}, val.dtype());
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val });
|
const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val });
|
||||||
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh);
|
||||||
|
|
||||||
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
|
var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type|
|
||||||
@ -1871,22 +1870,8 @@ pub const Tensor = struct {
|
|||||||
return self.mul(other);
|
return self.mul(other);
|
||||||
}
|
}
|
||||||
|
|
||||||
const other_shape = other.shape();
|
const res_shape = self.shape().outer(other.shape());
|
||||||
var res_shape = self.shape();
|
return self.broad(res_shape).mul(other.broad(res_shape));
|
||||||
var batching_axes: u8 = 0;
|
|
||||||
for (0..other.rank()) |ax| {
|
|
||||||
if (other_shape.tag(ax) != Shape.TagUnknown) {
|
|
||||||
if (self.shape().hasTag(other_shape.tag(ax))) |batching_ax| {
|
|
||||||
stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other });
|
|
||||||
batching_axes += 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
res_shape = res_shape.appendDim(other_shape.dim(ax), other_shape.tag(ax));
|
|
||||||
}
|
|
||||||
const left = self.broad(res_shape);
|
|
||||||
const right = other.broad(res_shape);
|
|
||||||
return left.mul(right);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a tensor and a shape of the same rank,
|
/// Given a tensor and a shape of the same rank,
|
||||||
@ -1904,9 +1889,10 @@ pub const Tensor = struct {
|
|||||||
const d = self.dim(self_ax);
|
const d = self.dim(self_ax);
|
||||||
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
|
stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax });
|
||||||
}
|
}
|
||||||
const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?;
|
const ctx = self.getContext();
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({}, {any}, axes={d})", .{ self, res_shape, axes_ });
|
const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?;
|
||||||
const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc);
|
const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ });
|
||||||
|
const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc);
|
||||||
|
|
||||||
return _result(res_shape, broadcast_op.result(0));
|
return _result(res_shape, broadcast_op.result(0));
|
||||||
}
|
}
|
||||||
@ -1959,7 +1945,7 @@ pub const Tensor = struct {
|
|||||||
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
pub fn reshape(self: Tensor, output_shape_: anytype) Tensor {
|
||||||
const output_shape = self._shape.reshape(output_shape_);
|
const output_shape = self._shape.reshape(output_shape_);
|
||||||
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
|
const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape);
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reshape({any})", .{output_shape});
|
const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape});
|
||||||
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc);
|
||||||
return _result(output_shape, reshape_value.result(0));
|
return _result(output_shape, reshape_value.result(0));
|
||||||
}
|
}
|
||||||
@ -2050,7 +2036,7 @@ pub const Tensor = struct {
|
|||||||
pub fn reverse(self: Tensor, axes_: anytype) Tensor {
|
pub fn reverse(self: Tensor, axes_: anytype) Tensor {
|
||||||
const actual_axes = self._shape.axes(axes_);
|
const actual_axes = self._shape.axes(axes_);
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({any})", .{axes_});
|
const loc = self.getContext().location(@src(), "reverse({any})", .{axes_});
|
||||||
const reverse_op = dialect.stablehlo.reverse(self.getContext().mlirCtx(), self.value(), toI64(actual_axes.constSlice()), loc);
|
const reverse_op = dialect.stablehlo.reverse(self.getContext().mlirCtx(), self.value(), toI64(actual_axes.constSlice()), loc);
|
||||||
return _result(self._shape, reverse_op.result(0));
|
return _result(self._shape, reverse_op.result(0));
|
||||||
}
|
}
|
||||||
@ -2257,7 +2243,8 @@ pub const Tensor = struct {
|
|||||||
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
|
/// while gatherValues, always copy values one by one, and as such don't have the same issues.
|
||||||
/// In our example the contiguous dimension .d is not sliced
|
/// In our example the contiguous dimension .d is not sliced
|
||||||
/// and gatherSlices can copy data by group of C'*D elements.
|
/// and gatherSlices can copy data by group of C'*D elements.
|
||||||
pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor {
|
||||||
|
const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32);
|
||||||
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
|
// scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices });
|
||||||
|
|
||||||
const tagged_api = slice_shape.isFullyTagged();
|
const tagged_api = slice_shape.isFullyTagged();
|
||||||
@ -2307,7 +2294,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src());
|
const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices });
|
||||||
const gather_op = dialect.stablehlo.gather(
|
const gather_op = dialect.stablehlo.gather(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -2331,6 +2318,12 @@ pub const Tensor = struct {
|
|||||||
const zml = @import("zml.zig");
|
const zml = @import("zml.zig");
|
||||||
const platform = zml.testing.env();
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const Local = struct {
|
||||||
|
pub fn _gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor {
|
||||||
|
return self.gatherSlices(slice_shape, indices, opts);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
{
|
{
|
||||||
// Only test shapes
|
// Only test shapes
|
||||||
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform);
|
||||||
@ -2367,7 +2360,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const mod = try zml.compileFn(
|
const mod = try zml.compileFn(
|
||||||
std.testing.allocator,
|
std.testing.allocator,
|
||||||
gatherSlices,
|
Local._gatherSlices,
|
||||||
.{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } },
|
.{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } },
|
||||||
platform,
|
platform,
|
||||||
);
|
);
|
||||||
@ -2383,7 +2376,7 @@ pub const Tensor = struct {
|
|||||||
const start_indices = (try zml.Buffer.fromArray(platform, [2][2]i32{ .{ 2, 1 }, .{ 0, 3 } })).withTags(.{ .n, ._ });
|
const start_indices = (try zml.Buffer.fromArray(platform, [2][2]i32{ .{ 2, 1 }, .{ 0, 3 } })).withTags(.{ .n, ._ });
|
||||||
defer start_indices.deinit();
|
defer start_indices.deinit();
|
||||||
|
|
||||||
const result = try zml.testing.compileAndCall(platform, gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
const result = try zml.testing.compileAndCall(platform, Local._gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} });
|
||||||
|
|
||||||
const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{
|
const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{
|
||||||
.{
|
.{
|
||||||
@ -2730,15 +2723,14 @@ pub const Tensor = struct {
|
|||||||
/// Stable argmax:
|
/// Stable argmax:
|
||||||
/// * bubbles up Nan
|
/// * bubbles up Nan
|
||||||
/// * in case of equality the smallest index matching the maximum
|
/// * in case of equality the smallest index matching the maximum
|
||||||
pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes {
|
pub fn argMax(x: Tensor, axis_: anytype) ArgMaxRes {
|
||||||
stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype});
|
|
||||||
|
|
||||||
const a = x.axis(axis_);
|
const a = x.axis(axis_);
|
||||||
|
const dt: DataType = if (x.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||||
|
|
||||||
return ops.reduce(
|
return ops.reduce(
|
||||||
ArgMaxRes.cmp,
|
ArgMaxRes.cmp,
|
||||||
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x.shape(), &.{a}) },
|
.{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, dt).broadcast(x.shape(), &.{a}) },
|
||||||
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) },
|
.{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, dt) },
|
||||||
&.{a},
|
&.{a},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -2749,7 +2741,7 @@ pub const Tensor = struct {
|
|||||||
const allocator = std.testing.allocator;
|
const allocator = std.testing.allocator;
|
||||||
const ArgMaxTest = struct {
|
const ArgMaxTest = struct {
|
||||||
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
pub fn forward(x: Tensor) Tensor.ArgMaxRes {
|
||||||
return x.argMax(1, .i32);
|
return x.argMax(1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3097,7 +3089,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const new_shape = self._shape.set(a, slice_.len);
|
const new_shape = self._shape.set(a, slice_.len);
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
const loc = self.getContext().location(@src(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
||||||
|
|
||||||
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
||||||
start_indices[a] = slice_.start.value();
|
start_indices[a] = slice_.start.value();
|
||||||
@ -3397,7 +3389,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape });
|
||||||
|
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)});
|
const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)});
|
||||||
const op = dialect.stablehlo.compare(
|
const op = dialect.stablehlo.compare(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -3568,17 +3560,31 @@ pub const Tensor = struct {
|
|||||||
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||||||
pub fn any(self: Tensor, axis_: anytype) Tensor {
|
pub fn any(self: Tensor, axis_: anytype) Tensor {
|
||||||
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero()));
|
||||||
const red = ops.reduce(
|
return ops.reduce(
|
||||||
struct {
|
struct {
|
||||||
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||||||
return res.logical(.OR, x);
|
return res.logical(.OR, x);
|
||||||
}
|
}
|
||||||
}.acc,
|
}.acc,
|
||||||
pred,
|
pred,
|
||||||
Tensor.scalar(0, pred.dtype()),
|
Tensor.scalar(false, .bool),
|
||||||
|
&.{self.axis(axis_)},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis.
|
||||||
|
pub fn all(self: Tensor, axis_: anytype) Tensor {
|
||||||
|
const pred = if (self.dtype() == .bool) self else self.cmp(.NE, Tensor.scalar(0, self.dtype()));
|
||||||
|
return ops.reduce(
|
||||||
|
struct {
|
||||||
|
pub fn acc(x: Tensor, res: Tensor) Tensor {
|
||||||
|
return res.logical(.AND, x);
|
||||||
|
}
|
||||||
|
}.acc,
|
||||||
|
pred,
|
||||||
|
Tensor.scalar(true, .bool),
|
||||||
&.{self.axis(axis_)},
|
&.{self.axis(axis_)},
|
||||||
);
|
);
|
||||||
return red;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Given a set of N vectors of lengths A, B, C, D,
|
/// Given a set of N vectors of lengths A, B, C, D,
|
||||||
@ -3701,6 +3707,7 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn binaryOp(
|
fn binaryOp(
|
||||||
|
src: std.builtin.SourceLocation,
|
||||||
op_name: []const u8,
|
op_name: []const u8,
|
||||||
op_fn: fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation,
|
op_fn: fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation,
|
||||||
) fn (Tensor, Tensor) Tensor {
|
) fn (Tensor, Tensor) Tensor {
|
||||||
@ -3718,9 +3725,9 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape });
|
||||||
|
|
||||||
const mlirCtx = self.getContext().mlirCtx();
|
const ctx = self.getContext();
|
||||||
const location = mlirCtx.location(@src());
|
const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other });
|
||||||
const ret = @call(.auto, op_fn, .{ mlirCtx, self.value(), other.value(), location });
|
const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location });
|
||||||
return _result(self._shape, ret.result(0));
|
return _result(self._shape, ret.result(0));
|
||||||
}
|
}
|
||||||
}.binaryOpHelper;
|
}.binaryOpHelper;
|
||||||
|
|||||||
@ -25,6 +25,7 @@ pub const nn = @import("nn.zig");
|
|||||||
pub const module = @import("module.zig");
|
pub const module = @import("module.zig");
|
||||||
pub const meta = @import("meta.zig");
|
pub const meta = @import("meta.zig");
|
||||||
pub const platform = @import("platform.zig");
|
pub const platform = @import("platform.zig");
|
||||||
|
pub const pjrt = @import("pjrtx.zig");
|
||||||
pub const testing = @import("testing.zig");
|
pub const testing = @import("testing.zig");
|
||||||
pub const torch = @import("torch.zig");
|
pub const torch = @import("torch.zig");
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user