Instructions to use neuralvfx/Z-Image-SAM-ControlNet with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use neuralvfx/Z-Image-SAM-ControlNet with Diffusers:
pip install -U diffusers transformers accelerate
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline controlnet = ControlNetModel.from_pretrained("neuralvfx/Z-Image-SAM-ControlNet") pipe = StableDiffusionControlNetPipeline.from_pretrained( "Tongyi-MAI/Z-Image", controlnet=controlnet ) - Notebooks
- Google Colab
- Kaggle
| # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved. | |
| # Refactored and optimized by DEVAIEXP Team | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from diffusers.configuration_utils import ConfigMixin, register_to_config | |
| from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin | |
| from diffusers.models.attention_dispatch import dispatch_attention_fn | |
| from diffusers.models.attention_processor import Attention, AttentionProcessor | |
| from diffusers.models.modeling_outputs import Transformer2DModelOutput | |
| from diffusers.models.modeling_utils import ModelMixin | |
| from diffusers.models.normalization import RMSNorm | |
| from diffusers.utils import ( | |
| is_torch_version, | |
| ) | |
| from diffusers.utils.torch_utils import maybe_allow_in_graph | |
| from torch.nn.utils.rnn import pad_sequence | |
| ADALN_EMBED_DIM = 256 | |
| SEQ_MULTI_OF = 32 | |
| def zero_module(module): | |
| """ | |
| Initializes the parameters of a given module with zeros. | |
| Args: | |
| module (nn.Module): The module to be zero-initialized. | |
| Returns: | |
| nn.Module: The same module with its parameters initialized to zero. | |
| """ | |
| for p in module.parameters(): | |
| nn.init.zeros_(p) | |
| return module | |
| class TimestepEmbedder(nn.Module): | |
| """ | |
| A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings | |
| followed by a multilayer perceptron (MLP). | |
| """ | |
| def __init__(self, out_size, mid_size=None, frequency_embedding_size=256): | |
| """ | |
| Initializes the TimestepEmbedder module. | |
| Args: | |
| out_size (int): The output dimension of the embedding. | |
| mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`. | |
| frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256. | |
| """ | |
| super().__init__() | |
| if mid_size is None: | |
| mid_size = out_size | |
| self.mlp = nn.Sequential( | |
| nn.Linear( | |
| frequency_embedding_size, | |
| mid_size, | |
| bias=True, | |
| ), | |
| nn.SiLU(), | |
| nn.Linear( | |
| mid_size, | |
| out_size, | |
| bias=True, | |
| ), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| """ | |
| Creates sinusoidal timestep embeddings. | |
| Args: | |
| t (torch.Tensor): A 1-D Tensor of N timesteps. | |
| dim (int): The dimension of the embedding. | |
| max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000. | |
| Returns: | |
| torch.Tensor: The timestep embeddings with shape (N, dim). | |
| """ | |
| with torch.amp.autocast("cuda", enabled=False): | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half) | |
| args = t[:, None] * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| """ | |
| Processes the input timesteps to generate embeddings. | |
| Args: | |
| t (torch.Tensor): The input timesteps. | |
| Returns: | |
| torch.Tensor: The final timestep embeddings after passing through the MLP. | |
| """ | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| weight_dtype = self.mlp[0].weight.dtype | |
| if weight_dtype.is_floating_point: | |
| t_freq = t_freq.to(weight_dtype) | |
| t_emb = self.mlp(t_freq) | |
| return t_emb | |
| class FeedForward(nn.Module): | |
| """ | |
| A Feed-Forward Network module using SwiGLU activation. | |
| """ | |
| def __init__(self, dim: int, hidden_dim: int): | |
| """ | |
| Initializes the FeedForward module. | |
| Args: | |
| dim (int): Input and output dimension. | |
| hidden_dim (int): The hidden dimension of the network. | |
| """ | |
| super().__init__() | |
| self.w1 = nn.Linear(dim, hidden_dim, bias=False) | |
| self.w2 = nn.Linear(hidden_dim, dim, bias=False) | |
| self.w3 = nn.Linear(dim, hidden_dim, bias=False) | |
| def _forward_silu_gating(self, x1, x3): | |
| """ | |
| Applies the SiLU gating mechanism. | |
| Args: | |
| x1 (torch.Tensor): The first intermediate tensor. | |
| x3 (torch.Tensor): The second intermediate tensor (gate). | |
| Returns: | |
| torch.Tensor: The result of the gating operation. | |
| """ | |
| return F.silu(x1) * x3 | |
| def forward(self, x): | |
| """ | |
| Defines the forward pass of the FeedForward network. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| Returns: | |
| torch.Tensor: The output tensor. | |
| """ | |
| return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) | |
| class FinalLayer(nn.Module): | |
| """ | |
| The final layer of the transformer, which applies AdaLN modulation and a linear projection. | |
| """ | |
| def __init__(self, hidden_size, out_channels): | |
| """ | |
| Initializes the FinalLayer module. | |
| Args: | |
| hidden_size (int): The input hidden size. | |
| out_channels (int): The output dimension (number of channels). | |
| """ | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True), | |
| ) | |
| def forward(self, x, c): | |
| """ | |
| Defines the forward pass for the final layer. | |
| Args: | |
| x (torch.Tensor): The main input tensor from the transformer blocks. | |
| c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation. | |
| Returns: | |
| torch.Tensor: The final output tensor projected to the patch dimension. | |
| """ | |
| scale = 1.0 + self.adaLN_modulation(c) | |
| x = self.norm_final(x) * scale.unsqueeze(1) | |
| x = self.linear(x) | |
| return x | |
| class RopeEmbedder: | |
| """ | |
| Computes Rotary Positional Embeddings (RoPE) for 3D coordinates. | |
| """ | |
| def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)): | |
| """ | |
| Initializes the RopeEmbedder. | |
| Args: | |
| theta (float, optional): The base for the rotary frequencies. Defaults to 256.0. | |
| axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48). | |
| axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512). | |
| """ | |
| self.theta = theta | |
| self.axes_dims = axes_dims | |
| self.axes_lens = axes_lens | |
| self.freqs_cis_cache = {} | |
| def _precompute_freqs_cis(self, device): | |
| """ | |
| Precomputes and caches the rotary frequency tensors (cos and sin values). | |
| Args: | |
| device (torch.device): The device to store the cached tensors on. | |
| Returns: | |
| List[torch.Tensor]: A list of precomputed frequency tensors for each axis. | |
| """ | |
| if device in self.freqs_cis_cache: | |
| return self.freqs_cis_cache[device] | |
| freqs_cis_list = [] | |
| for dim, max_len in zip(self.axes_dims, self.axes_lens): | |
| half = dim // 2 | |
| freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half)) | |
| t = torch.arange(max_len, device=device, dtype=torch.float32) | |
| freqs = torch.outer(t, freqs) | |
| emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1) | |
| freqs_cis_list.append(emb) | |
| self.freqs_cis_cache[device] = freqs_cis_list | |
| return freqs_cis_list | |
| def __call__(self, ids: torch.Tensor): | |
| """ | |
| Generates RoPE embeddings for a batch of 3D coordinates. | |
| Args: | |
| ids (torch.Tensor): A tensor of coordinates with shape (N, 3). | |
| Returns: | |
| torch.Tensor: The concatenated RoPE embeddings for the input coordinates. | |
| """ | |
| assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims) | |
| device = ids.device | |
| freqs_cis_list = self._precompute_freqs_cis(device) | |
| result = [] | |
| for i in range(len(self.axes_dims)): | |
| result.append(freqs_cis_list[i][ids[:, i]]) | |
| return torch.cat(result, dim=-2) | |
| class ZSingleStreamAttnProcessor: | |
| """ | |
| An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors | |
| before computing scaled dot-product attention. | |
| """ | |
| _attention_backend = None | |
| _parallel_config = None | |
| def __init__(self): | |
| """ | |
| Initializes the ZSingleStreamAttnProcessor. | |
| """ | |
| if not hasattr(F, "scaled_dot_product_attention"): | |
| raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.") | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| freqs_cis: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| The forward call for the attention processor. | |
| Args: | |
| attn (Attention): The attention layer that this processor is attached to. | |
| hidden_states (torch.Tensor): The input hidden states. | |
| encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None. | |
| attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None. | |
| freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None. | |
| Returns: | |
| torch.Tensor: The output of the attention mechanism. | |
| """ | |
| def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Applies RoPE to a query or key tensor. | |
| """ | |
| x = q_or_k.transpose(1, 2) | |
| x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2) | |
| x0 = x_reshaped[..., 0] | |
| x1 = x_reshaped[..., 1] | |
| freqs_cos = freqs_cis[..., 0].unsqueeze(1) | |
| freqs_sin = freqs_cis[..., 1].unsqueeze(1) | |
| x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin | |
| x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos | |
| x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1) | |
| x_out = x_rotated.flatten(-2).transpose(1, 2) | |
| return x_out.to(q_or_k.dtype) | |
| query = attn.to_q(hidden_states) | |
| key = attn.to_k(hidden_states) | |
| value = attn.to_v(hidden_states) | |
| query = query.unflatten(-1, (attn.heads, -1)) | |
| key = key.unflatten(-1, (attn.heads, -1)) | |
| value = value.unflatten(-1, (attn.heads, -1)) | |
| if attn.norm_q is not None: | |
| query = attn.norm_q(query) | |
| if attn.norm_k is not None: | |
| key = attn.norm_k(key) | |
| if freqs_cis is not None: | |
| query = apply_rotary_emb(query, freqs_cis) | |
| key = apply_rotary_emb(key, freqs_cis) | |
| if attention_mask is not None and attention_mask.ndim == 2: | |
| attention_mask = attention_mask[:, None, None, :] | |
| hidden_states = dispatch_attention_fn( | |
| query, | |
| key, | |
| value, | |
| attn_mask=attention_mask, | |
| dropout_p=0.0, | |
| is_causal=False, | |
| backend=self._attention_backend, | |
| parallel_config=self._parallel_config, | |
| ) | |
| hidden_states = hidden_states.flatten(2, 3) | |
| output = attn.to_out[0](hidden_states.to(hidden_states.dtype)) | |
| if len(attn.to_out) > 1: | |
| output = attn.to_out[1](output) | |
| return output | |
| class ZImageTransformerBlock(nn.Module): | |
| """ | |
| A standard transformer block consisting of a self-attention layer and a feed-forward network. | |
| Includes support for AdaLN modulation. | |
| """ | |
| def __init__( | |
| self, | |
| layer_id: int, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: int, | |
| norm_eps: float, | |
| qk_norm: bool, | |
| modulation=True, | |
| ): | |
| """ | |
| Initializes the ZImageTransformerBlock. | |
| Args: | |
| layer_id (int): The index of the layer. | |
| dim (int): The dimension of the input and output features. | |
| n_heads (int): The number of attention heads. | |
| n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention). | |
| norm_eps (float): Epsilon for RMSNorm. | |
| qk_norm (bool): Whether to apply normalization to query and key tensors. | |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. | |
| """ | |
| super().__init__() | |
| self.dim = dim | |
| self.head_dim = dim // n_heads | |
| self.attention = Attention( | |
| query_dim=dim, | |
| cross_attention_dim=None, | |
| dim_head=dim // n_heads, | |
| heads=n_heads, | |
| qk_norm="rms_norm" if qk_norm else None, | |
| eps=1e-5, | |
| bias=False, | |
| out_bias=False, | |
| processor=ZSingleStreamAttnProcessor(), | |
| ) | |
| self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8)) | |
| self.layer_id = layer_id | |
| self.attention_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm1 = RMSNorm(dim, eps=norm_eps) | |
| self.attention_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm2 = RMSNorm(dim, eps=norm_eps) | |
| self.modulation = modulation | |
| if modulation: | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True), | |
| ) | |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
| """ | |
| Returns a dictionary of all attention processors used in the module. | |
| """ | |
| processors = {} | |
| def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
| if hasattr(module, "get_processor"): | |
| processors[f"{name}.processor"] = module.get_processor() | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
| return processors | |
| for name, module in self.named_children(): | |
| fn_recursive_add_processors(name, module, processors) | |
| return processors | |
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
| """ | |
| Sets the attention processor for the attention layer in this block. | |
| """ | |
| count = len(self.attn_processors.keys()) | |
| if isinstance(processor, dict) and len(processor) != count: | |
| raise ValueError( | |
| f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
| f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
| ) | |
| def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
| if hasattr(module, "set_processor"): | |
| if not isinstance(processor, dict): | |
| module.set_processor(processor) | |
| else: | |
| module.set_processor(processor.pop(f"{name}.processor")) | |
| for sub_name, child in module.named_children(): | |
| fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
| for name, module in self.named_children(): | |
| fn_recursive_attn_processor(name, module, processor) | |
| def forward(self, x, attn_mask, freqs_cis, adaln_input=None): | |
| """ | |
| Defines the forward pass for the transformer block. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| attn_mask (torch.Tensor): The attention mask. | |
| freqs_cis (torch.Tensor): The RoPE frequencies. | |
| adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None. | |
| Returns: | |
| torch.Tensor: The output tensor of the block. | |
| """ | |
| if self.modulation: | |
| scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) | |
| scale_msa = scale_msa + 1.0 | |
| gate_msa = gate_msa.tanh() | |
| scale_mlp = scale_mlp + 1.0 | |
| gate_mlp = gate_mlp.tanh() | |
| normed = self.attention_norm1(x) | |
| normed = normed * scale_msa | |
| attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) | |
| attn_out = self.attention_norm2(attn_out) * gate_msa | |
| x = x + attn_out | |
| normed = self.ffn_norm1(x) | |
| normed = normed * scale_mlp | |
| ffn_out = self.feed_forward(normed) | |
| ffn_out = self.ffn_norm2(ffn_out) * gate_mlp | |
| x = x + ffn_out | |
| else: | |
| normed = self.attention_norm1(x) | |
| attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis) | |
| x = x + self.attention_norm2(attn_out) | |
| normed = self.ffn_norm1(x) | |
| ffn_out = self.feed_forward(normed) | |
| x = x + self.ffn_norm2(ffn_out) | |
| return x | |
| class ZImageControlTransformerBlock(ZImageTransformerBlock): | |
| """ | |
| A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock | |
| and adds projection layers to generate and combine control signals. | |
| """ | |
| def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): | |
| """ | |
| Initializes the ZImageControlTransformerBlock. | |
| Args: | |
| layer_id (int): The index of the layer. | |
| dim (int): The dimension of the features. | |
| n_heads (int): The number of attention heads. | |
| n_kv_heads (int): The number of key/value heads. | |
| norm_eps (float): Epsilon for RMSNorm. | |
| qk_norm (bool): Whether to apply normalization to query and key. | |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. | |
| block_id (int, optional): The index of this control block. Defaults to 0. | |
| """ | |
| super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) | |
| self.block_id = block_id | |
| if block_id == 0: | |
| self.before_proj = zero_module(nn.Linear(self.dim, self.dim)) | |
| self.after_proj = zero_module(nn.Linear(self.dim, self.dim)) | |
| def forward(self, c, x, **kwargs): | |
| """ | |
| Defines the forward pass for the control block. | |
| Args: | |
| c (torch.Tensor): The control signal tensor. | |
| x (torch.Tensor): The reference tensor from the main pathway. | |
| **kwargs: Additional arguments for the parent's forward method. | |
| Returns: | |
| torch.Tensor: A stacked tensor containing the skip connection and the final output. | |
| """ | |
| if self.block_id == 0: | |
| c = self.before_proj(c) + x | |
| all_c = [] | |
| else: | |
| all_c = list(torch.unbind(c)) | |
| c = all_c.pop(-1) | |
| c = super().forward(c, **kwargs) | |
| c_skip = self.after_proj(c) | |
| all_c += [c_skip, c] | |
| c = torch.stack(all_c) | |
| return c | |
| class BaseZImageTransformerBlock(ZImageTransformerBlock): | |
| """ | |
| The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock | |
| and adds the logic to inject control "hints" from the control pathway. | |
| """ | |
| def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0): | |
| """ | |
| Initializes the BaseZImageTransformerBlock. | |
| Args: | |
| layer_id (int): The index of the layer. | |
| dim (int): The dimension of the features. | |
| n_heads (int): The number of attention heads. | |
| n_kv_heads (int): The number of key/value heads. | |
| norm_eps (float): Epsilon for RMSNorm. | |
| qk_norm (bool): Whether to apply normalization to query and key. | |
| modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True. | |
| block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0. | |
| """ | |
| super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation) | |
| self.block_id = block_id | |
| def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs): | |
| """ | |
| Defines the forward pass, including the injection of control hints. | |
| Args: | |
| hidden_states (torch.Tensor): The input tensor. | |
| hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None. | |
| context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0. | |
| **kwargs: Additional arguments for the parent's forward method. | |
| Returns: | |
| torch.Tensor: The output tensor of the block. | |
| """ | |
| hidden_states = super().forward(hidden_states, **kwargs) | |
| if self.block_id is not None and hints is not None: | |
| hidden_states = hidden_states + hints[self.block_id] * context_scale | |
| return hidden_states | |
| class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): | |
| _supports_gradient_checkpointing = True | |
| _keys_to_ignore_on_load_unexpected = [ | |
| r"control_layers\..*", | |
| r"control_noise_refiner\..*", | |
| r"control_all_x_embedder\..*", | |
| ] | |
| _no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"] | |
| _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] | |
| _group_offload_block_modules = ["t_embedder", "cap_embedder"] | |
| def __init__( | |
| self, | |
| control_layers_places=None, | |
| control_refiner_layers_places=None, | |
| control_in_dim=None, | |
| add_control_noise_refiner=False, | |
| all_patch_size=(2,), | |
| all_f_patch_size=(1,), | |
| in_channels=16, | |
| dim=3840, | |
| n_layers=30, | |
| n_refiner_layers=2, | |
| n_heads=30, | |
| n_kv_heads=30, | |
| norm_eps=1e-5, | |
| qk_norm=True, | |
| cap_feat_dim=2560, | |
| rope_theta=256.0, | |
| t_scale=1000.0, | |
| axes_dims=[32, 48, 48], | |
| axes_lens=[1024, 512, 512], | |
| use_controlnet=True, | |
| checkpoint_ratio=0.5, | |
| ): | |
| """ | |
| Initializes the ZImageControlTransformer2DModel. | |
| Args: | |
| control_layers_places (List[int], optional): Indices of main layers where control hints are injected. | |
| control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control. | |
| control_in_dim (int, optional): Input channel dimension for the control context. | |
| add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal. | |
| all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions. | |
| all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension. | |
| in_channels (int, optional): Number of input channels for the latent image. | |
| dim (int, optional): The main dimension of the transformer model. | |
| n_layers (int, optional): The number of main transformer layers. | |
| n_refiner_layers (int, optional): The number of layers in the refiner blocks. | |
| n_heads (int, optional): The number of attention heads. | |
| n_kv_heads (int, optional): The number of key/value heads. | |
| norm_eps (float, optional): Epsilon for RMSNorm. | |
| qk_norm (bool, optional): Whether to apply normalization to query and key. | |
| cap_feat_dim (int, optional): The dimension of the input caption features. | |
| rope_theta (float, optional): The base for RoPE. | |
| t_scale (float, optional): A scaling factor for the timestep. | |
| axes_dims (List[int], optional): Dimensions for each axis in RoPE. | |
| axes_lens (List[int], optional): Maximum lengths for each axis in RoPE. | |
| use_controlnet (bool, optional): If False, control-related layers will not be created to save memory. | |
| checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. | |
| """ | |
| super().__init__() | |
| self.use_controlnet = use_controlnet | |
| self.in_channels = in_channels | |
| self.out_channels = in_channels | |
| self.all_patch_size = all_patch_size | |
| self.all_f_patch_size = all_f_patch_size | |
| self.dim = dim | |
| self.control_in_dim = self.dim if control_in_dim is None else control_in_dim | |
| self.is_two_stage_control = self.control_in_dim > 16 | |
| self.n_heads = n_heads | |
| self.rope_theta = rope_theta | |
| self.t_scale = t_scale | |
| self.gradient_checkpointing = False | |
| self.checkpoint_ratio = checkpoint_ratio | |
| assert len(all_patch_size) == len(all_f_patch_size) | |
| self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places | |
| self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places | |
| self.add_control_noise_refiner = add_control_noise_refiner | |
| assert 0 in self.control_layers_places | |
| self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)} | |
| self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)} | |
| self.all_x_embedder = nn.ModuleDict( | |
| { | |
| f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True) | |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) | |
| } | |
| ) | |
| self.all_final_layer = nn.ModuleDict( | |
| { | |
| f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels) | |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) | |
| } | |
| ) | |
| self.context_refiner = nn.ModuleList( | |
| [ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)] | |
| ) | |
| self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024) | |
| self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True)) | |
| self.x_pad_token = nn.Parameter(torch.empty((1, dim))) | |
| self.cap_pad_token = nn.Parameter(torch.empty((1, dim))) | |
| head_dim = dim // n_heads | |
| assert head_dim == sum(axes_dims) | |
| self.axes_dims = axes_dims | |
| self.axes_lens = axes_lens | |
| self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens) | |
| self.layers = nn.ModuleList( | |
| [BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)] | |
| ) | |
| self.noise_refiner = nn.ModuleList( | |
| [ | |
| BaseZImageTransformerBlock( | |
| 1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i) | |
| ) | |
| for i in range(n_refiner_layers) | |
| ] | |
| ) | |
| if self.use_controlnet: | |
| self.control_layers = nn.ModuleList( | |
| [ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places] | |
| ) | |
| self.control_all_x_embedder = nn.ModuleDict( | |
| { | |
| f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True) | |
| for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size) | |
| } | |
| ) | |
| if self.is_two_stage_control: | |
| if self.add_control_noise_refiner: | |
| self.control_noise_refiner = nn.ModuleList( | |
| [ | |
| ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id) | |
| for layer_id in range(n_refiner_layers) | |
| ] | |
| ) | |
| else: | |
| self.control_noise_refiner = None | |
| else: # V1 | |
| self.control_noise_refiner = nn.ModuleList( | |
| [ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)] | |
| ) | |
| else: | |
| self.control_layers = None | |
| self.control_all_x_embedder = None | |
| self.control_noise_refiner = None | |
| def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor: | |
| """ | |
| Converts a sequence of image tokens back into a batched image tensor. This version is robust | |
| to batches containing images of different original sizes. | |
| Args: | |
| x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim]. | |
| all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch. | |
| patch_size (int): The spatial patch size (height and width). | |
| f_patch_size (int): The frame/temporal patch size. | |
| Returns: | |
| torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W]. | |
| """ | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| batch_size = x_image_tokens.shape[0] | |
| unpatched_images = [] | |
| for i in range(batch_size): | |
| F, H, W = all_sizes[i] | |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW | |
| original_seq_len = F_tokens * H_tokens * W_tokens | |
| current_image_tokens = x_image_tokens[i, :original_seq_len, :] | |
| unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels) | |
| unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W) | |
| unpatched_images.append(unpatched_image) | |
| try: | |
| final_tensor = torch.stack(unpatched_images, dim=0) | |
| except RuntimeError: | |
| raise ValueError( | |
| "Could not stack unpatched images into a single batch tensor. " | |
| "This typically occurs if you are trying to generate images of different sizes in the same batch." | |
| ) | |
| return final_tensor | |
| def _patchify( | |
| self, | |
| all_image: List[torch.Tensor], | |
| patch_size: int, | |
| f_patch_size: int, | |
| cap_padding_len: int, | |
| ): | |
| """ | |
| Converts a list of image tensors into patch sequences and computes their positional IDs. | |
| Args: | |
| all_image (List[torch.Tensor]): A list of image tensors to process. | |
| patch_size (int): The spatial patch size. | |
| f_patch_size (int): The frame/temporal patch size. | |
| cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs. | |
| Returns: | |
| Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks. | |
| """ | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| device = all_image[0].device | |
| all_image_out = [] | |
| all_image_size = [] | |
| all_image_pos_ids = [] | |
| all_image_pad_mask = [] | |
| for i, image in enumerate(all_image): | |
| C, F, H, W = image.size() | |
| all_image_size.append((F, H, W)) | |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW | |
| image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) | |
| image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) | |
| image_ori_len = len(image) | |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF | |
| image_ori_pos_ids = self._create_coordinate_grid( | |
| size=(F_tokens, H_tokens, W_tokens), | |
| start=(cap_padding_len + 1, 0, 0), | |
| device=device, | |
| ).flatten(0, 2) | |
| image_padding_pos_ids = ( | |
| self._create_coordinate_grid( | |
| size=(1, 1, 1), | |
| start=(0, 0, 0), | |
| device=device, | |
| ) | |
| .flatten(0, 2) | |
| .repeat(image_padding_len, 1) | |
| ) | |
| image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) | |
| all_image_pos_ids.append(image_padded_pos_ids) | |
| all_image_pad_mask.append( | |
| torch.cat( | |
| [ | |
| torch.zeros((image_ori_len,), dtype=torch.bool, device=device), | |
| torch.ones((image_padding_len,), dtype=torch.bool, device=device), | |
| ], | |
| dim=0, | |
| ) | |
| ) | |
| image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) | |
| all_image_out.append(image_padded_feat) | |
| return ( | |
| all_image_out, | |
| all_image_size, | |
| all_image_pos_ids, | |
| all_image_pad_mask, | |
| ) | |
| def _patchify_and_embed( | |
| self, | |
| all_image: List[torch.Tensor], | |
| all_cap_feats: List[torch.Tensor], | |
| patch_size: int, | |
| f_patch_size: int, | |
| ): | |
| """ | |
| Processes a batch of images and caption features by converting them into padded patch sequences | |
| and generating their corresponding positional IDs and padding masks. This is the general-purpose, | |
| robust version that iterates through the batch. | |
| Args: | |
| all_image (List[torch.Tensor]): A list of image tensors. | |
| all_cap_feats (List[torch.Tensor]): A list of caption feature tensors. | |
| patch_size (int): The spatial patch size. | |
| f_patch_size (int): The frame/temporal patch size. | |
| Returns: | |
| Tuple: A tuple containing all processed data structures (image patches, caption features, sizes, | |
| position IDs, and padding masks) as lists. | |
| """ | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| device = all_image[0].device | |
| all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], [] | |
| all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], [] | |
| for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): | |
| cap_ori_len = len(cap_feat) | |
| cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF | |
| cap_total_len = cap_ori_len + cap_padding_len | |
| cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2) | |
| all_cap_pos_ids.append(cap_padded_pos_ids) | |
| cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device) | |
| cap_mask[:cap_ori_len] = False | |
| all_cap_pad_mask.append(cap_mask) | |
| if cap_padding_len > 0: | |
| padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1) | |
| cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0) | |
| else: | |
| cap_padded_feat = cap_feat | |
| all_cap_feats_out.append(cap_padded_feat) | |
| C, Fr, H, W = image.size() | |
| all_image_size.append((Fr, H, W)) | |
| F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW | |
| image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C) | |
| image_ori_len = image_reshaped.shape[0] | |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF | |
| image_total_len = image_ori_len + image_padding_len | |
| image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2) | |
| if image_padding_len > 0: | |
| image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device) | |
| image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) | |
| else: | |
| image_padded_pos_ids = image_ori_pos_ids | |
| all_image_pos_ids.append(image_padded_pos_ids) | |
| image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device) | |
| image_mask[:image_ori_len] = False | |
| all_image_pad_mask.append(image_mask) | |
| if image_padding_len > 0: | |
| padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1) | |
| image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0) | |
| else: | |
| image_padded_feat = image_reshaped | |
| all_image_out.append(image_padded_feat) | |
| return ( | |
| all_image_out, | |
| all_cap_feats_out, | |
| all_image_size, | |
| all_image_pos_ids, | |
| all_cap_pos_ids, | |
| all_image_pad_mask, | |
| all_cap_pad_mask, | |
| ) | |
| def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask): | |
| """ | |
| Processes caption features with intelligent duplicate detection to avoid redundant computation, | |
| especially for Classifier-Free Guidance (CFG) where prompts are repeated. | |
| Args: | |
| cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors. | |
| cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors. | |
| cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks. | |
| Returns: | |
| Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths. | |
| """ | |
| device = cap_feats_list[0].device | |
| bsz = len(cap_feats_list) | |
| shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list) | |
| if shapes_equal and bsz >= 2: | |
| unique_indices = [0] | |
| unique_tensors = [cap_feats_list[0]] | |
| tensor_mapping = [0] | |
| for i in range(1, bsz): | |
| found_match = False | |
| for j, unique_tensor in enumerate(unique_tensors): | |
| if torch.equal(cap_feats_list[i], unique_tensor): | |
| tensor_mapping.append(j) | |
| found_match = True | |
| break | |
| if not found_match: | |
| unique_indices.append(i) | |
| unique_tensors.append(cap_feats_list[i]) | |
| tensor_mapping.append(len(unique_tensors) - 1) | |
| if len(unique_tensors) < bsz: | |
| unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices] | |
| unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices] | |
| unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices] | |
| cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list] | |
| cap_max_item_seqlen = max(cap_item_seqlens_unique) | |
| cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0) | |
| cap_feats_embedded = self.cap_embedder(cap_feats_cat) | |
| cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token | |
| cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) | |
| cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0)) | |
| cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0) | |
| cap_feats_padded = cap_feats_padded_unique[tensor_mapping] | |
| cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping] | |
| seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32) | |
| arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) | |
| cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] | |
| cap_item_seqlens = [cap_max_item_seqlen] * bsz | |
| return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens | |
| cap_item_seqlens = [len(i) for i in cap_feats_list] | |
| cap_max_item_seqlen = max(cap_item_seqlens) | |
| cap_feats_cat = torch.cat(cap_feats_list, dim=0) | |
| cap_feats_embedded = self.cap_embedder(cap_feats_cat) | |
| cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token | |
| cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) | |
| cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)) | |
| cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) | |
| seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32) | |
| arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32) | |
| cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None] | |
| return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens | |
| def _create_coordinate_grid(size, start=None, device=None): | |
| """ | |
| Creates a 3D coordinate grid. | |
| Args: | |
| size (Tuple[int]): The dimensions of the grid (F, H, W). | |
| start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0). | |
| device (torch.device, optional): The device to create the tensor on. Defaults to None. | |
| Returns: | |
| torch.Tensor: The coordinate grid tensor. | |
| """ | |
| if start is None: | |
| start = (0 for _ in size) | |
| axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] | |
| grids = torch.meshgrid(axes, indexing="ij") | |
| return torch.stack(grids, dim=-1) | |
| def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs): | |
| """ | |
| Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing. | |
| Args: | |
| hidden_states (torch.Tensor): The input tensor. | |
| layers (nn.ModuleList): The list of transformer layers to apply. | |
| checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5. | |
| **kwargs: Additional keyword arguments to pass to each layer's forward method. | |
| Returns: | |
| torch.Tensor: The output tensor after applying all layers. | |
| """ | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| def create_custom_forward(module, **static_kwargs): | |
| def custom_forward(*inputs): | |
| return module(*inputs, **static_kwargs) | |
| return custom_forward | |
| ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
| checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1 | |
| for i, layer in enumerate(layers): | |
| if i % checkpoint_every_n == 0: | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(layer, **kwargs), | |
| hidden_states, | |
| **ckpt_kwargs, | |
| ) | |
| else: | |
| hidden_states = layer(hidden_states, **kwargs) | |
| else: | |
| for layer in layers: | |
| hidden_states = layer(hidden_states, **kwargs) | |
| return hidden_states | |
| def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device): | |
| """ | |
| Prepares the control context for the transformer, including patchifying, embedding, and generating | |
| positional information. Includes a fast path for batches with uniform shapes. | |
| Args: | |
| control_context (torch.Tensor or List[torch.Tensor]): The control context input. | |
| cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation. | |
| t (torch.Tensor): The timestep tensor. | |
| patch_size (int): The spatial patch size. | |
| f_patch_size (int): The frame/temporal patch size. | |
| device (torch.device): The target device. | |
| Returns: | |
| Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.). | |
| """ | |
| bsz = control_context.shape[0] | |
| if isinstance(control_context, torch.Tensor) and control_context.ndim == 5: | |
| control_list = list(torch.unbind(control_context, dim=0)) | |
| else: | |
| control_list = control_context | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1] | |
| shapes = [c.shape for c in control_list] | |
| same_shape = all(s == shapes[0] for s in shapes) | |
| if same_shape and bsz >= 2: | |
| control_batch = torch.stack(control_list, dim=0) | |
| B, C, F, H, W = control_batch.shape | |
| F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW | |
| control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) | |
| control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) | |
| ori_len = control_batch.shape[1] | |
| padding_len = (-ori_len) % SEQ_MULTI_OF | |
| if padding_len > 0: | |
| pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1) | |
| control_batch = torch.cat([control_batch, pad_tensor], dim=1) | |
| c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch) | |
| final_seq_len = control_batch.shape[1] | |
| pos_ids_ori = self._create_coordinate_grid( | |
| size=(F_tokens, H_tokens, W_tokens), | |
| start=(cap_padding_len + 1, 0, 0), | |
| device=device, | |
| ).flatten(0, 2) # [ori_len, 3] | |
| pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device) | |
| pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0) | |
| c_freqs_cis_single = self.rope_embedder(pos_ids_padded) | |
| c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1) | |
| c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device) | |
| return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)} | |
| (c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len) | |
| c_item_seqlens = [len(p) for p in c_patches] | |
| c_max_item_seqlen = max(c_item_seqlens) | |
| c = torch.cat(c_patches, dim=0) | |
| c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c) | |
| c[torch.cat(c_inner_pad_mask)] = self.x_pad_token | |
| c = list(c.split(c_item_seqlens, dim=0)) | |
| c_freqs_cis_list = [] | |
| for pos_ids in c_pos_ids: | |
| c_freqs_cis_list.append(self.rope_embedder(pos_ids)) | |
| c_padded = pad_sequence(c, batch_first=True, padding_value=0.0) | |
| c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0) | |
| seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32) | |
| arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32) | |
| c_attn_mask = arange[None, :] < seq_lens_tensor[:, None] | |
| return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)} | |
| def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size): | |
| """ | |
| An optimized version of _patchify_and_embed for batches where all images and captions have | |
| uniform shapes. It processes the entire batch using vectorized operations instead of a loop. | |
| Args: | |
| all_image (List[torch.Tensor]): List of image tensors, all of the same shape. | |
| all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape. | |
| patch_size (int): The spatial patch size. | |
| f_patch_size (int): The frame/temporal patch size. | |
| Returns: | |
| Tuple: A tuple containing all processed data structures, matching the output of the standard method. | |
| """ | |
| pH = pW = patch_size | |
| pF = f_patch_size | |
| device = all_image[0].device | |
| image_shapes = [img.shape for img in all_image] | |
| cap_shapes = [cap.shape for cap in all_cap_feats] | |
| same_image_shape = all(s == image_shapes[0] for s in image_shapes) | |
| same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes) | |
| if not (same_image_shape and same_cap_shape): | |
| return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size) | |
| images_batch = torch.stack(all_image, dim=0) | |
| caps_batch = torch.stack(all_cap_feats, dim=0) | |
| B, C, Fr, H, W = images_batch.shape | |
| cap_ori_len = caps_batch.shape[1] | |
| cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF | |
| cap_total_len = cap_ori_len + cap_padding_len | |
| if cap_padding_len > 0: | |
| cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1) | |
| caps_batch = torch.cat([caps_batch, cap_pad], dim=1) | |
| cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1) | |
| cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device) | |
| if cap_padding_len > 0: | |
| cap_mask[:, cap_ori_len:] = True | |
| F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW | |
| images_reshaped = ( | |
| images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW) | |
| .permute(0, 2, 4, 6, 3, 5, 7, 1) | |
| .reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C) | |
| ) | |
| image_ori_len = images_reshaped.shape[1] | |
| image_padding_len = (-image_ori_len) % SEQ_MULTI_OF | |
| image_total_len = image_ori_len + image_padding_len | |
| if image_padding_len > 0: | |
| img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1) | |
| images_reshaped = torch.cat([images_reshaped, img_pad], dim=1) | |
| image_pos_ids = ( | |
| self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device) | |
| .flatten(0, 2) | |
| .unsqueeze(0) | |
| .repeat(B, 1, 1) | |
| ) | |
| if image_padding_len > 0: | |
| img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device) | |
| image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1) | |
| image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device) | |
| if image_padding_len > 0: | |
| image_mask[:, image_ori_len:] = True | |
| all_image_size = [(Fr, H, W)] * B | |
| return ( | |
| list(torch.unbind(images_reshaped, dim=0)), | |
| list(torch.unbind(caps_batch, dim=0)), | |
| all_image_size, | |
| list(torch.unbind(image_pos_ids, dim=0)), | |
| list(torch.unbind(cap_pos_ids, dim=0)), | |
| list(torch.unbind(image_mask, dim=0)), | |
| list(torch.unbind(cap_mask, dim=0)), | |
| ) | |
| def forward( | |
| self, | |
| x: List[torch.Tensor], | |
| t, | |
| cap_feats: List[torch.Tensor], | |
| patch_size=2, | |
| f_patch_size=1, | |
| control_context=None, | |
| conditioning_scale=1.0, | |
| refiner_conditioning_scale=1.0, | |
| ): | |
| """ | |
| The main forward pass of the transformer model. | |
| Args: | |
| x (List[torch.Tensor]): | |
| A list of latent image tensors. | |
| t (torch.Tensor): | |
| A batch of timesteps. | |
| cap_feats (List[torch.Tensor]): | |
| A list of caption feature tensors. | |
| patch_size (int, optional): | |
| The spatial patch size to use. Defaults to 2. | |
| f_patch_size (int, optional): | |
| The frame/temporal patch size to use. Defaults to 1. | |
| control_context (torch.Tensor, optional): | |
| The control context tensor. Defaults to None. | |
| conditioning_scale (float, optional): | |
| The scale for applying control hints. Defaults to 1.0. | |
| refiner_conditioning_scale (float, optional): | |
| The scale for applying refiner control hints. Defaults to 1.0. | |
| Returns: | |
| Transformer2DModelOutput: An object containing the final denoised sample. | |
| """ | |
| is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0 | |
| if refiner_conditioning_scale is None: | |
| refiner_conditioning_scale = conditioning_scale or 1.0 | |
| assert patch_size in self.all_patch_size | |
| assert f_patch_size in self.all_f_patch_size | |
| bsz = len(x) | |
| device = x[0].device | |
| t = t * self.t_scale | |
| t = self.t_embedder(t) | |
| can_optimize_patchify = ( | |
| bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats) | |
| ) | |
| if can_optimize_patchify: | |
| (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized( | |
| x, cap_feats, patch_size, f_patch_size | |
| ) | |
| else: | |
| (x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed( | |
| x, cap_feats, patch_size, f_patch_size | |
| ) | |
| x_item_seqlens = [len(i) for i in x_list] | |
| x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0 | |
| x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device) | |
| x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat) | |
| if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any(): | |
| x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token | |
| x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) | |
| adaln_input = t.to(device).type_as(x) | |
| cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache( | |
| cap_feats_list, cap_pos_ids, cap_inner_pad_mask | |
| ) | |
| x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device) | |
| x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0) | |
| seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32) | |
| arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32) | |
| x_attn_mask = arange[None, :] < seq_lens_tensor[:, None] | |
| refiner_hints = None | |
| if is_control_mode and self.is_two_stage_control: | |
| prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) | |
| c = prepared_control["c"] | |
| """ | |
| kwargs_for_control_refiner = { | |
| "x": x, | |
| "attn_mask": prepared_control["attn_mask"], | |
| "freqs_cis": prepared_control["freqs_cis"], | |
| "adaln_input": prepared_control["adaln_input"], | |
| } | |
| c_processed = self._apply_transformer_blocks( | |
| c, | |
| self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, | |
| checkpoint_ratio=self.checkpoint_ratio, | |
| **kwargs_for_control_refiner, | |
| ) | |
| refiner_hints = torch.unbind(c_processed)[:-1] | |
| control_context_processed = torch.unbind(c_processed)[-1] | |
| control_context_item_seqlens = prepared_control["c_item_seqlens"] | |
| """ | |
| kwargs_for_control_refiner = { | |
| "x": x, | |
| "attn_mask": x_attn_mask, # was prepared_control["attn_mask"] | |
| "freqs_cis": x_freqs_cis, # was prepared_control["freqs_cis"] | |
| "adaln_input": adaln_input, | |
| } | |
| c_processed = self._apply_transformer_blocks( | |
| c, | |
| self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers, # KEEP ORIGINAL | |
| checkpoint_ratio=self.checkpoint_ratio, | |
| **kwargs_for_control_refiner, | |
| ) | |
| refiner_hints = torch.unbind(c_processed)[:-1] | |
| control_context_processed = torch.unbind(c_processed)[-1] | |
| control_context_item_seqlens = prepared_control["c_item_seqlens"] | |
| kwargs_for_refiner = { | |
| "attn_mask": x_attn_mask, | |
| "freqs_cis": x_freqs_cis, | |
| "adaln_input": adaln_input, | |
| "context_scale": refiner_conditioning_scale, | |
| } | |
| if refiner_hints is not None: | |
| kwargs_for_refiner["hints"] = refiner_hints | |
| x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner) | |
| kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis} | |
| cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context) | |
| unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)] | |
| unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0 | |
| unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device) | |
| unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device) | |
| for i in range(bsz): | |
| x_len = x_item_seqlens[i] | |
| cap_len = cap_item_seqlens[i] | |
| unified[i, :x_len] = x[i, :x_len] | |
| unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len] | |
| unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len] | |
| unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len] | |
| seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32) | |
| arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32) | |
| unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None] | |
| hints = None | |
| if is_control_mode: | |
| kwargs_for_hints = { | |
| "attn_mask": unified_attn_mask, | |
| "freqs_cis": unified_freqs_cis, | |
| "adaln_input": adaln_input, | |
| } | |
| if self.is_two_stage_control: | |
| control_context_unified_list = [ | |
| torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz) | |
| ] | |
| c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) | |
| new_kwargs = dict(x=unified, **kwargs_for_hints) | |
| c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) | |
| hints = torch.unbind(c_processed)[:-1] | |
| else: | |
| prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device) | |
| c = prepared_control["c"] | |
| kwargs_for_v1_refiner = { | |
| "attn_mask": prepared_control["attn_mask"], | |
| "freqs_cis": prepared_control["freqs_cis"], | |
| "adaln_input": prepared_control["adaln_input"], | |
| } | |
| c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner) | |
| c_item_seqlens = prepared_control["c_item_seqlens"] | |
| control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)] | |
| c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0) | |
| new_kwargs = dict(x=unified, **kwargs_for_hints) | |
| c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs) | |
| hints = torch.unbind(c_processed)[:-1] | |
| kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input} | |
| if hints is not None: | |
| kwargs_for_layers["hints"] = hints | |
| kwargs_for_layers["context_scale"] = conditioning_scale | |
| unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers) | |
| unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input) | |
| x_image_tokens = unified_out[:, :x_max_item_seqlen] | |
| x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size) | |
| return Transformer2DModelOutput(sample=x_final_tensor) | |