| from typing import List, Optional, Union, Dict |
| from transformers.models.llama import LlamaTokenizerFast |
| from transformers import AutoTokenizer |
| from transformers.tokenization_utils_base import AddedToken |
| import torch |
| import os |
| from huggingface_hub import hf_hub_download |
|
|
| HUBERT_TOKENS = [f"[Hu{i}]" for i in range(500)] |
| AUDIO_SPECIAL_TOKENS = { |
| "begin_audio_token": "<|begin▁of▁audio|>", |
| "end_audio_token": "<|end▁of▁audio|>" |
| } |
|
|
| class MoSTTokenizerFast(LlamaTokenizerFast): |
| """ |
| MoST tokenizer extending DeepSeek's tokenizer (which is based on LlamaTokenizerFast) with additional support for audio tokens. |
| This includes 500 HuBERT tokens and special audio markers. |
| """ |
| SPECIAL_TOKENS_ATTRIBUTES = LlamaTokenizerFast.SPECIAL_TOKENS_ATTRIBUTES + ["begin_audio_token", "end_audio_token"] |
| |
| def __init__(self, *args, **kwargs): |
| |
| self._begin_audio_token = AUDIO_SPECIAL_TOKENS["begin_audio_token"] |
| self._end_audio_token = AUDIO_SPECIAL_TOKENS["end_audio_token"] |
| |
| |
| super().__init__(*args, **kwargs) |
| |
| |
| self.add_tokens(HUBERT_TOKENS) |
| |
| |
| special_tokens_dict = { |
| "begin_audio_token": AddedToken(AUDIO_SPECIAL_TOKENS["begin_audio_token"], |
| lstrip=False, rstrip=False, normalized=True, single_word=False), |
| "end_audio_token": AddedToken(AUDIO_SPECIAL_TOKENS["end_audio_token"], |
| lstrip=False, rstrip=False, normalized=True, single_word=False) |
| } |
| self.add_special_tokens(special_tokens_dict) |
| |
| |
| self._hubert_token_ids = {token: self.convert_tokens_to_ids(token) for token in HUBERT_TOKENS} |
| self._begin_audio_token_id = self.convert_tokens_to_ids(AUDIO_SPECIAL_TOKENS["begin_audio_token"]) |
| self._end_audio_token_id = self.convert_tokens_to_ids(AUDIO_SPECIAL_TOKENS["end_audio_token"]) |
|
|
| @property |
| def hubert_token_ids(self) -> Dict[str, int]: |
| """Get the mapping of HuBERT tokens to their IDs.""" |
| return self._hubert_token_ids |
|
|
| @property |
| def begin_audio_token_id(self) -> int: |
| """Get the ID of the begin audio token.""" |
| return self._begin_audio_token_id |
|
|
| @property |
| def end_audio_token_id(self) -> int: |
| """Get the ID of the end audio token.""" |
| return self._end_audio_token_id |
|
|
| def convert_ids_to_tokens( |
| self, ids: Union[int, List[int]], skip_special_tokens: bool = False |
| ) -> Union[str, List[str]]: |
| """ |
| Converts a single index or a sequence of indices in a token or a sequence of tokens. |
| Handles both text and audio tokens. |
| """ |
| if isinstance(ids, int): |
| return self._convert_id_to_token(ids) |
| |
| tokens = [] |
| for index in ids: |
| index = int(index) |
| if skip_special_tokens and index in self.all_special_ids: |
| continue |
| token = self._tokenizer.id_to_token(index) |
| tokens.append(token if token is not None else "") |
| return tokens |
| |
| def _convert_id_to_token(self, index: int) -> Optional[str]: |
| """Convert a single token id to its string representation.""" |
| token = self._tokenizer.id_to_token(int(index)) |
| return token if token is not None else "" |
|
|
| def encode_audio_sequence(self, hubert_indices: List[int], add_special_tokens: bool = True) -> List[int]: |
| """ |
| Encode a sequence of HuBERT indices into token IDs, optionally adding audio special tokens. |
| |
| Args: |
| hubert_indices: List of HuBERT indices (0-499) |
| add_special_tokens: Whether to add begin/end audio tokens |
| |
| Returns: |
| List of token IDs |
| """ |
| if not all(0 <= idx < 500 for idx in hubert_indices): |
| raise ValueError("HuBERT indices must be between 0 and 499") |
| |
| |
| hubert_tokens = [f"[Hu{idx}]" for idx in hubert_indices] |
| token_ids = [self._hubert_token_ids[token] for token in hubert_tokens] |
| |
| if add_special_tokens: |
| token_ids = [self._begin_audio_token_id] + token_ids + [self._end_audio_token_id] |
| |
| return token_ids |
| |
|
|
| if __name__ == "__main__": |
| |
| try: |
| |
| tokenizer = MoSTTokenizerFast.from_pretrained("./") |
| print("Loaded MoST tokenizer from local directory") |
| except Exception as e: |
| print(f"Creating new MoST tokenizer: {e}") |
| tokenizer = MoSTTokenizerFast() |
| |
| print(f"\nVocabulary size: {tokenizer.vocab_size}") |
| |
| |
| text = "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES" |
| encoded = tokenizer(text) |
| decoded = tokenizer.decode(encoded["input_ids"]) |
| print(f"\nText encoding/decoding test:") |
| print(f"Original: {text}") |
| print(f"Decoded: {decoded}") |
| |
| |
| hubert_indices = [0, 1, 2, 3, 4] |
| audio_tokens = tokenizer.encode_audio_sequence(hubert_indices) |
| decoded_audio = tokenizer.decode(audio_tokens) |
| print(f"\nAudio encoding/decoding test:") |
| print(f"HuBERT indices: {hubert_indices}") |
| print(f"Decoded: {decoded_audio}") |
| |
| |
| mixed_tokens = tokenizer(text)["input_ids"] + tokenizer.encode_audio_sequence(hubert_indices) |
| decoded_mixed = tokenizer.decode(mixed_tokens) |
| print(f"\nMixed text/audio test:") |
| print(f"Decoded: {decoded_mixed}") |
|
|
| |
| print("\nChecking vocabulary boundary(text):") |
| for i in range(99995, 100005, 1): |
| decoded = tokenizer.decode([i]) |
| if decoded.strip(): |
| print(f"Token {i}: '{decoded}'") |
|
|
| print("\nChecking vocabulary boundary(audio):") |
| for i in range(100500, 100510, 1): |
| decoded = tokenizer.decode([i]) |
| if decoded.strip(): |
| print(f"Token {i}: '{decoded}'") |
|
|