268 lines
10 KiB
Zig
268 lines
10 KiB
Zig
|
|
const std = @import("std");
|
||
|
|
const zml = @import("zml");
|
||
|
|
const asynk = @import("async");
|
||
|
|
|
||
|
|
const async_ = asynk.async_;
|
||
|
|
|
||
|
|
const show_mlir = true;
|
||
|
|
|
||
|
|
/// Model definition
|
||
|
|
const Mnist = struct {
|
||
|
|
fc1: Layer,
|
||
|
|
fc2: Layer,
|
||
|
|
|
||
|
|
const Layer = struct {
|
||
|
|
weight: zml.Tensor,
|
||
|
|
bias: zml.Tensor,
|
||
|
|
|
||
|
|
pub fn forward(self: Layer, input: zml.Tensor) zml.Tensor {
|
||
|
|
return self.weight.matmul(input).add(self.bias).relu();
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
/// just two linear layers + relu activation
|
||
|
|
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
|
||
|
|
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
|
||
|
|
var x = input.flattenAll().convert(.f32);
|
||
|
|
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
||
|
|
for (layers) |layer| {
|
||
|
|
x = zml.call(layer, .forward, .{x});
|
||
|
|
}
|
||
|
|
return x.argMax(0, .u8).indices;
|
||
|
|
}
|
||
|
|
};
|
||
|
|
|
||
|
|
pub fn main() !void {
|
||
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||
|
|
defer _ = gpa.deinit();
|
||
|
|
const allocator = gpa.allocator();
|
||
|
|
|
||
|
|
try asynk.AsyncThread.main(allocator, asyncMain, .{});
|
||
|
|
}
|
||
|
|
|
||
|
|
pub fn asyncMain() !void {
|
||
|
|
// Short lived allocations
|
||
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
||
|
|
defer _ = gpa.deinit();
|
||
|
|
const allocator = gpa.allocator();
|
||
|
|
|
||
|
|
// Create ZML context
|
||
|
|
var context = try zml.Context.init();
|
||
|
|
defer context.deinit();
|
||
|
|
|
||
|
|
std.debug.print("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
|
||
|
|
|
||
|
|
// Auto-select platform
|
||
|
|
const platform = context.autoPlatform();
|
||
|
|
{
|
||
|
|
// List available targets
|
||
|
|
std.debug.print("Available Platforms:\n", .{});
|
||
|
|
const selected_prefix = "✅";
|
||
|
|
const not_selected_prefix = "• ";
|
||
|
|
const selected_postfix = "(AUTO-SELECTED)\n";
|
||
|
|
const not_selected_postfix = "\n";
|
||
|
|
for (zml.platform.available_targets) |target| {
|
||
|
|
std.debug.print(" {s} {s} {s}", .{
|
||
|
|
if (target == platform.target) selected_prefix else not_selected_prefix,
|
||
|
|
@tagName(target),
|
||
|
|
if (target == platform.target) selected_postfix else not_selected_postfix,
|
||
|
|
});
|
||
|
|
|
||
|
|
// now the platform's devices
|
||
|
|
if (context.platforms.get(target)) |pfm| {
|
||
|
|
for (pfm.getDevices(), 0..) |device, index| {
|
||
|
|
const deviceKind = device.getDescription(platform.pjrt_api).getKind(platform.pjrt_api);
|
||
|
|
std.debug.print(" ◦ #{d}: {s}\n", .{
|
||
|
|
index,
|
||
|
|
deviceKind,
|
||
|
|
});
|
||
|
|
// we only list 1 CPU device
|
||
|
|
if (target == .cpu) break;
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
std.debug.print("\n", .{});
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse program args
|
||
|
|
const process_args = try std.process.argsAlloc(allocator);
|
||
|
|
defer std.process.argsFree(allocator, process_args);
|
||
|
|
const pt_model = process_args[1];
|
||
|
|
const t10kfilename = process_args[2];
|
||
|
|
|
||
|
|
// Memory arena dedicated to model shapes and weights
|
||
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
||
|
|
defer arena_state.deinit();
|
||
|
|
const arena = arena_state.allocator();
|
||
|
|
|
||
|
|
// Read model shapes.
|
||
|
|
// Note this works because Mnist struct uses the same layer names as the pytorch model
|
||
|
|
var buffer_store = try zml.aio.torch.open(allocator, pt_model);
|
||
|
|
defer buffer_store.deinit();
|
||
|
|
|
||
|
|
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
||
|
|
std.debug.print("✅ Read model shapes from PyTorch file {s}\n", .{pt_model});
|
||
|
|
|
||
|
|
// Start loading weights
|
||
|
|
var model_weights = try zml.aio.loadModelBuffers(Mnist, mnist_model, buffer_store, arena, platform);
|
||
|
|
defer zml.aio.unloadBuffers(&model_weights);
|
||
|
|
|
||
|
|
// Start compiling
|
||
|
|
const comp_start_time = std.time.milliTimestamp();
|
||
|
|
if (show_mlir) {
|
||
|
|
std.debug.print("\nCompiling model to MLIR....\n", .{});
|
||
|
|
std.debug.print("-" ** 160 ++ "\n", .{});
|
||
|
|
} else {
|
||
|
|
std.debug.print("Compiling model to MLIR....\r", .{});
|
||
|
|
}
|
||
|
|
var compilation = try async_(zml.compile, .{ allocator, Mnist, .{}, .forward, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
|
||
|
|
|
||
|
|
// Wait for end of compilation and end of weights loading.
|
||
|
|
const compiled_mnist = try compilation.await_();
|
||
|
|
const comp_end_time = std.time.milliTimestamp();
|
||
|
|
if (show_mlir) std.debug.print("-" ** 160 ++ "\n", .{});
|
||
|
|
std.debug.print("✅ Compiled MNIST model in {d} milliseconds! \n", .{comp_end_time - comp_start_time});
|
||
|
|
|
||
|
|
// send weights to accelerator / GPU
|
||
|
|
var mnist = try compiled_mnist.prepare(allocator, model_weights);
|
||
|
|
defer mnist.deinit();
|
||
|
|
std.debug.print("✅ Weights transferred, starting inference...\n\n", .{});
|
||
|
|
|
||
|
|
// Load a random digit image from the dataset.
|
||
|
|
const dataset = try asynk.File.open(t10kfilename, .{ .mode = .read_only });
|
||
|
|
defer dataset.close() catch unreachable;
|
||
|
|
var rng = std.Random.Xoshiro256.init(@intCast(std.time.timestamp()));
|
||
|
|
|
||
|
|
// inference - can be looped
|
||
|
|
{
|
||
|
|
const idx = rng.random().intRangeAtMost(u64, 0, 10000 - 1);
|
||
|
|
var sample: [28 * 28]u8 align(16) = undefined;
|
||
|
|
_ = try dataset.pread(&sample, 16 + (idx * 28 * 28));
|
||
|
|
var input = try zml.Buffer.from(platform, zml.HostBuffer.fromBytes(zml.Shape.init(.{ 28, 28 }, .u8), &sample));
|
||
|
|
defer input.deinit();
|
||
|
|
|
||
|
|
printDigit(sample);
|
||
|
|
var result: zml.Buffer = mnist.call(.{input});
|
||
|
|
defer result.deinit();
|
||
|
|
|
||
|
|
std.debug.print("\n✅ RECOGNIZED DIGIT:\n", .{});
|
||
|
|
std.debug.print(" +-------------+\n", .{});
|
||
|
|
std.debug.print("{s}\n", .{digits[try result.getValue(u8)]});
|
||
|
|
std.debug.print(" +-------------+\n\n", .{});
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
fn printDigit(digit: [28 * 28]u8) void {
|
||
|
|
var buffer: [28][30][2]u8 = undefined;
|
||
|
|
std.debug.print(" R E C O G N I Z I N G I N P U T I M A G E :\n", .{});
|
||
|
|
std.debug.print("+---------------------------------------------------------+\n", .{});
|
||
|
|
for (0..28) |y| {
|
||
|
|
buffer[y][0] = .{ '|', ' ' };
|
||
|
|
buffer[y][29] = .{ '|', '\n' };
|
||
|
|
for (1..29) |x| {
|
||
|
|
const idx = (y * 28) + (x - 1);
|
||
|
|
const val = digit[idx];
|
||
|
|
buffer[y][x] = blk: {
|
||
|
|
if (val > 240) break :blk .{ '*', '*' };
|
||
|
|
if (val > 225) break :blk .{ 'o', 'o' };
|
||
|
|
if (val > 210) break :blk .{ '.', '.' };
|
||
|
|
break :blk .{ ' ', ' ' };
|
||
|
|
};
|
||
|
|
}
|
||
|
|
}
|
||
|
|
std.fmt.format(asynk.StdOut().writer(), "{s}", .{std.mem.asBytes(&buffer)}) catch unreachable;
|
||
|
|
std.debug.print("+---------------------------------------------------------+\n", .{});
|
||
|
|
}
|
||
|
|
|
||
|
|
const digits = [_][]const u8{
|
||
|
|
\\ | ### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ### |
|
||
|
|
,
|
||
|
|
\\ | # |
|
||
|
|
\\ | ## |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ####### |
|
||
|
|
,
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
\\ | # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ####### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
,
|
||
|
|
\\ | ####### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ###### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | ###### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
\\ | ####### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # |
|
||
|
|
,
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
\\ | ##### |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ###### |
|
||
|
|
\\ | # |
|
||
|
|
\\ | # # |
|
||
|
|
\\ | ##### |
|
||
|
|
,
|
||
|
|
};
|
||
|
|
|
||
|
|
pub const std_options = .{
|
||
|
|
// Set the global log level to err
|
||
|
|
.log_level = .err,
|
||
|
|
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||
|
|
.{ .scope = .pjrt, .level = .err },
|
||
|
|
.{ .scope = .zml_module, .level = if (show_mlir) .debug else .err },
|
||
|
|
},
|
||
|
|
};
|