Add sharding usage to the benchmark and simple_layer example programs.

This commit is contained in:
Foke Singh 2023-02-23 11:18:27 +00:00
parent fc718ab649
commit cad1a688da
2 changed files with 16 additions and 12 deletions

View File

@ -42,7 +42,9 @@ pub fn asyncMain() !void {
defer context.deinit(); defer context.deinit();
// Auto-select platform // Auto-select platform
const platform = context.autoPlatform(); const platform = context.autoPlatform().withCompilationOptions(.{
.sharding_enabled = true,
});
{ {
// List available targets // List available targets
std.debug.print("Available Platforms:\n", .{}); std.debug.print("Available Platforms:\n", .{});
@ -76,8 +78,8 @@ pub fn asyncMain() !void {
var args = std.process.args(); var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs); const cli_args = flags.parse(&args, CliArgs);
const input_shape = zml.Shape.init(.{ cli_args.size, cli_args.size }, cli_args.dtype); const a_shape = zml.Shape.init(.{ cli_args.size, cli_args.size }, cli_args.dtype).withTags(.{ .m, .k }).withSharding(.{.k});
const b_shape = a_shape.withTags(.{ .k, .n }).withSharding(.{.k});
var timer = try std.time.Timer.start(); var timer = try std.time.Timer.start();
std.debug.print("\nCompiling model to MLIR....\n", .{}); std.debug.print("\nCompiling model to MLIR....\n", .{});
@ -85,7 +87,7 @@ pub fn asyncMain() !void {
// Start compiling. // Start compiling.
// The shape of the input tensor, we have to pass in manually. // The shape of the input tensor, we have to pass in manually.
timer.reset(); timer.reset();
var compilation = try async_(zml.module.compileModel, .{ allocator, Benchmark{}, .forward, .{ input_shape.withTags(.{ .m, .k }), input_shape.withTags(.{ .k, .n }) }, platform }); var compilation = try async_(zml.module.compileModel, .{ allocator, Benchmark{}, .forward, .{ a_shape, b_shape }, platform });
// Wait for compilation to finish // Wait for compilation to finish
const compiled = try compilation.await_(); const compiled = try compilation.await_();
@ -100,9 +102,9 @@ pub fn asyncMain() !void {
var rng = std.Random.DefaultPrng.init(0); var rng = std.Random.DefaultPrng.init(0);
const random = rng.random(); const random = rng.random();
var a_buffer = try createRandomBuffer(allocator, platform, input_shape, random); var a_buffer = try createRandomBuffer(allocator, platform, a_shape, random);
defer a_buffer.deinit(); defer a_buffer.deinit();
var b_buffer = try createRandomBuffer(allocator, platform, input_shape, random); var b_buffer = try createRandomBuffer(allocator, platform, b_shape, random);
defer b_buffer.deinit(); defer b_buffer.deinit();
std.debug.print("\nRunning benchmark....\n", .{}); std.debug.print("\nRunning benchmark....\n", .{});

View File

@ -41,9 +41,9 @@ pub fn asyncMain() !void {
const platform = context.autoPlatform(); const platform = context.autoPlatform();
// Our weights and bias to use // Our weights and bias to use
var weights = [3]f16{ 2.0, 2.0, 2.0 }; var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 };
var bias = [3]f16{ 1.0, 2.0, 3.0 }; var bias = [4]f16{ 1.0, 2.0, 3.0, 4.0 };
const input_shape = zml.Shape.init(.{3}, .f16); const input_shape = zml.Shape.init(.{4}, .f16);
// We manually produce a BufferStore. You would not normally do that. // We manually produce a BufferStore. You would not normally do that.
// A BufferStore is usually created by loading model data from a file. // A BufferStore is usually created by loading model data from a file.
@ -59,7 +59,9 @@ pub fn asyncMain() !void {
// A clone of our model, consisting of shapes. We only need shapes for compiling. // A clone of our model, consisting of shapes. We only need shapes for compiling.
// We use the BufferStore to infer the shapes. // We use the BufferStore to infer the shapes.
const model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store); var model_shapes = try zml.aio.populateModel(Layer, allocator, buffer_store);
model_shapes.weight = model_shapes.weight.withSharding(.{-1});
model_shapes.bias = model_shapes.bias.?.withSharding(.{-1});
// Start compiling. This uses the inferred shapes from the BufferStore. // Start compiling. This uses the inferred shapes from the BufferStore.
// The shape of the input tensor, we have to pass in manually. // The shape of the input tensor, we have to pass in manually.
@ -68,7 +70,7 @@ pub fn asyncMain() !void {
// Produce a bufferized weights struct from the fake BufferStore. // Produce a bufferized weights struct from the fake BufferStore.
// This is like the inferred shapes, but with actual values. // This is like the inferred shapes, but with actual values.
// We will need to send those to the computation device later. // We will need to send those to the computation device later.
var model_weights = try zml.aio.loadBuffers(Layer, .{}, buffer_store, arena, platform); var model_weights = try zml.aio.loadModelBuffers(Layer, model_shapes, buffer_store, arena, platform);
defer zml.aio.unloadBuffers(&model_weights); // for good practice defer zml.aio.unloadBuffers(&model_weights); // for good practice
// Wait for compilation to finish // Wait for compilation to finish
@ -82,7 +84,7 @@ pub fn asyncMain() !void {
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer // Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
// with a specific shape from an array. // with a specific shape from an array.
// For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape. // For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape.
var input = [3]f16{ 5.0, 5.0, 5.0 }; var input = [4]f16{ 5.0, 5.0, 5.0, 5.0 };
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input)); var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input));
defer input_buffer.deinit(); defer input_buffer.deinit();