FM_PhysMamba_UNET: Physics-Informed Mamba for Image Dehazing
This repository contains the official implementation of FM_PhysMamba_UNET, a State Space Model (Mamba) architecture integrated with Flow Matching and physical priors for high-performance image dehazing.
Image Illustrations
Figure 1: Comparison between Hazy Input (Left) and FM-PhysMamba Restored Output (Right)
π Repository Structure
To ensure all internal imports (e.g., from model import ...) function correctly, please maintain this layout:
model/: Contains the core UNet and PhysMamba blocks.configs/: Model configuration files.utils.py: Contains the utility logic for model inference.data/: Data processing and re-standardization utilities.checkpoints/: Directory for model weights (pytorch_model.bin) (Please put your weights inside this directory)test_imgs/: Sample images for testing.
π Quick Start: Inference
This script performs dehazing on large images by breaking them into overlapping tiles to manage memory efficiently.
import torch
import os
from PIL import Image
from model import FM_PhysMamba_UNET, ODESolver
from utils import predict_large_image_vectorized, preprocess_single_image
from data.utils import restandardize_tensor
import matplotlib.pyplot as plt
# ==========================================
# CONFIGURATION
# ==========================================
# REPLACE THIS with your actual weight file path
WEIGHT_PATH = "checkpoints/final_weights/DENSE_HAZE/pytorch_model.bin"
INPUT_IMG_PATH = "test_imgs/test_images.jpeg"
OUTPUT_DIR = "test_imgs/results"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def load_dehazing_model(weight_path, device):
"""Initializes model and loads pre-trained weights."""
print(f"Loading model to {device}...")
model = FM_PhysMamba_UNET("small").to(device)
if not os.path.exists(weight_path):
raise FileNotFoundError(f"Weights not found at: {weight_path}")
checkpoint = torch.load(weight_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
model.eval()
return model
def run_inference(model, img_path, device, output_dir):
"""Performs tiled inference and saves comparison result."""
solver = ODESolver(model)
raw_img = Image.open(img_path).convert("RGB")
input_tensor = preprocess_single_image(raw_img, device=device)
print(f"Starting tiled inference for {img_path}...")
with torch.no_grad():
restored_tensor = predict_large_image_vectorized(
solver=solver,
full_img_tensor=input_tensor,
device=device,
tile_size=256,
overlap_ratio=0.25
)
# Process for visualization
hazy_disp = restandardize_tensor(input_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy()
restored_disp = restandardize_tensor(restored_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy()
# Create Comparison Plot
os.makedirs(output_dir, exist_ok=True)
fig, axes = plt.subplots(1, 2, figsize=(15, 7))
axes[0].imshow(hazy_disp)
axes[0].set_title("Original Hazy Input")
axes[0].axis("off")
axes[1].imshow(restored_disp)
axes[1].set_title("FM-PhysMamba Restored")
axes[1].axis("off")
save_path = os.path.join(output_dir, "comparison_result.png")
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Success! Result saved to: {save_path}")
plt.show()
if __name__ == "__main__":
# 1. Load the model once
my_model = load_dehazing_model(WEIGHT_PATH, DEVICE)
# 2. Run inference
run_inference(my_model, INPUT_IMG_PATH, DEVICE, OUTPUT_DIR)
β¨ Key Features
PhysMamba Block: Combines the efficiency of State Space Models (SSMs) with physics-based architectural constraints.
Flow Matching: Uses continuous-time normalizing flows for high-fidelity image restoration.
Vectorized Tiling: The predict_large_image_vectorized function in utils.py handles high-resolution inputs (2K/4K) without running out of VRAM.
π Installation (using uv)
We recommend using uv for extremely fast dependency resolution.
# Install dependencies into a new virtual environment
uv sync
# Activate the environment
source .venv/bin/activate