Image-to-Text
English
saad1926q commited on
Commit
2b67222
·
verified ·
1 Parent(s): f120b40

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +49 -0
model.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional,Tuple
3
+ from transformers import GPT2LMHeadModel
4
+
5
+
6
+ class MLP(torch.nn.Module):
7
+ def __init__(self,prefix_size,intermediate_size,out_size):
8
+ super().__init__()
9
+ layers=[]
10
+ self.proj1=torch.nn.Linear(prefix_size,intermediate_size,bias=True)
11
+ self.proj2=torch.nn.Linear(intermediate_size,out_size,bias=True)
12
+
13
+ def forward(self,X):
14
+ z=self.proj1(X)
15
+ z=torch.nn.functional.tanh(z)
16
+ z=self.proj2(z)
17
+ return z
18
+
19
+
20
+ class ClipCapModel(torch.nn.Module):
21
+ def __init__(self,prefix_length:int,clip_length:Optional[int]=None,prefix_size:int=512,num_layers:int=8):
22
+ super().__init__()
23
+ self.prefix_length=prefix_length
24
+ self.gpt=GPT2LMHeadModel.from_pretrained('gpt2')
25
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
26
+
27
+ self.mapping=MLP(prefix_size=prefix_size,intermediate_size=(self.gpt_embedding_size * prefix_length) // 2,out_size=self.gpt_embedding_size * prefix_length)
28
+
29
+
30
+ def forward(self,tokens:torch.Tensor,prefix:torch.Tensor,mask:Optional[torch.Tensor]=None,labels: Optional[torch.Tensor] = None):
31
+ text_embeddings=self.gpt.transformer.wte(tokens) # word token embedding layer . Each token ID is turned into an embedding.
32
+ prefix_mapped=self.mapping(prefix).view(-1,self.prefix_length,self.gpt_embedding_size) # Go from (batch_size,self.prefix_length * self.gpt_embedding_size) to (batch_size,self.prefix_length,self.gpt_embedding_size)
33
+
34
+ embeddings=torch.cat((prefix_mapped,text_embeddings),dim=1)
35
+
36
+ batch_size=tokens.shape[0]
37
+
38
+ # For training, GPT-2 needs a label for every input token.
39
+
40
+ if labels is not None:
41
+ # insert dummy tokens (zeros) in the label for the prefix part since there’s no ground-truth text corresponding to that.
42
+ dummy_tokens=torch.zeros(batch_size,self.prefix_length)
43
+ labels=torch.cat((dummy_tokens,tokens),dim=1)
44
+
45
+ out=self.gpt(inputs_embeds=embeddings, labels=labels, attention_mask=mask)
46
+
47
+
48
+ return out
49
+