chatterbox / handler.py
aiplexdeveloper's picture
Update handler.py
4c6d421 verified
Raw
History Blame Contribute Delete
1.68 kB
import torchaudio as ta
from chatterbox.tts import ChatterboxTTS
from typing import Dict, Any, List
import soundfile as sf
import io
import base64
from huggingface_hub import hf_hub_download
class EndpointHandler:
def __init__(self, path: str = ""):
try:
self.model = ChatterboxTTS.from_pretrained(device="cuda")
except Exception as e:
raise RuntimeError(f"[ERROR] Failed to load model: {e}")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: #, data: Dict[str, Any]) -> List[Dict[str, Any]]
try:
inputs = data.get("inputs", {})
text = inputs.get("text")
exaggeration = inputs.get("exaggeration", 0.3)
cfg_weight = inputs.get("cfg_weight", 0.5)
print(exaggeration, cfg_weight)
AUDIO_PROMPT_PATH=hf_hub_download(repo_id="aiplexdeveloper/chatterbox", filename="arjun_das_output_audio.mp3")
wav = self.model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH, exaggeration = exaggeration, cfg_weight=cfg_weight)
buffer = io.BytesIO()
sf.write(buffer, wav.cpu().numpy().T, self.model.sr, format='WAV')
buffer.seek(0)
# Encode to base64
audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
wav_squeeze = wav.squeeze() # Shape becomes [960000]
audio_length_seconds = len(wav_squeeze) / self.model.sr
return [{"audio_base64": audio_base64, "audio_length_seconds":audio_length_seconds}]
except Exception as e:
print(f"[ERROR] Inference failed: {e}")
return [{"error": str(e)}]