Rewrite simple transpose as reshape in core ZML modules and raise default profiler event limit to 1,000,000.

This commit is contained in:
Tarry Singh 2023-12-18 13:56:45 +00:00
parent 8a031bd4c8
commit 7ef87236ce
6 changed files with 101 additions and 28 deletions

View File

@ -619,7 +619,7 @@ pub const DenseElementsAttributeTypes = enum {
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) { const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
.bool => .{ i32, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue }, .bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value }, .i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value }, .i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value }, .i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },

View File

@ -135,7 +135,7 @@ pub const Profiler = struct {
return; return;
} }
var converter = try TraceContainer.init(allocator, profile_data.items(), null); var converter = try TraceContainer.init(allocator, profile_data.items(), 1_000_000);
defer converter.deinit(); defer converter.deinit();
var output_file = try dir.createFile(file_name, .{}); var output_file = try dir.createFile(file_name, .{});

View File

@ -177,6 +177,10 @@ pub const CompilationContext = struct {
return self._mlir_ctx; return self._mlir_ctx;
} }
pub fn location(self: *const CompilationContext, src: std.builtin.SourceLocation, comptime name: [:0]const u8, args: anytype) mlir.Location {
return self._mlir_ctx.location(src).namedFmt(self._mlir_ctx, name, args);
}
/// Compiles the given function with the given arguments. /// Compiles the given function with the given arguments.
/// This is the untyped API and is not meant to be use directly. /// This is the untyped API and is not meant to be use directly.
/// ///
@ -883,11 +887,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
try options.env_option_overrides.ensureUnusedCapacity(arena, 16); try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| { if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
setFlag(&options, "xla_dump_to", xla_dump_to); setFlag(&options, "xla_dump_to", xla_dump_to);
if (platform.compilation_options.xla_dump_fusion_visualization) {
setFlag(&options, "xla_dump_hlo_as_html", true);
setFlag(&options, "xla_dump_hlo_as_dot", true); setFlag(&options, "xla_dump_hlo_as_dot", true);
if (platform.compilation_options.xla_dump_fusion_visualization) {
setFlag(&options, "xla_dump_fusion_visualization", true); setFlag(&options, "xla_dump_fusion_visualization", true);
} }
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
setFlag(&options, "xla_dump_hlo_pass_re", re);
}
} }
switch (platform.target) { switch (platform.target) {
.cuda => cuda_dir: { .cuda => cuda_dir: {

View File

@ -17,6 +17,7 @@ pub const available_targets = std.enums.values(Target);
pub const CompilationOptions = struct { pub const CompilationOptions = struct {
xla_dump_to: ?[]const u8 = null, xla_dump_to: ?[]const u8 = null,
xla_dump_fusion_visualization: bool = false, xla_dump_fusion_visualization: bool = false,
xla_dump_hlo_pass_re: ?[]const u8 = null,
sharding_enabled: bool = false, sharding_enabled: bool = false,
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
}; };

View File

@ -391,20 +391,23 @@ pub const Shape = struct {
const bare_fmt = fmt.len == 1 and fmt[0] == '_'; const bare_fmt = fmt.len == 1 and fmt[0] == '_';
_ = try writer.write(if (bare_fmt) "{" else "Shape({"); _ = try writer.write(if (bare_fmt) "{" else "Shape({");
var need_comma = false;
for (self.dims(), 0..) |d, i| { for (self.dims(), 0..) |d, i| {
const prefix = if (i == 0) "" else ","; if (need_comma) try writer.writeByte(',');
const t = self.tag(i); const t = self.tag(i);
if (t != TagUnknown) { if (t != TagUnknown) {
try writer.print("{s}.{s}={d}", .{ prefix, t, d }); try writer.print("{s}={d}", .{ t, d });
} else { } else {
try writer.print("{s}{d}", .{ prefix, d }); try writer.print("{d}", .{d});
} }
if (self._sharding_info[i]) { if (self._sharding_info[i]) {
try writer.writeByte('!'); try writer.writeByte('!');
} }
need_comma = true;
} }
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())}); if (need_comma) try writer.writeByte(',');
if (!bare_fmt) _ = try writer.write(")"); _ = try writer.write(@tagName(self.dtype()));
_ = try writer.write(if (bare_fmt) "}" else "})");
} }
pub fn reshape(self: Shape, new_shape_: anytype) Shape { pub fn reshape(self: Shape, new_shape_: anytype) Shape {

View File

@ -1371,7 +1371,11 @@ pub const Tensor = struct {
} }
const res_shape = self._shape.transpose(permutation); const res_shape = self._shape.transpose(permutation);
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "tr({any})", .{axes_}); if (transposeIsJustAReshape(self.shape(), permutation)) {
return self.reshape(res_shape);
}
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation });
const op = dialect.stablehlo.transpose( const op = dialect.stablehlo.transpose(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
@ -1772,7 +1776,7 @@ pub const Tensor = struct {
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
const res_shape = sh.withDtype(dt); const res_shape = sh.withDtype(dt);
const mlir_ctx = CompilationContext.current().mlirCtx(); const mlir_ctx = CompilationContext.current().mlirCtx();
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ }); const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a });
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
return _result(res_shape, op.result(0)); return _result(res_shape, op.result(0));
@ -1808,11 +1812,27 @@ pub const Tensor = struct {
return res; return res;
} }
/// Returns a 0d Tensor with the given value. /// Returns a 0-rank Tensor with the given value.
pub fn scalar(val: anytype, dt: DataType) Tensor { pub fn scalar(val: anytype, dt: DataType) Tensor {
return Tensor.constant(.{}, Data.init(dt, val)); return Tensor.constant(.{}, Data.init(dt, val));
} }
test scalar {
const zml = @import("zml.zig");
const platform = zml.testing.env();
const Local = struct {
pub fn _fwd() [6]Tensor {
var res: [6]Tensor = undefined;
const dtypes = .{ .bool, .u8, .i32, .f32, .bf16, .u64 };
inline for (0..6) |i| res[i] = scalar(0, dtypes[i]);
return res;
}
};
_ = try zml.testing.compileAndCall(platform, Local._fwd, .{});
}
/// Returns a constant Tensor with the given value. /// Returns a constant Tensor with the given value.
pub fn constant(dimz: anytype, val: Data) Tensor { pub fn constant(dimz: anytype, val: Data) Tensor {
const sh = Shape.init(dimz, val.dtype()); const sh = Shape.init(dimz, val.dtype());
@ -1944,6 +1964,12 @@ pub const Tensor = struct {
return _result(output_shape, reshape_value.result(0)); return _result(output_shape, reshape_value.result(0));
} }
/// Converts the given 1 element Tensor into a 0-rank Tensor.
pub fn asScalar(self: Tensor) Tensor {
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self});
return self.reshape(.{});
}
pub const Pad = struct { pub const Pad = struct {
low: i64 = 0, low: i64 = 0,
high: i64 = 0, high: i64 = 0,
@ -2129,7 +2155,7 @@ pub const Tensor = struct {
// Sometimes the backend recognize this pattern, but not always. // Sometimes the backend recognize this pattern, but not always.
// So let us handle that. // So let us handle that.
if (indices.count() == 1) { if (indices.count() == 1) {
return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape);
} }
var slice_dims: Shape.DimsArray = .{}; var slice_dims: Shape.DimsArray = .{};
@ -3066,25 +3092,21 @@ pub const Tensor = struct {
/// Slices the input Tensor along a specific axis, with a start offset known at runtime. /// Slices the input Tensor along a specific axis, with a start offset known at runtime.
/// Note: this doesn't support tagging, if you have tags, /// Note: this doesn't support tagging, if you have tags,
/// you should use `dynamicSlice` directly. /// you should use `dynamicSlice` directly.
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start});
const a = self.axis(axis_); const a = self.axis(axis_);
const new_shape = self._shape.set(a, len); const new_shape = self._shape.set(a, slice_.len);
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, len={}", .{ axis_, len }); const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len });
var indices: [Tensor.MAX_RANK]mlir.Value = undefined;
for (0..self.rank()) |i| { var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
indices[i] = if (i == a) start_indices[a] = slice_.start.value();
start_indices.value()
else
constant(.{}, start_indices.dtype().zero()).value();
}
const op = dialect.stablehlo.dynamicSlice( const op = dialect.stablehlo.dynamicSlice(
self.getContext().mlirCtx(), self.getContext().mlirCtx(),
self.value(), self.value(),
new_shape.dims(), new_shape.dims(),
indices[0..self.rank()], start_indices[0..self.rank()],
loc, loc,
); );
@ -3115,7 +3137,7 @@ pub const Tensor = struct {
// TODO use slices and slices_tags for the format. // TODO use slices and slices_tags for the format.
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })" // Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
// which is kinda ugly. // which is kinda ugly.
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({any})", .{slices_}); const loc = self.getContext().location(@src(), "dynSlice({any})", .{slices_});
const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32; const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32;
const zero = Tensor.scalar(0, idx_dtype).value(); const zero = Tensor.scalar(0, idx_dtype).value();
@ -3153,7 +3175,7 @@ pub const Tensor = struct {
{ {
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
const z = try zml.Buffer.scalar(platform, 4, .i32); const z = try zml.Buffer.scalar(platform, 4, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z }); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T)); try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
} }
@ -3163,7 +3185,7 @@ pub const Tensor = struct {
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } }); const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
const z = try zml.Buffer.scalar(platform, 3, .i32); const z = try zml.Buffer.scalar(platform, 3, .i32);
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z }); const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T)); try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
} }
} }
@ -3976,3 +3998,44 @@ inline fn toI64(values: anytype) []i64 {
for (values, 0..) |val, i| res[i] = @intCast(val); for (values, 0..) |val, i| res[i] = @intCast(val);
return res[0..values.len]; return res[0..values.len];
} }
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
// Don't rewrite on invalid inputs.
if (permutation.len > x.rank()) return false;
for (permutation) |ax| {
const squeezable = x.dim(ax) == 1;
perm.appendAssumeCapacity(.{ @intCast(ax), squeezable });
}
var effective_ax: u8 = 0;
for (0..perm.len) |i| {
const ax, const squeezable = perm.get(i);
if (squeezable) {
// Effectively squeeze this axis by decrementing axes coming after by 1.
for (i..perm.len) |j| {
if (perm.buffer[j][0] > ax) {
perm.buffer[j][0] -= 1;
}
}
continue;
}
if (ax != effective_ax) return false;
effective_ax += 1;
}
return true;
}
test transposeIsJustAReshape {
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 0, 1, 2 }));
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 1, 0, 2 }));
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 2, 1, 0 }));
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 8, 1, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 64, 8, 155, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 1, 1, 128 }, .bf16), &.{ 1, 2, 0, 3 }));
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ .b = 1, .h = 10, .q = 155, .hd = 1 }, .f32), &.{ 0, 2, 1, 3 }));
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 }));
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 }));
}