diff --git a/docs/howtos/howto_torch2zml.md b/docs/howtos/howto_torch2zml.md index 4a6f195..a5785f4 100644 --- a/docs/howtos/howto_torch2zml.md +++ b/docs/howtos/howto_torch2zml.md @@ -249,7 +249,7 @@ pub fn asyncMain() !void { var ctx = try zml.Context.init(); defer ctx.deinit(); - const platform = ctx.autoPlatform(); + const platform = ctx.autoPlatform(.{}); const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform); zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3); diff --git a/docs/tutorials/write_first_model.md b/docs/tutorials/write_first_model.md index 8b4d925..cc7148a 100644 --- a/docs/tutorials/write_first_model.md +++ b/docs/tutorials/write_first_model.md @@ -184,7 +184,7 @@ pub fn asyncMain() !void { var context = try zml.Context.init(); defer context.deinit(); - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); ... } ``` @@ -458,7 +458,7 @@ pub fn asyncMain() !void { var context = try zml.Context.init(); defer context.deinit(); - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); // Our weights and bias to use var weights = [3]f16{ 2.0, 2.0, 2.0 }; diff --git a/examples/benchmark/main.zig b/examples/benchmark/main.zig index 0a8aedb..d8b0b47 100644 --- a/examples/benchmark/main.zig +++ b/examples/benchmark/main.zig @@ -34,7 +34,7 @@ pub fn asyncMain() !void { defer context.deinit(); // Auto-select platform - const platform = context.autoPlatform().withCompilationOptions(.{ + const platform = context.autoPlatform(.{}).withCompilationOptions(.{ .sharding_enabled = true, }); context.printAvailablePlatforms(platform); diff --git a/examples/llama/main.zig b/examples/llama/main.zig index ccf5f28..8508ac8 100644 --- a/examples/llama/main.zig +++ b/examples/llama/main.zig @@ -140,6 +140,8 @@ pub fn asyncMain() !void { prompt: ?[]const u8 = null, test_activations: ?[]const u8 = null, seed: ?u128 = null, + // eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}' + create_options: []const u8 = "{}", }; log.info(" LLama was compiled with {}", .{@import("builtin").mode}); @@ -157,9 +159,6 @@ pub fn asyncMain() !void { .sharding_enabled = true, }; - const platform = context.autoPlatform().withCompilationOptions(compilation_options); - context.printAvailablePlatforms(platform); - var args = std.process.args(); const cli_args = flags.parse(&args, CliArgs); const model_file = cli_args.model; @@ -168,6 +167,10 @@ pub fn asyncMain() !void { defer arena_state.deinit(); const model_arena = arena_state.allocator(); + const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{}); + const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options); + context.printAvailablePlatforms(platform); + log.info("Model file: {s}", .{model_file}); var ts = try zml.aio.detectFormatAndOpen(allocator, model_file); diff --git a/examples/llama/test.zig b/examples/llama/test.zig index ab2c895..ae30e23 100644 --- a/examples/llama/test.zig +++ b/examples/llama/test.zig @@ -33,7 +33,7 @@ pub fn asyncMain() !void { defer context.deinit(); // Select platform - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); // Parse program args var args = std.process.args(); diff --git a/examples/loader/main.zig b/examples/loader/main.zig index 8037712..6673289 100644 --- a/examples/loader/main.zig +++ b/examples/loader/main.zig @@ -34,7 +34,7 @@ pub fn asyncMain() !void { var context = try zml.Context.init(); defer context.deinit(); - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); context.printAvailablePlatforms(platform); var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count()); diff --git a/examples/mnist/mnist.zig b/examples/mnist/mnist.zig index 98fab69..0f35c7e 100644 --- a/examples/mnist/mnist.zig +++ b/examples/mnist/mnist.zig @@ -51,7 +51,7 @@ pub fn asyncMain() !void { // log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{}); // // Auto-select platform - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); context.printAvailablePlatforms(platform); // Parse program args diff --git a/examples/simple_layer/main.zig b/examples/simple_layer/main.zig index 475f10b..6f61c0b 100644 --- a/examples/simple_layer/main.zig +++ b/examples/simple_layer/main.zig @@ -34,7 +34,7 @@ pub fn asyncMain() !void { var context = try zml.Context.init(); defer context.deinit(); - const platform = context.autoPlatform(); + const platform = context.autoPlatform(.{}); context.printAvailablePlatforms(platform); // Our weights and bias to use