2023-01-03 10:21:07 +00:00
|
|
|
const std = @import("std");
|
|
|
|
|
|
2024-03-18 13:11:14 +00:00
|
|
|
const asynk = @import("async");
|
|
|
|
|
const zml = @import("zml");
|
2023-01-03 10:21:07 +00:00
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
const log = std.log.scoped(.mnist);
|
|
|
|
|
|
2023-08-03 11:35:24 +00:00
|
|
|
pub const std_options: std.Options = .{
|
|
|
|
|
.log_level = .info,
|
2024-03-18 13:11:14 +00:00
|
|
|
.logFn = asynk.logFn(std.log.defaultLog),
|
2023-08-03 11:35:24 +00:00
|
|
|
};
|
|
|
|
|
|
2023-01-03 10:21:07 +00:00
|
|
|
/// Model definition
|
|
|
|
|
const Mnist = struct {
|
|
|
|
|
fc1: Layer,
|
|
|
|
|
fc2: Layer,
|
|
|
|
|
|
|
|
|
|
const Layer = struct {
|
|
|
|
|
weight: zml.Tensor,
|
|
|
|
|
bias: zml.Tensor,
|
|
|
|
|
|
|
|
|
|
pub fn forward(self: Layer, input: zml.Tensor) zml.Tensor {
|
|
|
|
|
return self.weight.matmul(input).add(self.bias).relu();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/// just two linear layers + relu activation
|
|
|
|
|
pub fn forward(self: Mnist, input: zml.Tensor) zml.Tensor {
|
|
|
|
|
// std.log.info("Compiling for target: {s}", .{@tagName(input.getContext().target())});
|
|
|
|
|
var x = input.flattenAll().convert(.f32);
|
|
|
|
|
const layers: []const Layer = &.{ self.fc1, self.fc2 };
|
|
|
|
|
for (layers) |layer| {
|
|
|
|
|
x = zml.call(layer, .forward, .{x});
|
|
|
|
|
}
|
2023-12-26 10:45:52 +00:00
|
|
|
return x.argMax(0).indices.convert(.u8);
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
pub fn main() !void {
|
2023-06-27 14:23:22 +00:00
|
|
|
try asynk.AsyncThread.main(std.heap.c_allocator, asyncMain);
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn asyncMain() !void {
|
2023-06-27 14:23:22 +00:00
|
|
|
const allocator = std.heap.c_allocator;
|
2023-01-03 10:21:07 +00:00
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
// // Create ZML context
|
2023-01-03 10:21:07 +00:00
|
|
|
var context = try zml.Context.init();
|
|
|
|
|
defer context.deinit();
|
|
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
|
2023-01-03 10:21:07 +00:00
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
// // Auto-select platform
|
2023-11-09 12:31:11 +00:00
|
|
|
const platform = context.autoPlatform(.{});
|
2023-06-27 14:23:22 +00:00
|
|
|
context.printAvailablePlatforms(platform);
|
2023-01-03 10:21:07 +00:00
|
|
|
|
|
|
|
|
// Parse program args
|
|
|
|
|
const process_args = try std.process.argsAlloc(allocator);
|
|
|
|
|
defer std.process.argsFree(allocator, process_args);
|
|
|
|
|
const pt_model = process_args[1];
|
|
|
|
|
const t10kfilename = process_args[2];
|
|
|
|
|
|
|
|
|
|
// Memory arena dedicated to model shapes and weights
|
|
|
|
|
var arena_state = std.heap.ArenaAllocator.init(allocator);
|
|
|
|
|
defer arena_state.deinit();
|
|
|
|
|
const arena = arena_state.allocator();
|
|
|
|
|
|
|
|
|
|
// Read model shapes.
|
|
|
|
|
// Note this works because Mnist struct uses the same layer names as the pytorch model
|
|
|
|
|
var buffer_store = try zml.aio.torch.open(allocator, pt_model);
|
|
|
|
|
defer buffer_store.deinit();
|
|
|
|
|
|
|
|
|
|
const mnist_model = try zml.aio.populateModel(Mnist, allocator, buffer_store);
|
2023-06-27 14:23:22 +00:00
|
|
|
log.info("Reading model shapes from PyTorch file {s}...", .{pt_model});
|
2023-01-03 10:21:07 +00:00
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
// Start compiling
|
|
|
|
|
log.info("Compiling model to MLIR....", .{});
|
|
|
|
|
var start_time = try std.time.Timer.start();
|
|
|
|
|
var compilation = try asynk.asyncc(zml.compile, .{ allocator, Mnist.forward, .{}, .{zml.Shape.init(.{ 28, 28 }, .u8)}, buffer_store, platform });
|
|
|
|
|
|
|
|
|
|
// While compiling, start loading weights on the platform
|
2023-01-03 10:21:07 +00:00
|
|
|
var model_weights = try zml.aio.loadModelBuffers(Mnist, mnist_model, buffer_store, arena, platform);
|
|
|
|
|
defer zml.aio.unloadBuffers(&model_weights);
|
|
|
|
|
|
|
|
|
|
// Wait for end of compilation and end of weights loading.
|
2023-08-03 11:35:24 +00:00
|
|
|
const compiled_mnist = try compilation.awaitt();
|
2023-06-27 14:23:22 +00:00
|
|
|
log.info("✅ Compiled model in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
2023-01-03 10:21:07 +00:00
|
|
|
|
2023-10-10 11:12:34 +00:00
|
|
|
const mnist = compiled_mnist.prepare(model_weights);
|
2023-01-03 10:21:07 +00:00
|
|
|
defer mnist.deinit();
|
2023-06-27 14:23:22 +00:00
|
|
|
log.info("✅ Weights transferred in {d}ms", .{start_time.read() / std.time.ns_per_ms});
|
|
|
|
|
|
|
|
|
|
log.info("Starting inference...", .{});
|
2023-01-03 10:21:07 +00:00
|
|
|
|
|
|
|
|
// Load a random digit image from the dataset.
|
|
|
|
|
const dataset = try asynk.File.open(t10kfilename, .{ .mode = .read_only });
|
|
|
|
|
defer dataset.close() catch unreachable;
|
|
|
|
|
var rng = std.Random.Xoshiro256.init(@intCast(std.time.timestamp()));
|
|
|
|
|
|
|
|
|
|
// inference - can be looped
|
|
|
|
|
{
|
|
|
|
|
const idx = rng.random().intRangeAtMost(u64, 0, 10000 - 1);
|
|
|
|
|
var sample: [28 * 28]u8 align(16) = undefined;
|
|
|
|
|
_ = try dataset.pread(&sample, 16 + (idx * 28 * 28));
|
2024-12-20 09:30:35 +00:00
|
|
|
var input = try zml.Buffer.from(platform, zml.HostBuffer.fromBytes(zml.Shape.init(.{ 28, 28 }, .u8), &sample), .{});
|
2023-01-03 10:21:07 +00:00
|
|
|
defer input.deinit();
|
|
|
|
|
|
|
|
|
|
printDigit(sample);
|
|
|
|
|
var result: zml.Buffer = mnist.call(.{input});
|
|
|
|
|
defer result.deinit();
|
|
|
|
|
|
2023-06-27 14:23:22 +00:00
|
|
|
log.info(
|
|
|
|
|
\\✅ RECOGNIZED DIGIT:
|
|
|
|
|
\\ +-------------+
|
|
|
|
|
\\{s}
|
|
|
|
|
\\ +-------------+
|
|
|
|
|
\\
|
|
|
|
|
, .{digits[try result.getValue(u8)]});
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn printDigit(digit: [28 * 28]u8) void {
|
|
|
|
|
var buffer: [28][30][2]u8 = undefined;
|
|
|
|
|
for (0..28) |y| {
|
|
|
|
|
buffer[y][0] = .{ '|', ' ' };
|
|
|
|
|
buffer[y][29] = .{ '|', '\n' };
|
|
|
|
|
for (1..29) |x| {
|
|
|
|
|
const idx = (y * 28) + (x - 1);
|
|
|
|
|
const val = digit[idx];
|
|
|
|
|
buffer[y][x] = blk: {
|
|
|
|
|
if (val > 240) break :blk .{ '*', '*' };
|
|
|
|
|
if (val > 225) break :blk .{ 'o', 'o' };
|
|
|
|
|
if (val > 210) break :blk .{ '.', '.' };
|
|
|
|
|
break :blk .{ ' ', ' ' };
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-27 14:23:22 +00:00
|
|
|
|
|
|
|
|
log.info(
|
|
|
|
|
\\
|
|
|
|
|
\\ R E C O G N I Z I N G I N P U T I M A G E :
|
|
|
|
|
\\+---------------------------------------------------------+
|
|
|
|
|
\\{s}+---------------------------------------------------------+
|
|
|
|
|
\\
|
|
|
|
|
, .{std.mem.asBytes(&buffer)});
|
2023-01-03 10:21:07 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const digits = [_][]const u8{
|
|
|
|
|
\\ | ### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ### |
|
|
|
|
|
,
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ## |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ####### |
|
|
|
|
|
,
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ####### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
,
|
|
|
|
|
\\ | ####### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ###### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | ###### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
\\ | ####### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # |
|
|
|
|
|
,
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ###### |
|
|
|
|
|
\\ | # |
|
|
|
|
|
\\ | # # |
|
|
|
|
|
\\ | ##### |
|
|
|
|
|
,
|
|
|
|
|
};
|