| | import glob |
| | import os |
| | from collections import defaultdict |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import cv2 |
| | import numpy as np |
| | import PIL |
| | import PIL.Image |
| | import requests |
| | from transformers import PretrainedConfig |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | MEDIA_TOKENS = { |
| | "image": "<image>", |
| | "video": "<vila/video>", |
| | } |
| |
|
| |
|
| | class Media: |
| | pass |
| |
|
| |
|
| | class File(Media): |
| | def __init__(self, path: str) -> None: |
| | self.path = path |
| |
|
| |
|
| | class Image(File): |
| | pass |
| |
|
| |
|
| | class Video(File): |
| | pass |
| |
|
| |
|
| | def make_list(obj: Any) -> List: |
| | return obj if isinstance(obj, list) else [obj] |
| |
|
| |
|
| | def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: |
| | if isinstance(image, Image): |
| | if image.path.startswith("http://") or image.path.startswith("https://"): |
| | image = PIL.Image.open(requests.get(image.path, stream=True).raw) |
| | else: |
| | image = PIL.Image.open(image.path) |
| | return image |
| |
|
| |
|
| | def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: |
| | |
| | if os.path.isdir(video_path): |
| | frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) |
| | indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) |
| | return [PIL.Image.open(frame_paths[index]) for index in indices] |
| |
|
| | |
| | vidcap = cv2.VideoCapture(video_path) |
| |
|
| | |
| | frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | while frame_count > 0: |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) |
| | if vidcap.grab(): |
| | break |
| | frame_count -= 1 |
| | else: |
| | raise ValueError(f"Video '{video_path}' has no frames.") |
| |
|
| | |
| | indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) |
| | frames = {} |
| | for index in indices: |
| | if index in frames: |
| | continue |
| | vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) |
| | success, frame = vidcap.read() |
| | if not success: |
| | print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") |
| | continue |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames[index] = PIL.Image.fromarray(frame) |
| | return [frames[index] for index in indices if index in frames] |
| |
|
| |
|
| | def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: |
| | num_frames = config.num_video_frames |
| | if getattr(config, "fps") != 0: |
| | print("Extracting frames from video with specified FPS is not supported yet. Ignored.") |
| |
|
| | frames = _load_video(video.path, num_frames=num_frames) |
| | return frames |
| |
|
| |
|
| | def extract_media( |
| | messages: List[Dict[str, Any]], |
| | config: Optional[PretrainedConfig] = None, |
| | draft: bool = False, |
| | ) -> Dict[str, List[Any]]: |
| | media = defaultdict(list) |
| | for message in messages: |
| | text = "" |
| | for part in make_list(message["value"]): |
| | if isinstance(part, str): |
| | for token in MEDIA_TOKENS.values(): |
| | if token in part: |
| | print(f"Media token '{token}' found in text: '{part}'. Removed.") |
| | part = part.replace(token, "").strip() |
| | text += part |
| | elif isinstance(part, (Image, PIL.Image.Image)): |
| | if draft: |
| | media["image"].append(part) |
| | else: |
| | media["image"].append(_extract_image(part)) |
| | text += MEDIA_TOKENS["image"] |
| | elif isinstance(part, Video): |
| | if draft: |
| | media["video"].append(part) |
| | else: |
| | media["video"].append(_extract_video(part, config)) |
| | text += MEDIA_TOKENS["video"] |
| | else: |
| | raise ValueError(f"Unsupported prompt part type: {type(part)}") |
| | message["value"] = text |
| | return media |
| |
|