Fix CollectionOver scope error in ActivationCollector and clean dead code/comments in zml_utils.py

This commit is contained in:
Tarry Singh 2023-01-10 09:43:03 +00:00
parent 04ad137417
commit 48b671f100

View File

@ -13,13 +13,11 @@
# limitations under the License.
import builtins
import enum
import functools
import inspect
import logging
import re
import torch
from torch import Tensor
log = logging.getLogger(__name__)
@ -33,10 +31,6 @@ def log_and_open(file, *args, **kwargs):
builtins.open = log_and_open
class CollectionOver(Exception):
pass
class ActivationCollector:
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
@ -52,6 +46,9 @@ class ActivationCollector:
* blacklist_regexes: skeep layers matching any of the regexes
"""
class CollectionOver(Exception):
pass
def __init__(
self,
model,
@ -77,22 +74,20 @@ class ActivationCollector:
hook = torch.nn.modules.module.register_module_forward_hook(
self.log_activation_hook
)
# modeling_llama.apply_rotary_pos_emb = my_rot
try:
res = self.model(*args, **kwargs)
except ActivationCollector.CollectionOver:
res = None
finally:
# modeling_llama.apply_rotary_pos_emb = rot
hook.remove()
tensors = {}
# print(module_outs)
for name, outputs, inputs in self.outs.values():
# Only save first layer for a smaller file.
for blacklist in self.blacklist_regexes:
if re.match(blacklist, name):
# print("skipping:", name)
continue
if name == "":
@ -108,8 +103,7 @@ class ActivationCollector:
tensors[f"{name}.out.{idx}"] = out
for k, v in tensors.items():
if k.endswith(".out.1"):
print(k, "->", v.shape)
print(k, "->", v.shape)
return res, tensors
@ -117,9 +111,8 @@ class ActivationCollector:
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
if self.stop_after_first_step and prev_out is not None:
print(f"stopping collection cause {name} was already recorded")
print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`")
raise ActivationCollector.CollectionOver()
return
if prev_out is None:
self.count += 1