2023-01-03 10:21:07 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
const zml = @import("zml");
|
|
|
|
|
const asynk = @import("async");
|
|
|
|
|
const flags = @import("tigerbeetle/flags");
|
|
|
|
|
|
2023-03-20 15:31:44 +00:00
|
|
|
// set log level to debug to print the generated IR
|
|
|
|
|
pub const std_options = .{
|
|
|
|
|
.log_level = .debug,
|
|
|
|
|
};
|
|
|
|
|
|
2023-01-03 10:21:07 +00:00
|
|
|
/// Model definition
|
|
|
|
|
const Benchmark = struct {
|
|
|
|
|
pub fn forward(self: Benchmark, a: zml.Tensor, b: zml.Tensor) zml.Tensor {
|
|
|
|
|
_ = self;
|
2023-03-20 15:31:44 +00:00
|
|
|
return a.withSharding(.{.k}).dot(b.withSharding(.{.k}), .{.k}).withSharding(.{.m});
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub fn main() !void {
|
2023-06-27 14:23:22 +00:00
|
|
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn asyncMain() !void {
|
|
|
|
|
const CliArgs = struct {
|
|
|
|
|
pub const help =
|
|
|
|
|
\\ benchmark --size=4096 --dtype=f16
|
|
|
|
|
;
|
|
|
|
|
size: usize = 4096,
|
|
|
|
|
dtype: zml.DataType = .f16,
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Short lived allocations
|
|
|
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
|
|
|
|
defer _ = gpa.deinit();
|
|
|
|
|
const allocator = gpa.allocator();
|
|
|
|
|
|
|
|
|
|
// Arena allocator for BufferStore etc.
|
|
|
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|
|
|
|
defer arena_state.deinit();
|
|
|
|
|
const arena = arena_state.allocator();
|
|
|
|
|
|
|
|
|
|
var context = try zml.Context.init();
|
|
|
|
|
defer context.deinit();
|
|
|
|
|
|
|
|
|
|
// Auto-select platform
|
2023-02-23 11:18:27 +00:00
|
|
|
const platform = context.autoPlatform().withCompilationOptions(.{
|
|
|
|
|
.sharding_enabled = true,
|
|
|
|
|
});
|
2023-06-27 14:23:22 +00:00
|
|
|
context.printAvailablePlatforms(platform);
|
2023-01-03 10:21:07 +00:00
|
|
|
|
|
|
|
|
var args = std.process.args();
|
|
|
|
|
const cli_args = flags.parse(&args, CliArgs);
|
|
|
|
|
|
2023-02-23 11:18:27 +00:00
|
|
|
const a_shape = zml.Shape.init(.{ cli_args.size, cli_args.size }, cli_args.dtype).withTags(.{ .m, .k }).withSharding(.{.k});
|
|
|
|
|
const b_shape = a_shape.withTags(.{ .k, .n }).withSharding(.{.k});
|
2023-01-03 10:21:07 +00:00
|
|
|
var timer = try std.time.Timer.start();
|
|
|
|
|
|
|
|
|
|
std.debug.print("\nCompiling model to MLIR....\n", .{});
|
|
|
|
|
std.debug.print("-" ** 160 ++ "\n", .{});
|
|
|
|
|
// Start compiling.
|
|
|
|
|
// The shape of the input tensor, we have to pass in manually.
|
|
|
|
|
timer.reset();
|
2023-06-01 16:11:58 +00:00
|
|
|
var compilation = try asynk.asyncGeneric(zml.module.compileModel, .{ allocator, Benchmark{}, .forward, .{ a_shape, b_shape }, platform });
|
2023-01-03 10:21:07 +00:00
|
|
|
|
|
|
|
|
// Wait for compilation to finish
|
|
|
|
|
const compiled = try compilation.await_();
|
|
|
|
|
const compilation_elapsed = timer.lap() / std.time.ns_per_ms;
|
|
|
|
|
std.debug.print("-" ** 160 ++ "\n\n", .{});
|
|
|
|
|
std.debug.print("✅ Compiled Benchmark model in {d} milliseconds! \n", .{compilation_elapsed});
|
|
|
|
|
|
|
|
|
|
// pass the model weights to the compiled module to create an executable module
|
|
|
|
|
var executable = try compiled.prepare(arena, .{});
|
|
|
|
|
defer executable.deinit();
|
|
|
|
|
|
|
|
|
|
var rng = std.Random.DefaultPrng.init(0);
|
|
|
|
|
const random = rng.random();
|
|
|
|
|
|
2023-02-23 11:18:27 +00:00
|
|
|
var a_buffer = try createRandomBuffer(allocator, platform, a_shape, random);
|
2023-01-03 10:21:07 +00:00
|
|
|
defer a_buffer.deinit();
|
2023-02-23 11:18:27 +00:00
|
|
|
var b_buffer = try createRandomBuffer(allocator, platform, b_shape, random);
|
2023-01-03 10:21:07 +00:00
|
|
|
defer b_buffer.deinit();
|
|
|
|
|
|
|
|
|
|
std.debug.print("\nRunning benchmark....\n", .{});
|
|
|
|
|
|
|
|
|
|
// Ignore first run
|
|
|
|
|
{
|
|
|
|
|
var result: zml.Buffer = executable.call(.{ a_buffer, b_buffer });
|
|
|
|
|
defer result.deinit();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// call our executable module
|
|
|
|
|
timer.reset();
|
|
|
|
|
var result: zml.Buffer = executable.call(.{ a_buffer, b_buffer });
|
|
|
|
|
defer result.deinit();
|
|
|
|
|
const elapsed_ns = timer.lap();
|
|
|
|
|
const elapsed_ms = @as(f64, @floatFromInt(elapsed_ns)) / std.time.ns_per_ms;
|
|
|
|
|
const elapsed_s = @as(f64, @floatFromInt(elapsed_ns)) / std.time.ns_per_s;
|
|
|
|
|
|
|
|
|
|
std.debug.print("\n✅ Benchmark done!\n\n", .{});
|
|
|
|
|
|
|
|
|
|
const floating_op_count = 2 * cli_args.size * cli_args.size * cli_args.size;
|
|
|
|
|
const flops = @as(f64, @floatFromInt(floating_op_count)) / elapsed_s;
|
|
|
|
|
std.debug.print("Dot product size: {d}x{d} - Datatype: {s} - Elapsed: {d:.3}ms - {d:.3} GFLOP/s\n\n", .{ cli_args.size, cli_args.size, @tagName(cli_args.dtype), elapsed_ms, flops / 1_000_000_000 });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn createRandomBuffer(allocator: std.mem.Allocator, platform: zml.Platform, shape: zml.Shape, random: std.Random) !zml.Buffer {
|
|
|
|
|
const data = try allocator.alloc(u8, shape.byteSize());
|
|
|
|
|
defer allocator.free(data);
|
|
|
|
|
|
|
|
|
|
switch (shape.dtype()) {
|
|
|
|
|
inline else => |v| {
|
|
|
|
|
const ZigType = v.toZigType();
|
|
|
|
|
switch (comptime v.class()) {
|
|
|
|
|
.bool => unreachable,
|
|
|
|
|
.integer => {
|
|
|
|
|
for (std.mem.bytesAsSlice(ZigType, data)) |*e| e.* = random.int(ZigType);
|
|
|
|
|
},
|
|
|
|
|
.float => {
|
|
|
|
|
const value = random.float(f64);
|
|
|
|
|
for (std.mem.bytesAsSlice(ZigType, data)) |*e| e.* = if (ZigType == f64)
|
|
|
|
|
value
|
|
|
|
|
else if (ZigType == f32)
|
|
|
|
|
@floatCast(value)
|
|
|
|
|
else if (ZigType == f16)
|
|
|
|
|
@floatCast(value)
|
|
|
|
|
else
|
|
|
|
|
@bitCast(random.int(std.meta.Int(.unsigned, @bitSizeOf(ZigType))));
|
|
|
|
|
},
|
|
|
|
|
.complex => unreachable,
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var host_buffer = zml.HostBuffer.fromBytes(shape, data);
|
|
|
|
|
errdefer host_buffer.deinit(allocator);
|
|
|
|
|
return zml.Buffer.from(platform, host_buffer);
|
|
|
|
|
}
|