From 7ef87236ce9f11e81087a3e8d018e1e9c1f6088d Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Mon, 18 Dec 2023 13:56:45 +0000 Subject: [PATCH] Rewrite simple transpose as reshape in core ZML modules and raise default profiler event limit to 1,000,000. --- mlir/mlir.zig | 2 +- pjrt/profiler.zig | 2 +- zml/module.zig | 10 ++++- zml/platform.zig | 1 + zml/shape.zig | 13 +++--- zml/tensor.zig | 101 +++++++++++++++++++++++++++++++++++++--------- 6 files changed, 101 insertions(+), 28 deletions(-) diff --git a/mlir/mlir.zig b/mlir/mlir.zig index fd6bdb8..2dc2b9f 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -619,7 +619,7 @@ pub const DenseElementsAttributeTypes = enum { pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type { const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) { - .bool => .{ i32, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue }, + .bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue }, .i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value }, .i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value }, .i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value }, diff --git a/pjrt/profiler.zig b/pjrt/profiler.zig index 93883b2..1e6e2f0 100644 --- a/pjrt/profiler.zig +++ b/pjrt/profiler.zig @@ -135,7 +135,7 @@ pub const Profiler = struct { return; } - var converter = try TraceContainer.init(allocator, profile_data.items(), null); + var converter = try TraceContainer.init(allocator, profile_data.items(), 1_000_000); defer converter.deinit(); var output_file = try dir.createFile(file_name, .{}); diff --git a/zml/module.zig b/zml/module.zig index 3b7e67c..f7e6930 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -177,6 +177,10 @@ pub const CompilationContext = struct { return self._mlir_ctx; } + pub fn location(self: *const CompilationContext, src: std.builtin.SourceLocation, comptime name: [:0]const u8, args: anytype) mlir.Location { + return self._mlir_ctx.location(src).namedFmt(self._mlir_ctx, name, args); + } + /// Compiles the given function with the given arguments. /// This is the untyped API and is not meant to be use directly. /// @@ -883,11 +887,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m try options.env_option_overrides.ensureUnusedCapacity(arena, 16); if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| { setFlag(&options, "xla_dump_to", xla_dump_to); + setFlag(&options, "xla_dump_hlo_as_dot", true); if (platform.compilation_options.xla_dump_fusion_visualization) { - setFlag(&options, "xla_dump_hlo_as_html", true); - setFlag(&options, "xla_dump_hlo_as_dot", true); setFlag(&options, "xla_dump_fusion_visualization", true); } + if (platform.compilation_options.xla_dump_hlo_pass_re) |re| { + setFlag(&options, "xla_dump_hlo_pass_re", re); + } } switch (platform.target) { .cuda => cuda_dir: { diff --git a/zml/platform.zig b/zml/platform.zig index 6c13597..ec4ae4c 100644 --- a/zml/platform.zig +++ b/zml/platform.zig @@ -17,6 +17,7 @@ pub const available_targets = std.enums.values(Target); pub const CompilationOptions = struct { xla_dump_to: ?[]const u8 = null, xla_dump_fusion_visualization: bool = false, + xla_dump_hlo_pass_re: ?[]const u8 = null, sharding_enabled: bool = false, sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{}, }; diff --git a/zml/shape.zig b/zml/shape.zig index f9643e9..c032ba2 100644 --- a/zml/shape.zig +++ b/zml/shape.zig @@ -391,20 +391,23 @@ pub const Shape = struct { const bare_fmt = fmt.len == 1 and fmt[0] == '_'; _ = try writer.write(if (bare_fmt) "{" else "Shape({"); + var need_comma = false; for (self.dims(), 0..) |d, i| { - const prefix = if (i == 0) "" else ","; + if (need_comma) try writer.writeByte(','); const t = self.tag(i); if (t != TagUnknown) { - try writer.print("{s}.{s}={d}", .{ prefix, t, d }); + try writer.print("{s}={d}", .{ t, d }); } else { - try writer.print("{s}{d}", .{ prefix, d }); + try writer.print("{d}", .{d}); } if (self._sharding_info[i]) { try writer.writeByte('!'); } + need_comma = true; } - _ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())}); - if (!bare_fmt) _ = try writer.write(")"); + if (need_comma) try writer.writeByte(','); + _ = try writer.write(@tagName(self.dtype())); + _ = try writer.write(if (bare_fmt) "}" else "})"); } pub fn reshape(self: Shape, new_shape_: anytype) Shape { diff --git a/zml/tensor.zig b/zml/tensor.zig index 00ae1d3..bc66a8c 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -1371,7 +1371,11 @@ pub const Tensor = struct { } const res_shape = self._shape.transpose(permutation); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "tr({any})", .{axes_}); + if (transposeIsJustAReshape(self.shape(), permutation)) { + return self.reshape(res_shape); + } + + const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation }); const op = dialect.stablehlo.transpose( self.getContext().mlirCtx(), self.value(), @@ -1772,7 +1776,7 @@ pub const Tensor = struct { const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64; const res_shape = sh.withDtype(dt); const mlir_ctx = CompilationContext.current().mlirCtx(); - const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ }); + const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a }); var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc); return _result(res_shape, op.result(0)); @@ -1808,11 +1812,27 @@ pub const Tensor = struct { return res; } - /// Returns a 0d Tensor with the given value. + /// Returns a 0-rank Tensor with the given value. pub fn scalar(val: anytype, dt: DataType) Tensor { return Tensor.constant(.{}, Data.init(dt, val)); } + test scalar { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const Local = struct { + pub fn _fwd() [6]Tensor { + var res: [6]Tensor = undefined; + const dtypes = .{ .bool, .u8, .i32, .f32, .bf16, .u64 }; + inline for (0..6) |i| res[i] = scalar(0, dtypes[i]); + return res; + } + }; + + _ = try zml.testing.compileAndCall(platform, Local._fwd, .{}); + } + /// Returns a constant Tensor with the given value. pub fn constant(dimz: anytype, val: Data) Tensor { const sh = Shape.init(dimz, val.dtype()); @@ -1944,6 +1964,12 @@ pub const Tensor = struct { return _result(output_shape, reshape_value.result(0)); } + /// Converts the given 1 element Tensor into a 0-rank Tensor. + pub fn asScalar(self: Tensor) Tensor { + stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self}); + return self.reshape(.{}); + } + pub const Pad = struct { low: i64 = 0, high: i64 = 0, @@ -2129,7 +2155,7 @@ pub const Tensor = struct { // Sometimes the backend recognize this pattern, but not always. // So let us handle that. if (indices.count() == 1) { - return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape); + return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape); } var slice_dims: Shape.DimsArray = .{}; @@ -3066,25 +3092,21 @@ pub const Tensor = struct { /// Slices the input Tensor along a specific axis, with a start offset known at runtime. /// Note: this doesn't support tagging, if you have tags, /// you should use `dynamicSlice` directly. - pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor { - stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()}); + pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor { + stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start}); const a = self.axis(axis_); - const new_shape = self._shape.set(a, len); - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, len={}", .{ axis_, len }); - var indices: [Tensor.MAX_RANK]mlir.Value = undefined; - for (0..self.rank()) |i| { - indices[i] = if (i == a) - start_indices.value() - else - constant(.{}, start_indices.dtype().zero()).value(); - } + const new_shape = self._shape.set(a, slice_.len); + const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len }); + + var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK; + start_indices[a] = slice_.start.value(); const op = dialect.stablehlo.dynamicSlice( self.getContext().mlirCtx(), self.value(), new_shape.dims(), - indices[0..self.rank()], + start_indices[0..self.rank()], loc, ); @@ -3115,7 +3137,7 @@ pub const Tensor = struct { // TODO use slices and slices_tags for the format. // Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })" // which is kinda ugly. - const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({any})", .{slices_}); + const loc = self.getContext().location(@src(), "dynSlice({any})", .{slices_}); const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32; const zero = Tensor.scalar(0, idx_dtype).value(); @@ -3153,7 +3175,7 @@ pub const Tensor = struct { { const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }); const z = try zml.Buffer.scalar(platform, 4, .i32); - const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z }); + const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } }); try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T)); } @@ -3163,7 +3185,7 @@ pub const Tensor = struct { const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } }); const z = try zml.Buffer.scalar(platform, 3, .i32); - const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z }); + const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } }); try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T)); } } @@ -3976,3 +3998,44 @@ inline fn toI64(values: anytype) []i64 { for (values, 0..) |val, i| res[i] = @intCast(val); return res[0..values.len]; } + +fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool { + var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{}; + // Don't rewrite on invalid inputs. + if (permutation.len > x.rank()) return false; + for (permutation) |ax| { + const squeezable = x.dim(ax) == 1; + perm.appendAssumeCapacity(.{ @intCast(ax), squeezable }); + } + + var effective_ax: u8 = 0; + for (0..perm.len) |i| { + const ax, const squeezable = perm.get(i); + if (squeezable) { + // Effectively squeeze this axis by decrementing axes coming after by 1. + for (i..perm.len) |j| { + if (perm.buffer[j][0] > ax) { + perm.buffer[j][0] -= 1; + } + } + continue; + } + + if (ax != effective_ax) return false; + effective_ax += 1; + } + + return true; +} + +test transposeIsJustAReshape { + try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 0, 1, 2 })); + try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 1, 0, 2 })); + try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 2, 1, 0 })); + try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 8, 1, 128 }, .bf16), &.{ 0, 2, 1, 3 })); + try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 64, 8, 155, 128 }, .bf16), &.{ 0, 2, 1, 3 })); + try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 1, 1, 128 }, .bf16), &.{ 1, 2, 0, 3 })); + try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ .b = 1, .h = 10, .q = 155, .hd = 1 }, .f32), &.{ 0, 2, 1, 3 })); + try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 })); + try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 })); +}