Integrate user sharding hints and HLO sharding annotations across MLIR dialects and ZML core, and remove the now‑unused module options arguments.
This commit is contained in:
parent
e30e35deeb
commit
a4f0fc96c0
@ -8,17 +8,23 @@ pub fn func(
|
|||||||
args: []const mlir.Type,
|
args: []const mlir.Type,
|
||||||
arg_attrs: []const mlir.Attribute = &.{},
|
arg_attrs: []const mlir.Attribute = &.{},
|
||||||
results: []const mlir.Type,
|
results: []const mlir.Type,
|
||||||
|
res_attrs: []const mlir.Attribute = &.{},
|
||||||
block: mlir.Block,
|
block: mlir.Block,
|
||||||
location: mlir.Location,
|
location: mlir.Location,
|
||||||
},
|
},
|
||||||
) mlir.Operation {
|
) mlir.Operation {
|
||||||
const AttrTuple = struct { [:0]const u8, mlir.Attribute };
|
const AttrTuple = struct { [:0]const u8, mlir.Attribute };
|
||||||
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 3){};
|
var attrs_tuple_buffer = std.BoundedArray(AttrTuple, 4){};
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "sym_name", mlir.StringAttribute.init(ctx, args.sym_name).as(mlir.Attribute).? });
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "function_type", mlir.TypeAttribute.init((mlir.FunctionType.init(ctx, args.args, args.results) catch unreachable).as(mlir.Type).?).as(mlir.Attribute).? });
|
||||||
if (args.arg_attrs.len > 0) {
|
if (args.arg_attrs.len > 0) {
|
||||||
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "arg_attrs", mlir.ArrayAttribute.init(ctx, args.arg_attrs).as(mlir.Attribute).? });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (args.res_attrs.len > 0) {
|
||||||
|
attrs_tuple_buffer.appendAssumeCapacity(.{ "res_attrs", mlir.ArrayAttribute.init(ctx, args.res_attrs).as(mlir.Attribute).? });
|
||||||
|
}
|
||||||
|
|
||||||
return mlir.Operation.make(ctx, "func.func", .{
|
return mlir.Operation.make(ctx, "func.func", .{
|
||||||
.blocks = &.{args.block},
|
.blocks = &.{args.block},
|
||||||
.attributes = attrs_tuple_buffer.constSlice(),
|
.attributes = attrs_tuple_buffer.constSlice(),
|
||||||
|
|||||||
@ -697,6 +697,22 @@ pub fn custom_call(ctx: mlir.Context, inputs: []const mlir.Value, opts: CustomCa
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn sharding(ctx: mlir.Context, inputs: []const mlir.Value, sharding_spec: mlir.StringAttribute, res_types: []const mlir.Type, location: mlir.Location) mlir.Operation {
|
||||||
|
return mlir.Operation.make(ctx, "stablehlo.custom_call", .{
|
||||||
|
.operands = inputs,
|
||||||
|
.results = res_types,
|
||||||
|
.attributes = &.{
|
||||||
|
.{ "api_version", mlir.IntegerAttribute(.i32).init(ctx, 1).asAttr() },
|
||||||
|
.{ "call_target_name", mlir.StringAttribute.init(ctx, "Sharding").asAttr() },
|
||||||
|
.{ "has_side_effect", mlir.BoolAttribute.init(ctx, false).asAttr() },
|
||||||
|
.{ "backend_config", mlir.StringAttribute.init(ctx, &.{}).asAttr() },
|
||||||
|
.{ "output_operand_aliases", mlir.ArrayAttribute.init(ctx, &.{}).asAttr() },
|
||||||
|
.{ "mhlo.sharding", sharding_spec.asAttr() },
|
||||||
|
},
|
||||||
|
.location = location,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub const DotDimensionNumbersAttribute = struct {
|
pub const DotDimensionNumbersAttribute = struct {
|
||||||
_inner: c.MlirAttribute,
|
_inner: c.MlirAttribute,
|
||||||
|
|
||||||
|
|||||||
@ -466,6 +466,10 @@ pub const ArrayAttribute = struct {
|
|||||||
pub fn get(self: Self, index: usize) Attribute {
|
pub fn get(self: Self, index: usize) Attribute {
|
||||||
return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index)));
|
return Attribute.wrap(c.mlirArrayAttrGetElement(self.inner(), @intCast(index)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn asAttr(self: Self) Attribute {
|
||||||
|
return .{ ._inner = self._inner };
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
pub fn IntegerAttribute(comptime it: IntegerTypes) type {
|
||||||
|
|||||||
@ -82,18 +82,14 @@ pub const Buffer = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
/// Wraps pre-exisiting `pjrt.Buffer` shards into one `zml.Buffer`.
|
||||||
pub fn fromPjrtBuffers(platform: Platform, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
pub fn fromPjrtBuffers(platform: Platform, shape_: Shape, pjrt_buffers: []const *pjrt.Buffer) Buffer {
|
||||||
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
meta.assert(pjrt_buffers.len <= MAX_NUM_SHARDS, "ZML doesn't support having more than {} shards. Received {} shards for one buffer.", .{ MAX_NUM_SHARDS, pjrt_buffers.len });
|
||||||
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
|
meta.assert(pjrt_buffers.len > 0, "fromPjrtBuffers expects at least one buffer, got 0.", .{});
|
||||||
var shards: Shards = .{};
|
var shards: Shards = .{};
|
||||||
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
shards.appendSliceAssumeCapacity(pjrt_buffers);
|
||||||
return .{
|
return .{
|
||||||
._api = platform.pjrt_api,
|
._api = platform.pjrt_api,
|
||||||
._shape = Shape.init(
|
._shape = shape_,
|
||||||
// This isn't with sharded axes.
|
|
||||||
pjrt_buffers[0].getDimensions(platform.pjrt_api),
|
|
||||||
dtypeFromBufferType(pjrt_buffers[0].getElementType(platform.pjrt_api)),
|
|
||||||
),
|
|
||||||
._shards = shards,
|
._shards = shards,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@ -183,7 +183,6 @@ pub const CompilationContext = struct {
|
|||||||
comptime func: anytype,
|
comptime func: anytype,
|
||||||
model: *const ModuleSignature(func).ModelT,
|
model: *const ModuleSignature(func).ModelT,
|
||||||
args: *const ModuleSignature(func).ArgsT,
|
args: *const ModuleSignature(func).ArgsT,
|
||||||
opts: struct { add_donations_attributes: bool = false, sharding: bool = true },
|
|
||||||
) error{OutOfMemory}!MlirFn {
|
) error{OutOfMemory}!MlirFn {
|
||||||
const frame = self._tracer.frameStart("generateBytecode.emit");
|
const frame = self._tracer.frameStart("generateBytecode.emit");
|
||||||
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
|
errdefer self._tracer.frameEnd(frame, "generateBytecode.emit");
|
||||||
@ -244,6 +243,7 @@ pub const CompilationContext = struct {
|
|||||||
|
|
||||||
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
var fn_res_values: [out_tensor_count]mlir.Value = undefined;
|
||||||
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
self.extractValuesAndTypes(&fn_res, &fn_res_values, fn_res_types, fn_res_shapes, fn_res_donations);
|
||||||
|
|
||||||
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
const fn_ret = dialect.func.return_(mlir_ctx, &fn_res_values, loc);
|
||||||
fn_body.addOperationsRecursive(fn_ret);
|
fn_body.addOperationsRecursive(fn_ret);
|
||||||
}
|
}
|
||||||
@ -251,19 +251,21 @@ pub const CompilationContext = struct {
|
|||||||
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
|
const arg_attrs = try arena.alloc(AttributeList, tensor_count);
|
||||||
@memset(arg_attrs, .{});
|
@memset(arg_attrs, .{});
|
||||||
|
|
||||||
// Donations attributes only make sense on the main function.
|
const res_attrs = try arena.alloc(AttributeList, out_tensor_count);
|
||||||
if (opts.add_donations_attributes) {
|
@memset(res_attrs, .{});
|
||||||
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
|
||||||
}
|
|
||||||
if (opts.sharding) {
|
|
||||||
self.addShardingAttributes(arg_attrs, input_shapes.items);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// Donations attributes only make sense on the main function.
|
||||||
|
self.addDonationsAttributes(arg_attrs, fn_res_donations);
|
||||||
|
|
||||||
|
if (self._platform.sharding().num_partitions > 1) {
|
||||||
|
self.addShardingAttributes(arg_attrs, res_attrs, input_shapes.items, fn_res_shapes);
|
||||||
|
}
|
||||||
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
|
const mlir_fn = dialect.func.func(self.mlirCtx(), .{
|
||||||
.sym_name = fn_name,
|
.sym_name = fn_name,
|
||||||
.args = input_types,
|
.args = input_types,
|
||||||
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
.arg_attrs = try finalizeAttributeList(arena, mlir_ctx, arg_attrs),
|
||||||
.results = fn_res_types,
|
.results = fn_res_types,
|
||||||
|
.res_attrs = try finalizeAttributeList(arena, mlir_ctx, res_attrs),
|
||||||
.block = fn_body,
|
.block = fn_body,
|
||||||
.location = loc,
|
.location = loc,
|
||||||
});
|
});
|
||||||
@ -344,7 +346,7 @@ pub const CompilationContext = struct {
|
|||||||
var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
|
var comp = try zml.module.CompilationContext.init(allocator, "test", platform);
|
||||||
defer comp.deinit();
|
defer comp.deinit();
|
||||||
var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }};
|
var tensor_args = .{Tensor{ ._shape = s, ._id = .{ .arg_id = 1234 } }};
|
||||||
const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args, .{ .add_donations_attributes = true });
|
const f = try comp.generateBytecode(allocator, "test.generateBytecode.Local.forward", Local.forward, &model, &tensor_args);
|
||||||
|
|
||||||
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
|
var mlir_bytecode: std.ArrayListUnmanaged(u8) = .{};
|
||||||
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
try mlir_bytecode.writer(allocator).print("{}", .{f.mlir_fn.mlirFormatter(.{})});
|
||||||
@ -359,25 +361,42 @@ pub const CompilationContext = struct {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
fn addShardingAttributes(self: CompilationContext, attributes: []AttributeList, shapes: []const Shape) void {
|
pub fn getShardingAttr(self: CompilationContext, shape: Shape) mlir.StringAttribute {
|
||||||
const mlir_ctx = self.mlirCtx();
|
const mlir_ctx = self.mlirCtx();
|
||||||
if (!self._platform.compilation_options.sharding_enabled) return;
|
|
||||||
|
|
||||||
const num_partitions = self._platform.sharding().num_partitions;
|
const num_partitions = self._platform.sharding().num_partitions;
|
||||||
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
var sharding_str: std.BoundedArray(u8, 128) = .{};
|
||||||
|
|
||||||
|
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
||||||
|
return mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn addShardingAttributes(self: CompilationContext, arg_attrs: []AttributeList, res_attrs: []AttributeList, input_shapes: []const Shape, output_shapes: []const Shape) void {
|
||||||
|
const mlir_ctx = self.mlirCtx();
|
||||||
|
if (!self._platform.compilation_options.sharding_enabled) return;
|
||||||
|
|
||||||
const mhlo_default_layout = mlir.NamedAttribute.init(
|
const mhlo_default_layout = mlir.NamedAttribute.init(
|
||||||
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
|
mlir.Identifier.get(mlir_ctx, "mhlo.layout_mode"),
|
||||||
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
|
mlir.StringAttribute.init(mlir_ctx, "default").asAttr(),
|
||||||
);
|
);
|
||||||
for (attributes, shapes) |*attr, shape| {
|
for (arg_attrs, input_shapes) |*attr, shape| {
|
||||||
attr.appendAssumeCapacity(mhlo_default_layout);
|
attr.appendAssumeCapacity(mhlo_default_layout);
|
||||||
|
|
||||||
writeShardingRepresentation(shape, num_partitions, sharding_str.writer()) catch unreachable;
|
const sharding_attr = self.getShardingAttr(shape);
|
||||||
defer sharding_str.len = 0;
|
|
||||||
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
||||||
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
||||||
mlir.StringAttribute.init(mlir_ctx, sharding_str.constSlice()).asAttr(),
|
sharding_attr.asAttr(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (res_attrs, output_shapes) |*attr, shape| {
|
||||||
|
attr.appendAssumeCapacity(mhlo_default_layout);
|
||||||
|
|
||||||
|
const sharding_attr = self.getShardingAttr(shape);
|
||||||
|
|
||||||
|
attr.appendAssumeCapacity(mlir.NamedAttribute.init(
|
||||||
|
mlir.Identifier.get(mlir_ctx, "mhlo.sharding"),
|
||||||
|
sharding_attr.asAttr(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -668,18 +687,20 @@ fn fillBuffers(v: anytype, buffers: []const [*]*pjrt.Buffer, start: u32, len: u3
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
/// Visit the given struct and override tensors by creating a new one using the provided PJRT buffers.
|
||||||
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, expected_count: u32) void {
|
pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjrt.Buffer, buffer_shapes: []Shape, expected_count: u32) void {
|
||||||
const LocalContext = struct {
|
const LocalContext = struct {
|
||||||
index: u32,
|
index: u32,
|
||||||
platform: Platform,
|
platform: Platform,
|
||||||
buffers: []const [*]*pjrt.Buffer,
|
buffers: []const [*]*pjrt.Buffer,
|
||||||
expected_count: u32,
|
expected_count: u32,
|
||||||
|
buffer_shapes: []Shape,
|
||||||
};
|
};
|
||||||
var local_ctx: LocalContext = .{
|
var local_ctx: LocalContext = .{
|
||||||
.index = 0,
|
.index = 0,
|
||||||
.platform = platform,
|
.platform = platform,
|
||||||
.buffers = buffers,
|
.buffers = buffers,
|
||||||
.expected_count = expected_count,
|
.expected_count = expected_count,
|
||||||
|
.buffer_shapes = buffer_shapes,
|
||||||
};
|
};
|
||||||
meta.visit((struct {
|
meta.visit((struct {
|
||||||
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
fn cb(ctx: *LocalContext, buffer: *Buffer) void {
|
||||||
@ -691,7 +712,7 @@ pub fn assignRawBuffers(v: anytype, platform: Platform, buffers: []const [*]*pjr
|
|||||||
for (ctx.buffers) |buff| {
|
for (ctx.buffers) |buff| {
|
||||||
shards.appendAssumeCapacity(buff[i]);
|
shards.appendAssumeCapacity(buff[i]);
|
||||||
}
|
}
|
||||||
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, shards.constSlice());
|
buffer.* = Buffer.fromPjrtBuffers(ctx.platform, ctx.buffer_shapes[i], shards.constSlice());
|
||||||
}
|
}
|
||||||
}).cb, &local_ctx, v);
|
}).cb, &local_ctx, v);
|
||||||
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
meta.internalAssert(local_ctx.index == expected_count, "Pjrt call returned {} tensors, but the return type {s}, contains {} Buffers. Note that modules need to have a comptime know number of returned tensors.", .{ buffers.len, @typeName(@TypeOf(v)), local_ctx.index });
|
||||||
@ -731,8 +752,12 @@ const BaseExe = struct {
|
|||||||
args_buffer_count: u32,
|
args_buffer_count: u32,
|
||||||
/// Number of buffers in result.
|
/// Number of buffers in result.
|
||||||
result_buffer_count: u32,
|
result_buffer_count: u32,
|
||||||
|
/// Shapes of buffers in result.
|
||||||
|
result_buffer_shapes: []Shape,
|
||||||
/// Num devices used (>1 for sharded executable)
|
/// Num devices used (>1 for sharded executable)
|
||||||
num_devices: u8,
|
num_devices: u8,
|
||||||
|
/// Allocator backing result_buffer_shapes and deinit by ExeWithWeights
|
||||||
|
_allocator: std.heap.ArenaAllocator,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Represents a ZML model, compiled into a PJRT executable.
|
/// Represents a ZML model, compiled into a PJRT executable.
|
||||||
@ -815,6 +840,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
// Free in reverse order of allocation.
|
// Free in reverse order of allocation.
|
||||||
self._allocator.free(self._all_per_device);
|
self._allocator.free(self._all_per_device);
|
||||||
self._allocator.free(self._all_buffers);
|
self._allocator.free(self._all_buffers);
|
||||||
|
self.inner._allocator.deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn platform(self: Self) Platform {
|
pub fn platform(self: Self) Platform {
|
||||||
@ -836,7 +862,7 @@ pub fn ExeWithWeights(comptime func: anytype) type {
|
|||||||
}) catch unreachable;
|
}) catch unreachable;
|
||||||
|
|
||||||
var result: Bufferized(Signature.ReturnT) = undefined;
|
var result: Bufferized(Signature.ReturnT) = undefined;
|
||||||
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_count);
|
assignRawBuffers(&result, self.inner.platform, self.output_per_device, self.inner.result_buffer_shapes, self.inner.result_buffer_count);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -860,7 +886,7 @@ fn compileInternal(
|
|||||||
var timer = std.time.Timer.start() catch null;
|
var timer = std.time.Timer.start() catch null;
|
||||||
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
const tensor_args = context.tensorFromShapes(ModuleSignature(func).ArgsT, arena, args);
|
||||||
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
// Run in a dedicated thread because compilation relies on `threadlocal`.
|
||||||
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args, .{ .add_donations_attributes = true } });
|
const f = try asynk.callGeneric(CompilationContext.generateBytecode, .{ context, arena, "main", func, &model, &tensor_args });
|
||||||
context._module.getBody().appendOperation(f.mlir_fn);
|
context._module.getBody().appendOperation(f.mlir_fn);
|
||||||
|
|
||||||
const sharding = context._platform.sharding();
|
const sharding = context._platform.sharding();
|
||||||
@ -884,13 +910,18 @@ fn compileInternal(
|
|||||||
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)});
|
if (time_ms > 1000) log.info("Compilation took {d:.3}s", .{meta.divFloat(f32, time_ms, 1000)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var arena_state_exe = std.heap.ArenaAllocator.init(allocator);
|
||||||
|
const arena_exe = arena_state_exe.allocator();
|
||||||
|
|
||||||
return .{
|
return .{
|
||||||
.platform = context._platform,
|
.platform = context._platform,
|
||||||
.exe = loaded_executable,
|
.exe = loaded_executable,
|
||||||
.model_buffer_count = f.n_model,
|
.model_buffer_count = f.n_model,
|
||||||
.args_buffer_count = f.n_args,
|
.args_buffer_count = f.n_args,
|
||||||
.result_buffer_count = @intCast(f.res_types.len),
|
.result_buffer_count = @intCast(f.res_types.len),
|
||||||
|
.result_buffer_shapes = arena_exe.dupe(Shape, f.res_shapes) catch unreachable,
|
||||||
.num_devices = sharding.num_replicas * sharding.num_partitions,
|
.num_devices = sharding.num_replicas * sharding.num_partitions,
|
||||||
|
._allocator = arena_state_exe,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -159,9 +159,30 @@ pub const Tensor = struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
|
pub fn withSharding(self: Tensor, axes_: anytype) Tensor {
|
||||||
|
return switch (self._id) {
|
||||||
|
.arg_id, .mlir => {
|
||||||
|
const ctx = self.getContext();
|
||||||
|
var res = self;
|
||||||
|
res._shape = self._shape.withSharding(axes_);
|
||||||
|
|
||||||
|
const sharding = ctx.getShardingAttr(res._shape);
|
||||||
|
|
||||||
|
const op = dialect.stablehlo.sharding(
|
||||||
|
ctx.mlirCtx(),
|
||||||
|
&.{self.value()},
|
||||||
|
sharding,
|
||||||
|
&.{self.value().getType()},
|
||||||
|
ctx.mlirCtx().location(@src()),
|
||||||
|
);
|
||||||
|
|
||||||
|
return _result(res._shape, op.result(0));
|
||||||
|
},
|
||||||
|
.buffer_id => {
|
||||||
var res = self;
|
var res = self;
|
||||||
res._shape = self._shape.withSharding(axes_);
|
res._shape = self._shape.withSharding(axes_);
|
||||||
return res;
|
return res;
|
||||||
|
},
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Tensor with new tag names.
|
/// Returns a Tensor with new tag names.
|
||||||
|
|||||||
@ -175,7 +175,7 @@ pub fn testLayerOut(
|
|||||||
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
|
log.warn("Reference models uses {d} inputs, but implementation uses {d}", .{ n_in_exp, n_in });
|
||||||
}
|
}
|
||||||
|
|
||||||
const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform, .{});
|
const exe = try zml.compileModel(alloc, layer, .forward, input_shapes, platform);
|
||||||
|
|
||||||
const n_out_exp = activations.countLayers(out_name);
|
const n_out_exp = activations.countLayers(out_name);
|
||||||
if (exe.inner.result_buffer_count != n_out_exp) {
|
if (exe.inner.result_buffer_count != n_out_exp) {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user