| | from functools import partial |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class BaseEncoder(nn.Module): |
| | def __init__(self, parent: nn.Module) -> None: |
| | super().__init__() |
| | self._parent = [parent] |
| |
|
| | @property |
| | def parent(self) -> nn.Module: |
| | return self._parent[0] |
| |
|
| |
|
| | class BasicImageEncoder(BaseEncoder): |
| | def __init__( |
| | self, |
| | parent: torch.nn.Module, |
| | start_tokens: Optional[str] = None, |
| | end_tokens: Optional[str] = "\n", |
| | ) -> None: |
| | super().__init__(parent) |
| | self.start_tokens = start_tokens |
| | self.end_tokens = end_tokens |
| |
|
| | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
| | if tokens is None: |
| | return None |
| | token_ids = self.parent.tokenizer(tokens).input_ids |
| | token_ids = torch.tensor(token_ids, device=self.parent.device) |
| | return self.parent.llm.model.embed_tokens(token_ids) |
| |
|
| | def _process_features( |
| | self, |
| | features: torch.Tensor, |
| | start_token_embeds: Optional[torch.Tensor], |
| | end_token_embeds: Optional[torch.Tensor], |
| | ) -> torch.Tensor: |
| | if start_token_embeds is not None: |
| | features = torch.cat([start_token_embeds, features], dim=0) |
| | if end_token_embeds is not None: |
| | features = torch.cat([features, end_token_embeds], dim=0) |
| | return features |
| |
|
| | def forward(self, images: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
| | images = torch.stack(images, dim=0) |
| | features = self.parent.encode_images(images, block_sizes=config.get("block_sizes")) |
| | process_features = partial( |
| | self._process_features, |
| | start_token_embeds=self.embed_tokens(self.start_tokens), |
| | end_token_embeds=self.embed_tokens(self.end_tokens), |
| | ) |
| | return [process_features(f) for f in features] |
| |
|
| |
|
| | class BasicVideoEncoder(BaseEncoder): |
| | def __init__( |
| | self, |
| | parent: torch.nn.Module, |
| | start_tokens: Optional[str] = None, |
| | end_tokens: Optional[str] = "\n", |
| | ) -> None: |
| | super().__init__(parent) |
| | self.start_tokens = start_tokens |
| | self.end_tokens = end_tokens |
| |
|
| | def embed_tokens(self, tokens: Optional[str]) -> Optional[torch.Tensor]: |
| | if tokens is None: |
| | return None |
| | token_ids = self.parent.tokenizer(tokens).input_ids |
| | token_ids = torch.tensor(token_ids, device=self.parent.device) |
| | return self.parent.llm.model.embed_tokens(token_ids) |
| |
|
| | def _process_features( |
| | self, |
| | features: torch.Tensor, |
| | start_token_embeds: Optional[torch.Tensor], |
| | end_token_embeds: Optional[torch.Tensor], |
| | ) -> torch.Tensor: |
| | if start_token_embeds is not None: |
| | start_embeds = torch.stack([start_token_embeds] * features.shape[0], dim=0) |
| | features = torch.cat([start_embeds, features], dim=1) |
| | if end_token_embeds is not None: |
| | end_embeds = torch.stack([end_token_embeds] * features.shape[0], dim=0) |
| | features = torch.cat([features, end_embeds], dim=1) |
| | return features.flatten(0, 1) |
| |
|
| | def forward(self, videos: List[torch.Tensor], config: Dict[str, Any]) -> List[torch.Tensor]: |
| | num_frames = [video.shape[0] for video in videos] |
| | images = torch.cat(videos, dim=0) |
| | features = self.parent.encode_images(images) |
| | features = torch.split(features, num_frames) |
| | process_features = partial( |
| | self._process_features, |
| | start_token_embeds=self.embed_tokens(self.start_tokens), |
| | end_token_embeds=self.embed_tokens(self.end_tokens), |
| | ) |
| | return [process_features(f) for f in features] |
| |
|