zml/ops.zig: Added zml.ops.case operation
This can be used to select which branch will be run at runtime. It wraps the `stablehlo.case` operation.
This commit is contained in:
parent
27aabf9beb
commit
221ece647d
70
zml/ops.zig
70
zml/ops.zig
@ -514,6 +514,76 @@ test "if" {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Execute exactly one of the `branches` given by `index`.
|
||||||
|
///
|
||||||
|
/// If `index` is out of bound, it is clamped withing bounds.
|
||||||
|
///
|
||||||
|
/// The branches take no parameters but can access the local context given by `blkctx`.
|
||||||
|
pub fn case(
|
||||||
|
index: Tensor,
|
||||||
|
comptime branches: anytype,
|
||||||
|
blkctx: BlockSignNoArgs(branches[0]).BlkCtx,
|
||||||
|
) BlockSign(branches[0]).Return {
|
||||||
|
const Signature = BlockSignNoArgs(branches[0]);
|
||||||
|
const ctx = CompilationContext.current();
|
||||||
|
|
||||||
|
const branch_count = branches.len;
|
||||||
|
var blocks: [branch_count]mlir.Block = undefined;
|
||||||
|
var res: [branch_count]Signature.Return = undefined;
|
||||||
|
inline for (branches, 0..) |branch, i| {
|
||||||
|
blocks[i], res[i] = ctx.makeBlock(.open, Signature, &branch, blkctx, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
const loc = ctx.mlirCtx().location(@src());
|
||||||
|
const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.case", .{
|
||||||
|
.operands = &.{index.value()},
|
||||||
|
.result_type_inference = true,
|
||||||
|
.blocks = &blocks,
|
||||||
|
// We can't verify right away, cause the weights captured by the if haven't been added yet.
|
||||||
|
.verify = false,
|
||||||
|
.location = loc,
|
||||||
|
});
|
||||||
|
|
||||||
|
return fromMlirOperationWithTags(op, res[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
test "case" {
|
||||||
|
const zml = @import("zml.zig");
|
||||||
|
const platform = zml.testing.env();
|
||||||
|
|
||||||
|
const CaseMod = struct {
|
||||||
|
pub fn _fwd(index: Tensor, a: Tensor, b: Tensor) Tensor {
|
||||||
|
const result = case(index, .{ case1, case2, case3 }, .{ a, b });
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn case1(a: Tensor, b: Tensor) Tensor {
|
||||||
|
return a.matmul(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn case2(a: Tensor, b: Tensor) Tensor {
|
||||||
|
return a.add(b);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn case3(a: Tensor, b: Tensor) Tensor {
|
||||||
|
return a.sub(b);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
const index = try zml.Buffer.fromSlice(platform, .{}, &[1]i32{1});
|
||||||
|
const a = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[4]f32{ 1, 1, 2, 2 });
|
||||||
|
const b = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[4]f32{ 1, 1, 1, 1 });
|
||||||
|
const result = try zml.testing.compileAndCall(platform, CaseMod._fwd, .{ index, a, b });
|
||||||
|
|
||||||
|
const expected: [2][2]f32 = .{
|
||||||
|
.{ 2, 2 },
|
||||||
|
.{ 3, 3 },
|
||||||
|
};
|
||||||
|
try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sort(
|
pub fn sort(
|
||||||
comptime comp_fn: anytype,
|
comptime comp_fn: anytype,
|
||||||
blkctx: BlockSign(comp_fn).BlkCtx,
|
blkctx: BlockSign(comp_fn).BlkCtx,
|
||||||
|
|||||||
@ -1,27 +0,0 @@
|
|||||||
const std = @import("std");
|
|
||||||
const c = @import("c");
|
|
||||||
const HFTokenizers = @import("hftokenizers").HFTokenizers;
|
|
||||||
|
|
||||||
pub fn main() !void {
|
|
||||||
const tokenizer = HFTokenizers.init("/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json");
|
|
||||||
defer HFTokenizers.deinit(tokenizer);
|
|
||||||
|
|
||||||
const input = "Hello, world! plane pouet plane";
|
|
||||||
var encoded = HFTokenizers.encode(tokenizer, input);
|
|
||||||
defer encoded.deinit();
|
|
||||||
var pouet = std.ArrayList(u32).init(std.heap.c_allocator);
|
|
||||||
defer pouet.deinit();
|
|
||||||
|
|
||||||
// try pouet.appendSlice(encoded.ids);
|
|
||||||
|
|
||||||
var t = try std.time.Timer.start();
|
|
||||||
for (0..100) |_| {
|
|
||||||
try pouet.appendSlice(encoded.ids);
|
|
||||||
t.reset();
|
|
||||||
var decoded = HFTokenizers.decode(tokenizer, pouet.items);
|
|
||||||
defer decoded.deinit();
|
|
||||||
const elapsed = t.lap();
|
|
||||||
// std.debug.print("{any} {any} {d}us\n", .{tokenizer, encoded, elapsed / std.time.ns_per_us});
|
|
||||||
std.debug.print("{any} {any} {s} {d}ns {d}us\n", .{ tokenizer, encoded, decoded.str, elapsed, elapsed / std.time.ns_per_us });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,289 +0,0 @@
|
|||||||
const std = @import("std");
|
|
||||||
const c = @import("c");
|
|
||||||
const ffi = @import("ffi");
|
|
||||||
|
|
||||||
pub const SentencePieceError = error{
|
|
||||||
Cancelled,
|
|
||||||
Unknown,
|
|
||||||
InvalidArgument,
|
|
||||||
DeadlineExceeded,
|
|
||||||
NotFound,
|
|
||||||
AlreadyExists,
|
|
||||||
PermissionDenied,
|
|
||||||
ResourceExhausted,
|
|
||||||
FailedPrecondition,
|
|
||||||
Aborted,
|
|
||||||
OutOfRange,
|
|
||||||
Unimplemented,
|
|
||||||
Internal,
|
|
||||||
Unavailable,
|
|
||||||
DataLoss,
|
|
||||||
Unauthenticated,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const DecoderStream = struct {
|
|
||||||
const TokensSize = 4;
|
|
||||||
const StringSize = 128;
|
|
||||||
decoder: SentencePieceProcessor.Decoder,
|
|
||||||
buffer: [StringSize]u8 = undefined,
|
|
||||||
last_tokens: []u8 = &.{},
|
|
||||||
|
|
||||||
pub fn init(decoder: SentencePieceProcessor.Decoder) DecoderStream {
|
|
||||||
var ret: DecoderStream = .{
|
|
||||||
.decoder = decoder,
|
|
||||||
};
|
|
||||||
ret.decoder.reserve_tokens(TokensSize);
|
|
||||||
ret.decoder.reserve_string(StringSize);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn next(self: *DecoderStream, next_token: u32) !?[]const u8 {
|
|
||||||
if (self.decoder.tokens().len >= TokensSize) {
|
|
||||||
const tokens = self.decoder.tokens();
|
|
||||||
inline for (0..TokensSize - 1) |i| {
|
|
||||||
tokens[i] = tokens[i + 1];
|
|
||||||
}
|
|
||||||
tokens[TokensSize - 1] = next_token;
|
|
||||||
} else {
|
|
||||||
self.decoder.append(next_token);
|
|
||||||
}
|
|
||||||
const new_tokens = try self.decoder.decode();
|
|
||||||
if (self.last_tokens.len == 0) {
|
|
||||||
self.last_tokens = self.buffer[0..new_tokens.len];
|
|
||||||
@memcpy(self.last_tokens, new_tokens);
|
|
||||||
return new_tokens;
|
|
||||||
}
|
|
||||||
for (1..self.last_tokens.len) |i| {
|
|
||||||
if (std.mem.startsWith(u8, new_tokens, self.last_tokens[i..])) {
|
|
||||||
const toks = new_tokens[self.last_tokens.len - i ..];
|
|
||||||
self.last_tokens = self.buffer[0..new_tokens.len];
|
|
||||||
@memcpy(self.last_tokens, new_tokens);
|
|
||||||
return toks;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const SentencePieceProcessor = opaque {
|
|
||||||
pub const Encoder = struct {
|
|
||||||
inner: *SentencePieceProcessor,
|
|
||||||
vec: *c.std_vector_int,
|
|
||||||
|
|
||||||
fn init(inner: *SentencePieceProcessor) Encoder {
|
|
||||||
return .{
|
|
||||||
.inner = inner,
|
|
||||||
.vec = c.std_vector_int_new() orelse unreachable,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *Encoder) void {
|
|
||||||
c.std_vector_int_delete(self.vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reserve(self: *Encoder, size: usize) void {
|
|
||||||
c.std_vector_int_reserve(self.vec, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset(self: *Encoder) void {
|
|
||||||
c.std_vector_int_clear(self.vec);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encode(self: *Encoder, input: []const u8) ![]const u32 {
|
|
||||||
try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec));
|
|
||||||
return ffi.ZigSlice.to(u32, .{
|
|
||||||
.ptr = c.std_vector_int_data(self.vec),
|
|
||||||
.len = c.std_vector_int_size(self.vec),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub const Decoder = struct {
|
|
||||||
inner: *SentencePieceProcessor,
|
|
||||||
vec: *c.std_vector_int,
|
|
||||||
str: *c.std_string,
|
|
||||||
|
|
||||||
fn init(inner: *SentencePieceProcessor) Decoder {
|
|
||||||
return .{
|
|
||||||
.inner = inner,
|
|
||||||
.vec = c.std_vector_int_new() orelse unreachable,
|
|
||||||
.str = c.std_string_new() orelse unreachable,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn append(self: *Decoder, token: u32) void {
|
|
||||||
c.std_vector_int_push_back(self.vec, @intCast(token));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *Decoder) void {
|
|
||||||
c.std_vector_int_delete(self.vec);
|
|
||||||
c.std_string_delete(self.str);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reserve_tokens(self: *Decoder, size: usize) void {
|
|
||||||
c.std_vector_int_reserve(self.vec, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reserve_string(self: *Decoder, size: usize) void {
|
|
||||||
c.std_string_reserve(self.str, size);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset(self: *Decoder) void {
|
|
||||||
c.std_vector_int_clear(self.vec);
|
|
||||||
c.std_string_clear(self.str);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decode(self: *Decoder) ![]const u8 {
|
|
||||||
try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str));
|
|
||||||
return self.string();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn string(self: *const Decoder) []const u8 {
|
|
||||||
const res = c.std_string_data(self.str);
|
|
||||||
return ffi.ZigSlice.to(u8, res);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn tokens(self: *const Decoder) []u32 {
|
|
||||||
const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec));
|
|
||||||
return ptr[0..c.std_vector_int_size(self.vec)];
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
fn assertOk(code: c.sentencepiece_util_StatusCode) SentencePieceError!void {
|
|
||||||
return switch (code) {
|
|
||||||
c.sentencepiece_util_StatusCode_kOk => {},
|
|
||||||
c.sentencepiece_util_StatusCode_kCancelled => error.Cancelled,
|
|
||||||
c.sentencepiece_util_StatusCode_kUnknown => error.Unknown,
|
|
||||||
c.sentencepiece_util_StatusCode_kInvalidArgument => error.InvalidArgument,
|
|
||||||
c.sentencepiece_util_StatusCode_kDeadlineExceeded => error.DeadlineExceeded,
|
|
||||||
c.sentencepiece_util_StatusCode_kNotFound => error.NotFound,
|
|
||||||
c.sentencepiece_util_StatusCode_kAlreadyExists => error.AlreadyExists,
|
|
||||||
c.sentencepiece_util_StatusCode_kPermissionDenied => error.PermissionDenied,
|
|
||||||
c.sentencepiece_util_StatusCode_kResourceExhausted => error.ResourceExhausted,
|
|
||||||
c.sentencepiece_util_StatusCode_kFailedPrecondition => error.FailedPrecondition,
|
|
||||||
c.sentencepiece_util_StatusCode_kAborted => error.Aborted,
|
|
||||||
c.sentencepiece_util_StatusCode_kOutOfRange => error.OutOfRange,
|
|
||||||
c.sentencepiece_util_StatusCode_kUnimplemented => error.Unimplemented,
|
|
||||||
c.sentencepiece_util_StatusCode_kInternal => error.Internal,
|
|
||||||
c.sentencepiece_util_StatusCode_kUnavailable => error.Unavailable,
|
|
||||||
c.sentencepiece_util_StatusCode_kDataLoss => error.DataLoss,
|
|
||||||
c.sentencepiece_util_StatusCode_kUnauthenticated => error.Unauthenticated,
|
|
||||||
else => unreachable,
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn load(model: []const u8) !*SentencePieceProcessor {
|
|
||||||
const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new());
|
|
||||||
errdefer sp.deinit();
|
|
||||||
try sp.load_from(model);
|
|
||||||
return sp;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn deinit(self: *SentencePieceProcessor) void {
|
|
||||||
c.SentencePieceProcessor_delete(@ptrCast(self));
|
|
||||||
}
|
|
||||||
|
|
||||||
fn load_from(self: *SentencePieceProcessor, model: []const u8) !void {
|
|
||||||
try assertOk(c.SentencePieceProcessor_Load(@ptrCast(self), ffi.ZigSlice.from(model)));
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn encoder(self: *SentencePieceProcessor) Encoder {
|
|
||||||
return Encoder.init(self);
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn decoder(self: *SentencePieceProcessor) Decoder {
|
|
||||||
return Decoder.init(self);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 {
|
|
||||||
var result: [std.fs.max_path_bytes:0]u8 = undefined;
|
|
||||||
@memcpy(result[0..path.len], path);
|
|
||||||
result[path.len] = 0;
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn main() !void {
|
|
||||||
const sp = try SentencePieceProcessor.load("/Users/steeve/Downloads/poolside.sp.pb");
|
|
||||||
defer sp.deinit();
|
|
||||||
|
|
||||||
std.debug.print("Loaded model\n", .{});
|
|
||||||
|
|
||||||
var encoder = sp.encoder();
|
|
||||||
defer encoder.deinit();
|
|
||||||
|
|
||||||
var decoder = sp.decoder();
|
|
||||||
defer decoder.deinit();
|
|
||||||
|
|
||||||
const ss = @embedFile("main.zig");
|
|
||||||
// \\String class
|
|
||||||
// \\Strings are objects that represent sequences of characters.
|
|
||||||
// \\
|
|
||||||
// \\The standard string class provides support for such objects with an interface similar to that of a standard container of bytes, but adding features specifically designed to operate with strings of single-byte characters.
|
|
||||||
// \\
|
|
||||||
// \\The string class is an instantiation of the basic_string class template that uses char (i.e., bytes) as its character type, with its default char_traits and allocator types (see basic_string for more info on the template).
|
|
||||||
// \\
|
|
||||||
// \\Note that this class handles bytes independently of the encoding used: If used to handle sequences of multi-byte or variable-length characters (such as UTF-8), all members of this class (such as length or size), as well as its iterators, will still operate in terms of bytes (not actual encoded characters).
|
|
||||||
// \\
|
|
||||||
// ;
|
|
||||||
const tokens = try encoder.encode(ss);
|
|
||||||
|
|
||||||
// const ss2 = 128;
|
|
||||||
// var buf = [_]u8{0} ** ss2;
|
|
||||||
// // _ = buf; // autofix
|
|
||||||
// var last_tokens: []u8 = &.{};
|
|
||||||
// // _ = last_tokens; // autofix
|
|
||||||
// decoder.reserve_tokens(4);
|
|
||||||
// decoder.reserve_string(128);
|
|
||||||
|
|
||||||
var stream = DecoderStream.init(decoder);
|
|
||||||
|
|
||||||
var start = try std.time.Timer.start();
|
|
||||||
for (tokens) |token| {
|
|
||||||
if (try stream.next(token)) |chunk| {
|
|
||||||
// std.debug.print("{s}", .{chunk});
|
|
||||||
std.debug.print("{d}us - {s}\n", .{ start.lap() / std.time.ns_per_us, chunk });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// var start = try std.time.Timer.start();
|
|
||||||
// var it = std.mem.window(u32, tokens, 3, 1);
|
|
||||||
// while (it.next()) |slice| {
|
|
||||||
// if (decoder.tokens().len >= 4) {
|
|
||||||
// const kept_tokens = decoder.tokens()[1..];
|
|
||||||
// std.mem.copyForwards(u32, decoder.tokens()[0..kept_tokens.len], kept_tokens);
|
|
||||||
// kept_tokens[kept_tokens.len - 1] = slice[2];
|
|
||||||
// } else {
|
|
||||||
// for (slice) |token| {
|
|
||||||
// decoder.append(token);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// const new_tokens = try decoder.decode();
|
|
||||||
// for (0..ss2) |i| {
|
|
||||||
// if (std.mem.startsWith(u8, new_tokens, last_tokens[i..])) {
|
|
||||||
// const toks = new_tokens[last_tokens.len - i..];
|
|
||||||
// // std.debug.print("{s}", .{toks});
|
|
||||||
// if (toks.len == 0) {
|
|
||||||
// // std.debug.print("WESH\n", .{});
|
|
||||||
// }
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// last_tokens = buf[0..new_tokens.len];
|
|
||||||
// @memcpy(last_tokens, new_tokens);
|
|
||||||
// std.debug.print("{d}us\n", .{start.lap() / std.time.ns_per_us});
|
|
||||||
// }
|
|
||||||
|
|
||||||
// for (tokens) |token| {
|
|
||||||
// decoder.append(token);
|
|
||||||
// }
|
|
||||||
// const decoded = try decoder.decode();
|
|
||||||
// std.debug.print("Decoded: {s}\n", .{decoded});
|
|
||||||
|
|
||||||
// const model = "/Users/steeve/Downloads/poolside.sp.pb";
|
|
||||||
|
|
||||||
// c.SentencePieceProcessor_LoadOrDie(sp, c.zig_slice{ .ptr = model.ptr, .len = model.len });
|
|
||||||
|
|
||||||
// const piece = c.SentencePieceProcessor_IdToPiece(sp, 10999);
|
|
||||||
// std.debug.print("{s}\n", .{piece.ptr[0..piece.len]});
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue
Block a user