Add Zig example programs for LLaMA, ModernBERT, and SimpleLayer, including a Bazel BUILD file for the LLaMA example.
This commit is contained in:
parent
488a844a0f
commit
0ed7f5c907
@ -10,6 +10,7 @@ zig_binary(
|
|||||||
"llama.zig",
|
"llama.zig",
|
||||||
],
|
],
|
||||||
main = "main.zig",
|
main = "main.zig",
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
"//async",
|
"//async",
|
||||||
"//stdx",
|
"//stdx",
|
||||||
|
|||||||
@ -87,7 +87,7 @@ pub const LlamaLM = struct {
|
|||||||
kv_cache: KvCache,
|
kv_cache: KvCache,
|
||||||
rng: Tensor.Rng,
|
rng: Tensor.Rng,
|
||||||
) struct { Tensor, KvCache, Tensor.Rng } {
|
) struct { Tensor, KvCache, Tensor.Rng } {
|
||||||
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {} and {}", .{ tokens_, token_index });
|
stdx.debug.assert(tokens_.dtype() == .u32 and tokens_.rank() >= 1 and token_index.dtype() == .u32 and token_index.rank() <= 1, "Can't run Llama ! Expected >=1d tokens and 0d token_index, got: {f} and {f}", .{ tokens_, token_index });
|
||||||
const tokens = tokens_.withPartialTags(.{.s});
|
const tokens = tokens_.withPartialTags(.{.s});
|
||||||
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
const out, const updated_kv_cache = zml.call(self.model, .forward, .{ tokens, token_index, kv_cache });
|
||||||
const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts);
|
const new_tokens, const new_rng = self.sampleTokens(self.lm_head, out, rng, self.gen_opts);
|
||||||
@ -200,8 +200,8 @@ pub const TransformerLayer = struct {
|
|||||||
kv_cache: KvCache,
|
kv_cache: KvCache,
|
||||||
) struct { Tensor, KvCache } {
|
) struct { Tensor, KvCache } {
|
||||||
// Self Attention
|
// Self Attention
|
||||||
//log.debug("TransformerLayer({}) -> {}", .{ x0, self.input_layernorm.forward(x0) });
|
//log.debug("TransformerLayer({f}) -> {f}", .{ x0, self.input_layernorm.forward(x0) });
|
||||||
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {}", .{x0});
|
stdx.debug.assert(x0.rank() >= 2 and x0.shape().hasTags(.{ .s, .d }), "TransformerLayer expected input shape: {{..., .s, .d}}, received: {f}", .{x0});
|
||||||
|
|
||||||
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
|
const x0_normalized = zml.call(self.input_layernorm, .forward, .{x0});
|
||||||
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
|
const delta0, const updated_kv_cache = zml.call(self.self_attn, .forward, .{ x0_normalized, token_index, kv_cache });
|
||||||
@ -336,7 +336,7 @@ pub const KvCache = struct {
|
|||||||
return .{
|
return .{
|
||||||
.k = try zml.Buffer.constant(platform, kv_shape, 1),
|
.k = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||||
.v = try zml.Buffer.constant(platform, kv_shape, 1),
|
.v = try zml.Buffer.constant(platform, kv_shape, 1),
|
||||||
.layer_index = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0),
|
.layer_index = try zml.Buffer.scalar(platform, 0, .u32),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,19 +1,19 @@
|
|||||||
|
const std = @import("std");
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const clap = @import("clap");
|
const clap = @import("clap");
|
||||||
const std = @import("std");
|
|
||||||
const stdx = @import("stdx");
|
const stdx = @import("stdx");
|
||||||
const zml = @import("zml");
|
const zml = @import("zml");
|
||||||
|
const Buffer = zml.Buffer;
|
||||||
|
const Tensor = zml.Tensor;
|
||||||
|
const ShapeOf = zml.ShapeOf;
|
||||||
|
|
||||||
const llama = @import("llama.zig");
|
const llama = @import("llama.zig");
|
||||||
|
|
||||||
const LlamaLM = llama.LlamaLM;
|
const LlamaLM = llama.LlamaLM;
|
||||||
const Llama = llama.Llama;
|
const Llama = llama.Llama;
|
||||||
const KvCache = llama.KvCache;
|
const KvCache = llama.KvCache;
|
||||||
const TransformerLayer = llama.TransformerLayer;
|
const TransformerLayer = llama.TransformerLayer;
|
||||||
const SelfAttn = llama.SelfAttn;
|
const SelfAttn = llama.SelfAttn;
|
||||||
const Buffer = zml.Buffer;
|
|
||||||
const Tensor = zml.Tensor;
|
|
||||||
const ShapeOf = zml.ShapeOf;
|
|
||||||
|
|
||||||
const log = std.log.scoped(.llama);
|
const log = std.log.scoped(.llama);
|
||||||
|
|
||||||
@ -23,7 +23,7 @@ pub const std_options: std.Options = .{
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
|
pub fn tokenizePrompt(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, config: LlamaLM.Config, prompt: []const u8, skip_llama3_encoding: bool) ![]u32 {
|
||||||
var tokens = std.ArrayList(u32).init(allocator);
|
var tokens = std.array_list.Managed(u32).init(allocator);
|
||||||
var encoder = try tokenizer.encoder();
|
var encoder = try tokenizer.encoder();
|
||||||
defer encoder.deinit();
|
defer encoder.deinit();
|
||||||
|
|
||||||
@ -101,7 +101,7 @@ pub fn generateText(
|
|||||||
defer current_token.deinit();
|
defer current_token.deinit();
|
||||||
|
|
||||||
// Here we collect the generated text
|
// Here we collect the generated text
|
||||||
var output = std.ArrayList(u8).init(allocator);
|
var output = std.array_list.Managed(u8).init(allocator);
|
||||||
defer output.deinit();
|
defer output.deinit();
|
||||||
|
|
||||||
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
const output_tokens_len = max_seq_len - prompt_tok.len - 1;
|
||||||
@ -179,21 +179,22 @@ pub fn asyncMain() !void {
|
|||||||
.PATH = clap.parsers.string,
|
.PATH = clap.parsers.string,
|
||||||
};
|
};
|
||||||
var diag: clap.Diagnostic = .{};
|
var diag: clap.Diagnostic = .{};
|
||||||
const stderr = std.io.getStdErr().writer();
|
var stderr_buffer: [1024]u8 = undefined;
|
||||||
|
var stderr = std.fs.File.stderr().writer(&stderr_buffer);
|
||||||
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
var res = clap.parse(clap.Help, ¶ms, parsers, .{
|
||||||
.diagnostic = &diag,
|
.diagnostic = &diag,
|
||||||
.allocator = allocator,
|
.allocator = allocator,
|
||||||
}) catch |err| {
|
}) catch |err| {
|
||||||
diag.report(stderr, err) catch {};
|
diag.report(&stderr.interface, err) catch {};
|
||||||
stderr.print("usage: ", .{}) catch {};
|
stderr.interface.print("usage: ", .{}) catch {};
|
||||||
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
clap.usage(&stderr.interface, clap.Help, ¶ms) catch {};
|
||||||
stderr.print("\n", .{}) catch {};
|
stderr.interface.print("\n", .{}) catch {};
|
||||||
return;
|
return;
|
||||||
};
|
};
|
||||||
defer res.deinit();
|
defer res.deinit();
|
||||||
|
|
||||||
if (res.args.help != 0) {
|
if (res.args.help != 0) {
|
||||||
clap.help(std.io.getStdErr().writer(), clap.Help, ¶ms, .{}) catch {};
|
clap.help(&stderr.interface, clap.Help, ¶ms, .{}) catch {};
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -224,7 +225,9 @@ pub fn asyncMain() !void {
|
|||||||
const config = blk: {
|
const config = blk: {
|
||||||
var config_json_file = try asynk.File.open(model_config_path, .{ .mode = .read_only });
|
var config_json_file = try asynk.File.open(model_config_path, .{ .mode = .read_only });
|
||||||
defer config_json_file.close() catch unreachable;
|
defer config_json_file.close() catch unreachable;
|
||||||
var reader = std.json.reader(allocator, config_json_file.reader());
|
var config_json_buffer: [256]u8 = undefined;
|
||||||
|
var config_reader = config_json_file.reader(&config_json_buffer);
|
||||||
|
var reader = std.json.Reader.init(allocator, &config_reader.interface);
|
||||||
defer reader.deinit();
|
defer reader.deinit();
|
||||||
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
const config_obj = try std.json.parseFromTokenSourceLeaky(llama.LlamaLM.Config, allocator, &reader, .{ .ignore_unknown_fields = true });
|
||||||
break :blk config_obj;
|
break :blk config_obj;
|
||||||
@ -298,16 +301,16 @@ pub fn asyncMain() !void {
|
|||||||
platform,
|
platform,
|
||||||
});
|
});
|
||||||
|
|
||||||
log.info("\tLoading Llama weights from {?s}...", .{model_weights_path});
|
log.info("\tLoading Llama weights from {s}...", .{model_weights_path});
|
||||||
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
|
var llama_weights = try zml.aio.loadBuffers(llama.LlamaLM, .{ config, llama_options }, ts, model_arena.allocator(), platform);
|
||||||
defer zml.aio.unloadBuffers(&llama_weights);
|
defer zml.aio.unloadBuffers(&llama_weights);
|
||||||
log.info("✅\tLoaded weights in {}", .{std.fmt.fmtDuration(start.read())});
|
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||||||
|
|
||||||
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
var llama_module_prefill = (try fut_mod_prefill.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module_prefill.deinit();
|
defer llama_module_prefill.deinit();
|
||||||
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
var llama_module = (try fut_mod.awaitt()).prepare(llama_weights);
|
||||||
defer llama_module.deinit();
|
defer llama_module.deinit();
|
||||||
log.info("✅\tCompiled model in {}", .{std.fmt.fmtDuration(start.read())});
|
log.info("✅\tCompiled model in {D}", .{start.read()});
|
||||||
|
|
||||||
log.info("Creating KvCache", .{});
|
log.info("Creating KvCache", .{});
|
||||||
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
const kv_cache = try llama.KvCache.initBuffer(kv_shape, platform);
|
||||||
@ -315,7 +318,7 @@ pub fn asyncMain() !void {
|
|||||||
var tokenizer = blk: {
|
var tokenizer = blk: {
|
||||||
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
|
log.info("Loading tokenizer from {s}", .{model_tokenizer_path});
|
||||||
var timer = try stdx.time.Timer.start();
|
var timer = try stdx.time.Timer.start();
|
||||||
defer log.info("Loaded tokenizer from {s} [{}]", .{ model_tokenizer_path, timer.read() });
|
defer log.info("Loaded tokenizer from {s} [{D}]", .{ model_tokenizer_path, timer.read() });
|
||||||
|
|
||||||
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path);
|
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena.allocator(), model_tokenizer_path);
|
||||||
};
|
};
|
||||||
|
|||||||
@ -1,7 +1,4 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const log = std.log.scoped(.modernbert);
|
|
||||||
|
|
||||||
const modernbert = @import("modernbert.zig");
|
|
||||||
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
const clap = @import("clap");
|
const clap = @import("clap");
|
||||||
@ -9,6 +6,10 @@ const stdx = @import("stdx");
|
|||||||
const zml = @import("zml");
|
const zml = @import("zml");
|
||||||
const Tensor = zml.Tensor;
|
const Tensor = zml.Tensor;
|
||||||
|
|
||||||
|
const modernbert = @import("modernbert.zig");
|
||||||
|
|
||||||
|
const log = std.log.scoped(.modernbert);
|
||||||
|
|
||||||
pub const std_options: std.Options = .{
|
pub const std_options: std.Options = .{
|
||||||
.log_level = .info,
|
.log_level = .info,
|
||||||
.log_scope_levels = &[_]std.log.ScopeLevel{
|
.log_scope_levels = &[_]std.log.ScopeLevel{
|
||||||
@ -42,20 +43,20 @@ pub fn main() !void {
|
|||||||
|
|
||||||
pub fn asyncMain() !void {
|
pub fn asyncMain() !void {
|
||||||
const allocator = std.heap.c_allocator;
|
const allocator = std.heap.c_allocator;
|
||||||
const stderr = std.io.getStdErr().writer();
|
const stderr = std.fs.File.stderr();
|
||||||
|
|
||||||
var diag: clap.Diagnostic = .{};
|
var diag: clap.Diagnostic = .{};
|
||||||
var cli = clap.parse(clap.Help, ¶ms, clap_parsers, .{
|
var cli = clap.parse(clap.Help, ¶ms, clap_parsers, .{
|
||||||
.diagnostic = &diag,
|
.diagnostic = &diag,
|
||||||
.allocator = allocator,
|
.allocator = allocator,
|
||||||
}) catch |err| {
|
}) catch |err| {
|
||||||
try diag.report(stderr, err);
|
try diag.reportToFile(stderr, err);
|
||||||
try printUsageAndExit(stderr);
|
try printUsageAndExit(stderr);
|
||||||
};
|
};
|
||||||
defer cli.deinit();
|
defer cli.deinit();
|
||||||
|
|
||||||
if (cli.args.help != 0) {
|
if (cli.args.help != 0) {
|
||||||
try clap.help(stderr, clap.Help, ¶ms, .{});
|
try clap.helpToFile(stderr, clap.Help, ¶ms, .{});
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,7 +81,9 @@ pub fn asyncMain() !void {
|
|||||||
|
|
||||||
// Detects the format of the model file (base on filename) and open it.
|
// Detects the format of the model file (base on filename) and open it.
|
||||||
const model_file = cli.args.model orelse {
|
const model_file = cli.args.model orelse {
|
||||||
stderr.print("Error: missing --model=...\n\n", .{}) catch {};
|
var buf: [256]u8 = undefined;
|
||||||
|
var writer = stderr.writer(&buf);
|
||||||
|
writer.interface.print("Error: missing --model=...\n\n", .{}) catch {};
|
||||||
printUsageAndExit(stderr);
|
printUsageAndExit(stderr);
|
||||||
unreachable;
|
unreachable;
|
||||||
};
|
};
|
||||||
@ -96,7 +99,7 @@ pub fn asyncMain() !void {
|
|||||||
if (cli.args.tokenizer) |tok| {
|
if (cli.args.tokenizer) |tok| {
|
||||||
log.info("\tLoading tokenizer from {s}", .{tok});
|
log.info("\tLoading tokenizer from {s}", .{tok});
|
||||||
var timer = try stdx.time.Timer.start();
|
var timer = try stdx.time.Timer.start();
|
||||||
defer log.info("✅\tLoaded tokenizer from {s} [{}]", .{ tok, timer.read() });
|
defer log.info("✅\tLoaded tokenizer from {s} [{D}]", .{ tok, timer.read() });
|
||||||
|
|
||||||
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok);
|
break :blk try zml.tokenizer.Tokenizer.fromFile(model_arena, tok);
|
||||||
} else {
|
} else {
|
||||||
@ -123,13 +126,13 @@ pub fn asyncMain() !void {
|
|||||||
const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256));
|
const seq_len = @as(i64, @intCast(cli.args.@"seq-len" orelse 256));
|
||||||
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32);
|
const input_shape = zml.Shape.init(.{ .b = 1, .s = seq_len }, .u32);
|
||||||
|
|
||||||
var start = try std.time.Timer.start();
|
var start = try stdx.time.Timer.start();
|
||||||
|
|
||||||
// Load weights
|
// Load weights
|
||||||
log.info("\tLoading ModernBERT weights from {?s}...", .{model_file});
|
log.info("\tLoading ModernBERT weights from {s}...", .{model_file});
|
||||||
var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform);
|
var bert_weights = try zml.aio.loadBuffers(modernbert.ModernBertForMaskedLM, .{modernbert_options}, tensor_store, model_arena, platform);
|
||||||
defer zml.aio.unloadBuffers(&bert_weights);
|
defer zml.aio.unloadBuffers(&bert_weights);
|
||||||
log.info("✅\tLoaded weights in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tLoaded weights in {D}", .{start.read()});
|
||||||
|
|
||||||
// Compile the model
|
// Compile the model
|
||||||
log.info("\tCompiling ModernBERT model...", .{});
|
log.info("\tCompiling ModernBERT model...", .{});
|
||||||
@ -143,7 +146,7 @@ pub fn asyncMain() !void {
|
|||||||
});
|
});
|
||||||
var bert_module = (try fut_mod.awaitt()).prepare(bert_weights);
|
var bert_module = (try fut_mod.awaitt()).prepare(bert_weights);
|
||||||
defer bert_module.deinit();
|
defer bert_module.deinit();
|
||||||
log.info("✅\tLoaded weights and compiled model in {d}ms", .{start.read() / std.time.ns_per_ms});
|
log.info("✅\tLoaded weights and compiled model in {D}", .{start.read()});
|
||||||
|
|
||||||
const text = cli.args.text orelse "Paris is the [MASK] of France.";
|
const text = cli.args.text orelse "Paris is the [MASK] of France.";
|
||||||
log.info("\tInput text: {s}", .{text});
|
log.info("\tInput text: {s}", .{text});
|
||||||
@ -213,7 +216,7 @@ pub fn unmask(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 {
|
pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer, prompt: []const u8) ![]const u32 {
|
||||||
var tokens = std.ArrayList(u32).init(allocator);
|
var tokens = std.array_list.Managed(u32).init(allocator);
|
||||||
var encoder = try tokenizer.encoder();
|
var encoder = try tokenizer.encoder();
|
||||||
defer encoder.deinit();
|
defer encoder.deinit();
|
||||||
|
|
||||||
@ -228,7 +231,7 @@ pub fn tokenize(allocator: std.mem.Allocator, tokenizer: zml.tokenizer.Tokenizer
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize {
|
fn findMaskPositions(allocator: std.mem.Allocator, tokens: []const u32, mask_token: u32) ![]usize {
|
||||||
var mask_positions = std.ArrayList(usize).init(allocator);
|
var mask_positions = std.array_list.Managed(usize).init(allocator);
|
||||||
defer mask_positions.deinit();
|
defer mask_positions.deinit();
|
||||||
|
|
||||||
for (tokens, 0..) |token, i| {
|
for (tokens, 0..) |token, i| {
|
||||||
@ -267,9 +270,7 @@ fn bool_parser(in: []const u8) error{}!bool {
|
|||||||
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
return std.mem.indexOfScalar(u8, "tTyY1", in[0]) != null;
|
||||||
}
|
}
|
||||||
|
|
||||||
fn printUsageAndExit(stderr: anytype) noreturn {
|
fn printUsageAndExit(stderr: std.fs.File) noreturn {
|
||||||
stderr.print("usage: ", .{}) catch {};
|
clap.usageToFile(stderr, clap.Help, ¶ms) catch {};
|
||||||
clap.usage(stderr, clap.Help, ¶ms) catch {};
|
|
||||||
stderr.print("\n", .{}) catch {};
|
|
||||||
std.process.exit(0);
|
std.process.exit(0);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
const std = @import("std");
|
const std = @import("std");
|
||||||
const zml = @import("zml");
|
|
||||||
const asynk = @import("async");
|
const asynk = @import("async");
|
||||||
|
const zml = @import("zml");
|
||||||
|
|
||||||
/// Model definition
|
/// Model definition
|
||||||
const Layer = struct {
|
const Layer = struct {
|
||||||
@ -38,9 +39,9 @@ pub fn asyncMain() !void {
|
|||||||
context.printAvailablePlatforms(platform);
|
context.printAvailablePlatforms(platform);
|
||||||
|
|
||||||
// Our weights and bias to use
|
// Our weights and bias to use
|
||||||
var weights = [4]f16{ 2.0, 2.0, 2.0, 2.0 };
|
var weights = [4]f32{ 2.0, 2.0, 2.0, 2.0 };
|
||||||
var bias = [4]f16{ 1.0, 2.0, 3.0, 4.0 };
|
var bias = [4]f32{ 1.0, 2.0, 3.0, 4.0 };
|
||||||
const input_shape = zml.Shape.init(.{4}, .f16);
|
const input_shape = zml.Shape.init(.{4}, .f32);
|
||||||
|
|
||||||
// We manually produce a BufferStore. You would not normally do that.
|
// We manually produce a BufferStore. You would not normally do that.
|
||||||
// A BufferStore is usually created by loading model data from a file.
|
// A BufferStore is usually created by loading model data from a file.
|
||||||
@ -80,8 +81,8 @@ pub fn asyncMain() !void {
|
|||||||
// prepare an input buffer
|
// prepare an input buffer
|
||||||
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
|
// Here, we use zml.HostBuffer.fromSlice to show how you would create a HostBuffer
|
||||||
// with a specific shape from an array.
|
// with a specific shape from an array.
|
||||||
// For situations where e.g. you have an [4]f16 array but need a .{2, 2} input shape.
|
// For situations where e.g. you have an [4]f32 array but need a .{2, 2} input shape.
|
||||||
var input = [4]f16{ 5.0, 5.0, 5.0, 5.0 };
|
var input = [4]f32{ 5.0, 5.0, 5.0, 5.0 };
|
||||||
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input), .{});
|
var input_buffer = try zml.Buffer.from(platform, zml.HostBuffer.fromSlice(input_shape, &input), .{});
|
||||||
defer input_buffer.deinit();
|
defer input_buffer.deinit();
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ pub fn asyncMain() !void {
|
|||||||
// fetch the result to CPU memory
|
// fetch the result to CPU memory
|
||||||
const cpu_result = try result.toHostAlloc(arena);
|
const cpu_result = try result.toHostAlloc(arena);
|
||||||
std.debug.print(
|
std.debug.print(
|
||||||
"\nThe result of {d} * {d} + {d} = {d}\n",
|
"\nThe result of {any} * {any} + {any} = {any}\n",
|
||||||
.{ &weights, &input, &bias, cpu_result.items(f16) },
|
.{ &weights, &input, &bias, cpu_result.items(f32) },
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user