Fix CollectionOver scope error in ActivationCollector and clean dead code/comments in zml_utils.py

This commit is contained in:
Tarry Singh 2023-01-10 09:43:03 +00:00
parent 04ad137417
commit 48b671f100

View File

@ -13,13 +13,11 @@
# limitations under the License. # limitations under the License.
import builtins import builtins
import enum import enum
import functools
import inspect import inspect
import logging import logging
import re import re
import torch import torch
from torch import Tensor
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -33,10 +31,6 @@ def log_and_open(file, *args, **kwargs):
builtins.open = log_and_open builtins.open = log_and_open
class CollectionOver(Exception):
pass
class ActivationCollector: class ActivationCollector:
"""Wrap a given torch.nn.Module and collect all its intermediary activations. """Wrap a given torch.nn.Module and collect all its intermediary activations.
@ -52,6 +46,9 @@ class ActivationCollector:
* blacklist_regexes: skeep layers matching any of the regexes * blacklist_regexes: skeep layers matching any of the regexes
""" """
class CollectionOver(Exception):
pass
def __init__( def __init__(
self, self,
model, model,
@ -77,22 +74,20 @@ class ActivationCollector:
hook = torch.nn.modules.module.register_module_forward_hook( hook = torch.nn.modules.module.register_module_forward_hook(
self.log_activation_hook self.log_activation_hook
) )
# modeling_llama.apply_rotary_pos_emb = my_rot
try: try:
res = self.model(*args, **kwargs) res = self.model(*args, **kwargs)
except ActivationCollector.CollectionOver: except ActivationCollector.CollectionOver:
res = None res = None
finally: finally:
# modeling_llama.apply_rotary_pos_emb = rot
hook.remove() hook.remove()
tensors = {} tensors = {}
# print(module_outs)
for name, outputs, inputs in self.outs.values(): for name, outputs, inputs in self.outs.values():
# Only save first layer for a smaller file. # Only save first layer for a smaller file.
for blacklist in self.blacklist_regexes: for blacklist in self.blacklist_regexes:
if re.match(blacklist, name): if re.match(blacklist, name):
# print("skipping:", name)
continue continue
if name == "": if name == "":
@ -108,8 +103,7 @@ class ActivationCollector:
tensors[f"{name}.out.{idx}"] = out tensors[f"{name}.out.{idx}"] = out
for k, v in tensors.items(): for k, v in tensors.items():
if k.endswith(".out.1"): print(k, "->", v.shape)
print(k, "->", v.shape)
return res, tensors return res, tensors
@ -117,9 +111,8 @@ class ActivationCollector:
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None)) name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
if self.stop_after_first_step and prev_out is not None: if self.stop_after_first_step and prev_out is not None:
print(f"stopping collection cause {name} was already recorded") print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`")
raise ActivationCollector.CollectionOver() raise ActivationCollector.CollectionOver()
return
if prev_out is None: if prev_out is None:
self.count += 1 self.count += 1