Update docs and Zig examples to demonstrate the new client creation flags API.

This commit is contained in:
Foke Singh 2023-11-09 12:31:11 +00:00
parent 9f4194ad97
commit cb6fcbbb1a
8 changed files with 14 additions and 11 deletions

View File

@ -249,7 +249,7 @@ pub fn asyncMain() !void {
var ctx = try zml.Context.init();
defer ctx.deinit();
const platform = ctx.autoPlatform();
const platform = ctx.autoPlatform(.{});
const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform);
zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3);

View File

@ -184,7 +184,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init();
defer context.deinit();
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
...
}
```
@ -458,7 +458,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init();
defer context.deinit();
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
// Our weights and bias to use
var weights = [3]f16{ 2.0, 2.0, 2.0 };

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
defer context.deinit();
// Auto-select platform
const platform = context.autoPlatform().withCompilationOptions(.{
const platform = context.autoPlatform(.{}).withCompilationOptions(.{
.sharding_enabled = true,
});
context.printAvailablePlatforms(platform);

View File

@ -140,6 +140,8 @@ pub fn asyncMain() !void {
prompt: ?[]const u8 = null,
test_activations: ?[]const u8 = null,
seed: ?u128 = null,
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
create_options: []const u8 = "{}",
};
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
@ -157,9 +159,6 @@ pub fn asyncMain() !void {
.sharding_enabled = true,
};
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
var args = std.process.args();
const cli_args = flags.parse(&args, CliArgs);
const model_file = cli_args.model;
@ -168,6 +167,10 @@ pub fn asyncMain() !void {
defer arena_state.deinit();
const model_arena = arena_state.allocator();
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
context.printAvailablePlatforms(platform);
log.info("Model file: {s}", .{model_file});
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);

View File

@ -33,7 +33,7 @@ pub fn asyncMain() !void {
defer context.deinit();
// Select platform
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
// Parse program args
var args = std.process.args();

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init();
defer context.deinit();
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform);
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());

View File

@ -51,7 +51,7 @@ pub fn asyncMain() !void {
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
// // Auto-select platform
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform);
// Parse program args

View File

@ -34,7 +34,7 @@ pub fn asyncMain() !void {
var context = try zml.Context.init();
defer context.deinit();
const platform = context.autoPlatform();
const platform = context.autoPlatform(.{});
context.printAvailablePlatforms(platform);
// Our weights and bias to use