Update docs and example to demonstrate adding the platform tag to buffers when adding weights.
This commit is contained in:
parent
12efc763d5
commit
9aeb4e9cd0
@ -44,7 +44,7 @@ Then, in your `BUILD.bazel`, you can refer to the files you defined above, in
|
||||
the following way:
|
||||
|
||||
```python
|
||||
zig_cc_binary(
|
||||
zig_binary(
|
||||
name = "mnist",
|
||||
args = [
|
||||
"$(location @mnist//:mnist.pt)",
|
||||
|
||||
@ -93,7 +93,7 @@ pub fn generateText(
|
||||
|
||||
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
||||
defer prefill_tokens.deinit();
|
||||
var prefill_token_pos = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0);
|
||||
var prefill_token_pos = try zml.Buffer.scalar(platform, 0, .u32);
|
||||
defer prefill_token_pos.deinit();
|
||||
|
||||
const prefilled_tokens, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_pos, kv_cache_, rng });
|
||||
|
||||
Loading…
Reference in New Issue
Block a user