Update docs and Zig examples to demonstrate the new client creation flags API.
This commit is contained in:
parent
9f4194ad97
commit
cb6fcbbb1a
@ -249,7 +249,7 @@ pub fn asyncMain() !void {
|
||||
|
||||
var ctx = try zml.Context.init();
|
||||
defer ctx.deinit();
|
||||
const platform = ctx.autoPlatform();
|
||||
const platform = ctx.autoPlatform(.{});
|
||||
const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform);
|
||||
|
||||
zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3);
|
||||
|
||||
@ -184,7 +184,7 @@ pub fn asyncMain() !void {
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
...
|
||||
}
|
||||
```
|
||||
@ -458,7 +458,7 @@ pub fn asyncMain() !void {
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
|
||||
// Our weights and bias to use
|
||||
var weights = [3]f16{ 2.0, 2.0, 2.0 };
|
||||
|
||||
@ -34,7 +34,7 @@ pub fn asyncMain() !void {
|
||||
defer context.deinit();
|
||||
|
||||
// Auto-select platform
|
||||
const platform = context.autoPlatform().withCompilationOptions(.{
|
||||
const platform = context.autoPlatform(.{}).withCompilationOptions(.{
|
||||
.sharding_enabled = true,
|
||||
});
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
@ -140,6 +140,8 @@ pub fn asyncMain() !void {
|
||||
prompt: ?[]const u8 = null,
|
||||
test_activations: ?[]const u8 = null,
|
||||
seed: ?u128 = null,
|
||||
// eg: --create-options='{"cuda":{"allocator":{"bfc":{"memory_fraction": 0.99}}}}'
|
||||
create_options: []const u8 = "{}",
|
||||
};
|
||||
|
||||
log.info(" LLama was compiled with {}", .{@import("builtin").mode});
|
||||
@ -157,9 +159,6 @@ pub fn asyncMain() !void {
|
||||
.sharding_enabled = true,
|
||||
};
|
||||
|
||||
const platform = context.autoPlatform().withCompilationOptions(compilation_options);
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
var args = std.process.args();
|
||||
const cli_args = flags.parse(&args, CliArgs);
|
||||
const model_file = cli_args.model;
|
||||
@ -168,6 +167,10 @@ pub fn asyncMain() !void {
|
||||
defer arena_state.deinit();
|
||||
const model_arena = arena_state.allocator();
|
||||
|
||||
const create_opts = try std.json.parseFromSliceLeaky(zml.Platform.CreateOptions, model_arena, cli_args.create_options, .{});
|
||||
const platform = context.autoPlatform(create_opts).withCompilationOptions(compilation_options);
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
log.info("Model file: {s}", .{model_file});
|
||||
|
||||
var ts = try zml.aio.detectFormatAndOpen(allocator, model_file);
|
||||
|
||||
@ -33,7 +33,7 @@ pub fn asyncMain() !void {
|
||||
defer context.deinit();
|
||||
|
||||
// Select platform
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
|
||||
// Parse program args
|
||||
var args = std.process.args();
|
||||
|
||||
@ -34,7 +34,7 @@ pub fn asyncMain() !void {
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
var buffers = try gpa.allocator().alloc(zml.Buffer, buffer_store.buffers.count());
|
||||
|
||||
@ -51,7 +51,7 @@ pub fn asyncMain() !void {
|
||||
// log.info("\n===========================\n== ZML MNIST Example ==\n===========================\n\n", .{});
|
||||
|
||||
// // Auto-select platform
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Parse program args
|
||||
|
||||
@ -34,7 +34,7 @@ pub fn asyncMain() !void {
|
||||
var context = try zml.Context.init();
|
||||
defer context.deinit();
|
||||
|
||||
const platform = context.autoPlatform();
|
||||
const platform = context.autoPlatform(.{});
|
||||
context.printAvailablePlatforms(platform);
|
||||
|
||||
// Our weights and bias to use
|
||||
|
||||
Loading…
Reference in New Issue
Block a user