import os import shutil import tempfile from datetime import datetime from pathlib import Path import gradio as gr # Strip trailing whitespace/newlines from API keys (HF Secrets UI sometimes adds \n) for _key in ("FAL_KEY", "HF_TOKEN"): _val = os.environ.get(_key) if _val: os.environ[_key] = _val.strip() from src.checkpoints import ensure_sam3d_checkpoints, link_mv_sam3d_checkpoints from src.fal_multiview import generate_multiview_parallel, DEFAULT_ANGLES from src.fal_sam3 import extract_main_object_alpha from src.mv_sam3d import run_mv_sam3d_inference from src.utils import ( natural_key, load_rgb_image, save_rgb_png, make_rgba_with_alpha, save_rgba_png, remove_white_background_alpha, ) MV_ROOT = Path("/mv_sam3d") # ── Sample directories ────────────────────────────────────────────────────── # Stage 1 samples: single images for multi-view generation SAMPLES_S1_DIR = Path("/data/samples/stage1") SAMPLES_S1_DIR.mkdir(parents=True, exist_ok=True) # Stage 2 samples: multi-view image sets SAMPLES_S2_DIR = Path("/data/samples/stage2") SAMPLES_S2_DIR.mkdir(parents=True, exist_ok=True) # Built-in Stage 2 samples from MV-SAM3D/data BUILTIN_S2_DIR = MV_ROOT / "data" # ── Helpers ───────────────────────────────────────────────────────────────── def _list_stage1_samples() -> list[str]: """Return list of sample names (subfolders or images in SAMPLES_S1_DIR).""" if not SAMPLES_S1_DIR.exists(): return [] samples = [] for p in sorted(SAMPLES_S1_DIR.iterdir()): if p.is_file() and p.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp"): samples.append(p.name) return samples def _list_stage2_samples() -> list[str]: """Return list of sample set names for Stage 2.""" names = [] # Built-in samples from MV-SAM3D/data if BUILTIN_S2_DIR.exists(): for d in sorted(BUILTIN_S2_DIR.iterdir()): if d.is_dir() and (d / "images").is_dir(): names.append(f"[builtin] {d.name}") # User / Stage-1 generated samples if SAMPLES_S2_DIR.exists(): for d in sorted(SAMPLES_S2_DIR.iterdir()): if d.is_dir(): names.append(d.name) return names def _load_stage2_sample(name: str) -> list[str]: """Load image file paths from a named Stage 2 sample set.""" if name.startswith("[builtin] "): folder = BUILTIN_S2_DIR / name.replace("[builtin] ", "") images_dir = folder / "images" else: folder = SAMPLES_S2_DIR / name images_dir = folder # flat folder of images if not images_dir.exists(): return [] exts = {".png", ".jpg", ".jpeg", ".webp"} files = sorted( [str(p) for p in images_dir.iterdir() if p.suffix.lower() in exts], key=lambda f: natural_key(Path(f).name), ) return files # ── Stage 1: Multi-View Generation ───────────────────────────────────────── def run_stage1( image_file, angles_str, vertical_angle, zoom, lora_scale, guidance_scale, num_steps, seed_val, progress=gr.Progress(), ): """Generate multi-view images from a single input image.""" if image_file is None: raise gr.Error("Please upload or select a source image.") # Parse image path img_path = image_file if isinstance(image_file, str) else image_file.name # Parse angles try: angles = [float(a.strip()) for a in angles_str.split(",") if a.strip()] except ValueError: raise gr.Error("Invalid angles format. Use comma-separated numbers like: 0,60,120,180,240,300") if not angles: angles = DEFAULT_ANGLES seed = int(seed_val) if seed_val and int(seed_val) >= 0 else None progress(0.05, desc="Uploading image & starting generation...") results = generate_multiview_parallel( image_path=img_path, horizontal_angles=angles, vertical_angle=vertical_angle, zoom=zoom, lora_scale=lora_scale, guidance_scale=guidance_scale, num_inference_steps=int(num_steps), seed=seed, on_progress=lambda frac, msg: progress(0.05 + frac * 0.9, desc=msg), ) progress(0.98, desc="Saving generated views...") # Save to temp dir for gallery display out_paths = [] for angle, img in results: tmp = tempfile.NamedTemporaryFile(suffix=f"_{int(angle)}deg.png", delete=False) img.save(tmp.name, format="PNG") out_paths.append(tmp.name) progress(1.0, desc=f"Done! Generated {len(results)} views.") return out_paths def send_to_stage2(gallery_images, progress=gr.Progress()): """Save Stage 1 gallery images as a new Stage 2 sample set.""" if not gallery_images: raise gr.Error("No generated views to send. Run Stage 1 first.") # Create a named folder ts = datetime.now().strftime("%Y%m%d_%H%M%S") name = f"multiview_{ts}" dst = SAMPLES_S2_DIR / name dst.mkdir(parents=True, exist_ok=True) for i, item in enumerate(gallery_images): # gallery_images items can be (filepath, caption) tuples or just filepaths src_path = item[0] if isinstance(item, (list, tuple)) else item shutil.copy2(src_path, dst / f"{i}.png") # Return updated dropdown choices and select the new one choices = _list_stage2_samples() return ( gr.update(choices=choices, value=name), f"✅ Saved {len(gallery_images)} views as sample set '{name}'", ) def save_stage1_sample(image_file): """Save uploaded image as a Stage 1 sample.""" if image_file is None: return gr.update() src = image_file if isinstance(image_file, str) else image_file.name dst = SAMPLES_S1_DIR / Path(src).name shutil.copy2(src, dst) choices = _list_stage1_samples() return gr.update(choices=choices) def load_stage1_sample(sample_name): """Load a Stage 1 sample image.""" if not sample_name: return None path = SAMPLES_S1_DIR / sample_name return str(path) if path.exists() else None # ── Stage 2: 3D Reconstruction ───────────────────────────────────────────── def load_s2_sample(sample_name): """Load a Stage 2 sample set into the file upload.""" if not sample_name: return None files = _load_stage2_sample(sample_name) return files if files else None def build_mv_input_from_uploads( files, mask_prompt: str, prompt: str | None, pick_mode: str, max_masks: int, skip_sam3: bool, white_threshold: int, progress: gr.Progress, ): """ Create MV-SAM3D input folder: input/ images/0.png ... /0.png ... (RGBA with alpha mask) """ if not files: raise gr.Error("Upload at least 2 images (multi-view).") files_sorted = sorted(files, key=lambda f: natural_key(Path(f).name)) workdir = Path(tempfile.mkdtemp(prefix="mv_input_")) input_dir = workdir / "input" images_dir = input_dir / "images" masks_dir = input_dir / mask_prompt images_dir.mkdir(parents=True, exist_ok=True) masks_dir.mkdir(parents=True, exist_ok=True) n = len(files_sorted) for i, fp in enumerate(files_sorted): img = load_rgb_image(fp) if skip_sam3: progress((i + 1) / max(n, 1), desc=f"Removing white bg {i+1}/{n}") alpha = remove_white_background_alpha(img, threshold=white_threshold) else: progress((i + 1) / max(n, 1), desc=f"SAM-3 masking view {i+1}/{n}") alpha = extract_main_object_alpha( image_path=fp, prompt=(prompt.strip() if prompt and prompt.strip() else None), pick_mode=pick_mode, max_masks=max_masks, ) rgb_path = images_dir / f"{i}.png" save_rgb_png(img, rgb_path) rgba = make_rgba_with_alpha(img, alpha) mask_path = masks_dir / f"{i}.png" save_rgba_png(rgba, mask_path) return input_dir, [str(images_dir / f"{i}.png") for i in range(n)] def run_stage2_pipeline( files, s2_sample_name, mask_prompt, sam3_prompt, pick_mode, max_masks, skip_sam3, white_threshold, image_names, stage1_weighting, stage2_weighting, stage2_weight_source, progress=gr.Progress(), ): # Resolve files: from upload or from sample actual_files = files if (not actual_files or len(actual_files) == 0) and s2_sample_name: actual_files = _load_stage2_sample(s2_sample_name) if not actual_files or len(actual_files) < 2: raise gr.Error("Provide at least 2 multi-view images (upload or select a sample set).") # 1) ensure checkpoints progress(0.02, desc="Ensuring SAM-3D checkpoints...") ensure_sam3d_checkpoints() link_mv_sam3d_checkpoints(MV_ROOT) # 2) build input folder progress(0.05, desc="Preparing multi-view input...") input_dir, preview_imgs = build_mv_input_from_uploads( files=actual_files, mask_prompt=mask_prompt, prompt=sam3_prompt, pick_mode=pick_mode, max_masks=int(max_masks), skip_sam3=skip_sam3, white_threshold=int(white_threshold), progress=progress, ) # 3) run MV-SAM3D inference progress(0.75, desc="Running MV-SAM3D inference (GPU)...") out = run_mv_sam3d_inference( mv_root=MV_ROOT, input_dir=input_dir, mask_prompt=mask_prompt, image_names=image_names, stage1_weighting=stage1_weighting, stage2_weighting=stage2_weighting, stage2_weight_source=stage2_weight_source, da3_npz_path=None, ) progress(1.0, desc="Done!") download_files = [p for p in [out.glb_path, out.ply_path, out.npz_path] if p] return ( out.viewer_path, download_files, out.log_tail, preview_imgs, ) # ══════════════════════════════════════════════════════════════════════════════ # Gradio UI # ══════════════════════════════════════════════════════════════════════════════ with gr.Blocks(title="Multi-View 3D Reconstruction") as demo: gr.Markdown( """ # 🎯 Multi-View 3D Object Reconstruction **Stage 1** — Generate multi-view images from a single photo (fal Qwen Multi-Angles) **Stage 2** — Reconstruct 3D model from multi-view images (MV-SAM3D) Each stage runs independently. Stage 1 results can be sent to Stage 2 as a sample set. """ ) # ── STAGE 1 ───────────────────────────────────────────────────────────── with gr.Accordion("🖼️ Stage 1: Single Image → Multi-View Generation", open=True): with gr.Row(): with gr.Column(scale=2): s1_image = gr.Image( label="Source image (single object photo)", type="filepath", height=300, ) with gr.Column(scale=1): s1_samples_dd = gr.Dropdown( label="Stage 1 Samples", choices=_list_stage1_samples(), value=None, interactive=True, info="Select a saved sample or upload a new image", ) s1_save_btn = gr.Button("💾 Save current image as sample", size="sm") # Advanced settings with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): s1_angles = gr.Textbox( label="Horizontal angles (comma-separated degrees)", value="0, 60, 120, 180, 240, 300", info="0°=front, 90°=right, 180°=back, 270°=left", ) s1_vertical = gr.Slider( label="Vertical angle", minimum=-30, maximum=90, value=0, step=5, info="-30°=low angle, 0°=eye level, 90°=bird's eye", ) with gr.Row(): s1_zoom = gr.Slider(label="Zoom", minimum=0, maximum=10, value=5, step=0.5) s1_lora = gr.Slider(label="LoRA scale", minimum=0, maximum=2, value=1.0, step=0.1) with gr.Row(): s1_guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, value=4.5, step=0.5) s1_steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=28, step=1) s1_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0) s1_run_btn = gr.Button("🚀 Generate Multi-Views", variant="primary", size="lg") s1_gallery = gr.Gallery( label="Generated multi-view images", columns=6, height=280, object_fit="contain", ) with gr.Row(): s1_send_btn = gr.Button("📤 Send to Stage 2 Sample Sets", variant="secondary", size="lg") s1_status = gr.Textbox(label="Status", interactive=False, scale=2) # ── Stage 1 event handlers s1_samples_dd.change( fn=load_stage1_sample, inputs=[s1_samples_dd], outputs=[s1_image], ) s1_save_btn.click( fn=save_stage1_sample, inputs=[s1_image], outputs=[s1_samples_dd], ) s1_run_btn.click( fn=run_stage1, inputs=[ s1_image, s1_angles, s1_vertical, s1_zoom, s1_lora, s1_guidance, s1_steps, s1_seed, ], outputs=[s1_gallery], ) # ── STAGE 2 ───────────────────────────────────────────────────────────── with gr.Accordion("🧊 Stage 2: Multi-View → 3D Reconstruction (MV-SAM3D)", open=True): with gr.Row(): s2_samples_dd = gr.Dropdown( label="Sample sets", choices=_list_stage2_samples(), value=None, interactive=True, info="Select a built-in or generated sample set", scale=1, ) s2_load_btn = gr.Button("📂 Load Sample", size="sm", scale=0) s2_files = gr.Files(label="Multi-view images (PNG/JPG)", file_types=["image"]) with gr.Row(): s2_mask_prompt = gr.Textbox(label="mask_prompt folder name", value="object") s2_skip_sam3 = gr.Checkbox( label="Skip SAM-3 (remove white background instead)", value=False, info="Use simple white background removal instead of fal SAM-3 API", ) s2_white_threshold = gr.Slider( label="White BG threshold (R,G,B ≥ threshold → background)", minimum=200, maximum=255, value=240, step=1, visible=False, ) with gr.Row(): s2_sam3_prompt = gr.Textbox( label="SAM-3 prompt (optional, e.g. 'stuffed toy')", value="", ) s2_pick_mode = gr.Dropdown( label="Pick main object mode", choices=["largest", "best_score"], value="largest", ) s2_max_masks = gr.Slider( label="SAM-3 max_masks", minimum=1, maximum=10, value=5, step=1, ) # Toggle SAM-3 vs white BG options s2_skip_sam3.change( fn=lambda s: ( gr.update(visible=s), gr.update(visible=not s), gr.update(visible=not s), gr.update(visible=not s), ), inputs=[s2_skip_sam3], outputs=[s2_white_threshold, s2_sam3_prompt, s2_pick_mode, s2_max_masks], ) with gr.Accordion("⚙️ MV-SAM3D Parameters", open=False): s2_image_names = gr.Textbox( label="image_names (comma-separated, optional)", placeholder="0,1,2,3,4,5", ) with gr.Row(): s2_stage1_w = gr.Checkbox(label="Stage 1 weighting", value=False) s2_stage2_w = gr.Checkbox(label="Stage 2 weighting", value=False) s2_w_source = gr.Dropdown( label="Stage 2 weight source", choices=["entropy", "visibility", "mixed"], value="entropy", ) s2_run_btn = gr.Button("🚀 Run 3D Reconstruction", variant="primary", size="lg") with gr.Row(): s2_viewer = gr.Model3D(label="3D Preview (GLB)") s2_gallery = gr.Gallery( label="Prepared inputs (images/*.png)", columns=4, height=240, ) s2_downloads = gr.Files(label="⬇️ Download Results (GLB / PLY / NPZ)") s2_log = gr.Textbox(label="Log tail", lines=15) # ── Stage 2 event handlers # Wire "Send to Stage 2" button from Stage 1 s1_send_btn.click( fn=send_to_stage2, inputs=[s1_gallery], outputs=[s2_samples_dd, s1_status], ) # Load sample set into file upload s2_load_btn.click( fn=load_s2_sample, inputs=[s2_samples_dd], outputs=[s2_files], ) # Run 3D reconstruction s2_run_btn.click( fn=run_stage2_pipeline, inputs=[ s2_files, s2_samples_dd, s2_mask_prompt, s2_sam3_prompt, s2_pick_mode, s2_max_masks, s2_skip_sam3, s2_white_threshold, s2_image_names, s2_stage1_w, s2_stage2_w, s2_w_source, ], outputs=[s2_viewer, s2_downloads, s2_log, s2_gallery], ) if __name__ == "__main__": demo.launch(allowed_paths=["/mv_sam3d", "/tmp", "/data"])