From acc492454f1d20f396df73207b8f3e2c5108e4a1 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 1 Jan 2024 15:31:41 +0000 Subject: [PATCH] Add operator name to source locations and introduce QoL enhancements: remove bias from sdpa, support shape literals in gatherSlices, add Shape.outer, Tensor.all, and infer argMax dtype. --- zml/nn.zig | 30 +++----- zml/nn/cuda.zig | 3 - zml/shape.zig | 16 ++++ zml/tensor.zig | 191 +++++++++++++++++++++++++----------------------- zml/zml.zig | 1 + 5 files changed, 126 insertions(+), 115 deletions(-) diff --git a/zml/nn.zig b/zml/nn.zig index 7ab1137..6f756e9 100644 --- a/zml/nn.zig +++ b/zml/nn.zig @@ -716,7 +716,6 @@ pub fn causalAttnMask( pub const SdpaOpts = struct { attn_mask: ?Tensor = null, scale: ?Tensor = null, - bias: ?Tensor = null, allow_cudnn: bool = true, // TODO: put a callback instead of all this field, // so that @@ -769,12 +768,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { // log.debug("attn_weights : {}", .{attn_weights}); // log.debug("attn_mask : {?}", .{attn_mask}); if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); - - attn_weights = attn_weights.convert(.f32); - if (opts.bias) |bias| { - attn_weights = attn_weights.add(bias); - } - attn_weights = attn_weights.softmax(.k).convert(q.dtype()); + attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype()); var attn = attn_weights.dot(v, .{.k}); return attn.transpose(q.shape()); @@ -983,10 +977,6 @@ pub fn sdpaChunk(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) PartialSoft // log.debug("attn_mask : {?}", .{attn_mask}); if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape())); - if (opts.bias) |bias| { - attn_weights = attn_weights.add(bias); - } - const partial = partialSoftmax(attn_weights, .k); const attn = partial.values.dot(v, .{.k}).transpose(q.shape()); @@ -1021,7 +1011,7 @@ test sdpaMemEfficient { const ref_res = try zml.testing.compileAndCall( platform, sdpa, - .{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } }, + .{ q, k, v, .{ .attn_mask = mask, .scale = null } }, ); try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); { @@ -1033,7 +1023,7 @@ test sdpaMemEfficient { q, k, v, - .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .attn_mask = mask, .scale = null }, .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 4) }, }, ); @@ -1049,7 +1039,7 @@ test sdpaMemEfficient { q, k, v, - .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .attn_mask = mask, .scale = null }, .{ .q_chunk_size = 256, .k_chunk_size = @divExact(512, 16) }, }, ); @@ -1079,7 +1069,7 @@ test "sdpaMemEfficient transposed" { const ref_res = try zml.testing.compileAndCall( platform, sdpa, - .{ q, k, v, .{ .attn_mask = mask, .scale = null, .bias = null } }, + .{ q, k, v, .{ .attn_mask = mask, .scale = null } }, ); try std.testing.expectEqualSlices(i64, q.shape().dims(), ref_res.shape().dims()); @@ -1091,7 +1081,7 @@ test "sdpaMemEfficient transposed" { q, k, v, - .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .attn_mask = mask, .scale = null }, .{ .q_chunk_size = @divExact(512, 2), .k_chunk_size = @divExact(512, 4) }, }, ); @@ -1107,7 +1097,7 @@ test "sdpaMemEfficient transposed" { q, k, v, - .{ .attn_mask = mask, .scale = null, .bias = null }, + .{ .attn_mask = mask, .scale = null }, .{ .q_chunk_size = 512, .k_chunk_size = @divExact(512, 4) }, }, ); @@ -1127,7 +1117,7 @@ pub const SamplingStrategy = struct { /// Returns an integer tensor with a shape similar to the input, but without the .voc axis. pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng) struct { Tensor, Tensor.Rng } { if (opts.topk <= 1) { - const next_tokens = activations.argMax(.voc, .i32).indices.squeeze(.voc); + const next_tokens = activations.argMax(.voc).indices.squeeze(.voc); return .{ next_tokens, rng }; } @@ -1144,7 +1134,7 @@ pub fn sampleTokens(activations: Tensor, opts: SamplingStrategy, rng: Tensor.Rng // https://en.wikipedia.org/wiki/Gumbel_distribution#Gumbel_reparametrization_tricks const next_rng, const gumbel_noise = rng.gumbel(x.shape()); x = x.add(gumbel_noise); - const topk_idx = x.argMax(.topk, .i32).indices; + const topk_idx = x.argMax(.topk).indices; // topk_idx is indices into topk.values ! so in the range [0, topk] // Convert for the original indices from the full [0, voc] range. @@ -1234,7 +1224,7 @@ pub fn sampleTokensDynamic(logits: Tensor, opts: DynamicSamplingStrategy, rng: T const next_rng, const gumbel_noise = rng.gumbel(x.shape()); x = x.add(gumbel_noise); - const topk_idx = x.argMax(.topk, .i32).indices; + const topk_idx = x.argMax(.topk).indices; const next_tokens = topk_indices.gatherValues(.voc, topk_idx.squeeze(.topk), .{}); return .{ next_tokens, next_rng }; } diff --git a/zml/nn/cuda.zig b/zml/nn/cuda.zig index 4c840cb..854da72 100644 --- a/zml/nn/cuda.zig +++ b/zml/nn/cuda.zig @@ -118,9 +118,6 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor { if (opts.attn_mask) |attn_mask| { bias = bias.add(attn_mask.broad(bias.shape())); } - if (opts.bias) |b| { - bias = bias.add(b); - } const mlir_ctx = ctx.mlirCtx(); const loc = mlir_ctx.location(@src()); diff --git a/zml/shape.zig b/zml/shape.zig index c032ba2..54fa3a9 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -1008,4 +1008,20 @@ pub const Shape = struct { try std.testing.expectEqual(1, s.axis(.b)); } } + + pub fn outer(self: Shape, other: Shape) Shape { + var res_shape = self; + var batching_axes: u8 = 0; + for (0..other.rank()) |ax| { + if (other.tag(ax) != Shape.TagUnknown) { + if (self.hasTag(other.tag(ax))) |batching_ax| { + stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other }); + batching_axes += 1; + } + } + + res_shape = res_shape.appendDim(other.dim(ax), other.tag(ax)); + } + return res_shape; + } }; diff --git a/zml/tensor.zig b/zml/tensor.zig index c375e24..8345a40 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -58,9 +58,9 @@ pub const Tensor = struct { options: std.fmt.FormatOptions, writer: anytype, ) !void { - _ = fmt; _ = options; - try writer.print("Tensor({_})", .{self._shape}); + const bare_fmt = fmt.len == 1 and fmt[0] == '_'; + try writer.print(if (bare_fmt) "{_}" else "Tensor({_})", .{self._shape}); } /// Returns the shape of a Tensor. @@ -277,7 +277,7 @@ pub const Tensor = struct { res_shape = res_shape.withDtype(dt); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "bitCast({})", .{dt}); + const loc = self.getContext().location(@src(), "bitCast({s})", .{@tagName(dt)}); const op = dialect.stablehlo.bitcast_convert( self.getContext().mlirCtx(), self.value(), @@ -317,13 +317,6 @@ pub const Tensor = struct { return _result(self._shape, op.result(0)); } - /// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'. - pub fn remainder(self: Tensor, other: Tensor) Tensor { - const loc = self.getContext().mlirCtx().location(@src()); - const op = dialect.stablehlo.remainder(self.getContext().mlirCtx(), self.value(), other.value(), loc); - return _result(self._shape, op.result(0)); - } - /// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'. /// /// See https://pytorch.org/docs/stable/generated/torch.fmod.html for more details. @@ -349,17 +342,17 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise left-shift operation of 'self' by 'other'. pub fn shiftLeft(self: Tensor, other: Tensor) Tensor { - return binaryOp("shiftLeft", dialect.stablehlo.shift_left)(self, other); + return binaryOp(@src(), "shiftLeft", dialect.stablehlo.shift_left)(self, other); } /// Returns a Tensor containing the element-wise arithmetic right-shift operation of 'self' by 'other'. pub fn shiftRightArithmetic(self: Tensor, other: Tensor) Tensor { - return binaryOp("shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other); + return binaryOp(@src(), "shiftRightArithmetic", dialect.stablehlo.shift_right_arithmetic)(self, other); } /// Returns a Tensor containing the element-wise logical right-shift operation of 'self' by 'other'. pub fn shiftRightLogical(self: Tensor, other: Tensor) Tensor { - return binaryOp("shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other); + return binaryOp(@src(), "shiftRightLogical", dialect.stablehlo.shift_right_logical)(self, other); } /// Returns the Cholesky decomposition of the input Tensor. @@ -369,7 +362,7 @@ pub const Tensor = struct { pub fn cholesky(self: Tensor, lower: bool) Tensor { stdx.debug.assert(self.rank() <= 2, "cholesky expects tensor rank to be <= 2, got {}", .{self.rank()}); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "lower={}", .{lower}); + const loc = self.getContext().location(@src(), "lower={}", .{lower}); const op = dialect.stablehlo.cholesky(self.getContext().mlirCtx(), self.value(), lower, loc); return _result(self._shape, op.result(0)); } @@ -379,7 +372,7 @@ pub const Tensor = struct { stdx.debug.assert(self.dtype() == other.dtype(), "triangularSolve expects tensors to be of the same type, got {} and {}", .{ self.dtype(), other.dtype() }); stdx.debug.assert(self.rank() <= 2 and self.rank() == other.rank(), "triangularSolve expects tensors to have the same rank and be <= 2, got {} and {}", .{ self.rank(), other.rank() }); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts}); + const loc = self.getContext().location(@src(), "triangularSolve({_}, {})", .{ self, opts }); const op = dialect.stablehlo.triangular_solve(self.getContext().mlirCtx(), self.value(), other.value(), loc, opts); return _result(self._shape, op.result(0)); } @@ -492,7 +485,7 @@ pub const Tensor = struct { }, }; - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "opts={}", .{opts}); + const loc = self.getContext().location(@src(), "fft({_},{})", .{ self, opts }); const op = dialect.stablehlo.fft(self.getContext().mlirCtx(), self.value(), loc, opts); return _result(sh, op.result(0)); } @@ -522,7 +515,7 @@ pub const Tensor = struct { /// but it is not guaranteed to be deterministic between implementations. pub fn bitGenerator(self: Rng, sh: Shape) struct { Rng, Tensor } { const ctx = CompilationContext.current(); - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "rand.bitGen({})", .{sh}); + const loc = ctx.location(@src(), "rand.bitGen({_})", .{sh}); const op = dialect.stablehlo.rng_bit_generator( ctx.mlirCtx(), self.algorithm, @@ -646,12 +639,12 @@ pub const Tensor = struct { pub fn normal(sh: Shape, opts: struct { mean: f64 = 0, stddev: f64 = 1 }) Tensor { stdx.debug.assert(sh.dtype().isFloat(), "normal expects tensor type to be a float, got {}", .{sh.dtype()}); - const ctx = CompilationContext.current().mlirCtx(); - const loc = ctx.location(@src()).namedFmt(ctx, "rand.normal({}, opts={})", .{ sh, opts }); + const ctx = CompilationContext.current(); + const loc = ctx.location(@src(), "rand.normal({_}, mean={},stddev={})", .{ sh, opts.mean, opts.stddev }); const a = Tensor.constant(.{}, Data.init(sh.dtype(), opts.mean)); const b = Tensor.constant(.{}, Data.init(sh.dtype(), opts.stddev)); const res_shape = Tensor.constantTensor(HostBuffer.fromSlice(.{sh.rank()}, sh.dims())); - const op = dialect.stablehlo.rng(ctx, a.value(), b.value(), res_shape.value(), .NORMAL, loc); + const op = dialect.stablehlo.rng(ctx.mlirCtx(), a.value(), b.value(), res_shape.value(), .NORMAL, loc); return _result(sh, op.result(0)); } @@ -692,7 +685,7 @@ pub const Tensor = struct { // Test out the gumbel reparametrization trick var x = target_dist.log().withTags(.{.d}).broad(s); x = x.add(data); - const samples = x.argMax(.d, .i32).indices.squeeze(.d); + const samples = x.argMax(.d).indices.squeeze(.d); // count 0, 1, 2 and 3 in samples: // - map 0 to 1, 1 to 2**16, 2 to 2**32, 3 to N**58 @@ -744,7 +737,7 @@ pub const Tensor = struct { stdx.debug.assert(1 <= exponent_bits, "reducePrecision expects 'exponent_bits' to be >= 1, got {}", .{exponent_bits}); stdx.debug.assert(0 <= mantissa_bits, "reducePrecision expects 'mantissa_bits' to be positive, got {}", .{mantissa_bits}); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits }); + const loc = self.getContext().location(@src(), "reducePrecision(exponent_bits={}, mantissa_bits={})", .{ exponent_bits, mantissa_bits }); const op = dialect.stablehlo.reduce_precision(self.getContext().mlirCtx(), self.value(), exponent_bits, mantissa_bits, loc); return _result(self._shape, op.result(0)); } @@ -867,7 +860,7 @@ pub const Tensor = struct { batch_group_count: i64 = 1, }, ) Tensor { - const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts}); + const loc = input.getContext().location(@src(), "opts={}", .{opts}); return input.convolution(kernel, .{ .window_strides = &.{opts.window_strides}, .pad_value = opts.padding, @@ -912,7 +905,7 @@ pub const Tensor = struct { batch_group_count: i64 = 1, }, ) Tensor { - const loc = input.getContext().mlirCtx().location(@src()).namedFmt(input.getContext().mlirCtx(), "opts={}", .{opts}); + const loc = input.getContext().location(@src(), "opts={}", .{opts}); return input.convolution(kernel, .{ .window_strides = opts.window_strides, .pad_value = opts.padding, @@ -935,37 +928,42 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise addition of the input Tensors. pub fn add(self: Tensor, other: Tensor) Tensor { - return binaryOp("add", dialect.stablehlo.add)(self, other); + return binaryOp(@src(), "add", dialect.stablehlo.add)(self, other); } /// Returns a Tensor containing the element-wise subtraction of the input Tensors. pub fn sub(self: Tensor, other: Tensor) Tensor { - return binaryOp("subtract", dialect.stablehlo.subtract)(self, other); + return binaryOp(@src(), "subtract", dialect.stablehlo.subtract)(self, other); } /// Returns a Tensor containing the element-wise multiplication of the input Tensors. pub fn mul(self: Tensor, other: Tensor) Tensor { - return binaryOp("mul", dialect.stablehlo.multiply)(self, other); + return binaryOp(@src(), "mul", dialect.stablehlo.multiply)(self, other); } /// Returns a Tensor containing the element-wise division of the input Tensors. pub fn div(self: Tensor, other: Tensor) Tensor { - return binaryOp("div", dialect.stablehlo.divide)(self, other); + return binaryOp(@src(), "div", dialect.stablehlo.divide)(self, other); } /// Returns a Tensor containing the element-wise exponentiation of the input Tensors. pub fn pow(self: Tensor, other: Tensor) Tensor { - return binaryOp("pow", dialect.stablehlo.power)(self, other); + return binaryOp(@src(), "pow", dialect.stablehlo.power)(self, other); } /// Returns a Tensor containing the element-wise maximum operation of the input Tensors. pub fn maximum(self: Tensor, other: Tensor) Tensor { - return binaryOp("maximum", dialect.stablehlo.maximum)(self, other); + return binaryOp(@src(), "maximum", dialect.stablehlo.maximum)(self, other); } /// Returns a Tensor containing the element-wise minimum operation of the input Tensors. pub fn minimum(self: Tensor, other: Tensor) Tensor { - return binaryOp("minimum", dialect.stablehlo.minimum)(self, other); + return binaryOp(@src(), "minimum", dialect.stablehlo.minimum)(self, other); + } + + /// Returns a Tensor containing the element-wise remainder of dividend 'self' and divisor 'other'. + pub fn remainder(self: Tensor, other: Tensor) Tensor { + return binaryOp(@src(), "remainder", dialect.stablehlo.remainder)(self, other); } /// Returns a Tensor containing the element-wise addition of the input Tensor with a constant. @@ -988,9 +986,9 @@ pub const Tensor = struct { /// Returns a Tensor containing the element-wise logical operation of the input Tensors. pub fn logical(self: Tensor, comptime logical_op: LogicalOp, other: Tensor) Tensor { return switch (logical_op) { - .OR => binaryOp("or", dialect.stablehlo.or_)(self, other), - .XOR => binaryOp("xor", dialect.stablehlo.xor)(self, other), - .AND => binaryOp("and", dialect.stablehlo.and_)(self, other), + .OR => binaryOp(@src(), "or", dialect.stablehlo.or_)(self, other), + .XOR => binaryOp(@src(), "xor", dialect.stablehlo.xor)(self, other), + .AND => binaryOp(@src(), "and", dialect.stablehlo.and_)(self, other), }; } @@ -1007,16 +1005,16 @@ pub const Tensor = struct { } /// Returns a Tensor containing the element-wise conversion to another type. - pub fn convert(self: Tensor, dt: DataType) Tensor { - if (dt == self.dtype()) { + pub fn convert(self: Tensor, to: DataType) Tensor { + if (to == self.dtype()) { return self; } - const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), dt)).as(mlir.Type).?; - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dtype={}", .{dt}); + const res_type = mlir.RankedTensorType.init(self.dims(), mlir.ext.Type.fromDType(self.getContext().mlirCtx(), to)).as(mlir.Type).?; + const loc = self.getContext().location(@src(), "convert({_},to={s})", .{ self, @tagName(to) }); const op = dialect.stablehlo.convert(self.getContext().mlirCtx(), self.value(), res_type, loc); - return _result(self._shape.withDtype(dt), op.result(0)); + return _result(self._shape.withDtype(to), op.result(0)); } /// Returns a Tensor containing the element-wise rounding operation of the input Tensor. @@ -1174,7 +1172,7 @@ pub const Tensor = struct { } const mlir_ctx = lhs.getContext().mlirCtx(); - const loc = mlir_ctx.location(@src()); + const loc = lhs.getContext().location(@src(), "dot({_},{_},contracting={any},batching={any}", .{ lhs, rhs, contracting_axes, batching_axes }); const op = dialect.stablehlo.dot_general( mlir_ctx, lhs.value(), @@ -1375,7 +1373,7 @@ pub const Tensor = struct { return self.reshape(res_shape); } - const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation }); + const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self, permutation }); const op = dialect.stablehlo.transpose( self.getContext().mlirCtx(), self.value(), @@ -1408,7 +1406,7 @@ pub const Tensor = struct { const new_dim = std.math.divExact(i64, self.dim(a), n) catch std.debug.panic("unflatten expects chosen dimension to be divisible by 'n' but {} is not divisible by {}", .{ self.dim(a), n }); const new_shape = self._shape.set(a, n).insert(a + 1, .{ ._ = new_dim }); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, n={}", .{ axis_, n }); + const loc = self.getContext().location(@src(), "axis={}, n={}", .{ axis_, n }); const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), @@ -1425,7 +1423,7 @@ pub const Tensor = struct { pub fn splitAxis(self: Tensor, ax: anytype, split_shape: anytype) Tensor { const new_shape = self._shape.splitAxis(ax, split_shape); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "splitAxis({}, {any})", .{ ax, split_shape }); + const loc = self.getContext().location(@src(), "splitAxis({}, {any})", .{ ax, split_shape }); const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), @@ -1463,8 +1461,7 @@ pub const Tensor = struct { // stdx.debug.assert(a + 1 < self.rank(), "Can't flatten {} on the last axis {}.", .{ self, axis }); const new_shape = old_shape.remove(a + 1).set(a, old_shape.dim(a) * old_shape.dim(a + 1)); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}", .{axis_}); - + const loc = self.getContext().location(@src(), "flatten({_},{})", .{ self, axis_ }); const reshaped_val = dialect.stablehlo.reshape( self.getContext().mlirCtx(), self.value(), @@ -1582,8 +1579,9 @@ pub const Tensor = struct { } const res_shape = tensors[0]._shape.set(a, concatenated_dim); - const loc = tensors[0].getContext().mlirCtx().location(@src()).namedFmt(tensors[0].getContext().mlirCtx(), "axis={}", .{axis_}); - const op = dialect.stablehlo.concatenate(tensors[0].getContext().mlirCtx(), buffer[0..tensors.len], a, loc); + const ctx = tensors[0].getContext(); + const loc = ctx.location(@src(), "axis={}", .{axis_}); + const op = dialect.stablehlo.concatenate(ctx.mlirCtx(), buffer[0..tensors.len], a, loc); // log.debug("concatenate({}, {}, {d}) -> {d}", .{ tensors[0], tensors[1], a, res_shape }); return _result(res_shape, op.result(0)); } @@ -1601,7 +1599,7 @@ pub const Tensor = struct { const res_shape = shape0.insertTag(axis_, 1, tag); for (tensors[1..]) |tensor| { - stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ tensor._shape, shape0 }); + stdx.debug.assert(shape0.eqlWithTags(tensor._shape), "stack expects tensor shapes to match, got {} and {}", .{ shape0, tensor._shape }); } var reshaped: [32]Tensor = undefined; @@ -1748,7 +1746,7 @@ pub const Tensor = struct { stdx.debug.assert(args.step > 0, "arange expects 'args.step' to be positive, got {}", .{args.step}); const ctx = CompilationContext.current(); - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "{}, dtype={}", .{ args, dt }); + const loc = ctx.location(@src(), "arange({}, dtype={})", .{ args, dt }); const n_steps = std.math.divCeil(i64, args.end - args.start, args.step) catch unreachable; const sh = Shape.init(.{n_steps}, dt); @@ -1775,9 +1773,10 @@ pub const Tensor = struct { const a = sh.axis(axis_); const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const res_shape = sh.withDtype(dt); - const mlir_ctx = CompilationContext.current().mlirCtx(); - const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a }); + const ctx = CompilationContext.current(); + const loc = ctx.location(@src(), "iota({_}, {})", .{ res_shape, a }); + const mlir_ctx = ctx.mlirCtx(); var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); return _result(res_shape, op.result(0)); } @@ -1795,7 +1794,7 @@ pub const Tensor = struct { stdx.debug.assert(dt.isFloat(), "linspace expects type to be a float, got {} (hint: use arange instead)", .{dt}); const ctx = CompilationContext.current(); - const loc = ctx.mlirCtx().location(@src()).namedFmt(ctx.mlirCtx(), "linspace({}, dtype={})", .{ args, dt }); + const loc = ctx.location(@src(), "linspace({}, dtype={})", .{ args, dt }); const sh = Shape.init(.{args.steps}, dt); var iota_op = dialect.stablehlo.iota(ctx.mlirCtx(), 0, mlir.ext.mlirType(ctx.mlirCtx(), sh), loc); @@ -1838,7 +1837,7 @@ pub const Tensor = struct { const sh = Shape.init(dimz, val.dtype()); const singleton_sh = Shape.init(.{}, val.dtype()); const ctx = CompilationContext.current().mlirCtx(); - const loc = ctx.location(@src()).namedFmt(ctx, "dims={d}, value={}", .{ sh, val }); + const loc = CompilationContext.current().location(@src(), "dims={d}, value={}", .{ sh, val }); const res_type = mlir.ext.RankedTensorType.fromShape(ctx, singleton_sh); var constant_op = if (mlir.ext.denseElementAttrType(val.dtype())) |elem_type| @@ -1871,22 +1870,8 @@ pub const Tensor = struct { return self.mul(other); } - const other_shape = other.shape(); - var res_shape = self.shape(); - var batching_axes: u8 = 0; - for (0..other.rank()) |ax| { - if (other_shape.tag(ax) != Shape.TagUnknown) { - if (self.shape().hasTag(other_shape.tag(ax))) |batching_ax| { - stdx.debug.assert(batching_ax == batching_axes and batching_ax == ax, "outer expects batching dims to be the first dims in both tensors, got outer({}, {})", .{ self, other }); - batching_axes += 1; - } - } - - res_shape = res_shape.appendDim(other_shape.dim(ax), other_shape.tag(ax)); - } - const left = self.broad(res_shape); - const right = other.broad(res_shape); - return left.mul(right); + const res_shape = self.shape().outer(other.shape()); + return self.broad(res_shape).mul(other.broad(res_shape)); } /// Given a tensor and a shape of the same rank, @@ -1904,9 +1889,10 @@ pub const Tensor = struct { const d = self.dim(self_ax); stdx.debug.assert(d == 1 or d == output_shape.dim(other_ax), "broadcast expects shape axes to either be 1-sized or to match the target size. got broadcast({}, {}, {d}), error on self axis {} mapping to other axis {}", .{ self, output_shape, axes_, self_ax, other_ax }); } - const result_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), res_shape).as(mlir.Type).?; - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "broadcast({}, {any}, axes={d})", .{ self, res_shape, axes_ }); - const broadcast_op = dialect.stablehlo.broadcast_in_dim(self.getContext().mlirCtx(), self.value(), axes_, result_type, loc); + const ctx = self.getContext(); + const result_type = mlir.ext.RankedTensorType.fromShape(ctx.mlirCtx(), res_shape).as(mlir.Type).?; + const loc = ctx.location(@src(), "broadcast({_}, {_}, axes={d})", .{ self, res_shape, axes_ }); + const broadcast_op = dialect.stablehlo.broadcast_in_dim(ctx.mlirCtx(), self.value(), axes_, result_type, loc); return _result(res_shape, broadcast_op.result(0)); } @@ -1959,7 +1945,7 @@ pub const Tensor = struct { pub fn reshape(self: Tensor, output_shape_: anytype) Tensor { const output_shape = self._shape.reshape(output_shape_); const tensor_type = mlir.ext.RankedTensorType.fromShape(self.getContext().mlirCtx(), output_shape); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reshape({any})", .{output_shape}); + const loc = self.getContext().location(@src(), "reshape({any})", .{output_shape}); const reshape_value = dialect.stablehlo.reshape(self.getContext().mlirCtx(), self.value(), tensor_type, loc); return _result(output_shape, reshape_value.result(0)); } @@ -2050,7 +2036,7 @@ pub const Tensor = struct { pub fn reverse(self: Tensor, axes_: anytype) Tensor { const actual_axes = self._shape.axes(axes_); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "reverse({any})", .{axes_}); + const loc = self.getContext().location(@src(), "reverse({any})", .{axes_}); const reverse_op = dialect.stablehlo.reverse(self.getContext().mlirCtx(), self.value(), toI64(actual_axes.constSlice()), loc); return _result(self._shape, reverse_op.result(0)); } @@ -2257,7 +2243,8 @@ pub const Tensor = struct { /// while gatherValues, always copy values one by one, and as such don't have the same issues. /// In our example the contiguous dimension .d is not sliced /// and gatherSlices can copy data by group of C'*D elements. - pub fn gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor { + pub fn gatherSlices(self: Tensor, slice_shape_: anytype, indices: Tensor, opts: GatherOpts) Tensor { + const slice_shape = if (@TypeOf(slice_shape_) == Shape) slice_shape_ else Shape.init(slice_shape_, .i32); // scoped_log.debug("gatherSlice({}, {_}, {})", .{ self, slice_shape, indices }); const tagged_api = slice_shape.isFullyTagged(); @@ -2307,7 +2294,7 @@ pub const Tensor = struct { } } - const loc = self.getContext().mlirCtx().location(@src()); + const loc = self.getContext().location(@src(), "gatherSlices({_}, slice_shape={_}, idx={_})", .{ self, slice_shape, indices }); const gather_op = dialect.stablehlo.gather( self.getContext().mlirCtx(), self.value(), @@ -2331,6 +2318,12 @@ pub const Tensor = struct { const zml = @import("zml.zig"); const platform = zml.testing.env(); + const Local = struct { + pub fn _gatherSlices(self: Tensor, slice_shape: Shape, indices: Tensor, opts: GatherOpts) Tensor { + return self.gatherSlices(slice_shape, indices, opts); + } + }; + { // Only test shapes var comp = try zml.module.CompilationContext.init(std.testing.allocator, "test", platform); @@ -2367,7 +2360,7 @@ pub const Tensor = struct { const mod = try zml.compileFn( std.testing.allocator, - gatherSlices, + Local._gatherSlices, .{ x.shape(), slice_shape, idx.shape(), .{ .indices_are_sorted = true } }, platform, ); @@ -2383,7 +2376,7 @@ pub const Tensor = struct { const start_indices = (try zml.Buffer.fromArray(platform, [2][2]i32{ .{ 2, 1 }, .{ 0, 3 } })).withTags(.{ .n, ._ }); defer start_indices.deinit(); - const result = try zml.testing.compileAndCall(platform, gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} }); + const result = try zml.testing.compileAndCall(platform, Local._gatherSlices, .{ operand, Shape.init(.{ .b = 2, .c = 3 }, .u16), start_indices, .{} }); const expected = zml.HostBuffer.fromArray(&[2][2][2][3]u16{ .{ @@ -2730,15 +2723,14 @@ pub const Tensor = struct { /// Stable argmax: /// * bubbles up Nan /// * in case of equality the smallest index matching the maximum - pub fn argMax(x: Tensor, axis_: anytype, index_dtype: DataType) ArgMaxRes { - stdx.debug.assert(index_dtype.isInteger(), "argMax expect index type to be an integer, got {}", .{index_dtype}); - + pub fn argMax(x: Tensor, axis_: anytype) ArgMaxRes { const a = x.axis(axis_); + const dt: DataType = if (x.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; return ops.reduce( ArgMaxRes.cmp, - .{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, index_dtype).broadcast(x.shape(), &.{a}) }, - .{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, index_dtype) }, + .{ .values = x, .indices = Tensor.arange(.{ .end = x.dim(a) }, dt).broadcast(x.shape(), &.{a}) }, + .{ .values = Tensor.constant(&.{}, x.dtype().minValue()), .indices = Tensor.scalar(0, dt) }, &.{a}, ); } @@ -2749,7 +2741,7 @@ pub const Tensor = struct { const allocator = std.testing.allocator; const ArgMaxTest = struct { pub fn forward(x: Tensor) Tensor.ArgMaxRes { - return x.argMax(1, .i32); + return x.argMax(1); } }; @@ -3097,7 +3089,7 @@ pub const Tensor = struct { const a = self.axis(axis_); const new_shape = self._shape.set(a, slice_.len); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len }); + const loc = self.getContext().location(@src(), "dynSlice({}, len={})", .{ axis_, slice_.len }); var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK; start_indices[a] = slice_.start.value(); @@ -3397,7 +3389,7 @@ pub const Tensor = struct { stdx.debug.assert(self._shape.eql(other._shape), "cmp expects input tensor shapes to match, got {} and {}", .{ self._shape, other._shape }); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "cmp(.{s})", .{@tagName(direction)}); + const loc = self.getContext().location(@src(), "cmp(.{s})", .{@tagName(direction)}); const op = dialect.stablehlo.compare( self.getContext().mlirCtx(), self.value(), @@ -3568,17 +3560,31 @@ pub const Tensor = struct { /// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis. pub fn any(self: Tensor, axis_: anytype) Tensor { const pred = self.cmp(.NE, Tensor.constant(self.dims(), self.dtype().zero())); - const red = ops.reduce( + return ops.reduce( struct { pub fn acc(x: Tensor, res: Tensor) Tensor { return res.logical(.OR, x); } }.acc, pred, - Tensor.scalar(0, pred.dtype()), + Tensor.scalar(false, .bool), + &.{self.axis(axis_)}, + ); + } + + /// Returns a Tensor containing boolean indicating if there is a non-zero value over the given axis. + pub fn all(self: Tensor, axis_: anytype) Tensor { + const pred = if (self.dtype() == .bool) self else self.cmp(.NE, Tensor.scalar(0, self.dtype())); + return ops.reduce( + struct { + pub fn acc(x: Tensor, res: Tensor) Tensor { + return res.logical(.AND, x); + } + }.acc, + pred, + Tensor.scalar(true, .bool), &.{self.axis(axis_)}, ); - return red; } /// Given a set of N vectors of lengths A, B, C, D, @@ -3701,6 +3707,7 @@ pub const Tensor = struct { } fn binaryOp( + src: std.builtin.SourceLocation, op_name: []const u8, op_fn: fn (mlir.Context, mlir.Value, mlir.Value, mlir.Location) mlir.Operation, ) fn (Tensor, Tensor) Tensor { @@ -3718,9 +3725,9 @@ pub const Tensor = struct { stdx.debug.assert(self._shape.eql(other._shape), "{s} expects tensor shapes to match, got {} and {}", .{ op_name, self._shape, other._shape }); - const mlirCtx = self.getContext().mlirCtx(); - const location = mlirCtx.location(@src()); - const ret = @call(.auto, op_fn, .{ mlirCtx, self.value(), other.value(), location }); + const ctx = self.getContext(); + const location = ctx.location(src, "{s}({_}, {_})", .{ op_name, self, other }); + const ret = @call(.auto, op_fn, .{ ctx.mlirCtx(), self.value(), other.value(), location }); return _result(self._shape, ret.result(0)); } }.binaryOpHelper; diff --git a/zml/zml.zig b/zml/zml.zig index 292708e..d4e6992 100644 --- a/zml/zml.zig +++ b/zml/zml.zig @@ -25,6 +25,7 @@ pub const nn = @import("nn.zig"); pub const module = @import("module.zig"); pub const meta = @import("meta.zig"); pub const platform = @import("platform.zig"); +pub const pjrt = @import("pjrtx.zig"); pub const testing = @import("testing.zig"); pub const torch = @import("torch.zig");