InvariantBias#

class InvariantBias(in_rep)[source]#

Bases: Module

Module parameterizing a learnable \(\mathbb{G}\)-invariant bias.

For representation space \(\mathcal{X}\), this module enforces \(\rho_{\mathcal{X}}(g)\mathbf{b}=\mathbf{b}\) for all \(g\in\mathbb{G}\). Hence only trivial-irrep coordinates in the irrep-spectral basis carry free parameters.

If the input representation does not contain the trivial irrep (no trivial/invariant subspace), the module behaves as the identity function.

Note

Runtime behavior depends on mode. In training mode (model.train()), the invariant bias is recomputed each forward pass. In inference mode (model.eval()), the expanded invariant bias is cached and reused until bias_dof changes or invalidate_cache() is called, which is faster. With the cache active, the forward path is the same computation as the standard symmetry-agnostic bias add input + b with fixed b.

in_rep#

Representation defining the symmetry action on \(\mathcal{X}\).

Type:

Representation

out_rep#

Same as in_rep (bias acts in the same space).

Type:

Representation

has_bias#

True iff the trivial irrep is present and a learnable invariant bias exists.

Type:

bool

bias_dof#

Learnable trivial-subspace coefficients (present only if has_bias=True).

Type:

Parameter

Construct the invariant bias module.

Parameters:

in_rep (Representation) – Representation \(\rho_{\text{in}}\) of the input space (same as output space).

property bias#

Invariant bias; recomputed in training, cached otherwise.

expand_bias()[source]#

Expand the learnable parameters into the invariant bias in the original basis.

expand_bias_spectral_basis()[source]#

Return the invariant bias expressed in the irrep-spectral basis.

forward(input)[source]#

Apply the invariant bias.

Parameters:

input (Tensor) – Tensor whose last dimension equals in_rep.size.

Returns:

Output tensor with the same shape as input.

Return type:

Tensor

invalidate_cache()[source]#

Clear cached bias so it is recomputed on next use.

Return type:

None

load_state_dict(state_dict, strict=True)[source]#

Load parameters and invalidate cached expanded bias.

Parameters:

strict (bool)

refresh_eval_cache()[source]#

Ensure eval-mode cache is populated.

Return type:

None

reset_parameters(scheme='zeros')[source]#

Initialize the invariant bias degrees of freedom.

train(mode=True)[source]#

Switch between training and evaluation modes, managing cache appropriately.

Parameters:

mode (bool)