Add sharding usage to the benchmark and simple_layer example programs.
This commit is contained in:
parent
fc718ab649
commit
cad1a688da
@ -42,7 +42,9 @@ pub fn asyncMain() !void {
|
|||||||
defer context.deinit();
|
defer context.deinit();
|
||||||
|
|
||||||
// Auto-select platform
|
// Auto-select platform
|
||||||
const platform = context.autoPlatform();
|
const platform = context.autoPlatform().withCompilationOptions(.{
|
||||||
|
.sharding_enabled = true,
|
||||||
|
});
|
||||||
{
|
{
|
||||||
// List available targets
|
// List available targets
|
||||||
std.debug.print("Available Platforms:\n", .{});
|
std.debug.print("Available Platforms:\n", .{});
|
||||||
@ -76,8 +78,8 @@ pub fn asyncMain() !void {
|
|||||||
var args = std.process.args();
|
var args = std.process.args();
|
||||||
const cli_args = flags.parse(&args, CliArgs);
|
const cli_args = flags.parse(&args, CliArgs);
|
||||||
|
|
||||||
const input_shape = zml.Shape.init(.{ cli_args.size, cli_args.size }, cli_args.dtype);
|
const a_shape = zml.Shape.init(.{ cli_args.size, cli_args.size }, cli_args.dtype).withTags(.{ .m, .k }).withSharding(.{.k});
|
||||||
|
const b_shape = a_shape.withTags(.{ .k, .n }).withSharding(.{.k});
|
||||||
var timer = try std.time.Timer.start();
|
var timer = try std.time.Timer.start();
|
||||||
|
|
||||||
std.debug.print("\nCompiling model to MLIR....\n", .{});
|
std.debug.print("\nCompiling model to MLIR....\n", .{});
|
||||||
@ -85,7 +87,7 @@ pub fn asyncMain() !void {
|
|||||||
// Start compiling.
|
// Start compiling.
|
||||||
// The shape of the input tensor, we have to pass in manually.
|
// The shape of the input tensor, we have to pass in manually.
|
||||||
timer.reset();
|
timer.reset();
|
||||||
var compilation = try async_(zml.module.compileModel, .{ allocator, Benchmark{}, .forward, .{ input_shape.withTags(.{ .m, .k }), input_shape.withTags(.{ .k, .n }) }, platform });
|
var compilation = try async_(zml.module.compileModel, .{ allocator, Benchmark{}, .forward, .{ a_shape, b_shape }, platform });
|
||||||
|
|
||||||
// Wait for compilation to finish
|
// Wait for compilation to finish
|
||||||
const compiled = try compilation.await_();
|
const compiled = try compilation.await_();
|
||||||
@ -100,9 +102,9 @@ pub fn asyncMain() !void {
|
|||||||
var rng = std.Random.DefaultPrng.init(0);
|
var rng = std.Random.DefaultPrng.init(0);
|
||||||
const random = rng.random();
|
const random = rng.random();
|
||||||
|
|
||||||
var a_buffer = try createRandomBuffer(allocator, platform, input_shape, random);
|
var a_buffer = try createRandomBuffer(allocator, platform, a_shape, random);
|
||||||
defer a_buffer.deinit();
|
defer a_buffer.deinit();
|
||||||
var b_buffer = try createRandomBuffer(allocator, platform, input_shape, random);
|
var b_buffer = try createRandomBuffer(allocator, platform, b_shape, random);
|
||||||
defer b_buffer.deinit();
|
defer b_buffer.deinit();
|
||||||
|
|
||||||
std.debug.print("\nRunning benchmark....\n", .{});
|
std.debug.print("\nRunning benchmark....\n", .{});
|
||||||
|
|||||||
@ -41,9 +41,9 @@ pub fn asyncMain() !void {
|
|||||||
const platform = context.autoPlatform();
|
const platform = context.autoPlatform();
|
||||||
|
|
||||||
// Our weights and bias to use
|
// Our weights and bias to use
|
||||||
var weights = [3]f16{ 2.0, 2.0, 2.0 };
|
var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 };
|
||||||
var bias = [3]f16{ 1.0, 2.0, 3.0 };
|
var bias = [4]f16{ 1.0, 2.0, 3.0, 4.0 };
|
||||||
const input_shape = zml.Shape.init(.{3}, .f16);
|
const input_shape = zml.Shape.init(.{4}, .f16);
|
||||||
|
|
||||||
// We manually produce a BufferStore. You would not normally do that.
|
// We manually produce a BufferStore. You would not normally do that.
|
||||||
// A BufferStore is usually created by loading model data from a file.
|
// A BufferStore is usually created by loading model data from a file.
|
||||||
@ -59,7 +59,9 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
// A clone of our model, consisting of shapes. We only need shapes for compiling.
|
// A clone of our model, consisting of shapes. We only need shapes for compiling.
|
||||||
// We use the BufferStore to infer the shapes.
|
// We use the BufferStore to infer the shapes.
|
||||||
const model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store);
|
var model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store);
|
||||||
|
model_shapes.weight = model_shapes.weight.withSharding(.{-1});
|
||||||
|
model_shapes.bias = model_shapes.bias.?.withSharding(.{-1});
|
||||||
|
|
||||||
// Start compiling. This uses the inferred shapes from the BufferStore.
|
// Start compiling. This uses the inferred shapes from the BufferStore.
|
||||||
// The shape of the input tensor, we have to pass in manually.
|
// The shape of the input tensor, we have to pass in manually.
|
||||||
@ -68,7 +70,7 @@ pub fn asyncMain() !void {
|
|||||||
// Produce a bufferized weights struct from the fake BufferStore.
|
// Produce a bufferized weights struct from the fake BufferStore.
|
||||||
// This is like the inferred shapes, but with actual values.
|
// This is like the inferred shapes, but with actual values.
|
||||||
// We will need to send those to the computation device later.
|
// We will need to send those to the computation device later.
|
||||||
var model_weights = try zml.aio.loadBuffers(Layer, .{}, buffer_store, arena, platform);
|
var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, buffer_store, arena, platform);
|
||||||
defer zml.aio.unloadBuffers(&model_weights); // for good practice
|
defer zml.aio.unloadBuffers(&model_weights); // for good practice
|
||||||
|
|
||||||
// Wait for compilation to finish
|
// Wait for compilation to finish
|
||||||
@ -82,7 +84,7 @@ pub fn asyncMain() !void {
|
|||||||
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
|
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
|
||||||
// with a specific shape from an array.
|
// with a specific shape from an array.
|
||||||
// For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape.
|
// For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape.
|
||||||
var input = [3]f16{ 5.0, 5.0, 5.0 };
|
var input = [4]f16{ 5.0, 5.0, 5.0, 5.0 };
|
||||||
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input));
|
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input));
|
||||||
defer input_buffer.deinit();
|
defer input_buffer.deinit();
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user