Rewrite simple transpose as reshape in core ZML modules and raise default profiler event limit to 1,000,000.
This commit is contained in:
parent
8a031bd4c8
commit
7ef87236ce
@ -619,7 +619,7 @@ pub const DenseElementsAttributeTypes = enum {
|
|||||||
|
|
||||||
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
pub fn DenseIntOrFPElementsAttribute(comptime dt: DenseElementsAttributeTypes) type {
|
||||||
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
|
const ZigInDataType, const ZigOutDataType, const initFn, const getValue = switch (dt) {
|
||||||
.bool => .{ i32, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
|
.bool => .{ bool, bool, c.mlirDenseElementsAttrBoolGet, c.mlirDenseElementsAttrGetBoolValue },
|
||||||
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
|
.i8 => .{ i8, i8, c.mlirDenseElementsAttrInt8Get, c.mlirDenseElementsAttrGetInt8Value },
|
||||||
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
|
.i16 => .{ i16, i16, c.mlirDenseElementsAttrInt16Get, c.mlirDenseElementsAttrGetInt16Value },
|
||||||
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },
|
.i32 => .{ i32, i32, c.mlirDenseElementsAttrInt32Get, c.mlirDenseElementsAttrGetInt32Value },
|
||||||
|
|||||||
@ -135,7 +135,7 @@ pub const Profiler = struct {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
var converter = try TraceContainer.init(allocator, profile_data.items(), null);
|
var converter = try TraceContainer.init(allocator, profile_data.items(), 1_000_000);
|
||||||
defer converter.deinit();
|
defer converter.deinit();
|
||||||
|
|
||||||
var output_file = try dir.createFile(file_name, .{});
|
var output_file = try dir.createFile(file_name, .{});
|
||||||
|
|||||||
@ -177,6 +177,10 @@ pub const CompilationContext = struct {
|
|||||||
return self._mlir_ctx;
|
return self._mlir_ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn location(self: *const CompilationContext, src: std.builtin.SourceLocation, comptime name: [:0]const u8, args: anytype) mlir.Location {
|
||||||
|
return self._mlir_ctx.location(src).namedFmt(self._mlir_ctx, name, args);
|
||||||
|
}
|
||||||
|
|
||||||
/// Compiles the given function with the given arguments.
|
/// Compiles the given function with the given arguments.
|
||||||
/// This is the untyped API and is not meant to be use directly.
|
/// This is the untyped API and is not meant to be use directly.
|
||||||
///
|
///
|
||||||
@ -883,11 +887,13 @@ fn compileModuleToPjrtExecutable(arena: std.mem.Allocator, platform: Platform, m
|
|||||||
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
try options.env_option_overrides.ensureUnusedCapacity(arena, 16);
|
||||||
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
if (xla_dump_to_ orelse platform.compilation_options.xla_dump_to) |xla_dump_to| {
|
||||||
setFlag(&options, "xla_dump_to", xla_dump_to);
|
setFlag(&options, "xla_dump_to", xla_dump_to);
|
||||||
|
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
||||||
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
if (platform.compilation_options.xla_dump_fusion_visualization) {
|
||||||
setFlag(&options, "xla_dump_hlo_as_html", true);
|
|
||||||
setFlag(&options, "xla_dump_hlo_as_dot", true);
|
|
||||||
setFlag(&options, "xla_dump_fusion_visualization", true);
|
setFlag(&options, "xla_dump_fusion_visualization", true);
|
||||||
}
|
}
|
||||||
|
if (platform.compilation_options.xla_dump_hlo_pass_re) |re| {
|
||||||
|
setFlag(&options, "xla_dump_hlo_pass_re", re);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
switch (platform.target) {
|
switch (platform.target) {
|
||||||
.cuda => cuda_dir: {
|
.cuda => cuda_dir: {
|
||||||
|
|||||||
@ -17,6 +17,7 @@ pub const available_targets = std.enums.values(Target);
|
|||||||
pub const CompilationOptions = struct {
|
pub const CompilationOptions = struct {
|
||||||
xla_dump_to: ?[]const u8 = null,
|
xla_dump_to: ?[]const u8 = null,
|
||||||
xla_dump_fusion_visualization: bool = false,
|
xla_dump_fusion_visualization: bool = false,
|
||||||
|
xla_dump_hlo_pass_re: ?[]const u8 = null,
|
||||||
sharding_enabled: bool = false,
|
sharding_enabled: bool = false,
|
||||||
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
sharding_axes: std.BoundedArray([*:0]const u8, 8) = .{},
|
||||||
};
|
};
|
||||||
|
|||||||
@ -391,20 +391,23 @@ pub const Shape = struct {
|
|||||||
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
const bare_fmt = fmt.len == 1 and fmt[0] == '_';
|
||||||
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
_ = try writer.write(if (bare_fmt) "{" else "Shape({");
|
||||||
|
|
||||||
|
var need_comma = false;
|
||||||
for (self.dims(), 0..) |d, i| {
|
for (self.dims(), 0..) |d, i| {
|
||||||
const prefix = if (i == 0) "" else ",";
|
if (need_comma) try writer.writeByte(',');
|
||||||
const t = self.tag(i);
|
const t = self.tag(i);
|
||||||
if (t != TagUnknown) {
|
if (t != TagUnknown) {
|
||||||
try writer.print("{s}.{s}={d}", .{ prefix, t, d });
|
try writer.print("{s}={d}", .{ t, d });
|
||||||
} else {
|
} else {
|
||||||
try writer.print("{s}{d}", .{ prefix, d });
|
try writer.print("{d}", .{d});
|
||||||
}
|
}
|
||||||
if (self._sharding_info[i]) {
|
if (self._sharding_info[i]) {
|
||||||
try writer.writeByte('!');
|
try writer.writeByte('!');
|
||||||
}
|
}
|
||||||
|
need_comma = true;
|
||||||
}
|
}
|
||||||
_ = try writer.print("}}, dtype=.{s}", .{@tagName(self.dtype())});
|
if (need_comma) try writer.writeByte(',');
|
||||||
if (!bare_fmt) _ = try writer.write(")");
|
_ = try writer.write(@tagName(self.dtype()));
|
||||||
|
_ = try writer.write(if (bare_fmt) "}" else "})");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self: Shape, new_shape_: anytype) Shape {
|
pub fn reshape(self: Shape, new_shape_: anytype) Shape {
|
||||||
|
|||||||
101
zml/tensor.zig
101
zml/tensor.zig
@ -1371,7 +1371,11 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const res_shape = self._shape.transpose(permutation);
|
const res_shape = self._shape.transpose(permutation);
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "tr({any})", .{axes_});
|
if (transposeIsJustAReshape(self.shape(), permutation)) {
|
||||||
|
return self.reshape(res_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
const loc = self.getContext().location(@src(), "transpose({_}, {d})", .{ self.shape(), permutation });
|
||||||
const op = dialect.stablehlo.transpose(
|
const op = dialect.stablehlo.transpose(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
@ -1772,7 +1776,7 @@ pub const Tensor = struct {
|
|||||||
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
const dt: DataType = if (sh.dim(a) <= std.math.maxInt(i32)) .i32 else .i64;
|
||||||
const res_shape = sh.withDtype(dt);
|
const res_shape = sh.withDtype(dt);
|
||||||
const mlir_ctx = CompilationContext.current().mlirCtx();
|
const mlir_ctx = CompilationContext.current().mlirCtx();
|
||||||
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({}, {})", .{ res_shape, axis_ });
|
const loc = mlir_ctx.location(@src()).namedFmt(mlir_ctx, "iota({_}, {})", .{ res_shape, a });
|
||||||
|
|
||||||
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
var op = dialect.stablehlo.iota(mlir_ctx, a, mlir.ext.RankedTensorType.fromShape(mlir_ctx, res_shape).asType(), loc);
|
||||||
return _result(res_shape, op.result(0));
|
return _result(res_shape, op.result(0));
|
||||||
@ -1808,11 +1812,27 @@ pub const Tensor = struct {
|
|||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a 0d Tensor with the given value.
|
/// Returns a 0-rank Tensor with the given value.
|
||||||
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
pub fn scalar(val: anytype, dt: DataType) Tensor {
|
||||||
return Tensor.constant(.{}, Data.init(dt, val));
|
return Tensor.constant(.{}, Data.init(dt, val));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test scalar {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const Local = struct {
|
||||||
|
pub fn _fwd() [6]Tensor {
|
||||||
|
var res: [6]Tensor = undefined;
|
||||||
|
const dtypes = .{ .bool, .u8, .i32, .f32, .bf16, .u64 };
|
||||||
|
inline for (0..6) |i| res[i] = scalar(0, dtypes[i]);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
_ = try zml.testing.compileAndCall(platform, Local._fwd, .{});
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns a constant Tensor with the given value.
|
/// Returns a constant Tensor with the given value.
|
||||||
pub fn constant(dimz: anytype, val: Data) Tensor {
|
pub fn constant(dimz: anytype, val: Data) Tensor {
|
||||||
const sh = Shape.init(dimz, val.dtype());
|
const sh = Shape.init(dimz, val.dtype());
|
||||||
@ -1944,6 +1964,12 @@ pub const Tensor = struct {
|
|||||||
return _result(output_shape, reshape_value.result(0));
|
return _result(output_shape, reshape_value.result(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Converts the given 1 element Tensor into a 0-rank Tensor.
|
||||||
|
pub fn asScalar(self: Tensor) Tensor {
|
||||||
|
stdx.debug.assert(self.count() == 1, "Tensor.asScalar expects an input with exactly 1-element got {}", .{self});
|
||||||
|
return self.reshape(.{});
|
||||||
|
}
|
||||||
|
|
||||||
pub const Pad = struct {
|
pub const Pad = struct {
|
||||||
low: i64 = 0,
|
low: i64 = 0,
|
||||||
high: i64 = 0,
|
high: i64 = 0,
|
||||||
@ -2129,7 +2155,7 @@ pub const Tensor = struct {
|
|||||||
// Sometimes the backend recognize this pattern, but not always.
|
// Sometimes the backend recognize this pattern, but not always.
|
||||||
// So let us handle that.
|
// So let us handle that.
|
||||||
if (indices.count() == 1) {
|
if (indices.count() == 1) {
|
||||||
return self.dynamicSlice1d(coord_axes_.get(0), 1, indices.flattenAll().squeeze(0)).reshape(res_shape);
|
return self.dynamicSlice1d(coord_axes_.get(0), .{ .start = indices.flattenAll().squeeze(0), .len = 1 }).reshape(res_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
var slice_dims: Shape.DimsArray = .{};
|
var slice_dims: Shape.DimsArray = .{};
|
||||||
@ -3066,25 +3092,21 @@ pub const Tensor = struct {
|
|||||||
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
/// Slices the input Tensor along a specific axis, with a start offset known at runtime.
|
||||||
/// Note: this doesn't support tagging, if you have tags,
|
/// Note: this doesn't support tagging, if you have tags,
|
||||||
/// you should use `dynamicSlice` directly.
|
/// you should use `dynamicSlice` directly.
|
||||||
pub fn dynamicSlice1d(self: Tensor, axis_: i8, len: u63, start_indices: Tensor) Tensor {
|
pub fn dynamicSlice1d(self: Tensor, axis_: i8, slice_: DynSlice) Tensor {
|
||||||
stdx.debug.assert(start_indices.rank() == 0, "dynamicSlice1d expects 'start_indices' tensor rank to be equal to 0, got {}", .{start_indices.rank()});
|
stdx.debug.assert(slice_.start.rank() == 0, "dynamicSlice1d expects 'slice_.start' tensor rank to be a scalar, got {}", .{slice_.start});
|
||||||
|
|
||||||
const a = self.axis(axis_);
|
const a = self.axis(axis_);
|
||||||
const new_shape = self._shape.set(a, len);
|
const new_shape = self._shape.set(a, slice_.len);
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "axis={}, len={}", .{ axis_, len });
|
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({}, len={})", .{ axis_, slice_.len });
|
||||||
var indices: [Tensor.MAX_RANK]mlir.Value = undefined;
|
|
||||||
for (0..self.rank()) |i| {
|
var start_indices = [_]mlir.Value{constant(.{}, slice_.start.dtype().zero()).value()} ** MAX_RANK;
|
||||||
indices[i] = if (i == a)
|
start_indices[a] = slice_.start.value();
|
||||||
start_indices.value()
|
|
||||||
else
|
|
||||||
constant(.{}, start_indices.dtype().zero()).value();
|
|
||||||
}
|
|
||||||
|
|
||||||
const op = dialect.stablehlo.dynamicSlice(
|
const op = dialect.stablehlo.dynamicSlice(
|
||||||
self.getContext().mlirCtx(),
|
self.getContext().mlirCtx(),
|
||||||
self.value(),
|
self.value(),
|
||||||
new_shape.dims(),
|
new_shape.dims(),
|
||||||
indices[0..self.rank()],
|
start_indices[0..self.rank()],
|
||||||
loc,
|
loc,
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -3115,7 +3137,7 @@ pub const Tensor = struct {
|
|||||||
// TODO use slices and slices_tags for the format.
|
// TODO use slices and slices_tags for the format.
|
||||||
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
|
// Currently this prints: "dynSlice(struct{q: struct{start: tensor.Tensor, comptime len: comptime_int = 1}}{ .q = struct{start: tensor.Tensor, comptime len: comptime_int = 1}{ .start = Tensor({1,10}, dtype=.i64), .len = 1 } })"
|
||||||
// which is kinda ugly.
|
// which is kinda ugly.
|
||||||
const loc = self.getContext().mlirCtx().location(@src()).namedFmt(self.getContext().mlirCtx(), "dynSlice({any})", .{slices_});
|
const loc = self.getContext().location(@src(), "dynSlice({any})", .{slices_});
|
||||||
|
|
||||||
const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32;
|
const idx_dtype = if (slices.len > 0) slices.get(0).start.dtype() else .i32;
|
||||||
const zero = Tensor.scalar(0, idx_dtype).value();
|
const zero = Tensor.scalar(0, idx_dtype).value();
|
||||||
@ -3153,7 +3175,7 @@ pub const Tensor = struct {
|
|||||||
{
|
{
|
||||||
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
const x = try zml.Buffer.fromArray(platform, [10]T{ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 });
|
||||||
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
const z = try zml.Buffer.scalar(platform, 4, .i32);
|
||||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, 2, z });
|
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 0, .{ .len = 2, .start = z } });
|
||||||
|
|
||||||
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
try testing.expectEqual([2]T{ 4, 5 }, try res.getValue([2]T));
|
||||||
}
|
}
|
||||||
@ -3163,7 +3185,7 @@ pub const Tensor = struct {
|
|||||||
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
const x = try zml.Buffer.fromArray(platform, [2][5]T{ .{ 0, 1, 2, 3, 4 }, .{ 5, 6, 7, 8, 9 } });
|
||||||
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
const z = try zml.Buffer.scalar(platform, 3, .i32);
|
||||||
|
|
||||||
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, 2, z });
|
const res = try zml.testing.compileAndCall(platform, Tensor.dynamicSlice1d, .{ x, 1, .{ .len = 2, .start = z } });
|
||||||
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
try testing.expectEqual([4]T{ 3, 4, 8, 9 }, res.getValue([4]T));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -3976,3 +3998,44 @@ inline fn toI64(values: anytype) []i64 {
|
|||||||
for (values, 0..) |val, i| res[i] = @intCast(val);
|
for (values, 0..) |val, i| res[i] = @intCast(val);
|
||||||
return res[0..values.len];
|
return res[0..values.len];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn transposeIsJustAReshape(x: Shape, permutation: []const i64) bool {
|
||||||
|
var perm: std.BoundedArray(struct { u8, bool }, Tensor.MAX_RANK) = .{};
|
||||||
|
// Don't rewrite on invalid inputs.
|
||||||
|
if (permutation.len > x.rank()) return false;
|
||||||
|
for (permutation) |ax| {
|
||||||
|
const squeezable = x.dim(ax) == 1;
|
||||||
|
perm.appendAssumeCapacity(.{ @intCast(ax), squeezable });
|
||||||
|
}
|
||||||
|
|
||||||
|
var effective_ax: u8 = 0;
|
||||||
|
for (0..perm.len) |i| {
|
||||||
|
const ax, const squeezable = perm.get(i);
|
||||||
|
if (squeezable) {
|
||||||
|
// Effectively squeeze this axis by decrementing axes coming after by 1.
|
||||||
|
for (i..perm.len) |j| {
|
||||||
|
if (perm.buffer[j][0] > ax) {
|
||||||
|
perm.buffer[j][0] -= 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ax != effective_ax) return false;
|
||||||
|
effective_ax += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
test transposeIsJustAReshape {
|
||||||
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 0, 1, 2 }));
|
||||||
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 1, 0, 2 }));
|
||||||
|
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 5, 1, 3 }, .i32), &.{ 2, 1, 0 }));
|
||||||
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 8, 1, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
|
||||||
|
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 64, 8, 155, 128 }, .bf16), &.{ 0, 2, 1, 3 }));
|
||||||
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 64, 1, 1, 128 }, .bf16), &.{ 1, 2, 0, 3 }));
|
||||||
|
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ .b = 1, .h = 10, .q = 155, .hd = 1 }, .f32), &.{ 0, 2, 1, 3 }));
|
||||||
|
try std.testing.expect(!transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 2, 3, 1 }));
|
||||||
|
try std.testing.expect(transposeIsJustAReshape(Shape.init(.{ 1, 10, 155, 1 }, .f32), &.{ 0, 1, 3, 2 }));
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user