Add gpt-oss model support to core ZML components: formatting, utility functions, safetensors I/O, host buffer management, NN layer handling, and tensor operations.
This commit is contained in:
parent
e1b7fc5781
commit
77cd21d2b2
54
stdx/fmt.zig
54
stdx/fmt.zig
@ -10,20 +10,20 @@ fn FmtSlice(T: type) type {
|
|||||||
slice: []const T,
|
slice: []const T,
|
||||||
|
|
||||||
pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
|
pub fn format(f: @This(), writer: *std.io.Writer) std.io.Writer.Error!void {
|
||||||
return try formatSliceAny(f.slice, .{}, writer);
|
return try formatSliceAny(f.slice, .{}, 1, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
pub fn formatNumber(f: @This(), writer: *std.io.Writer, n: std.fmt.Number) std.io.Writer.Error!void {
|
||||||
return switch (@typeInfo(T)) {
|
return switch (@typeInfo(T)) {
|
||||||
.comptime_float, .float => try formatFloatSlice(f.slice, n, writer),
|
.comptime_float, .float => try formatFloatSlice(f.slice, n, 1, writer),
|
||||||
.comptime_int, .int => try formatIntSlice(f.slice, n, writer),
|
.comptime_int, .int => try formatIntSlice(f.slice, n, 1, writer),
|
||||||
.bool => try formatBoolSlice(f.slice, n, writer),
|
.bool => try formatBoolSlice(f.slice, n, 1, writer),
|
||||||
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
|
.@"struct" => if (@hasField(T, "re") and @hasField(T, "im")) {
|
||||||
try formatComplexSlice(f.slice, n, writer);
|
try formatComplexSlice(f.slice, n, 1, writer);
|
||||||
} else if (@hasDecl(T, "toF32")) {
|
} else if (@hasDecl(T, "toF32")) {
|
||||||
try formatFloatSlice(f.slice, n, writer);
|
try formatFloatSlice(f.slice, n, 1, writer);
|
||||||
} else {
|
} else {
|
||||||
try formatSliceAny(f.slice, n, writer);
|
try formatSliceAny(f.slice, n, 1, writer);
|
||||||
},
|
},
|
||||||
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
|
else => @compileError("FmtSlice doesn't support type: " ++ @typeName(T)),
|
||||||
};
|
};
|
||||||
@ -72,53 +72,55 @@ pub fn formatAny(value: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !
|
|||||||
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
|
return try writer.alignBufferOptions(s, .{ .alignment = spec.alignment, .fill = spec.fill });
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatSliceCustom(fmt_func: anytype, values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
// use the format "width" for the number of columns instead of individual width.
|
// use the format "width" for the number of columns instead of individual width.
|
||||||
const num_cols: usize = spec.width orelse 12;
|
const num_cols: usize = spec.width orelse 12;
|
||||||
var my_options = spec;
|
var my_options = spec;
|
||||||
my_options.width = null;
|
my_options.width = null;
|
||||||
const n: usize = values.len;
|
// TODO: handle negative strides
|
||||||
|
const strd: usize = @intCast(stride);
|
||||||
|
const n: usize = @divTrunc(values.len, strd);
|
||||||
|
|
||||||
_ = try writer.write("{");
|
_ = try writer.write("{");
|
||||||
if (n <= num_cols) {
|
if (n <= num_cols) {
|
||||||
for (values, 0..) |v, i| {
|
for (0..n) |i| {
|
||||||
// Force inlining so that the switch and the buffer can be done once.
|
// Force inlining so that the switch and the buffer can be done once.
|
||||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
|
||||||
if (i < n - 1) _ = try writer.write(",");
|
if (i < n - 1) _ = try writer.write(",");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
const half = @divFloor(num_cols, 2);
|
const half = @divFloor(num_cols, 2);
|
||||||
for (values[0..half]) |v| {
|
for (0..half) |i| {
|
||||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
|
||||||
_ = try writer.write(",");
|
_ = try writer.write(",");
|
||||||
}
|
}
|
||||||
_ = try writer.write(" ..., ");
|
_ = try writer.write(" ..., ");
|
||||||
for (values[n - half ..], 0..) |v, i| {
|
for (n - half..n) |i| {
|
||||||
try @call(.always_inline, fmt_func, .{ v, my_options, writer });
|
try @call(.always_inline, fmt_func, .{ values[i * strd], my_options, writer });
|
||||||
if (i < half - 1) _ = try writer.write(",");
|
if (i < n - 1) _ = try writer.write(",");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ = try writer.write("}");
|
_ = try writer.write("}");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatSliceAny(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatAny, values, spec, writer);
|
return try formatSliceCustom(formatAny, values, spec, stride, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatFloatSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatFloat, values, spec, writer);
|
return try formatSliceCustom(formatFloat, values, spec, stride, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatIntSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatInt, values, spec, writer);
|
return try formatSliceCustom(formatInt, values, spec, stride, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatComplexSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatComplex, values, spec, writer);
|
return try formatSliceCustom(formatComplex, values, spec, stride, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, writer: *std.Io.Writer) !void {
|
pub fn formatBoolSlice(values: anytype, spec: std.fmt.Number, stride: i64, writer: *std.Io.Writer) !void {
|
||||||
return try formatSliceCustom(formatBool, values, spec, writer);
|
return try formatSliceCustom(formatBool, values, spec, stride, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Format a struct using `format` method of subfields when possible.
|
/// Format a struct using `format` method of subfields when possible.
|
||||||
|
|||||||
@ -21,16 +21,6 @@ import torch
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
builtin_open = builtins.open
|
|
||||||
|
|
||||||
|
|
||||||
def log_and_open(file, *args, **kwargs):
|
|
||||||
print("open:", file, args, kwargs)
|
|
||||||
return builtin_open(file, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
builtins.open = log_and_open
|
|
||||||
|
|
||||||
class ActivationCollector:
|
class ActivationCollector:
|
||||||
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ pub fn open(allocator: std.mem.Allocator, path: []const u8) !zml.aio.BufferStore
|
|||||||
} else {
|
} else {
|
||||||
try loadFile(arena, &res, &files, path);
|
try loadFile(arena, &res, &files, path);
|
||||||
}
|
}
|
||||||
res.files = try files.toOwnedSlice(allocator);
|
res.files = try files.toOwnedSlice(arena);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -355,12 +355,17 @@ pub const HostBuffer = struct {
|
|||||||
try writer.splatByteAll(' ', indent_level);
|
try writer.splatByteAll(' ', indent_level);
|
||||||
switch (self.dtype()) {
|
switch (self.dtype()) {
|
||||||
inline else => |dt| {
|
inline else => |dt| {
|
||||||
const values = self.items(dt.toZigType());
|
const T = dt.toZigType();
|
||||||
|
const n: i64 = self._shape.dim(0);
|
||||||
|
// TODO: handle negative strides
|
||||||
|
const byte_stride: i64 = self._strides[0];
|
||||||
|
const elem_strides: i64 = @divExact(byte_stride, @sizeOf(T));
|
||||||
|
const values: []const T = @ptrCast(@alignCast(self._data[0..@intCast(n * byte_stride)]));
|
||||||
switch (comptime dt.class()) {
|
switch (comptime dt.class()) {
|
||||||
.float => try stdx.fmt.formatFloatSlice(values, options, writer),
|
.float => try stdx.fmt.formatFloatSlice(values, options, elem_strides, writer),
|
||||||
.integer => try stdx.fmt.formatIntSlice(values, options, writer),
|
.integer => try stdx.fmt.formatIntSlice(values, options, elem_strides, writer),
|
||||||
.complex => try stdx.fmt.formatComplexSlice(values, options, writer),
|
.complex => try stdx.fmt.formatComplexSlice(values, options, elem_strides, writer),
|
||||||
.bool => try stdx.fmt.formatBoolSlice(values, options, writer),
|
.bool => try stdx.fmt.formatBoolSlice(values, options, elem_strides, writer),
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
20
zml/nn.zig
20
zml/nn.zig
@ -110,9 +110,8 @@ pub const LayerNorm = struct {
|
|||||||
pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
|
pub fn rmsNorm(x: Tensor, axis: anytype, eps: f32) Tensor {
|
||||||
const ax = x.axis(axis);
|
const ax = x.axis(axis);
|
||||||
// upcast to improve precision
|
// upcast to improve precision
|
||||||
const xf32 = x.convert(.f32);
|
const variance = x.convert(.f32).powByConst(2).mean(ax);
|
||||||
const mean = xf32.mul(xf32).mean(ax);
|
const rsqrt = Tensor.rsqrt(variance.addConstant(eps)).convert(x.dtype());
|
||||||
const rsqrt = Tensor.rsqrt(mean.addConstant(eps)).convert(x.dtype());
|
|
||||||
return x.mul(rsqrt.broad(x.shape()));
|
return x.mul(rsqrt.broad(x.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -816,7 +815,7 @@ pub fn resizeCubic1d(image: Tensor, axis: i8, new_len: u63, opt: ResizeOpts) Ten
|
|||||||
Tensor.constant(t.shape(), dtype.one()),
|
Tensor.constant(t.shape(), dtype.one()),
|
||||||
t,
|
t,
|
||||||
t.mul(t),
|
t.mul(t),
|
||||||
t.pow(Tensor.scalar(3, dtype)),
|
t.powByConst(3),
|
||||||
}, .last, ._interpolated);
|
}, .last, ._interpolated);
|
||||||
|
|
||||||
std.debug.assert(pos.dim(0) == new_len);
|
std.debug.assert(pos.dim(0) == new_len);
|
||||||
@ -894,6 +893,7 @@ pub fn causalAttnMask(
|
|||||||
pub const SdpaOpts = struct {
|
pub const SdpaOpts = struct {
|
||||||
attn_mask: ?Tensor = null,
|
attn_mask: ?Tensor = null,
|
||||||
scale: ?Tensor = null,
|
scale: ?Tensor = null,
|
||||||
|
softmax_bias: ?Tensor = null,
|
||||||
allow_cudnn: bool = true,
|
allow_cudnn: bool = true,
|
||||||
// TODO: put a callback instead of all this field,
|
// TODO: put a callback instead of all this field,
|
||||||
// so that
|
// so that
|
||||||
@ -922,7 +922,7 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(k.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "k is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
|
stdx.debug.assert(v.shape().hasTags(.{ .h, .k, .hd }), err_template ++ "v is missing tags {{.h, .k, .hd}}", err_args);
|
||||||
|
|
||||||
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape())) {
|
if (opts.allow_cudnn and cuda.canUseCudnnSdpa(q.shape()) and opts.softmax_bias == null) {
|
||||||
return cuda.sdpa(q, k, v, opts);
|
return cuda.sdpa(q, k, v, opts);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -940,9 +940,15 @@ pub fn sdpa(q_: Tensor, k_: Tensor, v_: Tensor, opts: SdpaOpts) Tensor {
|
|||||||
k = k.mul(head_scaling.convert(k.dtype()));
|
k = k.mul(head_scaling.convert(k.dtype()));
|
||||||
|
|
||||||
var attn_weights = q.dot(k, .{.hd});
|
var attn_weights = q.dot(k, .{.hd});
|
||||||
// log.debug("attn_weights : {}, attn_mask : {?}", .{ attn_weights, attn_mask });
|
// log.debug("attn_weights : {f}, attn_mask : {?f}", .{ attn_weights, attn_mask });
|
||||||
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
if (attn_mask) |mask| attn_weights = attn_weights.add(mask.broad(attn_weights.shape()));
|
||||||
attn_weights = attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
attn_weights = attn_weights.convert(.f32);
|
||||||
|
attn_weights = if (opts.softmax_bias) |softmax_bias| attn: {
|
||||||
|
// The split is needed because we also split q ourselves.
|
||||||
|
// TODO: consider letting the user do that.
|
||||||
|
const bias = softmax_bias.splitAxis(.h, .{ .h = k.dim(.h), .hq = .auto });
|
||||||
|
break :attn attn_weights.convert(.f32).softmaxBiased(.k, bias).convert(q.dtype());
|
||||||
|
} else attn_weights.convert(.f32).softmax(.k).convert(q.dtype());
|
||||||
|
|
||||||
var attn = attn_weights.dot(v, .{.k});
|
var attn = attn_weights.dot(v, .{.k});
|
||||||
return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } });
|
return attn.transpose(q.shape()).merge(.{ .h = .{ .h, .hq } });
|
||||||
|
|||||||
@ -1105,13 +1105,16 @@ pub const Tensor = struct {
|
|||||||
/// Axes with the same tag on both sides, and which aren't contracting,
|
/// Axes with the same tag on both sides, and which aren't contracting,
|
||||||
/// are considered "batching axes".
|
/// are considered "batching axes".
|
||||||
pub fn dot(lhs: Tensor, rhs: Tensor, comptime contracting: anytype) Tensor {
|
pub fn dot(lhs: Tensor, rhs: Tensor, comptime contracting: anytype) Tensor {
|
||||||
var contracting_axes: [contracting.len][2]i8 = undefined;
|
var contracting_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{};
|
||||||
inline for (contracting, 0..) |c, i| {
|
if (@TypeOf(contracting) == EnumLiteral) {
|
||||||
contracting_axes[i] = .{ lhs.axis(c), rhs.axis(c) };
|
contracting_axes.appendAssumeCapacity(.{ lhs.axis(contracting), rhs.axis(contracting) });
|
||||||
|
} else {
|
||||||
|
inline for (contracting) |c| {
|
||||||
|
contracting_axes.appendAssumeCapacity(.{ lhs.axis(c), rhs.axis(c) });
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var batching_axes: [MAX_RANK][2]i8 = undefined;
|
var batching_axes: stdx.BoundedArray([2]i8, MAX_RANK) = .{};
|
||||||
var n_batching: u8 = 0;
|
|
||||||
for (lhs._shape.tags(), 0..) |l, li| {
|
for (lhs._shape.tags(), 0..) |l, li| {
|
||||||
stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
|
stdx.debug.assert(l != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, lhs });
|
||||||
|
|
||||||
@ -1119,20 +1122,19 @@ pub const Tensor = struct {
|
|||||||
stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
|
stdx.debug.assert(r != Shape.TagUnknown, "Can't use `dot(..., {any})` on {any}, it need to be explictily tagged.", .{ contracting, rhs });
|
||||||
|
|
||||||
if (l == r) {
|
if (l == r) {
|
||||||
for (contracting_axes) |ct| {
|
for (contracting_axes.slice()) |ct| {
|
||||||
if (l == lhs._shape.tag(ct[0])) {
|
if (l == lhs._shape.tag(ct[0])) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// tag is both in lhs and rhs but not in contracting -> it's a batching dim.
|
// tag is both in lhs and rhs but not in contracting -> it's a batching dim.
|
||||||
batching_axes[n_batching] = .{ @intCast(li), @intCast(ri) };
|
batching_axes.appendAssumeCapacity(.{ @intCast(li), @intCast(ri) });
|
||||||
n_batching += 1;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return dotGeneral(lhs, rhs, contracting_axes[0..], batching_axes[0..n_batching]);
|
return dotGeneral(lhs, rhs, contracting_axes.slice(), batching_axes.slice());
|
||||||
}
|
}
|
||||||
|
|
||||||
test dot {
|
test dot {
|
||||||
@ -1318,23 +1320,39 @@ pub const Tensor = struct {
|
|||||||
/// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor.
|
/// Returns a Tensor containing the Sigmoid Linear Unit (SiLU) activation function applied to each element of the input Tensor.
|
||||||
///
|
///
|
||||||
/// silu(x) = x σ(x)
|
/// silu(x) = x σ(x)
|
||||||
/// https://paperswithcode.com/method/silu
|
|
||||||
pub fn silu(x: Tensor) Tensor {
|
pub fn silu(x: Tensor) Tensor {
|
||||||
return x.mul(x.sigmoid());
|
return x.mul(.sigmoid(x));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the softmax function applied to each element of the input Tensor.
|
/// Computes softmax along the given axis.
|
||||||
|
/// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias )
|
||||||
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
pub fn softmax(self: Tensor, axis_: anytype) Tensor {
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const max_val = self.max(a);
|
const max_val = self.max(a);
|
||||||
const row_mask = max_val.cmp(.GT, Tensor.scalar(-std.math.inf(f64), self.dtype()));
|
const row_mask = max_val.cmp(.GT, .scalar(-std.math.inf(f64), self.dtype()));
|
||||||
|
|
||||||
const exp_diff_max = self.sub(self.max(a).broad(self._shape)).exp();
|
const exp_diff_max = self.sub(max_val).convert(.f32).exp();
|
||||||
const res = exp_diff_max.div(exp_diff_max.sum(a));
|
const res = exp_diff_max.div(exp_diff_max.sum(a)).convert(self.dtype());
|
||||||
|
|
||||||
// If a row is full -inf return full 0 instead of full nan,
|
// If a row is full -inf return full 0 instead of full nan,
|
||||||
// this fix attention when mask hides a full row.
|
// this fix attention when mask hides a full row.
|
||||||
return row_mask.broad(self.shape()).select(res, Tensor.scalar(0, self.dtype()));
|
return row_mask.broad(self.shape()).select(res, .scalar(0, self.dtype()));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes softmax, but adds a bias to the sum of exponentiel.
|
||||||
|
/// y[i] = exp(x[i]) / ( Σ_k exp(x[k]) + bias )
|
||||||
|
pub fn softmaxBiased(self: Tensor, axis_: anytype, bias: ?Tensor) Tensor {
|
||||||
|
const a = self.axis(axis_);
|
||||||
|
|
||||||
|
if (bias == null) return self.softmax(axis_);
|
||||||
|
const b = bias.?.convert(self.dtype()).broad(self.shape().setDim(a, 1));
|
||||||
|
const max_val: Tensor = maximum(self.max(a), b);
|
||||||
|
const exp_diff_max = self.sub(max_val).exp();
|
||||||
|
const bias_diff_max = b.sub(max_val).exp();
|
||||||
|
const res = exp_diff_max.div(exp_diff_max.sum(a).add(bias_diff_max));
|
||||||
|
|
||||||
|
// The bias means that denominator won't be 0, we don't need to handle that case.
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
/// Returns a Tensor containing the log of the sum of exponential over the given axis.
|
||||||
@ -1534,21 +1552,31 @@ pub const Tensor = struct {
|
|||||||
step: i32 = 1,
|
step: i32 = 1,
|
||||||
singleton: bool = false,
|
singleton: bool = false,
|
||||||
|
|
||||||
|
const full = .{ .start = 0, .end = to_the_end, .step = 1 };
|
||||||
|
|
||||||
pub fn single(offset: i64) Slice {
|
pub fn single(offset: i64) Slice {
|
||||||
return .{ .start = offset, .end = offset + 1, .singleton = true };
|
return .{ .start = offset, .end = offset + 1, .singleton = true };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn absolute(s: Slice, d: i64) Slice {
|
||||||
|
const start = if (s.start < 0) d + s.start else s.start;
|
||||||
|
const end = if (s.end == to_the_end) d else if (s.end < 0) d + s.end else s.end;
|
||||||
|
const res: Slice = .{ .start = start, .end = end, .step = s.step, .singleton = s.singleton };
|
||||||
|
stdx.debug.assert(start < end, "Slice {f} is invalid for axis of dimension {d} (resolved to {f}", .{ s, d, res });
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
const to_the_end = std.math.maxInt(i64);
|
const to_the_end = std.math.maxInt(i64);
|
||||||
|
|
||||||
pub fn format(self: Slice, writer: *std.Io.Writer) !void {
|
pub fn format(self: Slice, writer: *std.Io.Writer) !void {
|
||||||
if (self.singleton) {
|
if (self.singleton) {
|
||||||
try writer.print("[{}]", .{self.start});
|
try writer.print("[{d}]", .{self.start});
|
||||||
} else if (self.end == to_the_end and self.step == 1) {
|
} else if (self.end == to_the_end and self.step == 1) {
|
||||||
try writer.print("[{}..]", .{self.start});
|
try writer.print("[{d}..]", .{self.start});
|
||||||
} else if (self.step == 1) {
|
} else if (self.step == 1) {
|
||||||
try writer.print("[{}..{}]", .{ self.start, self.end });
|
try writer.print("[{d}..{d}]", .{ self.start, self.end });
|
||||||
} else {
|
} else {
|
||||||
try writer.print("[{}..{}:{}]", .{ self.start, self.end, self.step });
|
try writer.print("[{d}..{d}:{d}]", .{ self.start, self.end, self.step });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -1570,11 +1598,7 @@ pub const Tensor = struct {
|
|||||||
for (slices, 0..) |s, a| {
|
for (slices, 0..) |s, a| {
|
||||||
stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
|
stdx.debug.assert(s.step > 0, "slice expects 'step' to be positive, got {} at index {}", .{ s.step, a });
|
||||||
|
|
||||||
const args: Slice = .{
|
const args: Slice = s.absolute(self.dim(a));
|
||||||
.start = self.wrapIndex(a, s.start),
|
|
||||||
.end = if (s.end == Slice.to_the_end) self.dim(a) else self.wrapIndex(a, s.end),
|
|
||||||
.step = s.step,
|
|
||||||
};
|
|
||||||
start_indices[a] = args.start;
|
start_indices[a] = args.start;
|
||||||
limit_indices[a] = args.end;
|
limit_indices[a] = args.end;
|
||||||
strides[a] = args.step;
|
strides[a] = args.step;
|
||||||
@ -2016,6 +2040,10 @@ pub const Tensor = struct {
|
|||||||
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
return _result(sh, constant_op.result(0)).convert(val.dtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn zeroes(sh: Shape) Tensor {
|
||||||
|
return .constant(sh, sh.dtype().zero());
|
||||||
|
}
|
||||||
|
|
||||||
/// Embeds a buffer with concrete values into an Mlir program.
|
/// Embeds a buffer with concrete values into an Mlir program.
|
||||||
pub fn constantTensor(val: HostBuffer) Tensor {
|
pub fn constantTensor(val: HostBuffer) Tensor {
|
||||||
const ctx = CompilationContext.current().mlirCtx();
|
const ctx = CompilationContext.current().mlirCtx();
|
||||||
@ -3047,6 +3075,7 @@ pub const Tensor = struct {
|
|||||||
|
|
||||||
/// Returns a Tensor representing the result of Top-K over the given axis.
|
/// Returns a Tensor representing the result of Top-K over the given axis.
|
||||||
pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes {
|
pub fn topK(self: Tensor, named_axis_: anytype, k: u32, opts: struct { descending: bool = true }) SortRes {
|
||||||
|
stdx.debug.assert(k > 0, "topK expects a k > 0, got 0", .{});
|
||||||
const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`";
|
const err_msg = "topK named axis should be an integer or a named axis, eg `x.topK(.{{ .best_token = .token }}, 16)` or `x.topK(-1, 16)`";
|
||||||
const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) {
|
const has_name: ?[:0]const u8, const a = switch (@typeInfo(@TypeOf(named_axis_))) {
|
||||||
.int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) },
|
.int, .comptime_int => .{ null, self.axis(@as(i64, @intCast(named_axis_))) },
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user