BiliSakura commited on
Commit
b8c42d6
·
verified ·
1 Parent(s): a310d14

Update README.md and *.py files for RSEdit-DiT

Browse files
Files changed (1) hide show
  1. pipeline_rsedit_dit.py +490 -0
pipeline_rsedit_dit.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2025 The RSEdit Team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ RSEdit DiT Pipeline for Remote Sensing Image Editing using Token Concatenation.
19
+
20
+ This pipeline extends PixArtAlphaPipeline to support image-to-image editing
21
+ by concatenating source image tokens with noisy latent tokens, allowing the
22
+ DiT to leverage its sequence modeling capabilities for instruction-based editing.
23
+ """
24
+
25
+ from typing import Callable, List, Optional, Union
26
+
27
+ import PIL.Image
28
+ import numpy as np
29
+ import torch
30
+ from transformers import T5EncoderModel, T5Tokenizer
31
+
32
+ from diffusers import PixArtAlphaPipeline, AutoencoderKL, PixArtTransformer2DModel
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ replace_example_docstring,
36
+ )
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+ from diffusers.pipelines.pipeline_utils import ImagePipelineOutput
39
+
40
+
41
+ EXAMPLE_DOC_STRING = """
42
+ Examples:
43
+ ```py
44
+ >>> import torch
45
+ >>> from PIL import Image
46
+ >>> from pipeline_rsedit_dit import RSEditDiTPipeline
47
+
48
+ >>> # Load pipeline
49
+ >>> pipe = RSEditDiTPipeline.from_pretrained(
50
+ ... "path/to/rsedit-dit-model",
51
+ ... torch_dtype=torch.float16
52
+ ... )
53
+ >>> pipe = pipe.to("cuda")
54
+
55
+ >>> # Load source satellite image
56
+ >>> source_image = Image.open("satellite_image.png").convert("RGB")
57
+
58
+ >>> # Edit with instruction
59
+ >>> prompt = "Flood the coastal area"
60
+ >>> edited_image = pipe(
61
+ ... prompt=prompt,
62
+ ... source_image=source_image,
63
+ ... num_inference_steps=50,
64
+ ... guidance_scale=4.5,
65
+ ... ).images[0]
66
+
67
+ >>> edited_image.save("flooded_coastal_area.png")
68
+ ```
69
+ """
70
+
71
+
72
+ class RSEditDiTPipeline(PixArtAlphaPipeline):
73
+ """
74
+ Pipeline for RSEdit: Remote Sensing Image Editing using DiT with Token Concatenation.
75
+
76
+ This pipeline extends PixArtAlphaPipeline to support instruction-based image editing
77
+ for satellite imagery. It uses the Token Concatenation strategy where source image
78
+ latents are concatenated with noisy target latents along the spatial width dimension,
79
+ allowing the transformer to perform in-context learning for image-to-image translation.
80
+
81
+ The pipeline uses the following components:
82
+ - PixArtTransformer2DModel: Diffusion Transformer for denoising
83
+ - T5EncoderModel: Text encoder for instruction embeddings
84
+ - AutoencoderKL: VAE for encoding/decoding images to/from latent space
85
+
86
+ Args:
87
+ vae ([`AutoencoderKL`]):
88
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
89
+ text_encoder ([`T5EncoderModel`]):
90
+ Frozen text-encoder. PixArt-Alpha uses T5.
91
+ tokenizer ([`T5Tokenizer`]):
92
+ Tokenizer of class T5Tokenizer.
93
+ transformer ([`PixArtTransformer2DModel`]):
94
+ A PixArt transformer to denoise the encoded image latents.
95
+ scheduler ([`KarrasDiffusionSchedulers`]):
96
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ vae: AutoencoderKL,
102
+ text_encoder: T5EncoderModel,
103
+ tokenizer: T5Tokenizer,
104
+ transformer: PixArtTransformer2DModel,
105
+ scheduler: KarrasDiffusionSchedulers,
106
+ ):
107
+ super().__init__(
108
+ vae=vae,
109
+ text_encoder=text_encoder,
110
+ tokenizer=tokenizer,
111
+ transformer=transformer,
112
+ scheduler=scheduler,
113
+ )
114
+
115
+ def _encode_source_image(
116
+ self,
117
+ source_image: PIL.Image.Image,
118
+ device: torch.device,
119
+ dtype: torch.dtype,
120
+ num_images_per_prompt: int = 1,
121
+ ) -> torch.Tensor:
122
+ """
123
+ Encode the source image into latent space.
124
+
125
+ Args:
126
+ source_image: PIL Image to encode
127
+ device: Device to place the latents on
128
+ dtype: Data type for the latents (used for output, VAE uses its own dtype)
129
+ num_images_per_prompt: Number of images to generate per prompt
130
+
131
+ Returns:
132
+ Encoded latents of shape (batch_size * num_images_per_prompt, channels, height, width)
133
+ """
134
+ # Convert PIL image to tensor
135
+ image_np = np.array(source_image.convert("RGB")).astype(np.float32) / 127.5 - 1.0
136
+ image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).unsqueeze(0)
137
+ # Use VAE's dtype for encoding to ensure compatibility with mixed precision
138
+ image_tensor = image_tensor.to(device=device, dtype=self.vae.dtype)
139
+
140
+ # Encode to latent space (use mode for deterministic encoding)
141
+ latents = self.vae.encode(image_tensor).latent_dist.mode()
142
+ latents = latents * self.vae.config.scaling_factor
143
+
144
+ # Ensure latents are on the correct device (critical for multi-GPU with device_map)
145
+ # The VAE encoder might be on a different GPU, so we need to move the output
146
+ latents = latents.to(device=device)
147
+
148
+ # Cast back to requested dtype for pipeline consistency
149
+ latents = latents.to(dtype=dtype)
150
+
151
+ # Duplicate for num_images_per_prompt
152
+ if num_images_per_prompt > 1:
153
+ latents = latents.repeat(num_images_per_prompt, 1, 1, 1)
154
+
155
+ return latents
156
+
157
+ @torch.no_grad()
158
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
159
+ def __call__(
160
+ self,
161
+ prompt: Union[str, List[str]] = None,
162
+ source_image: Union[PIL.Image.Image, List[PIL.Image.Image]] = None,
163
+ negative_prompt: str = "",
164
+ num_inference_steps: int = 50,
165
+ timesteps: List[int] = None,
166
+ guidance_scale: float = 4.5,
167
+ image_guidance_scale: Optional[float] = None,
168
+ num_images_per_prompt: Optional[int] = 1,
169
+ height: Optional[int] = None,
170
+ width: Optional[int] = None,
171
+ eta: float = 0.0,
172
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
173
+ latents: Optional[torch.FloatTensor] = None,
174
+ prompt_embeds: Optional[torch.FloatTensor] = None,
175
+ prompt_attention_mask: Optional[torch.FloatTensor] = None,
176
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
177
+ negative_prompt_attention_mask: Optional[torch.FloatTensor] = None,
178
+ output_type: Optional[str] = "pil",
179
+ return_dict: bool = True,
180
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
181
+ callback_steps: int = 1,
182
+ clean_caption: bool = True,
183
+ use_resolution_binning: bool = True,
184
+ max_sequence_length: int = 120,
185
+ **kwargs,
186
+ ) -> Union[ImagePipelineOutput, tuple]:
187
+ """
188
+ Function invoked when calling the pipeline for generation.
189
+
190
+ Args:
191
+ prompt (`str` or `List[str]`, *optional*):
192
+ The editing instruction prompt or prompts to guide image generation. If not defined, you need
193
+ to pass `prompt_embeds`.
194
+ source_image (`PIL.Image.Image` or `List[PIL.Image.Image]`):
195
+ The source satellite image(s) to edit.
196
+ negative_prompt (`str` or `List[str]`, *optional*):
197
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
198
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
199
+ less than `1`).
200
+ num_inference_steps (`int`, *optional*, defaults to 50):
201
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
202
+ expense of slower inference.
203
+ timesteps (`List[int]`, *optional*):
204
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
205
+ timesteps are used.
206
+ guidance_scale (`float`, *optional*, defaults to 4.5):
207
+ Guidance scale as defined in Classifier-Free Guidance (CFG).
208
+ Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
209
+ usually at the expense of lower image quality.
210
+ image_guidance_scale (`float`, *optional*):
211
+ Image guidance scale for controlling the influence of the source image. If None, uses `guidance_scale`.
212
+ This allows separate control over text vs. image conditioning strength.
213
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
214
+ The number of images to generate per prompt.
215
+ height (`int`, *optional*, defaults to self.transformer.config.sample_size):
216
+ The height in pixels of the generated image.
217
+ width (`int`, *optional*, defaults to self.transformer.config.sample_size):
218
+ The width in pixels of the generated image.
219
+ eta (`float`, *optional*, defaults to 0.0):
220
+ Corresponds to parameter eta (η) in the DDIM paper. Only applies to DDIMScheduler.
221
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
222
+ One or a list of torch generator(s) to make generation deterministic.
223
+ latents (`torch.FloatTensor`, *optional*):
224
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
225
+ generation. Can be used to tweak the same generation with different prompts.
226
+ prompt_embeds (`torch.FloatTensor`, *optional*):
227
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
228
+ prompt_attention_mask (`torch.FloatTensor`, *optional*):
229
+ Pre-generated attention mask for text embeddings.
230
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
231
+ Pre-generated negative text embeddings.
232
+ negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
233
+ Pre-generated attention mask for negative text embeddings.
234
+ output_type (`str`, *optional*, defaults to `"pil"`):
235
+ The output format of the generate image. Choose between
236
+ `"pil"` (`PIL.Image.Image`), `"np"` (`np.array`) or `"pt"` (`torch.Tensor`).
237
+ return_dict (`bool`, *optional*, defaults to `True`):
238
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
239
+ callback (`Callable`, *optional*):
240
+ A function that will be called every `callback_steps` steps during inference.
241
+ callback_steps (`int`, *optional*, defaults to 1):
242
+ The frequency at which the `callback` function will be called.
243
+ clean_caption (`bool`, *optional*, defaults to `True`):
244
+ Whether or not to clean the caption before creating embeddings.
245
+ use_resolution_binning (`bool`, *optional*, defaults to `True`):
246
+ Whether to use resolution binning for PixArt models.
247
+ max_sequence_length (`int`, *optional*, defaults to 120):
248
+ Maximum sequence length for text encoder.
249
+
250
+ Returns:
251
+ [`~pipelines.ImagePipelineOutput`] or `tuple`:
252
+ If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
253
+ returned where the first element is a list with the generated images.
254
+
255
+ Examples:
256
+ """
257
+ # 1. Check inputs
258
+ if source_image is None:
259
+ raise ValueError("`source_image` must be provided for RSEdit image editing.")
260
+
261
+ if prompt is None and prompt_embeds is None:
262
+ raise ValueError("Either `prompt` or `prompt_embeds` must be provided.")
263
+
264
+ if height is None:
265
+ height = self.transformer.config.sample_size * self.vae_scale_factor
266
+ if width is None:
267
+ width = self.transformer.config.sample_size * self.vae_scale_factor
268
+
269
+ # 2. Define call parameters
270
+ if prompt is not None and isinstance(prompt, str):
271
+ batch_size = 1
272
+ elif prompt is not None and isinstance(prompt, list):
273
+ batch_size = len(prompt)
274
+ else:
275
+ batch_size = prompt_embeds.shape[0]
276
+
277
+ device = self._execution_device
278
+
279
+ # 3. Encode source image
280
+ if isinstance(source_image, PIL.Image.Image):
281
+ source_image = source_image.resize((width, height), PIL.Image.LANCZOS)
282
+ elif isinstance(source_image, list):
283
+ source_image = [img.resize((width, height), PIL.Image.LANCZOS) for img in source_image]
284
+ if len(source_image) != batch_size:
285
+ raise ValueError(
286
+ f"Number of source images ({len(source_image)}) must match batch size ({batch_size})"
287
+ )
288
+
289
+ # Encode source image(s)
290
+ if isinstance(source_image, list):
291
+ source_latents_list = []
292
+ for img in source_image:
293
+ latent = self._encode_source_image(
294
+ img, device, self.vae.dtype, num_images_per_prompt
295
+ )
296
+ # Ensure latent is on the correct device (critical for multi-GPU)
297
+ latent = latent.to(device=device)
298
+ source_latents_list.append(latent)
299
+ source_latents = torch.cat(source_latents_list, dim=0)
300
+ else:
301
+ source_latents = self._encode_source_image(
302
+ source_image, device, self.vae.dtype, num_images_per_prompt
303
+ )
304
+
305
+ # Duplicate source latents for batch
306
+ if batch_size > 1 and source_latents.shape[0] == 1:
307
+ source_latents = source_latents.repeat(batch_size * num_images_per_prompt, 1, 1, 1)
308
+
309
+ # 4. Encode input prompt
310
+ # Default image_guidance_scale to guidance_scale if not provided
311
+ if image_guidance_scale is None:
312
+ image_guidance_scale = guidance_scale
313
+
314
+ do_classifier_free_guidance = guidance_scale > 1.0 and image_guidance_scale >= 1.0
315
+
316
+ (
317
+ prompt_embeds,
318
+ prompt_attention_mask,
319
+ negative_prompt_embeds,
320
+ negative_prompt_attention_mask,
321
+ ) = self.encode_prompt(
322
+ prompt,
323
+ do_classifier_free_guidance,
324
+ negative_prompt=negative_prompt,
325
+ num_images_per_prompt=num_images_per_prompt,
326
+ device=device,
327
+ prompt_embeds=prompt_embeds,
328
+ negative_prompt_embeds=negative_prompt_embeds,
329
+ prompt_attention_mask=prompt_attention_mask,
330
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
331
+ clean_caption=clean_caption,
332
+ max_sequence_length=max_sequence_length,
333
+ )
334
+
335
+ if do_classifier_free_guidance:
336
+ # For InstructPix2Pix: [text_embeds, negative_embeds, negative_embeds]
337
+ # Corresponds to: [Text+Image, Image-Only, Unconditional]
338
+ # But standard diffusers usually does [neg, pos].
339
+ # We need 3 components for IP2P: (Text+Image), (Image), (None)
340
+
341
+ # Re-arranging to match: [Text+Image, Image, None]
342
+ # prompt_embeds contains the "positive" text.
343
+ # negative_prompt_embeds contains the "negative/null" text.
344
+
345
+ # Batch structure: [Positive Text, Negative Text, Negative Text]
346
+ prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0)
347
+ prompt_attention_mask = torch.cat([prompt_attention_mask, negative_prompt_attention_mask, negative_prompt_attention_mask], dim=0)
348
+
349
+ # 5. Prepare timesteps
350
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
351
+ timesteps = self.scheduler.timesteps
352
+
353
+ # 6. Prepare latent variables
354
+ num_channels_latents = self.transformer.config.in_channels
355
+ latents = self.prepare_latents(
356
+ batch_size * num_images_per_prompt,
357
+ num_channels_latents,
358
+ height,
359
+ width,
360
+ prompt_embeds.dtype,
361
+ device,
362
+ generator,
363
+ latents,
364
+ )
365
+
366
+ # Ensure source_latents are on the same device as latents (critical for multi-GPU)
367
+ # This ensures all concatenations in the denoising loop work correctly
368
+ source_latents = source_latents.to(device=latents.device)
369
+
370
+ # 7. Prepare extra step kwargs
371
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
372
+
373
+ # 8. Prepare added time ids & embeddings (resolution, aspect ratio)
374
+ added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
375
+ if self.transformer.config.sample_size == 128:
376
+ resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
377
+ aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
378
+ resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
379
+ aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
380
+
381
+ if do_classifier_free_guidance:
382
+ resolution = torch.cat([resolution, resolution, resolution], dim=0)
383
+ aspect_ratio = torch.cat([aspect_ratio, aspect_ratio, aspect_ratio], dim=0)
384
+
385
+ added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
386
+
387
+ # 9. Denoising loop with Token Concatenation
388
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
389
+
390
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
391
+ for i, t in enumerate(timesteps):
392
+ # Expand latents for classifier-free guidance
393
+ # IP2P: 3 copies [latents, latents, latents]
394
+ latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents
395
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
396
+
397
+ # **Token Concatenation Strategy**
398
+ # Concatenate source image latents with noisy target latents along width dimension
399
+ # IP2P Batch: [Source, Source, Zero_Source]
400
+ if do_classifier_free_guidance:
401
+ source_latents_input = torch.cat([source_latents, source_latents, torch.zeros_like(source_latents)], dim=0)
402
+ else:
403
+ source_latents_input = source_latents
404
+
405
+ # Ensure both tensors are on the same device before concatenation (critical for multi-GPU)
406
+ source_latents_input = source_latents_input.to(device=latent_model_input.device)
407
+ concatenated_latents = torch.cat([source_latents_input, latent_model_input], dim=3)
408
+
409
+ # Expand the timesteps for the expanded latents (CFG)
410
+ current_timestep = t
411
+ if not torch.is_tensor(current_timestep):
412
+ is_mps = concatenated_latents.device.type == "mps"
413
+ is_npu = concatenated_latents.device.type == "npu"
414
+ if isinstance(current_timestep, float):
415
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
416
+ else:
417
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
418
+ current_timestep = torch.tensor([current_timestep], dtype=dtype, device=concatenated_latents.device)
419
+ elif len(current_timestep.shape) == 0:
420
+ current_timestep = current_timestep[None].to(concatenated_latents.device)
421
+
422
+ current_timestep = current_timestep.expand(concatenated_latents.shape[0])
423
+
424
+ # Predict noise residual
425
+ noise_pred = self.transformer(
426
+ concatenated_latents,
427
+ encoder_hidden_states=prompt_embeds,
428
+ encoder_attention_mask=prompt_attention_mask,
429
+ timestep=current_timestep,
430
+ added_cond_kwargs=added_cond_kwargs,
431
+ return_dict=False,
432
+ )[0]
433
+
434
+ # **Extract target portion** (right half corresponding to edited image)
435
+ # The model predicts noise for both source and target, we only need target
436
+ target_width = latents.shape[3]
437
+ noise_pred = noise_pred[:, :, :, target_width:]
438
+
439
+ # Split model prediction if it contains variance (PixArt can output 8 channels)
440
+ if noise_pred.shape[1] == 2 * num_channels_latents:
441
+ noise_pred, _ = noise_pred.chunk(2, dim=1)
442
+
443
+ # Perform classifier-free guidance
444
+ if do_classifier_free_guidance:
445
+ # noise_pred batch: [Text+Image, Image, None]
446
+ noise_pred_text, noise_pred_image, noise_pred_uncond = noise_pred.chunk(3)
447
+
448
+ # IP2P CFG Formula:
449
+ # pred = uncond + s_text * (text - image) + s_image * (image - uncond)
450
+ # Note: s_text is usually just `guidance_scale`. s_image is `image_guidance_scale`.
451
+
452
+ # But wait, standard CFG is: uncond + s * (cond - uncond)
453
+ # IP2P paper eq: e_theta = e(phi, phi) + s_T * (e(c_I, c_T) - e(c_I, phi)) + s_I * (e(c_I, phi) - e(phi, phi))
454
+ # Mapping:
455
+ # e(c_I, c_T) -> noise_pred_text (Full)
456
+ # e(c_I, phi) -> noise_pred_image (Image only, Null Text)
457
+ # e(phi, phi) -> noise_pred_uncond (Unconditional)
458
+
459
+ noise_pred = (
460
+ noise_pred_uncond
461
+ + guidance_scale * (noise_pred_text - noise_pred_image)
462
+ + image_guidance_scale * (noise_pred_image - noise_pred_uncond)
463
+ )
464
+
465
+ # Compute previous noisy sample: x_t -> x_t-1
466
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
467
+
468
+ # Call the callback, if provided
469
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
470
+ progress_bar.update()
471
+ if callback is not None and i % callback_steps == 0:
472
+ step_idx = i // getattr(self.scheduler, "order", 1)
473
+ callback(step_idx, t, latents)
474
+
475
+ # 10. Post-processing
476
+ if not output_type == "latent":
477
+ # Cast latents to VAE dtype to ensure compatibility (fixes bf16 training validation)
478
+ image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
479
+ image = self.image_processor.postprocess(image, output_type=output_type)
480
+ else:
481
+ image = latents
482
+
483
+ # Offload all models
484
+ self.maybe_free_model_hooks()
485
+
486
+ if not return_dict:
487
+ return (image,)
488
+
489
+ return ImagePipelineOutput(images=image)
490
+