shethjenil commited on
Commit
a705b55
·
verified ·
1 Parent(s): e63e063

Update modeling_conformer.py

Browse files
Files changed (1) hide show
  1. modeling_conformer.py +3 -3
modeling_conformer.py CHANGED
@@ -41,7 +41,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
41
  else:
42
  l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
43
  self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
44
- self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * int(self.calc_length(torch.tensor(80.),repeat_num=config.num_feat_extract_layers)),config.hidden_size)
45
  self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
46
  for l in self.encoder.layers:
47
  l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
@@ -61,10 +61,10 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
61
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
62
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
63
 
64
- def calc_length(self,lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
65
  add_pad = all_paddings - kernel_size
66
  for _ in range(repeat_num):
67
- lengths = torch.floor((lengths.float() + add_pad) / stride + 1)
68
  return lengths
69
 
70
  def preprocessing(self, x):
 
41
  else:
42
  l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
43
  self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
44
+ self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * self.calc_length(80,repeat_num=config.num_feat_extract_layers),config.hidden_size)
45
  self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
46
  for l in self.encoder.layers:
47
  l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
 
61
  self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
62
  return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
63
 
64
+ def calc_length(self, lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
65
  add_pad = all_paddings - kernel_size
66
  for _ in range(repeat_num):
67
+ lengths = (lengths + add_pad) // stride + 1
68
  return lengths
69
 
70
  def preprocessing(self, x):