MoST-speech-text-moe / tokenization_most_fast.py
Yuxuan98's picture
Initial MoST model upload - tokenization_most_fast.py
91f0868 verified
from typing import List, Optional, Union, Dict
from transformers.models.llama import LlamaTokenizerFast
from transformers import AutoTokenizer
from transformers.tokenization_utils_base import AddedToken
import torch
import os
from huggingface_hub import hf_hub_download
HUBERT_TOKENS = [f"[Hu{i}]" for i in range(500)]
AUDIO_SPECIAL_TOKENS = {
"begin_audio_token": "<|begin▁of▁audio|>",
"end_audio_token": "<|end▁of▁audio|>"
}
class MoSTTokenizerFast(LlamaTokenizerFast):
"""
MoST tokenizer extending DeepSeek's tokenizer (which is based on LlamaTokenizerFast) with additional support for audio tokens.
This includes 500 HuBERT tokens and special audio markers.
"""
SPECIAL_TOKENS_ATTRIBUTES = LlamaTokenizerFast.SPECIAL_TOKENS_ATTRIBUTES + ["begin_audio_token", "end_audio_token"]
def __init__(self, *args, **kwargs):
# Initialize special token attributes
self._begin_audio_token = AUDIO_SPECIAL_TOKENS["begin_audio_token"]
self._end_audio_token = AUDIO_SPECIAL_TOKENS["end_audio_token"]
# Initialize the base tokenizer
super().__init__(*args, **kwargs)
# Add HuBERT tokens
self.add_tokens(HUBERT_TOKENS)
# Add special audio tokens
special_tokens_dict = {
"begin_audio_token": AddedToken(AUDIO_SPECIAL_TOKENS["begin_audio_token"],
lstrip=False, rstrip=False, normalized=True, single_word=False),
"end_audio_token": AddedToken(AUDIO_SPECIAL_TOKENS["end_audio_token"],
lstrip=False, rstrip=False, normalized=True, single_word=False)
}
self.add_special_tokens(special_tokens_dict)
# Cache the token IDs for quick access
self._hubert_token_ids = {token: self.convert_tokens_to_ids(token) for token in HUBERT_TOKENS}
self._begin_audio_token_id = self.convert_tokens_to_ids(AUDIO_SPECIAL_TOKENS["begin_audio_token"])
self._end_audio_token_id = self.convert_tokens_to_ids(AUDIO_SPECIAL_TOKENS["end_audio_token"])
@property
def hubert_token_ids(self) -> Dict[str, int]:
"""Get the mapping of HuBERT tokens to their IDs."""
return self._hubert_token_ids
@property
def begin_audio_token_id(self) -> int:
"""Get the ID of the begin audio token."""
return self._begin_audio_token_id
@property
def end_audio_token_id(self) -> int:
"""Get the ID of the end audio token."""
return self._end_audio_token_id
def convert_ids_to_tokens(
self, ids: Union[int, List[int]], skip_special_tokens: bool = False
) -> Union[str, List[str]]:
"""
Converts a single index or a sequence of indices in a token or a sequence of tokens.
Handles both text and audio tokens.
"""
if isinstance(ids, int):
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
index = int(index)
if skip_special_tokens and index in self.all_special_ids:
continue
token = self._tokenizer.id_to_token(index)
tokens.append(token if token is not None else "")
return tokens
def _convert_id_to_token(self, index: int) -> Optional[str]:
"""Convert a single token id to its string representation."""
token = self._tokenizer.id_to_token(int(index))
return token if token is not None else ""
def encode_audio_sequence(self, hubert_indices: List[int], add_special_tokens: bool = True) -> List[int]:
"""
Encode a sequence of HuBERT indices into token IDs, optionally adding audio special tokens.
Args:
hubert_indices: List of HuBERT indices (0-499)
add_special_tokens: Whether to add begin/end audio tokens
Returns:
List of token IDs
"""
if not all(0 <= idx < 500 for idx in hubert_indices):
raise ValueError("HuBERT indices must be between 0 and 499")
# Convert HuBERT indices to tokens
hubert_tokens = [f"[Hu{idx}]" for idx in hubert_indices]
token_ids = [self._hubert_token_ids[token] for token in hubert_tokens]
if add_special_tokens:
token_ids = [self._begin_audio_token_id] + token_ids + [self._end_audio_token_id]
return token_ids
if __name__ == "__main__":
# Test tokenizer functionality
try:
# Try loading from local directory first
tokenizer = MoSTTokenizerFast.from_pretrained("./")
print("Loaded MoST tokenizer from local directory")
except Exception as e:
print(f"Creating new MoST tokenizer: {e}")
tokenizer = MoSTTokenizerFast()
print(f"\nVocabulary size: {tokenizer.vocab_size}")
# Test text encoding/decoding
text = "MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES"
encoded = tokenizer(text)
decoded = tokenizer.decode(encoded["input_ids"])
print(f"\nText encoding/decoding test:")
print(f"Original: {text}")
print(f"Decoded: {decoded}")
# Test audio token encoding/decoding
hubert_indices = [0, 1, 2, 3, 4]
audio_tokens = tokenizer.encode_audio_sequence(hubert_indices)
decoded_audio = tokenizer.decode(audio_tokens)
print(f"\nAudio encoding/decoding test:")
print(f"HuBERT indices: {hubert_indices}")
print(f"Decoded: {decoded_audio}")
# Test mixed text and audio
mixed_tokens = tokenizer(text)["input_ids"] + tokenizer.encode_audio_sequence(hubert_indices)
decoded_mixed = tokenizer.decode(mixed_tokens)
print(f"\nMixed text/audio test:")
print(f"Decoded: {decoded_mixed}")
# Test vocabulary boundary
print("\nChecking vocabulary boundary(text):")
for i in range(99995, 100005, 1):
decoded = tokenizer.decode([i])
if decoded.strip(): # Only print non-empty decoded tokens
print(f"Token {i}: '{decoded}'")
print("\nChecking vocabulary boundary(audio):")
for i in range(100500, 100510, 1):
decoded = tokenizer.decode([i])
if decoded.strip(): # Only print non-empty decoded tokens
print(f"Token {i}: '{decoded}'")