diff --git a/zml/ops.zig b/zml/ops.zig index f314e01..c73dc68 100644 --- a/zml/ops.zig +++ b/zml/ops.zig @@ -514,6 +514,76 @@ test "if" { } } +/// Execute exactly one of the `branches` given by `index`. +/// +/// If `index` is out of bound, it is clamped withing bounds. +/// +/// The branches take no parameters but can access the local context given by `blkctx`. +pub fn case( + index: Tensor, + comptime branches: anytype, + blkctx: BlockSignNoArgs(branches[0]).BlkCtx, +) BlockSign(branches[0]).Return { + const Signature = BlockSignNoArgs(branches[0]); + const ctx = CompilationContext.current(); + + const branch_count = branches.len; + var blocks: [branch_count]mlir.Block = undefined; + var res: [branch_count]Signature.Return = undefined; + inline for (branches, 0..) |branch, i| { + blocks[i], res[i] = ctx.makeBlock(.open, Signature, &branch, blkctx, {}); + } + + const loc = ctx.mlirCtx().location(@src()); + const op = mlir.Operation.make(ctx.mlirCtx(), "stablehlo.case", .{ + .operands = &.{index.value()}, + .result_type_inference = true, + .blocks = &blocks, + // We can't verify right away, cause the weights captured by the if haven't been added yet. + .verify = false, + .location = loc, + }); + + return fromMlirOperationWithTags(op, res[0]); +} + +test "case" { + const zml = @import("zml.zig"); + const platform = zml.testing.env(); + + const CaseMod = struct { + pub fn _fwd(index: Tensor, a: Tensor, b: Tensor) Tensor { + const result = case(index, .{ case1, case2, case3 }, .{ a, b }); + return result; + } + + pub fn case1(a: Tensor, b: Tensor) Tensor { + return a.matmul(b); + } + + pub fn case2(a: Tensor, b: Tensor) Tensor { + return a.add(b); + } + + pub fn case3(a: Tensor, b: Tensor) Tensor { + return a.sub(b); + } + }; + + { + const index = try zml.Buffer.fromSlice(platform, .{}, &[1]i32{1}); + const a = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[4]f32{ 1, 1, 2, 2 }); + const b = try zml.Buffer.fromSlice(platform, .{ 2, 2 }, &[4]f32{ 1, 1, 1, 1 }); + const result = try zml.testing.compileAndCall(platform, CaseMod._fwd, .{ index, a, b }); + + const expected: [2][2]f32 = .{ + .{ 2, 2 }, + .{ 3, 3 }, + }; + try std.testing.expectEqual(expected, result.getValue(@TypeOf(expected))); + } +} + pub fn sort( comptime comp_fn: anytype, blkctx: BlockSign(comp_fn).BlkCtx, diff --git a/zml/tokenizer/hftokenizers/main.zig b/zml/tokenizer/hftokenizers/main.zig deleted file mode 100644 index 395660b..0000000 --- a/zml/tokenizer/hftokenizers/main.zig +++ /dev/null @@ -1,27 +0,0 @@ -const std = @import("std"); -const c = @import("c"); -const HFTokenizers = @import("hftokenizers").HFTokenizers; - -pub fn main() !void { - const tokenizer = HFTokenizers.init("/private/var/tmp/_bazel_steeve/a67b810d44f2a673ebbd5bab86ccd5cc/external/zml~~huggingface~Meta-Llama-3.1-8B-Instruct/tokenizer.json"); - defer HFTokenizers.deinit(tokenizer); - - const input = "Hello, world! plane pouet plane"; - var encoded = HFTokenizers.encode(tokenizer, input); - defer encoded.deinit(); - var pouet = std.ArrayList(u32).init(std.heap.c_allocator); - defer pouet.deinit(); - - // try pouet.appendSlice(encoded.ids); - - var t = try std.time.Timer.start(); - for (0..100) |_| { - try pouet.appendSlice(encoded.ids); - t.reset(); - var decoded = HFTokenizers.decode(tokenizer, pouet.items); - defer decoded.deinit(); - const elapsed = t.lap(); - // std.debug.print("{any} {any} {d}us\n", .{tokenizer, encoded, elapsed / std.time.ns_per_us}); - std.debug.print("{any} {any} {s} {d}ns {d}us\n", .{ tokenizer, encoded, decoded.str, elapsed, elapsed / std.time.ns_per_us }); - } -} diff --git a/zml/tokenizer/sentencepiece/main.zig b/zml/tokenizer/sentencepiece/main.zig deleted file mode 100644 index caa336f..0000000 --- a/zml/tokenizer/sentencepiece/main.zig +++ /dev/null @@ -1,289 +0,0 @@ -const std = @import("std"); -const c = @import("c"); -const ffi = @import("ffi"); - -pub const SentencePieceError = error{ - Cancelled, - Unknown, - InvalidArgument, - DeadlineExceeded, - NotFound, - AlreadyExists, - PermissionDenied, - ResourceExhausted, - FailedPrecondition, - Aborted, - OutOfRange, - Unimplemented, - Internal, - Unavailable, - DataLoss, - Unauthenticated, -}; - -pub const DecoderStream = struct { - const TokensSize = 4; - const StringSize = 128; - decoder: SentencePieceProcessor.Decoder, - buffer: [StringSize]u8 = undefined, - last_tokens: []u8 = &.{}, - - pub fn init(decoder: SentencePieceProcessor.Decoder) DecoderStream { - var ret: DecoderStream = .{ - .decoder = decoder, - }; - ret.decoder.reserve_tokens(TokensSize); - ret.decoder.reserve_string(StringSize); - return ret; - } - - pub fn next(self: *DecoderStream, next_token: u32) !?[]const u8 { - if (self.decoder.tokens().len >= TokensSize) { - const tokens = self.decoder.tokens(); - inline for (0..TokensSize - 1) |i| { - tokens[i] = tokens[i + 1]; - } - tokens[TokensSize - 1] = next_token; - } else { - self.decoder.append(next_token); - } - const new_tokens = try self.decoder.decode(); - if (self.last_tokens.len == 0) { - self.last_tokens = self.buffer[0..new_tokens.len]; - @memcpy(self.last_tokens, new_tokens); - return new_tokens; - } - for (1..self.last_tokens.len) |i| { - if (std.mem.startsWith(u8, new_tokens, self.last_tokens[i..])) { - const toks = new_tokens[self.last_tokens.len - i ..]; - self.last_tokens = self.buffer[0..new_tokens.len]; - @memcpy(self.last_tokens, new_tokens); - return toks; - } - } - return null; - } -}; - -pub const SentencePieceProcessor = opaque { - pub const Encoder = struct { - inner: *SentencePieceProcessor, - vec: *c.std_vector_int, - - fn init(inner: *SentencePieceProcessor) Encoder { - return .{ - .inner = inner, - .vec = c.std_vector_int_new() orelse unreachable, - }; - } - - pub fn deinit(self: *Encoder) void { - c.std_vector_int_delete(self.vec); - } - - pub fn reserve(self: *Encoder, size: usize) void { - c.std_vector_int_reserve(self.vec, size); - } - - pub fn reset(self: *Encoder) void { - c.std_vector_int_clear(self.vec); - } - - pub fn encode(self: *Encoder, input: []const u8) ![]const u32 { - try assertOk(c.SentencePieceProcessor_Encode(@ptrCast(self.inner), ffi.ZigSlice.from(input), self.vec)); - return ffi.ZigSlice.to(u32, .{ - .ptr = c.std_vector_int_data(self.vec), - .len = c.std_vector_int_size(self.vec), - }); - } - }; - - pub const Decoder = struct { - inner: *SentencePieceProcessor, - vec: *c.std_vector_int, - str: *c.std_string, - - fn init(inner: *SentencePieceProcessor) Decoder { - return .{ - .inner = inner, - .vec = c.std_vector_int_new() orelse unreachable, - .str = c.std_string_new() orelse unreachable, - }; - } - - pub fn append(self: *Decoder, token: u32) void { - c.std_vector_int_push_back(self.vec, @intCast(token)); - } - - pub fn deinit(self: *Decoder) void { - c.std_vector_int_delete(self.vec); - c.std_string_delete(self.str); - } - - pub fn reserve_tokens(self: *Decoder, size: usize) void { - c.std_vector_int_reserve(self.vec, size); - } - - pub fn reserve_string(self: *Decoder, size: usize) void { - c.std_string_reserve(self.str, size); - } - - pub fn reset(self: *Decoder) void { - c.std_vector_int_clear(self.vec); - c.std_string_clear(self.str); - } - - pub fn decode(self: *Decoder) ![]const u8 { - try assertOk(c.SentencePieceProcessor_Decode(@ptrCast(self.inner), self.vec, self.str)); - return self.string(); - } - - pub fn string(self: *const Decoder) []const u8 { - const res = c.std_string_data(self.str); - return ffi.ZigSlice.to(u8, res); - } - - pub fn tokens(self: *const Decoder) []u32 { - const ptr: [*c]u32 = @ptrCast(c.std_vector_int_data(self.vec)); - return ptr[0..c.std_vector_int_size(self.vec)]; - } - }; - - fn assertOk(code: c.sentencepiece_util_StatusCode) SentencePieceError!void { - return switch (code) { - c.sentencepiece_util_StatusCode_kOk => {}, - c.sentencepiece_util_StatusCode_kCancelled => error.Cancelled, - c.sentencepiece_util_StatusCode_kUnknown => error.Unknown, - c.sentencepiece_util_StatusCode_kInvalidArgument => error.InvalidArgument, - c.sentencepiece_util_StatusCode_kDeadlineExceeded => error.DeadlineExceeded, - c.sentencepiece_util_StatusCode_kNotFound => error.NotFound, - c.sentencepiece_util_StatusCode_kAlreadyExists => error.AlreadyExists, - c.sentencepiece_util_StatusCode_kPermissionDenied => error.PermissionDenied, - c.sentencepiece_util_StatusCode_kResourceExhausted => error.ResourceExhausted, - c.sentencepiece_util_StatusCode_kFailedPrecondition => error.FailedPrecondition, - c.sentencepiece_util_StatusCode_kAborted => error.Aborted, - c.sentencepiece_util_StatusCode_kOutOfRange => error.OutOfRange, - c.sentencepiece_util_StatusCode_kUnimplemented => error.Unimplemented, - c.sentencepiece_util_StatusCode_kInternal => error.Internal, - c.sentencepiece_util_StatusCode_kUnavailable => error.Unavailable, - c.sentencepiece_util_StatusCode_kDataLoss => error.DataLoss, - c.sentencepiece_util_StatusCode_kUnauthenticated => error.Unauthenticated, - else => unreachable, - }; - } - - pub fn load(model: []const u8) !*SentencePieceProcessor { - const sp: *SentencePieceProcessor = @ptrCast(c.SentencePieceProcessor_new()); - errdefer sp.deinit(); - try sp.load_from(model); - return sp; - } - - pub fn deinit(self: *SentencePieceProcessor) void { - c.SentencePieceProcessor_delete(@ptrCast(self)); - } - - fn load_from(self: *SentencePieceProcessor, model: []const u8) !void { - try assertOk(c.SentencePieceProcessor_Load(@ptrCast(self), ffi.ZigSlice.from(model))); - } - - pub fn encoder(self: *SentencePieceProcessor) Encoder { - return Encoder.init(self); - } - - pub fn decoder(self: *SentencePieceProcessor) Decoder { - return Decoder.init(self); - } -}; - -pub fn as_path(path: []const u8) [std.fs.max_path_bytes:0]u8 { - var result: [std.fs.max_path_bytes:0]u8 = undefined; - @memcpy(result[0..path.len], path); - result[path.len] = 0; - return result; -} - -pub fn main() !void { - const sp = try SentencePieceProcessor.load("/Users/steeve/Downloads/poolside.sp.pb"); - defer sp.deinit(); - - std.debug.print("Loaded model\n", .{}); - - var encoder = sp.encoder(); - defer encoder.deinit(); - - var decoder = sp.decoder(); - defer decoder.deinit(); - - const ss = @embedFile("main.zig"); - // \\String class - // \\Strings are objects that represent sequences of characters. - // \\ - // \\The standard string class provides support for such objects with an interface similar to that of a standard container of bytes, but adding features specifically designed to operate with strings of single-byte characters. - // \\ - // \\The string class is an instantiation of the basic_string class template that uses char (i.e., bytes) as its character type, with its default char_traits and allocator types (see basic_string for more info on the template). - // \\ - // \\Note that this class handles bytes independently of the encoding used: If used to handle sequences of multi-byte or variable-length characters (such as UTF-8), all members of this class (such as length or size), as well as its iterators, will still operate in terms of bytes (not actual encoded characters). - // \\ - // ; - const tokens = try encoder.encode(ss); - - // const ss2 = 128; - // var buf = [_]u8{0} ** ss2; - // // _ = buf; // autofix - // var last_tokens: []u8 = &.{}; - // // _ = last_tokens; // autofix - // decoder.reserve_tokens(4); - // decoder.reserve_string(128); - - var stream = DecoderStream.init(decoder); - - var start = try std.time.Timer.start(); - for (tokens) |token| { - if (try stream.next(token)) |chunk| { - // std.debug.print("{s}", .{chunk}); - std.debug.print("{d}us - {s}\n", .{ start.lap() / std.time.ns_per_us, chunk }); - } - } - - // var start = try std.time.Timer.start(); - // var it = std.mem.window(u32, tokens, 3, 1); - // while (it.next()) |slice| { - // if (decoder.tokens().len >= 4) { - // const kept_tokens = decoder.tokens()[1..]; - // std.mem.copyForwards(u32, decoder.tokens()[0..kept_tokens.len], kept_tokens); - // kept_tokens[kept_tokens.len - 1] = slice[2]; - // } else { - // for (slice) |token| { - // decoder.append(token); - // } - // } - // const new_tokens = try decoder.decode(); - // for (0..ss2) |i| { - // if (std.mem.startsWith(u8, new_tokens, last_tokens[i..])) { - // const toks = new_tokens[last_tokens.len - i..]; - // // std.debug.print("{s}", .{toks}); - // if (toks.len == 0) { - // // std.debug.print("WESH\n", .{}); - // } - // break; - // } - // } - // last_tokens = buf[0..new_tokens.len]; - // @memcpy(last_tokens, new_tokens); - // std.debug.print("{d}us\n", .{start.lap() / std.time.ns_per_us}); - // } - - // for (tokens) |token| { - // decoder.append(token); - // } - // const decoded = try decoder.decode(); - // std.debug.print("Decoded: {s}\n", .{decoded}); - - // const model = "/Users/steeve/Downloads/poolside.sp.pb"; - - // c.SentencePieceProcessor_LoadOrDie(sp, c.zig_slice{ .ptr = model.ptr, .len = model.len }); - - // const piece = c.SentencePieceProcessor_IdToPiece(sp, 10999); - // std.debug.print("{s}\n", .{piece.ptr[0..piece.len]}); -}