| import os |
| import shutil |
| import tempfile |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import gradio as gr |
|
|
| |
| for _key in ("FAL_KEY", "HF_TOKEN"): |
| _val = os.environ.get(_key) |
| if _val: |
| os.environ[_key] = _val.strip() |
|
|
| from src.checkpoints import ensure_sam3d_checkpoints, link_mv_sam3d_checkpoints |
| from src.fal_multiview import generate_multiview_parallel, DEFAULT_ANGLES |
| from src.fal_sam3 import extract_main_object_alpha |
| from src.mv_sam3d import run_mv_sam3d_inference |
| from src.utils import ( |
| natural_key, |
| load_rgb_image, |
| save_rgb_png, |
| make_rgba_with_alpha, |
| save_rgba_png, |
| remove_white_background_alpha, |
| ) |
|
|
| MV_ROOT = Path("/mv_sam3d") |
|
|
| |
| |
| SAMPLES_S1_DIR = Path("/data/samples/stage1") |
| SAMPLES_S1_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| SAMPLES_S2_DIR = Path("/data/samples/stage2") |
| SAMPLES_S2_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| BUILTIN_S2_DIR = MV_ROOT / "data" |
|
|
|
|
| |
|
|
| def _list_stage1_samples() -> list[str]: |
| """Return list of sample names (subfolders or images in SAMPLES_S1_DIR).""" |
| if not SAMPLES_S1_DIR.exists(): |
| return [] |
| samples = [] |
| for p in sorted(SAMPLES_S1_DIR.iterdir()): |
| if p.is_file() and p.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp"): |
| samples.append(p.name) |
| return samples |
|
|
|
|
| def _list_stage2_samples() -> list[str]: |
| """Return list of sample set names for Stage 2.""" |
| names = [] |
| |
| if BUILTIN_S2_DIR.exists(): |
| for d in sorted(BUILTIN_S2_DIR.iterdir()): |
| if d.is_dir() and (d / "images").is_dir(): |
| names.append(f"[builtin] {d.name}") |
| |
| if SAMPLES_S2_DIR.exists(): |
| for d in sorted(SAMPLES_S2_DIR.iterdir()): |
| if d.is_dir(): |
| names.append(d.name) |
| return names |
|
|
|
|
| def _load_stage2_sample(name: str) -> list[str]: |
| """Load image file paths from a named Stage 2 sample set.""" |
| if name.startswith("[builtin] "): |
| folder = BUILTIN_S2_DIR / name.replace("[builtin] ", "") |
| images_dir = folder / "images" |
| else: |
| folder = SAMPLES_S2_DIR / name |
| images_dir = folder |
| if not images_dir.exists(): |
| return [] |
| exts = {".png", ".jpg", ".jpeg", ".webp"} |
| files = sorted( |
| [str(p) for p in images_dir.iterdir() if p.suffix.lower() in exts], |
| key=lambda f: natural_key(Path(f).name), |
| ) |
| return files |
|
|
|
|
| |
|
|
| def run_stage1( |
| image_file, |
| angles_str, |
| vertical_angle, |
| zoom, |
| lora_scale, |
| guidance_scale, |
| num_steps, |
| seed_val, |
| progress=gr.Progress(), |
| ): |
| """Generate multi-view images from a single input image.""" |
| if image_file is None: |
| raise gr.Error("Please upload or select a source image.") |
|
|
| |
| img_path = image_file if isinstance(image_file, str) else image_file.name |
|
|
| |
| try: |
| angles = [float(a.strip()) for a in angles_str.split(",") if a.strip()] |
| except ValueError: |
| raise gr.Error("Invalid angles format. Use comma-separated numbers like: 0,60,120,180,240,300") |
|
|
| if not angles: |
| angles = DEFAULT_ANGLES |
|
|
| seed = int(seed_val) if seed_val and int(seed_val) >= 0 else None |
|
|
| progress(0.05, desc="Uploading image & starting generation...") |
|
|
| results = generate_multiview_parallel( |
| image_path=img_path, |
| horizontal_angles=angles, |
| vertical_angle=vertical_angle, |
| zoom=zoom, |
| lora_scale=lora_scale, |
| guidance_scale=guidance_scale, |
| num_inference_steps=int(num_steps), |
| seed=seed, |
| on_progress=lambda frac, msg: progress(0.05 + frac * 0.9, desc=msg), |
| ) |
|
|
| progress(0.98, desc="Saving generated views...") |
|
|
| |
| out_paths = [] |
| for angle, img in results: |
| tmp = tempfile.NamedTemporaryFile(suffix=f"_{int(angle)}deg.png", delete=False) |
| img.save(tmp.name, format="PNG") |
| out_paths.append(tmp.name) |
|
|
| progress(1.0, desc=f"Done! Generated {len(results)} views.") |
| return out_paths |
|
|
|
|
| def send_to_stage2(gallery_images, progress=gr.Progress()): |
| """Save Stage 1 gallery images as a new Stage 2 sample set.""" |
| if not gallery_images: |
| raise gr.Error("No generated views to send. Run Stage 1 first.") |
|
|
| |
| ts = datetime.now().strftime("%Y%m%d_%H%M%S") |
| name = f"multiview_{ts}" |
| dst = SAMPLES_S2_DIR / name |
| dst.mkdir(parents=True, exist_ok=True) |
|
|
| for i, item in enumerate(gallery_images): |
| |
| src_path = item[0] if isinstance(item, (list, tuple)) else item |
| shutil.copy2(src_path, dst / f"{i}.png") |
|
|
| |
| choices = _list_stage2_samples() |
| return ( |
| gr.update(choices=choices, value=name), |
| f"β
Saved {len(gallery_images)} views as sample set '{name}'", |
| ) |
|
|
|
|
| def save_stage1_sample(image_file): |
| """Save uploaded image as a Stage 1 sample.""" |
| if image_file is None: |
| return gr.update() |
| src = image_file if isinstance(image_file, str) else image_file.name |
| dst = SAMPLES_S1_DIR / Path(src).name |
| shutil.copy2(src, dst) |
| choices = _list_stage1_samples() |
| return gr.update(choices=choices) |
|
|
|
|
| def load_stage1_sample(sample_name): |
| """Load a Stage 1 sample image.""" |
| if not sample_name: |
| return None |
| path = SAMPLES_S1_DIR / sample_name |
| return str(path) if path.exists() else None |
|
|
|
|
| |
|
|
| def load_s2_sample(sample_name): |
| """Load a Stage 2 sample set into the file upload.""" |
| if not sample_name: |
| return None |
| files = _load_stage2_sample(sample_name) |
| return files if files else None |
|
|
|
|
| def build_mv_input_from_uploads( |
| files, mask_prompt: str, prompt: str | None, pick_mode: str, max_masks: int, |
| skip_sam3: bool, white_threshold: int, progress: gr.Progress, |
| ): |
| """ |
| Create MV-SAM3D input folder: |
| input/ |
| images/0.png ... |
| <mask_prompt>/0.png ... (RGBA with alpha mask) |
| """ |
| if not files: |
| raise gr.Error("Upload at least 2 images (multi-view).") |
|
|
| files_sorted = sorted(files, key=lambda f: natural_key(Path(f).name)) |
|
|
| workdir = Path(tempfile.mkdtemp(prefix="mv_input_")) |
| input_dir = workdir / "input" |
| images_dir = input_dir / "images" |
| masks_dir = input_dir / mask_prompt |
| images_dir.mkdir(parents=True, exist_ok=True) |
| masks_dir.mkdir(parents=True, exist_ok=True) |
|
|
| n = len(files_sorted) |
| for i, fp in enumerate(files_sorted): |
| img = load_rgb_image(fp) |
|
|
| if skip_sam3: |
| progress((i + 1) / max(n, 1), desc=f"Removing white bg {i+1}/{n}") |
| alpha = remove_white_background_alpha(img, threshold=white_threshold) |
| else: |
| progress((i + 1) / max(n, 1), desc=f"SAM-3 masking view {i+1}/{n}") |
| alpha = extract_main_object_alpha( |
| image_path=fp, |
| prompt=(prompt.strip() if prompt and prompt.strip() else None), |
| pick_mode=pick_mode, |
| max_masks=max_masks, |
| ) |
|
|
| rgb_path = images_dir / f"{i}.png" |
| save_rgb_png(img, rgb_path) |
|
|
| rgba = make_rgba_with_alpha(img, alpha) |
| mask_path = masks_dir / f"{i}.png" |
| save_rgba_png(rgba, mask_path) |
|
|
| return input_dir, [str(images_dir / f"{i}.png") for i in range(n)] |
|
|
|
|
| def run_stage2_pipeline( |
| files, |
| s2_sample_name, |
| mask_prompt, |
| sam3_prompt, |
| pick_mode, |
| max_masks, |
| skip_sam3, |
| white_threshold, |
| image_names, |
| stage1_weighting, |
| stage2_weighting, |
| stage2_weight_source, |
| progress=gr.Progress(), |
| ): |
| |
| actual_files = files |
| if (not actual_files or len(actual_files) == 0) and s2_sample_name: |
| actual_files = _load_stage2_sample(s2_sample_name) |
|
|
| if not actual_files or len(actual_files) < 2: |
| raise gr.Error("Provide at least 2 multi-view images (upload or select a sample set).") |
|
|
| |
| progress(0.02, desc="Ensuring SAM-3D checkpoints...") |
| ensure_sam3d_checkpoints() |
| link_mv_sam3d_checkpoints(MV_ROOT) |
|
|
| |
| progress(0.05, desc="Preparing multi-view input...") |
| input_dir, preview_imgs = build_mv_input_from_uploads( |
| files=actual_files, |
| mask_prompt=mask_prompt, |
| prompt=sam3_prompt, |
| pick_mode=pick_mode, |
| max_masks=int(max_masks), |
| skip_sam3=skip_sam3, |
| white_threshold=int(white_threshold), |
| progress=progress, |
| ) |
|
|
| |
| progress(0.75, desc="Running MV-SAM3D inference (GPU)...") |
| out = run_mv_sam3d_inference( |
| mv_root=MV_ROOT, |
| input_dir=input_dir, |
| mask_prompt=mask_prompt, |
| image_names=image_names, |
| stage1_weighting=stage1_weighting, |
| stage2_weighting=stage2_weighting, |
| stage2_weight_source=stage2_weight_source, |
| da3_npz_path=None, |
| ) |
|
|
| progress(1.0, desc="Done!") |
| download_files = [p for p in [out.glb_path, out.ply_path, out.npz_path] if p] |
| return ( |
| out.viewer_path, |
| download_files, |
| out.log_tail, |
| preview_imgs, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| with gr.Blocks(title="Multi-View 3D Reconstruction") as demo: |
|
|
| gr.Markdown( |
| """ |
| # π― Multi-View 3D Object Reconstruction |
| **Stage 1** β Generate multi-view images from a single photo (fal Qwen Multi-Angles) |
| **Stage 2** β Reconstruct 3D model from multi-view images (MV-SAM3D) |
| |
| Each stage runs independently. Stage 1 results can be sent to Stage 2 as a sample set. |
| """ |
| ) |
|
|
| |
| with gr.Accordion("πΌοΈ Stage 1: Single Image β Multi-View Generation", open=True): |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| s1_image = gr.Image( |
| label="Source image (single object photo)", |
| type="filepath", |
| height=300, |
| ) |
| with gr.Column(scale=1): |
| s1_samples_dd = gr.Dropdown( |
| label="Stage 1 Samples", |
| choices=_list_stage1_samples(), |
| value=None, |
| interactive=True, |
| info="Select a saved sample or upload a new image", |
| ) |
| s1_save_btn = gr.Button("πΎ Save current image as sample", size="sm") |
|
|
| |
| with gr.Accordion("βοΈ Advanced Settings", open=False): |
| with gr.Row(): |
| s1_angles = gr.Textbox( |
| label="Horizontal angles (comma-separated degrees)", |
| value="0, 60, 120, 180, 240, 300", |
| info="0Β°=front, 90Β°=right, 180Β°=back, 270Β°=left", |
| ) |
| s1_vertical = gr.Slider( |
| label="Vertical angle", |
| minimum=-30, maximum=90, value=0, step=5, |
| info="-30Β°=low angle, 0Β°=eye level, 90Β°=bird's eye", |
| ) |
| with gr.Row(): |
| s1_zoom = gr.Slider(label="Zoom", minimum=0, maximum=10, value=5, step=0.5) |
| s1_lora = gr.Slider(label="LoRA scale", minimum=0, maximum=2, value=1.0, step=0.1) |
| with gr.Row(): |
| s1_guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, value=4.5, step=0.5) |
| s1_steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=28, step=1) |
| s1_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0) |
|
|
| s1_run_btn = gr.Button("π Generate Multi-Views", variant="primary", size="lg") |
|
|
| s1_gallery = gr.Gallery( |
| label="Generated multi-view images", |
| columns=6, |
| height=280, |
| object_fit="contain", |
| ) |
|
|
| with gr.Row(): |
| s1_send_btn = gr.Button("π€ Send to Stage 2 Sample Sets", variant="secondary", size="lg") |
| s1_status = gr.Textbox(label="Status", interactive=False, scale=2) |
|
|
| |
| s1_samples_dd.change( |
| fn=load_stage1_sample, |
| inputs=[s1_samples_dd], |
| outputs=[s1_image], |
| ) |
|
|
| s1_save_btn.click( |
| fn=save_stage1_sample, |
| inputs=[s1_image], |
| outputs=[s1_samples_dd], |
| ) |
|
|
| s1_run_btn.click( |
| fn=run_stage1, |
| inputs=[ |
| s1_image, |
| s1_angles, |
| s1_vertical, |
| s1_zoom, |
| s1_lora, |
| s1_guidance, |
| s1_steps, |
| s1_seed, |
| ], |
| outputs=[s1_gallery], |
| ) |
|
|
| |
| with gr.Accordion("π§ Stage 2: Multi-View β 3D Reconstruction (MV-SAM3D)", open=True): |
|
|
| with gr.Row(): |
| s2_samples_dd = gr.Dropdown( |
| label="Sample sets", |
| choices=_list_stage2_samples(), |
| value=None, |
| interactive=True, |
| info="Select a built-in or generated sample set", |
| scale=1, |
| ) |
| s2_load_btn = gr.Button("π Load Sample", size="sm", scale=0) |
|
|
| s2_files = gr.Files(label="Multi-view images (PNG/JPG)", file_types=["image"]) |
|
|
| with gr.Row(): |
| s2_mask_prompt = gr.Textbox(label="mask_prompt folder name", value="object") |
| s2_skip_sam3 = gr.Checkbox( |
| label="Skip SAM-3 (remove white background instead)", |
| value=False, |
| info="Use simple white background removal instead of fal SAM-3 API", |
| ) |
|
|
| s2_white_threshold = gr.Slider( |
| label="White BG threshold (R,G,B β₯ threshold β background)", |
| minimum=200, maximum=255, value=240, step=1, |
| visible=False, |
| ) |
|
|
| with gr.Row(): |
| s2_sam3_prompt = gr.Textbox( |
| label="SAM-3 prompt (optional, e.g. 'stuffed toy')", |
| value="", |
| ) |
| s2_pick_mode = gr.Dropdown( |
| label="Pick main object mode", |
| choices=["largest", "best_score"], |
| value="largest", |
| ) |
| s2_max_masks = gr.Slider( |
| label="SAM-3 max_masks", |
| minimum=1, maximum=10, value=5, step=1, |
| ) |
|
|
| |
| s2_skip_sam3.change( |
| fn=lambda s: ( |
| gr.update(visible=s), |
| gr.update(visible=not s), |
| gr.update(visible=not s), |
| gr.update(visible=not s), |
| ), |
| inputs=[s2_skip_sam3], |
| outputs=[s2_white_threshold, s2_sam3_prompt, s2_pick_mode, s2_max_masks], |
| ) |
|
|
| with gr.Accordion("βοΈ MV-SAM3D Parameters", open=False): |
| s2_image_names = gr.Textbox( |
| label="image_names (comma-separated, optional)", |
| placeholder="0,1,2,3,4,5", |
| ) |
| with gr.Row(): |
| s2_stage1_w = gr.Checkbox(label="Stage 1 weighting", value=False) |
| s2_stage2_w = gr.Checkbox(label="Stage 2 weighting", value=False) |
| s2_w_source = gr.Dropdown( |
| label="Stage 2 weight source", |
| choices=["entropy", "visibility", "mixed"], |
| value="entropy", |
| ) |
|
|
| s2_run_btn = gr.Button("π Run 3D Reconstruction", variant="primary", size="lg") |
|
|
| with gr.Row(): |
| s2_viewer = gr.Model3D(label="3D Preview (GLB)") |
| s2_gallery = gr.Gallery( |
| label="Prepared inputs (images/*.png)", |
| columns=4, |
| height=240, |
| ) |
|
|
| s2_downloads = gr.Files(label="β¬οΈ Download Results (GLB / PLY / NPZ)") |
|
|
| s2_log = gr.Textbox(label="Log tail", lines=15) |
|
|
| |
|
|
| |
| s1_send_btn.click( |
| fn=send_to_stage2, |
| inputs=[s1_gallery], |
| outputs=[s2_samples_dd, s1_status], |
| ) |
|
|
| |
| s2_load_btn.click( |
| fn=load_s2_sample, |
| inputs=[s2_samples_dd], |
| outputs=[s2_files], |
| ) |
|
|
| |
| s2_run_btn.click( |
| fn=run_stage2_pipeline, |
| inputs=[ |
| s2_files, |
| s2_samples_dd, |
| s2_mask_prompt, |
| s2_sam3_prompt, |
| s2_pick_mode, |
| s2_max_masks, |
| s2_skip_sam3, |
| s2_white_threshold, |
| s2_image_names, |
| s2_stage1_w, |
| s2_stage2_w, |
| s2_w_source, |
| ], |
| outputs=[s2_viewer, s2_downloads, s2_log, s2_gallery], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch(allowed_paths=["/mv_sam3d", "/tmp", "/data"]) |
|
|