Source code for symm_learning.models.transformer.etransformer
from __future__ import annotations
import logging
from collections.abc import Callable
from math import ceil
from typing import Literal
import torch
import torch.nn.functional as F
from escnn.group import Representation
# from torch.nn import Transformer
from symm_learning.nn.activation import eMultiheadAttention
from symm_learning.nn.linear import eLinear
from symm_learning.nn.normalization import eLayerNorm, eRMSNorm
from symm_learning.representation_theory import direct_sum
logger = logging.getLogger(__name__)
[docs]
class eTransformerEncoderLayer(torch.nn.Module):
r"""Equivariant Transformer encoder layer with the same API as ``torch.nn.TransformerEncoderLayer``.
Applies :class:`~symm_learning.nn.activation.eMultiheadAttention` followed by an equivariant feed-forward block
built from :class:`~symm_learning.nn.linear.eLinear` layers and
:class:`~symm_learning.nn.normalization.eLayerNorm`, mirroring PyTorch’s ordering
(pre- or post-norm) while constraining every linear map to commute with the group action.
The layer defines:
.. math::
\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T} \to \mathcal{X}^{T}.
Functional equivariance constraint:
.. math::
\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{x}_{1:T})
= \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{x}_{1:T})
\quad \forall g\in\mathbb{G},
where :math:`\rho_{\mathcal{X}}(g)` acts on the feature/channel axis at every token.
"""
__constants__ = ["norm_first"]
def __init__(
self,
in_rep: Representation,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str | Callable[[torch.Tensor], torch.Tensor] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
norm_first: bool = True,
norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm",
bias: bool = True,
device=None,
dtype=None,
init_scheme: str | None = "xavier_uniform",
) -> None:
r"""Create an equivariant Transformer encoder layer.
Args:
in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}`.
nhead: Number of attention heads.
dim_feedforward: Hidden dimension of the feedforward network.
dropout: Dropout probability.
activation: Activation function (``'relu'`` or ``'gelu'``).
layer_norm_eps: Epsilon for layer normalization.
batch_first: If ``True``, input/output shape is ``(B, T, D)``.
norm_first: If ``True``, apply normalization before attention/feedforward.
norm_module: Normalization layer type (``'layernorm'`` or ``'rmsnorm'``).
bias: Whether to use bias in linear layers.
device: Tensor device.
dtype: Tensor dtype.
init_scheme: Initialization scheme for equivariant layers.
"""
super().__init__()
if dim_feedforward <= 0:
raise ValueError(f"dim_feedforward must be positive, got {dim_feedforward}")
self.in_rep, self.out_rep = in_rep, in_rep
factory_kwargs = {"device": device, "dtype": dtype or torch.get_default_dtype()}
G = in_rep.group
num_hidden_reps = max(1, ceil(dim_feedforward / G.order()))
self.embedding_rep = direct_sum([G.regular_representation] * num_hidden_reps)
self.hidden_dim = self.embedding_rep.size
self.requested_dim_feedforward = dim_feedforward
self.self_attn = eMultiheadAttention(
in_rep=self.in_rep,
num_heads=nhead,
dropout=dropout,
bias=bias,
batch_first=batch_first,
device=device,
dtype=dtype,
init_scheme=init_scheme,
)
self.linear1 = eLinear(self.in_rep, self.embedding_rep, bias, init_scheme=init_scheme).to(**factory_kwargs)
self.linear2 = eLinear(self.embedding_rep, self.out_rep, bias, init_scheme=init_scheme).to(**factory_kwargs)
self.dropout = torch.nn.Dropout(dropout)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
if norm_module == "layernorm":
norm_cls = eLayerNorm
norm_kwargs = {"bias": bias} | factory_kwargs
raise ValueError("eLayerNorm is numerically unstable. Use eRMSNorm instead for now.")
elif norm_module == "rmsnorm":
norm_cls = eRMSNorm
norm_kwargs = factory_kwargs
else:
raise ValueError(f"norm_module must be 'layernorm' or 'rmsnorm', got {norm_module}")
self.norm1 = norm_cls(self.in_rep, eps=layer_norm_eps, equiv_affine=True, **norm_kwargs)
self.norm2 = norm_cls(self.out_rep, eps=layer_norm_eps, equiv_affine=True, **norm_kwargs)
self.norm_first = norm_first
self.batch_first = batch_first
if isinstance(activation, str):
activation = _get_activation_fn(activation)
self.activation = activation
if init_scheme is not None:
self.reset_parameters(scheme=init_scheme)
[docs]
def forward(
self,
src: torch.Tensor,
src_mask: torch.Tensor | None = None,
src_key_padding_mask: torch.Tensor | None = None,
is_causal: bool = False,
) -> torch.Tensor:
r"""Pass the input through the equivariant encoder layer.
Args:
src: input sequence of shape ``(T, B, D)`` or ``(B, T, D)`` depending on ``batch_first``,
with last dimension equal to ``in_rep.size``.
src_mask: optional attention mask for the input sequence.
src_key_padding_mask: optional padding mask for the batch.
is_causal: if ``True``, applies a causal mask to the self-attention block.
"""
src_key_padding_mask = F._canonical_mask(
mask=src_key_padding_mask,
mask_name="src_key_padding_mask",
other_type=F._none_or_dtype(src_mask),
other_name="src_mask",
target_type=src.dtype,
)
src_mask = F._canonical_mask(
mask=src_mask,
mask_name="src_mask",
other_type=None,
other_name="",
target_type=src.dtype,
check_other=False,
)
x = src
if self.norm_first:
x = x + self._self_attention_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal)
x = x + self._feed_forward_block(self.norm2(x))
else:
x = self.norm1(x + self._self_attention_block(x, src_mask, src_key_padding_mask, is_causal=is_causal))
x = self.norm2(x + self._feed_forward_block(x))
return x
def _self_attention_block(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
is_causal: bool = False,
) -> torch.Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
)[0]
return self.dropout1(x)
def _feed_forward_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
@torch.no_grad()
def reset_parameters(self, scheme="xavier_uniform") -> None: # noqa: D102
logger.debug(f"Resetting parameters of {self.__class__.__name__} with scheme: {scheme}")
self.linear1.reset_parameters(scheme)
self.linear2.reset_parameters(scheme)
self.norm1.reset_parameters()
self.norm2.reset_parameters()
# Reset attention layers:
self.self_attn.reset_parameters(scheme)
[docs]
class eTransformerDecoderLayer(torch.nn.Module):
r"""Equivariant Transformer decoder layer mirroring :class:`torch.nn.TransformerDecoderLayer`.
Combines an equivariant self-attention block, an equivariant cross-attention block,
and the same :class:`~symm_learning.nn.linear.eLinear`/
:class:`~symm_learning.nn.normalization.eLayerNorm` feed-forward structure used by the encoder so every
submodule commutes with the group action while keeping PyTorch’s runtime logic intact.
The layer defines:
.. math::
\mathbf{f}_{\mathbf{\theta}}: \mathcal{X}^{T_{\mathrm{tgt}}}\times\mathcal{X}^{T_{\mathrm{mem}}}
\to \mathcal{X}^{T_{\mathrm{tgt}}}.
Functional equivariance constraint (assuming ``tgt`` and ``memory`` transform under the same representation):
.. math::
\mathbf{f}_{\mathbf{\theta}}(\rho_{\mathcal{X}}(g)\mathbf{tgt},\rho_{\mathcal{X}}(g)\mathbf{mem})
= \rho_{\mathcal{X}}(g)\,\mathbf{f}_{\mathbf{\theta}}(\mathbf{tgt},\mathbf{mem})
\quad \forall g\in\mathbb{G}.
"""
__constants__ = ["norm_first"]
def __init__(
self,
in_rep: Representation,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str | Callable[[torch.Tensor], torch.Tensor] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = True,
norm_first: bool = True,
norm_module: Literal["layernorm", "rmsnorm"] = "rmsnorm",
bias: bool = True,
device=None,
dtype=None,
init_scheme: str | None = "xavier_uniform",
) -> None:
r"""Create an equivariant Transformer decoder layer.
Args:
in_rep (:class:`~escnn.group.Representation`): Input representation :math:`\rho_{\text{in}}`.
nhead: Number of attention heads.
dim_feedforward: Hidden dimension of the feedforward network.
dropout: Dropout probability.
activation: Activation function (``'relu'`` or ``'gelu'``).
layer_norm_eps: Epsilon for layer normalization.
batch_first: If ``True``, input/output shape is ``(B, T, D)``.
norm_first: If ``True``, apply normalization before attention/feedforward.
norm_module: Normalization layer type (``'layernorm'`` or ``'rmsnorm'``).
bias: Whether to use bias in linear layers.
device: Tensor device.
dtype: Tensor dtype.
init_scheme: Initialization scheme for equivariant layers.
"""
super().__init__()
if dim_feedforward <= 0:
raise ValueError(f"dim_feedforward must be positive, got {dim_feedforward}")
self.in_rep, self.out_rep = in_rep, in_rep
factory_kwargs = {"device": device, "dtype": dtype or torch.get_default_dtype()}
G = in_rep.group
num_hidden_reps = max(1, ceil(dim_feedforward / G.order()))
self.embedding_rep = direct_sum([G.regular_representation] * num_hidden_reps)
self.hidden_dim = self.embedding_rep.size
self.requested_dim_feedforward = dim_feedforward
self.self_attn = eMultiheadAttention(
in_rep=self.in_rep,
num_heads=nhead,
dropout=dropout,
bias=bias,
batch_first=batch_first,
device=device,
dtype=dtype,
init_scheme=init_scheme,
)
self.cross_attn = eMultiheadAttention(
in_rep=self.in_rep,
num_heads=nhead,
dropout=dropout,
bias=bias,
batch_first=batch_first,
device=device,
dtype=dtype,
init_scheme=init_scheme,
)
self.linear1 = eLinear(self.in_rep, self.embedding_rep, bias, init_scheme=init_scheme).to(**factory_kwargs)
self.dropout = torch.nn.Dropout(dropout)
self.linear2 = eLinear(self.embedding_rep, self.out_rep, bias, init_scheme=init_scheme).to(**factory_kwargs)
self.norm_first = norm_first
if norm_module == "layernorm":
norm_cls = eLayerNorm
norm_kwargs = {"bias": bias} | factory_kwargs
elif norm_module == "rmsnorm":
norm_cls = eRMSNorm
norm_kwargs = factory_kwargs
else:
raise ValueError(f"norm_module must be 'layernorm' or 'rmsnorm', got {norm_module}")
self.norm1 = norm_cls(self.in_rep, eps=layer_norm_eps, equiv_affine=True, **norm_kwargs)
self.norm2 = norm_cls(self.in_rep, eps=layer_norm_eps, equiv_affine=True, **norm_kwargs)
self.norm3 = norm_cls(self.out_rep, eps=layer_norm_eps, equiv_affine=True, **norm_kwargs)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.dropout3 = torch.nn.Dropout(dropout)
if isinstance(activation, str):
activation = _get_activation_fn(activation)
self.activation = activation
self.batch_first = batch_first
if init_scheme is not None:
self.reset_parameters(scheme=init_scheme)
[docs]
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: torch.Tensor | None = None,
memory_mask: torch.Tensor | None = None,
tgt_key_padding_mask: torch.Tensor | None = None,
memory_key_padding_mask: torch.Tensor | None = None,
tgt_is_causal: bool = False,
memory_is_causal: bool = False,
) -> torch.Tensor:
r"""Pass the input through the equivariant decoder layer.
Args:
tgt: target/query tensor of shape ``(T, B, D)`` or ``(B, T, D)`` matching
``batch_first``. The last dimension must equal ``in_rep.size``.
memory: encoder memory tensor of shape ``(S, B, D)`` or ``(B, S, D)``
(same ``batch_first``). We assume this tensor transforms under the
*same representation* as ``tgt``; i.e., it is typically the output
of an equivariant encoder with representation ``in_rep``.
tgt_mask: optional target attention mask (same semantics as PyTorch’s API).
memory_mask: optional memory attention mask.
tgt_key_padding_mask: optional padding mask for the target batch.
memory_key_padding_mask: optional padding mask for the memory batch.
tgt_is_causal: if ``True``, applies a causal mask to the target self-attention.
memory_is_causal: if ``True``, applies a causal mask to the cross-attention.
"""
tgt_key_padding_mask = F._canonical_mask(
mask=tgt_key_padding_mask,
mask_name="tgt_key_padding_mask",
other_type=F._none_or_dtype(tgt_mask),
other_name="tgt_mask",
target_type=tgt.dtype,
)
tgt_mask = F._canonical_mask(
mask=tgt_mask,
mask_name="tgt_mask",
other_type=None,
other_name="",
target_type=tgt.dtype,
check_other=False,
)
memory_key_padding_mask = F._canonical_mask(
mask=memory_key_padding_mask,
mask_name="memory_key_padding_mask",
other_type=F._none_or_dtype(memory_mask),
other_name="memory_mask",
target_type=memory.dtype,
)
memory_mask = F._canonical_mask(
mask=memory_mask,
mask_name="memory_mask",
other_type=None,
other_name="",
target_type=memory.dtype,
check_other=False,
)
x = tgt
if self.norm_first:
x = x + self._self_attention_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._multihead_attention_block(
self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal
)
x = x + self._feed_forward_block(self.norm3(x))
else:
x = self.norm1(x + self._self_attention_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(
x + self._multihead_attention_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)
)
x = self.norm3(x + self._feed_forward_block(x))
return x
def _self_attention_block(
self,
x: torch.Tensor,
attn_mask: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
is_causal: bool = False,
) -> torch.Tensor:
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
)[0]
return self.dropout1(x)
def _multihead_attention_block(
self,
x: torch.Tensor,
mem: torch.Tensor,
attn_mask: torch.Tensor | None = None,
key_padding_mask: torch.Tensor | None = None,
is_causal: bool = False,
) -> torch.Tensor:
x = self.cross_attn(
x,
mem,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
is_causal=is_causal,
)[0]
return self.dropout2(x)
def _feed_forward_block(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
@torch.no_grad()
def reset_parameters(self, scheme="xavier_uniform") -> None: # noqa: D102
logger.debug(f"Resetting parameters of {self.__class__.__name__} with scheme: {scheme}")
# Reset equivariant linear layers (symm_learning.nn.eLinear)
self.linear1.reset_parameters(scheme)
self.linear2.reset_parameters(scheme)
self.norm1.reset_parameters()
self.norm2.reset_parameters()
self.norm3.reset_parameters()
# Reset attention layers:
self.self_attn.reset_parameters(scheme)
self.cross_attn.reset_parameters(scheme)
[docs]
@torch.no_grad()
def check_equivariance(
self,
batch_size: int = 4,
tgt_len: int = 3,
mem_len: int = 5,
samples: int = 20,
atol: float = 1e-4,
rtol: float = 1e-4,
) -> None:
"""Quick sanity check ensuring both attention blocks and the full layer are equivariant."""
G = self.in_rep.group
device = next(self.parameters()).device
def act(rep: Representation, g, tensor: torch.Tensor) -> torch.Tensor:
mat = torch.tensor(rep(g), dtype=tensor.dtype, device=device)
return torch.einsum("ij,...j->...i", mat, tensor)
for _ in range(samples):
g = G.sample()
tgt = torch.randn(batch_size, tgt_len, self.in_rep.size, device=device)
mem = torch.randn(batch_size, mem_len, self.in_rep.size, device=device)
g_tgt = act(self.in_rep, g, tgt)
g_mem = act(self.in_rep, g, mem)
sa = self._self_attention_block(tgt)
g_sa = act(self.in_rep, g, sa)
g_sa_exp = self._self_attention_block(g_tgt)
assert torch.allclose(g_sa, g_sa_exp, atol=atol, rtol=rtol), (
f"Self-attention equivarinace failed max error: {torch.max(g_sa - g_sa_exp).item():.3e}"
)
ca = self._multihead_attention_block(tgt, mem)
g_ca = act(self.in_rep, g, ca)
g_ca_exp = self._multihead_attention_block(g_tgt, g_mem)
assert torch.allclose(g_ca, g_ca_exp, atol=atol, rtol=rtol), (
f"Cross-attention equivarinace failed max error: {torch.max(g_ca - g_ca_exp).item():.3e}"
)
out = self(tgt, mem)
g_out = act(self.in_rep, g, out)
g_out_exp = self(g_tgt, g_mem)
assert torch.allclose(g_out, g_out_exp, atol=atol, rtol=rtol), (
f"Transormer decoder equivarinace failed max error: {torch.max(g_ca - g_ca_exp).item():.3e}"
)
def _get_activation_fn(activation: str) -> Callable[[torch.Tensor], torch.Tensor]:
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
raise RuntimeError(f"activation should be relu/gelu, not {activation}")
if __name__ == "__main__":
import logging
import sys
import types
from pathlib import Path
# logging.basicConfig(level=logging.DEBUG)
from escnn.group import CyclicGroup, DihedralGroup, Icosahedral
repo_root = Path(__file__).resolve().parents[3]
test_dir = repo_root / "test"
sys.path.insert(0, str(repo_root))
test_pkg = sys.modules.get("test")
test_paths = [str(path) for path in getattr(test_pkg, "__path__", [])] if test_pkg else []
if str(test_dir) not in test_paths:
test_pkg = types.ModuleType("test")
test_pkg.__path__ = [str(test_dir)]
sys.modules["test"] = test_pkg
from symm_learning.models.transformer.etransformer import eTransformerEncoderLayer
from symm_learning.utils import (
bytes_to_mb,
check_equivariance,
describe_memory,
module_device_memory,
module_memory,
)
from test.utils import benchmark, benchmark_eval_forward
G = CyclicGroup(2)
m = 2
in_rep = direct_sum([G.regular_representation] * m)
encoder_kwargs = dict(
in_rep=in_rep,
nhead=1,
dim_feedforward=in_rep.size * 4,
dropout=0.1,
activation="relu",
norm_first=True,
batch_first=True,
)
etransformer = eTransformerEncoderLayer(**encoder_kwargs)
etransformer.eval() # disable dropout for the test
# describe_memory("transformer encoder", etransformer)
check_equivariance(
lambda x: etransformer._feed_forward_block(x),
in_rep=etransformer.in_rep,
out_rep=etransformer.out_rep,
module_name="feed forward",
)
check_equivariance(
lambda x: etransformer._self_attention_block(x),
in_rep=etransformer.in_rep,
out_rep=etransformer.out_rep,
module_name="self_attention",
)
for depth in [1, 3, 5, 10]:
base_layer = eTransformerEncoderLayer(**encoder_kwargs)
base_layer.reset_parameters()
base_layer.eval()
encoder_stack = torch.nn.TransformerEncoder(
encoder_layer=base_layer, num_layers=depth, enable_nested_tensor=False
)
for layer in encoder_stack.layers:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
encoder_stack.eval()
print(f"\n Testing encoder stack depth={depth} equivariance...")
check_equivariance(
encoder_stack,
input_dim=3,
module_name=f"encoder stack depth={depth}",
atol=1e-4,
rtol=1e-4,
in_rep=in_rep,
out_rep=in_rep,
)
print(f"Encoder stack depth={depth} equivariance test passed")
print("\n\n\nTesting decoder layer equivariance...")
decoder_kwargs = dict(
in_rep=in_rep,
nhead=1,
dim_feedforward=in_rep.size * 2,
dropout=0.0,
activation="relu",
norm_first=True,
batch_first=True,
)
tdecoder = eTransformerDecoderLayer(**decoder_kwargs)
tdecoder.eval()
tdecoder.check_equivariance()
def check_decoder_stack(module: torch.nn.Module, rep: Representation, depth: int, atol=1e-4, rtol=1e-4): # noqa: D103
G = rep.group
def act(rep: Representation, g, tensor: torch.Tensor) -> torch.Tensor:
mat = torch.tensor(rep(g), dtype=tensor.dtype, device=tensor.device)
return torch.einsum("ij,...j->...i", mat, tensor)
B, tgt_len, mem_len = 11, 3, 5
module.eval()
for _ in range(10):
g = G.sample()
tgt = torch.randn(B, tgt_len, rep.size)
mem = torch.randn(B, mem_len, rep.size)
out = module(tgt=tgt, memory=mem)
g_tgt = act(rep, g, tgt)
g_mem = act(rep, g, mem)
g_out = module(tgt=g_tgt, memory=g_mem)
g_out_exp = act(rep, g, out)
assert torch.allclose(g_out, g_out_exp, atol=atol, rtol=rtol), (
f"Decoder stack depth={depth} equivariance failed, max err {(g_out - g_out_exp).abs().max().item():.3e}"
)
for depth in (1, 3, 5, 10):
base_layer = eTransformerDecoderLayer(**decoder_kwargs)
base_layer.reset_parameters()
base_layer.eval()
decoder_stack = torch.nn.TransformerDecoder(decoder_layer=base_layer, num_layers=depth)
for layer in decoder_stack.layers:
if hasattr(layer, "reset_parameters"):
layer.reset_parameters()
decoder_stack.eval()
print(f"\n Testing decoder stack depth={depth} equivariance...")
check_decoder_stack(decoder_stack, in_rep, depth=depth, atol=1e-3, rtol=1e-3)
print(f"Decoder stack depth={depth} equivariance test passed")
print("Decoder equivariance test passed")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_float32_matmul_precision("high")
print(f"\nBenchmarking on device: {device}")
fastpath_prev = torch.backends.mha.get_fastpath_enabled()
torch.backends.mha.set_fastpath_enabled(False)
batch_size = 256
src_len = 16
tgt_len = 16
mem_len = 16
iters = 200
warmup = 50
requested_dim_feedforward = in_rep.size * 4
effective_dim_feedforward = max(1, ceil(requested_dim_feedforward / in_rep.group.order())) * in_rep.group.order()
bench_encoder_kwargs = dict(
in_rep=in_rep,
nhead=1,
dim_feedforward=requested_dim_feedforward,
dropout=0.1,
activation="relu",
norm_first=True,
batch_first=True,
norm_module="rmsnorm",
bias=True,
)
bench_decoder_kwargs = dict(
in_rep=in_rep,
nhead=1,
dim_feedforward=requested_dim_feedforward,
dropout=0.1,
activation="relu",
norm_first=True,
batch_first=True,
norm_module="rmsnorm",
bias=True,
)
src = torch.randn(batch_size, src_len, in_rep.size, device=device)
tgt = torch.randn(batch_size, tgt_len, in_rep.size, device=device)
mem = torch.randn(batch_size, mem_len, in_rep.size, device=device)
encoder_modules = [
("eTransformer Encoder", eTransformerEncoderLayer(**bench_encoder_kwargs).to(device)),
(
"Torch Encoder",
torch.nn.TransformerEncoderLayer(
d_model=in_rep.size,
nhead=bench_encoder_kwargs["nhead"],
dim_feedforward=effective_dim_feedforward,
dropout=bench_encoder_kwargs["dropout"],
activation=bench_encoder_kwargs["activation"],
batch_first=bench_encoder_kwargs["batch_first"],
norm_first=bench_encoder_kwargs["norm_first"],
bias=bench_encoder_kwargs["bias"],
).to(device),
),
]
decoder_modules = [
("eTransformer Decoder", eTransformerDecoderLayer(**bench_decoder_kwargs).to(device)),
(
"Torch Decoder",
torch.nn.TransformerDecoderLayer(
d_model=in_rep.size,
nhead=bench_decoder_kwargs["nhead"],
dim_feedforward=effective_dim_feedforward,
dropout=bench_decoder_kwargs["dropout"],
activation=bench_decoder_kwargs["activation"],
batch_first=bench_decoder_kwargs["batch_first"],
norm_first=bench_decoder_kwargs["norm_first"],
bias=bench_decoder_kwargs["bias"],
).to(device),
),
]
def print_benchmark_table(title: str, results: list[dict], name_width: int) -> None: # noqa: D103
header = (
f"{'Layer':<{name_width}} {'Forward eval (ms)':>18} {'Forward (ms)':>18} {'Backward (ms)':>18} "
f"{'Total (ms)':>15} {'Trainable MB':>15} {'Non-train MB':>15} {'Total MB':>12} "
f"{'GPU Alloc MB':>15} {'GPU Peak MB':>15}"
)
separator = "-" * len(header)
print(f"\n{title}")
print(separator)
print(header)
print(separator)
for res in results:
fwd_eval_str = f"{res['fwd_eval_mean']:.3f} +/- {res['fwd_eval_std']:.3f}"
fwd_str = f"{res['fwd_mean']:.3f} +/- {res['fwd_std']:.3f}"
bwd_str = f"{res['bwd_mean']:.3f} +/- {res['bwd_std']:.3f}"
total_mb = res["train_mem"] + res["non_train_mem"]
gpu_alloc_mb = bytes_to_mb(res["gpu_mem"])
gpu_peak_mb = bytes_to_mb(res["gpu_peak"])
print(
f"{res['name']:<{name_width}} {fwd_eval_str:>18} {fwd_str:>18} {bwd_str:>18} "
f"{res['total_time']:>15.3f} {bytes_to_mb(res['train_mem']):>15.3f} "
f"{bytes_to_mb(res['non_train_mem']):>15.3f} {bytes_to_mb(total_mb):>12.3f} "
f"{gpu_alloc_mb:>15.3f} {gpu_peak_mb:>15.3f}"
)
print(separator)
def benchmark_modules( # noqa: D103
title: str,
modules: list[tuple[str, torch.nn.Module]],
forward_fn_builder: Callable[[torch.nn.Module], torch.Tensor],
) -> None:
results = []
for name, module in modules:
def forward_fn(mod=module): # noqa: D103
return forward_fn_builder(mod)
train_mem, non_train_mem = module_memory(module)
gpu_alloc, gpu_peak = module_device_memory(module, device=device)
eval_fwd_mean, eval_fwd_std = benchmark_eval_forward(module, forward_fn, iters=iters, warmup=warmup)
(fwd_mean, fwd_std), (bwd_mean, bwd_std) = benchmark(module, forward_fn, iters=iters, warmup=warmup)
results.append(
{
"name": name,
"fwd_eval_mean": eval_fwd_mean,
"fwd_eval_std": eval_fwd_std,
"fwd_mean": fwd_mean,
"fwd_std": fwd_std,
"bwd_mean": bwd_mean,
"bwd_std": bwd_std,
"total_time": fwd_mean + bwd_mean,
"train_mem": train_mem,
"non_train_mem": non_train_mem,
"gpu_mem": gpu_alloc,
"gpu_peak": gpu_peak,
}
)
name_width = max(22, max(len(res["name"]) for res in results) + 2)
print_benchmark_table(title, results, name_width)
benchmark_modules(
title=f"Encoder layer benchmark per batch={batch_size}, seq_len={src_len}",
modules=encoder_modules,
forward_fn_builder=lambda mod: mod(src),
)
benchmark_modules(
title=f"Decoder layer benchmark per batch={batch_size}, tgt_len={tgt_len}, mem_len={mem_len}",
modules=decoder_modules,
forward_fn_builder=lambda mod: mod(tgt, mem),
)
torch.backends.mha.set_fastpath_enabled(fastpath_prev)