Skip to content

csttool.preprocess

The preprocessing module exposes the high-level functions that the CLI preprocess command orchestrates. Use it directly when you want to script preprocessing from Python.

from csttool.preprocess import preprocess

preprocess(
    nifti="raw_dwi.nii.gz",
    out="./preproc",
    denoise_method="patch2self",
    perform_motion_correction=True,
)

csttool.preprocess

Preprocessing package for csttool.

Exports the main preprocessing orchestrator and module functions.

Functions

run_preprocessing(input_dir, output_dir, filename, *, denoise_method='patch2self', coil_count=4, apply_gibbs_correction=False, apply_motion_correction=False, target_voxel_size=None, save_visualizations=False, verbose=False)

Run the complete DWI preprocessing pipeline.

Steps: 1. Load dataset (NIfTI/DICOM + gradient table) 2. Reslice to target voxel size (optional) 3. Denoise (Patch2Self or NLMeans) 4. Brain masking (median Otsu on b0 volumes) 5. Gibbs unringing (optional) 6. Motion correction (optional) 7. Save outputs

Parameters:

Name Type Description Default
input_dir str or Path

Directory containing input NIfTI/DICOM and gradient files.

required
output_dir str or Path

Output directory for preprocessed files.

required
filename str

Base filename without extension (e.g., "sub01_dwi").

required
denoise_method str

Denoising method: "patch2self" or "nlmeans".

"patch2self"
coil_count int

Number of scanner coils (for NLMeans noise estimation).

4
apply_gibbs_correction bool

Apply Gibbs ringing correction.

False
apply_motion_correction bool

Apply between-volume motion correction.

False
target_voxel_size tuple[float, float, float] or None

Target voxel size in mm (x, y, z). If provided, data will be resliced to this voxel size. If None, no reslicing is performed.

None
save_visualizations bool

Save QC visualizations.

False
verbose bool

Print detailed processing information.

False

Returns:

Type Description
dict

Dictionary containing: - 'output_paths': Paths to saved files - 'brain_mask': The computed brain mask array - 'motion_correction_applied': Whether motion correction was applied - 'gtab': The gradient table

Source code in src/csttool/preprocess/preprocess.py
def run_preprocessing(
    input_dir: str | Path,
    output_dir: str | Path,
    filename: str,
    *,
    # Denoising options
    denoise_method: str = "patch2self",
    coil_count: int = 4,
    # Optional steps
    apply_gibbs_correction: bool = False,
    apply_motion_correction: bool = False,
    target_voxel_size: tuple[float, float, float] | None = None,
    # Visualization options
    save_visualizations: bool = False,
    verbose: bool = False,
) -> dict:
    """
    Run the complete DWI preprocessing pipeline.

    Steps:
        1. Load dataset (NIfTI/DICOM + gradient table)
        2. Reslice to target voxel size (optional)
        3. Denoise (Patch2Self or NLMeans)
        4. Brain masking (median Otsu on b0 volumes)
        5. Gibbs unringing (optional)
        6. Motion correction (optional)
        7. Save outputs

    Parameters
    ----------
    input_dir : str or Path
        Directory containing input NIfTI/DICOM and gradient files.
    output_dir : str or Path
        Output directory for preprocessed files.
    filename : str
        Base filename without extension (e.g., "sub01_dwi").
    denoise_method : str, default="patch2self"
        Denoising method: "patch2self" or "nlmeans".
    coil_count : int, default=4
        Number of scanner coils (for NLMeans noise estimation).
    apply_gibbs_correction : bool, default=False
        Apply Gibbs ringing correction.
    apply_motion_correction : bool, default=False
        Apply between-volume motion correction.
    target_voxel_size : tuple[float, float, float] or None, default=None
        Target voxel size in mm (x, y, z). If provided, data will be resliced
        to this voxel size. If None, no reslicing is performed.
    save_visualizations : bool, default=False
        Save QC visualizations.
    verbose : bool, default=False
        Print detailed processing information.

    Returns
    -------
    dict
        Dictionary containing:
        - 'output_paths': Paths to saved files
        - 'brain_mask': The computed brain mask array
        - 'motion_correction_applied': Whether motion correction was applied
        - 'gtab': The gradient table
    """
    input_dir = Path(input_dir)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # -------------------------------------------------------------------------
    # Step 1: Load dataset
    # -------------------------------------------------------------------------
    if verbose:
        print(f"Loading dataset from {input_dir}")

    nii, gtab, nifti_dir, metadata = load_dataset(str(input_dir), filename)
    data = np.asarray(nii.dataobj) if hasattr(nii, 'dataobj') else nii
    affine = nii.affine if hasattr(nii, 'affine') else np.eye(4)

    # Get current voxel size from the NIfTI header
    current_voxel_size = nii.header.get_zooms()[:3]
    print(f"PREPROCESSING: Loaded data with shape {data.shape}")
    print(f"PREPROCESSING: Current voxel size: {current_voxel_size} mm")

    # -------------------------------------------------------------------------
    # Step 2: Reslice to target voxel size (optional)
    # -------------------------------------------------------------------------
    if target_voxel_size is not None:
        print(f"PREPROCESSING: Reslicing to target voxel size: {target_voxel_size} mm")
        data, affine = reslice_voxels(
            data,
            affine,
            voxel_size=current_voxel_size,
            new_voxel_size=target_voxel_size
        )
        print(f"PREPROCESSING: Reslicing complete. New shape: {data.shape}")
    elif verbose:
        print("PREPROCESSING: Reslicing skipped")

    # -------------------------------------------------------------------------
    # Step 3: Denoise
    # -------------------------------------------------------------------------
    denoised = denoise(
        data,
        bvals=gtab.bvals,
        brain_mask=None,
        denoise_method=denoise_method,
        N=coil_count
    )
    print(f"PREPROCESSING: Denoising complete ({denoise_method})")

    # -------------------------------------------------------------------------
    # Step 4: Brain masking
    # -------------------------------------------------------------------------
    masked_data, brain_mask = background_segmentation(denoised, gtab)
    print("PREPROCESSING: Brain masking complete")

    # -------------------------------------------------------------------------
    # Step 5: Gibbs unringing (optional)
    # -------------------------------------------------------------------------
    if apply_gibbs_correction:
        unringed = gibbs_unringing(masked_data)
        data_for_motion = unringed
        print("PREPROCESSING: Gibbs ringing correction complete")
    else:
        data_for_motion = masked_data
        if verbose:
            print("PREPROCESSING: Gibbs ringing correction skipped")

    # -------------------------------------------------------------------------
    # Step 6: Motion correction (optional)
    # -------------------------------------------------------------------------
    motion_correction_applied = False
    reg_affines = None

    if apply_motion_correction:
        try:
            preprocessed, reg_affines = perform_motion_correction(
                data_for_motion,
                gtab,
                affine,
                brain_mask=brain_mask
            )
            motion_correction_applied = True
            print("PREPROCESSING: Motion correction complete")
        except Exception as e:
            print(f"PREPROCESSING: Motion correction failed: {e}")
            print("   Continuing without motion correction")
            preprocessed = data_for_motion
    else:
        preprocessed = data_for_motion
        if verbose:
            print("PREPROCESSING: Motion correction skipped")

    # -------------------------------------------------------------------------
    # Step 7: Save outputs
    # -------------------------------------------------------------------------
    suffix = "_mc" if motion_correction_applied else "_nomc"
    output_stem = f"{filename}_dwi_preproc{suffix}"

    # Build gradient file paths
    bval_path = input_dir / f"{filename}.bval"
    if not bval_path.exists():
        bval_path = input_dir / f"{filename}.bvals"

    bvec_path = input_dir / f"{filename}.bvec"
    if not bvec_path.exists():
        bvec_path = input_dir / f"{filename}.bvecs"

    gradient_files = {}
    if bval_path.exists():
        gradient_files['bval'] = bval_path
    if bvec_path.exists():
        gradient_files['bvec'] = bvec_path

    output_paths = save_preprocessed(
        data=preprocessed,
        affine=affine,
        output_dir=output_dir,
        filename_stem=output_stem,
        gradient_files=gradient_files if gradient_files else None,
        brain_mask=brain_mask,
        processing_params={
            'denoise_method': denoise_method,
            'gibbs_correction': apply_gibbs_correction,
            'motion_correction': motion_correction_applied,
            'resliced': target_voxel_size is not None,
            'target_voxel_size': target_voxel_size if target_voxel_size else None,
        }
    )
    print(f"PREPROCESSING: Saved outputs to {output_dir}")

    # -------------------------------------------------------------------------
    # Step 8: Visualizations (optional)
    # -------------------------------------------------------------------------
    if save_visualizations:
        try:
            from .modules.visualizations import save_all_preprocessing_visualizations
            viz_dir = output_dir / "visualizations"
            viz_dir.mkdir(parents=True, exist_ok=True)
            save_all_preprocessing_visualizations(
                data_original=data,  # Raw data before any processing
                data_denoised=denoised,  # Denoised (same shape as original)
                data_masked=masked_data,  # After brain masking (cropped)
                data_unringed=unringed if apply_gibbs_correction else None,
                data_preprocessed=preprocessed,
                brain_mask=brain_mask,
                gtab=gtab,
                output_dir=viz_dir,
                stem=filename,
                denoise_method=denoise_method,
                reg_affines=reg_affines,
                motion_correction_applied=motion_correction_applied,
            )
            print("PREPROCESSING: QC visualizations saved")
        except Exception as e:
            print(f"PREPROCESSING: Visualization saving failed: {e}")

    print(f"\nPREPROCESSING COMPLETED")

    return {
        'output_paths': output_paths,
        'brain_mask': brain_mask,
        'motion_correction_applied': motion_correction_applied,
        'gtab': gtab,
    }

