diff --git a/docs/source/bibliography/references.bib b/docs/source/bibliography/references.bib index 006ed416..b4cb6b9e 100644 --- a/docs/source/bibliography/references.bib +++ b/docs/source/bibliography/references.bib @@ -160,3 +160,14 @@ @article{beck2009fast year={2009}, publisher={SIAM} } + +@article{munch2009stripe, + title={Stripe and ring artifact removal with combined wavelet—Fourier filtering}, + author={M{\"u}nch, Beat and Trtik, Pavel and Marone, Federica and Stampanoni, Marco}, + journal={Optics express}, + volume={17}, + number={10}, + pages={8567--8591}, + year={2009}, + publisher={Optical Society of America} +} diff --git a/httomolibgpu/__init__.py b/httomolibgpu/__init__.py index 0c5430e9..e1b85cef 100644 --- a/httomolibgpu/__init__.py +++ b/httomolibgpu/__init__.py @@ -9,6 +9,7 @@ from httomolibgpu.prep.phase import paganin_filter, paganin_filter_savu_legacy from httomolibgpu.prep.stripe import ( remove_stripe_based_sorting, + remove_stripe_fw, remove_stripe_ti, remove_all_stripe, raven_filter, diff --git a/httomolibgpu/cuda_kernels/remove_stripe_fw.cu b/httomolibgpu/cuda_kernels/remove_stripe_fw.cu new file mode 100644 index 00000000..16597f4e --- /dev/null +++ b/httomolibgpu/cuda_kernels/remove_stripe_fw.cu @@ -0,0 +1,155 @@ +template +__global__ void grouped_convolution_x( + int dim_x, + int dim_y, + int dim_z, + const float* in, + int in_stride_x, + int in_stride_y, + int in_stride_z, + float* out, + int out_stride_z, + int out_stride_group, + const float* w +) +{ + const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x; + const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y; + const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z; + if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z) + { + return; + } + + constexpr int out_groups = 2; + for (int i = 0; i < out_groups; ++i) + { + float acc = 0.F; + for (int j = 0; j < WSize; ++j) + { + const int w_idx = i * WSize + j; + const int in_idx = (g_thd_x * in_stride_x + j) + g_thd_y * in_stride_y + g_thd_z * in_stride_z; + acc += w[w_idx] * in[in_idx]; + } + const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + i * out_stride_group; + out[out_idx] = acc; + } +} + +template +__global__ void grouped_convolution_y( + int dim_x, + int dim_y, + int dim_z, + const float* in, + int in_stride_x, + int in_stride_y, + int in_stride_z, + int in_stride_group, + float* out, + int out_stride_z, + int out_stride_group, + const float* w +) +{ + const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x; + const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y; + const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z; + if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z) + { + return; + } + + constexpr int in_groups = 2; + constexpr int out_groups = 2; + constexpr int item_stride_y = 2; + for (int group = 0; group < in_groups; ++group) + { + for (int i = 0; i < out_groups; ++i) + { + float acc = 0.F; + for (int j = 0; j < WSize; ++j) + { + const int w_idx = (out_groups * group + i) * WSize + j; + const int in_idx = g_thd_x * in_stride_x + (item_stride_y * g_thd_y + j) * in_stride_y + group * in_stride_group + g_thd_z * in_stride_z; + acc += w[w_idx] * in[in_idx]; + } + const int out_idx = g_thd_x + g_thd_y * dim_x + g_thd_z * out_stride_z + (out_groups * group + i) * out_stride_group; + out[out_idx] = acc; + } + } +} + +template +__global__ void transposed_convolution_x( + int dim_x, + int dim_y, + int dim_z, + const float* in, + int in_dim_x, + int in_stride_y, + int in_stride_z, + const float* w, + float* out +) +{ + const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x; + const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y; + const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z; + if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z) + { + return; + } + + constexpr int item_out_stride = 2; + float acc = 0.F; + for (int i = 0; i < WSize; ++i) + { + const int in_x = (g_thd_x - i) / item_out_stride; + const int in_x_mod = (g_thd_x - i) % item_out_stride; + if (in_x_mod == 0 && in_x >= 0 && in_x < in_dim_x) + { + const int in_idx = in_x + g_thd_y * in_stride_y + g_thd_z * in_stride_z; + acc += in[in_idx] * w[i]; + } + } + const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z; + out[out_idx] = acc; +} + +template +__global__ void transposed_convolution_y( + int dim_x, + int dim_y, + int dim_z, + const float* in, + int in_dim_y, + int in_stride_y, + int in_stride_z, + const float* w, + float* out +) +{ + const int g_thd_x = blockDim.x * blockIdx.x + threadIdx.x; + const int g_thd_y = blockDim.y * blockIdx.y + threadIdx.y; + const int g_thd_z = blockDim.z * blockIdx.z + threadIdx.z; + if (g_thd_x >= dim_x || g_thd_y >= dim_y || g_thd_z >= dim_z) + { + return; + } + + constexpr int item_out_stride = 2; + float acc = 0.F; + for (int i = 0; i < WSize; ++i) + { + const int in_y = (g_thd_y - i) / item_out_stride; + const int in_y_mod = (g_thd_y - i) % item_out_stride; + if (in_y_mod == 0 && in_y >= 0 && in_y < in_dim_y) + { + const int in_idx = g_thd_x + in_y * in_stride_y + g_thd_z * in_stride_z; + acc += in[in_idx] * w[i]; + } + } + const int out_idx = g_thd_x + dim_x * g_thd_y + dim_x * dim_y * g_thd_z; + out[out_idx] = acc; +} diff --git a/httomolibgpu/prep/stripe.py b/httomolibgpu/prep/stripe.py index fa0fde84..c46c33d9 100644 --- a/httomolibgpu/prep/stripe.py +++ b/httomolibgpu/prep/stripe.py @@ -21,6 +21,7 @@ """Module for stripes removal""" import numpy as np +import pywt from httomolibgpu import cupywrapper cp = cupywrapper.cp @@ -31,6 +32,7 @@ if cupy_run: from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d from cupyx.scipy.fft import fft2, ifft2, fftshift + from cupyx.scipy.fftpack import get_fft_plan from httomolibgpu.cuda_kernels import load_cuda_module else: median_filter = Mock() @@ -41,10 +43,11 @@ fftshift = Mock() -from typing import Union +from typing import Optional, Tuple, Union __all__ = [ "remove_stripe_based_sorting", + "remove_stripe_fw", "remove_stripe_ti", "remove_all_stripe", "raven_filter", @@ -156,6 +159,604 @@ def remove_stripe_ti( return data +###### Ring removal with wavelet filtering (adapted for cupy from pytroch_wavelet package https://pytorch-wavelets.readthedocs.io/)########## +# These functions are taken from TomoCuPy package +# *************************************************************************** # +# Copyright © 2022, UChicago Argonne, LLC # +# All Rights Reserved # +# Software Name: Tomocupy # +# By: Argonne National Laboratory # +# # +# OPEN SOURCE LICENSE # +# # +# Redistribution and use in source and binary forms, with or without # +# modification, are permitted provided that the following conditions are met: # +# # +# 1. Redistributions of source code must retain the above copyright notice, # +# this list of conditions and the following disclaimer. # +# 2. Redistributions in binary form must reproduce the above copyright # +# notice, this list of conditions and the following disclaimer in the # +# documentation and/or other materials provided with the distribution. # +# 3. Neither the name of the copyright holder nor the names of its # +# contributors may be used to endorse or promote products derived # +# from this software without specific prior written permission. # +# # +# # +# *************************************************************************** # + + +def _reflect(x: np.ndarray, minx: float, maxx: float) -> np.ndarray: + """Reflect the values in matrix *x* about the scalar values *minx* and + *maxx*. Hence a vector *x* containing a long linearly increasing series is + converted into a waveform which ramps linearly up and down between *minx* + and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + + 0.5), the ramps will have repeated max and min samples. + + .. codeauthor:: Rich Wareham , Aug 2013 + .. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. + + """ + rng = maxx - minx + rng_by_2 = 2 * rng + mod = np.fmod(x - minx, rng_by_2) + normed_mod = np.where(mod < 0, mod + rng_by_2, mod) + out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx + return np.array(out, dtype=x.dtype) + + +class _DeviceMemStack: + def __init__(self) -> None: + self.allocations = [] + self.current = 0 + self.highwater = 0 + + def malloc(self, bytes): + self.allocations.append(bytes) + allocated = self._round_up(bytes) + self.current += allocated + self.highwater = max(self.current, self.highwater) + + def free(self, bytes): + assert bytes in self.allocations + self.allocations.remove(bytes) + self.current -= self._round_up(bytes) + assert self.current >= 0 + + def _round_up(self, size): + ALLOCATION_UNIT_SIZE = 512 + size = (size + ALLOCATION_UNIT_SIZE - 1) // ALLOCATION_UNIT_SIZE + return size * ALLOCATION_UNIT_SIZE + + +def _mypad( + x: cp.ndarray, pad: Tuple[int, int, int, int], mem_stack: Optional[_DeviceMemStack] +) -> cp.ndarray: + """Function to do numpy like padding on Arrays. Only works for 2-D + padding. + + Inputs: + x (array): Array to pad + pad (tuple): tuple of (left, right, top, bottom) pad sizes + """ + # Vertical only + if pad[0] == 0 and pad[1] == 0: + m1, m2 = pad[2], pad[3] + l = x.shape[-2] if not mem_stack else x[-2] + xe = _reflect(np.arange(-m1, l + m2, dtype="int32"), -0.5, l - 0.5) + if mem_stack: + ret_shape = [x[0], x[1], xe.size, x[3]] + mem_stack.malloc(np.prod(ret_shape) * np.float32().itemsize) + return ret_shape + return x[:, :, xe, :] + # horizontal only + elif pad[2] == 0 and pad[3] == 0: + m1, m2 = pad[0], pad[1] + l = x.shape[-1] if not mem_stack else x[-1] + xe = _reflect(np.arange(-m1, l + m2, dtype="int32"), -0.5, l - 0.5) + if mem_stack: + ret_shape = [x[0], x[1], x[2], xe.size] + mem_stack.malloc(np.prod(ret_shape) * np.float32().itemsize) + return ret_shape + return x[:, :, :, xe] + + +def _next_power_of_two(x: int, max_val: int = 128) -> int: + n = 1 + while n < x and n < max_val: + n *= 2 + return n + + +def _conv2d( + x: cp.ndarray, + w: np.ndarray, + stride: Tuple[int, int], + groups: int, + mem_stack: Optional[_DeviceMemStack], +) -> cp.ndarray: + """Convolution (equivalent pytorch.conv2d)""" + b, ci, hi, wi = x.shape if not mem_stack else x + co, _, hk, wk = w.shape + ho = int(np.floor(1 + (hi - hk) / stride[0])) + wo = int(np.floor(1 + (wi - wk) / stride[1])) + out_shape = [b, co, ho, wo] + if mem_stack: + mem_stack.malloc(np.prod(out_shape) * np.float32().itemsize) + return out_shape + + out = cp.zeros(out_shape, dtype="float32") + w = cp.asarray(w) + x = cp.expand_dims(x, axis=1) + w = np.expand_dims(w, axis=0) + symbol_names = [f"grouped_convolution_x<{wk}>", f"grouped_convolution_y<{hk}>"] + module = load_cuda_module("remove_stripe_fw", name_expressions=symbol_names) + dim_x = out.shape[-1] + dim_y = out.shape[-2] + dim_z = out.shape[0] + in_stride_x = stride[1] + in_stride_y = x.strides[-2] // x.dtype.itemsize + in_stride_z = x.strides[0] // x.dtype.itemsize + out_stride_z = out.strides[0] // x.dtype.itemsize + out_stride_group = out.strides[1] // x.dtype.itemsize + + block_x = _next_power_of_two(dim_x) + block_dim = (block_x, 1, 1) + grid_x = (dim_x + block_x - 1) // block_x + grid_dim = (grid_x, dim_y, dim_z) + + if groups == 1: + grouped_convolution_kernel_x = module.get_function(symbol_names[0]) + grouped_convolution_kernel_x( + grid_dim, + block_dim, + ( + dim_x, + dim_y, + dim_z, + x, + in_stride_x, + in_stride_y, + in_stride_z, + out, + out_stride_z, + out_stride_group, + w, + ), + ) + return out + + grouped_convolution_kernel_y = module.get_function(symbol_names[1]) + in_stride_group = x.strides[2] // x.dtype.itemsize + grouped_convolution_kernel_y( + grid_dim, + block_dim, + ( + dim_x, + dim_y, + dim_z, + x, + in_stride_x, + in_stride_y, + in_stride_z, + in_stride_group, + out, + out_stride_z, + out_stride_group, + w, + ), + ) + del w + return out + + +def _conv_transpose2d( + x: cp.ndarray, + w: np.ndarray, + stride: Tuple[int, int], + pad: Tuple[int, int], + groups: int, + mem_stack: Optional[_DeviceMemStack], +) -> cp.ndarray: + """Transposed convolution (equivalent pytorch.conv_transpose2d)""" + b, co, ho, wo = x.shape if not mem_stack else x + co, ci, hk, wk = w.shape + + hi = (ho - 1) * stride[0] + hk + wi = (wo - 1) * stride[1] + wk + out_shape = [b, ci, hi, wi] + if mem_stack: + mem_stack.malloc(np.prod(out_shape) * np.float32().itemsize) + mem_stack.malloc(w.size * np.float32().itemsize) + if pad != 0: + new_out_shape = [ + out_shape[0], + out_shape[1], + out_shape[2] - 2 * pad[0], + out_shape[3] - 2 * pad[1], + ] + mem_stack.malloc(np.prod(new_out_shape) * np.float32().itemsize) + mem_stack.free(np.prod(out_shape) * np.float32().itemsize) + out_shape = new_out_shape + mem_stack.free(w.size * np.float32().itemsize) + return out_shape + + out = cp.zeros(out_shape, dtype="float32") + w = cp.asarray(w) + + symbol_names = [ + f"transposed_convolution_x<{wk}>", + f"transposed_convolution_y<{hk}>", + ] + module = load_cuda_module("remove_stripe_fw", name_expressions=symbol_names) + dim_x = out.shape[-1] + dim_y = out.shape[-2] + dim_z = out.shape[0] + in_dim_x = x.shape[-1] + in_dim_y = x.shape[-2] + in_stride_y = x.strides[-2] // x.dtype.itemsize + in_stride_z = x.strides[0] // x.dtype.itemsize + + block_x = _next_power_of_two(dim_x) + block_dim = (block_x, 1, 1) + grid_x = (dim_x + block_x - 1) // block_x + grid_dim = (grid_x, dim_y, dim_z) + + if wk > 1: + transposed_convolution_kernel_x = module.get_function(symbol_names[0]) + transposed_convolution_kernel_x( + grid_dim, + block_dim, + (dim_x, dim_y, dim_z, x, in_dim_x, in_stride_y, in_stride_z, w, out), + ) + elif hk > 1: + transposed_convolution_kernel_y = module.get_function(symbol_names[1]) + transposed_convolution_kernel_y( + grid_dim, + block_dim, + (dim_x, dim_y, dim_z, x, in_dim_y, in_stride_y, in_stride_z, w, out), + ) + else: + assert False + + if pad != 0: + out = out[:, :, pad[0] : out.shape[2] - pad[0], pad[1] : out.shape[3] - pad[1]] + return cp.ascontiguousarray(out) + + +def _afb1d( + x: cp.ndarray, + h0: np.ndarray, + h1: np.ndarray, + dim: int, + mem_stack: Optional[_DeviceMemStack], +) -> cp.ndarray: + """1D analysis filter bank (along one dimension only) of an image + + Parameters + ---------- + x (array): 4D input with the last two dimensions the spatial input + h0 (array): 4D input for the lowpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + h1 (array): 4D input for the highpass filter. Should have shape (1, 1, + h, 1) or (1, 1, 1, w) + dim (int) - dimension of filtering. d=2 is for a vertical filter (called + column filtering but filters across the rows). d=3 is for a + horizontal filter, (called row filtering but filters across the + columns). + + Returns + ------- + lohi: lowpass and highpass subbands concatenated along the channel + dimension + """ + C = x.shape[1] if not mem_stack else x[1] + # Convert the dim to positive + d = dim % 4 + s = (2, 1) if d == 2 else (1, 2) + N = x.shape[d] if not mem_stack else x[d] + L = h0.size + shape = [1, 1, 1, 1] + shape[d] = L + h = np.concatenate([h0.reshape(*shape), h1.reshape(*shape)] * C, axis=0) + # Calculate the pad size + outsize = pywt.dwt_coeff_len(N, L, mode="symmetric") + p = 2 * (outsize - 1) - N + L + pad = (0, 0, p // 2, (p + 1) // 2) if d == 2 else (p // 2, (p + 1) // 2, 0, 0) + padded_x = _mypad(x, pad=pad, mem_stack=mem_stack) + lohi = _conv2d(padded_x, h, stride=s, groups=C, mem_stack=mem_stack) + if mem_stack: + mem_stack.free(np.prod(padded_x) * np.float32().itemsize) + del padded_x + return lohi + + +def _sfb1d( + lo: cp.ndarray, + hi: cp.ndarray, + g0: np.ndarray, + g1: np.ndarray, + dim: int, + mem_stack: Optional[_DeviceMemStack], +) -> cp.ndarray: + """1D synthesis filter bank of an image Array""" + + C = lo.shape[1] if not mem_stack else lo[1] + d = dim % 4 + L = g0.size + shape = [1, 1, 1, 1] + shape[d] = L + s = (2, 1) if d == 2 else (1, 2) + g0 = np.concatenate([g0.reshape(*shape)] * C, axis=0) + g1 = np.concatenate([g1.reshape(*shape)] * C, axis=0) + pad = (L - 2, 0) if d == 2 else (0, L - 2) + y_lo = _conv_transpose2d(lo, g0, stride=s, pad=pad, groups=C, mem_stack=mem_stack) + y_hi = _conv_transpose2d(hi, g1, stride=s, pad=pad, groups=C, mem_stack=mem_stack) + if mem_stack: + # Allocation of the sum + mem_stack.malloc(np.prod(y_hi) * np.float32().itemsize) + mem_stack.free(np.prod(y_lo) * np.float32().itemsize) + mem_stack.free(np.prod(y_hi) * np.float32().itemsize) + return y_lo + return y_lo + y_hi + + +class _DWTForward: + """Performs a 2d DWT Forward decomposition of an image + + Args: + wave (str): Which wavelet to use. + """ + + def __init__(self, wave: str): + super().__init__() + + wave = pywt.Wavelet(wave) + h0_col, h1_col = wave.dec_lo, wave.dec_hi + h0_row, h1_row = h0_col, h1_col + + self.h0_col = np.array(h0_col).astype("float32")[::-1].reshape((1, 1, -1, 1)) + self.h1_col = np.array(h1_col).astype("float32")[::-1].reshape((1, 1, -1, 1)) + self.h0_row = np.array(h0_row).astype("float32")[::-1].reshape((1, 1, 1, -1)) + self.h1_row = np.array(h1_row).astype("float32")[::-1].reshape((1, 1, 1, -1)) + + def apply( + self, x: cp.ndarray, mem_stack: Optional[_DeviceMemStack] = None + ) -> Tuple[cp.ndarray, cp.ndarray]: + """Forward pass of the DWT. + + Args: + x (array): Input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Returns: + (yl, yh) + tuple of lowpass (yl) and bandpass (yh) coefficients. + yh is a list of scale coefficients. yl has shape + :math:`(N, C_{in}, H_{in}', W_{in}')` and yh has shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. The new + dimension in yh iterates over the LH, HL and HH coefficients. + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + """ + # Do a multilevel transform + # Do 1 level of the transform + lohi = _afb1d(x, self.h0_row, self.h1_row, dim=3, mem_stack=mem_stack) + y = _afb1d(lohi, self.h0_col, self.h1_col, dim=2, mem_stack=mem_stack) + if mem_stack: + y_shape = [y[0], np.prod(y) // y[0] // 4 // y[-2] // y[-1], 4, y[-2], y[-1]] + x_shape = [y_shape[0], y_shape[1], y_shape[3], y_shape[4]] + yh_shape = [y_shape[0], y_shape[1], y_shape[2] - 1, y_shape[3], y_shape[4]] + + mem_stack.free(np.prod(lohi) * np.float32().itemsize) + mem_stack.malloc(np.prod(x_shape) * np.float32().itemsize) + mem_stack.malloc(np.prod(yh_shape) * np.float32().itemsize) + mem_stack.free(np.prod(y) * np.float32().itemsize) + return x_shape, yh_shape + del lohi + s = y.shape + y = y.reshape(s[0], -1, 4, s[-2], s[-1]) + x = cp.ascontiguousarray(y[:, :, 0]) + yh = cp.ascontiguousarray(y[:, :, 1:]) + return (x, yh) + + +class _DWTInverse: + """Performs a 2d DWT Inverse reconstruction of an image + + Args: + wave (str): Which wavelet to use. + """ + + def __init__(self, wave: str): + super().__init__() + wave = pywt.Wavelet(wave) + g0_col, g1_col = wave.rec_lo, wave.rec_hi + g0_row, g1_row = g0_col, g1_col + # Prepare the filters + self.g0_col = np.array(g0_col).astype("float32").reshape((1, 1, -1, 1)) + self.g1_col = np.array(g1_col).astype("float32").reshape((1, 1, -1, 1)) + self.g0_row = np.array(g0_row).astype("float32").reshape((1, 1, 1, -1)) + self.g1_row = np.array(g1_row).astype("float32").reshape((1, 1, 1, -1)) + + def apply( + self, + coeffs: Tuple[cp.ndarray, cp.ndarray], + mem_stack: Optional[_DeviceMemStack] = None, + ) -> cp.ndarray: + """ + Args: + coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: + yl is a lowpass array of shape :math:`(N, C_{in}, H_{in}', + W_{in}')` and yh is a list of bandpass arrays of shape + :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match + the format returned by DWTForward + + Returns: + Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` + + Note: + :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly + downsampled shapes of the DWT pyramid. + + """ + yl, yh = coeffs + lh = yh[:, :, 0, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]] + hl = yh[:, :, 1, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]] + hh = yh[:, :, 2, :, :] if not mem_stack else [yh[0], yh[1], yh[3], yh[4]] + lo = _sfb1d(yl, lh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack) + hi = _sfb1d(hl, hh, self.g0_col, self.g1_col, dim=2, mem_stack=mem_stack) + yl = _sfb1d(lo, hi, self.g0_row, self.g1_row, dim=3, mem_stack=mem_stack) + if mem_stack: + mem_stack.free(np.prod(lo) * np.float32().itemsize) + mem_stack.free(np.prod(hi) * np.float32().itemsize) + del lo + del hi + return yl + + +def _repair_memory_fragmentation_if_needed(fragmentation_threshold: float = 0.2): + pool = cp.get_default_memory_pool() + total = pool.total_bytes() + if (total / pool.used_bytes()) - 1 > fragmentation_threshold: + pool.free_all_blocks() + + +def remove_stripe_fw( + data: cp.ndarray, + sigma: float = 2, + wname: str = "db5", + level: Optional[int] = None, + calc_peak_gpu_mem: bool = False, +) -> cp.ndarray: + """ + Remove horizontal stripes from sinogram using the Fourier-Wavelet (FW) based method :cite:`munch2009stripe`. The original source code + taken from TomoCupy and NABU packages. + + Parameters + ---------- + data : ndarray + 3D tomographic data as a CuPy array. + sigma : float + Damping parameter in Fourier space. + wname : str + Type of the wavelet filter: select from 'db5', 'db7', 'haar', 'sym5', 'sym16' 'bior4.4'. + level : int, optional + Number of discrete wavelet transform levels. + calc_peak_gpu_mem: str: + Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user. + + Returns + ------- + ndarray + Stripe-corrected 3D tomographic data as a CuPy array. + """ + + if level is None: + if calc_peak_gpu_mem: + size = np.max(data) # data is a tuple in this case + else: + size = np.max(data.shape) + level = int(np.ceil(np.log2(size))) + + [nproj, nz, ni] = data.shape if not calc_peak_gpu_mem else data + + nproj_pad = nproj + nproj // 8 + + # Accepts all wave types available to PyWavelets + xfm = _DWTForward(wave=wname) + ifm = _DWTInverse(wave=wname) + + # Wavelet decomposition. + cc = [] + sli_shape = [nz, 1, nproj_pad, ni] + + if calc_peak_gpu_mem: + mem_stack = _DeviceMemStack() + # A data copy is assumed when invoking the function + mem_stack.malloc(np.prod(data) * np.float32().itemsize) + mem_stack.malloc(np.prod(sli_shape) * np.float32().itemsize) + cc = [] + fcV_bytes = None + for k in range(level): + new_sli_shape, c = xfm.apply(sli_shape, mem_stack) + mem_stack.free(np.prod(sli_shape) * np.float32().itemsize) + sli_shape = new_sli_shape + cc.append(c) + + if fcV_bytes: + mem_stack.free(fcV_bytes) + fcV_shape = [c[0], c[3], c[4]] + fcV_bytes = np.prod(fcV_shape) * np.complex64().itemsize + mem_stack.malloc(fcV_bytes) + + # For the FFT + mem_stack.malloc(2 * np.prod(fcV_shape) * np.float32().itemsize) + mem_stack.malloc(2 * fcV_bytes) + + fft_dummy = cp.empty(fcV_shape, dtype="float32") + fft_plan = get_fft_plan(fft_dummy) + fft_plan_size = fft_plan.work_area.mem.size + del fft_dummy + del fft_plan + mem_stack.malloc(fft_plan_size) + mem_stack.free(2 * np.prod(fcV_shape) * np.float32().itemsize) + mem_stack.free(fft_plan_size) + mem_stack.free(2 * fcV_bytes) + + # The rest of the iteration doesn't contribute to the peak + # NOTE: The last iteration of fcV is "leaked" + + for k in range(level)[::-1]: + new_sli_shape = [sli_shape[0], sli_shape[1], cc[k][-2], cc[k][-1]] + new_sli_shape = ifm.apply((new_sli_shape, cc[k]), mem_stack) + mem_stack.free(np.prod(sli_shape) * np.float32().itemsize) + sli_shape = new_sli_shape + + mem_stack.malloc(np.prod(data) * np.float32().itemsize) + for c in cc: + mem_stack.free(np.prod(c) * np.float32().itemsize) + mem_stack.free(np.prod(sli_shape) * np.float32().itemsize) + return int(mem_stack.highwater * 1.1) + + sli = cp.zeros(sli_shape, dtype="float32") + sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2] = data.swapaxes(0, 1) + for k in range(level): + sli, c = xfm.apply(sli) + cc.append(c) + # FFT + fft_in = cp.ascontiguousarray(cc[k][:, 0, 1]) + fft_plan = get_fft_plan(fft_in, axes=1) + with fft_plan: + fcV = cp.fft.fft(fft_in, axis=1) + del fft_plan + del fft_in + _, my, mx = fcV.shape + # Damping of ring artifact information. + y_hat = np.fft.ifftshift((np.arange(-my, my, 2) + 1) / 2) + damp = -np.expm1(-(y_hat**2) / (2 * sigma**2)) + fcV *= cp.tile(damp, (mx, 1)).swapaxes(0, 1) + # Inverse FFT. + ifft_in = cp.ascontiguousarray(fcV) + ifft_plan = get_fft_plan(ifft_in, axes=1) + with ifft_plan: + cc[k][:, 0, 1] = cp.fft.ifft(ifft_in, my, axis=1).real + del ifft_plan + del ifft_in + _repair_memory_fragmentation_if_needed() + + # Wavelet reconstruction. + for k in range(level)[::-1]: + shape0 = cc[k][0, 0, 1].shape + sli = sli[:, :, : shape0[0], : shape0[1]] + sli = ifm.apply((sli, cc[k])) + _repair_memory_fragmentation_if_needed() + + data = sli[:, 0, (nproj_pad - nproj) // 2 : (nproj_pad + nproj) // 2, :ni] + data = data.swapaxes(0, 1) + return cp.ascontiguousarray(data) + + ######## Optimized version for Vo-all ring removal in tomopy######## # This function is taken from TomoCuPy package # *************************************************************************** # diff --git a/pyproject.toml b/pyproject.toml index 27e53c35..9e448e2a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "pillow", "scikit-image", "tomobar", + "PyWavelets", ] [project.optional-dependencies] diff --git a/tests/test_prep/test_stripe.py b/tests/test_prep/test_stripe.py index 230aa46a..23713961 100644 --- a/tests/test_prep/test_stripe.py +++ b/tests/test_prep/test_stripe.py @@ -8,6 +8,7 @@ from httomolibgpu.prep.stripe import ( remove_stripe_based_sorting, remove_stripe_ti, + remove_stripe_fw, remove_all_stripe, raven_filter, ) @@ -34,6 +35,135 @@ def test_remove_stripe_ti_on_data(data, flats, darks): assert data_after_stripe_removal.dtype == np.float32 +class MaxMemoryHook(cp.cuda.MemoryHook): + def __init__(self, initial=0): + self.max_mem = initial + self.current = initial + + def malloc_postprocess( + self, device_id: int, size: int, mem_size: int, mem_ptr: int, pmem_id: int + ): + self.current += mem_size + self.max_mem = max(self.max_mem, self.current) + + def free_postprocess( + self, device_id: int, mem_size: int, mem_ptr: int, pmem_id: int + ): + self.current -= mem_size + + +def test_remove_stripe_fw_on_data(data, flats, darks): + # --- testing the CuPy implementation from TomoCupy ---# + data_norm = dark_flat_field_correction(data, flats, darks, cutoff=10) + data_norm = minus_log(data_norm) + + data_after_stripe_removal = remove_stripe_fw( + cp.copy(data_norm), wname="sym16", sigma=1, level=7 + ).get() + + assert_allclose(np.mean(data_after_stripe_removal), 0.279236, rtol=1e-05) + assert_allclose( + np.mean(data_after_stripe_removal, axis=(1, 2)).sum(), 50.2624, rtol=1e-06 + ) + assert_allclose(np.median(data_after_stripe_removal), 0.079203, rtol=1e-05) + assert_allclose(np.max(data_after_stripe_removal), 2.442347, rtol=1e-05) + assert data_after_stripe_removal.flags.c_contiguous + + data = None #: free up GPU memory + # make sure the output is float32 + assert data_after_stripe_removal.dtype == np.float32 + + +@pytest.fixture +def ensure_clean_memory(): + cp.get_default_memory_pool().free_all_blocks() + cp.get_default_pinned_memory_pool().free_all_blocks() + cache = cp.fft.config.get_plan_cache() + cache.clear() + yield None + cp.get_default_memory_pool().free_all_blocks() + cp.get_default_pinned_memory_pool().free_all_blocks() + cache = cp.fft.config.get_plan_cache() + cache.clear() + + +@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"]) +@pytest.mark.parametrize("slices", [3, 7, 32, 61, 109, 120, 150]) +@pytest.mark.parametrize("level", [None, 1, 3, 11]) +@pytest.mark.parametrize("dim_x", [128, 140]) +def test_remove_stripe_fw_calc_mem(slices, level, dim_x, wname, ensure_clean_memory): + dim_y = 159 + data = cp.random.random_sample((slices, dim_x, dim_y), dtype=np.float32) + hook = MaxMemoryHook() + with hook: + remove_stripe_fw(cp.copy(data), wname=wname, level=level) + actual_mem_peak = hook.max_mem + + try: + estimated_mem_peak = remove_stripe_fw( + data.shape, level=level, wname=wname, calc_peak_gpu_mem=True + ) + except cp.cuda.memory.OutOfMemoryError: + pytest.skip("Not enough GPU memory to estimate memory peak") + + assert actual_mem_peak * 0.99 <= estimated_mem_peak + assert estimated_mem_peak <= actual_mem_peak * 1.3 + + +@pytest.mark.parametrize("wname", ["haar", "db4", "sym5", "sym16", "bior4.4"]) +@pytest.mark.parametrize( + "slices", [38, 177, 239, 320, 490, 607, 803, 859, 902, 951, 1019, 1074, 1105] +) +@pytest.mark.parametrize("level", [None, 7, 11]) +@pytest.mark.parametrize("dims", [(901, 1200), (1801, 2560)]) +def test_remove_stripe_fw_calc_mem_big(wname, slices, level, dims, ensure_clean_memory): + dim_y, dim_x = dims + data_shape = (slices, dim_x, dim_y) + try: + estimated_mem_peak = remove_stripe_fw( + data_shape, wname=wname, level=level, calc_peak_gpu_mem=True + ) + except cp.cuda.memory.OutOfMemoryError: + pytest.skip("Not enough GPU memory to estimate memory peak") + av_mem = cp.cuda.Device().mem_info[0] + if av_mem < estimated_mem_peak: + pytest.skip("Not enough GPU memory to run this test") + + hook = MaxMemoryHook() + with hook: + data = cp.random.random_sample(data_shape, dtype=np.float32) + remove_stripe_fw(data, wname=wname, level=level) + actual_mem_peak = hook.max_mem + + assert actual_mem_peak * 0.99 <= estimated_mem_peak + assert estimated_mem_peak <= actual_mem_peak * 1.3 + + +@pytest.mark.perf +def test_remove_stripe_fw_performance(ensure_clean_memory): + data_host = ( + np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001 + ) + data = cp.asarray(data_host, dtype=np.float32) + + # do a cold run first + remove_stripe_fw(cp.copy(data)) + + dev = cp.cuda.Device() + dev.synchronize() + + start = time.perf_counter_ns() + nvtx.RangePush("Core") + for _ in range(10): + # have to take copy, as data is modified in-place + remove_stripe_fw(cp.copy(data)) + nvtx.RangePop() + dev.synchronize() + duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10 + + assert "performance in ms" == duration_ms + + @pytest.mark.parametrize("angles", [180, 181]) @pytest.mark.parametrize("det_x", [11, 18]) @pytest.mark.parametrize("det_y", [5, 7, 8]) diff --git a/zenodo-tests/test_prep/test_stripe.py b/zenodo-tests/test_prep/test_stripe.py index 9a92d8d5..847e67b2 100644 --- a/zenodo-tests/test_prep/test_stripe.py +++ b/zenodo-tests/test_prep/test_stripe.py @@ -5,6 +5,7 @@ from numpy.testing import assert_allclose from httomolibgpu.prep.stripe import ( remove_stripe_based_sorting, + remove_stripe_fw, remove_stripe_ti, remove_all_stripe, raven_filter, @@ -104,6 +105,51 @@ def test_remove_stripe_ti_i12_dataset4( assert output.flags.c_contiguous +@pytest.mark.parametrize( + "dataset_fixture, sigma_val, level, norm_res_expected", + [ + ( + "i12_dataset4", + 0.01, + 7, + 52.4856, + ), + ( + "i12_dataset4", + 0.3, + 5, + 53.7807, + ), + ( + "i12_dataset4", + 1.0, + 10, + 262.2167, + ), + ], + ids=["case_001", "case_003", "case_006"], +) +def test_remove_stripe_fw_i12_dataset4( + request, dataset_fixture, sigma_val, level, norm_res_expected +): + dataset = request.getfixturevalue(dataset_fixture) + data_normalised = dark_flat_field_correction(dataset[0], dataset[2], dataset[3]) + data_normalised = minus_log(data_normalised) + + del dataset + force_clean_gpu_memory() + + output = remove_stripe_fw(cp.copy(data_normalised), sigma=sigma_val, level=level) + + residual_calc = data_normalised - output + norm_res = cp.linalg.norm(residual_calc.flatten()) + + assert isclose(norm_res, norm_res_expected, abs_tol=10**-4) + + assert output.dtype == np.float32 + assert output.flags.c_contiguous + + @pytest.mark.parametrize( "dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected", [