Update docs and example to demonstrate adding the platform tag to buffers when adding weights.
This commit is contained in:
parent
12efc763d5
commit
9aeb4e9cd0
@ -44,7 +44,7 @@ Then, in your `BUILD.bazel`, you can refer to the files you defined above, in
|
|||||||
the following way:
|
the following way:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
zig_cc_binary(
|
zig_binary(
|
||||||
name = "mnist",
|
name = "mnist",
|
||||||
args = [
|
args = [
|
||||||
"$(location @mnist//:mnist.pt)",
|
"$(location @mnist//:mnist.pt)",
|
||||||
|
|||||||
@ -93,7 +93,7 @@ pub fn generateText(
|
|||||||
|
|
||||||
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
var prefill_tokens = try zml.Buffer.fromSlice(platform, .{max_seq_len}, prefill_buffer);
|
||||||
defer prefill_tokens.deinit();
|
defer prefill_tokens.deinit();
|
||||||
var prefill_token_pos = try zml.Buffer.constant(platform, zml.Shape.init(.{}, .u32), 0);
|
var prefill_token_pos = try zml.Buffer.scalar(platform, 0, .u32);
|
||||||
defer prefill_token_pos.deinit();
|
defer prefill_token_pos.deinit();
|
||||||
|
|
||||||
const prefilled_tokens, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_pos, kv_cache_, rng });
|
const prefilled_tokens, const kv_cache, rng = mod_prefill.call(.{ prefill_tokens, prefill_token_pos, kv_cache_, rng });
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user