load_dataset(dir_path, fname)

Load dataset from DICOM directory or NIfTI file and build gradient table.

Parameters:

Name Type Description Default
dir_path str

Path to the directory containing the dataset.

required
fname str

Name of the file to load.

required

Returns:

Name Type Description
nii Nifti1Image

NIfTI image.

bval str

Path to the bval file.

bvec str

Path to the bvec file.

gtab GradientTable

Gradient table.

Source code in src/csttool/preprocess/modules/load_dataset.py
def load_dataset(dir_path: str, fname: str):
    """
    Load dataset from DICOM directory or NIfTI file and build gradient table.

    Parameters
    ----------
    dir_path : str
        Path to the directory containing the dataset.
    fname : str
        Name of the file to load.

    Returns
    -------
    nii : Nifti1Image
        NIfTI image.
    bval : str
        Path to the bval file.
    bvec : str
        Path to the bvec file.
    gtab : GradientTable
        Gradient table.
    """
    # Check if the directory exists
    dir_path = Path(dir_path)
    if not dir_path.is_dir():
        raise ValueError(f"Directory {dir_path} does not exist")

    # Check if DICOM directory
    if any(f.suffix == ".dcm" for f in dir_path.iterdir()):
        print(f"DICOM directory detected: {dir_path}")
        # Convert DICOM to NIfTI
        # Save NIfTI, bval and bvec files to a directory called nifti one level up from dir_path
        print(f"Converting DICOM to NIfTI...")
        nifti_dir = dir_path.parent / "nifti"
        nifti_dir.mkdir(parents=True, exist_ok=True)  # Create directory if it doesn't exist
        result = dicom2nifti.dicom_series_to_nifti(
            str(dir_path),
            str(nifti_dir / (fname + ".nii.gz")),
            reorient_nifti=True
        )
        nii = nib.load(result["NII_FILE"])
        bval_path = result.get("BVAL_FILE")
        bvec_path = result.get("BVEC_FILE") 
        nii_path = result["NII_FILE"] # Ensure path is available for sidecar lookup 
    else:
        print(f"NIfTI directory detected: {dir_path}")
        nifti_dir = dir_path
        nii_path = os.path.join(dir_path, fname + ".nii.gz")

        # Try .bval first, then .bvals for gradient files
        bval_path = os.path.join(dir_path, fname + ".bval")
        if not os.path.exists(bval_path):
            bval_path = os.path.join(dir_path, fname + ".bvals")

        bvec_path = os.path.join(dir_path, fname + ".bvec")
        if not os.path.exists(bvec_path):
            bvec_path = os.path.join(dir_path, fname + ".bvecs")

        nii = nib.load(nii_path)

    # Read bvalues and bvectors, build a gradient table
    bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path)
    gtab = gradient_table(bvals=bvals, bvecs=bvecs)
    num_of_gradients = len(gtab)


    # Try to load JSON sidecar
    metadata = {}
    json_path = nifti_dir / (fname + ".json")
    if not json_path.exists():
        # Try finding json with same name as nifti
        json_path = Path(nii_path).with_suffix('').with_suffix('.json')

    if json_path.exists():
        try:
            import json
            with open(json_path, 'r') as f:
                sidecar = json.load(f)

            # Extract standard BIDS fields
            metadata = {
                'MagneticFieldStrength': sidecar.get('MagneticFieldStrength'),
                'EchoTime': sidecar.get('EchoTime'),
                'RepetitionTime': sidecar.get('RepetitionTime'),
                'Manufacturer': sidecar.get('Manufacturer'),
                'DeviceSerialNumber': sidecar.get('DeviceSerialNumber'),
                'SoftwareVersions': sidecar.get('SoftwareVersions'),
                'dcm2niix_version': sidecar.get('ConversionSoftwareVersion'),
            }
            # Clean up None values
            metadata = {k: v for k, v in metadata.items() if v is not None}
            print(f"  → Loaded metadata from sidecar: {json_path}")
        except Exception as e:
            print(f"  ⚠️ Could not read JSON sidecar: {e}")

    # Add derived fields
    metadata.update({
        'VoxelSize': [float(x) for x in nii.header.get_zooms()[:3]],
        'Dimensions': [int(x) for x in nii.shape],
        'NumVolumes': nii.shape[3] if len(nii.shape) > 3 else 1,
        'NumDirections': len(gtab.bvals) if hasattr(gtab, 'bvals') else 0,
        'MaxBValue': float(gtab.bvals.max()) if hasattr(gtab, 'bvals') and len(gtab.bvals) > 0 else 0,
    })

    print(gtab.info)
    print(f"Number of gradients: {num_of_gradients}")
    print("\n" + "=" * 60)

    return nii, gtab, nifti_dir, metadata

denoise(data, bvals=None, brain_mask=None, denoise_method='patch2self', N=4)

Denoise DWI data.

Parameters:

Name Type Description Default
data ndarray

4D DWI data array.

required
bvals ndarray

1D array of b values.

None
brain_mask ndarray

3D brain mask array.

