Source code for symm_learning.nn.module
from __future__ import annotations
import torch
[docs]
class eModule(torch.nn.Module):
"""Lightweight base class for equivariant modules.
This base centralizes lifecycle behavior related to cache validity and mode transitions.
Subclasses are expected to:
- define ``in_rep`` and ``out_rep`` when ``requires_reps=True``
- optionally override ``invalidate_cache``
"""
requires_reps: bool = True
def __getattribute__(self, name): # noqa: D105
if name in ("in_rep", "out_rep"):
try:
return super().__getattribute__(name)
except AttributeError as exc:
try:
requires_reps = super().__getattribute__("requires_reps")
except AttributeError:
requires_reps = True
if requires_reps:
raise AttributeError(
f"{self.__class__.__name__} did not define `{name}`. "
"Equivariant modules are expected to define `in_rep` and `out_rep` in the main constructor "
"(`__init__`)."
) from exc
raise
return super().__getattribute__(name)
[docs]
def invalidate_cache(self) -> None:
"""Clear derived cached tensors so they are recomputed on next use."""
def train(self, mode: bool = True): # noqa: D102
result = super().train(mode)
self.invalidate_cache()
return result
def _apply(self, fn):
result = super()._apply(fn)
self.invalidate_cache()
return result
def _load_from_state_dict(self, *args, **kwargs): # noqa: D102
# Use the recursive load path itself so submodules are covered when a parent
# calls `load_state_dict(...)`.
result = super()._load_from_state_dict(*args, **kwargs)
self.invalidate_cache()
return result