diff --git a/tools/zml_utils.py b/tools/zml_utils.py index 10e3216..a0053e7 100644 --- a/tools/zml_utils.py +++ b/tools/zml_utils.py @@ -13,13 +13,11 @@ # limitations under the License. import builtins import enum -import functools import inspect import logging import re import torch -from torch import Tensor log = logging.getLogger(__name__) @@ -33,10 +31,6 @@ def log_and_open(file, *args, **kwargs): builtins.open = log_and_open - -class CollectionOver(Exception): - pass - class ActivationCollector: """Wrap a given torch.nn.Module and collect all its intermediary activations. @@ -52,6 +46,9 @@ class ActivationCollector: * blacklist_regexes: skeep layers matching any of the regexes """ + class CollectionOver(Exception): + pass + def __init__( self, model, @@ -77,22 +74,20 @@ class ActivationCollector: hook = torch.nn.modules.module.register_module_forward_hook( self.log_activation_hook ) - # modeling_llama.apply_rotary_pos_emb = my_rot + try: res = self.model(*args, **kwargs) except ActivationCollector.CollectionOver: res = None finally: - # modeling_llama.apply_rotary_pos_emb = rot hook.remove() tensors = {} - # print(module_outs) + for name, outputs, inputs in self.outs.values(): # Only save first layer for a smaller file. for blacklist in self.blacklist_regexes: if re.match(blacklist, name): - # print("skipping:", name) continue if name == "": @@ -108,8 +103,7 @@ class ActivationCollector: tensors[f"{name}.out.{idx}"] = out for k, v in tensors.items(): - if k.endswith(".out.1"): - print(k, "->", v.shape) + print(k, "->", v.shape) return res, tensors @@ -117,9 +111,8 @@ class ActivationCollector: name, prev_out, prev_in = self.outs.get(id(module), (None, None, None)) if self.stop_after_first_step and prev_out is not None: - print(f"stopping collection cause {name} was already recorded") + print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`") raise ActivationCollector.CollectionOver() - return if prev_out is None: self.count += 1