None
denoise_method str

Denoising method to use. Can be "nlmeans" or "patch2self".

'patch2self'
N int

Number of scanner head coils used for acquisition, needed for NLM.

4

Returns:

Name Type Description
denoised_data ndarray

4D denoised DWI data array.

Source code in src/csttool/preprocess/modules/denoise.py
def denoise(
    data: np.ndarray, 
    bvals: np.ndarray = None, 
    brain_mask: np.ndarray = None, 
    denoise_method: str = "patch2self", 
    N: int = 4
):
    """
    Denoise DWI data.

    Parameters
    ----------
    data : np.ndarray
        4D DWI data array.
    bvals : np.ndarray
        1D array of b values.
    brain_mask : np.ndarray
        3D brain mask array.
    denoise_method : str
        Denoising method to use. Can be "nlmeans" or "patch2self".
    N : int
        Number of scanner head coils used for acquisition, needed for NLM.

    Returns
    -------
    denoised_data : np.ndarray
        4D denoised DWI data array.
    """

    available_methods = ["nlmeans", "patch2self"]
    if denoise_method not in available_methods:
        raise ValueError(f"Invalid denoise method: {denoise_method}. Available methods: {available_methods}")

    # Denoise with NLM
    # https://docs.dipy.org/dev/examples_built/preprocessing/denoise_nlmeans.html
    if denoise_method == "nlmeans":
        print("Denoising with NLM...")
        noise, noise_mask = piesno(data, N=N, return_mask=True)
        sigma = float(np.mean(noise))  # Calculate the noise standard deviation
        if brain_mask is None:
            print("  ⚠️ Brain mask is None, using noise mask as rudimentary brain mask")
            brain_mask = ~noise_mask  # Invert the noise mask as rudimentary brain mask
        denoised_data = nlmeans(
            data.astype(np.float32),
            sigma=sigma,
            mask=brain_mask,
            patch_radius=1,
            block_radius=2,
            rician=False,
            num_threads=-1
        )

    # Denoise with Patch2Self
    # Requires bvals
    # https://docs.dipy.org/dev/examples_built/preprocessing/denoise_patch2self.html
    elif denoise_method == "patch2self":
        print("Denoising with Patch2Self...")
        denoised_data = patch2self(
            data,
            bvals=bvals,
            model="ols",
            shift_intensity=True,
            clip_negative_vals=False,
            b0_threshold=50,
        )

    return denoised_data

gibbs_unringing(data, slice_axis=2, n_points=3)

Remove Gibbs' ringing artifacts from DWI data.

Gibbs ringing artifacts appear as spurious oscillations near sharp edges in MR images due to truncation of k-space data.

Parameters:

Name Type Description Default
data ndarray

3D or 4D DWI data array.

required
slice_axis int

Axis along which slices were acquired (0, 1, or 2). Default is 2.

2
n_points int

Number of neighbor points to access local TV. Default is 3.

3

Returns:

Name Type Description
data_corrected ndarray

3D or 4D DWI data array with Gibbs ringing artifacts removed.

Source code in src/csttool/preprocess/modules/gibbs_unringing.py
def gibbs_unringing(
    data: np.ndarray,
    slice_axis: int = 2,
    n_points: int = 3
) -> np.ndarray:
    """
    Remove Gibbs' ringing artifacts from DWI data.

    Gibbs ringing artifacts appear as spurious oscillations near sharp edges
    in MR images due to truncation of k-space data.

    Parameters
    ----------
    data : np.ndarray
        3D or 4D DWI data array.
    slice_axis : int, optional
        Axis along which slices were acquired (0, 1, or 2). Default is 2.
    n_points : int, optional
        Number of neighbor points to access local TV. Default is 3.

    Returns
    -------
    data_corrected : np.ndarray
        3D or 4D DWI data array with Gibbs ringing artifacts removed.
    """
    if slice_axis not in [0, 1, 2]:
        raise ValueError(f"slice_axis must be 0, 1, or 2, got {slice_axis}")

    data_corrected = gibbs_removal(
        data,
        slice_axis=slice_axis,
        n_points=n_points,
        num_processes=-1
    )
    return data_corrected

background_segmentation(data, gtab=None, median_radius=2, numpass=1, autocrop=False)

Estimate brain mask with median Otsu.

Parameters:

Name Type Description Default
data ndarray

4D DWI data array.

required
gtab GradientTable

Gradient table to identify b0 volumes. If provided, only b0 volumes are used for mask computation. If None, all volumes are used.

None
median_radius int

Radius of the median filter. Default is 2.

2
numpass int

Number of passes for the median filter. Default is 1.

1
autocrop bool

Whether to autocrop the data. Default is True.

False

Returns:

Name Type Description
masked_data ndarray

4D masked DWI data array.

mask ndarray

3D binary brain mask array.

Source code in src/csttool/preprocess/modules/background_segmentation.py
def background_segmentation(
    data: np.ndarray,
    gtab=None,
    median_radius: int = 2,
    numpass: int = 1,
    autocrop: bool = False
) -> tuple[np.ndarray, np.ndarray]:
    """
    Estimate brain mask with median Otsu.

    Parameters
    ----------
    data : np.ndarray
        4D DWI data array.
    gtab : GradientTable, optional
        Gradient table to identify b0 volumes. If provided, only b0 volumes
        are used for mask computation. If None, all volumes are used.
    median_radius : int, optional
        Radius of the median filter. Default is 2.
    numpass : int, optional
        Number of passes for the median filter. Default is 1.
    autocrop : bool, optional
        Whether to autocrop the data. Default is True.

    Returns
    -------
    masked_data : np.ndarray
        4D masked DWI data array.
    mask : np.ndarray
        3D binary brain mask array.
    """
    # Determine which volumes to use for mask computation
    vol_idx = None
    if gtab is not None:
        # Use only b0 volumes (b-value < 50)
        b0_idx = np.where(gtab.bvals < 50)[0]
        if b0_idx.size > 0:
            vol_idx = b0_idx
    else:
        # If no gtab provided, default to using the first volume
        if data.ndim == 4:
            vol_idx = [0]

    masked_data, mask = median_otsu(
        data,
        vol_idx=vol_idx,
        median_radius=median_radius,
        numpass=numpass,
        autocrop=autocrop
    )

    return masked_data, mask

perform_motion_correction(data, gtab, affine, brain_mask=None)

Perform between-volume motion correction.

Parameters:

Name Type Description Default
data ndarray

4D DWI data array.

required
gtab GradientTable

Gradient table containing b-values and b-vectors.

required
affine ndarray

4x4 affine transformation matrix.

required
brain_mask ndarray or None

Binary brain mask to constrain registration. If provided, will be converted to uint8 and passed as static_mask.

None

Returns:

Name Type Description
data_corrected ndarray

4D DWI data array with motion correction applied.

reg_affines list

List of 4x4 registration affine matrices for each volume.

