| from typing import Any |
| import torch |
| from torch import nn |
| import math |
| from fractions import Fraction |
| from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig |
| from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel |
| import torch.nn.functional as F |
|
|
|
|
| class QFormerCrossAttention(nn.Module): |
| """Multi-headed cross-attention for QFormer with SDPA/Flash Attention support""" |
| |
| def __init__(self, hidden_size, num_heads, attn_bias=False, attention_dropout=0.05, final_dropout=0.05): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.attention_dropout = attention_dropout |
| |
| if self.head_dim * num_heads != hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size} " |
| f"and `num_heads`: {num_heads})." |
| ) |
| |
| |
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
| self.k_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
| self.v_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
| self.o_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias) |
| self.dropout = nn.Dropout(final_dropout) |
| |
| def forward(self, hidden_states, encoder_hidden_states, attention_mask=None): |
| """ |
| Args: |
| hidden_states: (B, query_len, hidden_size) - queries |
| encoder_hidden_states: (B, encoder_len, hidden_size) - keys and values |
| attention_mask: optional attention mask |
| Returns: |
| (B, query_len, hidden_size) |
| """ |
| batch_size, query_len, _ = hidden_states.shape |
| encoder_len = encoder_hidden_states.shape[1] |
| |
| |
| query_states = self.q_proj(hidden_states).view( |
| batch_size, query_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| |
| |
| key_states = self.k_proj(encoder_hidden_states).view( |
| batch_size, encoder_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| value_states = self.v_proj(encoder_hidden_states).view( |
| batch_size, encoder_len, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| |
| |
| |
| attn_output = torch.nn.functional.scaled_dot_product_attention( |
| query_states, |
| key_states, |
| value_states, |
| attn_mask=attention_mask, |
| dropout_p=self.attention_dropout if self.training else 0.0, |
| is_causal=False, |
| ) |
| |
| |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size) |
| attn_output = self.o_proj(attn_output) |
| |
| attn_output = self.dropout(attn_output) |
| return attn_output |
|
|
|
|
| class QFormerMLP(nn.Module): |
| """Feed-forward network (MLP) for QFormer with SiLU activation""" |
| |
| def __init__(self, hidden_size, mlp_hidden_size, mlp_bias=False, dropout_prob=0.05): |
| super().__init__() |
| self.hidden_size = hidden_size |
| |
| self.fc1 = nn.Linear(hidden_size, mlp_hidden_size, bias=mlp_bias) |
| self.act = nn.SiLU() |
| self.fc2 = nn.Linear(mlp_hidden_size, hidden_size, bias=mlp_bias) |
| self.dropout = nn.Dropout(dropout_prob) |
| |
| def forward(self, hidden_states): |
| """ |
| Args: |
| hidden_states: (B, seq_len, hidden_size) |
| |
| Returns: |
| (B, seq_len, hidden_size) |
| """ |
| hidden_states = self.fc1(hidden_states) |
| hidden_states = self.act(hidden_states) |
| hidden_states = self.dropout(self.fc2(hidden_states)) |
| return hidden_states |
|
|
|
|
| class SimplifiedQFormer(nn.Module): |
| """ |
| Simplified QFormer with a single cross-attention layer followed by an MLP. |
| Lightweight design: queries attend to encoder hidden states via cross-attention, |
| then pass through a feed-forward network, similar to a transformer block. |
| """ |
| |
| def __init__(self, hidden_size, num_heads=8, mlp_hidden_size=2048, mlp_bias=False, attn_bias=False, eps=1e-6): |
| super().__init__() |
| self.hidden_size = hidden_size |
| self.num_heads = num_heads |
| |
| |
| self.attn_norm = nn.LayerNorm(hidden_size, eps=eps) |
| self.cross_attention = QFormerCrossAttention( |
| hidden_size, num_heads, attn_bias=attn_bias, |
| ) |
| |
| |
| self.mlp_norm = nn.LayerNorm(hidden_size, eps=eps) |
| self.mlp = QFormerMLP(hidden_size, mlp_hidden_size, mlp_bias=mlp_bias) |
| |
| def forward(self, query_embeds, encoder_hidden_states): |
| """ |
| Args: |
| query_embeds: (B, num_queries, hidden_size) - learnable queries |
| encoder_hidden_states: (B, num_tokens, hidden_size) - input features |
| |
| Returns: |
| (B, num_queries, hidden_size) - output features |
| """ |
| |
| residual = query_embeds |
| hidden_states = self.attn_norm(query_embeds) |
| hidden_states = self.cross_attention(hidden_states, encoder_hidden_states) |
| hidden_states = residual + hidden_states |
| |
| |
| residual = hidden_states |
| hidden_states = self.mlp_norm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| |
| return hidden_states |
|
|
|
|
|
|
| class InterpolateDownsampler: |
| def __init__(self, config, mode="area"): |
| self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size |
| self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate)) |
| self.mode = mode |
| |
| def __call__(self, image_features): |
| batch_size, _, dim = image_features.size() |
| up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim] |
| |
| large_image_permuted = image_features.view(up_shape).permute(0,3,1,2) |
| small_image_permuted = torch.nn.functional.interpolate( |
| large_image_permuted, size=(self.new_image_side, self.new_image_side), |
| mode=self.mode, |
| ) |
| |
| final = small_image_permuted.permute(0,2,3,1).flatten(1,2) |
| return final |
|
|
|
|
| class SpatialOffsetDownsampler: |
| """ |
| Downsampler that samples with local block continuity pattern. |
| Instead of global strided [1,0,1,0], creates local 2x2 blocks where sampling |
| creates continuity: within each 2x2 block, adjacent samples are spatially adjacent. |
| """ |
| def __init__(self, config, offset=0): |
| """ |
| Args: |
| config: Model configuration |
| offset: Integer offset (0, 1, 2, or 3) for position within each 2x2 block |
| 0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right |
| """ |
| self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size |
| self.new_image_side = self.orig_image_side // 2 |
| self.offset = offset |
| |
| self.offsets = [(0, 0), (0, 1), (1, 0), (1, 1)] |
| self.offset_h, self.offset_w = self.offsets[offset] |
| |
| def __call__(self, image_features): |
| """ |
| Extract features by sampling one position from each 2x2 block across the image. |
| This maintains full spatial coverage while creating local continuity. |
| |
| For a 4x4 image with offset=0 (top-left of each 2x2 block): |
| Original: Sampled (raster order): |
| [A B | C D] [A C] |
| [E F | G H] -> [I K] |
| [---+---] |
| [I J | K L] |
| [M N | O P] |
| |
| Result in sequence: [A, C, I, K] - maintains spatial structure |
| |
| Args: |
| image_features: Tensor of shape [batch, height*width, hidden_dim] |
| |
| Returns: |
| Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim] |
| """ |
| batch_size, seq_len, hidden_dim = image_features.shape |
| |
| |
| features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim) |
| |
| |
| n_blocks = self.new_image_side |
| features_blocks = features_2d.reshape( |
| batch_size, n_blocks, 2, n_blocks, 2, hidden_dim |
| ) |
| |
| |
| |
| sampled = features_blocks[:, :, self.offset_h, :, self.offset_w, :] |
| |
| |
| sampled = sampled.reshape(batch_size, -1, hidden_dim) |
| |
| return sampled |
| |
|
|
| class SpatialQuadrantDownsampler: |
| """ |
| Alternative downsampler that samples contiguous spatial quadrants. |
| Takes a full quadrant of the image rather than sampling across the entire image. |
| This creates maximum local continuity but only covers 1/4 of the spatial extent. |
| |
| Use case: When you want queries to focus on a specific region with maximum |
| local coherence, trading off global spatial coverage. |
| """ |
| def __init__(self, config, offset=0): |
| """ |
| Args: |
| config: Model configuration |
| offset: Integer offset (0, 1, 2, or 3) for quadrant selection |
| 0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right |
| """ |
| self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size |
| self.new_image_side = self.orig_image_side // 2 |
| self.offset = offset |
| |
| self.offsets = [ |
| (0, 0), |
| (0, self.new_image_side), |
| (self.new_image_side, 0), |
| (self.new_image_side, self.new_image_side) |
| ] |
| self.start_h, self.start_w = self.offsets[offset] |
| |
| def __call__(self, image_features): |
| """ |
| Extract a contiguous quadrant from the image. |
| |
| For a 4x4 image with offset=0 (top-left quadrant): |
| Original: Sampled: |
| [A B | C D] [A B] |
| [E F | G H] -> [E F] |
| [---+---] |
| [I J | K L] |
| [M N | O P] |
| |
| Result in sequence: [A, B, E, F] - maximum local continuity |
| |
| Args: |
| image_features: Tensor of shape [batch, height*width, hidden_dim] |
| |
| Returns: |
| Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim] |
| """ |
| batch_size, seq_len, hidden_dim = image_features.shape |
| |
| |
| features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim) |
| |
| |
| sampled = features_2d[:, self.start_h:self.start_h + self.new_image_side, |
| self.start_w:self.start_w + self.new_image_side, :] |
| |
| |
| sampled = sampled.reshape(batch_size, -1, hidden_dim) |
| |
| return sampled |
|
|
|
|
|
|
| class WindowQFormerDownsampler(nn.Module): |
| def __init__(self, config, checkerboard_offset=None, use_quadrant_sampling=False): |
| super().__init__() |
| llm_hidden_size = config.text_config.hidden_size |
| vision_hidden_size = config.vision_config.hidden_size |
| |
| |
| self.dropout = nn.Dropout(config.projector_dropout) |
| |
| |
| if checkerboard_offset is not None: |
| if use_quadrant_sampling: |
| |
| self.downsampler = SpatialQuadrantDownsampler(config, offset=checkerboard_offset) |
| else: |
| |
| self.downsampler = SpatialOffsetDownsampler(config, offset=checkerboard_offset) |
| else: |
| self.downsampler = InterpolateDownsampler(config) |
| |
| self.use_simplified_qformer = config.simplified_qformer |
| |
| |
| if self.use_simplified_qformer: |
| |
| self.qformer = SimplifiedQFormer( |
| hidden_size=vision_hidden_size, |
| num_heads=vision_hidden_size // 64, |
| mlp_hidden_size=3072, |
| mlp_bias=True, |
| attn_bias=True |
| ) |
| else: |
| |
| configuration = Blip2QFormerConfig( |
| hidden_size=vision_hidden_size, |
| num_attention_heads=vision_hidden_size // 64, |
| intermediate_size=3072, |
| num_hidden_layers=1, |
| encoder_hidden_size=vision_hidden_size, |
| cross_attention_frequency=1, |
| max_position_embeddings=2048, |
| use_qformer_text_input=False, |
| ) |
| self.qformer = Blip2QFormerModel(configuration) |
|
|
| self.image_side = config.vision_config.image_size // config.vision_config.patch_size |
| q, w = config.downsample_rate.split("/") |
| self.query_side, self.window_side = int(q), int(w) |
| |
| self.query_length = self.query_side ** 2 |
| embed_std = 1 / math.sqrt(vision_hidden_size) |
| self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6) |
| self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std) |
| |
| self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std) |
| self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True) |
|
|
| def _win(self, x, side, win): |
| """ |
| (B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win |
| windows are raster-ordered, and tokens inside each window are raster-ordered. |
| """ |
| B, _, C = x.shape |
| n = side // win |
| return ( |
| x.view(B, side, side, C) |
| .view(B, n, win, n, win, C) |
| .transpose(2, 3) |
| .flatten(0, 2) |
| .flatten(1, 2) |
| ) |
|
|
| def _unwin(self, xw, n, win): |
| """ |
| (B*n*n, win*win, C) -> (B, (n*win)^2, C) raster |
| """ |
| Bnn, _, C = xw.shape |
| assert Bnn % (n * n) == 0 |
| B = Bnn // (n * n) |
| side = n * win |
| return ( |
| xw.view(B, n, n, win, win, C) |
| .transpose(2, 3) |
| .contiguous() |
| .view(B, side, side, C) |
| .flatten(1, 2) |
| ) |
|
|
| def forward(self, image_features): |
| B, HW, C = image_features.shape |
| assert HW == self.image_side * self.image_side |
| n = self.image_side // self.window_side |
| image_features = self.norm(image_features) |
| enc = self._win(image_features, self.image_side, self.window_side) |
|
|
| |
| downsampled = self.downsampler(image_features) |
| |
| new_side = n * self.query_side |
| downsampled_w = self._win(downsampled, new_side, self.query_side) |
|
|
| |
| if self.use_simplified_qformer: |
| |
| |
| |
| query_embeds = self.dropout(self.query + downsampled_w) |
| encoder_embeds = self.dropout(enc + self.image_positions) |
| out_w = self.qformer( |
| query_embeds=query_embeds, |
| encoder_hidden_states=encoder_embeds |
| ) |
| else: |
| |
| |
| query_embeds = self.query + downsampled_w |
| encoder_embeds = self.dropout(enc + self.image_positions) |
| out_w = self.qformer( |
| query_embeds=query_embeds, |
| encoder_hidden_states=encoder_embeds, |
| return_dict=True, |
| ).last_hidden_state |
|
|
| out = self._unwin(out_w, n=n, win=self.query_side) |
| |
| |
| out = self.dropout(out) |
| return self.out_linear(out) |
| |