Shrinkage Fields (image restoration)

Shrinkage fields is a random field-based machine learning technique that aims to perform high quality image restoration (denoising and deblurring) using low computational overhead.

Method

The restored image ${\displaystyle x}$ is predicted from a corrupted observation ${\displaystyle y}$ after training on a set of sample images ${\displaystyle S}$.

A shrinkage (mapping) function ${\displaystyle {f}_{{\pi }_{i}}\left(v\right)={\sum }_{j=1}^{M}{\pi }_{i,j}\exp \left(-{\frac {\gamma }{2}}{\left(v-{\mu }_{j}\right)}^{2}\right)}$ is directly modeled as a linear combination of radial basis function kernels, where ${\displaystyle \gamma }$ is the shared precision parameter, ${\displaystyle \mu }$ denotes the (equidistant) kernel positions, and M is the number of Gaussian kernels.

Because the shrinkage function is directly modeled, the optimization procedure is reduced to a single quadratic minimization per iteration, denoted as the prediction of a shrinkage field ${\displaystyle {g}_{\mathrm {\Theta } }\left({\text{x}}\right)={\mathcal {F}}^{-1}\left\lbrack {\frac {{\mathcal {F}}\left(\lambda {K}^{T}y+{\sum }_{i=1}^{N}{F}_{i}^{T}{f}_{{\pi }_{i}}\left({F}_{i}x\right)\right)}{\lambda {\check {K}}^{\text{*}}\circ {\check {K}}+{\sum }_{i=1}^{N}{\check {F}}_{i}^{\text{*}}\circ {\check {F}}_{i}}}\right\rbrack ={\mathrm {\Omega } }^{-1}\eta }$ where ${\displaystyle {\mathcal {F}}}$ denotes the discrete Fourier transform and ${\displaystyle F_{x}}$ is the 2D convolution ${\displaystyle {\text{f}}\otimes {\text{x}}}$ with point spread function filter, ${\displaystyle {\breve {F}}}$ is an optical transfer function defined as the discrete Fourier transform of ${\displaystyle {\text{f}}}$, and ${\displaystyle {\breve {F}}^{\text{*}}}$ is the complex conjugate of ${\displaystyle {\breve {F}}}$.

${\displaystyle {\hat {x}}_{t}}$ is learned as ${\displaystyle {\hat {x}}_{t}={g}_{{\mathrm {\Theta } }_{t}}\left({\hat {x}}_{t-1}\right)}$ for each iteration ${\displaystyle t}$ with the initial case ${\displaystyle {\hat {x}}_{0}=y}$, this forms a cascade of Gaussian conditional random fields (or cascade of shrinkage fields (CSF)). Loss-minimization is used to learn the model parameters ${\displaystyle {\mathrm {\Theta } }_{t}={\left\lbrace {\lambda }_{t},{\pi }_{\mathit {ti}},{f}_{\mathit {ti}}\right\rbrace }_{i=1}^{N}}$.

The learning objective function is defined as ${\displaystyle J\left({\mathrm {\Theta } }_{t}\right)={\sum }_{s=1}^{S}l\left({\hat {x}}_{t}^{\left(s\right)};{x}_{gt}^{\left(s\right)}\right)}$, where ${\displaystyle l}$ is a differentiable loss function which is greedily minimized using training data ${\displaystyle {\left\lbrace {x}_{gt}^{\left(s\right)},{y}^{\left(s\right)},{k}^{\left(s\right)}\right\rbrace }_{s=1}^{S}}$ and ${\displaystyle {\hat {x}}_{t}^{\left(s\right)}}$.

Performance

Preliminary tests by the author suggest that RTF5[1] obtains slightly better denoising performance than ${\displaystyle {\text{CSF}}_{7\times 7}^{\left\lbrace \mathrm {3,4,5} \right\rbrace }}$, followed by ${\displaystyle {\text{CSF}}_{5\times 5}^{5}}$, ${\displaystyle {\text{CSF}}_{7\times 7}^{2}}$, ${\displaystyle {\text{CSF}}_{5\times 5}^{\left\lbrace \mathrm {3,4} \right\rbrace }}$, and BM3D.

BM3D denoising speed falls between that of ${\displaystyle {\text{CSF}}_{5\times 5}^{4}}$ and ${\displaystyle {\text{CSF}}_{7\times 7}^{4}}$, RTF being an order of magnitude slower.

• Results are comparable to those obtained by BM3D (reference in state of the art denoising since its inception in 2007)
• Minimal runtime compared to other high-performance methods (potentially applicable within embedded devices)
• Parallelizable (e.g.: possible GPU implementation)
• Predictability: ${\displaystyle O(D\log D)}$ runtime where ${\displaystyle D}$ is the number of pixels
• Fast training even with CPU