Source code in src/csttool/preprocess/modules/perform_motion_correction.py
def perform_motion_correction(
    data: np.ndarray,
    gtab,
    affine: np.ndarray,
    brain_mask: np.ndarray | None = None
) -> tuple[np.ndarray, list]:
    """
    Perform between-volume motion correction.

    Parameters
    ----------
    data : np.ndarray
        4D DWI data array.
    gtab : GradientTable
        Gradient table containing b-values and b-vectors.
    affine : np.ndarray
        4x4 affine transformation matrix.
    brain_mask : np.ndarray or None, optional
        Binary brain mask to constrain registration. If provided,
        will be converted to uint8 and passed as static_mask.

    Returns
    -------
    data_corrected : np.ndarray
        4D DWI data array with motion correction applied.
    reg_affines : list
        List of 4x4 registration affine matrices for each volume.
    """
    # Ensure mask is binary uint8 if provided
    if brain_mask is not None:
        brain_mask = brain_mask.astype(np.float64)
        data_corrected, reg_affines = motion_correction(
            data,
            gtab,
            affine=affine,
            static_mask=brain_mask
        )
    else:
        data_corrected, reg_affines = motion_correction(
            data,
            gtab,
            affine=affine
        )

    # dipy.align.motion_correction returns a Nifti1Image; extract the array
    if hasattr(data_corrected, 'get_fdata'):
        data_corrected = data_corrected.get_fdata(dtype=np.float32)

    return data_corrected, reg_affines

save_preprocessed(data, affine, output_dir, filename_stem, *, gradient_files=None, brain_mask=None, metadata=None, processing_params=None, create_report=True)

Save preprocessed DWI data with auxiliary files and metadata.

Parameters:

Name Type Description Default
data ndarray

Preprocessed 4D DWI data array (X, Y, Z, volumes).

required
affine ndarray

4x4 affine transformation matrix from NIfTI header.

required
output_dir str or Path

Output directory for all files (flat structure).

required
filename_stem str

Base filename without extension (e.g., "sub01_dwi_preproc").

required
gradient_files dict or None

Dictionary with keys 'bval' and 'bvec' pointing to source files. If provided, these will be copied to the output directory.

None
brain_mask ndarray or None

3D binary brain mask to save alongside data.

None
metadata dict or None

Custom metadata to include in report (e.g., subject ID, session).

None
processing_params dict or None

Processing parameters used (e.g., denoising method, motion correction).

None
create_report bool

Whether to generate a JSON processing report.

True

Returns:

Name Type Description
output_paths dict[str, Path]

Dictionary mapping output types to their absolute paths: - 'data': Path to saved preprocessed data - 'bval': Path to copied bval file (if provided) - 'bvec': Path to copied bvec file (if provided) - 'mask': Path to saved brain mask (if provided) - 'report': Path to processing report (if created)

Examples:

>>> paths = save_preprocessed(
...     data=preprocessed_data,
...     affine=affine,
...     output_dir="/data/preprocessed",
...     filename_stem="sub01_dwi_preproc",
...     gradient_files={"bval": "original.bval", "bvec": "original.bvec"},
...     brain_mask=mask,
...     processing_params={"denoise_method": "patch2self", "motion_correction": True}
... )
Source code in src/csttool/preprocess/modules/save_preprocessed.py
def save_preprocessed(
    data: np.ndarray,
    affine: np.ndarray,
    output_dir: str | Path,
    filename_stem: str,
    *,
    gradient_files: dict[str, str | Path] | None = None,
    brain_mask: np.ndarray | None = None,
    metadata: dict | None = None,
    processing_params: dict | None = None,
    create_report: bool = True,
) -> dict[str, Path]:
    """
    Save preprocessed DWI data with auxiliary files and metadata.

    Parameters
    ----------
    data : np.ndarray
        Preprocessed 4D DWI data array (X, Y, Z, volumes).
    affine : np.ndarray
        4x4 affine transformation matrix from NIfTI header.
    output_dir : str or Path
        Output directory for all files (flat structure).
    filename_stem : str
        Base filename without extension (e.g., "sub01_dwi_preproc").

    gradient_files : dict or None, optional
        Dictionary with keys 'bval' and 'bvec' pointing to source files.
        If provided, these will be copied to the output directory.
    brain_mask : np.ndarray or None, optional
        3D binary brain mask to save alongside data.
    metadata : dict or None, optional
        Custom metadata to include in report (e.g., subject ID, session).
    processing_params : dict or None, optional
        Processing parameters used (e.g., denoising method, motion correction).
    create_report : bool, default=True
        Whether to generate a JSON processing report.

    Returns
    -------
    output_paths : dict[str, Path]
        Dictionary mapping output types to their absolute paths:
        - 'data': Path to saved preprocessed data
        - 'bval': Path to copied bval file (if provided)
        - 'bvec': Path to copied bvec file (if provided)
        - 'mask': Path to saved brain mask (if provided)
        - 'report': Path to processing report (if created)

    Examples
    --------
    >>> paths = save_preprocessed(
    ...     data=preprocessed_data,
    ...     affine=affine,
    ...     output_dir="/data/preprocessed",
    ...     filename_stem="sub01_dwi_preproc",
    ...     gradient_files={"bval": "original.bval", "bvec": "original.bvec"},
    ...     brain_mask=mask,
    ...     processing_params={"denoise_method": "patch2self", "motion_correction": True}
    ... )
    """
    # Validate inputs
    if not isinstance(data, np.ndarray):
        raise TypeError(f"data must be a numpy array, got {type(data)}")
    if affine.shape != (4, 4):
        raise ValueError(f"affine must be 4x4, got shape {affine.shape}")

    # Setup output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    output_paths = {}

    # Save preprocessed data
    data_path = output_dir / f"{filename_stem}.nii.gz"
    nib.save(nib.Nifti1Image(data, affine), data_path)
    output_paths['data'] = data_path
    print(f"  ✓ Saved preprocessed data: {data_path}")

    # Copy gradient files
    if gradient_files is not None:
        for grad_type in ['bval', 'bvec']:
            if grad_type in gradient_files:
                src = Path(gradient_files[grad_type])
                if src.exists():
                    dest = output_dir / f"{filename_stem}.{grad_type}"
                    shutil.copy2(src, dest)
                    output_paths[grad_type] = dest
                    print(f"  ✓ Copied {grad_type}: {dest}")
                else:
                    print(f"  ⚠️ {grad_type} file not found: {src}")

    # Save brain mask
    if brain_mask is not None:
        mask_path = output_dir / f"{filename_stem}_mask.nii.gz"
        nib.save(nib.Nifti1Image(brain_mask.astype(np.uint8), affine), mask_path)
        output_paths['mask'] = mask_path
        print(f"  ✓ Saved brain mask: {mask_path}")

    # Create processing report
    if create_report:
        report = {
            'timestamp': datetime.now().isoformat(),
            'filename_stem': filename_stem,
            'data_shape': list(data.shape),
            'data_dtype': str(data.dtype),
            'voxel_size': np.sqrt(np.sum(affine[:3, :3]**2, axis=0)).tolist(),
        }

        if processing_params is not None:
            report['processing_params'] = processing_params

        if metadata is not None:
            report['metadata'] = metadata

        report_path = output_dir / f"{filename_stem}_report.json"
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2)
        output_paths['report'] = report_path
        print(f"  ✓ Saved processing report: {report_path}")

    return output_paths

plot_denoising_comparison(data_before, data_after, brain_mask, output_dir, stem, denoise_method, vol_idx=None, verbose=True)

Create before/after denoising comparison figure.

Shows three orthogonal views comparing original and denoised data, plus RMS residuals highlighting removed noise.

Parameters:

Name Type Description Default
data_before ndarray

4D DWI data before denoising.

required
data_after ndarray

4D DWI data after denoising.

required
brain_mask ndarray

3D binary brain mask.

required
output_dir str or Path

Output directory for saving figure.

required
stem str

