WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis

arXiv preprint

University of Basel, Department of Biomedical Engineering

Schematic overview of the proposed wavelet-based medical image synthesis framework.

Abstract

Due to the three-dimensional nature of CT- or MR-scans, generative modeling of medical images is a particularly challenging task. Existing approaches mostly apply patch-wise, slice-wise, or cascaded generation techniques to fit the high-dimensional data into the limited GPU memory. However, these approaches may introduce artifacts and potentially restrict the model's applicability for certain downstream tasks. This work presents WDM, a wavelet-based medical image synthesis framework that applies a diffusion model on wavelet decomposed images. The presented approach is a simple yet effective way of scaling diffusion models to high resolutions and can be trained on a single 40 GB GPU. Experimental results on BraTS and LIDC-IDRI unconditional image generation at a resolution of 128 × 128 × 128 show state-of-the-art image fidelity (FID) and sample diversity (MS-SSIM) scores compared to GANs, Diffusion Models, and Latent Diffusion Models. Our proposed method is the only one capable of generating high-quality images at a resolution of 256 × 256 × 256.

General Concept

Our proposed method follows a concept that is closely related to Latent Diffusion Models (LDMs). While LDMs use a pretrained autoencoder to encode input images into a learned low-dimensional representation and then apply a diffusion model on these latent representations, our method follows a dataset-agnostic, training-free approach for spatial dimensionality reduction. Instead of using a pretrained autoencoder, we apply Discrete Wavelet Transform (DWT) to decompose input images \(y \in \mathbb{R}^{D \times H \times W}\) into their wavelet coefficients \(x_{\{lll, ..., hhh\}} \in \mathbb{R}^{\frac{D}{2} \times \frac{H}{2} \times \frac{W}{2}}\), which we then concatenate to form a single target matrix \(x \in \mathbb{R}^{8 \times \frac{D}{2} \times \frac{H}{2} \times \frac{W}{2}}\) to be generated by a diffusion model. When processing this matrix \(x\), we first map it onto the network's base channels \(\mathcal{C}\) (number of channels in the input layer) via a first convolution, leaving the network width unchanged compared to standard architectures. As our network then operates on the wavelet domain only, most parts profit from an \(8 \times\) reduction in spatial dimension, allowing for shallower network architectures, less computations and a significantly reduced memory footprint. The final output images are obtained by applying Inverse Discrete Wavelet Transform (IDWT) to the generated wavelet coefficients.

Unconditional Image Generation Results

To assess our models performance, we evaluate it on an unconditional brain MR and lung CT generation task. Our proposed approach not only outperforms most comparing methods in FID and MS-SSIM metrics, it also has the lowest inference GPU memory footprint at a resolution of 128 × 128 × 128 and was the only diffusion-based method that could be trained at a resolution of 256 × 256 × 256 (on a single 40 GB GPU). In the following section, we present some synthetic images generated by our method.

Brain MRI - BraTs (128 × 128 × 128)


Brain MRI - BraTs (256 × 256 × 256)


Lung CT - LIDC-IDRI (128 × 128 × 128)


Lung CT - LIDC-IDRI (256 × 256 × 256)

BibTeX

@article{friedrich2024wdm,
         title={WDM: 3D Wavelet Diffusion Models for High-Resolution Medical Image Synthesis},
         author={Paul Friedrich and Julia Wolleb and Florentin Bieder and Alicia Durrer and Philippe C. Cattin},
         year={2024},
         journal={arXiv preprint arXiv:2402.19043}}

Acknowledgements

This work was financially supported by the Werner Siemens Foundation through the MIRACLE II project.