From a4f0fc96c05e58b2b8a5bc28533440cfd3d8a1a9 Mon Sep 17 00:00:00 2001 From: Tarry Singh Date: Tue, 21 Mar 2023 10:50:39 +0000 Subject: [PATCH] =?UTF-8?q?Integrate=20user=20sharding=20hints=20and=20HLO?= =?UTF-8?q?=20sharding=20annotations=20across=20MLIR=20dialects=20and=20ZM?= =?UTF-8?q?L=20core,=20and=20remove=20the=20now=E2=80=91unused=20module=20?= =?UTF-8?q?options=20arguments.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mlir/dialects/func.zig | 8 ++++- mlir/dialects/stablehlo.zig | 16 +++++++++ mlir/mlir.zig | 4 +++ zml/buffer.zig | 8 ++--- zml/module.zig | 69 +++++++++++++++++++++++++++---------- zml/tensor.zig | 27 +++++++++++++-- zml/testing.zig | 2 +- 7 files changed, 104 insertions(+), 30 deletions(-) diff --git a/mlir/dialects/func.zig b/mlir/dialects/func.zig index 4f2feb0..96ba1d4 100644 --- a/mlir/dialects/func.zig +++ b/mlir/dialects/func.zig @@ -8,17 +8,23 @@ pub fn func( args: []const mlir.Type, arg_attrs: []const mlir.Attribute = &.{}, results: []const mlir.Type, + res_attrs: []const mlir.Attribute = &.{}, block: mlir.Block, location: mlir.Location, }, ) mlir.Operation { const AttrTuple = struct { [:0]const u8, mlir.Attribute }; - var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 3){}; + var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){}; attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); if (args.arg_attrs.len > 0) { attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? }); } + + if (args.res_attrs.len > 0) { + attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? }); + } + return mlir.Operation.make(ctx, "func.func", .{ .blocks = &.{args.block}, .attributes = attrs_tuple_buffer.constSlice(), diff --git a/mlir/dialects/stablehlo.zig b/mlir/dialects/stablehlo.zig index 2fa642c..9f3606e 100644 --- a/mlir/dialects/stablehlo.zig +++ b/mlir/dialects/stablehlo.zig @@ -697,6 +697,22 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa }); } +pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation { + return mlir.Operation.make(ctx, "stablehlo.custom_call", .{ + .operands = inputs, + .results = res_types, + .attributes = &.{ + .{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() }, + .{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() }, + .{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() }, + .{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() }, + .{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() }, + .{ "mhlo.sharding", sharding_spec.asAttr() }, + }, + .location = location, + }); +} + pub const DotDimensionNumbersAttribute = struct { _inner: c.MlirAttribute, diff --git a/mlir/mlir.zig b/mlir/mlir.zig index 0ec47ea..0c00cf5 100644 --- a/mlir/mlir.zig +++ b/mlir/mlir.zig @@ -466,6 +466,10 @@ pub const ArrayAttribute = struct { pub fn get(self: Self, index: usize) Attribute { return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index))); } + + pub fn asAttr(self: Self) Attribute { + return .{ ._inner = self._inner }; + } }; pub fn IntegerAttribute(comptime it: IntegerTypes) type { diff --git a/zml/buffer.zig b/zml/buffer.zig index 02d4bbf..d093d65 100644 --- a/zml/buffer.zig +++ b/zml/buffer.zig @@ -82,18 +82,14 @@ pub const Buffer = struct { } /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. - pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer { + pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer { meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); var shards: Shards = .{}; shards.appendSliceAssumeCapacity(pjrt_buffers); return .{ ._api = platform.pjrt_api, - ._shape = Shape.init( - // This isn't with sharded axes. - pjrt_buffers[0].getDimensions(platform.pjrt_api), - dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)), - ), + ._shape = shape_, ._shards = shards, }; } diff --git a/zml/module.zig b/zml/module.zig index 2cd64d0..69f2c95 100644 --- a/zml/module.zig +++ b/zml/module.zig @@ -183,7 +183,6 @@ pub const CompilationContext = struct { comptime func: anytype, model: *const ModuleSignature(func).ModelT, args: *const ModuleSignature(func).ArgsT, - opts: struct { add_donations_attributes: bool = false, sharding: bool = true }, ) error{OutOfMemory}!MlirFn { const frame = self._tracer.frameStart("generateBytecode.emit"); errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); @@ -244,6 +243,7 @@ pub const CompilationContext = struct { var fn_res_values: [out_tensor_count]mlir.Value = undefined; self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); + const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); fn_body.addOperationsRecursive(fn_ret); } @@ -251,19 +251,21 @@ pub const CompilationContext = struct { const arg_attrs = try arena.alloc(AttributeList, tensor_count); @memset(arg_attrs, .{}); - // Donations attributes only make sense on the main function. - if (opts.add_donations_attributes) { - self.addDonationsAttributes(arg_attrs, fn_res_donations); - } - if (opts.sharding) { - self.addShardingAttributes(arg_attrs, input_shapes.items); - } + const res_attrs = try arena.alloc(AttributeList, out_tensor_count); + @memset(res_attrs, .{}); + // Donations attributes only make sense on the main function. + self.addDonationsAttributes(arg_attrs, fn_res_donations); + + if (self._platform.sharding().num_partitions > 1) { + self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes); + } const mlir_fn = dialect.func.func(self.mlirCtx(), .{ .sym_name = fn_name, .args = input_types, .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .results = fn_res_types, + .res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs), .block = fn_body, .location = loc, }); @@ -344,7 +346,7 @@ pub const CompilationContext = struct { var comp = try zml.module.CompilationContext.init(allocator, "test", platform); defer comp.deinit(); var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }}; - const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args, .{ .add_donations_attributes = true }); + const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args); var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); @@ -359,25 +361,42 @@ pub const CompilationContext = struct { }; } - fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void { + pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute { const mlir_ctx = self.mlirCtx(); - if (!self._platform.compilation_options.sharding_enabled) return; const num_partitions = self._platform.sharding().num_partitions; var sharding_str: std.BoundedArray(u8, 128) = .{}; + writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; + return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()); + } + + fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void { + const mlir_ctx = self.mlirCtx(); + if (!self._platform.compilation_options.sharding_enabled) return; + const mhlo_default_layout = mlir.NamedAttribute.init( mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"), mlir.StringAttribute.init(mlir_ctx, "default").asAttr(), ); - for (attributes, shapes) |*attr, shape| { + for (arg_attrs, input_shapes) |*attr, shape| { attr.appendAssumeCapacity(mhlo_default_layout); - writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; - defer sharding_str.len = 0; + const sharding_attr = self.getShardingAttr(shape); attr.appendAssumeCapacity(mlir.NamedAttribute.init( mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), - mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(), + sharding_attr.asAttr(), + )); + } + + for (res_attrs, output_shapes) |*attr, shape| { + attr.appendAssumeCapacity(mhlo_default_layout); + + const sharding_attr = self.getShardingAttr(shape); + + attr.appendAssumeCapacity(mlir.NamedAttribute.init( + mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), + sharding_attr.asAttr(), )); } } @@ -668,18 +687,20 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3 } /// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. -pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void { +pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape, expected_count: u32) void { const LocalContext = struct { index: u32, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32, + buffer_shapes: []Shape, }; var local_ctx: LocalContext = .{ .index = 0, .platform = platform, .buffers = buffers, .expected_count = expected_count, + .buffer_shapes = buffer_shapes, }; meta.visit((struct { fn cb(ctx: *LocalContext, buffer: *Buffer) void { @@ -691,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr for (ctx.buffers) |buff| { shards.appendAssumeCapacity(buff[i]); } - buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice()); + buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice()); } }).cb, &local_ctx, v); meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); @@ -731,8 +752,12 @@ const BaseExe = struct { args_buffer_count: u32, /// Number of buffers in result. result_buffer_count: u32, + /// Shapes of buffers in result. + result_buffer_shapes: []Shape, /// Num devices used (>1 for sharded executable) num_devices: u8, + /// Allocator backing result_buffer_shapes and deinit by ExeWithWeights + _allocator: std.heap.ArenaAllocator, }; /// Represents a ZML model, compiled into a PJRT executable. @@ -815,6 +840,7 @@ pub fn ExeWithWeights(comptime func: anytype) type { // Free in reverse order of allocation. self._allocator.free(self._all_per_device); self._allocator.free(self._all_buffers); + self.inner._allocator.deinit(); } pub fn platform(self: Self) Platform { @@ -836,7 +862,7 @@ pub fn ExeWithWeights(comptime func: anytype) type { }) catch unreachable; var result: Bufferized(Signature.ReturnT) = undefined; - assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count); + assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count); return result; } }; @@ -860,7 +886,7 @@ fn compileInternal( var timer = std.time.Timer.start() catch null; const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); // Run in a dedicated thread because compilation relies on `threadlocal`. - const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } }); + const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args }); context._module.getBody().appendOperation(f.mlir_fn); const sharding = context._platform.sharding(); @@ -884,13 +910,18 @@ fn compileInternal( if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)}); } + var arena_state_exe = std.heap.ArenaAllocator.init(allocator); + const arena_exe = arena_state_exe.allocator(); + return .{ .platform = context._platform, .exe = loaded_executable, .model_buffer_count = f.n_model, .args_buffer_count = f.n_args, .result_buffer_count = @intCast(f.res_types.len), + .result_buffer_shapes = arena_exe.dupe(Shape, f.res_shapes) catch unreachable, .num_devices = sharding.num_replicas * sharding.num_partitions, + ._allocator = arena_state_exe, }; } diff --git a/zml/tensor.zig b/zml/tensor.zig index 97d001c..0d182fb 100644 --- a/zml/tensor.zig +++ b/zml/tensor.zig @@ -159,9 +159,30 @@ pub const Tensor = struct { } pub fn withSharding(self: Tensor, axes_: anytype) Tensor { - var res = self; - res._shape = self._shape.withSharding(axes_); - return res; + return switch (self._id) { + .arg_id, .mlir => { + const ctx = self.getContext(); + var res = self; + res._shape = self._shape.withSharding(axes_); + + const sharding = ctx.getShardingAttr(res._shape); + + const op = dialect.stablehlo.sharding( + ctx.mlirCtx(), + &.{self.value()}, + sharding, + &.{self.value().getType()}, + ctx.mlirCtx().location(@src()), + ); + + return _result(res._shape, op.result(0)); + }, + .buffer_id => { + var res = self; + res._shape = self._shape.withSharding(axes_); + return res; + }, + }; } /// Returns a Tensor with new tag names. diff --git a/zml/testing.zig b/zml/testing.zig index 6f0544a..df3c04e 100644 --- a/zml/testing.zig +++ b/zml/testing.zig @@ -175,7 +175,7 @@ pub fn testLayerOut( log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in }); } - const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform, .{}); + const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform); const n_out_exp = activations.countLayers(out_name); if (exe.inner.result_buffer_count != n_out_exp) {