| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Run Reka Edge 7B on an image or video + text query. |
| |
| Usage: |
| uv run example.py --image photo.jpg |
| uv run example.py --image photo.jpg --prompt "What is this?" |
| uv run example.py --video media/dashcam.mp4 --prompt "Is this person falling asleep?" |
| uv run example.py --image photo.jpg --model /path/to/local/checkpoint |
| """ |
|
|
| import argparse |
|
|
| import torch |
| from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="Reka Edge 7B inference") |
| group = parser.add_mutually_exclusive_group(required=True) |
| group.add_argument("--image", help="Path to an image file") |
| group.add_argument("--video", help="Path to a video file") |
| parser.add_argument( |
| "--prompt", |
| default="Describe what you see in detail.", |
| help="Text prompt (default: 'Describe what you see in detail.')", |
| ) |
| parser.add_argument( |
| "--model", |
| default=".", |
| help="Model ID or local path (default: current directory)", |
| ) |
| args = parser.parse_args() |
|
|
| |
| if torch.cuda.is_available(): |
| device = torch.device("cuda") |
| model_dtype = torch.float16 |
| else: |
| mps_ok = False |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_built() and torch.backends.mps.is_available(): |
| torch.zeros(1, device="mps") |
| mps_ok = True |
|
|
| if mps_ok: |
| device = torch.device("mps") |
| model_dtype = torch.float16 |
| else: |
| device = torch.device("cpu") |
| model_dtype = torch.float32 |
|
|
| |
| processor = AutoProcessor.from_pretrained(args.model, trust_remote_code=True) |
| model = AutoModelForImageTextToText.from_pretrained( |
| args.model, |
| trust_remote_code=True, |
| torch_dtype=model_dtype, |
| ).eval() |
| model = model.to(device) |
|
|
| |
| if args.video: |
| media_entry = {"type": "video", "video": args.video} |
| else: |
| media_entry = {"type": "image", "image": args.image} |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| media_entry, |
| {"type": "text", "text": args.prompt}, |
| ], |
| } |
| ] |
|
|
| |
| inputs = processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| return_dict=True, |
| ) |
|
|
| |
| for key, val in inputs.items(): |
| if isinstance(val, torch.Tensor): |
| if val.is_floating_point(): |
| inputs[key] = val.to(device=device, dtype=model_dtype) |
| else: |
| inputs[key] = val.to(device=device) |
|
|
| |
| with torch.inference_mode(): |
| |
| sep_token_id = processor.tokenizer.convert_tokens_to_ids("<sep>") |
| output_ids = model.generate( |
| **inputs, |
| max_new_tokens=256, |
| do_sample=False, |
| eos_token_id=[processor.tokenizer.eos_token_id, sep_token_id], |
| ) |
|
|
| |
| input_len = inputs["input_ids"].shape[1] |
| new_tokens = output_ids[0, input_len:] |
| output_text = processor.tokenizer.decode(new_tokens, skip_special_tokens=True) |
| |
| output_text = output_text.replace("<sep>", "").strip() |
| print(output_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|