FM_PhysMamba_UNET: Physics-Informed Mamba for Image Dehazing

This repository contains the official implementation of FM_PhysMamba_UNET, a State Space Model (Mamba) architecture integrated with Flow Matching and physical priors for high-performance image dehazing.

Image Illustrations

FM-PhysMamba Dehazing Result
Figure 1: Comparison between Hazy Input (Left) and FM-PhysMamba Restored Output (Right)

πŸ“‚ Repository Structure

To ensure all internal imports (e.g., from model import ...) function correctly, please maintain this layout:

  • model/: Contains the core UNet and PhysMamba blocks.

  • configs/: Model configuration files.

  • utils.py: Contains the utility logic for model inference.

  • data/: Data processing and re-standardization utilities.

  • checkpoints/: Directory for model weights (pytorch_model.bin) (Please put your weights inside this directory)

  • test_imgs/: Sample images for testing.

πŸš€ Quick Start: Inference

This script performs dehazing on large images by breaking them into overlapping tiles to manage memory efficiently.

import torch
import os
from PIL import Image
from model import FM_PhysMamba_UNET, ODESolver
from utils import predict_large_image_vectorized, preprocess_single_image
from data.utils import restandardize_tensor
import matplotlib.pyplot as plt

# ==========================================
# CONFIGURATION
# ==========================================
# REPLACE THIS with your actual weight file path
WEIGHT_PATH = "checkpoints/final_weights/DENSE_HAZE/pytorch_model.bin" 

INPUT_IMG_PATH = "test_imgs/test_images.jpeg"
OUTPUT_DIR = "test_imgs/results"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def load_dehazing_model(weight_path, device):
    """Initializes model and loads pre-trained weights."""
    print(f"Loading model to {device}...")
    model = FM_PhysMamba_UNET("small").to(device)
    
    if not os.path.exists(weight_path):
        raise FileNotFoundError(f"Weights not found at: {weight_path}")
        
    checkpoint = torch.load(weight_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint)
    model.eval()
    return model

def run_inference(model, img_path, device, output_dir):
    """Performs tiled inference and saves comparison result."""
    solver = ODESolver(model)
    raw_img = Image.open(img_path).convert("RGB")
    input_tensor = preprocess_single_image(raw_img, device=device)
    
    print(f"Starting tiled inference for {img_path}...")
    with torch.no_grad():
        restored_tensor = predict_large_image_vectorized(
            solver=solver, 
            full_img_tensor=input_tensor, 
            device=device,
            tile_size=256,
            overlap_ratio=0.25
        )
    
    # Process for visualization
    hazy_disp = restandardize_tensor(input_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy()
    restored_disp = restandardize_tensor(restored_tensor.detach().squeeze(0).cpu()).permute(1, 2, 0).numpy()

    # Create Comparison Plot
    os.makedirs(output_dir, exist_ok=True)
    fig, axes = plt.subplots(1, 2, figsize=(15, 7))
    axes[0].imshow(hazy_disp)
    axes[0].set_title("Original Hazy Input")
    axes[0].axis("off")
    
    axes[1].imshow(restored_disp)
    axes[1].set_title("FM-PhysMamba Restored")
    axes[1].axis("off")
    
    save_path = os.path.join(output_dir, "comparison_result.png")
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"Success! Result saved to: {save_path}")
    plt.show()

if __name__ == "__main__":
    # 1. Load the model once
    my_model = load_dehazing_model(WEIGHT_PATH, DEVICE)
    
    # 2. Run inference
    run_inference(my_model, INPUT_IMG_PATH, DEVICE, OUTPUT_DIR)

✨ Key Features

  • PhysMamba Block: Combines the efficiency of State Space Models (SSMs) with physics-based architectural constraints.

  • Flow Matching: Uses continuous-time normalizing flows for high-fidelity image restoration.

  • Vectorized Tiling: The predict_large_image_vectorized function in utils.py handles high-resolution inputs (2K/4K) without running out of VRAM.

πŸ›  Installation (using uv)

We recommend using uv for extremely fast dependency resolution.


# Install dependencies into a new virtual environment
uv sync

# Activate the environment
source .venv/bin/activate
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support