259 lines
7.8 KiB
Python
259 lines
7.8 KiB
Python
# Copyright 2024 ZML
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import builtins
|
|
import enum
|
|
import inspect
|
|
import logging
|
|
import re
|
|
|
|
import torch
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
class ActivationCollector:
|
|
"""Wrap a given torch.nn.Module and collect all its intermediary activations.
|
|
|
|
Usage:
|
|
|
|
collector = zml_utils.ActivationCollector(model, **collection_config)
|
|
model_output, activations = collector(model_input)
|
|
zml_utils.save_with_confirmation("activations.pt", activations)
|
|
|
|
Args:
|
|
* max_layers: stop collecting activations after this many collected
|
|
* stop_after_first_step: if a layer is called twice (typically for generative model), stop immediately
|
|
* blacklist_regexes: skeep layers matching any of the regexes
|
|
"""
|
|
|
|
class CollectionOver(Exception):
|
|
pass
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
*,
|
|
max_layers: int = -1,
|
|
stop_after_first_step: bool = False,
|
|
blacklist_regexes: list[str] = [r".*\.(\d\d+)\.", r".*\.[1-9]\."],
|
|
):
|
|
self.model = model
|
|
self.max_layers = max_layers
|
|
self.stop_after_first_step = stop_after_first_step
|
|
self.blacklist_regexes = blacklist_regexes
|
|
self.count = 0
|
|
mods = named_modules(model)
|
|
self.outs = {id(module): (name, None, None) for name, module in mods}
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
"""Call the wrapped model with the given arguments.
|
|
|
|
Return the model output and the activations.
|
|
"""
|
|
self.count = 0
|
|
hook = torch.nn.modules.module.register_module_forward_hook(
|
|
self.log_activation_hook
|
|
)
|
|
|
|
try:
|
|
res = self.model(*args, **kwargs)
|
|
except ActivationCollector.CollectionOver:
|
|
res = None
|
|
finally:
|
|
hook.remove()
|
|
|
|
tensors = {}
|
|
|
|
for name, outputs, inputs in self.outs.values():
|
|
# Only save first layer for a smaller file.
|
|
for blacklist in self.blacklist_regexes:
|
|
if re.match(blacklist, name):
|
|
continue
|
|
|
|
if name == "":
|
|
# Skip the softmax output
|
|
continue
|
|
if (outputs, inputs) == (None, None):
|
|
# print(f"no inputs/outputs for {name}")
|
|
continue
|
|
for idx, inp in enumerate(inputs):
|
|
tensors[f"{name}.in.{idx}"] = inp
|
|
|
|
for idx, out in enumerate(outputs):
|
|
tensors[f"{name}.out.{idx}"] = out
|
|
|
|
for k, v in tensors.items():
|
|
print(k, "->", v.shape)
|
|
|
|
return res, tensors
|
|
|
|
def log_activation_hook(self, module, input, out) -> None:
|
|
name, prev_out, prev_in = self.outs.get(id(module), (None, None, None))
|
|
|
|
if self.stop_after_first_step and prev_out is not None:
|
|
print(f"stopping collection cause {name} was already recorded or stop_after_first_step was set to `True`")
|
|
raise ActivationCollector.CollectionOver()
|
|
|
|
if prev_out is None:
|
|
self.count += 1
|
|
|
|
if name is None:
|
|
print("err: unknown module", module.__class__)
|
|
breakpoint()
|
|
return
|
|
|
|
assert out is not None
|
|
outs = [o.detach().cpu() for o in _flatten(out)]
|
|
inputs = [i.detach().cpu() for i in _flatten(input)]
|
|
|
|
kwargs = inspect.stack()[1].frame.f_locals["kwargs"]
|
|
extra_inputs = [i.detach().cpu() for i in _flatten(kwargs)]
|
|
|
|
self.outs[id(module)] = (name, outs, inputs + extra_inputs)
|
|
if 0 < self.max_layers < self.count:
|
|
print(f"stopping collection cause we got {self.count} activations already")
|
|
raise ActivationCollector.CollectionOver()
|
|
|
|
|
|
def save_with_confirmation(filename: str, tensors: dict):
|
|
"""Regular torch.save with a CLI confirmation."""
|
|
sizes = [(v.numel() * v.dtype.itemsize, k) for k, v in tensors.items()]
|
|
sizes.sort()
|
|
disk_size = sum(s for s, k in sizes)
|
|
|
|
GB = 1024**3
|
|
print(f"About to write {disk_size/ GB:.3f}GB at {filename}. Biggest tensors:")
|
|
print(sizes[-20:])
|
|
print("Enter `c` to continue, `q` to quit.")
|
|
|
|
breakpoint()
|
|
torch.save(tensors, filename)
|
|
|
|
|
|
def _flatten(out):
|
|
if out is None:
|
|
return []
|
|
elif isinstance(out, torch.Tensor):
|
|
outs = [out]
|
|
elif isinstance(out, tuple):
|
|
outs = []
|
|
for x in out:
|
|
outs.extend(_flatten(x))
|
|
elif isinstance(out, dict):
|
|
outs = []
|
|
for x in out.values():
|
|
outs.extend(_flatten(x))
|
|
else:
|
|
outs = []
|
|
return outs
|
|
|
|
|
|
def named_modules(model):
|
|
if hasattr(model, "named_modules"):
|
|
return model.named_modules()
|
|
|
|
else:
|
|
root_modules = [
|
|
(k, v) for k, v in model.__dict__.items() if isinstance(v, torch.nn.Module)
|
|
]
|
|
for root, mod in root_modules:
|
|
for k, v in mod.named_modules():
|
|
if k:
|
|
yield f"{root}.{k}", v
|
|
else:
|
|
yield root, v
|
|
|
|
|
|
def read_layer_config(model: torch.nn.Module) -> dict:
|
|
layer_config = {}
|
|
|
|
def _append_node_config(node, prefix: str) -> None:
|
|
for k, v in node.__dict__.items():
|
|
# Skip special members. In particular all children module and tensors
|
|
# will be hidden in special dicts `_parameters` and `_modules`
|
|
if k.startswith("_"):
|
|
continue
|
|
# All modules have a "training" flag
|
|
if k in ("training", "init_fn"):
|
|
continue
|
|
if v is None:
|
|
continue
|
|
|
|
if not is_basic_type(v):
|
|
log.warning(f"Skipping layer config {k}={v!r}")
|
|
continue
|
|
layer_config[prefix + k] = v
|
|
|
|
_append_node_config(model, "")
|
|
for name, node in find_children(model, torch.nn.Module):
|
|
_append_node_config(node, name + ".")
|
|
|
|
return layer_config
|
|
|
|
|
|
def find_children(model: torch.nn.Module, t: type, layer_filter: str = "") -> list:
|
|
queue = list(model._modules.items())
|
|
modules = []
|
|
while queue:
|
|
name, node = queue.pop()
|
|
if node is None:
|
|
continue
|
|
if layer_filter and not re.match(layer_filter, name):
|
|
continue
|
|
if isinstance(node, t):
|
|
modules.append((name, node))
|
|
for child_name, child_node in node._modules.items():
|
|
queue.append((".".join((name, child_name)), child_node))
|
|
|
|
return modules
|
|
|
|
|
|
def is_basic_type(value) -> bool:
|
|
if isinstance(value, int):
|
|
return True
|
|
if isinstance(value, float):
|
|
return True
|
|
if isinstance(value, bool):
|
|
return True
|
|
if isinstance(value, enum.Enum):
|
|
return True
|
|
if isinstance(value, tuple) and len(value) == 1:
|
|
return True
|
|
if isinstance(value, str) and len(value) < 8:
|
|
return True
|
|
return False
|
|
|
|
|
|
def pdb_persistent(name, fn, *args, **kwargs):
|
|
"""Cache that can survive through a PDB restart.
|
|
|
|
Useful when debugging to avoid reloading models all the time.
|
|
"""
|
|
import sys
|
|
|
|
pdb = sys.modules.get("pdb", None)
|
|
if pdb is None:
|
|
return fn(*args, **kwargs)
|
|
|
|
if not hasattr(pdb, "__cache__"):
|
|
setattr(pdb, "__cache__", {})
|
|
|
|
cache = getattr(pdb, "__cache__")
|
|
entry = cache.get(name)
|
|
if entry is not None:
|
|
return entry
|
|
|
|
res = fn(*args, **kwargs)
|
|
cache[name] = res
|
|
return res
|