Source code for symm_learning.nn.activations

import torch
import torch.nn.functional as F
from escnn.nn import EquivariantModule


[docs] class Mish(EquivariantModule): """Mish activation function.""" def __init__(self, in_type): super().__init__() self.in_type = in_type self.out_type = in_type for r in in_type.representations: assert "pointwise" in r.supported_nonlinearities, ( 'Error! Representation "{}" does not support "pointwise" non-linearity'.format(r.name) )
[docs] def forward(self, x): # noqa: D102 return self.out_type(F.mish(x.tensor))
[docs] def export(self) -> torch.nn.Mish: """Exporting to a torch.nn.Mish""" return torch.nn.Mish()
[docs] def evaluate_output_shape(self, input_shape): # noqa: D102 return input_shape