benwiesel's picture
Upload folder using huggingface_hub
cae0922 verified
from typing import Any
import torch
from torch import nn
import math
from fractions import Fraction
from transformers.models.blip_2.configuration_blip_2 import Blip2QFormerConfig
from transformers.models.blip_2.modeling_blip_2 import Blip2QFormerModel
import torch.nn.functional as F
class QFormerCrossAttention(nn.Module):
"""Multi-headed cross-attention for QFormer with SDPA/Flash Attention support"""
def __init__(self, hidden_size, num_heads, attn_bias=False, attention_dropout=0.05, final_dropout=0.05):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.attention_dropout = attention_dropout
if self.head_dim * num_heads != hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {hidden_size} "
f"and `num_heads`: {num_heads})."
)
# Q from queries, K and V from encoder
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias)
self.k_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias)
self.v_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=attn_bias)
self.dropout = nn.Dropout(final_dropout)
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
"""
Args:
hidden_states: (B, query_len, hidden_size) - queries
encoder_hidden_states: (B, encoder_len, hidden_size) - keys and values
attention_mask: optional attention mask
Returns:
(B, query_len, hidden_size)
"""
batch_size, query_len, _ = hidden_states.shape
encoder_len = encoder_hidden_states.shape[1]
# Project queries from hidden_states
query_states = self.q_proj(hidden_states).view(
batch_size, query_len, self.num_heads, self.head_dim
).transpose(1, 2)
# Project keys and values from encoder_hidden_states
key_states = self.k_proj(encoder_hidden_states).view(
batch_size, encoder_len, self.num_heads, self.head_dim
).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(
batch_size, encoder_len, self.num_heads, self.head_dim
).transpose(1, 2)
# Use PyTorch's scaled_dot_product_attention (SDPA)
# This automatically uses Flash Attention when available
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
is_causal=False,
)
# Reshape back to (B, query_len, hidden_size)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
attn_output = self.dropout(attn_output)
return attn_output
class QFormerMLP(nn.Module):
"""Feed-forward network (MLP) for QFormer with SiLU activation"""
def __init__(self, hidden_size, mlp_hidden_size, mlp_bias=False, dropout_prob=0.05):
super().__init__()
self.hidden_size = hidden_size
self.fc1 = nn.Linear(hidden_size, mlp_hidden_size, bias=mlp_bias)
self.act = nn.SiLU()
self.fc2 = nn.Linear(mlp_hidden_size, hidden_size, bias=mlp_bias)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, hidden_states):
"""
Args:
hidden_states: (B, seq_len, hidden_size)
Returns:
(B, seq_len, hidden_size)
"""
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dropout(self.fc2(hidden_states))
return hidden_states
class SimplifiedQFormer(nn.Module):
"""
Simplified QFormer with a single cross-attention layer followed by an MLP.
Lightweight design: queries attend to encoder hidden states via cross-attention,
then pass through a feed-forward network, similar to a transformer block.
"""
def __init__(self, hidden_size, num_heads=8, mlp_hidden_size=2048, mlp_bias=False, attn_bias=False, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
# Cross-attention block
self.attn_norm = nn.LayerNorm(hidden_size, eps=eps)
self.cross_attention = QFormerCrossAttention(
hidden_size, num_heads, attn_bias=attn_bias,
)
# MLP block (feed-forward network)
self.mlp_norm = nn.LayerNorm(hidden_size, eps=eps)
self.mlp = QFormerMLP(hidden_size, mlp_hidden_size, mlp_bias=mlp_bias)
def forward(self, query_embeds, encoder_hidden_states):
"""
Args:
query_embeds: (B, num_queries, hidden_size) - learnable queries
encoder_hidden_states: (B, num_tokens, hidden_size) - input features
Returns:
(B, num_queries, hidden_size) - output features
"""
# Cross-attention block with residual and pre-norm
residual = query_embeds
hidden_states = self.attn_norm(query_embeds)
hidden_states = self.cross_attention(hidden_states, encoder_hidden_states)
hidden_states = residual + hidden_states
# MLP block with residual and pre-norm
residual = hidden_states
hidden_states = self.mlp_norm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class InterpolateDownsampler:
def __init__(self, config, mode="area"):
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = int(self.orig_image_side * Fraction(config.downsample_rate))
self.mode = mode
def __call__(self, image_features):
batch_size, _, dim = image_features.size()
up_shape = [batch_size] + [self.orig_image_side] * 2 + [dim]
# interpolate expects B,C,H,W
large_image_permuted = image_features.view(up_shape).permute(0,3,1,2)
small_image_permuted = torch.nn.functional.interpolate(
large_image_permuted, size=(self.new_image_side, self.new_image_side),
mode=self.mode,
)
# back to B,H*W,C
final = small_image_permuted.permute(0,2,3,1).flatten(1,2)
return final
class SpatialOffsetDownsampler:
"""
Downsampler that samples with local block continuity pattern.
Instead of global strided [1,0,1,0], creates local 2x2 blocks where sampling
creates continuity: within each 2x2 block, adjacent samples are spatially adjacent.
"""
def __init__(self, config, offset=0):
"""
Args:
config: Model configuration
offset: Integer offset (0, 1, 2, or 3) for position within each 2x2 block
0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right
"""
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = self.orig_image_side // 2 # downsample by 2x
self.offset = offset
# Map offset to position within 2x2 blocks
self.offsets = [(0, 0), (0, 1), (1, 0), (1, 1)]
self.offset_h, self.offset_w = self.offsets[offset]
def __call__(self, image_features):
"""
Extract features by sampling one position from each 2x2 block across the image.
This maintains full spatial coverage while creating local continuity.
For a 4x4 image with offset=0 (top-left of each 2x2 block):
Original: Sampled (raster order):
[A B | C D] [A C]
[E F | G H] -> [I K]
[---+---]
[I J | K L]
[M N | O P]
Result in sequence: [A, C, I, K] - maintains spatial structure
Args:
image_features: Tensor of shape [batch, height*width, hidden_dim]
Returns:
Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim]
"""
batch_size, seq_len, hidden_dim = image_features.shape
# Reshape to [batch, height, width, hidden_dim]
features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim)
# Reshape into 2x2 blocks: [batch, n_blocks_h, 2, n_blocks_w, 2, hidden_dim]
n_blocks = self.new_image_side
features_blocks = features_2d.reshape(
batch_size, n_blocks, 2, n_blocks, 2, hidden_dim
)
# Select the specified position from each 2x2 block
# This maintains spatial coverage while creating local continuity
sampled = features_blocks[:, :, self.offset_h, :, self.offset_w, :]
# Flatten spatial dimensions back to [batch, n_blocks*n_blocks, hidden_dim]
sampled = sampled.reshape(batch_size, -1, hidden_dim)
return sampled
class SpatialQuadrantDownsampler:
"""
Alternative downsampler that samples contiguous spatial quadrants.
Takes a full quadrant of the image rather than sampling across the entire image.
This creates maximum local continuity but only covers 1/4 of the spatial extent.
Use case: When you want queries to focus on a specific region with maximum
local coherence, trading off global spatial coverage.
"""
def __init__(self, config, offset=0):
"""
Args:
config: Model configuration
offset: Integer offset (0, 1, 2, or 3) for quadrant selection
0: top-left, 1: top-right, 2: bottom-left, 3: bottom-right
"""
self.orig_image_side = config.vision_config.image_size // config.vision_config.patch_size
self.new_image_side = self.orig_image_side // 2 # downsample by 2x
self.offset = offset
# Map offset to quadrant starting positions
self.offsets = [
(0, 0), # top-left
(0, self.new_image_side), # top-right
(self.new_image_side, 0), # bottom-left
(self.new_image_side, self.new_image_side) # bottom-right
]
self.start_h, self.start_w = self.offsets[offset]
def __call__(self, image_features):
"""
Extract a contiguous quadrant from the image.
For a 4x4 image with offset=0 (top-left quadrant):
Original: Sampled:
[A B | C D] [A B]
[E F | G H] -> [E F]
[---+---]
[I J | K L]
[M N | O P]
Result in sequence: [A, B, E, F] - maximum local continuity
Args:
image_features: Tensor of shape [batch, height*width, hidden_dim]
Returns:
Downsampled features of shape [batch, (height/2)*(width/2), hidden_dim]
"""
batch_size, seq_len, hidden_dim = image_features.shape
# Reshape to [batch, height, width, hidden_dim]
features_2d = image_features.reshape(batch_size, self.orig_image_side, self.orig_image_side, hidden_dim)
# Extract contiguous quadrant
sampled = features_2d[:, self.start_h:self.start_h + self.new_image_side,
self.start_w:self.start_w + self.new_image_side, :]
# Flatten spatial dimensions back to [batch, new_height*new_width, hidden_dim]
sampled = sampled.reshape(batch_size, -1, hidden_dim)
return sampled
class WindowQFormerDownsampler(nn.Module):
def __init__(self, config, checkerboard_offset=None, use_quadrant_sampling=False):
super().__init__()
llm_hidden_size = config.text_config.hidden_size
vision_hidden_size = config.vision_config.hidden_size
# Dropout rates for robustness (conservative approach)
self.dropout = nn.Dropout(config.projector_dropout)
# Choose downsampler based on parameters
if checkerboard_offset is not None:
if use_quadrant_sampling:
# Use quadrant sampling: maximum local continuity, limited spatial coverage
self.downsampler = SpatialQuadrantDownsampler(config, offset=checkerboard_offset)
else:
# Use block sampling: balanced continuity and full spatial coverage (default)
self.downsampler = SpatialOffsetDownsampler(config, offset=checkerboard_offset)
else:
self.downsampler = InterpolateDownsampler(config)
self.use_simplified_qformer = config.simplified_qformer
# Choose between SimplifiedQFormer and Blip2QFormerModel
if self.use_simplified_qformer:
# Use our simplified QFormer with full self-attention
self.qformer = SimplifiedQFormer(
hidden_size=vision_hidden_size,
num_heads=vision_hidden_size // 64,
mlp_hidden_size=3072,
mlp_bias=True,
attn_bias=True
)
else:
# Use original Blip2QFormerModel with cross-attention
configuration = Blip2QFormerConfig(
hidden_size=vision_hidden_size,
num_attention_heads=vision_hidden_size // 64,
intermediate_size=3072,
num_hidden_layers=1,
encoder_hidden_size=vision_hidden_size,
cross_attention_frequency=1,
max_position_embeddings=2048,
use_qformer_text_input=False,
)
self.qformer = Blip2QFormerModel(configuration)
self.image_side = config.vision_config.image_size // config.vision_config.patch_size
q, w = config.downsample_rate.split("/")
self.query_side, self.window_side = int(q), int(w)
# query length is cubical for seamless integration with llava next
self.query_length = self.query_side ** 2
embed_std = 1 / math.sqrt(vision_hidden_size)
self.norm = nn.LayerNorm(vision_hidden_size, eps=1e-6)
self.query = nn.Parameter(torch.randn(1, self.query_length, vision_hidden_size) * embed_std)
# qformer model doesn't have positional embeddings, adding to the flat patches
self.image_positions = nn.Parameter(torch.randn(1, self.window_side ** 2, vision_hidden_size) * embed_std)
self.out_linear = nn.Linear(vision_hidden_size, llm_hidden_size, bias=True)
def _win(self, x, side, win):
"""
(B, side*side, C) raster -> (B*n*n, win*win, C) where n=side//win
windows are raster-ordered, and tokens inside each window are raster-ordered.
"""
B, _, C = x.shape
n = side // win
return (
x.view(B, side, side, C)
.view(B, n, win, n, win, C)
.transpose(2, 3) # (B, n, n, win, win, C)
.flatten(0, 2) # (B*n*n, win, win, C)
.flatten(1, 2) # (B*n*n, win*win, C)
)
def _unwin(self, xw, n, win):
"""
(B*n*n, win*win, C) -> (B, (n*win)^2, C) raster
"""
Bnn, _, C = xw.shape
assert Bnn % (n * n) == 0
B = Bnn // (n * n)
side = n * win
return (
xw.view(B, n, n, win, win, C)
.transpose(2, 3) # (B, n, win, n, win, C)
.contiguous()
.view(B, side, side, C)
.flatten(1, 2)
)
def forward(self, image_features):
B, HW, C = image_features.shape
assert HW == self.image_side * self.image_side
n = self.image_side // self.window_side
image_features = self.norm(image_features)
enc = self._win(image_features, self.image_side, self.window_side) # (B*n^2, w^2, C)
# Apply downsampling (either spatial offset or interpolation)
downsampled = self.downsampler(image_features) # (B, new_side^2, C) raster
new_side = n * self.query_side
downsampled_w = self._win(downsampled, new_side, self.query_side) # (B*n^2, q^2, C)
# Apply QFormer based on the chosen mechanism
if self.use_simplified_qformer:
# SimplifiedQFormer: full self-attention between queries and inputs
# Broadcasting handles batch dimension automatically
# Apply dropout to embeddings for robustness
query_embeds = self.dropout(self.query + downsampled_w)
encoder_embeds = self.dropout(enc + self.image_positions)
out_w = self.qformer(
query_embeds=query_embeds,
encoder_hidden_states=encoder_embeds
) # (B*n^2, q^2, C)
else:
# Blip2QFormerModel: cross-attention mechanism
# Apply dropout to embeddings for robustness
query_embeds = self.query + downsampled_w # blip already dropouts the queries
encoder_embeds = self.dropout(enc + self.image_positions)
out_w = self.qformer(
query_embeds=query_embeds,
encoder_hidden_states=encoder_embeds,
return_dict=True,
).last_hidden_state # (B*n^2, q^2, C)
out = self._unwin(out_w, n=n, win=self.query_side) # (B, new_side^2, C) raster
# Apply output dropout before final projection
out = self.dropout(out)
return self.out_linear(out)