Subject/scan identifier for filename.

required
denoise_method str

Denoising method used.

required
vol_idx int

Volume index to visualize. Default picks a DWI volume (middle of 4th dim).

None
verbose bool

Print progress information.

True

Returns:

Name Type Description
fig_path Path

Path to saved figure.

Source code in src/csttool/preprocess/modules/visualizations.py
def plot_denoising_comparison(
    data_before,
    data_after,
    brain_mask,
    output_dir,
    stem,
    denoise_method,
    vol_idx=None,
    verbose=True
):
    """
    Create before/after denoising comparison figure.

    Shows three orthogonal views comparing original and denoised data,
    plus RMS residuals highlighting removed noise.

    Parameters
    ----------
    data_before : ndarray
        4D DWI data before denoising.
    data_after : ndarray
        4D DWI data after denoising.
    brain_mask : ndarray
        3D binary brain mask.
    output_dir : str or Path
        Output directory for saving figure.
    stem : str
        Subject/scan identifier for filename.
    denoise_method : str
        Denoising method used. 
    vol_idx : int, optional
        Volume index to visualize. Default picks a DWI volume (middle of 4th dim).
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    fig_path : Path
        Path to saved figure.
    """
    output_dir = Path(output_dir)
    viz_dir = output_dir / "visualizations"
    viz_dir.mkdir(parents=True, exist_ok=True)

    # Default to a DWI volume (middle of 4th dimension)
    if vol_idx is None:
        vol_idx = data_before.shape[3] // 2

    # Get middle slice indices for each orientation
    mid_ax = data_before.shape[2] // 2
    mid_cor = data_before.shape[1] // 2
    mid_sag = data_before.shape[0] // 2

    # Extract 3D volumes
    before = data_before[..., vol_idx]
    after = data_after[..., vol_idx]

    # Compute RMS residuals (accentuates outliers)
    rms_diff = np.sqrt((before.astype(np.float64) - after.astype(np.float64)) ** 2)
    if brain_mask is not None and brain_mask.shape == before.shape:
        rms_diff[~brain_mask] = 0

    # Define orthogonal views
    views = [
        ('Axial', before[:, :, mid_ax], after[:, :, mid_ax], rms_diff[:, :, mid_ax]),
        ('Coronal', before[:, mid_cor, :], after[:, mid_cor, :], rms_diff[:, mid_cor, :]),
        ('Sagittal', before[mid_sag, :, :], after[mid_sag, :, :], rms_diff[mid_sag, :, :]),
    ]

    # Create figure: 3 rows (views) × 3 columns (original, denoised, residuals)
    fig, axes = plt.subplots(3, 3, figsize=(12, 12),
                              subplot_kw={'xticks': [], 'yticks': []})
    fig.subplots_adjust(hspace=0.1, wspace=0.05)
    fig.suptitle(f"Denoising using {denoise_method} - {stem} (Volume {vol_idx})", fontsize=14, fontweight='bold')

    # Column titles
    axes[0, 0].set_title('Original', fontsize=12)
    axes[0, 1].set_title('Denoised', fontsize=12)
    axes[0, 2].set_title('Residuals (RMS)', fontsize=12)

    for row, (view_name, orig, den, res) in enumerate(views):
        # Original
        axes[row, 0].imshow(orig.T, cmap='gray', interpolation='none', origin='lower')
        axes[row, 0].set_ylabel(view_name, fontsize=12, fontweight='bold')

        # Denoised
        axes[row, 1].imshow(den.T, cmap='gray', interpolation='none', origin='lower')

        # Residuals
        axes[row, 2].imshow(res.T, cmap='gray', interpolation='none', origin='lower')

    fig_path = viz_dir / f"{stem}_denoising_qc.png"
    fig.savefig(fig_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    if verbose:
        print(f"  ✓ Denoising QC: {fig_path}")

    return fig_path

plot_gibbs_unringing_comparison(data_before, data_after, brain_mask, output_dir, stem, vol_idx=None, verbose=True)

Create before/after Gibbs unringing comparison figure.

Shows three orthogonal views comparing data before and after unringing, plus RMS residuals highlighting removed ringing artifacts.

Parameters:

Name Type Description Default
data_before ndarray

4D DWI data before Gibbs unringing.

required
data_after ndarray

4D DWI data after Gibbs unringing.

required
brain_mask ndarray

3D binary brain mask.

required
output_dir str or Path

Output directory for saving figure.

required
stem str

Subject/scan identifier for filename.

required
vol_idx int

Volume index to visualize. Default picks a DWI volume (middle of 4th dim).

None
verbose bool

Print progress information.

True

Returns:

Name Type Description
fig_path Path

Path to saved figure.

Source code in src/csttool/preprocess/modules/visualizations.py
def plot_gibbs_unringing_comparison(
    data_before,
    data_after,
    brain_mask,
    output_dir,
    stem,
    vol_idx=None,
    verbose=True
):
    """
    Create before/after Gibbs unringing comparison figure.

    Shows three orthogonal views comparing data before and after
    unringing, plus RMS residuals highlighting removed ringing artifacts.

    Parameters
    ----------
    data_before : ndarray
        4D DWI data before Gibbs unringing.
    data_after : ndarray
        4D DWI data after Gibbs unringing.
    brain_mask : ndarray
        3D binary brain mask.
    output_dir : str or Path
        Output directory for saving figure.
    stem : str
        Subject/scan identifier for filename.
    vol_idx : int, optional
        Volume index to visualize. Default picks a DWI volume (middle of 4th dim).
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    fig_path : Path
        Path to saved figure.
    """
    output_dir = Path(output_dir)
    viz_dir = output_dir / "visualizations"
    viz_dir.mkdir(parents=True, exist_ok=True)

    # Default to a DWI volume (middle of 4th dimension)
    if vol_idx is None:
        vol_idx = data_before.shape[3] // 2

    # Get middle slice indices for each orientation
    mid_ax = data_before.shape[2] // 2
    mid_cor = data_before.shape[1] // 2
    mid_sag = data_before.shape[0] // 2

    # Extract 3D volumes
    before = data_before[..., vol_idx]
    after = data_after[..., vol_idx]

    # Compute RMS residuals (accentuates outliers)
    rms_diff = np.sqrt((before.astype(np.float64) - after.astype(np.float64)) ** 2)
    if brain_mask is not None and brain_mask.shape == before.shape:
        rms_diff[~brain_mask] = 0

    # Define orthogonal views
    views = [
        ('Axial', before[:, :, mid_ax], after[:, :, mid_ax], rms_diff[:, :, mid_ax]),
        ('Coronal', before[:, mid_cor, :], after[:, mid_cor, :], rms_diff[:, mid_cor, :]),
        ('Sagittal', before[mid_sag, :, :], after[mid_sag, :, :], rms_diff[mid_sag, :, :]),
    ]

    # Create figure: 3 rows (views) × 3 columns (before, after, residuals)
    fig, axes = plt.subplots(3, 3, figsize=(12, 12),
                              subplot_kw={'xticks': [], 'yticks': []})
    fig.subplots_adjust(hspace=0.1, wspace=0.05)
    fig.suptitle(f"Gibbs Unringing - {stem} (Volume {vol_idx})", fontsize=14, fontweight='bold')

    # Column titles
    axes[0, 0].set_title('Before', fontsize=12)
    axes[0, 1].set_title('After', fontsize=12)
    axes[0, 2].set_title('Residuals (RMS)', fontsize=12)

    for row, (view_name, bef, aft, res) in enumerate(views):
        # Before
        axes[row, 0].imshow(bef.T, cmap='gray', interpolation='none', origin='lower')
        axes[row, 0].set_ylabel(view_name, fontsize=12, fontweight='bold')

        # After
        axes[row, 1].imshow(aft.T, cmap='gray', interpolation='none', origin='lower')

        # Residuals
        axes[row, 2].imshow(res.T, cmap='gray', interpolation='none', origin='lower')

    fig_path = viz_dir / f"{stem}_gibbs_unringing_qc.png"
    fig.savefig(fig_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close(fig)

    if verbose:
        print(f"  ✓ Gibbs unringing QC: {fig_path}")

    return fig_path

plot_brain_mask_overlay(data, brain_mask, gtab, output_dir, stem, verbose=True)

Create brain mask overlay visualization.

Shows brain mask overlaid on b0 image in three orthogonal views, plus mask coverage statistics.

Parameters:

Name Type Description Default
data ndarray

4D DWI data (masked or unmasked).

required
brain_mask ndarray

3D binary brain mask.

required
gtab GradientTable

Gradient table to identify b0 volumes.

required
output_dir str or Path

Output directory for saving figure.

required
stem str

Subject/scan identifier for filename.

required
verbose bool

Print progress information.

True

Returns:

Name Type Description
fig_path Path

Path to saved figure.

Source code in src/csttool/preprocess/modules/visualizations.py
def plot_brain_mask_overlay(
    data,
    brain_mask,
    gtab,
    output_dir,
    stem,
    verbose=True
):
    """
    Create brain mask overlay visualization.

    Shows brain mask overlaid on b0 image in three orthogonal views,
    plus mask coverage statistics.

    Parameters
    ----------
    data : ndarray
        4D DWI data (masked or unmasked).
    brain_mask : ndarray
        3D binary brain mask.
    gtab : GradientTable
        Gradient table to identify b0 volumes.
    output_dir : str or Path
        Output directory for saving figure.
    stem : str
        Subject/scan identifier for filename.
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    fig_path : Path
        Path to saved figure.
    """
    output_dir = Path(output_dir)
    viz_dir = output_dir / "visualizations"
    viz_dir.mkdir(parents=True, exist_ok=True)

    # Get b0 volume
    b0_idx = np.where(gtab.bvals < 50)[0]
    if len(b0_idx) == 0:
        b0_idx = [0]
    b0 = data[..., b0_idx[0]]

    # Get slice indices
    mid_ax = data.shape[2] // 2
    mid_cor = data.shape[1] // 2
    mid_sag = data.shape[0] // 2

    # Compute statistics
    total_voxels = brain_mask.size
    brain_voxels = brain_mask.sum()
    coverage = brain_voxels / total_voxels * 100

    # Create figure
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), constrained_layout=True)
    fig.suptitle(f"Brain Mask QC - {stem}\n"
                 f"Coverage: {brain_voxels:,} voxels ({coverage:.1f}%)",
                 fontsize=14, fontweight='bold')

    vmax = np.percentile(b0[brain_mask], 99) if brain_mask.any() else np.percentile(b0, 99)

    views = [
        ('Axial', b0[:, :, mid_ax], brain_mask[:, :, mid_ax], f'z={mid_ax}'),
        ('Coronal', b0[:, mid_cor, :], brain_mask[:, mid_cor, :], f'y={mid_cor}'),
        ('Sagittal', b0[mid_sag, :, :], brain_mask[mid_sag, :, :], f'x={mid_sag}'),
    ]

    for col, (view_name, b0_slice, mask_slice, coord) in enumerate(views):
        # Row 0: b0 only
        axes[0, col].imshow(b0_slice.T, cmap='gray', origin='lower', vmin=0, vmax=vmax)
        axes[0, col].set_title(f'{view_name} ({coord})\nb0 image')
        axes[0, col].axis('off')

        # Row 1: b0 with mask overlay
        axes[1, col].imshow(b0_slice.T, cmap='gray', origin='lower', vmin=0, vmax=vmax)

        # Create masked array for overlay
        mask_overlay = np.ma.masked_where(mask_slice.T == 0, mask_slice.T)
        axes[1, col].imshow(mask_overlay, cmap='Reds', alpha=0.4, origin='lower')

        # Add contour
        axes[1, col].contour(mask_slice.T, levels=[0.5], colors='red', linewidths=1)
        axes[1, col].set_title(f'{view_name}\nwith brain mask')
        axes[1, col].axis('off')

    fig_path = viz_dir / f"{stem}_brain_mask_qc.png"
    plt.savefig(fig_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()

    if verbose:
        print(f"  ✓ Brain mask QC: {fig_path}")

    return fig_path

plot_motion_correction_summary(reg_affines, output_dir, stem, verbose=True)

Create motion correction summary visualization.

Shows translation and rotation parameters across volumes, highlighting any large motion events.

Parameters:

Name Type Description Default
reg_affines list of ndarray

List of 4x4 registration affine matrices (one per volume).

required
output_dir str or Path

Output directory for saving figure.

required
stem str

Subject/scan identifier for filename.

required
verbose bool

Print progress information.

True

Returns:

Name Type Description
fig_path Path

Path to saved figure.

Source code in src/csttool/preprocess/modules/visualizations.py
def plot_motion_correction_summary(
    reg_affines,
    output_dir,
    stem,
    verbose=True
):
    """
    Create motion correction summary visualization.

    Shows translation and rotation parameters across volumes,
    highlighting any large motion events.

    Parameters
    ----------
    reg_affines : list of ndarray
        List of 4x4 registration affine matrices (one per volume).
    output_dir : str or Path
        Output directory for saving figure.
    stem : str
        Subject/scan identifier for filename.
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    fig_path : Path
        Path to saved figure.
    """
    output_dir = Path(output_dir)
    viz_dir = output_dir / "visualizations"
    viz_dir.mkdir(parents=True, exist_ok=True)

    n_vols = len(reg_affines)

    # Extract translation and rotation parameters
    translations = np.zeros((n_vols, 3))
    rotations = np.zeros((n_vols, 3))

    for i, affine in enumerate(reg_affines):
        # Translation is in the last column
        translations[i] = affine[:3, 3]

        # Approximate rotation angles from rotation matrix
        # Using small angle approximation for simplicity
        R = affine[:3, :3]
        rotations[i, 0] = np.arctan2(R[2, 1], R[2, 2]) * 180 / np.pi  # Roll (x)
        rotations[i, 1] = np.arctan2(-R[2, 0], np.sqrt(R[2, 1]**2 + R[2, 2]**2)) * 180 / np.pi  # Pitch (y)
        rotations[i, 2] = np.arctan2(R[1, 0], R[0, 0]) * 180 / np.pi  # Yaw (z)

    # Compute relative motion (difference from first volume)
    translations_rel = translations - translations[0]
    rotations_rel = rotations - rotations[0]

    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(15, 10), constrained_layout=True)
    fig.suptitle(f"Motion Correction QC - {stem}\n{n_vols} volumes",
                 fontsize=14, fontweight='bold')

    volumes = np.arange(n_vols)

    # Translation (absolute)
    ax = axes[0, 0]
    ax.plot(volumes, translations[:, 0], 'r-', label='X', linewidth=1.5)
    ax.plot(volumes, translations[:, 1], 'g-', label='Y', linewidth=1.5)
    ax.plot(volumes, translations[:, 2], 'b-', label='Z', linewidth=1.5)
    ax.set_xlabel('Volume')
    ax.set_ylabel('Translation (mm)')
    ax.set_title('Absolute Translation')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    # Translation (relative to first)
    ax = axes[0, 1]
    ax.plot(volumes, translations_rel[:, 0], 'r-', label='X', linewidth=1.5)
    ax.plot(volumes, translations_rel[:, 1], 'g-', label='Y', linewidth=1.5)
    ax.plot(volumes, translations_rel[:, 2], 'b-', label='Z', linewidth=1.5)
    ax.axhline(y=0, color='k', linestyle='--', linewidth=0.5)
    ax.set_xlabel('Volume')
    ax.set_ylabel('Translation (mm)')
    ax.set_title('Relative Translation (vs. volume 0)')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    # Rotation (absolute)
    ax = axes[1, 0]
    ax.plot(volumes, rotations[:, 0], 'r-', label='Roll (X)', linewidth=1.5)
    ax.plot(volumes, rotations[:, 1], 'g-', label='Pitch (Y)', linewidth=1.5)
    ax.plot(volumes, rotations[:, 2], 'b-', label='Yaw (Z)', linewidth=1.5)
    ax.set_xlabel('Volume')
    ax.set_ylabel('Rotation (degrees)')
    ax.set_title('Absolute Rotation')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    # Rotation (relative to first)
    ax = axes[1, 1]
    ax.plot(volumes, rotations_rel[:, 0], 'r-', label='Roll (X)', linewidth=1.5)
    ax.plot(volumes, rotations_rel[:, 1], 'g-', label='Pitch (Y)', linewidth=1.5)
    ax.plot(volumes, rotations_rel[:, 2], 'b-', label='Yaw (Z)', linewidth=1.5)
    ax.axhline(y=0, color='k', linestyle='--', linewidth=0.5)
    ax.set_xlabel('Volume')
    ax.set_ylabel('Rotation (degrees)')
    ax.set_title('Relative Rotation (vs. volume 0)')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)

    # Add summary statistics
    max_trans = np.max(np.abs(translations_rel))
    max_rot = np.max(np.abs(rotations_rel))

    fig.text(0.5, 0.02, 
             f"Max displacement: {max_trans:.2f} mm | Max rotation: {max_rot:.2f}°",
             ha='center', fontsize=11, style='italic')

    fig_path = viz_dir / f"{stem}_motion_qc.png"
    plt.savefig(fig_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()

    if verbose:
        print(f"  ✓ Motion correction QC: {fig_path}")

    return fig_path

create_preprocessing_summary(data_original, data_preprocessed, brain_mask, gtab, output_dir, stem, motion_correction_applied=False, verbose=True)

Create multi-panel preprocessing summary figure.

Combines key QC visualizations into a single summary figure for quick assessment of preprocessing quality.

Parameters:

Name Type Description Default
data_original ndarray

4D DWI data before preprocessing.

required
data_preprocessed ndarray

4D DWI data after preprocessing.

required
brain_mask ndarray

3D binary brain mask.

required
gtab GradientTable

Gradient table.

required
output_dir str or Path

Output directory for saving figure.

required
stem str

Subject/scan identifier for filename.

required
motion_correction_applied bool

Whether motion correction was applied.

False
verbose bool

Print progress information.

True

Returns:

Name Type Description
fig_path Path

Path to saved figure.

Source code in src/csttool/preprocess/modules/visualizations.py
def create_preprocessing_summary(
    data_original,
    data_preprocessed,
    brain_mask,
    gtab,
    output_dir,
    stem,
    motion_correction_applied=False,
    verbose=True
):
    """
    Create multi-panel preprocessing summary figure.

    Combines key QC visualizations into a single summary figure
    for quick assessment of preprocessing quality.

    Parameters
    ----------
    data_original : ndarray
        4D DWI data before preprocessing.
    data_preprocessed : ndarray
        4D DWI data after preprocessing.
    brain_mask : ndarray
        3D binary brain mask.
    gtab : GradientTable
        Gradient table.
    output_dir : str or Path
        Output directory for saving figure.
    stem : str
        Subject/scan identifier for filename.
    motion_correction_applied : bool, optional
        Whether motion correction was applied.
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    fig_path : Path
        Path to saved figure.
    """
    output_dir = Path(output_dir)
    viz_dir = output_dir / "visualizations"
    viz_dir.mkdir(parents=True, exist_ok=True)

    # Check if shapes are compatible
    orig_shape = data_original.shape[:3]
    proc_shape = data_preprocessed.shape[:3]
    mask_shape = brain_mask.shape

    # If original and preprocessed have different shapes, use preprocessed for all comparisons
    if orig_shape != proc_shape or orig_shape != mask_shape:
        # Use preprocessed data as "original" for visualization when shapes mismatch
        data_original = data_preprocessed

    # Get b0 volume index
    b0_idx = np.where(gtab.bvals < 50)[0]
    if len(b0_idx) == 0:
        b0_idx = [0]
    vol_idx = b0_idx[0]

    # Get slice indices
    mid_ax = data_original.shape[2] // 2

    # Extract data
    orig_b0 = data_original[:, :, mid_ax, vol_idx]
    proc_b0 = data_preprocessed[:, :, mid_ax, vol_idx]
    mask_slice = brain_mask[:, :, mid_ax]

    # Compute difference
    diff = np.abs(proc_b0.astype(np.float64) - orig_b0.astype(np.float64))
    diff[~mask_slice] = 0

    # Compute statistics
    brain_voxels = brain_mask.sum()
    coverage = brain_voxels / brain_mask.size * 100

    # Intensity statistics
    orig_mean = np.mean(data_original[brain_mask])
    proc_mean = np.mean(data_preprocessed[brain_mask])
    orig_std = np.std(data_original[brain_mask])
    proc_std = np.std(data_preprocessed[brain_mask])

    # Create figure with constrained_layout
    fig = plt.figure(figsize=(18, 12), constrained_layout=True)
    fig.suptitle(f"Preprocessing Summary - {stem}", fontsize=16, fontweight='bold')

    # Create grid
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)

    vmax = np.percentile(orig_b0[mask_slice], 99) if mask_slice.any() else np.percentile(orig_b0, 99)
    diff_vmax = np.percentile(diff[mask_slice], 99) if diff[mask_slice].any() else 1

    # Row 0: Original, Preprocessed, Difference, Mask
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(orig_b0.T, cmap='gray', origin='lower', vmin=0, vmax=vmax)
    ax1.set_title('Original b0')
    ax1.axis('off')

    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(proc_b0.T, cmap='gray', origin='lower', vmin=0, vmax=vmax)
    ax2.set_title('Preprocessed b0')
    ax2.axis('off')

    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(diff.T, cmap='hot', origin='lower', vmin=0, vmax=diff_vmax)
    ax3.set_title('Difference (denoising)')
    ax3.axis('off')

    ax4 = fig.add_subplot(gs[0, 3])
    ax4.imshow(proc_b0.T, cmap='gray', origin='lower', vmin=0, vmax=vmax)
    ax4.contour(mask_slice.T, levels=[0.5], colors='red', linewidths=1.5)
    ax4.set_title(f'Brain Mask\n({brain_voxels:,} voxels)')
    ax4.axis('off')

    # Row 1: Three orthogonal views with mask overlay
    views_data = [
        ('Sagittal', data_preprocessed[data_preprocessed.shape[0]//2, :, :, vol_idx],
         brain_mask[brain_mask.shape[0]//2, :, :]),
        ('Coronal', data_preprocessed[:, data_preprocessed.shape[1]//2, :, vol_idx],
         brain_mask[:, brain_mask.shape[1]//2, :]),
        ('Axial', data_preprocessed[:, :, data_preprocessed.shape[2]//2, vol_idx],
         brain_mask[:, :, brain_mask.shape[2]//2]),
    ]

    for i, (name, img, msk) in enumerate(views_data):
        ax = fig.add_subplot(gs[1, i])
        ax.imshow(img.T, cmap='gray', origin='lower')
        ax.contour(msk.T, levels=[0.5], colors='cyan', linewidths=1)
        ax.set_title(name)
        ax.axis('off')

    # Row 1, col 3: Histogram comparison
    ax_hist = fig.add_subplot(gs[1, 3])

    orig_vals = data_original[brain_mask].flatten()
    proc_vals = data_preprocessed[brain_mask].flatten()

    # Subsample for efficiency
    if len(orig_vals) > 100000:
        idx = np.random.choice(len(orig_vals), 100000, replace=False)
        orig_vals = orig_vals[idx]
        proc_vals = proc_vals[idx]

    ax_hist.hist(orig_vals, bins=100, alpha=0.5, label='Original', density=True)
    ax_hist.hist(proc_vals, bins=100, alpha=0.5, label='Preprocessed', density=True)
    ax_hist.set_xlabel('Intensity')
    ax_hist.set_ylabel('Density')
    ax_hist.set_title('Intensity Distribution')
    ax_hist.legend(fontsize=8)

    # Row 2: Statistics panel
    ax_stats = fig.add_subplot(gs[2, :])
    ax_stats.axis('off')

    mc_status = "Applied" if motion_correction_applied else "Not applied"

    stats_text = (
        f"{'─' * 80}\n"
        f"PREPROCESSING STATISTICS\n"
        f"{'─' * 80}\n\n"
        f"Data Shape:           {data_original.shape}\n"
        f"Voxel Dimensions:     {data_original.shape[:3]}\n"
        f"Number of Volumes:    {data_original.shape[3]}\n"
        f"B-values:             {sorted(set(gtab.bvals.astype(int)))}\n\n"
        f"Brain Mask Coverage:  {brain_voxels:,} voxels ({coverage:.1f}%)\n\n"
        f"Intensity (in brain):\n"
        f"  Original:           mean = {orig_mean:.1f}, std = {orig_std:.1f}\n"
        f"  Preprocessed:       mean = {proc_mean:.1f}, std = {proc_std:.1f}\n\n"
        f"Motion Correction:    {mc_status}\n"
        f"{'─' * 80}"
    )

    ax_stats.text(0.5, 0.5, stats_text, transform=ax_stats.transAxes,
                  fontsize=10, fontfamily='monospace',
                  verticalalignment='center', horizontalalignment='center',
                  bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    fig_path = viz_dir / f"{stem}_preprocessing_summary.png"
    plt.savefig(fig_path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()

    if verbose:
        print(f"  ✓ Preprocessing summary: {fig_path}")

    return fig_path

save_all_preprocessing_visualizations(data_original, data_denoised, data_masked, data_unringed, data_preprocessed, brain_mask, gtab, output_dir, stem, denoise_method, reg_affines=None, motion_correction_applied=False, verbose=True)

Generate and save all preprocessing visualizations.

Convenience function that calls all visualization functions and returns paths to all generated figures.

Parameters:

Name Type Description Default
data_original ndarray

4D DWI data before any preprocessing.

required
data_denoised ndarray

4D DWI data after denoising (before masking).

required
data_masked ndarray

4D DWI data after brain masking (cropped).

required
data_unringed ndarray

4D DWI data after Gibbs unringing (cropped).

required
data_preprocessed ndarray

4D DWI data after full preprocessing.

required
brain_mask ndarray

3D binary brain mask.

required
gtab GradientTable

Gradient table.

required
output_dir str or Path

Output directory for saving figures.

required
stem str

Subject/scan identifier for filenames.

required
denoise_method str

Denoising method used.

required
reg_affines list of ndarray

Registration affines from motion correction.

None
motion_correction_applied bool

Whether motion correction was applied.

False
verbose bool

Print progress information.

True

Returns:

Name Type Description
viz_paths dict

Dictionary mapping visualization names to file paths.

Source code in src/csttool/preprocess/modules/visualizations.py
def save_all_preprocessing_visualizations(
    data_original,
    data_denoised,
    data_masked,
    data_unringed,
    data_preprocessed,
    brain_mask,
    gtab,
    output_dir,
    stem,
    denoise_method,
    reg_affines=None,
    motion_correction_applied=False,
    verbose=True
):
    """
    Generate and save all preprocessing visualizations.

    Convenience function that calls all visualization functions
    and returns paths to all generated figures.

    Parameters
    ----------
    data_original : ndarray
        4D DWI data before any preprocessing.
    data_denoised : ndarray
        4D DWI data after denoising (before masking).
    data_masked : ndarray
        4D DWI data after brain masking (cropped).
    data_unringed : ndarray
        4D DWI data after Gibbs unringing (cropped).
    data_preprocessed : ndarray
        4D DWI data after full preprocessing.
    brain_mask : ndarray
        3D binary brain mask.
    gtab : GradientTable
        Gradient table.
    output_dir : str or Path
        Output directory for saving figures.
    stem : str
        Subject/scan identifier for filenames.
    denoise_method : str, optional
        Denoising method used.
    reg_affines : list of ndarray, optional
        Registration affines from motion correction.
    motion_correction_applied : bool, optional
        Whether motion correction was applied.
    verbose : bool, optional
        Print progress information.

    Returns
    -------
    viz_paths : dict
        Dictionary mapping visualization names to file paths.
    """
    if verbose:
        print("\nGenerating preprocessing visualizations...")

    viz_paths = {}

    # Denoising comparison
    if data_denoised is not None:
        viz_paths['denoising_qc'] = plot_denoising_comparison(
            data_original, data_denoised, brain_mask,
            output_dir, stem, denoise_method, verbose=verbose
        )

    # Gibbs unringing comparison (both inputs are cropped/masked)
    if data_unringed is not None and data_masked is not None:
        viz_paths['gibbs_unringing_qc'] = plot_gibbs_unringing_comparison(
            data_masked, data_unringed, brain_mask,
            output_dir, stem, verbose=verbose
        )

    # Brain mask overlay
    viz_paths['brain_mask_qc'] = plot_brain_mask_overlay(
        data_preprocessed, brain_mask, gtab,
        output_dir, stem, verbose=verbose
    )

    # Motion correction (if applied)
    if motion_correction_applied and reg_affines is not None:
        viz_paths['motion_qc'] = plot_motion_correction_summary(
            reg_affines, output_dir, stem, verbose=verbose
        )

    # Summary figure
    viz_paths['summary'] = create_preprocessing_summary(
        data_original, data_preprocessed, brain_mask, gtab,
        output_dir, stem, motion_correction_applied, verbose=verbose
    )

    if verbose:
        print(f"  ✓ All preprocessing visualizations saved to: {Path(output_dir) / 'visualizations'}")

    return viz_paths