Integrate user sharding hints and HLO sharding annotations across MLIR dialects and ZML core, and remove the now‑unused module options arguments.

This commit is contained in:
Tarry Singh 2023-03-21 10:50:39 +00:00
parent e30e35deeb
commit a4f0fc96c0
7 changed files with 104 additions and 30 deletions

View File

@ -8,17 +8,23 @@ pub fn func(
args: []const mlir.Type, args: []const mlir.Type,
arg_attrs: []const mlir.Attribute = &.{}, arg_attrs: []const mlir.Attribute = &.{},
results: []const mlir.Type, results: []const mlir.Type,
res_attrs: []const mlir.Attribute = &.{},
block: mlir.Block, block: mlir.Block,
location: mlir.Location, location: mlir.Location,
}, },
) mlir.Operation { ) mlir.Operation {
const AttrTuple = struct { [:0]const u8, mlir.Attribute }; const AttrTuple = struct { [:0]const u8, mlir.Attribute };
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 3){}; var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){};
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
if (args.arg_attrs.len > 0) { if (args.arg_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? }); attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
} }
if (args.res_attrs.len > 0) {
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? });
}
return mlir.Operation.make(ctx, "func.func", .{ return mlir.Operation.make(ctx, "func.func", .{
.blocks = &.{args.block}, .blocks = &.{args.block},
.attributes = attrs_tuple_buffer.constSlice(), .attributes = attrs_tuple_buffer.constSlice(),

View File

@ -697,6 +697,22 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
}); });
} }
pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
.operands = inputs,
.results = res_types,
.attributes = &.{
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
.{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() },
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() },
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
.{ "mhlo.sharding", sharding_spec.asAttr() },
},
.location = location,
});
}
pub const DotDimensionNumbersAttribute = struct { pub const DotDimensionNumbersAttribute = struct {
_inner: c.MlirAttribute, _inner: c.MlirAttribute,

View File

@ -466,6 +466,10 @@ pub const ArrayAttribute = struct {
pub fn get(self: Self, index: usize) Attribute { pub fn get(self: Self, index: usize) Attribute {
return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index))); return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index)));
} }
pub fn asAttr(self: Self) Attribute {
return .{ ._inner = self._inner };
}
}; };
pub fn IntegerAttribute(comptime it: IntegerTypes) type { pub fn IntegerAttribute(comptime it: IntegerTypes) type {

View File

@ -82,18 +82,14 @@ pub const Buffer = struct {
} }
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`. /// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer { pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len }); meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{}); meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
var shards: Shards = .{}; var shards: Shards = .{};
shards.appendSliceAssumeCapacity(pjrt_buffers); shards.appendSliceAssumeCapacity(pjrt_buffers);
return .{ return .{
._api = platform.pjrt_api, ._api = platform.pjrt_api,
._shape = Shape.init( ._shape = shape_,
// This isn't with sharded axes.
pjrt_buffers[0].getDimensions(platform.pjrt_api),
dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)),
),
._shards = shards, ._shards = shards,
}; };
} }

View File

@ -183,7 +183,6 @@ pub const CompilationContext = struct {
comptime func: anytype, comptime func: anytype,
model: *const ModuleSignature(func).ModelT, model: *const ModuleSignature(func).ModelT,
args: *const ModuleSignature(func).ArgsT, args: *const ModuleSignature(func).ArgsT,
opts: struct { add_donations_attributes: bool = false, sharding: bool = true },
) error{OutOfMemory}!MlirFn { ) error{OutOfMemory}!MlirFn {
const frame = self._tracer.frameStart("generateBytecode.emit"); const frame = self._tracer.frameStart("generateBytecode.emit");
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit"); errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
@ -244,6 +243,7 @@ pub const CompilationContext = struct {
var fn_res_values: [out_tensor_count]mlir.Value = undefined; var fn_res_values: [out_tensor_count]mlir.Value = undefined;
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations); self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc); const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
fn_body.addOperationsRecursive(fn_ret); fn_body.addOperationsRecursive(fn_ret);
} }
@ -251,19 +251,21 @@ pub const CompilationContext = struct {
const arg_attrs = try arena.alloc(AttributeList, tensor_count); const arg_attrs = try arena.alloc(AttributeList, tensor_count);
@memset(arg_attrs, .{}); @memset(arg_attrs, .{});
// Donations attributes only make sense on the main function. const res_attrs = try arena.alloc(AttributeList, out_tensor_count);
if (opts.add_donations_attributes) { @memset(res_attrs, .{});
self.addDonationsAttributes(arg_attrs, fn_res_donations);
}
if (opts.sharding) {
self.addShardingAttributes(arg_attrs, input_shapes.items);
}
// Donations attributes only make sense on the main function.
self.addDonationsAttributes(arg_attrs, fn_res_donations);
if (self._platform.sharding().num_partitions > 1) {
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
}
const mlir_fn = dialect.func.func(self.mlirCtx(), .{ const mlir_fn = dialect.func.func(self.mlirCtx(), .{
.sym_name = fn_name, .sym_name = fn_name,
.args = input_types, .args = input_types,
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs), .arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
.results = fn_res_types, .results = fn_res_types,
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
.block = fn_body, .block = fn_body,
.location = loc, .location = loc,
}); });
@ -344,7 +346,7 @@ pub const CompilationContext = struct {
var comp = try zml.module.CompilationContext.init(allocator, "test", platform); var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
defer comp.deinit(); defer comp.deinit();
var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }}; var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }};
const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args, .{ .add_donations_attributes = true }); const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args);
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{}; var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})}); try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
@ -359,25 +361,42 @@ pub const CompilationContext = struct {
}; };
} }
fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void { pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
const mlir_ctx = self.mlirCtx(); const mlir_ctx = self.mlirCtx();
if (!self._platform.compilation_options.sharding_enabled) return;
const num_partitions = self._platform.sharding().num_partitions; const num_partitions = self._platform.sharding().num_partitions;
var sharding_str: std.BoundedArray(u8, 128) = .{}; var sharding_str: std.BoundedArray(u8, 128) = .{};
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice());
}
fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void {
const mlir_ctx = self.mlirCtx();
if (!self._platform.compilation_options.sharding_enabled) return;
const mhlo_default_layout = mlir.NamedAttribute.init( const mhlo_default_layout = mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"), mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(), mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
); );
for (attributes, shapes) |*attr, shape| { for (arg_attrs, input_shapes) |*attr, shape| {
attr.appendAssumeCapacity(mhlo_default_layout); attr.appendAssumeCapacity(mhlo_default_layout);
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable; const sharding_attr = self.getShardingAttr(shape);
defer sharding_str.len = 0;
attr.appendAssumeCapacity(mlir.NamedAttribute.init( attr.appendAssumeCapacity(mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"), mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(), sharding_attr.asAttr(),
));
}
for (res_attrs, output_shapes) |*attr, shape| {
attr.appendAssumeCapacity(mhlo_default_layout);
const sharding_attr = self.getShardingAttr(shape);
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
sharding_attr.asAttr(),
)); ));
} }
} }
@ -668,18 +687,20 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3
} }
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers. /// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void { pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape, expected_count: u32) void {
const LocalContext = struct { const LocalContext = struct {
index: u32, index: u32,
platform: Platform, platform: Platform,
buffers: []const [*]*pjrt.Buffer, buffers: []const [*]*pjrt.Buffer,
expected_count: u32, expected_count: u32,
buffer_shapes: []Shape,
}; };
var local_ctx: LocalContext = .{ var local_ctx: LocalContext = .{
.index = 0, .index = 0,
.platform = platform, .platform = platform,
.buffers = buffers, .buffers = buffers,
.expected_count = expected_count, .expected_count = expected_count,
.buffer_shapes = buffer_shapes,
}; };
meta.visit((struct { meta.visit((struct {
fn cb(ctx: *LocalContext, buffer: *Buffer) void { fn cb(ctx: *LocalContext, buffer: *Buffer) void {
@ -691,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
for (ctx.buffers) |buff| { for (ctx.buffers) |buff| {
shards.appendAssumeCapacity(buff[i]); shards.appendAssumeCapacity(buff[i]);
} }
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice()); buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
} }
}).cb, &local_ctx, v); }).cb, &local_ctx, v);
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index }); meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
@ -731,8 +752,12 @@ const BaseExe = struct {
args_buffer_count: u32, args_buffer_count: u32,
/// Number of buffers in result. /// Number of buffers in result.
result_buffer_count: u32, result_buffer_count: u32,
/// Shapes of buffers in result.
result_buffer_shapes: []Shape,
/// Num devices used (>1 for sharded executable) /// Num devices used (>1 for sharded executable)
num_devices: u8, num_devices: u8,
/// Allocator backing result_buffer_shapes and deinit by ExeWithWeights
_allocator: std.heap.ArenaAllocator,
}; };
/// Represents a ZML model, compiled into a PJRT executable. /// Represents a ZML model, compiled into a PJRT executable.
@ -815,6 +840,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
// Free in reverse order of allocation. // Free in reverse order of allocation.
self._allocator.free(self._all_per_device); self._allocator.free(self._all_per_device);
self._allocator.free(self._all_buffers); self._allocator.free(self._all_buffers);
self.inner._allocator.deinit();
} }
pub fn platform(self: Self) Platform { pub fn platform(self: Self) Platform {
@ -836,7 +862,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
}) catch unreachable; }) catch unreachable;
var result: Bufferized(Signature.ReturnT) = undefined; var result: Bufferized(Signature.ReturnT) = undefined;
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count); assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count);
return result; return result;
} }
}; };
@ -860,7 +886,7 @@ fn compileInternal(
var timer = std.time.Timer.start() catch null; var timer = std.time.Timer.start() catch null;
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args); const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
// Run in a dedicated thread because compilation relies on `threadlocal`. // Run in a dedicated thread because compilation relies on `threadlocal`.
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } }); const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args });
context._module.getBody().appendOperation(f.mlir_fn); context._module.getBody().appendOperation(f.mlir_fn);
const sharding = context._platform.sharding(); const sharding = context._platform.sharding();
@ -884,13 +910,18 @@ fn compileInternal(
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)}); if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)});
} }
var arena_state_exe = std.heap.ArenaAllocator.init(allocator);
const arena_exe = arena_state_exe.allocator();
return .{ return .{
.platform = context._platform, .platform = context._platform,
.exe = loaded_executable, .exe = loaded_executable,
.model_buffer_count = f.n_model, .model_buffer_count = f.n_model,
.args_buffer_count = f.n_args, .args_buffer_count = f.n_args,
.result_buffer_count = @intCast(f.res_types.len), .result_buffer_count = @intCast(f.res_types.len),
.result_buffer_shapes = arena_exe.dupe(Shape, f.res_shapes) catch unreachable,
.num_devices = sharding.num_replicas * sharding.num_partitions, .num_devices = sharding.num_replicas * sharding.num_partitions,
._allocator = arena_state_exe,
}; };
} }

View File

@ -159,9 +159,30 @@ pub const Tensor = struct {
} }
pub fn withSharding(self: Tensor, axes_: anytype) Tensor { pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
return switch (self._id) {
.arg_id, .mlir => {
const ctx = self.getContext();
var res = self;
res._shape = self._shape.withSharding(axes_);
const sharding = ctx.getShardingAttr(res._shape);
const op = dialect.stablehlo.sharding(
ctx.mlirCtx(),
&.{self.value()},
sharding,
&.{self.value().getType()},
ctx.mlirCtx().location(@src()),
);
return _result(res._shape, op.result(0));
},
.buffer_id => {
var res = self; var res = self;
res._shape = self._shape.withSharding(axes_); res._shape = self._shape.withSharding(axes_);
return res; return res;
},
};
} }
/// Returns a Tensor with new tag names. /// Returns a Tensor with new tag names.

View File

@ -175,7 +175,7 @@ pub fn testLayerOut(
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in }); log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
} }
const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform, .{}); const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform);
const n_out_exp = activations.countLayers(out_name); const n_out_exp = activations.countLayers(out_name);
if (exe.inner.result_buffer_count != n_out_exp) { if (exe.inner.result_buffer_count != n_out_exp) {