# How to port Pytorch models to ZML ? ## Requirements We assume you already have a working ZML project, and you can run it with a Bazel command like `bazel run //my_project:torch2zml`. You can refer to [write your first model](../tutorials/write_first_model.md) to do so. We also assume that you know enough Python to run the reference implementation. ## Overview Porting Neural Network implementations can be tedious. Some small errors can degrade the output of the model, in subtle or not so subtle ways. To track down errors in a model with four thousand layers, we best be organized. By the way if you are interested in a specific model, be careful that not all implementations of a model you can find on Github are equivalent. Sometimes people introduce subtle bugs when porting across Python libraries. Ideally use the author's implementation, or at least one you have tested yourself. **The recommended process is as follows:** 1. run the reference implementation on a known input, and sample layer activations 2. start a ZML project and load the sampled reference activations 3. start porting layers one by one, and test individual layers 4. end-to-end test the model ## Sampling reference activations Pytorch exposes "forward hooks" that allow to inspect the input/output of each `torch.nn.Module`. That way it is possible to create a dictionary with each layer input/output, keyed by the name of the layer. The main caveat is that if you have a functional implementation that doesn't use `torch.nn.Module`, this technique won't work. It is the easiest to start from a "huggingface" snippet, or a python script that calls the model of your choice on an example input. eg: ```python import torch import transformers model_path = "meta-llama/Meta-Llama-3-8B" pipeline = transformers.pipeline( "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.float16}, # device="cuda", token=token, ) prompt = "Q: What is the largest animal?\nA:" output = pipeline(prompt) print(output) ``` Then edit the script to import [zml_utils](https://github.com/zml/zml/blob/master/tools/zml_utils.py). `zml_utils.py` is standalone and currently it's not distributed as a python package, so the simplest way to use it, is to copy it next to your python script. Then wrap the model/pipeline in a `zml_utils.ActivationCollector`. The collector wraps the given model, and returns the original results AND the activations in a dict of `torch.Tensor` when it's being called. After that, you can save those activations to a `.pt` file. ```python import torch import transformers import zml_utils model_path = "meta-llama/Meta-Llama-3-8B" pipeline = transformers.pipeline( "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.float16}, # device="cuda", ) model, tokenizer = pipeline.model, pipeline.tokenizer prompt = "Q: What is the largest animal?\nA:" # Wrap the pipeline, and extract activations. # Activations files can be huge for big models, # so let's stop collecting after 1000 layers. pipeline = zml_utils.ActivationCollector(pipeline, max_layers=1000, stop_after_first_step=True) output, activations = pipeline(prompt) # `output` can be `None` if activations collection # has stopped before the end of the inference if output: print(output) # Save activations to a file. filename = model_path.split("/")[-1] + ".activations.pt" torch.save(activations, filename) print(f"Saved {len(activations)} activations to {filename}") ``` Run this script: `python activations.py` If you're using HuggingFace, make note of the local path where the model is saved, it should be something like `~/.cache/huggingface/hub/...`. (and should appear on the console when running the script). We will need it in the next steps. ## Loading model and activations in ZML Let's create a basic ZML program that loads the activations and the Pytorch model. Put the following in `my_project/torch2zml.zig`. ```zig const std = @import("std"); const log = std.log; const async = @import("async"); const zml = @import("zml"); pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); try async.AsyncThread.main(gpa.allocator(), asyncMain, .{}); } pub fn asyncMain() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); const args = try std.process.argsAlloc(allocator); defer std.process.argsFree(allocator, args); const model_path, const activations_path = args[1..3].*; const activations = try zml.aio.torch.open(allocator, activations_path); defer activations.deinit(); log.info("Found {} activations in {s}", .{ activations.buffers.count(), activations_path }); const model_weights = try zml.aio.detectFormatAndOpen(allocator, model_path); defer model_weights.deinit(); log.info("Found {} model layers in {s}", .{ model_weights.buffers.count(), activations_path }); } ``` And add a `zig_cc_binary` target in `my_project/BUILD.bazel`: ```python load("@rules_zig//zig:defs.bzl", "zig_binary") zig_binary( name = "torch2zml", main = "torch2zml.zig", deps = [ "@zml//async", "@zml//zml", ], ) ``` Now check that the weights can be loaded correctly using the bazel CLI. ```bash bazel build //my_project:torch2zml ./bazel-bin/my_project/torch2zml /path/to/my/model.safetensors.index.json ./my_project/Meta-Llama-3-8B.activations.pt info: Found 1108 activations in /Users/guw/Documents/zml/models/torch2zml/Meta-Llama-3-8B.activations.pt debug(zml_io): Loading shard: model-00004-of-00004.safetensors debug(zml_io): Loading shard: model-00001-of-00004.safetensors debug(zml_io): Loading shard: model-00002-of-00004.safetensors debug(zml_io): Loading shard: model-00003-of-00004.safetensors info: Found 291 model layers in /Users/guw/Documents/zml/models/torch2zml/Meta-Llama-3-8B.activations.pt ``` ## Loading an individual layer In the above Zig code, the `model_weights` struct is a wrapper around a flat dictionary, containing an entry for each tensor in the model (similar to a "state dict"). Manipulating a dictionary is generally not very convenient, so let's convert it to a Zig struct. Declare the following layer at the bottom of your file: ```zig const Mlp = struct { up_proj: zml.nn.Linear, gate_proj: zml.nn.Linear, down_proj: zml.nn.Linear, }; ``` The `zml.nn.Linear` is the equivalent of `torch.nn.Linear` and is defined by its `weight` and optional `bias` tensors. To create such a struct from our `model_weights` dictionary, we can use the `zml.aio.populateModelWithPrefix` helper: ```zig pub fn asyncMain() !void { ... const mlp_shape = try zml.aio.populateModelWithPrefix(Mlp, allocator, model_weights, "model.layers.0.mlp"); log.info("layer.0.mlp: {}", .{mlp_shape}); } ``` Build and run, using previous commands. Typical errors are of the form _"Layer not found: ..."_. This is typically due to the naming of layers in Zig not matching the naming in the file. Double-check everything and don't hesitate to print more things, e.g. in the Python script. Alternatively, Huggingface's web-interface allows to peek into `.safetensor` files. ## Testing an individual layer Finally, we are going to write the actual math code for our `MLP` layer. ```zig const Mlp = struct { up_proj: zml.nn.Linear, gate_proj: zml.nn.Linear, down_proj: zml.nn.Linear, pub fn forward(self: Mlp, x: Tensor) Tensor { const proj = zml.call(self.up_proj, .forward, .{x}); var output = zml.call(self.gate_proj, .forward, .{x}); output = output.silu().mul(proj); return zml.call(self.down_proj, .forward, .{output}); } }; ``` Note that we use `zml.call` instead of directly calling `self.up_proj.forward(x)`. Calling `forward` directly results in the same computation happening at runtime; but going through `zml.call` allows ZML to generate an MLIR representation that is closer to the Zig code and therefore easier to read. We can test the MLP layer with the `zml.testing.testLayer` utility: ```zig pub fn asyncMain() !void { ... var ctx = try zml.Context.init(); defer ctx.deinit(); const platform = ctx.autoPlatform(.{}); const mlp_weights = try zml.aio.loadModelBuffers(Mlp, mlp_shape, model_weights, allocator, platform); zml.testing.testLayer(platform, activations, "model.layers.0.mlp", mlp_shape, mlp_weights, 1e-3); } ``` During this phase, you have three kinds of errors that can appear: * Zig compilation errors: we've all been there, learning a new language can be tough. Normally, the compiler should help you figure out what's wrong. You can also check [ZML concepts](../learn/concepts.md) that explains types used by ZML. * Buffer not found errors: be careful that you need to use the naming scheme of the inference pipeline when loading the activations. Depending on how you write your code, you may have a different naming convention in the model file and in the activation file. This is because in Python, and in particular the `transformers` library, it's not uncommon to wrap the model in a `Pipeline` object before using it. So a given layer may be named `layer.0.mlp` in the model file, but its activations may be saved under `model.layer.0.mlp`. * MLIR compilation errors: typically this is caused by a mathematical error in the `forward` function. To help here, you can log the shapes of the input and intermediary values: `std.log.info("x: {}", .{x})`, and put similar print statements in the Python code. You can also consider splitting a big layer into smaller parts. Since our code only explicitly captures `torch.nn.Module` input/output, you may need to modify the Python script to add some extra tensors to the dictionary with example input/output of a specific function. ## General tips * Porting models can be hard, especially if the original code is messy, has poor comments, behaves differently on different input shapes, or has unused code paths. Start by identifying parts of the Python code which are **unused**. It is common in research code that some code paths were written for one paper, but didn't get used in subsequent papers. * ZML offers a few Pytorch specific helpers in `zml.torch`; those operators are offered to help you port models, but in general they may have weird APIs. If you're lucky and the code you are porting has comments indicating "tags", eg "C,W,H" of tensors, you can port this to actual tensor attributes using `x.withTags(.{.c, .w, .h})`, and use those tags (eg `.c`) to refer to axes instead of offsets. E.g. in Pytorch: `x.sum(0) # reduce over channel axis` becomes `x.sum(.c)`. More on this topic in ["Working with tensors"](../tutorials/working_with_tensors.md).