| | import torch |
| | from speechbrain.inference.interfaces import Pretrained |
| |
|
| |
|
| | class CustomSLUDecoder(Pretrained): |
| | """A end-to-end SLU model using hubert self-supervised encoder. |
| | |
| | The class can be used either to run only the encoder (encode()) to extract |
| | features or to run the entire model (decode()) to map the speech to its semantics. |
| | |
| | Example |
| | ------- |
| | >>> from speechbrain.pretrained.interfaces import foreign_class |
| | >>> slu_model = foreign_class(source="speechbrain/slu-timers-and-such-direct-librispeech-asr", |
| | pymodule_file="custom_interface.py", classname="CustomEncoderWav2vec2Classifier") |
| | >>> slu_model.decode_file("samples/audio_samples/example6.wav") |
| | "{'intent': 'SimpleMath', 'slots': {'number1': 37.67, 'number2': 75.7, 'op': ' minus '}}" |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.tokenizer = self.hparams.tokenizer |
| |
|
| | def decode_file(self, path): |
| | """Maps the given audio file to a string representing the |
| | semantic dictionary for the utterance. |
| | |
| | Arguments |
| | --------- |
| | path : str |
| | Path to audio file to decode. |
| | |
| | Returns |
| | ------- |
| | str |
| | The predicted semantics. |
| | """ |
| | waveform = self.load_audio(path) |
| | waveform = waveform.to(self.device) |
| | |
| | batch = waveform.unsqueeze(0) |
| | rel_length = torch.tensor([1.0]) |
| | predicted_words, predicted_tokens = self.decode_batch(batch, rel_length) |
| | return predicted_words[0] |
| | |
| | def encode_batch(self, wavs): |
| | """Encodes the input audio into a sequence of hidden states |
| | |
| | Arguments |
| | --------- |
| | wavs : torch.tensor |
| | Batch of waveforms [batch, time, channels] or [batch, time] |
| | depending on the model. |
| | |
| | Returns |
| | ------- |
| | torch.tensor |
| | The encoded batch |
| | """ |
| | wavs = wavs.float() |
| | wavs = wavs.to(self.device) |
| | encoder_out = self.mods.hubert(wavs.detach()) |
| | return encoder_out |
| |
|
| | def decode_batch(self, wavs, wav_lens): |
| | """Maps the input audio to its semantics |
| | |
| | Arguments |
| | --------- |
| | wavs : torch.tensor |
| | Batch of waveforms [batch, time, channels] or [batch, time] |
| | depending on the model. |
| | wav_lens : torch.tensor |
| | Lengths of the waveforms relative to the longest one in the |
| | batch, tensor of shape [batch]. The longest one should have |
| | relative length 1.0 and others len(waveform) / max_length. |
| | Used for ignoring padding. |
| | |
| | Returns |
| | ------- |
| | list |
| | Each waveform in the batch decoded. |
| | tensor |
| | Each predicted token id. |
| | """ |
| | with torch.no_grad(): |
| | wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) |
| | encoder_out = self.encode_batch(wavs) |
| | predicted_tokens, scores, _, _ = self.mods.beam_searcher( |
| | encoder_out, wav_lens |
| | ) |
| | predicted_words = [ |
| | self.tokenizer.decode_ids(token_seq) |
| | for token_seq in predicted_tokens |
| | ] |
| | return predicted_words, predicted_tokens |
| |
|
| | def forward(self, wavs, wav_lens): |
| | """Runs full decoding - note: no gradients through decoding""" |
| | return self.decode_batch(wavs, wav_lens) |