Fix CollectionOver scope error in ActivationCollector and clean dead code/comments in zml_utils.py
This commit is contained in:
parent
04ad137417
commit
48b671f100
@ -13,13 +13,11 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import builtins
|
import builtins
|
||||||
import enum
|
import enum
|
||||||
import functools
|
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -33,10 +31,6 @@ def log_and_open(file, *args, **kwargs):
|
|||||||
|
|
||||||
builtins.open = log_and_open
|
builtins.open = log_and_open
|
||||||
|
|
||||||
|
|
||||||
class CollectionOver(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class ActivationCollector:
|
class ActivationCollector:
|
||||||
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
||||||
|
|
||||||
@ -52,6 +46,9 @@ class ActivationCollector:
|
|||||||
* blacklist_regexes: skeep layers matching any of the regexes
|
* blacklist_regexes: skeep layers matching any of the regexes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class CollectionOver(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -77,22 +74,20 @@ class ActivationCollector:
|
|||||||
hook = torch.nn.modules.module.register_module_forward_hook(
|
hook = torch.nn.modules.module.register_module_forward_hook(
|
||||||
self.log_activation_hook
|
self.log_activation_hook
|
||||||
)
|
)
|
||||||
# modeling_llama.apply_rotary_pos_emb = my_rot
|
|
||||||
try:
|
try:
|
||||||
res = self.model(*args, **kwargs)
|
res = self.model(*args, **kwargs)
|
||||||
except ActivationCollector.CollectionOver:
|
except ActivationCollector.CollectionOver:
|
||||||
res = None
|
res = None
|
||||||
finally:
|
finally:
|
||||||
# modeling_llama.apply_rotary_pos_emb = rot
|
|
||||||
hook.remove()
|
hook.remove()
|
||||||
|
|
||||||
tensors = {}
|
tensors = {}
|
||||||
# print(module_outs)
|
|
||||||
for name, outputs, inputs in self.outs.values():
|
for name, outputs, inputs in self.outs.values():
|
||||||
# Only save first layer for a smaller file.
|
# Only save first layer for a smaller file.
|
||||||
for blacklist in self.blacklist_regexes:
|
for blacklist in self.blacklist_regexes:
|
||||||
if re.match(blacklist, name):
|
if re.match(blacklist, name):
|
||||||
# print("skipping:", name)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if name == "":
|
if name == "":
|
||||||
@ -108,7 +103,6 @@ class ActivationCollector:
|
|||||||
tensors[f"{name}.out.{idx}"] = out
|
tensors[f"{name}.out.{idx}"] = out
|
||||||
|
|
||||||
for k, v in tensors.items():
|
for k, v in tensors.items():
|
||||||
if k.endswith(".out.1"):
|
|
||||||
print(k, "->", v.shape)
|
print(k, "->", v.shape)
|
||||||
|
|
||||||
return res, tensors
|
return res, tensors
|
||||||
@ -117,9 +111,8 @@ class ActivationCollector:
|
|||||||
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
|
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
|
||||||
|
|
||||||
if self.stop_after_first_step and prev_out is not None:
|
if self.stop_after_first_step and prev_out is not None:
|
||||||
print(f"stopping collection cause {name} was already recorded")
|
print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`")
|
||||||
raise ActivationCollector.CollectionOver()
|
raise ActivationCollector.CollectionOver()
|
||||||
return
|
|
||||||
|
|
||||||
if prev_out is None:
|
if prev_out is None:
|
||||||
self.count += 1
|
self.count += 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user