Spaces:
Running on Zero
Running on Zero
| # --------------------------------------------------------- Config ------------------------------------------------- | |
| num_workers :int = 1 | |
| DDIM_STEPS = 50 | |
| BATCH_SIZE = 1 | |
| FIXED_CODE = False | |
| # for vis | |
| SAVE_INTERMEDIATES = True | |
| NUM_grid_in_a_column = 5 | |
| # ------------------------------------------------------------------------------------------------------------------------ | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Custom inference for tgt/ref image pairs.") | |
| parser.add_argument("--task-name", type=str, | |
| default='face', | |
| help="face|hair|motion|head") | |
| parser.add_argument("--out-dir", type=str, default='examples/outputs', help="Output directory") | |
| # option 1: pass 2 paths | |
| parser.add_argument("--tgt", type=str, default=None, help="Path to target image. if None, will use paths read from --pair-list") | |
| parser.add_argument("--ref", type=str, default=None, help="Path to reference image") | |
| # option 2: pass a txt containing paths | |
| parser.add_argument("--pair-list", type=str, default='examples/inputs.txt', help="white-space-separated list file: tgt_path ref_path") | |
| args = parser.parse_args() | |
| #-----------------------------------------set TASK-------------------------------------------------------------------------- | |
| task_name :str = args.task_name | |
| TASK :int = { | |
| 'face': 0, | |
| 'hair': 1, | |
| 'motion': 2, | |
| 'head': 3, | |
| }[task_name] | |
| print(f'task: {task_name} transfer (ID: {TASK})') | |
| # ------------------------------------------------------------------------------------------------------------------------ | |
| import sys | |
| import os | |
| from pathlib import Path | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| from imports import * | |
| import torch | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from einops import rearrange | |
| from torchvision.utils import make_grid | |
| from my_py_lib.image_util import imgs_2_grid_A,img_paths_2_grid_A | |
| from pytorch_lightning import seed_everything | |
| from torch import autocast | |
| from contextlib import nullcontext | |
| import torchvision | |
| from ldm.models.diffusion.ddpm import LatentDiffusion | |
| from ldm.util import instantiate_from_config | |
| from ldm.models.diffusion.ddim import DDIMSampler | |
| from Dataset_custom import Dataset_custom | |
| from MoE import offload_unused_tasks__LD | |
| from ldm.models.diffusion.ddpm import LandmarkExtractor | |
| from my_py_lib.torch_util import cleanup_gpu_memory | |
| from gen_lmk_and_mask import gen_lmk_and_mask | |
| # ------------------------------------------------------------------------------------------------------------------------ | |
| DDIM_ETA = 0.0 | |
| SCALE = 3.0 | |
| PRECISION = "full" # "full" or "autocast" | |
| H = 512 | |
| W = 512 | |
| C = 4 | |
| F = 8 | |
| # ------------------------------------------------------------------------------------------------------------------------ | |
| def load_first_stage_from_sd14(model: LatentDiffusion, sd14_path: Path) -> None: | |
| print(f"Loading first_stage_model from {sd14_path}") | |
| sd14 = torch.load(str(sd14_path), map_location="cpu") | |
| if isinstance(sd14, dict) and "state_dict" in sd14: | |
| sd14_sd = sd14["state_dict"] | |
| else: | |
| sd14_sd = sd14 | |
| prefixes = ["first_stage_model.", "model.first_stage_model."] | |
| fs_sd = {} | |
| for prefix in prefixes: | |
| for k, v in sd14_sd.items(): | |
| if k.startswith(prefix): | |
| fs_sd[k[len(prefix):]] = v | |
| if fs_sd: | |
| break | |
| if not fs_sd: | |
| raise RuntimeError("Could not find first_stage_model weights in SD v1-4 checkpoint.") | |
| model.first_stage_model.load_state_dict(fs_sd, strict=True) | |
| def save_sample_by_decode(x, model, base_path, segment_id, intermediate_num): | |
| x = model.decode_first_stage(x) | |
| x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) | |
| x = x.cpu().permute(0, 2, 3, 1).numpy() | |
| for i in range(len(x)): | |
| img = Image.fromarray((x[i] * 255).astype(np.uint8)) | |
| save_path = Path(base_path) / segment_id | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| img.save(save_path / f"{intermediate_num}.png") | |
| def get_tensor_clip(normalize=True, toTensor=True): | |
| transform_list = [] | |
| if toTensor: | |
| transform_list += [torchvision.transforms.ToTensor()] | |
| if normalize: | |
| transform_list += [ | |
| torchvision.transforms.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ) | |
| ] | |
| return torchvision.transforms.Compose(transform_list) | |
| def load_model_from_config(ckpt, verbose=1): | |
| if 1: | |
| ckpt = Path(ckpt) | |
| print(f"Loading model from {ckpt}") | |
| pl_sd = torch.load(str(ckpt), map_location="cpu") | |
| if isinstance(pl_sd, dict) and "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| sd = pl_sd | |
| else: | |
| print("DEBUG_skip_load_ckpt") | |
| if 1: | |
| from init_model import get_moe | |
| model: LatentDiffusion = get_moe() | |
| model.ptsM_Generator = LandmarkExtractor(include_visualizer=True, img_256_mode=False) | |
| cleanup_gpu_memory() | |
| if 1: | |
| m, u = model.load_state_dict(sd, strict=False) | |
| if len(m) > 0 and verbose: | |
| print("missing keys:") | |
| pretty_print_torch_module_keys(m) | |
| if len(u) > 0 and verbose: | |
| print("unexpected keys:") | |
| pretty_print_torch_module_keys(u) | |
| load_first_stage_from_sd14(model, SD14_localpath) | |
| offload_unused_tasks__LD(model, TASK, method="del") # for save cuda mem | |
| model.cuda() | |
| model.eval() | |
| return model | |
| def load_pairs(pair_list, tgt, ref): | |
| if tgt and ref: | |
| pairs = [(tgt, ref), ] | |
| elif pair_list: | |
| pairs = [] | |
| with open(pair_list, "r") as f: | |
| for line_num, line in enumerate(f, start=1): | |
| line = line.strip() | |
| if not line or line.startswith("#"): | |
| continue | |
| parts = line.split(" ") | |
| if len(parts) != 2: | |
| raise ValueError(f"Invalid pair list line {line_num}: expected white-space-separated tgt/ref. got {parts=}") | |
| pairs.append((parts[0], parts[1])) | |
| else: | |
| raise ValueError("No input pairs provided. Use --tgt/--ref or --pair-list.") | |
| print(f"{pairs=}") | |
| return pairs | |
| def un_norm(x): | |
| return (x + 1.0) / 2.0 | |
| def un_norm_clip(x1): | |
| x = x1 * 1.0 | |
| reduce = False | |
| if len(x.shape) == 3: | |
| x = x.unsqueeze(0) | |
| reduce = True | |
| x[:, 0, :, :] = x[:, 0, :, :] * 0.26862954 + 0.48145466 | |
| x[:, 1, :, :] = x[:, 1, :, :] * 0.26130258 + 0.4578275 | |
| x[:, 2, :, :] = x[:, 2, :, :] * 0.27577711 + 0.40821073 | |
| if reduce: | |
| x = x.squeeze(0) | |
| return x | |
| if __name__ == "__main__": | |
| pairs = load_pairs(args.pair_list, args.tgt, args.ref) | |
| out_dir = Path(args.out_dir) | |
| result_path = out_dir / "results" | |
| grid_path = out_dir / "grid" | |
| inter_path = out_dir / "intermediates" | |
| inter_pred_path = inter_path / "pred_x0" | |
| inter_noised_path = inter_path / "noised" | |
| out_dir.mkdir(parents=False, exist_ok=True) | |
| result_path.mkdir(parents=False, exist_ok=True) | |
| grid_path.mkdir(parents=False, exist_ok=True) | |
| inter_path.mkdir(parents=False, exist_ok=True) | |
| if SAVE_INTERMEDIATES: | |
| inter_pred_path.mkdir(parents=False, exist_ok=True) | |
| inter_noised_path.mkdir(parents=False, exist_ok=True) | |
| paths_tgt = [p[0] for p in pairs] | |
| paths_ref = [p[1] for p in pairs] | |
| gen_lmk_and_mask(paths_tgt + paths_ref) | |
| seed_everything(42) | |
| model: LatentDiffusion = load_model_from_config(PRETRAIN_CKPT_PATH, ) | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| model = model.to(device) | |
| sampler = DDIMSampler(model) | |
| dataset = Dataset_custom( | |
| "test", | |
| task=TASK, | |
| paths_tgt=paths_tgt, | |
| paths_ref=paths_ref, | |
| ) | |
| dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=BATCH_SIZE, | |
| num_workers=num_workers, | |
| pin_memory=True, | |
| shuffle=False, | |
| drop_last=False, | |
| ) | |
| start_code = None | |
| if FIXED_CODE: | |
| start_code = torch.randn([BATCH_SIZE, C, H // F, W // F], device=device) | |
| precision_scope = autocast if PRECISION == "autocast" else nullcontext | |
| grids = [] | |
| grid_stems = [] | |
| with torch.no_grad(): | |
| with precision_scope("cuda"): | |
| with model.ema_scope(): | |
| for test_batch, prior, test_model_kwargs, out_stem_batch in tqdm(dataloader): | |
| model.set_task(test_model_kwargs) | |
| bs = test_batch.shape[0] | |
| batch_ = { | |
| **test_model_kwargs, | |
| "GT": torch.zeros_like(test_model_kwargs["inpaint_image"]), | |
| } | |
| batch_, c = model.get_input_and_conditioning(batch_, device=device) | |
| z_inpaint = batch_["z4_inpaint"] | |
| z_inpaint_mask = batch_["tgt_mask_64"] | |
| z_ref = batch_["z_ref"] | |
| z9 = batch_["z9"] | |
| uc = None | |
| if SCALE != 1.0: | |
| uc = model.learnable_vector[TASK].repeat(bs, 1, 1) | |
| shape = [C, H // F, W // F] | |
| local_start_code = start_code | |
| if FIXED_CODE and (local_start_code is None or local_start_code.shape[0] != bs): | |
| local_start_code = torch.randn([bs, C, H // F, W // F], device=device) | |
| samples_ddim, intermediates = sampler.sample( | |
| S=DDIM_STEPS, | |
| conditioning=c, | |
| batch_size=bs, | |
| shape=shape, | |
| verbose=False, | |
| unconditional_guidance_scale=SCALE, | |
| unconditional_conditioning=uc, | |
| eta=DDIM_ETA, | |
| x_T=local_start_code, | |
| log_every_t=100, | |
| z_inpaint=z_inpaint, | |
| z_inpaint_mask=z_inpaint_mask, | |
| z_ref=z_ref, | |
| z9=z9, | |
| ) | |
| if SAVE_INTERMEDIATES: | |
| intermediate_pred_x0 = intermediates["pred_x0"] | |
| intermediate_noised = intermediates["x_inter"] | |
| for i in range(len(intermediate_pred_x0)): | |
| for j in range(bs): | |
| stem = f"{out_stem_batch[j]}" | |
| save_sample_by_decode( | |
| intermediate_pred_x0[i][j : j + 1], | |
| model, | |
| inter_pred_path, | |
| stem, | |
| i, | |
| ) | |
| save_sample_by_decode( | |
| intermediate_noised[i][j : j + 1], | |
| model, | |
| inter_noised_path, | |
| stem, | |
| i, | |
| ) | |
| x_samples_ddim = model.decode_first_stage(samples_ddim) | |
| x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) | |
| x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() | |
| x_checked_image_torch = torch.from_numpy(x_samples_ddim).permute(0, 3, 1, 2) | |
| for i, x_sample in enumerate(x_checked_image_torch): | |
| stem = f"{out_stem_batch[i]}" | |
| out_path = result_path / f"{stem}.png" | |
| img = Image.fromarray((x_sample.permute(1, 2, 0).numpy() * 255).astype(np.uint8)) | |
| img.save(out_path) | |
| print(f"{out_path=}") | |
| for i, x_sample in enumerate(x_checked_image_torch): | |
| all_img = [] | |
| all_img.append(un_norm(test_batch[i]).cpu()) | |
| if TASK != 2: | |
| ref_img = test_model_kwargs["ref_imgs"].squeeze(1) | |
| ref_img = torchvision.transforms.Resize([512, 512])(ref_img) | |
| ref_img = un_norm_clip(ref_img[i]).cpu() | |
| else: | |
| ref_img = un_norm(test_model_kwargs["ref512"].squeeze(1)[i]).cpu() | |
| all_img.append(ref_img) | |
| all_img.append(x_sample) | |
| grid = torch.stack(all_img, 0) | |
| grid = make_grid(grid) | |
| grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() | |
| img = Image.fromarray(grid.astype(np.uint8)) | |
| stem = f"{out_stem_batch[i]}" | |
| path_save_img = grid_path / f"grid-{stem}.jpg" | |
| img.save(path_save_img) | |
| print(f"{path_save_img=}") | |
| grids.append(img) | |
| grid_stems.append(stem) | |
| if len(grids) >= NUM_grid_in_a_column: | |
| stem_start = grid_stems[0] | |
| stem_end = grid_stems[-1] | |
| grid_column = imgs_2_grid_A( | |
| grids, | |
| grid_layout='column', | |
| grid_path=os.path.join(grid_path, f"{stem_start}--{stem_end}.jpg"), | |
| ) | |
| grids = [] | |
| grid_stems = [] | |
| model.unset_task() | |
| print(f"Your samples are ready and waiting for you here: {out_dir}") | |