Fix CollectionOver scope error in ActivationCollector and clean dead code/comments in zml_utils.py
This commit is contained in:
parent
04ad137417
commit
48b671f100
@ -13,13 +13,11 @@
|
||||
# limitations under the License.
|
||||
import builtins
|
||||
import enum
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -33,10 +31,6 @@ def log_and_open(file, *args, **kwargs):
|
||||
|
||||
builtins.open = log_and_open
|
||||
|
||||
|
||||
class CollectionOver(Exception):
|
||||
pass
|
||||
|
||||
class ActivationCollector:
|
||||
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
||||
|
||||
@ -52,6 +46,9 @@ class ActivationCollector:
|
||||
* blacklist_regexes: skeep layers matching any of the regexes
|
||||
"""
|
||||
|
||||
class CollectionOver(Exception):
|
||||
pass
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -77,22 +74,20 @@ class ActivationCollector:
|
||||
hook = torch.nn.modules.module.register_module_forward_hook(
|
||||
self.log_activation_hook
|
||||
)
|
||||
# modeling_llama.apply_rotary_pos_emb = my_rot
|
||||
|
||||
try:
|
||||
res = self.model(*args, **kwargs)
|
||||
except ActivationCollector.CollectionOver:
|
||||
res = None
|
||||
finally:
|
||||
# modeling_llama.apply_rotary_pos_emb = rot
|
||||
hook.remove()
|
||||
|
||||
tensors = {}
|
||||
# print(module_outs)
|
||||
|
||||
for name, outputs, inputs in self.outs.values():
|
||||
# Only save first layer for a smaller file.
|
||||
for blacklist in self.blacklist_regexes:
|
||||
if re.match(blacklist, name):
|
||||
# print("skipping:", name)
|
||||
continue
|
||||
|
||||
if name == "":
|
||||
@ -108,7 +103,6 @@ class ActivationCollector:
|
||||
tensors[f"{name}.out.{idx}"] = out
|
||||
|
||||
for k, v in tensors.items():
|
||||
if k.endswith(".out.1"):
|
||||
print(k, "->", v.shape)
|
||||
|
||||
return res, tensors
|
||||
@ -117,9 +111,8 @@ class ActivationCollector:
|
||||
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
|
||||
|
||||
if self.stop_after_first_step and prev_out is not None:
|
||||
print(f"stopping collection cause {name} was already recorded")
|
||||
print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`")
|
||||
raise ActivationCollector.CollectionOver()
|
||||
return
|
||||
|
||||
if prev_out is None:
|
||||
self.count += 1
|
||||
|
||||
Loading…
Reference in New Issue
Block a user