Update modeling_conformer.py
Browse files- modeling_conformer.py +3 -3
modeling_conformer.py
CHANGED
|
@@ -41,7 +41,7 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 41 |
else:
|
| 42 |
l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
|
| 43 |
self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
|
| 44 |
-
self.feature_projection.projection = nn.Linear(config.conv_dim[-1] *
|
| 45 |
self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
|
| 46 |
for l in self.encoder.layers:
|
| 47 |
l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
|
|
@@ -61,10 +61,10 @@ class Wav2Vec2ConformerRNNT(Wav2Vec2ConformerModel):
|
|
| 61 |
self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
|
| 62 |
return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
|
| 63 |
|
| 64 |
-
def calc_length(self,lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
|
| 65 |
add_pad = all_paddings - kernel_size
|
| 66 |
for _ in range(repeat_num):
|
| 67 |
-
lengths =
|
| 68 |
return lengths
|
| 69 |
|
| 70 |
def preprocessing(self, x):
|
|
|
|
| 41 |
else:
|
| 42 |
l.conv = nn.Sequential(nn.Conv2d(l.conv.in_channels,l.conv.out_channels,l.conv.kernel_size[0],l.conv.stride,1,groups=l.conv.out_channels),nn.Conv2d(l.conv.in_channels,l.conv.out_channels, 1))
|
| 43 |
self.feature_extractor.conv_layers.append(Op(lambda x : x.transpose(1, 2)))
|
| 44 |
+
self.feature_projection.projection = nn.Linear(config.conv_dim[-1] * self.calc_length(80,repeat_num=config.num_feat_extract_layers),config.hidden_size)
|
| 45 |
self.feature_projection.layer_norm = Op(lambda x:x.permute(0, 2, 1, 3).flatten(2))
|
| 46 |
for l in self.encoder.layers:
|
| 47 |
l.conv_module.glu = nn.Sequential(l.conv_module.glu,self.mask_layer)
|
|
|
|
| 61 |
self.mask_layer.cache_pad_mask = (torch.arange(hidden_states.size(1), device=hidden_states.device).unsqueeze(0) >= self.cache_length.unsqueeze(1))
|
| 62 |
return super()._mask_hidden_states(hidden_states, mask_time_indices, attention_mask)
|
| 63 |
|
| 64 |
+
def calc_length(self, lengths, all_paddings=2, kernel_size=3, stride=2, repeat_num=1):
|
| 65 |
add_pad = all_paddings - kernel_size
|
| 66 |
for _ in range(repeat_num):
|
| 67 |
+
lengths = (lengths + add_pad) // stride + 1
|
| 68 |
return lengths
|
| 69 |
|
| 70 |
def preprocessing(self, x):
|