import os import io import hashlib import pickle import subprocess from pathlib import Path import torch # FIX for PyTorch >=2.6 security change from torch.serialization import add_safe_globals from TTS.tts.configs.xtts_config import XttsConfig import TTS.tts.configs.xtts_config import TTS.tts.models.xtts add_safe_globals([ TTS.tts.configs.xtts_config.XttsConfig, TTS.tts.models.xtts.XttsAudioConfig ]) from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from TTS.api import TTS VOICE_DIR = Path("/voices") CACHE_DIR = Path("/cache") VOICE_DIR.mkdir(exist_ok=True) CACHE_DIR.mkdir(exist_ok=True) MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" print("Loading XTTS model...") tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu") print("Model loaded.") app = FastAPI() embedding_cache = {} def sha256(path): h = hashlib.sha256() with open(path, "rb") as f: while True: chunk = f.read(8192) if not chunk: break h.update(chunk) return h.hexdigest() def ensure_wav(voice_name): wav = VOICE_DIR / f"{voice_name}.wav" mp3 = VOICE_DIR / f"{voice_name}.mp3" if wav.exists(): if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime: print(f"MP3 newer than WAV → reconverting {voice_name}") convert_to_wav(mp3, wav) return wav if mp3.exists(): print(f"Converting MP3 → WAV for {voice_name}") convert_to_wav(mp3, wav) return wav raise HTTPException(404, f"Voice '{voice_name}' not found") def convert_to_wav(src, dst): subprocess.run( [ "ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst), ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) def load_cached_embedding(cache_file): with open(cache_file, "rb") as f: return pickle.load(f) def save_cached_embedding(cache_file, data): with open(cache_file, "wb") as f: pickle.dump(data, f) def get_embedding(voice_name): if voice_name in embedding_cache: return embedding_cache[voice_name] src = None for ext in ["wav", "mp3"]: p = VOICE_DIR / f"{voice_name}.{ext}" if p.exists(): src = p break if not src: raise HTTPException(404, f"Voice '{voice_name}' not found") wav_file = ensure_wav(voice_name) # wav_file = src if src.suffix == ".wav" else convert_to_wav(src) file_hash = sha256(wav_file) cache_file = CACHE_DIR / f"{voice_name}.pkl" if cache_file.exists(): cached = load_cached_embedding(cache_file) if cached["hash"] == file_hash: print(f"Using cached embedding for {voice_name}") embedding_cache[voice_name] = cached["data"] return cached["data"] print(f"Computing embedding for {voice_name}") model = tts.synthesizer.tts_model gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=str(wav_file) ) data = (gpt_cond_latent, speaker_embedding) save_cached_embedding( cache_file, {"hash": file_hash, "data": data}, ) embedding_cache[voice_name] = data return data @app.get("/") def root(): return {"status": "XTTS server running"} @app.get("/voices") def list_voices(): voices = [] for f in VOICE_DIR.iterdir(): if f.suffix in [".wav", ".mp3"]: voices.append(f.stem) return {"voices": voices} @app.get("/tts") @app.get("/api/tts") def synthesize( text: str, voice: str = "default", lang: str = "en", ): gpt_cond_latent, speaker_embedding = get_embedding(voice) out = tts.synthesizer.tts_model.inference( text, lang, gpt_cond_latent, speaker_embedding, ) buf = io.BytesIO() import numpy as np import soundfile as sf # XTTS returns a dict, extract numpy array wav = out["wav"] # Ensure 2D for soundfile if len(wav.shape) == 1: wav = np.expand_dims(wav, 1) sf.write(buf, wav, 24000, format="WAV") # Ensure the audio is 2D: (samples, channels) #if len(wav.shape) == 1: # wav = np.expand_dims(wav, 1) #sf.write(buf, wav, 24000, format="WAV") #import soundfile as sf #sf.write(buf, wav, 24000, format="WAV", subtype="PCM_16") #sf.write(buf, wav, 24000, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav")