From 62ccaa5a4a75e6f2a01713fac671382f70071d1c Mon Sep 17 00:00:00 2001 From: fedepup Date: Fri, 30 May 2025 09:17:42 +0000 Subject: [PATCH 1/4] add seed to models and EEGConformer --- RELEASE.md | 3 + docs/selfeeg.models.rst | 2 + selfeeg/models/__init__.py | 2 + selfeeg/models/encoders.py | 347 +++++++++++++++++----- selfeeg/models/zoo.py | 493 ++++++++++++++++++++++++-------- selfeeg/utils/utils.py | 30 +- test/EEGself/models/zoo_test.py | 50 ++++ 7 files changed, 729 insertions(+), 198 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 2050f60..4d769f4 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -2,6 +2,9 @@ **Functionality** +- **models module**: + - models can be initialized with a custom seed. + - add EEGConformer. - **dataloading module**: - EEGDataset now supports EEG with multiple labels (1 per window partition). - **ssl module**: diff --git a/docs/selfeeg.models.rst b/docs/selfeeg.models.rst index 66ec893..86074e6 100644 --- a/docs/selfeeg.models.rst +++ b/docs/selfeeg.models.rst @@ -40,6 +40,7 @@ Classes :template: classtemplate.rst DeepConvNetEncoder + EEGConformerEncoder EEGInceptionEncoder EEGNetEncoder EEGSymEncoder @@ -65,6 +66,7 @@ Classes ATCNet DeepConvNet + EEGConformer EEGInception EEGNet EEGSym diff --git a/selfeeg/models/__init__.py b/selfeeg/models/__init__.py index da1f618..9ae1820 100644 --- a/selfeeg/models/__init__.py +++ b/selfeeg/models/__init__.py @@ -9,6 +9,7 @@ from .encoders import ( BasicBlock1, DeepConvNetEncoder, + EEGConformerEncoder, EEGInceptionEncoder, EEGNetEncoder, EEGSymEncoder, @@ -22,6 +23,7 @@ from .zoo import ( ATCNet, DeepConvNet, + EEGConformer, EEGInception, EEGNet, EEGSym, diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index 0924ffb..1c5f141 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -9,10 +9,12 @@ SeparableConv2d, FilterBank, ) +from ..utils.utils import _reset_seed __all__ = [ "BasicBlock1", "DeepConvNetEncoder", + "EEGConformerEncoder", "EEGInceptionEncoder", "EEGNetEncoder", "EEGSymEncoder", @@ -81,6 +83,11 @@ class EEGNetEncoder(nn.Module): If None no constraint will be applied. Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note @@ -102,24 +109,26 @@ class EEGNetEncoder(nn.Module): def __init__( self, - Chans, - kernLength=64, - dropRate=0.5, - F1=8, - D=2, - F2=16, - dropType="Dropout", - ELUalpha=1, - pool1=4, - pool2=8, - separable_kernel=16, - depthwise_max_norm=1.0, + Chans: int, + kernLength: int=64, + dropRate: float=0.5, + F1: int=8, + D: int=2, + F2: int=16, + dropType: str="Dropout", + ELUalpha: int=1, + pool1: int=4, + pool2: int=8, + separable_kernel: int=16, + depthwise_max_norm: float=1.0, + seed: int=None ): if dropType not in ["SpatialDropout2D", "Dropout"]: - raise ValueError("implemented Dropout types are" " 'Dropout' or 'SpatialDropout2D '") + raise ValueError("Dropout types are 'Dropout' or 'SpatialDropout2D'") super(EEGNetEncoder, self).__init__() + _reset_seed(seed) # Layer 1 self.conv1 = nn.Conv2d(1, F1, (1, kernLength), padding="same", bias=False) @@ -154,26 +163,20 @@ def forward(self, x): """ :meta private: """ - # Layer 1 x = torch.unsqueeze(x, 1) x = self.conv1(x) x = self.batchnorm1(x) - - # Layer 2 x = self.conv2(x) x = self.batchnorm2(x) x = self.elu2(x) x = self.pooling2(x) x = self.drop2(x) - - # Layer 3 x = self.sepconv3(x) x = self.batchnorm3(x) x = self.elu3(x) x = self.pooling3(x) x = self.drop3(x) x = self.flatten3(x) - return x @@ -226,6 +229,11 @@ class DeepConvNetEncoder(nn.Module): The dropout percentage in range [0,1]. Default = 0.5 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -241,19 +249,21 @@ class DeepConvNetEncoder(nn.Module): def __init__( self, - Chans, - kernLength=10, - F=25, - Pool=3, - stride=3, - max_norm=None, - batch_momentum=0.1, - ELUalpha=1, - dropRate=0.5, + Chans: int, + kernLength: int=10, + F: int=25, + Pool: int=3, + stride: int=3, + max_norm: float=None, + batch_momentum: float=0.1, + ELUalpha: int=1, + dropRate: float=0.5, + seed: int=None ): super(DeepConvNetEncoder, self).__init__() - + _reset_seed(seed) + self.conv1 = ConstrainedConv2d( 1, F, (1, kernLength), padding="valid", stride=(1, 1), max_norm=max_norm ) @@ -372,6 +382,11 @@ class EEGInceptionEncoder(nn.Module): If None no constraint will be included. Default = 1. + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -387,19 +402,22 @@ class EEGInceptionEncoder(nn.Module): def __init__( self, - Chans, - F1=8, - D=2, - kernel_size=64, - pool=4, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - batch_momentum=0.1, - max_depth_norm=1.0, + Chans: int, + F1: int=8, + D: int=2, + kernel_size: int=64, + pool: int=4, + dropRate: float=0.5, + ELUalpha: float=1.0, + bias: bool=True, + batch_momentum: float=0.1, + max_depth_norm: float=1.0, + seed: int=None ): super(EEGInceptionEncoder, self).__init__() + _reset_seed(seed) + self.inc1 = nn.Sequential( nn.Conv2d(1, F1, (1, kernel_size), padding="same", bias=bias), nn.BatchNorm2d(F1, momentum=batch_momentum), @@ -545,6 +563,11 @@ class TinySleepNetEncoder(nn.Module): Hidden size of the lstm block. Default = 128 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -560,17 +583,20 @@ class TinySleepNetEncoder(nn.Module): def __init__( self, - Chans, - Fs, - F=128, - kernlength=8, - pool=8, - dropRate=0.5, - batch_momentum=0.1, - hidden_lstm=128, + Chans: int, + Fs: int, + F: int=128, + kernlength: int=8, + pool: int=8, + dropRate: float=0.5, + batch_momentum: float=0.1, + hidden_lstm: int=128, + seed: int=None ): super(TinySleepNetEncoder, self).__init__() + _reset_seed(seed) + self.conv1 = nn.Conv1d(Chans, F, int(Fs // 2), stride=int(Fs // 16), padding="valid") self.BN1 = nn.BatchNorm1d(F, momentum=batch_momentum) self.Relu = nn.ReLU() @@ -611,14 +637,10 @@ def forward(self, x): x = self.conv4(x) x = self.BN4(x) x = self.Relu(x) - x = self.pool2(x) x = self.drop2(x) - x = torch.permute(x, (2, 0, 1)) - out, (ht, ct) = self.lstm1(x) - return ht[-1] @@ -649,6 +671,11 @@ class StagerNetEncoder(nn.Module): The temporal pooling kernel size. Default = 4 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -662,9 +689,10 @@ class StagerNetEncoder(nn.Module): """ - def __init__(self, Chans, kernLength=64, F=8, Pool=16): + def __init__(self, Chans, kernLength: int=64, F: int=8, Pool: int=16, seed: int=None): super(StagerNetEncoder, self).__init__() + _reset_seed(seed) self.conv1 = nn.Conv2d(1, Chans, (Chans, 1), stride=(1, 1), bias=True) self.conv2 = nn.Conv2d(1, F, (1, kernLength), stride=(1, 1), padding="same") @@ -676,7 +704,6 @@ def __init__(self, Chans, kernLength=64, F=8, Pool=16): def forward(self, x): """ :meta private: - """ x = torch.unsqueeze(x, 1) x = self.conv1(x) @@ -720,6 +747,11 @@ class ShallowNetEncoder(nn.Module): Dropout probability. Must be in [0,1) Default= 0.2 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -740,9 +772,11 @@ class ShallowNetEncoder(nn.Module): """ - def __init__(self, Chans, F=40, K1=25, Pool=75, p=0.2): + def __init__(self, Chans, F: int=40, K1: int=25, Pool: int=75, p: float=0.2, seed: int=None): super(ShallowNetEncoder, self).__init__() + _reset_seed(seed) + self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1)) self.conv2 = nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)) self.batch1 = nn.BatchNorm2d(F) @@ -774,7 +808,7 @@ class BasicBlock1(nn.Module): :meta private: """ - def __init__(self, inplanes, planes, kernLength=7, stride=1): + def __init__(self, inplanes: int, planes: int, kernLength: int=7, stride: int=1): super(BasicBlock1, self).__init__() self.stride = stride @@ -820,18 +854,13 @@ def forward(self, x): :meta private: """ residual = self.downsample(x) - # print('residual: ', residual.shape) out = self.conv1(x) out = self.bn1(out) out = self.relu(out) - # print('out 1: ', out.shape) out = self.conv2(out) out = self.bn2(out) - # print('out 2: ', out.shape) - out += residual out = self.relu(out) - return out @@ -894,7 +923,12 @@ class ResNet1DEncoder(nn.Module): 2. nn.BatchNorm2d() 3. nn.ReLU() - Default = None + Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -915,7 +949,7 @@ class ResNet1DEncoder(nn.Module): def __init__( self, - Chans, + Chans: int, block: nn.Module = BasicBlock1, Layers: "list of 4 ints" = [2, 2, 2, 2], inplane: int = 16, @@ -923,9 +957,12 @@ def __init__( addConnection: bool = False, preBlock: nn.Module = None, postBlock: nn.Module = None, + seed: int=None ): super(ResNet1DEncoder, self).__init__() + _reset_seed(seed) + self.inplane = inplane self.kernLength = kernLength self.connection = addConnection @@ -1121,6 +1158,11 @@ class STNetEncoder(nn.Module): If True, adds a learnable bias to the convolutional layers. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -1134,8 +1176,17 @@ class STNetEncoder(nn.Module): """ - def __init__(self, Samples, F=256, kernlength=5, dropRate=0.5, bias=True): + def __init__( + self, + Samples, + F: int=256, + kernlength: int=5, + dropRate: float=0.5, + bias: bool=True, + seed: int=None + ): super(STNetEncoder, self).__init__() + _reset_seed(seed) self.conv1 = nn.Conv2d( Samples, F, kernel_size=kernlength - 2, stride=1, padding="same", bias=bias @@ -1219,7 +1270,6 @@ class EEGSymInception(nn.Module): """ :meta private: """ - def __init__( self, in_channels, @@ -1288,7 +1338,6 @@ class EEGSymResBlock(nn.Module): """ :meta private: """ - def __init__( self, in_channels, @@ -1396,6 +1445,11 @@ class EEGSymEncoder(nn.Module): Currently not implemented, will be added in future releases. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Example ------- @@ -1411,21 +1465,22 @@ class EEGSymEncoder(nn.Module): def __init__( self, - Chans, - Samples, - Fs, - scales_time=(500, 250, 125), - lateral_chans=3, - first_left=True, - F=8, - pool=2, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - residual=True, + Chans: int, + Samples: int, + Fs: float, + scales_time: tuple=(500, 250, 125), + lateral_chans: int=3, + first_left: bool=True, + F: int=8, + pool: int=2, + dropRate: float=0.5, + ELUalpha: float=1.0, + bias: bool=True, + residual: bool=True, + seed: int=None ): - super(EEGSymEncoder, self).__init__() + _reset_seed(seed) self.input_samples = int(Samples * Fs / 1000) self.scales_samples = [int(s * Fs / 1000) for s in scales_time] @@ -1771,6 +1826,11 @@ class FBCNetEncoder(nn.Module): The maximum norm each filter can have in the depthwise block. If None no constraint will be included. + Default = None + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + Default = None Example @@ -1802,6 +1862,7 @@ def __init__( TemporalStride: int = 4, batch_momentum: float = 0.1, depthwise_max_norm=None, + seed: int=None ): super(FBCNetEncoder, self).__init__() self.FilterBands = FilterBands @@ -1858,3 +1919,133 @@ def forward(self, x): x = self.TMB(x) x = torch.flatten(x, 1) return x + + +class EEGConformerEncoder(nn.Module): + """ + Pytorch implementation of the EEGConformer Encoder. + + See EEGConformer for some references. + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). + + Parameters + ---------- + Chans: int + The number of EEG channels. + F: int, optional + The number of output filters in the temporal convolution layer. + + Default = 40 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 25 + Pool: int, optional + The temporal pooling kernel size. + + Default = 75 + stride_pool: int, optional + The temporal pooling stride. + + Default = 15 + d_model: int, optional + The embedding size. It is the number of expected features in the input of + the transformer encoder layer. + + Default = 40 + nlayers: int, optional + The number of transformer encoder layers. + + Default = 6 + nheads: int, optional + The number of heads in the multi-head attention layers. + + Default = 10 + dim_feedforward: int, optional + The dimension of the feedforward hidden layer in the transformer encoder. + + Default = 160 + activation_transformer: str or Callabel, optional + The activation function in the transformer encoder. See the PyTorch + TransformerEncoderLayer documentation for accepted inputs. + + Default = "gelu" + p: float, optional + Dropout probability in the tokenizer. Must be in [0,1) + + Default= 0.2 + p_transformer: float, optional + Dropout probability in the transformer encoder. Must be in [0,1) + + Default= 0.5 + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.EEGConformerEncoder(8) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 224]) + >>> print(torch.isnan(out).sum()) # shoud return 0 + + """ + + def __init__( + self, + Chans, + F: int=40, + K1: int=25, + Pool: int=75, + stride_pool: int=15, + d_model: int=40, + nlayers: int=6, + nheads: int=10, + dim_feedforward: int=160, + activation_transformer: str or Callable="gelu", + p: float=0.2, + p_transformer: float=0.5, + seed: int=None + ): + + super(EEGConformerEncoder, self).__init__() + _reset_seed(seed) + + self.tokenizer = nn.Sequential( + nn.Conv2d(1, F, (1, K1), stride=(1, 1)), + nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)), + nn.BatchNorm2d(F), + nn.AvgPool2d((1, Pool), stride=(1, stride_pool)), + nn.Dropout(p) + ) + self.projection = nn.Conv2d(F, d_model, (1, 1)) + # squeeze 2 and permute + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model, + nhead = nheads, + dim_feedforward = dim_feedforward, + dropout = p_transformer, + activation = "gelu", + batch_first = True + ), + num_layers = nlayers, + ) + + def forward(self, x): + """ + :meta private: + """ + x = torch.unsqueeze(x, 1) + x = self.tokenizer(x) + x = x.squeeze(2) + x = torch.permute(x,[0,2,1]) + x = self.transformer(x) + x = torch.permute(x,[0,2,1]) + return x \ No newline at end of file diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 3f79d2b..83f57bd 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -2,11 +2,11 @@ import torch.nn as nn import torch.nn.functional as F from .layers import ConstrainedDense, ConstrainedConv1d, ConstrainedConv2d - from .encoders import ( BasicBlock1, DeepConvNetEncoder, EEGInceptionEncoder, + EEGConformerEncoder, EEGNetEncoder, EEGSymEncoder, FBCNetEncoder, @@ -16,10 +16,12 @@ STNetEncoder, TinySleepNetEncoder, ) +from ..utils.utils import _reset_seed __all__ = [ "ATCNet", "DeepConvNet", + "EEGConformer", "EEGInception", "EEGNet", "EEGSym", @@ -103,6 +105,11 @@ class EEGNet(nn.Module): applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -129,22 +136,23 @@ class EEGNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - kernLength=64, - dropRate=0.5, - F1=8, - D=2, - F2=16, - norm_rate=0.25, - dropType="Dropout", - ELUalpha=1, - pool1=4, - pool2=8, - separable_kernel=16, - depthwise_max_norm=1.0, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + kernLength: int=64, + dropRate: float=0.5, + F1: int=8, + D: int=2, + F2: int=16, + norm_rate: int=0.25, + dropType: str="Dropout", + ELUalpha: int=1, + pool1: int=4, + pool2: int=8, + separable_kernel: int=16, + depthwise_max_norm: float=1.0, + return_logits: bool=True, + seed: int=None ): super(EEGNet, self).__init__() @@ -164,7 +172,10 @@ def __init__( pool2, separable_kernel, depthwise_max_norm, + seed ) + + _reset_seed(seed) self.Dense = ConstrainedDense( F2 * (Samples // int(pool1 * pool2)), 1 if nb_classes <= 2 else nb_classes, @@ -253,6 +264,11 @@ class DeepConvNet(nn.Module): the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -281,32 +297,35 @@ class DeepConvNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - kernLength=10, - F=25, - Pool=3, - stride=3, - max_norm=None, - batch_momentum=0.1, - ELUalpha=1, - dropRate=0.5, - max_dense_norm=None, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + kernLength: int=10, + F: int=25, + Pool: int=3, + stride: int=3, + max_norm: int=None, + batch_momentum: float=0.1, + ELUalpha: int=1, + dropRate: float=0.5, + max_dense_norm: float=None, + return_logits: bool=True, + seed: int=None ): super(DeepConvNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits self.encoder = DeepConvNetEncoder( - Chans, kernLength, F, Pool, stride, max_norm, batch_momentum, ELUalpha, dropRate + Chans, kernLength, F, Pool, stride, max_norm, batch_momentum, ELUalpha, dropRate, seed ) k = kernLength Dense_input = [Samples] * 8 for i in range(4): Dense_input[i * 2] = Dense_input[i * 2 - 1] - k + 1 Dense_input[i * 2 + 1] = (Dense_input[i * 2] - Pool) // stride + 1 + + _reset_seed(seed) self.Dense = ConstrainedDense( F * 8 * Dense_input[-1], 1 if nb_classes <= 2 else nb_classes, max_norm=max_dense_norm ) @@ -392,6 +411,11 @@ class EEGInception(nn.Module): the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -416,19 +440,20 @@ class EEGInception(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - F1=8, - D=2, - kernel_size=64, - pool=4, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - batch_momentum=0.1, - max_depth_norm=1.0, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + F1: int=8, + D: int=2, + kernel_size: int=64, + pool: int=4, + dropRate: float=0.5, + ELUalpha: float=1.0, + bias: bool=True, + batch_momentum: float=0.1, + max_depth_norm: float=1.0, + return_logits: bool=True, + seed: int=None ): super(EEGInception, self).__init__() self.nb_classes = nb_classes @@ -444,7 +469,10 @@ def __init__( bias, batch_momentum, max_depth_norm, + seed ) + + _reset_seed(seed) self.Dense = nn.Linear( int((F1 * 3) / 4) * int((Samples // (pool * (int(pool // 2) ** 3)))), 1 if nb_classes <= 2 else nb_classes, @@ -523,6 +551,11 @@ class TinySleepNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -547,26 +580,28 @@ class TinySleepNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Fs, - F=128, - kernlength=8, - pool=8, - dropRate=0.5, - batch_momentum=0.1, - max_dense_norm=2.0, - hidden_lstm=128, - return_logits=True, + nb_classes: int, + Chans: int, + Fs: int, + F: int=128, + kernlength: int=8, + pool: int=8, + dropRate: float=0.5, + batch_momentum: float=0.1, + max_dense_norm: float=2.0, + hidden_lstm: int=128, + return_logits: bool=True, + seed: int=None ): super(TinySleepNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits self.encoder = TinySleepNetEncoder( - Chans, Fs, F, kernlength, pool, dropRate, batch_momentum, hidden_lstm + Chans, Fs, F, kernlength, pool, dropRate, batch_momentum, hidden_lstm, seed ) + _reset_seed(seed) self.drop3 = nn.Dropout1d(dropRate) self.Dense = ConstrainedDense( hidden_lstm, 1 if nb_classes <= 2 else nb_classes, max_norm=max_dense_norm @@ -629,6 +664,11 @@ class StagerNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -650,21 +690,25 @@ class StagerNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - dropRate=0.5, - kernLength=64, - F=8, - Pool=16, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + dropRate: float=0.5, + kernLength: int=64, + F: int=8, + Pool: int=16, + return_logits: bool=True, + seed: int=None ): super(StagerNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = StagerNetEncoder(Chans, kernLength=kernLength, F=F, Pool=Pool) + self.encoder = StagerNetEncoder( + Chans, kernLength=kernLength, F=F, Pool=Pool, seed=seed + ) + _reset_seed(seed) self.drop = nn.Dropout(p=dropRate) self.Dense = nn.Linear( Chans * F * (int((int((Samples - Pool) / Pool + 1) - Pool) / Pool + 1)), @@ -728,6 +772,11 @@ class ShallowNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -754,13 +803,26 @@ class ShallowNet(nn.Module): """ - def __init__(self, nb_classes, Chans, Samples, F=40, K1=25, Pool=75, p=0.2, return_logits=True): + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + F: int=40, + K1: int=25, + Pool: int=75, + p: float=0.2, + return_logits: bool=True, + seed: int=None + ): super(ShallowNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = ShallowNetEncoder(Chans, F=F, K1=K1, Pool=Pool, p=p) + self.encoder = ShallowNetEncoder(Chans, F=F, K1=K1, Pool=Pool, p=p, seed=seed) + + _reset_seed(seed) self.Dense = nn.Linear( F * ((Samples - K1 + 1 - Pool) // 15 + 1), 1 if nb_classes <= 2 else nb_classes ) @@ -864,6 +926,11 @@ class ResNet1D(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None Note ---- @@ -889,9 +956,9 @@ class ResNet1D(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, + nb_classes: int, + Chans: int, + Samples: int, block: nn.Module = BasicBlock1, Layers: "list of 4 int" = [2, 2, 2, 2], inplane: int = 16, @@ -900,7 +967,8 @@ def __init__( preBlock: nn.Module = None, postBlock: nn.Module = None, classifier: nn.Module = None, - return_logits=True, + return_logits: bool=True, + seed: int=None ): super(ResNet1D, self).__init__() @@ -916,8 +984,11 @@ def __init__( addConnection=addConnection, preBlock=preBlock, postBlock=postBlock, + seed=seed ) + # Classifier + _reset_seed(seed) if classifier is None: if addConnection: out1 = int((Samples + 2 * (int(kernLength // 2)) - kernLength) // 2) + 1 @@ -999,6 +1070,11 @@ class STNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1019,23 +1095,25 @@ class STNet(nn.Module): def __init__( self, - nb_classes, - Samples, - grid_size=9, - F=256, - kernlength=5, - dropRate=0.5, - bias=True, - dense_size=1024, - return_logits=True, + nb_classes: int, + Samples: int, + grid_size: int=9, + F: int=256, + kernlength: int=5, + dropRate: float=0.5, + bias: bool=True, + dense_size: int=1024, + return_logits: bool=True, + seed: int=None ): super(STNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = STNetEncoder(Samples, F, kernlength, dropRate, bias) - self.drop3 = nn.Dropout(dropRate) + self.encoder = STNetEncoder(Samples, F, kernlength, dropRate, bias, seed=seed) + _reset_seed(seed) + self.drop3 = nn.Dropout(dropRate) self.Dense = nn.Sequential( nn.Linear(int(F / 16) * (grid_size**2), dense_size), nn.Dropout(dropRate), @@ -1131,6 +1209,11 @@ class EEGSym(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1151,20 +1234,21 @@ class EEGSym(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - Fs, - scales_time=(500, 250, 125), - lateral_chans=3, - first_left=True, - F=8, - pool=2, - dropRate=0.5, - ELUalpha=1.0, - bias=True, - residual=True, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + Fs: int, + scales_time: tuple=(500, 250, 125), + lateral_chans: int=3, + first_left: bool=True, + F: int=8, + pool: int=2, + dropRate: float=0.5, + ELUalpha: float=1.0, + bias: bool=True, + residual: bool=True, + return_logits: bool=True, + seed: int=None ): super(EEGSym, self).__init__() self.nb_classes = nb_classes @@ -1182,7 +1266,10 @@ def __init__( ELUalpha, bias, residual, + seed=seed ) + + _reset_seed(seed) self.Dense = nn.Linear(int((F * 9) / 2), 1 if nb_classes <= 2 else nb_classes) def forward(self, x): @@ -1304,6 +1391,11 @@ class FBCNet(nn.Module): to not use False as the pytorch crossentropy applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1344,6 +1436,7 @@ def __init__( linear_max_norm: float = None, classifier: nn.Module = None, return_logits: bool = True, + seed: int=None ): super(FBCNet, self).__init__() @@ -1367,8 +1460,11 @@ def __init__( TemporalStride, batch_momentum, depthwise_max_norm, + seed=seed ) + # Head + _reset_seed(seed) if classifier is None: self.head = ConstrainedDense( D * FilterBands * TemporalStride, @@ -1622,6 +1718,11 @@ class ATCNet(nn.Module): applies the softmax internally. Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None References ---------- @@ -1644,32 +1745,34 @@ class ATCNet(nn.Module): def __init__( self, - nb_classes, - Chans, - Samples, - Fs, - num_windows=5, - mha_heads=2, - tcn_depth=2, - F1=16, - D=2, - pool1=None, - pool2=None, - dropRate=0.3, - max_norm=None, - batchMomentum=0.1, - ELUAlpha=1.0, - mha_dropRate=0.5, - tcn_kernLength=4, - tcn_F=32, - tcn_ELUAlpha=0.0, - tcn_dropRate=0.3, - tcn_max_norm=None, - tcn_batchMom=0.1, - return_logits=True, + nb_classes: int, + Chans: int, + Samples: int, + Fs: float, + num_windows: int=5, + mha_heads: int=2, + tcn_depth: int=2, + F1: int=16, + D: int=2, + pool1: int=None, + pool2: int=None, + dropRate: float=0.3, + max_norm: float=None, + batchMomentum: float=0.1, + ELUAlpha: float=1.0, + mha_dropRate: float=0.5, + tcn_kernLength: int=4, + tcn_F: int=32, + tcn_ELUAlpha: float=0.0, + tcn_dropRate: float=0.3, + tcn_max_norm: float=None, + tcn_batchMom: float=0.1, + return_logits: bool=True, + seed: int=None ): super(ATCNet, self).__init__() + _reset_seed(seed) # important for model construction self.return_logits = return_logits @@ -1721,6 +1824,7 @@ def __init__( ) # Construct each Branch + _reset_seed(seed) for i in range(self.num_windows): self.add_multi_head(i) self.add_residual_tcn(i) @@ -1778,3 +1882,154 @@ def forward(self, x): else: x = F.softmax(x, dim=1) return x + + +class EEGConformer(nn.Module): + """ + Pytorch implementation of EEGConformer. + + For more information see the following paper [EEGcon]_ . + The original implementation of EEGconformer can be found here [EEGcongit]_ . + + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). + + Parameters + ---------- + nb_classes: int + The number of classes. If less than 2, a binary classification problem + is considered (output dimensions will be [batch, 1] in this case). + Chans: int + The number of EEG channels. + F: int, optional + The number of output filters in the temporal convolution layer. + + Default = 40 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 25 + Pool: int, optional + The temporal pooling kernel size. + + Default = 75 + stride_pool: int, optional + The temporal pooling stride. + + Default = 15 + d_model: int, optional + The embedding size. It is the number of expected features in the input of + the transformer encoder layer. + + Default = 40 + nlayers: int, optional + The number of transformer encoder layers. + + Default = 6 + nheads: int, optional + The number of heads in the multi-head attention layers. + + Default = 10 + dim_feedforward: int, optional + The dimension of the feedforward hidden layer in the transformer encoder. + + Default = 160 + activation_transformer: str or Callabel, optional + The activation function in the transformer encoder. See the PyTorch + TransformerEncoderLayer documentation for accepted inputs. + + Default = "gelu" + p: float, optional + Dropout probability in the tokenizer. Must be in [0,1) + + Default = 0.2 + p_transformer: float, optional + Dropout probability in the transformer encoder. Must be in [0,1) + + Default = 0.5 + mlp_dim: list, optional + A two-element list indicating the output dimensions of the 2 FC + layers in the final classification head. + + Default = [256, 32] + return_logits: bool, optional + Whether to return the output as logit or probability. + It is suggested to not use False as the pytorch crossentropy loss function + applies the softmax internally. + + Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + References + ---------- + .. [EEGcon] Song et al., EEG Conformer: Convolutional Transformer for EEG Decoding + and Visualization. IEEE TNSRE. 2023. https://doi.org/10.1109/TNSRE.2022.3230250 + .. [EEGcongit] https://github.com/eeyhsong/EEG-Conformer + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.EEGConformer(2, 8, 512) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 1]) + + """ + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + F: int=40, + K1: int=25, + Pool: int=75, + stride_pool: int=15, + d_model: int=40, + nlayers: int=6, + nheads: int=10, + dim_feedforward: int=160, + activation_transformer: str or Callable="gelu", + p: float=0.2, + p_transformer: float=0.5, + mlp_dim: list[int,int]=[256,32], + return_logits: bool=True, + seed: int=None + ): + + super(EEGConformer, self).__init__() + self.return_logits = return_logits + self.nb_classes = nb_classes + + self.encoder = EEGConformerEncoder( + Chans, F, K1, Pool, stride_pool, d_model, nlayers, nheads, + dim_feedforward, activation_transformer, p, p_transformer, seed + ) + + _reset_seed(seed) + self.MLP = nn.Sequential( + nn.AvgPool1d((Samples - K1 + 1 - Pool) // stride_pool + 1), + nn.Flatten(start_dim=1), + nn.LayerNorm(d_model), + nn.Linear(d_model, mlp_dim[0]), + nn.ELU(), + nn.Dropout(p), + nn.Linear(mlp_dim[0], mlp_dim[1]), + nn.ELU(), + nn.Dropout(p), + nn.Linear(mlp_dim[1], 1 if nb_classes <= 2 else nb_classes) + ) + + def forward(self, x): + x = self.encoder(x) + x = self.MLP(x) + if not (self.return_logits): + if self.nb_classes <= 2: + x = torch.sigmoid(x) + else: + x = F.softmax(x, dim=1) + return x \ No newline at end of file diff --git a/selfeeg/utils/utils.py b/selfeeg/utils/utils.py index c8958a3..8a286fd 100644 --- a/selfeeg/utils/utils.py +++ b/selfeeg/utils/utils.py @@ -382,7 +382,13 @@ class adaptation of the ``scale_range_with_soft_clip`` function. """ - def __init__(self, Range=200, asintote=1.2, scale="mV", exact=True): + def __init__( + self, + Range: float=200, + asintote: float=1.2, + scale: str="mV", + exact: bool=True + ): if Range < 0: raise ValueError("Range cannot be lower than 0") if asintote is None: @@ -910,3 +916,25 @@ def count_parameters( ) print(" " * char2add + "TOTAL TRAINABLE PARAMS" + " " * char2add2, total_params) return (layer_table, total_params) if return_table else total_params + + +def _reset_seed( + seed: int=None, + reset_random: bool=True, + reset_numpy: bool=True, + reset_torch: bool=True, +) -> None: + """ + :meta private: + """ + if seed is not None: + assert seed>=0, "seed must be a nonnegative number" + if reset_numpy: + np.random.seed(seed) + if reset_random: + random.seed(seed) + if reset_torch: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index 50202db..e8d1196 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -69,6 +69,7 @@ def test_ATCNet(self): "F1": [12, 8], "D": [2, 3], "return_logits": [False], + "seed": [42] } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -106,6 +107,7 @@ def test_DeepConvNet(self): "dropRate": [0.5], "max_dense_norm": [1.0], "return_logits": [False], + "seed": [42] } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -127,6 +129,45 @@ def test_DeepConvNet(self): self.assertGreaterEqual(out.min(), 0) print(" DeepConvNet OK: tested ", len(DCN_grid), " combinations of input arguments") + def test_EEGConformer(self): + print("Testing EEGConformer...", end="", flush=True) + EEGcon_args = { + "nb_classes": [2, 4], + "Samples": [2048], + "Chans": [self.Chan], + "F": [40], + "K1": [25, 12], + "Pool": [75, 50], + "stride_pool": [20], + "nlayers": [4], + "nheads": [8, 10], + "dim_feedforward": [80], + "activation_transformer": ["gelu"], + "mlp_dim": [[128,32]], + "return_logits": [False], + "seed": [42] + } + EEGcon_args = self.makeGrid(EEGcon_args) + for i in EEGcon_args: + model = models.EEGConformer(**i) + out = model(self.x) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + + if self.device.type != "cpu": + for i in EEGcon_args: + model = models.EEGConformer(**i).to(device=self.device) + out = model(self.x2) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + print(" EEGConformer OK: tested", len(EEGcon_args), " combinations of input arguments") + def test_EEGInception(self): print("Testing EEGInception...", end="", flush=True) EEGin_args = { @@ -142,6 +183,7 @@ def test_EEGInception(self): "max_depth_norm": [1.0], "return_logits": [False], "bias": [True, False], + "seed": [42] } EEGin_args = self.makeGrid(EEGin_args) for i in EEGin_args: @@ -178,6 +220,7 @@ def test_EEGNet(self): "pool2": [8, 16], "separable_kernel": [16, 32], "return_logits": [False], + "seed": [42] } EEGnet_args = self.makeGrid(EEGnet_args) for i in EEGnet_args: @@ -214,6 +257,7 @@ def test_EEGSym(self): "pool": [2, 3], "bias": [True, False], "return_logits": [False], + "seed": [42] } EEGsym_args = self.makeGrid(EEGsym_args) for i in EEGsym_args: @@ -249,6 +293,7 @@ def test_FBCNet(self): "FilterType": ["Cheby2", "ellip"], "TemporalType": ["var", "max", "mean", "std", "logvar"], "return_logits": [False], + "seed": [42] } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -282,6 +327,7 @@ def test_ResNet(self): "kernLength": [7, 13], "addConnection": [True, False], "return_logits": [False], + "seed": [42] } EEGres_args = self.makeGrid(EEGres_args) for i in EEGres_args: @@ -314,6 +360,7 @@ def test_ShallowNet(self): "K1": [25, 12], "Pool": [75, 50], "return_logits": [False], + "seed": [42] } EEGsha_args = self.makeGrid(EEGsha_args) for i in EEGsha_args: @@ -346,6 +393,7 @@ def test_StagerNet(self): "kernLength": [64, 120], "Pool": [16, 8], "return_logits": [False], + "seed": [42] } EEGsta_args = self.makeGrid(EEGsta_args) for i in EEGsta_args: @@ -378,6 +426,7 @@ def test_STNet(self): "kernlength": [5, 7], "dense_size": [1024, 512], "return_logits": [False], + "seed": [42] } EEGstn_args = self.makeGrid(EEGstn_args) for i in EEGstn_args: @@ -415,6 +464,7 @@ def test_TinySleepNet(self): "pool": [16, 5], "hidden_lstm": [128, 50], "return_logits": [False], + "seed": [42] } EEGsleep_args = self.makeGrid(EEGsleep_args) for i in EEGsleep_args: From 13b88ec19842aa1fe523981bf1481f22d30de721 Mon Sep 17 00:00:00 2001 From: fedepup Date: Tue, 3 Jun 2025 08:54:50 +0000 Subject: [PATCH 2/4] add xEEGNet --- RELEASE.md | 1 + docs/selfeeg.models.rst | 3 + selfeeg/models/__init__.py | 2 + selfeeg/models/encoders.py | 255 +++++++++++++++++++++++++++++++- selfeeg/models/zoo.py | 175 ++++++++++++++++++++++ test/EEGself/models/zoo_test.py | 48 ++++++ 6 files changed, 482 insertions(+), 2 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 4d769f4..c30d59b 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -5,6 +5,7 @@ - **models module**: - models can be initialized with a custom seed. - add EEGConformer. + - add xEEGNet. - **dataloading module**: - EEGDataset now supports EEG with multiple labels (1 per window partition). - **ssl module**: diff --git a/docs/selfeeg.models.rst b/docs/selfeeg.models.rst index 86074e6..773f6d0 100644 --- a/docs/selfeeg.models.rst +++ b/docs/selfeeg.models.rst @@ -50,6 +50,7 @@ Classes StagerNetEncoder STNetEncoder TinySleepNetEncoder + xEEGNetEncoder models.zoo module @@ -76,3 +77,5 @@ Classes StagerNet STNet TinySleepNet + xEEGNet + diff --git a/selfeeg/models/__init__.py b/selfeeg/models/__init__.py index 9ae1820..6923e96 100644 --- a/selfeeg/models/__init__.py +++ b/selfeeg/models/__init__.py @@ -18,6 +18,7 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, + xEEGNetEncoder ) from .zoo import ( @@ -33,4 +34,5 @@ StagerNet, STNet, TinySleepNet, + xEEGNet ) diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index 1c5f141..42ddf6a 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -1,3 +1,5 @@ +from itertools import chain, combinations +from scipy.signal import firwin import torch import torch.nn as nn import torch.nn.functional as F @@ -24,6 +26,7 @@ "StagerNetEncoder", "STNetEncoder", "TinySleepNetEncoder", + "xEEGNetEncoder" ] @@ -2025,7 +2028,6 @@ def __init__( nn.Dropout(p) ) self.projection = nn.Conv2d(F, d_model, (1, 1)) - # squeeze 2 and permute self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model, @@ -2048,4 +2050,253 @@ def forward(self, x): x = torch.permute(x,[0,2,1]) x = self.transformer(x) x = torch.permute(x,[0,2,1]) - return x \ No newline at end of file + return x + + +class xEEGNetEncoder(nn.Module): + """ + Pytorch implementation of the xEEGNet Encoder. + + See xEEGNet for some references. + The expected **input** is a **3D tensor** with size: + (Batch x Channels x Samples). + + Parameters + ---------- + Chans: int + The number of EEG channels. + Fs: int + The sampling rate of the EEG signal in Hz. + It is used to initialize the weights of the filters. + Must be specified even if `random_temporal_filter` is False. + F1: int, optional + The number of output filters in the temporal convolution layer. + + Default = 7 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 125 + F2: int, optional + The number of output filters in the spatial convolution layer. + + Default = 7 + Pool: int, optional + Kernel size for temporal pooling. + + Default = 75 + p: float, optional + Dropout probability in [0,1) + + Default = 0.2 + random_temporal_filter: bool, optional + If True, initialize the temporal filter weights randomly. + Otherwise, use a passband FIR filter. + + Default = False + freeze_temporal: int, optional + Number of forward steps to keep the temporal layer frozen. + + Default = 1e12 + spatial_depthwise: bool, optional + Whether to apply a depthwise layer in the spatial convolution. + + Default = True + log_activation_base: str, optional + Base for the logarithmic activation after pooling. + Options: "e" (natural log), "10" (logarithm base 10), "dB" (decibel scale). + + Default = "dB" + norm_type: str, optional + The type of normalization. Expected values are "batch" or "instance". + + Default = "batchnorm" + global_pooling: bool, optional + If True, apply global average pooling instead of flattening. + + Default = True + bias: list[int, int], optional + A 2-element list with boolean values. + If the first element is True, a bias will be added to the temporal + convolutional layer. + If the second element is True, a bias will be added to the spatial + convolutional layer. + + Default = [False, False] + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4, 8, 512) + >>> mdl = models.xEEGNetEncoder(8, 125) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 7]) + + """ + def __init__( + self, + Chans: int, + Fs: int, + F1: int = 7, + K1: int = 125, + F2: int = 7, + Pool: int = 75, + p: float = 0.2, + random_temporal_filter = False, + freeze_temporal: int = 1e12, + spatial_depthwise: bool = True, + log_activation_base: str = "dB", + norm_type: str = "batchnorm", + global_pooling = True, + bias: list[int, int] = [False, False], + seed: int = None + ): + + super(xEEGNetEncoder, self).__init__() + + # Set seed before initializing layers + self.custom_seed = seed + _reset_seed(seed) + + self.Fs = Fs + self.chans = Chans + self.freeze_temporal = freeze_temporal + self.bias_1conv = bias[0] + self.bias_2conv = bias[1] + self.do_global_pooling = global_pooling + if self.Fs <=0 and not(random_temporal_filter): + raise ValueError( + "to properly initialize non random temporal fir filters, " + "Fs (sampling rate) must be given" + ) + + if random_temporal_filter: + self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1, 1), bias=self.bias_1conv) + else: + self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1,1), bias=self.bias_1conv) + self._initialize_custom_temporal_filter(self.custom_seed) + + if spatial_depthwise: + self.conv2 = nn.Conv2d( + F1, F2, (Chans, 1), stride=(1, 1), groups=F1, bias=self.bias_2conv + ) + else: + self.conv2 = nn.Conv2d( + F1, F2, (Chans, 1), stride=(1, 1), bias = self.bias_2conv + ) + + if "batch" in norm_type.casefold(): + self.batch1 = nn.BatchNorm2d(F2,affine=True) + elif "instance" in norm_type.casefold(): + self.batch1 = nn.InstanceNorm2d(F2) + else: + raise ValueError( + "normalization layer type can be 'batchnorm' or 'instancenorm'" + ) + + if log_activation_base in ["e", torch.e]: + self.log_activation = lambda x: torch.log(torch.clamp(x, 1e-7, 1e4)) + elif log_activation_base in ["10", 10]: + self.log_activation = lambda x: torch.log10(torch.clamp(x, 1e-7, 1e4)) + elif log_activation_base in ["db", "dB"]: + self.log_activation = lambda x: 10*torch.log10(torch.clamp(x, 1e-7, 1e4)) + else: + raise ValueError( + "allowed activation base are 'e' for torch.log, " + "'10' for torch.log10, and 'dB' for 10*torch.log10" + ) + + if not self.do_global_pooling: + self.pool2 = nn.AvgPool2d((1, Pool), stride=(1, max(1, Pool//5))) + else: + self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) + + self.drop1 = nn.Dropout(p) + self.flatten = nn.Flatten() + + def forward(self, x): + if self.freeze_temporal: + self.freeze_temporal -= 1 + self.conv1.requires_grad_(False) + else: + self.conv1.requires_grad_(True) + x = torch.unsqueeze(x, 1) + x = self.conv1(x) + x = self.conv2(x) + x = self.batch1(x) + x = torch.square(x) + if self.do_global_pooling: + x = self.global_pooling(x) + else: + x = self.pool2(x) + x = self.log_activation(x) + x = self.drop1(x) + x = self.flatten(x) + return x + + @torch.no_grad() + def _get_spatial_softmax(self): + return torch.softmax(self.conv2.weight, -2) + + @torch.no_grad() + def _get_spatial_zero(self): + return self.conv2.weight-torch.sum(self.conv2.weight,-2, keepdim=True) + + @torch.no_grad() + def _initialize_custom_temporal_filter(self, seed=None): + _reset_seed(seed) + if self.conv1.weight.shape[-1] >= 75: + bands = ( + ( 0.5, 4.0), # delta + ( 4.0, 8.0), # theta + ( 8.0, 12.0), # alpha + (12.0, 16.0), # beta1 + (16.0, 20.0), # beta2 + (20.0, 28.0), # beta3 + (28.0, 45.0) # gamma + ) + else: + bands = ( + ( 0.5, 8.0), + ( 8.0, 16.0), + (16.0, 28.0), + (28.0, 45.0) + ) + F, KernLength = self.conv1.weight.shape[0], self.conv1.weight.shape[-1] + comb = self._powerset(bands) + for i in range(min(F,len(comb))): # if F <= len(comb): + filt_coeff = firwin( + KernLength, + self._merge_tuples(comb[i]), + pass_zero=False, + fs=self.Fs + ) + self.conv1.weight.data[i,0,0] = torch.from_numpy(filt_coeff) + + @torch.no_grad() + def _powerset(self, s): + return tuple(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))) + + @torch.no_grad() + def _merge_tuples(self, tuples): + merged = [num for tup in tuples for num in tup] + merged = sorted(merged) + if len(merged)>2: + new_merged = [merged[0]] + for i in range(1, len(merged)-2, 2): + if merged[i] != merged[i+1]: + new_merged.append(merged[i]) + new_merged.append(merged[i+1]) + new_merged.append(merged[-1]) + return sorted(new_merged) + return merged + + @torch.no_grad() + def _combinatorial_op(self, N, k): + return int((math.factorial(N))/(math.factorial(k)*math.factorial(N-k))) diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 83f57bd..533d63a 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -15,6 +15,7 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, + xEEGNetEncoder ) from ..utils.utils import _reset_seed @@ -31,6 +32,7 @@ "StagerNet", "STNet", "TinySleepNet", + "xEEGNet" ] @@ -2027,6 +2029,179 @@ def __init__( def forward(self, x): x = self.encoder(x) x = self.MLP(x) + if not (self.return_logits): + if self.nb_classes <= 2: + x = torch.sigmoid(x) + else: + x = F.softmax(x, dim=1) + return x + + +class xEEGNet(nn.Module): + """ + Pytorch implementation of xEEGNet. + + For more information see the following paper [xEEG]_ . + The original implementation of EEGconformer can be found here [xEEGgit]_ . + The expected **input** is a **3D tensor** with size: + (Batch x Channels x Samples). + + Parameters + ---------- + nb_classes: int + The number of classes. If less than 2, a binary classification problem + is considered (output dimensions will be [batch, 1] in this case). + Chans: int + The number of EEG channels. + Samples: int + The sample length (number of time steps). + It will be used to calculate the embedding size (for head initialization). + Fs: int + The sampling rate of the EEG signal in Hz. + It is used to initialize the weights of the filters. + Must be specified even if `random_temporal_filter` is False. + F1: int, optional + The number of output filters in the temporal convolution layer. + + Default = 7 + K1: int, optional + The length of the temporal convolutional layer. + + Default = 125 + F2: int, optional + The number of output filters in the spatial convolution layer. + + Default = 7 + Pool: int, optional + Kernel size for temporal pooling. + + Default = 75 + p: float, optional + Dropout probability in [0,1) + + Default = 0.2 + random_temporal_filter: bool, optional + If True, initialize the temporal filter weights randomly. + Otherwise, use a passband FIR filter. + + Default = False + freeze_temporal: int, optional + Number of forward steps to keep the temporal layer frozen. + + Default = 1e12 + spatial_depthwise: bool, optional + Whether to apply a depthwise layer in the spatial convolution. + + Default = True + log_activation_base: str, optional + Base for the logarithmic activation after pooling. + Options: "e" (natural log), "10" (logarithm base 10), "dB" (decibel scale). + + Default = "dB" + norm_type: str, optional + The type of normalization. Expected values are "batch" or "instance". + + Default = "batchnorm" + global_pooling: bool, optional + If True, apply global average pooling instead of flattening. + + Default = True + bias: list[int, int], optional + A 2-element list with boolean values. + If the first element is True, a bias will be added to the temporal + convolutional layer. + If the second element is True, a bias will be added to the spatial + convolutional layer. + If the third element is True, a bias will be added to the final dense layer. + + Default = [False, False, False] + return_logits: bool, optional + If True, return the output as logit. + It is suggested to not use False as the pytorch crossentropy loss function + applies the softmax internally. + + Default = True + seed: int, optional + A custom seed for model initialization. It must be a nonnegative number. + If None is passed, no custom seed will be set + + Default = None + + References + ---------- + .. [xEEG] zanola et al., xEEGNet: Towards Explainable AI in EEG Dementia + Classification. arXiv preprint. 2025. https://doi.org/10.48550/arXiv.2504.21457 + .. [xEEGgit] https://github.com/MedMaxLab/shallownetXAI + + Example + ------- + >>> import selfeeg.models + >>> import torch + >>> x = torch.randn(4,8,512) + >>> mdl = models.xEEGNet(3, 8, 512, 125) + >>> out = mdl(x) + >>> print(out.shape) # shoud return torch.Size([4, 3]) + + """ + def __init__( + self, + nb_classes: int, + Chans: int, + Samples: int, + Fs: int, + F1: int = 7, + K1: int = 125, + F2: int = 7, + Pool: int = 75, + p: float = 0.2, + random_temporal_filter = False, + freeze_temporal: int = 1e12, + spatial_depthwise: bool = True, + log_activation_base: str = "dB", + norm_type: str = "batchnorm", + global_pooling = True, + bias: list[int, int, int] = [False, False, False], + dense_hidden: int = -1, + return_logits=True, + seed = None + ): + + super(xEEGNet, self).__init__() + + self.nb_classes = nb_classes + self.return_logits = return_logits + self.encoder = xEEGNetEncoder( + Chans, Fs, F1, K1, F2, Pool, p, random_temporal_filter, + freeze_temporal, spatial_depthwise, log_activation_base, + norm_type, global_pooling, bias, seed + ) + + if global_pooling: + self.emb_size = F2 + else: + self.emb_size = F2 * ((Samples - K1 + 1 - Pool) // max(1,int(Pool//5)) + 1) + + _reset_seed(seed) + if dense_hidden<=0: + self.Dense = nn.Linear( + self.emb_size, + 1 if nb_classes <= 2 else nb_classes, + bias=bias[2] + ) + else: + self.Dense = nn.Sequential( + nn.Linear(self.emb_size, dense_hidden, bias=True), + nn.ReLU(), + nn.Linear( + dense_hidden, + 1 if nb_classes <= 2 else nb_classes, + bias=bias[2] + ) + ) + + def forward(self, x): + x = self.encoder(x) + x = self.Dense(x) if not (self.return_logits): if self.nb_classes <= 2: x = torch.sigmoid(x) diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index e8d1196..a92728b 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -487,6 +487,54 @@ def test_TinySleepNet(self): self.assertGreaterEqual(out.min(), 0) print(" TinySleepNet OK: tested", len(EEGsleep_args), " combinations of input arguments") + def test_xEEGNet(self): + print("Testing xEEGNet...", end="", flush=True) + EEGxeg_args = { + "nb_classes": [4], + "Samples": [2048], + "Chans": [self.Chan], + "Fs": [125], + "F1": [7, 126], + "K1": [125, 75], + "F2": [7, 126], + "Pool": [75, 50], + "random_temporal_filter": [True, False], + "freeze_temporal": [0, 1e12], + "spatial_depthwise": [True, False], + "log_activation_base": ["dB"], + "norm_type": ["batchnorm"], + "global_pooling": [True, False], + "bias": [[False]*3], + "dense_hidden": [-1, 32], + "return_logits": [False], + "seed": [42] + } + + EEGxeg_args = self.makeGrid(EEGxeg_args) + for i in EEGxeg_args: + if i["F1"]>i["F2"] and i["spatial_depthwise"]: + continue + model = models.xEEGNet(**i) + out = model(self.x) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + + if self.device.type != "cpu": + for i in EEGxeg_args: + if i["F1"]>i["F2"] and i["spatial_depthwise"]: + continue + model = models.xEEGNet(**i).to(device=self.device) + out = model(self.x2) + self.assertEqual(torch.isnan(out).sum(), 0) + self.assertEqual(out.shape[1], i["nb_classes"] if i["nb_classes"] > 2 else 1) + if not (i["return_logits"]): + self.assertLessEqual(out.max(), 1) + self.assertGreaterEqual(out.min(), 0) + print(" xEEGNet OK: tested", len(EEGxeg_args), " combinations of input arguments") + if __name__ == "__main__": unittest.main() From 78683a0dbd3e5361542e4d4134a9bd9a9075af03 Mon Sep 17 00:00:00 2001 From: fedepup Date: Tue, 3 Jun 2025 09:49:45 +0000 Subject: [PATCH 3/4] add phase swap --- RELEASE.md | 7 +- docs/selfeeg.augmentation.rst | 1 + selfeeg/augmentation/__init__.py | 1 + selfeeg/augmentation/functional.py | 91 +++++++++++++++++++- test/EEGself/augmentation/functional_test.py | 23 +++++ 5 files changed, 118 insertions(+), 5 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index c30d59b..5f8b9c8 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,7 +1,12 @@ # Version X.X.X (only via git install) +# Version 0.2.1 (latest) + **Functionality** +- **augmentation module**: + - add Circular augmenter in compose module. + - add phase swap augmentation in functional module. - **models module**: - models can be initialized with a custom seed. - add EEGConformer. @@ -18,7 +23,7 @@ * reduced unittest overall time -# Version 0.2.0 (latest) +# Version 0.2.0 **Functionality** diff --git a/docs/selfeeg.augmentation.rst b/docs/selfeeg.augmentation.rst index 97a3e25..dbb4679 100644 --- a/docs/selfeeg.augmentation.rst +++ b/docs/selfeeg.augmentation.rst @@ -58,6 +58,7 @@ Functions moving_avg permutation_signal permute_channels + phase_swap random_FT_phase random_slope_scale scaling diff --git a/selfeeg/augmentation/__init__.py b/selfeeg/augmentation/__init__.py index f232afd..2c869cb 100644 --- a/selfeeg/augmentation/__init__.py +++ b/selfeeg/augmentation/__init__.py @@ -25,6 +25,7 @@ moving_avg, permutation_signal, permute_channels, + phase_swap, random_FT_phase, random_slope_scale, scaling, diff --git a/selfeeg/augmentation/functional.py b/selfeeg/augmentation/functional.py index d0a1dc3..5255e3e 100755 --- a/selfeeg/augmentation/functional.py +++ b/selfeeg/augmentation/functional.py @@ -35,6 +35,7 @@ "moving_avg", "permutation_signal", "permute_channels", + "phase_swap", "random_FT_phase", "random_slope_scale", "scaling", @@ -114,7 +115,7 @@ def shift_horizontal( batch_equal: bool = True, ) -> ArrayLike: """ - shifts temporally the elements of the ArrayLike object. + Shifts temporally the elements of the ArrayLike object. Shift is applied along the last dimension. The empty elements at beginning or the ending part @@ -438,6 +439,89 @@ def shift_frequency( return _shift_frequency(x, shift_freq, Fs, forward, random_shift, batch_equal, t, h) +def phase_swap(x: ArrayLike) -> ArrayLike: + """ + Apply the phase swap data augmentation to the ArrayLike object. + + The phase swap data augmentation consists in merging the amplitude + and phase components of biosignals from different sources to help + the model learn their coupling. + Specifically, the amplitude and phase of two randomly selected EEG samples + are extracted using the Fourier transform. + New samples are then generated by applying the inverse Fourier transform, + combining the amplitude from one sample with the phase from the other. + See the following paper for more information [phaseswap]_. + + Parameters + ---------- + x : ArrayLike + A 3-dimensional torch tensor or numpy array. + The last two dimensions must refer to the EEG (Channels x Samples). + + Returns + ------- + x: ArrayLike + The augmented version of the input Tensor or Array. + + Note + ---- + `Phase swap` ignores the class of each sample. + + + References + ---------- + .. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo. + "Boosting Generalization in Bio-signal Classification by + Learning the Phase-Amplitude Coupling". DAGM GCPR (2020). + + """ + + Ndim = len(x.shape) + assert Ndim==3, "x must be a 3-dimensional array or tensor" + + N = x.shape[0] + + if isinstance(x, torch.Tensor): + # Compute fft, module and phase + xfft = torch.fft.fft(x) + amplitude = xfft.abs() + phase = xfft.angle() + x_aug = torch.clone(xfft) + + # Random shuffle indeces + idx_shuffle = torch.randperm(N).to(device=x.device) + idx_shuffle_1 = idx_shuffle[:(N//2)] + idx_shuffle_2 = idx_shuffle[(N//2):(N//2)*2] + + # Apply phase swap + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1]*torch.exp(1j*phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2]*torch.exp(1j*phase[idx_shuffle_1]) + + # Reconstruct the signal + x_aug = (torch.fft.ifft(x_aug)).real.to(device=x.device) + + else: + + xfft = np.fft.fft(x) + amplitude = np.abs(xfft) + phase = np.angle(xfft) + x_aug = np.copy(xfft) + + # Random shuffle indeces + idx_shuffle = np.random.permutation(N) + idx_shuffle_1 = idx_shuffle[:(N//2)] + idx_shuffle_2 = idx_shuffle[(N//2):(N//2)*2] + + # Apply phase swap + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1]*np.exp(1j*phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2]*np.exp(1j*phase[idx_shuffle_1]) + + # Reconstruct the signal + x_aug = (np.fft.ifft(x_aug)).real + + return x_aug + + def flip_vertical(x: ArrayLike) -> ArrayLike: """ changes the sign of all the elements of the input. @@ -456,9 +540,8 @@ def flip_vertical(x: ArrayLike) -> ArrayLike: ------- >>> import torch >>> import selfeeg.augmentation as aug - >>> x = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024)) - >>> xaug = aug.flip_vertical(x) - >>> print(torch.equal(xaug, x*(-1))) # should return True + >>> x = torch.randn(64,32,512) + >>> xaug = aug.phase_swap(x) """ x_flip = x * (-1) diff --git a/test/EEGself/augmentation/functional_test.py b/test/EEGself/augmentation/functional_test.py index 92daed1..a22b259 100644 --- a/test/EEGself/augmentation/functional_test.py +++ b/test/EEGself/augmentation/functional_test.py @@ -210,6 +210,29 @@ def test_shift_frequency(self): self.assertTrue(math.isclose(per1[24], per2[104], rel_tol=1e-5)) print(" shift frequency OK: tested", N + len(aug_args), "combinations of input arguments") + def test_phase_swap(self): + print("Testing phase swap...", end="", flush=True) + aug_args = { + "x": [self.x3, self.x3np] + } + aug_args = self.makeGrid(aug_args) + for i in aug_args: + xaug = aug.phase_swap(**i) + if isinstance(xaug, torch.Tensor): + self.assertTrue(torch.isnan(xaug).sum() == 0) + self.assertFalse(torch.equal(i["x"], xaug)) + else: + self.assertTrue(np.isnan(xaug).sum() == 0) + self.assertFalse(np.array_equal(i["x"], xaug)) + N = len(aug_args) + if self.device.type != "cpu": + aug_args = {"x": [self.x3gpu]} + aug_args = self.makeGrid(aug_args) + for i in aug_args: + xaug = aug.phase_swap(**i) + + print(" phase_swap OK: tested", N + len(aug_args), "combinations of input arguments") + def test_flip_vertical(self): print("Testing flip vertical...", end="", flush=True) aug_args = { From c7561f65ff78fb9baf97edf3d400b5f17e8d6e65 Mon Sep 17 00:00:00 2001 From: fedepup Date: Tue, 3 Jun 2025 12:05:10 +0200 Subject: [PATCH 4/4] double-check before new version --- docs/selfeeg.models.rst | 1 - selfeeg/augmentation/functional.py | 35 +- selfeeg/models/__init__.py | 4 +- selfeeg/models/encoders.py | 298 +++++++++-------- selfeeg/models/zoo.py | 329 ++++++++++--------- selfeeg/utils/utils.py | 17 +- test/EEGself/augmentation/functional_test.py | 6 +- test/EEGself/models/zoo_test.py | 36 +- 8 files changed, 379 insertions(+), 347 deletions(-) diff --git a/docs/selfeeg.models.rst b/docs/selfeeg.models.rst index 773f6d0..470ae7c 100644 --- a/docs/selfeeg.models.rst +++ b/docs/selfeeg.models.rst @@ -78,4 +78,3 @@ Classes STNet TinySleepNet xEEGNet - diff --git a/selfeeg/augmentation/functional.py b/selfeeg/augmentation/functional.py index 5255e3e..cf3ee04 100755 --- a/selfeeg/augmentation/functional.py +++ b/selfeeg/augmentation/functional.py @@ -445,7 +445,7 @@ def phase_swap(x: ArrayLike) -> ArrayLike: The phase swap data augmentation consists in merging the amplitude and phase components of biosignals from different sources to help - the model learn their coupling. + the model learn their coupling. Specifically, the amplitude and phase of two randomly selected EEG samples are extracted using the Fourier transform. New samples are then generated by applying the inverse Fourier transform, @@ -465,7 +465,7 @@ def phase_swap(x: ArrayLike) -> ArrayLike: Note ---- - `Phase swap` ignores the class of each sample. + `Phase swap` ignores the class of each sample. References @@ -473,35 +473,36 @@ def phase_swap(x: ArrayLike) -> ArrayLike: .. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo. "Boosting Generalization in Bio-signal Classification by Learning the Phase-Amplitude Coupling". DAGM GCPR (2020). - + """ Ndim = len(x.shape) - assert Ndim==3, "x must be a 3-dimensional array or tensor" - + if Ndim != 3: + raise ValueError("x must be a 3-dimensional array or tensor") + N = x.shape[0] - + if isinstance(x, torch.Tensor): # Compute fft, module and phase xfft = torch.fft.fft(x) amplitude = xfft.abs() phase = xfft.angle() x_aug = torch.clone(xfft) - + # Random shuffle indeces idx_shuffle = torch.randperm(N).to(device=x.device) - idx_shuffle_1 = idx_shuffle[:(N//2)] - idx_shuffle_2 = idx_shuffle[(N//2):(N//2)*2] + idx_shuffle_1 = idx_shuffle[: (N // 2)] + idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2] # Apply phase swap - x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1]*torch.exp(1j*phase[idx_shuffle_2]) - x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2]*torch.exp(1j*phase[idx_shuffle_1]) + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * torch.exp(1j * phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * torch.exp(1j * phase[idx_shuffle_1]) # Reconstruct the signal x_aug = (torch.fft.ifft(x_aug)).real.to(device=x.device) else: - + xfft = np.fft.fft(x) amplitude = np.abs(xfft) phase = np.angle(xfft) @@ -509,13 +510,13 @@ def phase_swap(x: ArrayLike) -> ArrayLike: # Random shuffle indeces idx_shuffle = np.random.permutation(N) - idx_shuffle_1 = idx_shuffle[:(N//2)] - idx_shuffle_2 = idx_shuffle[(N//2):(N//2)*2] + idx_shuffle_1 = idx_shuffle[: (N // 2)] + idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2] # Apply phase swap - x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1]*np.exp(1j*phase[idx_shuffle_2]) - x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2]*np.exp(1j*phase[idx_shuffle_1]) - + x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * np.exp(1j * phase[idx_shuffle_2]) + x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * np.exp(1j * phase[idx_shuffle_1]) + # Reconstruct the signal x_aug = (np.fft.ifft(x_aug)).real diff --git a/selfeeg/models/__init__.py b/selfeeg/models/__init__.py index 6923e96..5a336c8 100644 --- a/selfeeg/models/__init__.py +++ b/selfeeg/models/__init__.py @@ -18,7 +18,7 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, - xEEGNetEncoder + xEEGNetEncoder, ) from .zoo import ( @@ -34,5 +34,5 @@ StagerNet, STNet, TinySleepNet, - xEEGNet + xEEGNet, ) diff --git a/selfeeg/models/encoders.py b/selfeeg/models/encoders.py index 42ddf6a..5b382dd 100644 --- a/selfeeg/models/encoders.py +++ b/selfeeg/models/encoders.py @@ -26,7 +26,7 @@ "StagerNetEncoder", "STNetEncoder", "TinySleepNetEncoder", - "xEEGNetEncoder" + "xEEGNetEncoder", ] @@ -113,18 +113,18 @@ class EEGNetEncoder(nn.Module): def __init__( self, Chans: int, - kernLength: int=64, - dropRate: float=0.5, - F1: int=8, - D: int=2, - F2: int=16, - dropType: str="Dropout", - ELUalpha: int=1, - pool1: int=4, - pool2: int=8, - separable_kernel: int=16, - depthwise_max_norm: float=1.0, - seed: int=None + kernLength: int = 64, + dropRate: float = 0.5, + F1: int = 8, + D: int = 2, + F2: int = 16, + dropType: str = "Dropout", + ELUalpha: int = 1, + pool1: int = 4, + pool2: int = 8, + separable_kernel: int = 16, + depthwise_max_norm: float = 1.0, + seed: int = None, ): if dropType not in ["SpatialDropout2D", "Dropout"]: @@ -253,20 +253,20 @@ class DeepConvNetEncoder(nn.Module): def __init__( self, Chans: int, - kernLength: int=10, - F: int=25, - Pool: int=3, - stride: int=3, - max_norm: float=None, - batch_momentum: float=0.1, - ELUalpha: int=1, - dropRate: float=0.5, - seed: int=None + kernLength: int = 10, + F: int = 25, + Pool: int = 3, + stride: int = 3, + max_norm: float = None, + batch_momentum: float = 0.1, + ELUalpha: int = 1, + dropRate: float = 0.5, + seed: int = None, ): super(DeepConvNetEncoder, self).__init__() _reset_seed(seed) - + self.conv1 = ConstrainedConv2d( 1, F, (1, kernLength), padding="valid", stride=(1, 1), max_norm=max_norm ) @@ -406,21 +406,21 @@ class EEGInceptionEncoder(nn.Module): def __init__( self, Chans: int, - F1: int=8, - D: int=2, - kernel_size: int=64, - pool: int=4, - dropRate: float=0.5, - ELUalpha: float=1.0, - bias: bool=True, - batch_momentum: float=0.1, - max_depth_norm: float=1.0, - seed: int=None + F1: int = 8, + D: int = 2, + kernel_size: int = 64, + pool: int = 4, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + batch_momentum: float = 0.1, + max_depth_norm: float = 1.0, + seed: int = None, ): super(EEGInceptionEncoder, self).__init__() _reset_seed(seed) - + self.inc1 = nn.Sequential( nn.Conv2d(1, F1, (1, kernel_size), padding="same", bias=bias), nn.BatchNorm2d(F1, momentum=batch_momentum), @@ -588,18 +588,18 @@ def __init__( self, Chans: int, Fs: int, - F: int=128, - kernlength: int=8, - pool: int=8, - dropRate: float=0.5, - batch_momentum: float=0.1, - hidden_lstm: int=128, - seed: int=None + F: int = 128, + kernlength: int = 8, + pool: int = 8, + dropRate: float = 0.5, + batch_momentum: float = 0.1, + hidden_lstm: int = 128, + seed: int = None, ): super(TinySleepNetEncoder, self).__init__() _reset_seed(seed) - + self.conv1 = nn.Conv1d(Chans, F, int(Fs // 2), stride=int(Fs // 16), padding="valid") self.BN1 = nn.BatchNorm1d(F, momentum=batch_momentum) self.Relu = nn.ReLU() @@ -692,7 +692,7 @@ class StagerNetEncoder(nn.Module): """ - def __init__(self, Chans, kernLength: int=64, F: int=8, Pool: int=16, seed: int=None): + def __init__(self, Chans, kernLength: int = 64, F: int = 8, Pool: int = 16, seed: int = None): super(StagerNetEncoder, self).__init__() _reset_seed(seed) @@ -775,11 +775,13 @@ class ShallowNetEncoder(nn.Module): """ - def __init__(self, Chans, F: int=40, K1: int=25, Pool: int=75, p: float=0.2, seed: int=None): + def __init__( + self, Chans, F: int = 40, K1: int = 25, Pool: int = 75, p: float = 0.2, seed: int = None + ): super(ShallowNetEncoder, self).__init__() _reset_seed(seed) - + self.conv1 = nn.Conv2d(1, F, (1, K1), stride=(1, 1)) self.conv2 = nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)) self.batch1 = nn.BatchNorm2d(F) @@ -811,7 +813,7 @@ class BasicBlock1(nn.Module): :meta private: """ - def __init__(self, inplanes: int, planes: int, kernLength: int=7, stride: int=1): + def __init__(self, inplanes: int, planes: int, kernLength: int = 7, stride: int = 1): super(BasicBlock1, self).__init__() self.stride = stride @@ -960,12 +962,12 @@ def __init__( addConnection: bool = False, preBlock: nn.Module = None, postBlock: nn.Module = None, - seed: int=None + seed: int = None, ): super(ResNet1DEncoder, self).__init__() _reset_seed(seed) - + self.inplane = inplane self.kernLength = kernLength self.connection = addConnection @@ -1182,11 +1184,11 @@ class STNetEncoder(nn.Module): def __init__( self, Samples, - F: int=256, - kernlength: int=5, - dropRate: float=0.5, - bias: bool=True, - seed: int=None + F: int = 256, + kernlength: int = 5, + dropRate: float = 0.5, + bias: bool = True, + seed: int = None, ): super(STNetEncoder, self).__init__() _reset_seed(seed) @@ -1273,6 +1275,7 @@ class EEGSymInception(nn.Module): """ :meta private: """ + def __init__( self, in_channels, @@ -1341,6 +1344,7 @@ class EEGSymResBlock(nn.Module): """ :meta private: """ + def __init__( self, in_channels, @@ -1471,16 +1475,16 @@ def __init__( Chans: int, Samples: int, Fs: float, - scales_time: tuple=(500, 250, 125), - lateral_chans: int=3, - first_left: bool=True, - F: int=8, - pool: int=2, - dropRate: float=0.5, - ELUalpha: float=1.0, - bias: bool=True, - residual: bool=True, - seed: int=None + scales_time: tuple = (500, 250, 125), + lateral_chans: int = 3, + first_left: bool = True, + F: int = 8, + pool: int = 2, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + residual: bool = True, + seed: int = None, ): super(EEGSymEncoder, self).__init__() _reset_seed(seed) @@ -1865,7 +1869,7 @@ def __init__( TemporalStride: int = 4, batch_momentum: float = 0.1, depthwise_max_norm=None, - seed: int=None + seed: int = None, ): super(FBCNetEncoder, self).__init__() self.FilterBands = FilterBands @@ -2003,41 +2007,41 @@ class EEGConformerEncoder(nn.Module): def __init__( self, Chans, - F: int=40, - K1: int=25, - Pool: int=75, - stride_pool: int=15, - d_model: int=40, - nlayers: int=6, - nheads: int=10, - dim_feedforward: int=160, - activation_transformer: str or Callable="gelu", - p: float=0.2, - p_transformer: float=0.5, - seed: int=None + F: int = 40, + K1: int = 25, + Pool: int = 75, + stride_pool: int = 15, + d_model: int = 40, + nlayers: int = 6, + nheads: int = 10, + dim_feedforward: int = 160, + activation_transformer: str or Callable = "gelu", + p: float = 0.2, + p_transformer: float = 0.5, + seed: int = None, ): super(EEGConformerEncoder, self).__init__() _reset_seed(seed) - + self.tokenizer = nn.Sequential( nn.Conv2d(1, F, (1, K1), stride=(1, 1)), nn.Conv2d(F, F, (Chans, 1), stride=(1, 1)), nn.BatchNorm2d(F), nn.AvgPool2d((1, Pool), stride=(1, stride_pool)), - nn.Dropout(p) + nn.Dropout(p), ) self.projection = nn.Conv2d(F, d_model, (1, 1)) self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model, - nhead = nheads, - dim_feedforward = dim_feedforward, - dropout = p_transformer, - activation = "gelu", - batch_first = True + nhead=nheads, + dim_feedforward=dim_feedforward, + dropout=p_transformer, + activation="gelu", + batch_first=True, ), - num_layers = nlayers, + num_layers=nlayers, ) def forward(self, x): @@ -2047,9 +2051,9 @@ def forward(self, x): x = torch.unsqueeze(x, 1) x = self.tokenizer(x) x = x.squeeze(2) - x = torch.permute(x,[0,2,1]) + x = torch.permute(x, [0, 2, 1]) x = self.transformer(x) - x = torch.permute(x,[0,2,1]) + x = torch.permute(x, [0, 2, 1]) return x @@ -2058,8 +2062,8 @@ class xEEGNetEncoder(nn.Module): Pytorch implementation of the xEEGNet Encoder. See xEEGNet for some references. - The expected **input** is a **3D tensor** with size: - (Batch x Channels x Samples). + The expected **input** is a **3D tensor** with size + (Batch x Channels x Samples). Parameters ---------- @@ -2068,11 +2072,11 @@ class xEEGNetEncoder(nn.Module): Fs: int The sampling rate of the EEG signal in Hz. It is used to initialize the weights of the filters. - Must be specified even if `random_temporal_filter` is False. + Must be specified even if `random_temporal_filter` is False. F1: int, optional The number of output filters in the temporal convolution layer. - Default = 7 + Default = 7 K1: int, optional The length of the temporal convolutional layer. @@ -2137,8 +2141,9 @@ class xEEGNetEncoder(nn.Module): >>> mdl = models.xEEGNetEncoder(8, 125) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 7]) - + """ + def __init__( self, Chans: int, @@ -2148,14 +2153,14 @@ def __init__( F2: int = 7, Pool: int = 75, p: float = 0.2, - random_temporal_filter = False, + random_temporal_filter=False, freeze_temporal: int = 1e12, spatial_depthwise: bool = True, log_activation_base: str = "dB", norm_type: str = "batchnorm", - global_pooling = True, + global_pooling=True, bias: list[int, int] = [False, False], - seed: int = None + seed: int = None, ): super(xEEGNetEncoder, self).__init__() @@ -2163,23 +2168,23 @@ def __init__( # Set seed before initializing layers self.custom_seed = seed _reset_seed(seed) - + self.Fs = Fs self.chans = Chans self.freeze_temporal = freeze_temporal self.bias_1conv = bias[0] self.bias_2conv = bias[1] self.do_global_pooling = global_pooling - if self.Fs <=0 and not(random_temporal_filter): + if self.Fs <= 0 and not (random_temporal_filter): raise ValueError( "to properly initialize non random temporal fir filters, " "Fs (sampling rate) must be given" ) - + if random_temporal_filter: self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1, 1), bias=self.bias_1conv) else: - self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1,1), bias=self.bias_1conv) + self.conv1 = nn.Conv2d(1, F1, (1, K1), stride=(1, 1), bias=self.bias_1conv) self._initialize_custom_temporal_filter(self.custom_seed) if spatial_depthwise: @@ -2187,45 +2192,44 @@ def __init__( F1, F2, (Chans, 1), stride=(1, 1), groups=F1, bias=self.bias_2conv ) else: - self.conv2 = nn.Conv2d( - F1, F2, (Chans, 1), stride=(1, 1), bias = self.bias_2conv - ) - - if "batch" in norm_type.casefold(): - self.batch1 = nn.BatchNorm2d(F2,affine=True) + self.conv2 = nn.Conv2d(F1, F2, (Chans, 1), stride=(1, 1), bias=self.bias_2conv) + + if "batch" in norm_type.casefold(): + self.batch1 = nn.BatchNorm2d(F2, affine=True) elif "instance" in norm_type.casefold(): self.batch1 = nn.InstanceNorm2d(F2) else: - raise ValueError( - "normalization layer type can be 'batchnorm' or 'instancenorm'" - ) - - if log_activation_base in ["e", torch.e]: + raise ValueError("normalization layer type can be 'batchnorm' or 'instancenorm'") + + if log_activation_base in ["e", torch.e]: self.log_activation = lambda x: torch.log(torch.clamp(x, 1e-7, 1e4)) elif log_activation_base in ["10", 10]: self.log_activation = lambda x: torch.log10(torch.clamp(x, 1e-7, 1e4)) elif log_activation_base in ["db", "dB"]: - self.log_activation = lambda x: 10*torch.log10(torch.clamp(x, 1e-7, 1e4)) + self.log_activation = lambda x: 10 * torch.log10(torch.clamp(x, 1e-7, 1e4)) else: raise ValueError( "allowed activation base are 'e' for torch.log, " "'10' for torch.log10, and 'dB' for 10*torch.log10" ) - + if not self.do_global_pooling: - self.pool2 = nn.AvgPool2d((1, Pool), stride=(1, max(1, Pool//5))) + self.pool2 = nn.AvgPool2d((1, Pool), stride=(1, max(1, Pool // 5))) else: self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) - + self.drop1 = nn.Dropout(p) self.flatten = nn.Flatten() def forward(self, x): + """ + :meta private: + """ if self.freeze_temporal: self.freeze_temporal -= 1 self.conv1.requires_grad_(False) else: - self.conv1.requires_grad_(True) + self.conv1.requires_grad_(True) x = torch.unsqueeze(x, 1) x = self.conv1(x) x = self.conv2(x) @@ -2242,61 +2246,71 @@ def forward(self, x): @torch.no_grad() def _get_spatial_softmax(self): + """ + :meta private: + """ return torch.softmax(self.conv2.weight, -2) @torch.no_grad() def _get_spatial_zero(self): - return self.conv2.weight-torch.sum(self.conv2.weight,-2, keepdim=True) - + """ + :meta private: + """ + return self.conv2.weight - torch.sum(self.conv2.weight, -2, keepdim=True) + @torch.no_grad() def _initialize_custom_temporal_filter(self, seed=None): + """ + :meta private: + """ _reset_seed(seed) if self.conv1.weight.shape[-1] >= 75: bands = ( - ( 0.5, 4.0), # delta - ( 4.0, 8.0), # theta - ( 8.0, 12.0), # alpha - (12.0, 16.0), # beta1 - (16.0, 20.0), # beta2 - (20.0, 28.0), # beta3 - (28.0, 45.0) # gamma + (0.5, 4.0), # delta + (4.0, 8.0), # theta + (8.0, 12.0), # alpha + (12.0, 16.0), # beta1 + (16.0, 20.0), # beta2 + (20.0, 28.0), # beta3 + (28.0, 45.0), # gamma ) else: - bands = ( - ( 0.5, 8.0), - ( 8.0, 16.0), - (16.0, 28.0), - (28.0, 45.0) - ) + bands = ((0.5, 8.0), (8.0, 16.0), (16.0, 28.0), (28.0, 45.0)) F, KernLength = self.conv1.weight.shape[0], self.conv1.weight.shape[-1] comb = self._powerset(bands) - for i in range(min(F,len(comb))): # if F <= len(comb): + for i in range(min(F, len(comb))): # if F <= len(comb): filt_coeff = firwin( - KernLength, - self._merge_tuples(comb[i]), - pass_zero=False, - fs=self.Fs + KernLength, self._merge_tuples(comb[i]), pass_zero=False, fs=self.Fs ) - self.conv1.weight.data[i,0,0] = torch.from_numpy(filt_coeff) + self.conv1.weight.data[i, 0, 0] = torch.from_numpy(filt_coeff) @torch.no_grad() def _powerset(self, s): - return tuple(chain.from_iterable(combinations(s, r) for r in range(1, len(s)+1))) + """ + :meta private: + """ + return tuple(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1))) @torch.no_grad() def _merge_tuples(self, tuples): + """ + :meta private: + """ merged = [num for tup in tuples for num in tup] merged = sorted(merged) - if len(merged)>2: + if len(merged) > 2: new_merged = [merged[0]] - for i in range(1, len(merged)-2, 2): - if merged[i] != merged[i+1]: + for i in range(1, len(merged) - 2, 2): + if merged[i] != merged[i + 1]: new_merged.append(merged[i]) - new_merged.append(merged[i+1]) + new_merged.append(merged[i + 1]) new_merged.append(merged[-1]) - return sorted(new_merged) + return sorted(new_merged) return merged @torch.no_grad() def _combinatorial_op(self, N, k): - return int((math.factorial(N))/(math.factorial(k)*math.factorial(N-k))) + """ + :meta private: + """ + return int((math.factorial(N)) / (math.factorial(k) * math.factorial(N - k))) diff --git a/selfeeg/models/zoo.py b/selfeeg/models/zoo.py index 533d63a..2c9b94a 100755 --- a/selfeeg/models/zoo.py +++ b/selfeeg/models/zoo.py @@ -15,7 +15,7 @@ StagerNetEncoder, STNetEncoder, TinySleepNetEncoder, - xEEGNetEncoder + xEEGNetEncoder, ) from ..utils.utils import _reset_seed @@ -32,7 +32,7 @@ "StagerNet", "STNet", "TinySleepNet", - "xEEGNet" + "xEEGNet", ] @@ -141,20 +141,20 @@ def __init__( nb_classes: int, Chans: int, Samples: int, - kernLength: int=64, - dropRate: float=0.5, - F1: int=8, - D: int=2, - F2: int=16, - norm_rate: int=0.25, - dropType: str="Dropout", - ELUalpha: int=1, - pool1: int=4, - pool2: int=8, - separable_kernel: int=16, - depthwise_max_norm: float=1.0, - return_logits: bool=True, - seed: int=None + kernLength: int = 64, + dropRate: float = 0.5, + F1: int = 8, + D: int = 2, + F2: int = 16, + norm_rate: int = 0.25, + dropType: str = "Dropout", + ELUalpha: int = 1, + pool1: int = 4, + pool2: int = 8, + separable_kernel: int = 16, + depthwise_max_norm: float = 1.0, + return_logits: bool = True, + seed: int = None, ): super(EEGNet, self).__init__() @@ -174,7 +174,7 @@ def __init__( pool2, separable_kernel, depthwise_max_norm, - seed + seed, ) _reset_seed(seed) @@ -302,17 +302,17 @@ def __init__( nb_classes: int, Chans: int, Samples: int, - kernLength: int=10, - F: int=25, - Pool: int=3, - stride: int=3, - max_norm: int=None, - batch_momentum: float=0.1, - ELUalpha: int=1, - dropRate: float=0.5, - max_dense_norm: float=None, - return_logits: bool=True, - seed: int=None + kernLength: int = 10, + F: int = 25, + Pool: int = 3, + stride: int = 3, + max_norm: int = None, + batch_momentum: float = 0.1, + ELUalpha: int = 1, + dropRate: float = 0.5, + max_dense_norm: float = None, + return_logits: bool = True, + seed: int = None, ): super(DeepConvNet, self).__init__() @@ -445,17 +445,17 @@ def __init__( nb_classes: int, Chans: int, Samples: int, - F1: int=8, - D: int=2, - kernel_size: int=64, - pool: int=4, - dropRate: float=0.5, - ELUalpha: float=1.0, - bias: bool=True, - batch_momentum: float=0.1, - max_depth_norm: float=1.0, - return_logits: bool=True, - seed: int=None + F1: int = 8, + D: int = 2, + kernel_size: int = 64, + pool: int = 4, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + batch_momentum: float = 0.1, + max_depth_norm: float = 1.0, + return_logits: bool = True, + seed: int = None, ): super(EEGInception, self).__init__() self.nb_classes = nb_classes @@ -471,7 +471,7 @@ def __init__( bias, batch_momentum, max_depth_norm, - seed + seed, ) _reset_seed(seed) @@ -585,15 +585,15 @@ def __init__( nb_classes: int, Chans: int, Fs: int, - F: int=128, - kernlength: int=8, - pool: int=8, - dropRate: float=0.5, - batch_momentum: float=0.1, - max_dense_norm: float=2.0, - hidden_lstm: int=128, - return_logits: bool=True, - seed: int=None + F: int = 128, + kernlength: int = 8, + pool: int = 8, + dropRate: float = 0.5, + batch_momentum: float = 0.1, + max_dense_norm: float = 2.0, + hidden_lstm: int = 128, + return_logits: bool = True, + seed: int = None, ): super(TinySleepNet, self).__init__() @@ -695,20 +695,18 @@ def __init__( nb_classes: int, Chans: int, Samples: int, - dropRate: float=0.5, - kernLength: int=64, - F: int=8, - Pool: int=16, - return_logits: bool=True, - seed: int=None + dropRate: float = 0.5, + kernLength: int = 64, + F: int = 8, + Pool: int = 16, + return_logits: bool = True, + seed: int = None, ): super(StagerNet, self).__init__() self.nb_classes = nb_classes self.return_logits = return_logits - self.encoder = StagerNetEncoder( - Chans, kernLength=kernLength, F=F, Pool=Pool, seed=seed - ) + self.encoder = StagerNetEncoder(Chans, kernLength=kernLength, F=F, Pool=Pool, seed=seed) _reset_seed(seed) self.drop = nn.Dropout(p=dropRate) @@ -810,12 +808,12 @@ def __init__( nb_classes: int, Chans: int, Samples: int, - F: int=40, - K1: int=25, - Pool: int=75, - p: float=0.2, - return_logits: bool=True, - seed: int=None + F: int = 40, + K1: int = 25, + Pool: int = 75, + p: float = 0.2, + return_logits: bool = True, + seed: int = None, ): super(ShallowNet, self).__init__() @@ -969,8 +967,8 @@ def __init__( preBlock: nn.Module = None, postBlock: nn.Module = None, classifier: nn.Module = None, - return_logits: bool=True, - seed: int=None + return_logits: bool = True, + seed: int = None, ): super(ResNet1D, self).__init__() @@ -986,9 +984,9 @@ def __init__( addConnection=addConnection, preBlock=preBlock, postBlock=postBlock, - seed=seed + seed=seed, ) - + # Classifier _reset_seed(seed) if classifier is None: @@ -1099,14 +1097,14 @@ def __init__( self, nb_classes: int, Samples: int, - grid_size: int=9, - F: int=256, - kernlength: int=5, - dropRate: float=0.5, - bias: bool=True, - dense_size: int=1024, - return_logits: bool=True, - seed: int=None + grid_size: int = 9, + F: int = 256, + kernlength: int = 5, + dropRate: float = 0.5, + bias: bool = True, + dense_size: int = 1024, + return_logits: bool = True, + seed: int = None, ): super(STNet, self).__init__() @@ -1240,17 +1238,17 @@ def __init__( Chans: int, Samples: int, Fs: int, - scales_time: tuple=(500, 250, 125), - lateral_chans: int=3, - first_left: bool=True, - F: int=8, - pool: int=2, - dropRate: float=0.5, - ELUalpha: float=1.0, - bias: bool=True, - residual: bool=True, - return_logits: bool=True, - seed: int=None + scales_time: tuple = (500, 250, 125), + lateral_chans: int = 3, + first_left: bool = True, + F: int = 8, + pool: int = 2, + dropRate: float = 0.5, + ELUalpha: float = 1.0, + bias: bool = True, + residual: bool = True, + return_logits: bool = True, + seed: int = None, ): super(EEGSym, self).__init__() self.nb_classes = nb_classes @@ -1268,7 +1266,7 @@ def __init__( ELUalpha, bias, residual, - seed=seed + seed=seed, ) _reset_seed(seed) @@ -1438,7 +1436,7 @@ def __init__( linear_max_norm: float = None, classifier: nn.Module = None, return_logits: bool = True, - seed: int=None + seed: int = None, ): super(FBCNet, self).__init__() @@ -1462,9 +1460,9 @@ def __init__( TemporalStride, batch_momentum, depthwise_max_norm, - seed=seed + seed=seed, ) - + # Head _reset_seed(seed) if classifier is None: @@ -1751,26 +1749,26 @@ def __init__( Chans: int, Samples: int, Fs: float, - num_windows: int=5, - mha_heads: int=2, - tcn_depth: int=2, - F1: int=16, - D: int=2, - pool1: int=None, - pool2: int=None, - dropRate: float=0.3, - max_norm: float=None, - batchMomentum: float=0.1, - ELUAlpha: float=1.0, - mha_dropRate: float=0.5, - tcn_kernLength: int=4, - tcn_F: int=32, - tcn_ELUAlpha: float=0.0, - tcn_dropRate: float=0.3, - tcn_max_norm: float=None, - tcn_batchMom: float=0.1, - return_logits: bool=True, - seed: int=None + num_windows: int = 5, + mha_heads: int = 2, + tcn_depth: int = 2, + F1: int = 16, + D: int = 2, + pool1: int = None, + pool2: int = None, + dropRate: float = 0.3, + max_norm: float = None, + batchMomentum: float = 0.1, + ELUAlpha: float = 1.0, + mha_dropRate: float = 0.5, + tcn_kernLength: int = 4, + tcn_F: int = 32, + tcn_ELUAlpha: float = 0.0, + tcn_dropRate: float = 0.3, + tcn_max_norm: float = None, + tcn_batchMom: float = 0.1, + return_logits: bool = True, + seed: int = None, ): super(ATCNet, self).__init__() @@ -1892,7 +1890,7 @@ class EEGConformer(nn.Module): For more information see the following paper [EEGcon]_ . The original implementation of EEGconformer can be found here [EEGcongit]_ . - + The expected **input** is a **3D tensor** with size (Batch x Channels x Samples). @@ -1982,34 +1980,46 @@ class EEGConformer(nn.Module): >>> print(out.shape) # shoud return torch.Size([4, 1]) """ + def __init__( self, nb_classes: int, Chans: int, Samples: int, - F: int=40, - K1: int=25, - Pool: int=75, - stride_pool: int=15, - d_model: int=40, - nlayers: int=6, - nheads: int=10, - dim_feedforward: int=160, - activation_transformer: str or Callable="gelu", - p: float=0.2, - p_transformer: float=0.5, - mlp_dim: list[int,int]=[256,32], - return_logits: bool=True, - seed: int=None + F: int = 40, + K1: int = 25, + Pool: int = 75, + stride_pool: int = 15, + d_model: int = 40, + nlayers: int = 6, + nheads: int = 10, + dim_feedforward: int = 160, + activation_transformer: str or Callable = "gelu", + p: float = 0.2, + p_transformer: float = 0.5, + mlp_dim: list[int, int] = [256, 32], + return_logits: bool = True, + seed: int = None, ): super(EEGConformer, self).__init__() self.return_logits = return_logits self.nb_classes = nb_classes - + self.encoder = EEGConformerEncoder( - Chans, F, K1, Pool, stride_pool, d_model, nlayers, nheads, - dim_feedforward, activation_transformer, p, p_transformer, seed + Chans, + F, + K1, + Pool, + stride_pool, + d_model, + nlayers, + nheads, + dim_feedforward, + activation_transformer, + p, + p_transformer, + seed, ) _reset_seed(seed) @@ -2020,13 +2030,16 @@ def __init__( nn.Linear(d_model, mlp_dim[0]), nn.ELU(), nn.Dropout(p), - nn.Linear(mlp_dim[0], mlp_dim[1]), + nn.Linear(mlp_dim[0], mlp_dim[1]), nn.ELU(), nn.Dropout(p), - nn.Linear(mlp_dim[1], 1 if nb_classes <= 2 else nb_classes) + nn.Linear(mlp_dim[1], 1 if nb_classes <= 2 else nb_classes), ) def forward(self, x): + """ + :meta private: + """ x = self.encoder(x) x = self.MLP(x) if not (self.return_logits): @@ -2059,11 +2072,11 @@ class xEEGNet(nn.Module): Fs: int The sampling rate of the EEG signal in Hz. It is used to initialize the weights of the filters. - Must be specified even if `random_temporal_filter` is False. + Must be specified even if `random_temporal_filter` is False. F1: int, optional The number of output filters in the temporal convolution layer. - Default = 7 + Default = 7 K1: int, optional The length of the temporal convolutional layer. @@ -2141,8 +2154,9 @@ class xEEGNet(nn.Module): >>> mdl = models.xEEGNet(3, 8, 512, 125) >>> out = mdl(x) >>> print(out.shape) # shoud return torch.Size([4, 3]) - + """ + def __init__( self, nb_classes: int, @@ -2154,16 +2168,16 @@ def __init__( F2: int = 7, Pool: int = 75, p: float = 0.2, - random_temporal_filter = False, + random_temporal_filter=False, freeze_temporal: int = 1e12, spatial_depthwise: bool = True, log_activation_base: str = "dB", norm_type: str = "batchnorm", - global_pooling = True, + global_pooling=True, bias: list[int, int, int] = [False, False, False], dense_hidden: int = -1, return_logits=True, - seed = None + seed=None, ): super(xEEGNet, self).__init__() @@ -2171,35 +2185,44 @@ def __init__( self.nb_classes = nb_classes self.return_logits = return_logits self.encoder = xEEGNetEncoder( - Chans, Fs, F1, K1, F2, Pool, p, random_temporal_filter, - freeze_temporal, spatial_depthwise, log_activation_base, - norm_type, global_pooling, bias, seed + Chans, + Fs, + F1, + K1, + F2, + Pool, + p, + random_temporal_filter, + freeze_temporal, + spatial_depthwise, + log_activation_base, + norm_type, + global_pooling, + bias, + seed, ) - + if global_pooling: self.emb_size = F2 else: - self.emb_size = F2 * ((Samples - K1 + 1 - Pool) // max(1,int(Pool//5)) + 1) + self.emb_size = F2 * ((Samples - K1 + 1 - Pool) // max(1, int(Pool // 5)) + 1) _reset_seed(seed) - if dense_hidden<=0: + if dense_hidden <= 0: self.Dense = nn.Linear( - self.emb_size, - 1 if nb_classes <= 2 else nb_classes, - bias=bias[2] + self.emb_size, 1 if nb_classes <= 2 else nb_classes, bias=bias[2] ) else: self.Dense = nn.Sequential( nn.Linear(self.emb_size, dense_hidden, bias=True), nn.ReLU(), - nn.Linear( - dense_hidden, - 1 if nb_classes <= 2 else nb_classes, - bias=bias[2] - ) + nn.Linear(dense_hidden, 1 if nb_classes <= 2 else nb_classes, bias=bias[2]), ) - + def forward(self, x): + """ + :meta private: + """ x = self.encoder(x) x = self.Dense(x) if not (self.return_logits): @@ -2207,4 +2230,4 @@ def forward(self, x): x = torch.sigmoid(x) else: x = F.softmax(x, dim=1) - return x \ No newline at end of file + return x diff --git a/selfeeg/utils/utils.py b/selfeeg/utils/utils.py index 8a286fd..bd6fa4e 100644 --- a/selfeeg/utils/utils.py +++ b/selfeeg/utils/utils.py @@ -383,11 +383,7 @@ class adaptation of the ``scale_range_with_soft_clip`` function. """ def __init__( - self, - Range: float=200, - asintote: float=1.2, - scale: str="mV", - exact: bool=True + self, Range: float = 200, asintote: float = 1.2, scale: str = "mV", exact: bool = True ): if Range < 0: raise ValueError("Range cannot be lower than 0") @@ -919,16 +915,17 @@ def count_parameters( def _reset_seed( - seed: int=None, - reset_random: bool=True, - reset_numpy: bool=True, - reset_torch: bool=True, + seed: int = None, + reset_random: bool = True, + reset_numpy: bool = True, + reset_torch: bool = True, ) -> None: """ :meta private: """ if seed is not None: - assert seed>=0, "seed must be a nonnegative number" + if seed <= 0: + raise ValueError("seed must be a nonnegative number") if reset_numpy: np.random.seed(seed) if reset_random: diff --git a/test/EEGself/augmentation/functional_test.py b/test/EEGself/augmentation/functional_test.py index a22b259..2df53bb 100644 --- a/test/EEGself/augmentation/functional_test.py +++ b/test/EEGself/augmentation/functional_test.py @@ -212,9 +212,7 @@ def test_shift_frequency(self): def test_phase_swap(self): print("Testing phase swap...", end="", flush=True) - aug_args = { - "x": [self.x3, self.x3np] - } + aug_args = {"x": [self.x3, self.x3np]} aug_args = self.makeGrid(aug_args) for i in aug_args: xaug = aug.phase_swap(**i) @@ -232,7 +230,7 @@ def test_phase_swap(self): xaug = aug.phase_swap(**i) print(" phase_swap OK: tested", N + len(aug_args), "combinations of input arguments") - + def test_flip_vertical(self): print("Testing flip vertical...", end="", flush=True) aug_args = { diff --git a/test/EEGself/models/zoo_test.py b/test/EEGself/models/zoo_test.py index a92728b..6ba7f76 100644 --- a/test/EEGself/models/zoo_test.py +++ b/test/EEGself/models/zoo_test.py @@ -69,7 +69,7 @@ def test_ATCNet(self): "F1": [12, 8], "D": [2, 3], "return_logits": [False], - "seed": [42] + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -107,7 +107,7 @@ def test_DeepConvNet(self): "dropRate": [0.5], "max_dense_norm": [1.0], "return_logits": [False], - "seed": [42] + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -143,9 +143,9 @@ def test_EEGConformer(self): "nheads": [8, 10], "dim_feedforward": [80], "activation_transformer": ["gelu"], - "mlp_dim": [[128,32]], + "mlp_dim": [[128, 32]], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGcon_args = self.makeGrid(EEGcon_args) for i in EEGcon_args: @@ -183,7 +183,7 @@ def test_EEGInception(self): "max_depth_norm": [1.0], "return_logits": [False], "bias": [True, False], - "seed": [42] + "seed": [42], } EEGin_args = self.makeGrid(EEGin_args) for i in EEGin_args: @@ -220,7 +220,7 @@ def test_EEGNet(self): "pool2": [8, 16], "separable_kernel": [16, 32], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGnet_args = self.makeGrid(EEGnet_args) for i in EEGnet_args: @@ -257,7 +257,7 @@ def test_EEGSym(self): "pool": [2, 3], "bias": [True, False], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGsym_args = self.makeGrid(EEGsym_args) for i in EEGsym_args: @@ -293,7 +293,7 @@ def test_FBCNet(self): "FilterType": ["Cheby2", "ellip"], "TemporalType": ["var", "max", "mean", "std", "logvar"], "return_logits": [False], - "seed": [42] + "seed": [42], } DCN_grid = self.makeGrid(DCN_args) for i in DCN_grid: @@ -327,7 +327,7 @@ def test_ResNet(self): "kernLength": [7, 13], "addConnection": [True, False], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGres_args = self.makeGrid(EEGres_args) for i in EEGres_args: @@ -360,7 +360,7 @@ def test_ShallowNet(self): "K1": [25, 12], "Pool": [75, 50], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGsha_args = self.makeGrid(EEGsha_args) for i in EEGsha_args: @@ -393,7 +393,7 @@ def test_StagerNet(self): "kernLength": [64, 120], "Pool": [16, 8], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGsta_args = self.makeGrid(EEGsta_args) for i in EEGsta_args: @@ -426,7 +426,7 @@ def test_STNet(self): "kernlength": [5, 7], "dense_size": [1024, 512], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGstn_args = self.makeGrid(EEGstn_args) for i in EEGstn_args: @@ -464,7 +464,7 @@ def test_TinySleepNet(self): "pool": [16, 5], "hidden_lstm": [128, 50], "return_logits": [False], - "seed": [42] + "seed": [42], } EEGsleep_args = self.makeGrid(EEGsleep_args) for i in EEGsleep_args: @@ -504,15 +504,15 @@ def test_xEEGNet(self): "log_activation_base": ["dB"], "norm_type": ["batchnorm"], "global_pooling": [True, False], - "bias": [[False]*3], + "bias": [[False] * 3], "dense_hidden": [-1, 32], "return_logits": [False], - "seed": [42] + "seed": [42], } - + EEGxeg_args = self.makeGrid(EEGxeg_args) for i in EEGxeg_args: - if i["F1"]>i["F2"] and i["spatial_depthwise"]: + if i["F1"] > i["F2"] and i["spatial_depthwise"]: continue model = models.xEEGNet(**i) out = model(self.x) @@ -524,7 +524,7 @@ def test_xEEGNet(self): if self.device.type != "cpu": for i in EEGxeg_args: - if i["F1"]>i["F2"] and i["spatial_depthwise"]: + if i["F1"] > i["F2"] and i["spatial_depthwise"]: continue model = models.xEEGNet(**i).to(device=self.device) out = model(self.x2)