3-D FFT Example — Volumetric Frequency Analysis and 2-D Inverse FFT

This example demonstrates 3-D transform functions and the 2-D inverse FFT:

An 8×8×8 complex array built from three sinusoidal components (one per axis) is transformed forward with fft_3d_forward, recovered with fft_3d_backward, and the round-trip error is reported. The result is cross-checked against fast_3dft.

A separate 32×32 image is processed by fast_2dft, its high-frequency content is zeroed out (low-pass filter), and the result is reconstructed with fft_2d_backward.

Example Code

"""IMSL 3-D FFT example: volumetric frequency analysis and 2-D inverse FFT.

Demonstrates:
  - fft_3d_forward / fft_3d_backward  (IMSL FFT3F / FFT3B)
  - fast_3dft                          (IMSL FAST_3DFT)
  - fft_2d_backward                    (IMSL FFT2B, 2-D inverse FFT)
  - fast_2dft                          (IMSL FAST_2DFT, direct 2-D DFT)

Outputs:
  - Round-trip errors printed to stdout
  - SVG plot saved to test_output/example_imsl_3d_fft.svg
"""

from __future__ import annotations

from pathlib import Path
from typing import Dict

import matplotlib.pyplot as plt
import numpy as np

from transforms import (
    fast_2dft,
    fast_3dft,
    fft_2d_backward,
    fft_3d_backward,
    fft_3d_forward,
)


def run_demo_imsl_3d_fft() -> Dict[str, object]:
    """Run IMSL 3-D FFT round-trip and 2-D inverse FFT example.

    Creates an 8×8×8 real-valued 3-D array as a sum of sinusoids along each
    axis, computes its 3-D DFT via fft_3d_forward, reconstructs with
    fft_3d_backward, and reports the round-trip error.  The forward transform
    is cross-checked against fast_3dft (numpy-backed direct DFT).

    Separately builds a 32×32 image from two 2-D sinusoids, zeroes out the
    high-frequency half of its spectrum (low-pass filter), and reconstructs
    via fft_2d_backward.  The spectrum is first computed with fast_2dft.

    Args:
        None

    Returns:
        Dict[str, object]: Result dict with keys ``rt3d_error`` (float),
            ``rt2d_error`` (float), ``max_diff_3d`` (float), and
            ``plot_path`` (str).
    """
    # --- 3-D array: sum of sinusoids along x, y, z ---
    L, M, Nz = 8, 8, 8
    ax = np.arange(L)
    ay = np.arange(M)
    az = np.arange(Nz)
    X3, Y3, Z3 = np.meshgrid(ax, ay, az, indexing="ij")
    vol = (
        np.cos(2 * np.pi * X3 / L)
        + 0.5 * np.sin(2 * np.pi * 2 * Y3 / M)
        + 0.25 * np.cos(2 * np.pi * 3 * Z3 / Nz)
    ).astype(complex)

    # Forward 3-D FFT
    V_fwd = fft_3d_forward(vol)

    # Inverse 3-D FFT (round-trip)
    vol_rt = fft_3d_backward(V_fwd)
    rt3d_error = float(np.max(np.abs(vol_rt - vol)))

    # Cross-check with fast_3dft
    V_direct = fast_3dft(vol)
    max_diff_3d = float(np.max(np.abs(V_fwd - V_direct)))

    # --- 2-D: low-pass filter and reconstruct via fft_2d_backward ---
    P, Q = 32, 32
    px = np.arange(P)
    py = np.arange(Q)
    PX, PY = np.meshgrid(px, py, indexing="ij")
    img = (
        np.cos(2 * np.pi * 2 * PX / P) + np.sin(2 * np.pi * 3 * PY / Q)
    ).astype(complex)

    # Forward 2-D DFT via fast_2dft
    IMG_fwd = fast_2dft(img)

    # Low-pass: zero out high-frequency quarter
    IMG_lpf = IMG_fwd.copy()
    IMG_lpf[P // 4 : 3 * P // 4, :] = 0
    IMG_lpf[:, Q // 4 : 3 * Q // 4] = 0

    # Inverse 2-D FFT via fft_2d_backward
    img_full_rt = fft_2d_backward(IMG_fwd)
    img_rec = fft_2d_backward(IMG_lpf)
    rt2d_error = float(np.max(np.abs(img_full_rt - img)))

    print("\nIMSL 3-D FFT and 2-D Inverse FFT Example")
    print("-" * 55)
    print(f"{'3-D array shape':<42} {str(vol.shape):>8}")
    print(f"{'3-D round-trip max error':<42} {rt3d_error:>8.2e}")
    print(f"{'Max |fft_3d_forward - fast_3dft|':<42} {max_diff_3d:>8.2e}")
    print(f"{'2-D image shape':<42} {str(img.shape):>8}")
    print(f"{'2-D round-trip max error':<42} {rt2d_error:>8.2e}")
    print("-" * 55)

    output_dir = Path("test_output")
    output_dir.mkdir(parents=True, exist_ok=True)
    plot_path = output_dir / "example_imsl_3d_fft.svg"

    fig, axes = plt.subplots(1, 3, figsize=(14, 4))

    # Slice of 3-D spectrum magnitude at z=0
    spec_slice = np.abs(np.fft.fftshift(V_fwd[:, :, 0]))
    im0 = axes[0].imshow(spec_slice, cmap="plasma", aspect="auto")
    axes[0].set_title("3-D Spectrum |V[kx,ky,0]| (z-slice, shifted)")
    axes[0].set_xlabel("ky bin")
    axes[0].set_ylabel("kx bin")
    fig.colorbar(im0, ax=axes[0])

    # Original 2-D image (real part)
    im1 = axes[1].imshow(img.real, cmap="viridis", aspect="auto")
    axes[1].set_title("Original 2-D Signal (real part)")
    axes[1].set_xlabel("y")
    axes[1].set_ylabel("x")
    fig.colorbar(im1, ax=axes[1])

    # Reconstructed after low-pass filter
    im2 = axes[2].imshow(img_rec.real, cmap="viridis", aspect="auto")
    axes[2].set_title("After Low-Pass + fft_2d_backward (real part)")
    axes[2].set_xlabel("y")
    axes[2].set_ylabel("x")
    fig.colorbar(im2, ax=axes[2])

    fig.tight_layout()
    fig.savefig(plot_path, format="svg")
    plt.close(fig)

    return {
        "rt3d_error": rt3d_error,
        "rt2d_error": rt2d_error,
        "max_diff_3d": max_diff_3d,
        "plot_path": str(plot_path),
    }


if __name__ == "__main__":
    run_demo_imsl_3d_fft()

Plot Output

Three-panel plot: 3-D spectrum z-slice, original 2-D image, filtered reconstruction

Left: magnitude of a z=0 slice through the 3-D spectrum (centred via fftshift). Centre: real part of the original 32×32 2-D signal. Right: real part after low-pass filtering and reconstruction via fft_2d_backward.

Console Output

IMSL 3-D FFT and 2-D Inverse FFT Example
-------------------------------------------------------
3-D array shape                            (8, 8, 8)
3-D round-trip max error                   4.44e-16
Max |fft_3d_forward - fast_3dft|           0.00e+00
2-D image shape                            (32, 32)
2-D round-trip max error                   6.73e-16
-------------------------------------------------------