diff --git a/RELEASE.md b/RELEASE.md index 2050f60..5f8b9c8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,7 +1,16 @@ # Version X.X.X (only via git install) +# Version 0.2.1 (latest) + **Functionality** +- **augmentation module**: + - add Circular augmenter in compose module. + - add phase swap augmentation in functional module. +- **models module**: + - models can be initialized with a custom seed. + - add EEGConformer. + - add xEEGNet. - **dataloading module**: - EEGDataset now supports EEG with multiple labels (1 per window partition). - **ssl module**: @@ -14,7 +23,7 @@ * reduced unittest overall time -# Version 0.2.0 (latest) +# Version 0.2.0 **Functionality** diff --git a/docs/selfeeg.augmentation.rst b/docs/selfeeg.augmentation.rst index 97a3e25..dbb4679 100644 --- a/docs/selfeeg.augmentation.rst +++ b/docs/selfeeg.augmentation.rst @@ -58,6 +58,7 @@ Functions moving_avg permutation_signal permute_channels + phase_swap random_FT_phase random_slope_scale scaling diff --git a/docs/selfeeg.models.rst b/docs/selfeeg.models.rst index 66ec893..470ae7c 100644 --- a/docs/selfeeg.models.rst +++ b/docs/selfeeg.models.rst @@ -40,6 +40,7 @@ Classes :template: classtemplate.rst DeepConvNetEncoder + EEGConformerEncoder EEGInceptionEncoder EEGNetEncoder EEGSymEncoder @@ -49,6 +50,7 @@ Classes StagerNetEncoder STNetEncoder TinySleepNetEncoder + xEEGNetEncoder models.zoo module @@ -65,6 +67,7 @@ Classes ATCNet DeepConvNet + EEGConformer EEGInception EEGNet EEGSym @@ -74,3 +77,4 @@ Classes StagerNet STNet TinySleepNet + xEEGNet diff --git a/selfeeg/augmentation/__init__.py b/selfeeg/augmentation/__init__.py index f232afd..2c869cb 100644 --- a/selfeeg/augmentation/__init__.py +++ b/selfeeg/augmentation/__init__.py @@ -25,6 +25,7 @@ moving_avg, permutation_signal, permute_channels, + phase_swap, random_FT_phase, random_slope_scale, scaling, diff --git a/selfeeg/augmentation/functional.py b/selfeeg/augmentation/functional.py index d0a1dc3..cf3ee04 100755 --- a/selfeeg/augmentation/functional.py +++ b/selfeeg/augmentation/functional.py @@ -35,6 +35,7 @@ "moving_avg", "permutation_signal", "permute_channels", + "phase_swap", "random_FT_phase", "random_slope_scale", "scaling", @@ -114,7 +115,7 @@ def shift_horizontal( batch_equal: bool = True, ) -> ArrayLike: """ - shifts temporally the elements of the ArrayLike object. + Shifts temporally the elements of the ArrayLike object. Shift is applied along the last dimension. The empty elements at beginning or the ending part @@ -438,6 +439,90 @@ def shift_frequency( return _shift_frequency(x, shift_freq, Fs, forward, random_shift, batch_equal, t, h) +def phase_swap(x: ArrayLike) -> ArrayLike: + """ + Apply the phase swap data augmentation to the ArrayLike object. + + The phase swap data augmentation consists in merging the amplitude + and phase components of biosignals from different sources to help + the model learn their coupling. + Specifically, the amplitude and phase of two randomly selected EEG samples + are extracted using the Fourier transform. + New samples are then generated by applying the inverse Fourier transform, + combining the amplitude from one sample with the phase from the other. + See the following paper for more information [phaseswap]_. + + Parameters + ---------- + x : ArrayLike + A 3-dimensional torch tensor or numpy array. + The last two dimensions must refer to the EEG (Channels x Samples). + + Returns + ------- + x: ArrayLike + The augmented version of the input Tensor or Array. + + Note + ---- + `Phase swap` ignores the class of each sample. + + + References + ---------- + .. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo. + "Boosting Generalization in Bio-signal Classification by + Learning the Phase-Amplitude Coupling". DAGM GCPR (2020). + + """ + + Ndim = len(x.shape) + if Ndim != 3: + raise ValueError("x must be a 3-dimensional array or tensor") + + N = x.shape[0] + + if isinstance(x, torch.Tensor): + # Compute fft, module and phase + xfft = torch.fft.fft(x) + amplitude = xfft.abs() + phase = xfft.angle() + x_aug = torch.clone(xfft) + + # Random shuffle indeces + idx_shuffle = torch.randperm(N).to(device=x.device) + idx_shuffle_1 = idx_shuffle[: (N // 2)] + idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2] + + # Apply phase swap + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * torch.exp(1j * phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * torch.exp(1j * phase[idx_shuffle_1]) + + # Reconstruct the signal + x_aug = (torch.fft.ifft(x_aug)).real.to(device=x.device) + + else: + + xfft = np.fft.fft(x) + amplitude = np.abs(xfft) + phase = np.angle(xfft) + x_aug = np.copy(xfft) + + # Random shuffle indeces + idx_shuffle = np.random.permutation(N) + idx_shuffle_1 = idx_shuffle[: (N // 2)] + idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2] + + # Apply phase swap + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * np.exp(1j * phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * np.exp(1j * phase[idx_shuffle_1]) + + # Reconstruct the signal + x_aug = (np.fft.ifft(x_aug)).real + + return x_aug + + def flip_vertical(x: ArrayLike) -> ArrayLike: """ changes the sign of all the elements of the input. @@ -456,9 +541,8 @@ def flip_vertical(x: ArrayLike) -> ArrayLike: ------- >>> import torch >>> import selfeeg.augmentation as aug - >>> x = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024)) - >>> xaug = aug.flip_vertical(x) - >>> print(torch.equal(xaug, x*(-1))) # should return True + >>> x = torch.randn(64,32,512) + >>> xaug = aug.phase_swap(x) """ x_flip = x * (-1) diff --git a/selfeeg/models/__init__.py b/selfeeg/models/__init__.py index da1f618..5a336c8 100644 --- a/selfeeg/models/__init__.py +++ b/selfeeg/models/__init__.py @@ -9,6 +9,7 @@ from .encoders import ( BasicBlock1, DeepConvNetEncoder, + EEGConformerEncoder, EEGInceptionEncoder, EEGNetEncoder, EEGSymEncoder, @@ -17,11 +18,13 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, + xEEGNetEncoder, ) from .zoo import ( ATCNet, DeepConvNet, + EEGConformer, EEGInception, EEGNet, EEGSym, @@ -31,4 +34,5 @@ StagerNet, STNet, TinySleepNet, + xEEGNet, ) diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index 0924ffb..5b382dd 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -1,3 +1,5 @@ +from itertools import chain, combinations +from scipy.signal import firwin import torch import torch.nn as nn import torch.nn.functional as F @@ -9,10 +11,12 @@ SeparableConv2d, FilterBank, ) +from ..utils.utils import _reset_seed __all__ = [ "BasicBlock1", "DeepConvNetEncoder", + "EEGConformerEncoder", "EEGInceptionEncoder", "EEGNetEncoder", "EEGSymEncoder", @@ -22,6 +26,7 @@ "StagerNetEncoder", "STNetEncoder", "TinySleepNetEncoder", + "xEEGNetEncoder", ] @@ -81,6 +86,11 @@ class EEGNetEncoder(nn.Module): If None no constraint will be applied. Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note @@ -102,24 +112,26 @@ class EEGNetEncoder(nn.Module): def __init__( self, - Chans, - kernLength=64, - dropRate=0.5, - F1=8, - D=2, - F2=16, - dropType="Dropout", - ELUalpha=1, - pool1=4, - pool2=8, - separable_kernel=16, - depthwise_max_norm=1.0, + Chans: int, + kernLength: int = 64, + dropRate: float = 0.5, + F1: int = 8, + D: int = 2, + F2: int = 16, + dropType: str = "Dropout", + ELUalpha: int = 1, + pool1: int = 4, + pool2: int = 8, + separable_kernel: int = 16, + depthwise_max_norm: float = 1.0, + seed: int = None, ): if dropType not in ["SpatialDropout2D", "Dropout"]: - raise ValueError("implemented Dropout types are" " 'Dropout' or 'SpatialDropout2D '") + raise ValueError("Dropout types are 'Dropout' or 'SpatialDropout2D'") super(EEGNetEncoder, self).__init__() + _reset_seed(seed) # Layer 1 self.conv1 = nn.Conv2d(1, F1, (1, kernLength), padding="same", bias=False) @@ -154,26 +166,20 @@ def forward(self, x): """ :meta private: """ - # Layer 1 x = torch.unsqueeze(x, 1) x = self.conv1(x) x = self.batchnorm1(x) - - # Layer 2 x = self.conv2(x) x = self.batchnorm2(x) x = self.elu2(x) x = self.pooling2(x) x = self.drop2(x) - - # Layer 3 x = self.sepconv3(x) x = self.batchnorm3(x) x = self.elu3(x) x = self.pooling3(x) x = self.drop3(x) x = self.flatten3(x) - return x @@ -226,6 +232,11 @@ class DeepConvNetEncoder(nn.Module): The dropout percentage in range [0,1]. Default = 0.5 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -241,18 +252,20 @@ class DeepConvNetEncoder(nn.Module): def __init__( self, - Chans, - kernLength=10, - F=25, - Pool=3, - stride=3, - max_norm=None, - batch_momentum=0.1, - ELUalpha=1, - dropRate=0.5, + Chans: int, + kernLength: int = 10, + F: int = 25, + Pool: int = 3, + stride: int = 3, + max_norm: float = None, + batch_momentum: float = 0.1, + ELUalpha: int = 1, + dropRate: float = 0.5, + seed: int = None, ): super(DeepConvNetEncoder, self).__init__() + _reset_seed(seed) self.conv1 = ConstrainedConv2d( 1, F, (1, kernLength), padding="valid", stride=(1, 1), max_norm=max_norm @@ -372,6 +385,11 @@ class EEGInceptionEncoder(nn.Module): If None no constraint will be included. Default = 1. + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -387,19 +405,22 @@ class EEGInceptionEncoder(nn.Module): def __init__( self, - Chans, - F1=8, - D=2, - kernel_size=64, - pool=4, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - batch_momentum=0.1, - max_depth_norm=1.0, + Chans: int, + F1: int = 8, + D: int = 2, + kernel_size: int = 64, + pool: int = 4, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + batch_momentum: float = 0.1, + max_depth_norm: float = 1.0, + seed: int = None, ): super(EEGInceptionEncoder, self).__init__() + _reset_seed(seed) + self.inc1 = nn.Sequential( nn.Conv2d(1, F1, (1, kernel_size), padding="same", bias=bias), nn.BatchNorm2d(F1, momentum=batch_momentum), @@ -545,6 +566,11 @@ class TinySleepNetEncoder(nn.Module): Hidden size of the lstm block. Default = 128 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -560,17 +586,20 @@ class TinySleepNetEncoder(nn.Module): def __init__( self, - Chans, - Fs, - F=128, - kernlength=8, - pool=8, - dropRate=0.5, - batch_momentum=0.1, - hidden_lstm=128, + Chans: int, + Fs: int, + F: int = 128, + kernlength: int = 8, + pool: int = 8, + dropRate: float = 0.5, + batch_momentum: float = 0.1, + hidden_lstm: int = 128, + seed: int = None, ): super(TinySleepNetEncoder, self).__init__() + _reset_seed(seed) + self.conv1 = nn.Conv1d(Chans, F, int(Fs // 2), stride=int(Fs // 16), padding="valid") self.BN1 = nn.BatchNorm1d(F, momentum=batch_momentum) self.Relu = nn.ReLU() @@ -611,14 +640,10 @@ def forward(self, x): x = self.conv4(x) x = self.BN4(x) x = self.Relu(x) - x = self.pool2(x) x = self.drop2(x) - x = torch.permute(x, (2, 0, 1)) - out, (ht, ct) = self.lstm1(x) - return ht[-1] @@ -649,6 +674,11 @@ class StagerNetEncoder(nn.Module): The temporal pooling kernel size. Default = 4 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -662,9 +692,10 @@ class StagerNetEncoder(nn.Module): """ - def __init__(self, Chans, kernLength=64, F=8, Pool=16): + def __init__(self, Chans, kernLength: int = 64, F: int = 8, Pool: int = 16, seed: int = None): super(StagerNetEncoder, self).__init__() + _reset_seed(seed) self.conv1 = nn.Conv2d(1, Chans, (Chans, 1), stride=(1, 1), bias=True) self.conv2 = nn.Conv2d(1, F, (1, kernLength), stride=(1, 1), padding="same") @@ -676,7 +707,6 @@ def __init__(self, Chans, kernLength=64, F=8, Pool=16): def forward(self, x): """ :meta private: - """ x = torch.unsqueeze(x, 1) x = self.conv1(x) @@ -720,6 +750,11 @@ class ShallowNetEncoder(nn.Module): Dropout probability. Must be in [0,1) Default= 0.2 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -740,9 +775,13 @@ class ShallowNetEncoder(nn.Module): """ - def __init__(self, Chans, F=40, K1=25, Pool=75, p=0.2): + def __init__( + self, Chans, F: int = 40, K1: int = 25, Pool: int = 75, p: float = 0.2, seed: int = None + ): super(ShallowNetEncoder, self).__init__() + _reset_seed(seed) + self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1)) self.conv2 = nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)) self.batch1 = nn.BatchNorm2d(F) @@ -774,7 +813,7 @@ class BasicBlock1(nn.Module): :meta private: """ - def __init__(self, inplanes, planes, kernLength=7, stride=1): + def __init__(self, inplanes: int, planes: int, kernLength: int = 7, stride: int = 1): super(BasicBlock1, self).__init__() self.stride = stride @@ -820,18 +859,13 @@ def forward(self, x): :meta private: """ residual = self.downsample(x) - # print('residual: ', residual.shape) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - # print('out 1: ', out.shape) out = self.conv2(out) out = self.bn2(out) - # print('out 2: ', out.shape) - out += residual out = self.relu(out) - return out @@ -894,7 +928,12 @@ class ResNet1DEncoder(nn.Module): 2. nn.BatchNorm2d() 3. nn.ReLU() - Default = None + Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -915,7 +954,7 @@ class ResNet1DEncoder(nn.Module): def __init__( self, - Chans, + Chans: int, block: nn.Module = BasicBlock1, Layers: "list of 4 ints" = [2, 2, 2, 2], inplane: int = 16, @@ -923,9 +962,12 @@ def __init__( addConnection: bool = False, preBlock: nn.Module = None, postBlock: nn.Module = None, + seed: int = None, ): super(ResNet1DEncoder, self).__init__() + _reset_seed(seed) + self.inplane = inplane self.kernLength = kernLength self.connection = addConnection @@ -1121,6 +1163,11 @@ class STNetEncoder(nn.Module): If True, adds a learnable bias to the convolutional layers. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -1134,8 +1181,17 @@ class STNetEncoder(nn.Module): """ - def __init__(self, Samples, F=256, kernlength=5, dropRate=0.5, bias=True): + def __init__( + self, + Samples, + F: int = 256, + kernlength: int = 5, + dropRate: float = 0.5, + bias: bool = True, + seed: int = None, + ): super(STNetEncoder, self).__init__() + _reset_seed(seed) self.conv1 = nn.Conv2d( Samples, F, kernel_size=kernlength - 2, stride=1, padding="same", bias=bias @@ -1396,6 +1452,11 @@ class EEGSymEncoder(nn.Module): Currently not implemented, will be added in future releases. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -1411,21 +1472,22 @@ class EEGSymEncoder(nn.Module): def __init__( self, - Chans, - Samples, - Fs, - scales_time=(500, 250, 125), - lateral_chans=3, - first_left=True, - F=8, - pool=2, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - residual=True, + Chans: int, + Samples: int, + Fs: float, + scales_time: tuple = (500, 250, 125), + lateral_chans: int = 3, + first_left: bool = True, + F: int = 8, + pool: int = 2, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + residual: bool = True, + seed: int = None, ): - super(EEGSymEncoder, self).__init__() + _reset_seed(seed) self.input_samples = int(Samples * Fs / 1000) self.scales_samples = [int(s * Fs / 1000) for s in scales_time] @@ -1771,6 +1833,11 @@ class FBCNetEncoder(nn.Module): The maximum norm each filter can have in the depthwise block. If None no constraint will be included. + Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + Default = None Example @@ -1802,6 +1869,7 @@ def __init__( TemporalStride: int = 4, batch_momentum: float = 0.1, depthwise_max_norm=None, + seed: int = None, ): super(FBCNetEncoder, self).__init__() self.FilterBands = FilterBands @@ -1858,3 +1926,391 @@ def forward(self, x): x = self.TMB(x) x = torch.flatten(x, 1) return x + + +class EEGConformerEncoder(nn.Module): + """ + Pytorch implementation of the EEGConformer Encoder. + + See EEGConformer for some references. + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). + + Parameters + ---------- + Chans: int + The number of EEG channels. + F: int, optional + The number of output filters in the temporal convolution layer. + + Default = 40 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 25 + Pool: int, optional + The temporal pooling kernel size. + + Default = 75 + stride_pool: int, optional + The temporal pooling stride. + + Default = 15 + d_model: int, optional + The embedding size. It is the number of expected features in the input of + the transformer encoder layer. + + Default = 40 + nlayers: int, optional + The number of transformer encoder layers. + + Default = 6 + nheads: int, optional + The number of heads in the multi-head attention layers. + + Default = 10 + dim_feedforward: int, optional + The dimension of the feedforward hidden layer in the transformer encoder. + + Default = 160 + activation_transformer: str or Callabel, optional + The activation function in the transformer encoder. See the PyTorch + TransformerEncoderLayer documentation for accepted inputs. + + Default = "gelu" + p: float, optional + Dropout probability in the tokenizer. Must be in [0,1) + + Default= 0.2 + p_transformer: float, optional + Dropout probability in the transformer encoder. Must be in [0,1) + + Default= 0.5 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.EEGConformerEncoder(8) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 224]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + + """ + + def __init__( + self, + Chans, + F: int = 40, + K1: int = 25, + Pool: int = 75, + stride_pool: int = 15, + d_model: int = 40, + nlayers: int = 6, + nheads: int = 10, + dim_feedforward: int = 160, + activation_transformer: str or Callable = "gelu", + p: float = 0.2, + p_transformer: float = 0.5, + seed: int = None, + ): + + super(EEGConformerEncoder, self).__init__() + _reset_seed(seed) + + self.tokenizer = nn.Sequential( + nn.Conv2d(1, F, (1, K1), stride=(1, 1)), + nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)), + nn.BatchNorm2d(F), + nn.AvgPool2d((1, Pool), stride=(1, stride_pool)), + nn.Dropout(p), + ) + self.projection = nn.Conv2d(F, d_model, (1, 1)) + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model, + nhead=nheads, + dim_feedforward=dim_feedforward, + dropout=p_transformer, + activation="gelu", + batch_first=True, + ), + num_layers=nlayers, + ) + + def forward(self, x): + """ + :meta private: + """ + x = torch.unsqueeze(x, 1) + x = self.tokenizer(x) + x = x.squeeze(2) + x = torch.permute(x, [0, 2, 1]) + x = self.transformer(x) + x = torch.permute(x, [0, 2, 1]) + return x + + +class xEEGNetEncoder(nn.Module): + """ + Pytorch implementation of the xEEGNet Encoder. + + See xEEGNet for some references. + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). + + Parameters + ---------- + Chans: int + The number of EEG channels. + Fs: int + The sampling rate of the EEG signal in Hz. + It is used to initialize the weights of the filters. + Must be specified even if `random_temporal_filter` is False. + F1: int, optional + The number of output filters in the temporal convolution layer. + + Default = 7 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 125 + F2: int, optional + The number of output filters in the spatial convolution layer. + + Default = 7 + Pool: int, optional + Kernel size for temporal pooling. + + Default = 75 + p: float, optional + Dropout probability in [0,1) + + Default = 0.2 + random_temporal_filter: bool, optional + If True, initialize the temporal filter weights randomly. + Otherwise, use a passband FIR filter. + + Default = False + freeze_temporal: int, optional + Number of forward steps to keep the temporal layer frozen. + + Default = 1e12 + spatial_depthwise: bool, optional + Whether to apply a depthwise layer in the spatial convolution. + + Default = True + log_activation_base: str, optional + Base for the logarithmic activation after pooling. + Options: "e" (natural log), "10" (logarithm base 10), "dB" (decibel scale). + + Default = "dB" + norm_type: str, optional + The type of normalization. Expected values are "batch" or "instance". + + Default = "batchnorm" + global_pooling: bool, optional + If True, apply global average pooling instead of flattening. + + Default = True + bias: list[int, int], optional + A 2-element list with boolean values. + If the first element is True, a bias will be added to the temporal + convolutional layer. + If the second element is True, a bias will be added to the spatial + convolutional layer. + + Default = [False, False] + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4, 8, 512) + >>> mdl = models.xEEGNetEncoder(8, 125) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 7]) + + """ + + def __init__( + self, + Chans: int, + Fs: int, + F1: int = 7, + K1: int = 125, + F2: int = 7, + Pool: int = 75, + p: float = 0.2, + random_temporal_filter=False, + freeze_temporal: int = 1e12, + spatial_depthwise: bool = True, + log_activation_base: str = "dB", + norm_type: str = "batchnorm", + global_pooling=True, + bias: list[int, int] = [False, False], + seed: int = None, + ): + + super(xEEGNetEncoder, self).__init__() + + # Set seed before initializing layers + self.custom_seed = seed + _reset_seed(seed) + + self.Fs = Fs + self.chans = Chans + self.freeze_temporal = freeze_temporal + self.bias_1conv = bias[0] + self.bias_2conv = bias[1] + self.do_global_pooling = global_pooling + if self.Fs <= 0 and not (random_temporal_filter): + raise ValueError( + "to properly initialize non random temporal fir filters, " + "Fs (sampling rate) must be given" + ) + + if random_temporal_filter: + self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1, 1), bias=self.bias_1conv) + else: + self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1, 1), bias=self.bias_1conv) + self._initialize_custom_temporal_filter(self.custom_seed) + + if spatial_depthwise: + self.conv2 = nn.Conv2d( + F1, F2, (Chans, 1), stride=(1, 1), groups=F1, bias=self.bias_2conv + ) + else: + self.conv2 = nn.Conv2d(F1, F2, (Chans, 1), stride=(1, 1), bias=self.bias_2conv) + + if "batch" in norm_type.casefold(): + self.batch1 = nn.BatchNorm2d(F2, affine=True) + elif "instance" in norm_type.casefold(): + self.batch1 = nn.InstanceNorm2d(F2) + else: + raise ValueError("normalization layer type can be 'batchnorm' or 'instancenorm'") + + if log_activation_base in ["e", torch.e]: + self.log_activation = lambda x: torch.log(torch.clamp(x, 1e-7, 1e4)) + elif log_activation_base in ["10", 10]: + self.log_activation = lambda x: torch.log10(torch.clamp(x, 1e-7, 1e4)) + elif log_activation_base in ["db", "dB"]: + self.log_activation = lambda x: 10 * torch.log10(torch.clamp(x, 1e-7, 1e4)) + else: + raise ValueError( + "allowed activation base are 'e' for torch.log, " + "'10' for torch.log10, and 'dB' for 10*torch.log10" + ) + + if not self.do_global_pooling: + self.pool2 = nn.AvgPool2d((1, Pool), stride=(1, max(1, Pool // 5))) + else: + self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) + + self.drop1 = nn.Dropout(p) + self.flatten = nn.Flatten() + + def forward(self, x): + """ + :meta private: + """ + if self.freeze_temporal: + self.freeze_temporal -= 1 + self.conv1.requires_grad_(False) + else: + self.conv1.requires_grad_(True) + x = torch.unsqueeze(x, 1) + x = self.conv1(x) + x = self.conv2(x) + x = self.batch1(x) + x = torch.square(x) + if self.do_global_pooling: + x = self.global_pooling(x) + else: + x = self.pool2(x) + x = self.log_activation(x) + x = self.drop1(x) + x = self.flatten(x) + return x + + @torch.no_grad() + def _get_spatial_softmax(self): + """ + :meta private: + """ + return torch.softmax(self.conv2.weight, -2) + + @torch.no_grad() + def _get_spatial_zero(self): + """ + :meta private: + """ + return self.conv2.weight - torch.sum(self.conv2.weight, -2, keepdim=True) + + @torch.no_grad() + def _initialize_custom_temporal_filter(self, seed=None): + """ + :meta private: + """ + _reset_seed(seed) + if self.conv1.weight.shape[-1] >= 75: + bands = ( + (0.5, 4.0), # delta + (4.0, 8.0), # theta + (8.0, 12.0), # alpha + (12.0, 16.0), # beta1 + (16.0, 20.0), # beta2 + (20.0, 28.0), # beta3 + (28.0, 45.0), # gamma + ) + else: + bands = ((0.5, 8.0), (8.0, 16.0), (16.0, 28.0), (28.0, 45.0)) + F, KernLength = self.conv1.weight.shape[0], self.conv1.weight.shape[-1] + comb = self._powerset(bands) + for i in range(min(F, len(comb))): # if F <= len(comb): + filt_coeff = firwin( + KernLength, self._merge_tuples(comb[i]), pass_zero=False, fs=self.Fs + ) + self.conv1.weight.data[i, 0, 0] = torch.from_numpy(filt_coeff) + + @torch.no_grad() + def _powerset(self, s): + """ + :meta private: + """ + return tuple(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) + + @torch.no_grad() + def _merge_tuples(self, tuples): + """ + :meta private: + """ + merged = [num for tup in tuples for num in tup] + merged = sorted(merged) + if len(merged) > 2: + new_merged = [merged[0]] + for i in range(1, len(merged) - 2, 2): + if merged[i] != merged[i + 1]: + new_merged.append(merged[i]) + new_merged.append(merged[i + 1]) + new_merged.append(merged[-1]) + return sorted(new_merged) + return merged + + @torch.no_grad() + def _combinatorial_op(self, N, k): + """ + :meta private: + """ + return int((math.factorial(N)) / (math.factorial(k) * math.factorial(N - k))) diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 3f79d2b..2c9b94a 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -2,11 +2,11 @@ import torch.nn as nn import torch.nn.functional as F from .layers import ConstrainedDense, ConstrainedConv1d, ConstrainedConv2d - from .encoders import ( BasicBlock1, DeepConvNetEncoder, EEGInceptionEncoder, + EEGConformerEncoder, EEGNetEncoder, EEGSymEncoder, FBCNetEncoder, @@ -15,11 +15,14 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, + xEEGNetEncoder, ) +from ..utils.utils import _reset_seed __all__ = [ "ATCNet", "DeepConvNet", + "EEGConformer", "EEGInception", "EEGNet", "EEGSym", @@ -29,6 +32,7 @@ "StagerNet", "STNet", "TinySleepNet", + "xEEGNet", ] @@ -103,6 +107,11 @@ class EEGNet(nn.Module): applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -129,22 +138,23 @@ class EEGNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - kernLength=64, - dropRate=0.5, - F1=8, - D=2, - F2=16, - norm_rate=0.25, - dropType="Dropout", - ELUalpha=1, - pool1=4, - pool2=8, - separable_kernel=16, - depthwise_max_norm=1.0, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + kernLength: int = 64, + dropRate: float = 0.5, + F1: int = 8, + D: int = 2, + F2: int = 16, + norm_rate: int = 0.25, + dropType: str = "Dropout", + ELUalpha: int = 1, + pool1: int = 4, + pool2: int = 8, + separable_kernel: int = 16, + depthwise_max_norm: float = 1.0, + return_logits: bool = True, + seed: int = None, ): super(EEGNet, self).__init__() @@ -164,7 +174,10 @@ def __init__( pool2, separable_kernel, depthwise_max_norm, + seed, ) + + _reset_seed(seed) self.Dense = ConstrainedDense( F2 * (Samples // int(pool1 * pool2)), 1 if nb_classes <= 2 else nb_classes, @@ -253,6 +266,11 @@ class DeepConvNet(nn.Module): the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -281,32 +299,35 @@ class DeepConvNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - kernLength=10, - F=25, - Pool=3, - stride=3, - max_norm=None, - batch_momentum=0.1, - ELUalpha=1, - dropRate=0.5, - max_dense_norm=None, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + kernLength: int = 10, + F: int = 25, + Pool: int = 3, + stride: int = 3, + max_norm: int = None, + batch_momentum: float = 0.1, + ELUalpha: int = 1, + dropRate: float = 0.5, + max_dense_norm: float = None, + return_logits: bool = True, + seed: int = None, ): super(DeepConvNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits self.encoder = DeepConvNetEncoder( - Chans, kernLength, F, Pool, stride, max_norm, batch_momentum, ELUalpha, dropRate + Chans, kernLength, F, Pool, stride, max_norm, batch_momentum, ELUalpha, dropRate, seed ) k = kernLength Dense_input = [Samples] * 8 for i in range(4): Dense_input[i * 2] = Dense_input[i * 2 - 1] - k + 1 Dense_input[i * 2 + 1] = (Dense_input[i * 2] - Pool) // stride + 1 + + _reset_seed(seed) self.Dense = ConstrainedDense( F * 8 * Dense_input[-1], 1 if nb_classes <= 2 else nb_classes, max_norm=max_dense_norm ) @@ -392,6 +413,11 @@ class EEGInception(nn.Module): the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -416,19 +442,20 @@ class EEGInception(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - F1=8, - D=2, - kernel_size=64, - pool=4, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - batch_momentum=0.1, - max_depth_norm=1.0, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + F1: int = 8, + D: int = 2, + kernel_size: int = 64, + pool: int = 4, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + batch_momentum: float = 0.1, + max_depth_norm: float = 1.0, + return_logits: bool = True, + seed: int = None, ): super(EEGInception, self).__init__() self.nb_classes = nb_classes @@ -444,7 +471,10 @@ def __init__( bias, batch_momentum, max_depth_norm, + seed, ) + + _reset_seed(seed) self.Dense = nn.Linear( int((F1 * 3) / 4) * int((Samples // (pool * (int(pool // 2) ** 3)))), 1 if nb_classes <= 2 else nb_classes, @@ -523,6 +553,11 @@ class TinySleepNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -547,26 +582,28 @@ class TinySleepNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Fs, - F=128, - kernlength=8, - pool=8, - dropRate=0.5, - batch_momentum=0.1, - max_dense_norm=2.0, - hidden_lstm=128, - return_logits=True, + nb_classes: int, + Chans: int, + Fs: int, + F: int = 128, + kernlength: int = 8, + pool: int = 8, + dropRate: float = 0.5, + batch_momentum: float = 0.1, + max_dense_norm: float = 2.0, + hidden_lstm: int = 128, + return_logits: bool = True, + seed: int = None, ): super(TinySleepNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits self.encoder = TinySleepNetEncoder( - Chans, Fs, F, kernlength, pool, dropRate, batch_momentum, hidden_lstm + Chans, Fs, F, kernlength, pool, dropRate, batch_momentum, hidden_lstm, seed ) + _reset_seed(seed) self.drop3 = nn.Dropout1d(dropRate) self.Dense = ConstrainedDense( hidden_lstm, 1 if nb_classes <= 2 else nb_classes, max_norm=max_dense_norm @@ -629,6 +666,11 @@ class StagerNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -650,21 +692,23 @@ class StagerNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - dropRate=0.5, - kernLength=64, - F=8, - Pool=16, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + dropRate: float = 0.5, + kernLength: int = 64, + F: int = 8, + Pool: int = 16, + return_logits: bool = True, + seed: int = None, ): super(StagerNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = StagerNetEncoder(Chans, kernLength=kernLength, F=F, Pool=Pool) + self.encoder = StagerNetEncoder(Chans, kernLength=kernLength, F=F, Pool=Pool, seed=seed) + _reset_seed(seed) self.drop = nn.Dropout(p=dropRate) self.Dense = nn.Linear( Chans * F * (int((int((Samples - Pool) / Pool + 1) - Pool) / Pool + 1)), @@ -728,6 +772,11 @@ class ShallowNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -754,13 +803,26 @@ class ShallowNet(nn.Module): """ - def __init__(self, nb_classes, Chans, Samples, F=40, K1=25, Pool=75, p=0.2, return_logits=True): + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + F: int = 40, + K1: int = 25, + Pool: int = 75, + p: float = 0.2, + return_logits: bool = True, + seed: int = None, + ): super(ShallowNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = ShallowNetEncoder(Chans, F=F, K1=K1, Pool=Pool, p=p) + self.encoder = ShallowNetEncoder(Chans, F=F, K1=K1, Pool=Pool, p=p, seed=seed) + + _reset_seed(seed) self.Dense = nn.Linear( F * ((Samples - K1 + 1 - Pool) // 15 + 1), 1 if nb_classes <= 2 else nb_classes ) @@ -864,6 +926,11 @@ class ResNet1D(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -889,9 +956,9 @@ class ResNet1D(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, + nb_classes: int, + Chans: int, + Samples: int, block: nn.Module = BasicBlock1, Layers: "list of 4 int" = [2, 2, 2, 2], inplane: int = 16, @@ -900,7 +967,8 @@ def __init__( preBlock: nn.Module = None, postBlock: nn.Module = None, classifier: nn.Module = None, - return_logits=True, + return_logits: bool = True, + seed: int = None, ): super(ResNet1D, self).__init__() @@ -916,8 +984,11 @@ def __init__( addConnection=addConnection, preBlock=preBlock, postBlock=postBlock, + seed=seed, ) + # Classifier + _reset_seed(seed) if classifier is None: if addConnection: out1 = int((Samples + 2 * (int(kernLength // 2)) - kernLength) // 2) + 1 @@ -999,6 +1070,11 @@ class STNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1019,23 +1095,25 @@ class STNet(nn.Module): def __init__( self, - nb_classes, - Samples, - grid_size=9, - F=256, - kernlength=5, - dropRate=0.5, - bias=True, - dense_size=1024, - return_logits=True, + nb_classes: int, + Samples: int, + grid_size: int = 9, + F: int = 256, + kernlength: int = 5, + dropRate: float = 0.5, + bias: bool = True, + dense_size: int = 1024, + return_logits: bool = True, + seed: int = None, ): super(STNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = STNetEncoder(Samples, F, kernlength, dropRate, bias) - self.drop3 = nn.Dropout(dropRate) + self.encoder = STNetEncoder(Samples, F, kernlength, dropRate, bias, seed=seed) + _reset_seed(seed) + self.drop3 = nn.Dropout(dropRate) self.Dense = nn.Sequential( nn.Linear(int(F / 16) * (grid_size**2), dense_size), nn.Dropout(dropRate), @@ -1131,6 +1209,11 @@ class EEGSym(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1151,20 +1234,21 @@ class EEGSym(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - Fs, - scales_time=(500, 250, 125), - lateral_chans=3, - first_left=True, - F=8, - pool=2, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - residual=True, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + Fs: int, + scales_time: tuple = (500, 250, 125), + lateral_chans: int = 3, + first_left: bool = True, + F: int = 8, + pool: int = 2, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + residual: bool = True, + return_logits: bool = True, + seed: int = None, ): super(EEGSym, self).__init__() self.nb_classes = nb_classes @@ -1182,7 +1266,10 @@ def __init__( ELUalpha, bias, residual, + seed=seed, ) + + _reset_seed(seed) self.Dense = nn.Linear(int((F * 9) / 2), 1 if nb_classes <= 2 else nb_classes) def forward(self, x): @@ -1304,6 +1391,11 @@ class FBCNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1344,6 +1436,7 @@ def __init__( linear_max_norm: float = None, classifier: nn.Module = None, return_logits: bool = True, + seed: int = None, ): super(FBCNet, self).__init__() @@ -1367,8 +1460,11 @@ def __init__( TemporalStride, batch_momentum, depthwise_max_norm, + seed=seed, ) + # Head + _reset_seed(seed) if classifier is None: self.head = ConstrainedDense( D * FilterBands * TemporalStride, @@ -1622,6 +1718,11 @@ class ATCNet(nn.Module): applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1644,32 +1745,34 @@ class ATCNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - Fs, - num_windows=5, - mha_heads=2, - tcn_depth=2, - F1=16, - D=2, - pool1=None, - pool2=None, - dropRate=0.3, - max_norm=None, - batchMomentum=0.1, - ELUAlpha=1.0, - mha_dropRate=0.5, - tcn_kernLength=4, - tcn_F=32, - tcn_ELUAlpha=0.0, - tcn_dropRate=0.3, - tcn_max_norm=None, - tcn_batchMom=0.1, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + Fs: float, + num_windows: int = 5, + mha_heads: int = 2, + tcn_depth: int = 2, + F1: int = 16, + D: int = 2, + pool1: int = None, + pool2: int = None, + dropRate: float = 0.3, + max_norm: float = None, + batchMomentum: float = 0.1, + ELUAlpha: float = 1.0, + mha_dropRate: float = 0.5, + tcn_kernLength: int = 4, + tcn_F: int = 32, + tcn_ELUAlpha: float = 0.0, + tcn_dropRate: float = 0.3, + tcn_max_norm: float = None, + tcn_batchMom: float = 0.1, + return_logits: bool = True, + seed: int = None, ): super(ATCNet, self).__init__() + _reset_seed(seed) # important for model construction self.return_logits = return_logits @@ -1721,6 +1824,7 @@ def __init__( ) # Construct each Branch + _reset_seed(seed) for i in range(self.num_windows): self.add_multi_head(i) self.add_residual_tcn(i) @@ -1778,3 +1882,352 @@ def forward(self, x): else: x = F.softmax(x, dim=1) return x + + +class EEGConformer(nn.Module): + """ + Pytorch implementation of EEGConformer. + + For more information see the following paper [EEGcon]_ . + The original implementation of EEGconformer can be found here [EEGcongit]_ . + + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). + + Parameters + ---------- + nb_classes: int + The number of classes. If less than 2, a binary classification problem + is considered (output dimensions will be [batch, 1] in this case). + Chans: int + The number of EEG channels. + F: int, optional + The number of output filters in the temporal convolution layer. + + Default = 40 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 25 + Pool: int, optional + The temporal pooling kernel size. + + Default = 75 + stride_pool: int, optional + The temporal pooling stride. + + Default = 15 + d_model: int, optional + The embedding size. It is the number of expected features in the input of + the transformer encoder layer. + + Default = 40 + nlayers: int, optional + The number of transformer encoder layers. + + Default = 6 + nheads: int, optional + The number of heads in the multi-head attention layers. + + Default = 10 + dim_feedforward: int, optional + The dimension of the feedforward hidden layer in the transformer encoder. + + Default = 160 + activation_transformer: str or Callabel, optional + The activation function in the transformer encoder. See the PyTorch + TransformerEncoderLayer documentation for accepted inputs. + + Default = "gelu" + p: float, optional + Dropout probability in the tokenizer. Must be in [0,1) + + Default = 0.2 + p_transformer: float, optional + Dropout probability in the transformer encoder. Must be in [0,1) + + Default = 0.5 + mlp_dim: list, optional + A two-element list indicating the output dimensions of the 2 FC + layers in the final classification head. + + Default = [256, 32] + return_logits: bool, optional + Whether to return the output as logit or probability. + It is suggested to not use False as the pytorch crossentropy loss function + applies the softmax internally. + + Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + References + ---------- + .. [EEGcon] Song et al., EEG Conformer: Convolutional Transformer for EEG Decoding + and Visualization. IEEE TNSRE. 2023. https://doi.org/10.1109/TNSRE.2022.3230250 + .. [EEGcongit] https://github.com/eeyhsong/EEG-Conformer + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.EEGConformer(2, 8, 512) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 1]) + + """ + + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + F: int = 40, + K1: int = 25, + Pool: int = 75, + stride_pool: int = 15, + d_model: int = 40, + nlayers: int = 6, + nheads: int = 10, + dim_feedforward: int = 160, + activation_transformer: str or Callable = "gelu", + p: float = 0.2, + p_transformer: float = 0.5, + mlp_dim: list[int, int] = [256, 32], + return_logits: bool = True, + seed: int = None, + ): + + super(EEGConformer, self).__init__() + self.return_logits = return_logits + self.nb_classes = nb_classes + + self.encoder = EEGConformerEncoder( + Chans, + F, + K1, + Pool, + stride_pool, + d_model, + nlayers, + nheads, + dim_feedforward, + activation_transformer, + p, + p_transformer, + seed, + ) + + _reset_seed(seed) + self.MLP = nn.Sequential( + nn.AvgPool1d((Samples - K1 + 1 - Pool) // stride_pool + 1), + nn.Flatten(start_dim=1), + nn.LayerNorm(d_model), + nn.Linear(d_model, mlp_dim[0]), + nn.ELU(), + nn.Dropout(p), + nn.Linear(mlp_dim[0], mlp_dim[1]), + nn.ELU(), + nn.Dropout(p), + nn.Linear(mlp_dim[1], 1 if nb_classes <= 2 else nb_classes), + ) + + def forward(self, x): + """ + :meta private: + """ + x = self.encoder(x) + x = self.MLP(x) + if not (self.return_logits): + if self.nb_classes <= 2: + x = torch.sigmoid(x) + else: + x = F.softmax(x, dim=1) + return x + + +class xEEGNet(nn.Module): + """ + Pytorch implementation of xEEGNet. + + For more information see the following paper [xEEG]_ . + The original implementation of EEGconformer can be found here [xEEGgit]_ . + The expected **input** is a **3D tensor** with size: + (Batch x Channels x Samples). + + Parameters + ---------- + nb_classes: int + The number of classes. If less than 2, a binary classification problem + is considered (output dimensions will be [batch, 1] in this case). + Chans: int + The number of EEG channels. + Samples: int + The sample length (number of time steps). + It will be used to calculate the embedding size (for head initialization). + Fs: int + The sampling rate of the EEG signal in Hz. + It is used to initialize the weights of the filters. + Must be specified even if `random_temporal_filter` is False. + F1: int, optional + The number of output filters in the temporal convolution layer. + + Default = 7 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 125 + F2: int, optional + The number of output filters in the spatial convolution layer. + + Default = 7 + Pool: int, optional + Kernel size for temporal pooling. + + Default = 75 + p: float, optional + Dropout probability in [0,1) + + Default = 0.2 + random_temporal_filter: bool, optional + If True, initialize the temporal filter weights randomly. + Otherwise, use a passband FIR filter. + + Default = False + freeze_temporal: int, optional + Number of forward steps to keep the temporal layer frozen. + + Default = 1e12 + spatial_depthwise: bool, optional + Whether to apply a depthwise layer in the spatial convolution. + + Default = True + log_activation_base: str, optional + Base for the logarithmic activation after pooling. + Options: "e" (natural log), "10" (logarithm base 10), "dB" (decibel scale). + + Default = "dB" + norm_type: str, optional + The type of normalization. Expected values are "batch" or "instance". + + Default = "batchnorm" + global_pooling: bool, optional + If True, apply global average pooling instead of flattening. + + Default = True + bias: list[int, int], optional + A 2-element list with boolean values. + If the first element is True, a bias will be added to the temporal + convolutional layer. + If the second element is True, a bias will be added to the spatial + convolutional layer. + If the third element is True, a bias will be added to the final dense layer. + + Default = [False, False, False] + return_logits: bool, optional + If True, return the output as logit. + It is suggested to not use False as the pytorch crossentropy loss function + applies the softmax internally. + + Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + References + ---------- + .. [xEEG] zanola et al., xEEGNet: Towards Explainable AI in EEG Dementia + Classification. arXiv preprint. 2025. https://doi.org/10.48550/arXiv.2504.21457 + .. [xEEGgit] https://github.com/MedMaxLab/shallownetXAI + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.xEEGNet(3, 8, 512, 125) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 3]) + + """ + + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + Fs: int, + F1: int = 7, + K1: int = 125, + F2: int = 7, + Pool: int = 75, + p: float = 0.2, + random_temporal_filter=False, + freeze_temporal: int = 1e12, + spatial_depthwise: bool = True, + log_activation_base: str = "dB", + norm_type: str = "batchnorm", + global_pooling=True, + bias: list[int, int, int] = [False, False, False], + dense_hidden: int = -1, + return_logits=True, + seed=None, + ): + + super(xEEGNet, self).__init__() + + self.nb_classes = nb_classes + self.return_logits = return_logits + self.encoder = xEEGNetEncoder( + Chans, + Fs, + F1, + K1, + F2, + Pool, + p, + random_temporal_filter, + freeze_temporal, + spatial_depthwise, + log_activation_base, + norm_type, + global_pooling, + bias, + seed, + ) + + if global_pooling: + self.emb_size = F2 + else: + self.emb_size = F2 * ((Samples - K1 + 1 - Pool) // max(1, int(Pool // 5)) + 1) + + _reset_seed(seed) + if dense_hidden <= 0: + self.Dense = nn.Linear( + self.emb_size, 1 if nb_classes <= 2 else nb_classes, bias=bias[2] + ) + else: + self.Dense = nn.Sequential( + nn.Linear(self.emb_size, dense_hidden, bias=True), + nn.ReLU(), + nn.Linear(dense_hidden, 1 if nb_classes <= 2 else nb_classes, bias=bias[2]), + ) + + def forward(self, x): + """ + :meta private: + """ + x = self.encoder(x) + x = self.Dense(x) + if not (self.return_logits): + if self.nb_classes <= 2: + x = torch.sigmoid(x) + else: + x = F.softmax(x, dim=1) + return x diff --git a/selfeeg/utils/utils.py b/selfeeg/utils/utils.py index c8958a3..bd6fa4e 100644 --- a/selfeeg/utils/utils.py +++ b/selfeeg/utils/utils.py @@ -382,7 +382,9 @@ class adaptation of the ``scale_range_with_soft_clip`` function. """ - def __init__(self, Range=200, asintote=1.2, scale="mV", exact=True): + def __init__( + self, Range: float = 200, asintote: float = 1.2, scale: str = "mV", exact: bool = True + ): if Range < 0: raise ValueError("Range cannot be lower than 0") if asintote is None: @@ -910,3 +912,26 @@ def count_parameters( ) print(" " * char2add + "TOTAL TRAINABLE PARAMS" + " " * char2add2, total_params) return (layer_table, total_params) if return_table else total_params + + +def _reset_seed( + seed: int = None, + reset_random: bool = True, + reset_numpy: bool = True, + reset_torch: bool = True, +) -> None: + """ + :meta private: + """ + if seed is not None: + if seed <= 0: + raise ValueError("seed must be a nonnegative number") + if reset_numpy: + np.random.seed(seed) + if reset_random: + random.seed(seed) + if reset_torch: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/test/EEGself/augmentation/functional_test.py b/test/EEGself/augmentation/functional_test.py index 92daed1..2df53bb 100644 --- a/test/EEGself/augmentation/functional_test.py +++ b/test/EEGself/augmentation/functional_test.py @@ -210,6 +210,27 @@ def test_shift_frequency(self): self.assertTrue(math.isclose(per1[24], per2[104], rel_tol=1e-5)) print(" shift frequency OK: tested", N + len(aug_args), "combinations of input arguments") + def test_phase_swap(self): + print("Testing phase swap...", end="", flush=True) + aug_args = {"x": [self.x3, self.x3np]} + aug_args = self.makeGrid(aug_args) + for i in aug_args: + xaug = aug.phase_swap(**i) + if isinstance(xaug, torch.Tensor): + self.assertTrue(torch.isnan(xaug).sum() == 0) + self.assertFalse(torch.equal(i["x"], xaug)) + else: + self.assertTrue(np.isnan(xaug).sum() == 0) + self.assertFalse(np.array_equal(i["x"], xaug)) + N = len(aug_args) + if self.device.type != "cpu": + aug_args = {"x": [self.x3gpu]} + aug_args = self.makeGrid(aug_args) + for i in aug_args: + xaug = aug.phase_swap(**i) + + print(" phase_swap OK: tested", N + len(aug_args), "combinations of input arguments") + def test_flip_vertical(self): print("Testing flip vertical...", end="", flush=True) aug_args = { diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index 50202db..6ba7f76 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -69,6 +69,7 @@ def test_ATCNet(self): "F1": [12, 8], "D": [2, 3], "return_logits": [False], + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -106,6 +107,7 @@ def test_DeepConvNet(self): "dropRate": [0.5], "max_dense_norm": [1.0], "return_logits": [False], + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -127,6 +129,45 @@ def test_DeepConvNet(self): self.assertGreaterEqual(out.min(), 0) print(" DeepConvNet OK: tested ", len(DCN_grid), " combinations of input arguments") + def test_EEGConformer(self): + print("Testing EEGConformer...", end="", flush=True) + EEGcon_args = { + "nb_classes": [2, 4], + "Samples": [2048], + "Chans": [self.Chan], + "F": [40], + "K1": [25, 12], + "Pool": [75, 50], + "stride_pool": [20], + "nlayers": [4], + "nheads": [8, 10], + "dim_feedforward": [80], + "activation_transformer": ["gelu"], + "mlp_dim": [[128, 32]], + "return_logits": [False], + "seed": [42], + } + EEGcon_args = self.makeGrid(EEGcon_args) + for i in EEGcon_args: + model = models.EEGConformer(**i) + out = model(self.x) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + + if self.device.type != "cpu": + for i in EEGcon_args: + model = models.EEGConformer(**i).to(device=self.device) + out = model(self.x2) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + print(" EEGConformer OK: tested", len(EEGcon_args), " combinations of input arguments") + def test_EEGInception(self): print("Testing EEGInception...", end="", flush=True) EEGin_args = { @@ -142,6 +183,7 @@ def test_EEGInception(self): "max_depth_norm": [1.0], "return_logits": [False], "bias": [True, False], + "seed": [42], } EEGin_args = self.makeGrid(EEGin_args) for i in EEGin_args: @@ -178,6 +220,7 @@ def test_EEGNet(self): "pool2": [8, 16], "separable_kernel": [16, 32], "return_logits": [False], + "seed": [42], } EEGnet_args = self.makeGrid(EEGnet_args) for i in EEGnet_args: @@ -214,6 +257,7 @@ def test_EEGSym(self): "pool": [2, 3], "bias": [True, False], "return_logits": [False], + "seed": [42], } EEGsym_args = self.makeGrid(EEGsym_args) for i in EEGsym_args: @@ -249,6 +293,7 @@ def test_FBCNet(self): "FilterType": ["Cheby2", "ellip"], "TemporalType": ["var", "max", "mean", "std", "logvar"], "return_logits": [False], + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -282,6 +327,7 @@ def test_ResNet(self): "kernLength": [7, 13], "addConnection": [True, False], "return_logits": [False], + "seed": [42], } EEGres_args = self.makeGrid(EEGres_args) for i in EEGres_args: @@ -314,6 +360,7 @@ def test_ShallowNet(self): "K1": [25, 12], "Pool": [75, 50], "return_logits": [False], + "seed": [42], } EEGsha_args = self.makeGrid(EEGsha_args) for i in EEGsha_args: @@ -346,6 +393,7 @@ def test_StagerNet(self): "kernLength": [64, 120], "Pool": [16, 8], "return_logits": [False], + "seed": [42], } EEGsta_args = self.makeGrid(EEGsta_args) for i in EEGsta_args: @@ -378,6 +426,7 @@ def test_STNet(self): "kernlength": [5, 7], "dense_size": [1024, 512], "return_logits": [False], + "seed": [42], } EEGstn_args = self.makeGrid(EEGstn_args) for i in EEGstn_args: @@ -415,6 +464,7 @@ def test_TinySleepNet(self): "pool": [16, 5], "hidden_lstm": [128, 50], "return_logits": [False], + "seed": [42], } EEGsleep_args = self.makeGrid(EEGsleep_args) for i in EEGsleep_args: @@ -437,6 +487,54 @@ def test_TinySleepNet(self): self.assertGreaterEqual(out.min(), 0) print(" TinySleepNet OK: tested", len(EEGsleep_args), " combinations of input arguments") + def test_xEEGNet(self): + print("Testing xEEGNet...", end="", flush=True) + EEGxeg_args = { + "nb_classes": [4], + "Samples": [2048], + "Chans": [self.Chan], + "Fs": [125], + "F1": [7, 126], + "K1": [125, 75], + "F2": [7, 126], + "Pool": [75, 50], + "random_temporal_filter": [True, False], + "freeze_temporal": [0, 1e12], + "spatial_depthwise": [True, False], + "log_activation_base": ["dB"], + "norm_type": ["batchnorm"], + "global_pooling": [True, False], + "bias": [[False] * 3], + "dense_hidden": [-1, 32], + "return_logits": [False], + "seed": [42], + } + + EEGxeg_args = self.makeGrid(EEGxeg_args) + for i in EEGxeg_args: + if i["F1"] > i["F2"] and i["spatial_depthwise"]: + continue + model = models.xEEGNet(**i) + out = model(self.x) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + + if self.device.type != "cpu": + for i in EEGxeg_args: + if i["F1"] > i["F2"] and i["spatial_depthwise"]: + continue + model = models.xEEGNet(**i).to(device=self.device) + out = model(self.x2) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + print(" xEEGNet OK: tested", len(EEGxeg_args), " combinations of input arguments") + if __name__ == "__main__": unittest.main()