hiendang7613vulcanlabs's picture
UI: remove DA3 npz upload (not used)
9e8cc3c
import os
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
import gradio as gr
# Strip trailing whitespace/newlines from API keys (HF Secrets UI sometimes adds \n)
for _key in ("FAL_KEY", "HF_TOKEN"):
_val = os.environ.get(_key)
if _val:
os.environ[_key] = _val.strip()
from src.checkpoints import ensure_sam3d_checkpoints, link_mv_sam3d_checkpoints
from src.fal_multiview import generate_multiview_parallel, DEFAULT_ANGLES
from src.fal_sam3 import extract_main_object_alpha
from src.mv_sam3d import run_mv_sam3d_inference
from src.utils import (
natural_key,
load_rgb_image,
save_rgb_png,
make_rgba_with_alpha,
save_rgba_png,
remove_white_background_alpha,
)
MV_ROOT = Path("/mv_sam3d")
# ── Sample directories ──────────────────────────────────────────────────────
# Stage 1 samples: single images for multi-view generation
SAMPLES_S1_DIR = Path("/data/samples/stage1")
SAMPLES_S1_DIR.mkdir(parents=True, exist_ok=True)
# Stage 2 samples: multi-view image sets
SAMPLES_S2_DIR = Path("/data/samples/stage2")
SAMPLES_S2_DIR.mkdir(parents=True, exist_ok=True)
# Built-in Stage 2 samples from MV-SAM3D/data
BUILTIN_S2_DIR = MV_ROOT / "data"
# ── Helpers ─────────────────────────────────────────────────────────────────
def _list_stage1_samples() -> list[str]:
"""Return list of sample names (subfolders or images in SAMPLES_S1_DIR)."""
if not SAMPLES_S1_DIR.exists():
return []
samples = []
for p in sorted(SAMPLES_S1_DIR.iterdir()):
if p.is_file() and p.suffix.lower() in (".png", ".jpg", ".jpeg", ".webp"):
samples.append(p.name)
return samples
def _list_stage2_samples() -> list[str]:
"""Return list of sample set names for Stage 2."""
names = []
# Built-in samples from MV-SAM3D/data
if BUILTIN_S2_DIR.exists():
for d in sorted(BUILTIN_S2_DIR.iterdir()):
if d.is_dir() and (d / "images").is_dir():
names.append(f"[builtin] {d.name}")
# User / Stage-1 generated samples
if SAMPLES_S2_DIR.exists():
for d in sorted(SAMPLES_S2_DIR.iterdir()):
if d.is_dir():
names.append(d.name)
return names
def _load_stage2_sample(name: str) -> list[str]:
"""Load image file paths from a named Stage 2 sample set."""
if name.startswith("[builtin] "):
folder = BUILTIN_S2_DIR / name.replace("[builtin] ", "")
images_dir = folder / "images"
else:
folder = SAMPLES_S2_DIR / name
images_dir = folder # flat folder of images
if not images_dir.exists():
return []
exts = {".png", ".jpg", ".jpeg", ".webp"}
files = sorted(
[str(p) for p in images_dir.iterdir() if p.suffix.lower() in exts],
key=lambda f: natural_key(Path(f).name),
)
return files
# ── Stage 1: Multi-View Generation ─────────────────────────────────────────
def run_stage1(
image_file,
angles_str,
vertical_angle,
zoom,
lora_scale,
guidance_scale,
num_steps,
seed_val,
progress=gr.Progress(),
):
"""Generate multi-view images from a single input image."""
if image_file is None:
raise gr.Error("Please upload or select a source image.")
# Parse image path
img_path = image_file if isinstance(image_file, str) else image_file.name
# Parse angles
try:
angles = [float(a.strip()) for a in angles_str.split(",") if a.strip()]
except ValueError:
raise gr.Error("Invalid angles format. Use comma-separated numbers like: 0,60,120,180,240,300")
if not angles:
angles = DEFAULT_ANGLES
seed = int(seed_val) if seed_val and int(seed_val) >= 0 else None
progress(0.05, desc="Uploading image & starting generation...")
results = generate_multiview_parallel(
image_path=img_path,
horizontal_angles=angles,
vertical_angle=vertical_angle,
zoom=zoom,
lora_scale=lora_scale,
guidance_scale=guidance_scale,
num_inference_steps=int(num_steps),
seed=seed,
on_progress=lambda frac, msg: progress(0.05 + frac * 0.9, desc=msg),
)
progress(0.98, desc="Saving generated views...")
# Save to temp dir for gallery display
out_paths = []
for angle, img in results:
tmp = tempfile.NamedTemporaryFile(suffix=f"_{int(angle)}deg.png", delete=False)
img.save(tmp.name, format="PNG")
out_paths.append(tmp.name)
progress(1.0, desc=f"Done! Generated {len(results)} views.")
return out_paths
def send_to_stage2(gallery_images, progress=gr.Progress()):
"""Save Stage 1 gallery images as a new Stage 2 sample set."""
if not gallery_images:
raise gr.Error("No generated views to send. Run Stage 1 first.")
# Create a named folder
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
name = f"multiview_{ts}"
dst = SAMPLES_S2_DIR / name
dst.mkdir(parents=True, exist_ok=True)
for i, item in enumerate(gallery_images):
# gallery_images items can be (filepath, caption) tuples or just filepaths
src_path = item[0] if isinstance(item, (list, tuple)) else item
shutil.copy2(src_path, dst / f"{i}.png")
# Return updated dropdown choices and select the new one
choices = _list_stage2_samples()
return (
gr.update(choices=choices, value=name),
f"βœ… Saved {len(gallery_images)} views as sample set '{name}'",
)
def save_stage1_sample(image_file):
"""Save uploaded image as a Stage 1 sample."""
if image_file is None:
return gr.update()
src = image_file if isinstance(image_file, str) else image_file.name
dst = SAMPLES_S1_DIR / Path(src).name
shutil.copy2(src, dst)
choices = _list_stage1_samples()
return gr.update(choices=choices)
def load_stage1_sample(sample_name):
"""Load a Stage 1 sample image."""
if not sample_name:
return None
path = SAMPLES_S1_DIR / sample_name
return str(path) if path.exists() else None
# ── Stage 2: 3D Reconstruction ─────────────────────────────────────────────
def load_s2_sample(sample_name):
"""Load a Stage 2 sample set into the file upload."""
if not sample_name:
return None
files = _load_stage2_sample(sample_name)
return files if files else None
def build_mv_input_from_uploads(
files, mask_prompt: str, prompt: str | None, pick_mode: str, max_masks: int,
skip_sam3: bool, white_threshold: int, progress: gr.Progress,
):
"""
Create MV-SAM3D input folder:
input/
images/0.png ...
<mask_prompt>/0.png ... (RGBA with alpha mask)
"""
if not files:
raise gr.Error("Upload at least 2 images (multi-view).")
files_sorted = sorted(files, key=lambda f: natural_key(Path(f).name))
workdir = Path(tempfile.mkdtemp(prefix="mv_input_"))
input_dir = workdir / "input"
images_dir = input_dir / "images"
masks_dir = input_dir / mask_prompt
images_dir.mkdir(parents=True, exist_ok=True)
masks_dir.mkdir(parents=True, exist_ok=True)
n = len(files_sorted)
for i, fp in enumerate(files_sorted):
img = load_rgb_image(fp)
if skip_sam3:
progress((i + 1) / max(n, 1), desc=f"Removing white bg {i+1}/{n}")
alpha = remove_white_background_alpha(img, threshold=white_threshold)
else:
progress((i + 1) / max(n, 1), desc=f"SAM-3 masking view {i+1}/{n}")
alpha = extract_main_object_alpha(
image_path=fp,
prompt=(prompt.strip() if prompt and prompt.strip() else None),
pick_mode=pick_mode,
max_masks=max_masks,
)
rgb_path = images_dir / f"{i}.png"
save_rgb_png(img, rgb_path)
rgba = make_rgba_with_alpha(img, alpha)
mask_path = masks_dir / f"{i}.png"
save_rgba_png(rgba, mask_path)
return input_dir, [str(images_dir / f"{i}.png") for i in range(n)]
def run_stage2_pipeline(
files,
s2_sample_name,
mask_prompt,
sam3_prompt,
pick_mode,
max_masks,
skip_sam3,
white_threshold,
image_names,
stage1_weighting,
stage2_weighting,
stage2_weight_source,
progress=gr.Progress(),
):
# Resolve files: from upload or from sample
actual_files = files
if (not actual_files or len(actual_files) == 0) and s2_sample_name:
actual_files = _load_stage2_sample(s2_sample_name)
if not actual_files or len(actual_files) < 2:
raise gr.Error("Provide at least 2 multi-view images (upload or select a sample set).")
# 1) ensure checkpoints
progress(0.02, desc="Ensuring SAM-3D checkpoints...")
ensure_sam3d_checkpoints()
link_mv_sam3d_checkpoints(MV_ROOT)
# 2) build input folder
progress(0.05, desc="Preparing multi-view input...")
input_dir, preview_imgs = build_mv_input_from_uploads(
files=actual_files,
mask_prompt=mask_prompt,
prompt=sam3_prompt,
pick_mode=pick_mode,
max_masks=int(max_masks),
skip_sam3=skip_sam3,
white_threshold=int(white_threshold),
progress=progress,
)
# 3) run MV-SAM3D inference
progress(0.75, desc="Running MV-SAM3D inference (GPU)...")
out = run_mv_sam3d_inference(
mv_root=MV_ROOT,
input_dir=input_dir,
mask_prompt=mask_prompt,
image_names=image_names,
stage1_weighting=stage1_weighting,
stage2_weighting=stage2_weighting,
stage2_weight_source=stage2_weight_source,
da3_npz_path=None,
)
progress(1.0, desc="Done!")
download_files = [p for p in [out.glb_path, out.ply_path, out.npz_path] if p]
return (
out.viewer_path,
download_files,
out.log_tail,
preview_imgs,
)
# ══════════════════════════════════════════════════════════════════════════════
# Gradio UI
# ══════════════════════════════════════════════════════════════════════════════
with gr.Blocks(title="Multi-View 3D Reconstruction") as demo:
gr.Markdown(
"""
# 🎯 Multi-View 3D Object Reconstruction
**Stage 1** β€” Generate multi-view images from a single photo (fal Qwen Multi-Angles)
**Stage 2** β€” Reconstruct 3D model from multi-view images (MV-SAM3D)
Each stage runs independently. Stage 1 results can be sent to Stage 2 as a sample set.
"""
)
# ── STAGE 1 ─────────────────────────────────────────────────────────────
with gr.Accordion("πŸ–ΌοΈ Stage 1: Single Image β†’ Multi-View Generation", open=True):
with gr.Row():
with gr.Column(scale=2):
s1_image = gr.Image(
label="Source image (single object photo)",
type="filepath",
height=300,
)
with gr.Column(scale=1):
s1_samples_dd = gr.Dropdown(
label="Stage 1 Samples",
choices=_list_stage1_samples(),
value=None,
interactive=True,
info="Select a saved sample or upload a new image",
)
s1_save_btn = gr.Button("πŸ’Ύ Save current image as sample", size="sm")
# Advanced settings
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
with gr.Row():
s1_angles = gr.Textbox(
label="Horizontal angles (comma-separated degrees)",
value="0, 60, 120, 180, 240, 300",
info="0Β°=front, 90Β°=right, 180Β°=back, 270Β°=left",
)
s1_vertical = gr.Slider(
label="Vertical angle",
minimum=-30, maximum=90, value=0, step=5,
info="-30Β°=low angle, 0Β°=eye level, 90Β°=bird's eye",
)
with gr.Row():
s1_zoom = gr.Slider(label="Zoom", minimum=0, maximum=10, value=5, step=0.5)
s1_lora = gr.Slider(label="LoRA scale", minimum=0, maximum=2, value=1.0, step=0.1)
with gr.Row():
s1_guidance = gr.Slider(label="Guidance scale", minimum=1, maximum=20, value=4.5, step=0.5)
s1_steps = gr.Slider(label="Inference steps", minimum=10, maximum=50, value=28, step=1)
s1_seed = gr.Number(label="Seed (-1 = random)", value=-1, precision=0)
s1_run_btn = gr.Button("πŸš€ Generate Multi-Views", variant="primary", size="lg")
s1_gallery = gr.Gallery(
label="Generated multi-view images",
columns=6,
height=280,
object_fit="contain",
)
with gr.Row():
s1_send_btn = gr.Button("πŸ“€ Send to Stage 2 Sample Sets", variant="secondary", size="lg")
s1_status = gr.Textbox(label="Status", interactive=False, scale=2)
# ── Stage 1 event handlers
s1_samples_dd.change(
fn=load_stage1_sample,
inputs=[s1_samples_dd],
outputs=[s1_image],
)
s1_save_btn.click(
fn=save_stage1_sample,
inputs=[s1_image],
outputs=[s1_samples_dd],
)
s1_run_btn.click(
fn=run_stage1,
inputs=[
s1_image,
s1_angles,
s1_vertical,
s1_zoom,
s1_lora,
s1_guidance,
s1_steps,
s1_seed,
],
outputs=[s1_gallery],
)
# ── STAGE 2 ─────────────────────────────────────────────────────────────
with gr.Accordion("🧊 Stage 2: Multi-View β†’ 3D Reconstruction (MV-SAM3D)", open=True):
with gr.Row():
s2_samples_dd = gr.Dropdown(
label="Sample sets",
choices=_list_stage2_samples(),
value=None,
interactive=True,
info="Select a built-in or generated sample set",
scale=1,
)
s2_load_btn = gr.Button("πŸ“‚ Load Sample", size="sm", scale=0)
s2_files = gr.Files(label="Multi-view images (PNG/JPG)", file_types=["image"])
with gr.Row():
s2_mask_prompt = gr.Textbox(label="mask_prompt folder name", value="object")
s2_skip_sam3 = gr.Checkbox(
label="Skip SAM-3 (remove white background instead)",
value=False,
info="Use simple white background removal instead of fal SAM-3 API",
)
s2_white_threshold = gr.Slider(
label="White BG threshold (R,G,B β‰₯ threshold β†’ background)",
minimum=200, maximum=255, value=240, step=1,
visible=False,
)
with gr.Row():
s2_sam3_prompt = gr.Textbox(
label="SAM-3 prompt (optional, e.g. 'stuffed toy')",
value="",
)
s2_pick_mode = gr.Dropdown(
label="Pick main object mode",
choices=["largest", "best_score"],
value="largest",
)
s2_max_masks = gr.Slider(
label="SAM-3 max_masks",
minimum=1, maximum=10, value=5, step=1,
)
# Toggle SAM-3 vs white BG options
s2_skip_sam3.change(
fn=lambda s: (
gr.update(visible=s),
gr.update(visible=not s),
gr.update(visible=not s),
gr.update(visible=not s),
),
inputs=[s2_skip_sam3],
outputs=[s2_white_threshold, s2_sam3_prompt, s2_pick_mode, s2_max_masks],
)
with gr.Accordion("βš™οΈ MV-SAM3D Parameters", open=False):
s2_image_names = gr.Textbox(
label="image_names (comma-separated, optional)",
placeholder="0,1,2,3,4,5",
)
with gr.Row():
s2_stage1_w = gr.Checkbox(label="Stage 1 weighting", value=False)
s2_stage2_w = gr.Checkbox(label="Stage 2 weighting", value=False)
s2_w_source = gr.Dropdown(
label="Stage 2 weight source",
choices=["entropy", "visibility", "mixed"],
value="entropy",
)
s2_run_btn = gr.Button("πŸš€ Run 3D Reconstruction", variant="primary", size="lg")
with gr.Row():
s2_viewer = gr.Model3D(label="3D Preview (GLB)")
s2_gallery = gr.Gallery(
label="Prepared inputs (images/*.png)",
columns=4,
height=240,
)
s2_downloads = gr.Files(label="⬇️ Download Results (GLB / PLY / NPZ)")
s2_log = gr.Textbox(label="Log tail", lines=15)
# ── Stage 2 event handlers
# Wire "Send to Stage 2" button from Stage 1
s1_send_btn.click(
fn=send_to_stage2,
inputs=[s1_gallery],
outputs=[s2_samples_dd, s1_status],
)
# Load sample set into file upload
s2_load_btn.click(
fn=load_s2_sample,
inputs=[s2_samples_dd],
outputs=[s2_files],
)
# Run 3D reconstruction
s2_run_btn.click(
fn=run_stage2_pipeline,
inputs=[
s2_files,
s2_samples_dd,
s2_mask_prompt,
s2_sam3_prompt,
s2_pick_mode,
s2_max_masks,
s2_skip_sam3,
s2_white_threshold,
s2_image_names,
s2_stage1_w,
s2_stage2_w,
s2_w_source,
],
outputs=[s2_viewer, s2_downloads, s2_log, s2_gallery],
)
if __name__ == "__main__":
demo.launch(allowed_paths=["/mv_sam3d", "/tmp", "/data"])