Rewrite simple transpose as reshape in core ZML modules and raise default profiler event limit to 1,000,000.
This commit is contained in:
parent
8a031bd4c8
commit
7ef87236ce
@ -619,7 +619,7 @@ pub const DenseElementsAttributeTypes = enum {
|
||||
|
||||
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
||||
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
|
||||
.bool => .{ i32, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
|
||||
.bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
|
||||
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
|
||||
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
|
||||
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },
|
||||
|
||||
@ -135,7 +135,7 @@ pub const Profiler = struct {
|
||||
return;
|
||||
}
|
||||
|
||||
var converter = try TraceContainer.init(allocator, profile_data.items(), null);
|
||||
var converter = try TraceContainer.init(allocator, profile_data.items(), 1_000_000);
|
||||
defer converter.deinit();
|
||||
|
||||
var output_file = try dir.createFile(file_name, .{});
|
||||
|
||||
@ -177,6 +177,10 @@ pub const CompilationContext = struct {
|
||||
return self._mlir_ctx;
|
||||
}
|
||||
|
||||
pub fn location(self: *const CompilationContext, src: std.builtin.SourceLocation, comptime name: [:0]const u8, args: anytype) mlir.Location {
|
||||
return self._mlir_ctx.location(src).namedFmt(self._mlir_ctx, name, args);
|
||||
}
|
||||
|
||||
/// Compiles the given function with the given arguments.
|
||||
/// This is the untyped API and is not meant to be use directly.
|
||||
///
|
||||
@ -883,11 +887,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
||||
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
||||
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||
setFlag(&options, "xla_dump_to", xla_dump_to);
|
||||
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||
setFlag(&options, "xla_dump_hlo_as_html", true);
|
||||
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
||||
setFlag(&options, "xla_dump_fusion_visualization", true);
|
||||
}
|
||||
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
|
||||
setFlag(&options, "xla_dump_hlo_pass_re", re);
|
||||
}
|
||||
}
|
||||
switch (platform.target) {
|
||||
.cuda => cuda_dir: {
|
||||
|
||||
@ -17,6 +17,7 @@ pub const available_targets = std.enums.values(Target);
|
||||
pub const CompilationOptions = struct {
|
||||
xla_dump_to: ?[]const u8 = null,
|
||||
xla_dump_fusion_visualization: bool = false,
|
||||
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
||||
sharding_enabled: bool = false,
|
||||
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
||||
};
|
||||
|
||||
@ -391,20 +391,23 @@ pub const Shape = struct {
|
||||
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
||||
|
||||
var need_comma = false;
|
||||
for (self.dims(), 0..) |d, i| {
|
||||
const prefix = if (i == 0) "" else ",";
|
||||
if (need_comma) try writer.writeByte(',');
|
||||
const t = self.tag(i);
|
||||
if (t != TagUnknown) {
|
||||
try writer.print("{s}.{s}={d}", .{ prefix, t, d });
|
||||
try writer.print("{s}={d}", .{ t, d });
|
||||
} else {
|
||||
try writer.print("{s}{d}", .{ prefix, d });
|
||||
try writer.print("{d}", .{d});
|
||||
}
|
||||
if (self._sharding_info[i]) {
|
||||
try writer.writeByte('!');
|
||||
}
|
||||
need_comma = true;
|
||||
}
|
||||
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
|
||||
if (!bare_fmt) _ = try writer.write(")");
|
||||
if (need_comma) try writer.writeByte(',');
|
||||
_ = try writer.write(@tagName(self.dtype()));
|
||||
_ = try writer.write(if (bare_fmt) "}" else "})");
|
||||
}
|
||||
|
||||
pub fn reshape(self: Shape, new_shape_: anytype) Shape {
|
||||
|
||||
101
zml/tensor.zig
101
zml/tensor.zig
@ -1371,7 +1371,11 @@ pub const Tensor = struct {
|
||||
}
|
||||
|
||||
const res_shape = self._shape.transpose(permutation);
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "tr({any})", .{axes_});
|
||||
if (transposeIsJustAReshape(self.shape(), permutation)) {
|
||||
return self.reshape(res_shape);
|
||||
}
|
||||
|
||||
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation });
|
||||
const op = dialect.stablehlo.transpose(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
@ -1772,7 +1776,7 @@ pub const Tensor = struct {
|
||||
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||
const res_shape = sh.withDtype(dt);
|
||||
const mlir_ctx = CompilationContext.current().mlirCtx();
|
||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ });
|
||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a });
|
||||
|
||||
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
||||
return _result(res_shape, op.result(0));
|
||||
@ -1808,11 +1812,27 @@ pub const Tensor = struct {
|
||||
return res;
|
||||
}
|
||||
|
||||
/// Returns a 0d Tensor with the given value.
|
||||
/// Returns a 0-rank Tensor with the given value.
|
||||
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
||||
return Tensor.constant(.{}, Data.init(dt, val));
|
||||
}
|
||||
|
||||
test scalar {
|
||||
const zml = @import("zml.zig");
|
||||
const platform = zml.testing.env();
|
||||
|
||||
const Local = struct {
|
||||
pub fn _fwd() [6]Tensor {
|
||||
var res: [6]Tensor = undefined;
|
||||
const dtypes = .{ .bool, .u8, .i32, .f32, .bf16, .u64 };
|
||||
inline for (0..6) |i| res[i] = scalar(0, dtypes[i]);
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
_ = try zml.testing.compileAndCall(platform, Local._fwd, .{});
|
||||
}
|
||||
|
||||
/// Returns a constant Tensor with the given value.
|
||||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||||
const sh = Shape.init(dimz, val.dtype());
|
||||
@ -1944,6 +1964,12 @@ pub const Tensor = struct {
|
||||
return _result(output_shape, reshape_value.result(0));
|
||||
}
|
||||
|
||||
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
||||
pub fn asScalar(self: Tensor) Tensor {
|
||||
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self});
|
||||
return self.reshape(.{});
|
||||
}
|
||||
|
||||
pub const Pad = struct {
|
||||
low: i64 = 0,
|
||||
high: i64 = 0,
|
||||
@ -2129,7 +2155,7 @@ pub const Tensor = struct {
|
||||
// Sometimes the backend recognize this pattern, but not always.
|
||||
// So let us handle that.
|
||||
if (indices.count() == 1) {
|
||||
return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
||||
return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape);
|
||||
}
|
||||
|
||||
var slice_dims: Shape.DimsArray = .{};
|
||||
@ -3066,25 +3092,21 @@ pub const Tensor = struct {
|
||||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||||
/// Note: this doesn't support tagging, if you have tags,
|
||||
/// you should use `dynamicSlice` directly.
|
||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
||||
stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
|
||||
stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start});
|
||||
|
||||
const a = self.axis(axis_);
|
||||
const new_shape = self._shape.set(a, len);
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, len={}", .{ axis_, len });
|
||||
var indices: [Tensor.MAX_RANK]mlir.Value = undefined;
|
||||
for (0..self.rank()) |i| {
|
||||
indices[i] = if (i == a)
|
||||
start_indices.value()
|
||||
else
|
||||
constant(.{}, start_indices.dtype().zero()).value();
|
||||
}
|
||||
const new_shape = self._shape.set(a, slice_.len);
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
||||
|
||||
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
||||
start_indices[a] = slice_.start.value();
|
||||
|
||||
const op = dialect.stablehlo.dynamicSlice(
|
||||
self.getContext().mlirCtx(),
|
||||
self.value(),
|
||||
new_shape.dims(),
|
||||
indices[0..self.rank()],
|
||||
start_indices[0..self.rank()],
|
||||
loc,
|
||||
);
|
||||
|
||||
@ -3115,7 +3137,7 @@ pub const Tensor = struct {
|
||||
// TODO use slices and slices_tags for the format.
|
||||
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
|
||||
// which is kinda ugly.
|
||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({any})", .{slices_});
|
||||
const loc = self.getContext().location(@src(), "dynSlice({any})", .{slices_});
|
||||
|
||||
const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32;
|
||||
const zero = Tensor.scalar(0, idx_dtype).value();
|
||||
@ -3153,7 +3175,7 @@ pub const Tensor = struct {
|
||||
{
|
||||
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z });
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
|
||||
|
||||
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||||
}
|
||||
@ -3163,7 +3185,7 @@ pub const Tensor = struct {
|
||||
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||||
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
||||
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z });
|
||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
|
||||
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||||
}
|
||||
}
|
||||
@ -3976,3 +3998,44 @@ inline fn toI64(values: anytype) []i64 {
|
||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||
return res[0..values.len];
|
||||
}
|
||||
|
||||
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
|
||||
var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
|
||||
// Don't rewrite on invalid inputs.
|
||||
if (permutation.len > x.rank()) return false;
|
||||
for (permutation) |ax| {
|
||||
const squeezable = x.dim(ax) == 1;
|
||||
perm.appendAssumeCapacity(.{ @intCast(ax), squeezable });
|
||||
}
|
||||
|
||||
var effective_ax: u8 = 0;
|
||||
for (0..perm.len) |i| {
|
||||
const ax, const squeezable = perm.get(i);
|
||||
if (squeezable) {
|
||||
// Effectively squeeze this axis by decrementing axes coming after by 1.
|
||||
for (i..perm.len) |j| {
|
||||
if (perm.buffer[j][0] > ax) {
|
||||
perm.buffer[j][0] -= 1;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ax != effective_ax) return false;
|
||||
effective_ax += 1;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
test transposeIsJustAReshape {
|
||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 0, 1, 2 }));
|
||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 1, 0, 2 }));
|
||||
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 2, 1, 0 }));
|
||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 8, 1, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
|
||||
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 64, 8, 155, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
|
||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 1, 1, 128 }, .bf16), &.{ 1, 2, 0, 3 }));
|
||||
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ .b = 1, .h = 10, .q = 155, .hd = 1 }, .f32), &.{ 0, 2, 1, 3 }));
|
||||
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 }));
|
||||
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 }));
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user