Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
# Version X.X.X (only via git install)

# Version 0.2.1 (latest)

**Functionality**

- **augmentation module**:
- add Circular augmenter in compose module.
- add phase swap augmentation in functional module.
- **models module**:
- models can be initialized with a custom seed.
- add EEGConformer.
- add xEEGNet.
- **dataloading module**:
- EEGDataset now supports EEG with multiple labels (1 per window partition).
- **ssl module**:
Expand All @@ -14,7 +23,7 @@
* reduced unittest overall time


# Version 0.2.0 (latest)
# Version 0.2.0

**Functionality**

Expand Down
1 change: 1 addition & 0 deletions docs/selfeeg.augmentation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Functions
moving_avg
permutation_signal
permute_channels
phase_swap
random_FT_phase
random_slope_scale
scaling
Expand Down
4 changes: 4 additions & 0 deletions docs/selfeeg.models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ Classes
:template: classtemplate.rst

DeepConvNetEncoder
EEGConformerEncoder
EEGInceptionEncoder
EEGNetEncoder
EEGSymEncoder
Expand All @@ -49,6 +50,7 @@ Classes
StagerNetEncoder
STNetEncoder
TinySleepNetEncoder
xEEGNetEncoder


models.zoo module
Expand All @@ -65,6 +67,7 @@ Classes

ATCNet
DeepConvNet
EEGConformer
EEGInception
EEGNet
EEGSym
Expand All @@ -74,3 +77,4 @@ Classes
StagerNet
STNet
TinySleepNet
xEEGNet
1 change: 1 addition & 0 deletions selfeeg/augmentation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
moving_avg,
permutation_signal,
permute_channels,
phase_swap,
random_FT_phase,
random_slope_scale,
scaling,
Expand Down
92 changes: 88 additions & 4 deletions selfeeg/augmentation/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"moving_avg",
"permutation_signal",
"permute_channels",
"phase_swap",
"random_FT_phase",
"random_slope_scale",
"scaling",
Expand Down Expand Up @@ -114,7 +115,7 @@ def shift_horizontal(
batch_equal: bool = True,
) -> ArrayLike:
"""
shifts temporally the elements of the ArrayLike object.
Shifts temporally the elements of the ArrayLike object.

Shift is applied along the last dimension.
The empty elements at beginning or the ending part
Expand Down Expand Up @@ -438,6 +439,90 @@ def shift_frequency(
return _shift_frequency(x, shift_freq, Fs, forward, random_shift, batch_equal, t, h)


def phase_swap(x: ArrayLike) -> ArrayLike:
"""
Apply the phase swap data augmentation to the ArrayLike object.

The phase swap data augmentation consists in merging the amplitude
and phase components of biosignals from different sources to help
the model learn their coupling.
Specifically, the amplitude and phase of two randomly selected EEG samples
are extracted using the Fourier transform.
New samples are then generated by applying the inverse Fourier transform,
combining the amplitude from one sample with the phase from the other.
See the following paper for more information [phaseswap]_.

Parameters
----------
x : ArrayLike
A 3-dimensional torch tensor or numpy array.
The last two dimensions must refer to the EEG (Channels x Samples).

Returns
-------
x: ArrayLike
The augmented version of the input Tensor or Array.

Note
----
`Phase swap` ignores the class of each sample.


References
----------
.. [phaseswap] Lemkhenter, Abdelhak, and Favaro, Paolo.
"Boosting Generalization in Bio-signal Classification by
Learning the Phase-Amplitude Coupling". DAGM GCPR (2020).

"""

Ndim = len(x.shape)
if Ndim != 3:
raise ValueError("x must be a 3-dimensional array or tensor")

N = x.shape[0]

if isinstance(x, torch.Tensor):
# Compute fft, module and phase
xfft = torch.fft.fft(x)
amplitude = xfft.abs()
phase = xfft.angle()
x_aug = torch.clone(xfft)

# Random shuffle indeces
idx_shuffle = torch.randperm(N).to(device=x.device)
idx_shuffle_1 = idx_shuffle[: (N // 2)]
idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2]

# Apply phase swap
x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * torch.exp(1j * phase[idx_shuffle_2])
x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * torch.exp(1j * phase[idx_shuffle_1])

# Reconstruct the signal
x_aug = (torch.fft.ifft(x_aug)).real.to(device=x.device)

else:

xfft = np.fft.fft(x)
amplitude = np.abs(xfft)
phase = np.angle(xfft)
x_aug = np.copy(xfft)

# Random shuffle indeces
idx_shuffle = np.random.permutation(N)
idx_shuffle_1 = idx_shuffle[: (N // 2)]
idx_shuffle_2 = idx_shuffle[(N // 2) : (N // 2) * 2]

# Apply phase swap
x_aug[idx_shuffle_1] = amplitude[idx_shuffle_1] * np.exp(1j * phase[idx_shuffle_2])
x_aug[idx_shuffle_2] = amplitude[idx_shuffle_2] * np.exp(1j * phase[idx_shuffle_1])

# Reconstruct the signal
x_aug = (np.fft.ifft(x_aug)).real

return x_aug


def flip_vertical(x: ArrayLike) -> ArrayLike:
"""
changes the sign of all the elements of the input.
Expand All @@ -456,9 +541,8 @@ def flip_vertical(x: ArrayLike) -> ArrayLike:
-------
>>> import torch
>>> import selfeeg.augmentation as aug
>>> x = torch.zeros(16,32,1024) + torch.sin(torch.linspace(0, 8*np.pi,1024))
>>> xaug = aug.flip_vertical(x)
>>> print(torch.equal(xaug, x*(-1))) # should return True
>>> x = torch.randn(64,32,512)
>>> xaug = aug.phase_swap(x)

"""
x_flip = x * (-1)
Expand Down
4 changes: 4 additions & 0 deletions selfeeg/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .encoders import (
BasicBlock1,
DeepConvNetEncoder,
EEGConformerEncoder,
EEGInceptionEncoder,
EEGNetEncoder,
EEGSymEncoder,
Expand All @@ -17,11 +18,13 @@
StagerNetEncoder,
STNetEncoder,
TinySleepNetEncoder,
xEEGNetEncoder,
)

from .zoo import (
ATCNet,
DeepConvNet,
EEGConformer,
EEGInception,
EEGNet,
EEGSym,
Expand All @@ -31,4 +34,5 @@
StagerNet,
STNet,
TinySleepNet,
xEEGNet,
)
Loading
Loading