import os import io import hashlib import pickle import subprocess from pathlib import Path import torch import numpy as np import soundfile as sf import re # FIX for PyTorch >=2.6 security change from torch.serialization import add_safe_globals from TTS.tts.configs.xtts_config import XttsConfig import TTS.tts.configs.xtts_config import TTS.tts.models.xtts add_safe_globals([ TTS.tts.configs.xtts_config.XttsConfig, TTS.tts.models.xtts.XttsAudioConfig ]) from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from TTS.api import TTS # Set CPU threads BEFORE any torch operations - do this at module level torch.set_num_threads(os.cpu_count()) torch.set_num_interop_threads(max(1, os.cpu_count() // 2)) def chunk_text(text, max_len=150): sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current = "" for s in sentences: if len(current) + len(s) > max_len: if current: chunks.append(current.strip()) current = s else: current += " " + s if current: chunks.append(current.strip()) return chunks VOICE_DIR = Path("/voices") CACHE_DIR = Path("/cache") VOICE_DIR.mkdir(exist_ok=True) CACHE_DIR.mkdir(exist_ok=True) MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" print("Loading XTTS model...") tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu") print("Model loaded.") app = FastAPI() embedding_cache = {} def sha256(path): h = hashlib.sha256() with open(path, "rb") as f: while True: chunk = f.read(8192) if not chunk: break h.update(chunk) return h.hexdigest() def ensure_wav(voice_name): wav = VOICE_DIR / f"{voice_name}.wav" mp3 = VOICE_DIR / f"{voice_name}.mp3" if wav.exists(): if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime: print(f"MP3 newer than WAV → reconverting {voice_name}") convert_to_wav(mp3, wav) return wav if mp3.exists(): print(f"Converting MP3 → WAV for {voice_name}") convert_to_wav(mp3, wav) return wav raise HTTPException(404, f"Voice '{voice_name}' not found") def convert_to_wav(src, dst): subprocess.run( [ "ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst), ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) def load_cached_embedding(cache_file): with open(cache_file, "rb") as f: return pickle.load(f) def save_cached_embedding(cache_file, data): with open(cache_file, "wb") as f: pickle.dump(data, f) def get_embedding(voice_name): if voice_name in embedding_cache: return embedding_cache[voice_name] src = None for ext in ["wav", "mp3"]: p = VOICE_DIR / f"{voice_name}.{ext}" if p.exists(): src = p break if not src: raise HTTPException(404, f"Voice '{voice_name}' not found") wav_file = ensure_wav(voice_name) # wav_file = src if src.suffix == ".wav" else convert_to_wav(src) file_hash = sha256(wav_file) cache_file = CACHE_DIR / f"{voice_name}.pkl" if cache_file.exists(): cached = load_cached_embedding(cache_file) if cached["hash"] == file_hash: print(f"Using cached embedding for {voice_name}") embedding_cache[voice_name] = cached["data"] return cached["data"] print(f"Computing embedding for {voice_name}") model = tts.synthesizer.tts_model gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=str(wav_file) ) data = (gpt_cond_latent, speaker_embedding) save_cached_embedding( cache_file, {"hash": file_hash, "data": data}, ) embedding_cache[voice_name] = data return data @app.get("/") def root(): return {"status": "XTTS server running"} @app.get("/voices") def list_voices(): voices = [] for f in VOICE_DIR.iterdir(): if f.suffix in [".wav", ".mp3"]: voices.append(f.stem) return {"voices": voices} @app.get("/tts") @app.get("/api/tts") def synthesize( text: str, voice: str = "default", lang: str = "en", ): gpt_cond_latent, speaker_embedding = get_embedding(voice) text_chunks = chunk_text(text, max_len=150) wav_all = [] # Move model to CPU once if GPU is already under pressure use_cpu_fallback = False if torch.cuda.is_available(): free_mem, total_mem = torch.cuda.mem_get_info() if free_mem / total_mem < 0.2: # less than 20% VRAM free print("⚠ Low VRAM detected – using CPU for this request") use_cpu_fallback = True for chunk in text_chunks: try: if use_cpu_fallback: raise torch.cuda.OutOfMemoryError # skip straight to CPU path out = tts.synthesizer.tts_model.inference( chunk, lang, gpt_cond_latent, speaker_embedding, ) except torch.cuda.OutOfMemoryError: print(f"⚠ CUDA OOM – retrying chunk on CPU ({os.cpu_count()} cores)") torch.cuda.empty_cache() use_cpu_fallback = True # stay on CPU for remaining chunks with torch.inference_mode(): # faster, no grad tracking cpu_model = tts.synthesizer.tts_model.to("cpu") out = cpu_model.inference( chunk, lang, gpt_cond_latent.to("cpu"), speaker_embedding.to("cpu"), ) tts.synthesizer.tts_model.to("cuda") wav_chunk = out["wav"] if isinstance(wav_chunk, torch.Tensor): wav_chunk = wav_chunk.cpu().numpy() if len(wav_chunk.shape) == 1: wav_chunk = np.expand_dims(wav_chunk, 1) wav_all.append(wav_chunk) if torch.cuda.is_available() and not use_cpu_fallback: torch.cuda.empty_cache() wav = np.concatenate(wav_all, axis=0) buf = io.BytesIO() sf.write(buf, wav, 24000, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") @app.get("/tts_stream") @app.get("/api/tts_stream") def synthesize_stream( text: str, voice: str = "default", lang: str = "en", ): import numpy as np import torch import soundfile as sf import re import io def chunk_text(text, max_len=150): sentences = re.split(r'(?<=[.!?])\s+', text) chunks = [] current = "" for s in sentences: if len(current) + len(s) > max_len: if current: chunks.append(current.strip()) current = s else: current += " " + s if current: chunks.append(current.strip()) return chunks gpt_cond_latent, speaker_embedding = get_embedding(voice) text_chunks = chunk_text(text) def audio_generator(): for chunk in text_chunks: try: out = tts.synthesizer.tts_model.inference( chunk, lang, gpt_cond_latent, speaker_embedding, ) except torch.cuda.OutOfMemoryError: print("CUDA OOM – retrying on CPU") torch.cuda.empty_cache() cpu_model = tts.synthesizer.tts_model.to("cpu") out = cpu_model.inference( chunk, lang, gpt_cond_latent.to("cpu"), speaker_embedding.to("cpu"), ) tts.synthesizer.tts_model.to("cuda") wav = out["wav"] buf = io.BytesIO() sf.write(buf, wav, 24000, format="WAV") buf.seek(0) yield buf.read() if torch.cuda.is_available(): torch.cuda.empty_cache() return StreamingResponse(audio_generator(), media_type="audio/wav")