| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Denoising block for WorldEngine modular pipeline.""" |
|
|
| from typing import List |
|
|
| import torch |
|
|
| from diffusers.utils import logging |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ModularPipeline, |
| PipelineState, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| InputParam, |
| OutputParam, |
| ) |
| from diffusers import AutoModel |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class WorldEngineDenoiseLoop(ModularPipelineBlocks): |
| """Denoises latents using rectified flow and updates KV cache.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def expected_components(self) -> List[ComponentSpec]: |
| return [ComponentSpec("transformer", AutoModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoises latents using rectified flow (x = x + dsigma * v) " |
| "and updates KV cache for autoregressive generation." |
| ) |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "scheduler_sigmas", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Scheduler sigmas for denoising", |
| ), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Initial noisy latents [1, 1, C, H, W]", |
| ), |
| InputParam( |
| "kv_cache", |
| required=True, |
| description="KV cache for transformer attention", |
| ), |
| InputParam( |
| "frame_timestamp", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Text embeddings for conditioning", |
| ), |
| InputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| description="Padding mask for prompt embeddings", |
| ), |
| InputParam( |
| "button_tensor", |
| required=True, |
| type_hint=torch.Tensor, |
| description="One-hot encoded button tensor", |
| ), |
| InputParam( |
| "mouse_tensor", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Mouse velocity tensor", |
| ), |
| InputParam( |
| "scroll_tensor", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Scroll wheel sign tensor", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "latents", |
| type_hint=torch.Tensor, |
| description="Denoised latents", |
| ), |
| ] |
|
|
| @staticmethod |
| def _denoise_pass( |
| transformer, |
| x, |
| sigmas, |
| frame_timestamp, |
| prompt_emb, |
| prompt_pad_mask, |
| mouse, |
| button, |
| scroll, |
| kv_cache, |
| ): |
| """Denoising loop using rectified flow.""" |
| kv_cache.set_frozen(True) |
| sigma = x.new_empty((x.size(0), x.size(1))) |
| for step_sig, step_dsig in zip(sigmas, sigmas.diff()): |
| v = transformer( |
| x=x, |
| sigma=sigma.fill_(step_sig), |
| frame_timestamp=frame_timestamp, |
| prompt_emb=prompt_emb, |
| prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, |
| button=button, |
| scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
| x = x + step_dsig * v |
| return x |
|
|
| @staticmethod |
| def _cache_pass( |
| transformer, |
| x, |
| frame_timestamp, |
| prompt_emb, |
| prompt_pad_mask, |
| mouse, |
| button, |
| scroll, |
| kv_cache, |
| ): |
| """Cache pass to persist frame for next generation.""" |
| kv_cache.set_frozen(False) |
| transformer( |
| x=x, |
| sigma=x.new_zeros((x.size(0), x.size(1))), |
| frame_timestamp=frame_timestamp, |
| prompt_emb=prompt_emb, |
| prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, |
| button=button, |
| scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
|
|
| @torch.inference_mode() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| block_state.latents = self._denoise_pass( |
| components.transformer, |
| block_state.latents, |
| block_state.scheduler_sigmas, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ).clone() |
|
|
| self._cache_pass( |
| components.transformer, |
| block_state.latents, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ) |
| block_state.frame_timestamp.add_(1) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|