Update docs and Zig examples to demonstrate the new client creation flags API.

This commit is contained in:
Foke Singh 2023-11-09 12:31:11 +00:00
parent 9f4194ad97
commit cb6fcbbb1a
8 changed files with 14 additions and 11 deletions

View File

@ -249,7 +249,7 @@ pub fn asyncMain() !void {
var ctx = try zml.Context.init(); var ctx = try zml.Context.init();
defer ctx.deinit(); defer ctx.deinit();
const platform = ctx.autoPlatform(); const platform = ctx.autoPlatform(.{});
const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform); const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform);
zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3); zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3);

View File

@ -184,7 +184,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init(); var context = try zml.Context.init();
defer context.deinit(); defer context.deinit();
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
... ...
} }
``` ```
@ -458,7 +458,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init(); var context = try zml.Context.init();
defer context.deinit(); defer context.deinit();
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
// Our weights and bias to use // Our weights and bias to use
var weights = [3]f16{ 2.0, 2.0, 2.0 }; var weights = [3]f16{ 2.0, 2.0, 2.0 };

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
defer context.deinit(); defer context.deinit();
// Auto-select platform // Auto-select platform
const platform = context.autoPlatform().withCompilationOptions(.{ const platform = context.autoPlatform(.{}).withCompilationOptions(.{
.sharding_enabled = true, .sharding_enabled = true,
}); });
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);

View File

@ -140,6 +140,8 @@ pub fn asyncMain() !void {
prompt: ?[]const u8 = null, prompt: ?[]const u8 = null,
test_activations: ?[]const u8 = null, test_activations: ?[]const u8 = null,
seed: ?u128 = null, seed: ?u128 = null,
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
create_options: []const u8 = "{}",
}; };
log.info(" LLama was compiled with {}", .{@import("builtin").mode}); log.info(" LLama was compiled with {}", .{@import("builtin").mode});
@ -157,9 +159,6 @@ pub fn asyncMain() !void {
.sharding_enabled = true, .sharding_enabled = true,
}; };
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
var args = std.process.args(); var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs); const cli_args = flags.parse(&args, CliArgs);
const model_file = cli_args.model; const model_file = cli_args.model;
@ -168,6 +167,10 @@ pub fn asyncMain() !void {
defer arena_state.deinit(); defer arena_state.deinit();
const model_arena = arena_state.allocator(); const model_arena = arena_state.allocator();
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
log.info("Model file: {s}", .{model_file}); log.info("Model file: {s}", .{model_file});
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file); var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);

View File

@ -33,7 +33,7 @@ pub fn asyncMain() !void {
defer context.deinit(); defer context.deinit();
// Select platform // Select platform
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
// Parse program args // Parse program args
var args = std.process.args(); var args = std.process.args();

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init(); var context = try zml.Context.init();
defer context.deinit(); defer context.deinit();
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count()); var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());

View File

@ -51,7 +51,7 @@ pub fn asyncMain() !void {
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{}); // log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
// // Auto-select platform // // Auto-select platform
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);
// Parse program args // Parse program args

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init(); var context = try zml.Context.init();
defer context.deinit(); defer context.deinit();
const platform = context.autoPlatform(); const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform); context.printAvailablePlatforms(platform);
// Our weights and bias to use // Our weights and bias to use