Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
bd1b713
Copy remove_stripe_fw from Nabu/ToMoCuPy
mfep Nov 12, 2025
216043c
Add test_remove_stripe_fw_on_data
mfep Nov 12, 2025
a2e97e3
Cleanup: remove unused vars and arg defaults
mfep Nov 12, 2025
e7382ca
Type annotations, use np where possible
mfep Nov 13, 2025
7360a21
Estimate memory via simulated stack
mfep Nov 17, 2025
73460b9
Test memory estimation with stack
mfep Nov 17, 2025
8565a3d
Add zenodo test for remove_stripe_fw
mfep Nov 18, 2025
c855e27
More ergonomic mem peak estimation in remove_stripe_fw
mfep Nov 19, 2025
89b5ce7
adds function description and parameters, also wavelet filters variat…
dkazanc Nov 27, 2025
d68e9b0
Fix test_remove_stripe_fw_calc_mem
mfep Nov 27, 2025
57b6fd4
Fix and test remove_stripe_fw memory estimator for large sizes
mfep Nov 27, 2025
bebfbe0
Custom kernel in _conv2d
mfep Nov 24, 2025
3d9d1f1
Custom kernels in _conv_transpose2d
mfep Nov 25, 2025
33b0f02
Update mem estimator and tests for very large sizes
mfep Dec 1, 2025
c849133
minor correction to wname description
dkazanc Dec 2, 2025
b4a3542
Better memory estimation of FFT plan
mfep Dec 3, 2025
fea43a0
Prepare tests for memory allocation while estimating
mfep Dec 3, 2025
e33286e
Attempt to repair device memory fragmentation
mfep Dec 3, 2025
dd56265
update to defaults
dkazanc Dec 4, 2025
8bd38ef
Fix on_data test
mfep Dec 11, 2025
6531dab
Add test_remove_stripe_fw_performance
mfep Dec 11, 2025
0f2df10
docstring update and linting
dkazanc Dec 19, 2025
d26c216
adding pywavelets dependency
dkazanc Dec 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/bibliography/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,14 @@ @article{beck2009fast
year={2009},
publisher={SIAM}
}

@article{munch2009stripe,
title={Stripe and ring artifact removal with combined wavelet—Fourier filtering},
author={M{\"u}nch, Beat and Trtik, Pavel and Marone, Federica and Stampanoni, Marco},
journal={Optics express},
volume={17},
number={10},
pages={8567--8591},
year={2009},
publisher={Optical Society of America}
}
1 change: 1 addition & 0 deletions httomolibgpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from httomolibgpu.prep.phase import paganin_filter, paganin_filter_savu_legacy
from httomolibgpu.prep.stripe import (
remove_stripe_based_sorting,
remove_stripe_fw,
remove_stripe_ti,
remove_all_stripe,
raven_filter,
Expand Down
155 changes: 155 additions & 0 deletions httomolibgpu/cuda_kernels/remove_stripe_fw.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
template<int WSize>
__global__ void grouped_convolution_x(
int dim_x,
int dim_y,
int dim_z,
const float* in,
int in_stride_x,
int in_stride_y,
int in_stride_z,
float* out,
int out_stride_z,
int out_stride_group,
const float* w
)
{
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
{
return;
}

constexpr int out_groups = 2;
for (int i = 0; i < out_groups; ++i)
{
float acc = 0.F;
for (int j = 0; j < WSize; ++j)
{
const int w_idx = i * WSize + j;
const int in_idx = (g_thd_x * in_stride_x + j) + g_thd_y * in_stride_y + g_thd_z * in_stride_z;
acc += w[w_idx] * in[in_idx];
}
const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + i * out_stride_group;
out[out_idx] = acc;
}
}

template<int WSize>
__global__ void grouped_convolution_y(
int dim_x,
int dim_y,
int dim_z,
const float* in,
int in_stride_x,
int in_stride_y,
int in_stride_z,
int in_stride_group,
float* out,
int out_stride_z,
int out_stride_group,
const float* w
)
{
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
{
return;
}

constexpr int in_groups = 2;
constexpr int out_groups = 2;
constexpr int item_stride_y = 2;
for (int group = 0; group < in_groups; ++group)
{
for (int i = 0; i < out_groups; ++i)
{
float acc = 0.F;
for (int j = 0; j < WSize; ++j)
{
const int w_idx = (out_groups * group + i) * WSize + j;
const int in_idx = g_thd_x * in_stride_x + (item_stride_y * g_thd_y + j) * in_stride_y + group * in_stride_group + g_thd_z * in_stride_z;
acc += w[w_idx] * in[in_idx];
}
const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + (out_groups * group + i) * out_stride_group;
out[out_idx] = acc;
}
}
}

template<int WSize>
__global__ void transposed_convolution_x(
int dim_x,
int dim_y,
int dim_z,
const float* in,
int in_dim_x,
int in_stride_y,
int in_stride_z,
const float* w,
float* out
)
{
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
{
return;
}

constexpr int item_out_stride = 2;
float acc = 0.F;
for (int i = 0; i < WSize; ++i)
{
const int in_x = (g_thd_x - i) / item_out_stride;
const int in_x_mod = (g_thd_x - i) % item_out_stride;
if (in_x_mod == 0 && in_x >= 0 && in_x < in_dim_x)
{
const int in_idx = in_x + g_thd_y * in_stride_y + g_thd_z * in_stride_z;
acc += in[in_idx] * w[i];
}
}
const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z;
out[out_idx] = acc;
}

template<int WSize>
__global__ void transposed_convolution_y(
int dim_x,
int dim_y,
int dim_z,
const float* in,
int in_dim_y,
int in_stride_y,
int in_stride_z,
const float* w,
float* out
)
{
const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x;
const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y;
const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z;
if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z)
{
return;
}

constexpr int item_out_stride = 2;
float acc = 0.F;
for (int i = 0; i < WSize; ++i)
{
const int in_y = (g_thd_y - i) / item_out_stride;
const int in_y_mod = (g_thd_y - i) % item_out_stride;
if (in_y_mod == 0 && in_y >= 0 && in_y < in_dim_y)
{
const int in_idx = g_thd_x + in_y * in_stride_y + g_thd_z * in_stride_z;
acc += in[in_idx] * w[i];
}
}
const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z;
out[out_idx] = acc;
}
Loading