import os import io import re import hashlib import pickle import subprocess import threading from pathlib import Path import numpy as np import soundfile as sf import torch # Set CPU threads BEFORE any torch operations torch.set_num_threads(os.cpu_count()) torch.set_num_interop_threads(max(1, os.cpu_count() // 2)) # FIX for PyTorch >=2.6 security change from torch.serialization import add_safe_globals import TTS.tts.configs.xtts_config import TTS.tts.models.xtts add_safe_globals([ TTS.tts.configs.xtts_config.XttsConfig, TTS.tts.models.xtts.XttsAudioConfig, ]) from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from TTS.api import TTS # ─── Paths & constants ──────────────────────────────────────────────────────── VOICE_DIR = Path("/voices") CACHE_DIR = Path("/cache") MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" SAMPLE_RATE = 24000 VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit is ~400 tokens ≈ 250 chars VOICE_DIR.mkdir(exist_ok=True) CACHE_DIR.mkdir(exist_ok=True) # ─── Model loading ──────────────────────────────────────────────────────────── print("Loading XTTS model...") _device = "cuda" if torch.cuda.is_available() else "cpu" tts = TTS(MODEL_NAME).to(_device) print(f"Model loaded on {_device}.") # Single lock so concurrent requests don't fight over GPU / model.to() calls _model_lock = threading.Lock() app = FastAPI() embedding_cache: dict = {} # ─── Text helpers ───────────────────────────────────────────────────────────── # Characters XTTS tokeniser chokes on → strip or replace before inference _MARKDOWN_RE = re.compile(r'\*{1,2}|_{1,2}|`+|#{1,6}\s?|~~|\[([^\]]*)\]\([^)]*\)') _MULTI_SPACE = re.compile(r' +') _CONTROL_CHARS = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]') def clean_text(text: str) -> str: """Remove markdown and control characters that corrupt XTTS tokenisation.""" text = _MARKDOWN_RE.sub(r'\1', text) # strip md, keep link label text = _CONTROL_CHARS.sub('', text) text = text.replace('\r\n', '\n').replace('\r', '\n') # Collapse multiple blank lines / spaces text = re.sub(r'\n{3,}', '\n\n', text) text = _MULTI_SPACE.sub(' ', text) return text.strip() def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]: """ Split on sentence boundaries. Falls back to word-boundary splitting for sentences that are still too long (e.g. no punctuation at all). """ text = clean_text(text) # Split on sentence-ending punctuation followed by whitespace or end sentences = re.split(r'(?<=[.!?…])\s+', text) chunks: list[str] = [] current = "" for s in sentences: s = s.strip() if not s: continue # Single sentence longer than max_len → split on word boundary if len(s) > max_len: if current: chunks.append(current) current = "" words = s.split() part = "" for w in words: if len(part) + len(w) + 1 > max_len: if part: chunks.append(part.strip()) part = w else: part = (part + " " + w).strip() if part: chunks.append(part) continue if len(current) + len(s) + 1 > max_len: if current: chunks.append(current) current = s else: current = (current + " " + s).strip() if current: chunks.append(current) return [c for c in chunks if c] # ─── Voice / embedding helpers ──────────────────────────────────────────────── def sha256_file(path: Path) -> str: h = hashlib.sha256() with open(path, "rb") as f: for block in iter(lambda: f.read(65536), b""): h.update(block) return h.hexdigest() def convert_to_wav(src: Path, dst: Path) -> None: subprocess.run( ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) def ensure_wav(voice_name: str) -> Path: wav = VOICE_DIR / f"{voice_name}.wav" mp3 = VOICE_DIR / f"{voice_name}.mp3" if wav.exists(): if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime: print(f"MP3 newer than WAV → reconverting {voice_name}") convert_to_wav(mp3, wav) return wav if mp3.exists(): print(f"Converting MP3 → WAV for {voice_name}") convert_to_wav(mp3, wav) return wav raise HTTPException(404, f"Voice '{voice_name}' not found") def get_embedding(voice_name: str): if voice_name in embedding_cache: return embedding_cache[voice_name] wav_file = ensure_wav(voice_name) file_hash = sha256_file(wav_file) cache_file = CACHE_DIR / f"{voice_name}.pkl" if cache_file.exists(): try: with open(cache_file, "rb") as f: cached = pickle.load(f) if cached.get("hash") == file_hash: print(f"Using cached embedding for {voice_name}") embedding_cache[voice_name] = cached["data"] return cached["data"] except Exception as e: print(f"Cache read error for {voice_name}: {e} – recomputing") print(f"Computing embedding for {voice_name}") model = tts.synthesizer.tts_model gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=str(wav_file) ) data = (gpt_cond_latent, speaker_embedding) with open(cache_file, "wb") as f: pickle.dump({"hash": file_hash, "data": data}, f) embedding_cache[voice_name] = data return data # ─── Core inference helper ──────────────────────────────────────────────────── def _vram_low() -> bool: if not torch.cuda.is_available(): return True free, total = torch.cuda.mem_get_info() return (free / total) < VRAM_HEADROOM def _infer_chunk(chunk: str, lang: str, gpt_cond_latent, speaker_embedding) -> np.ndarray: """Run inference for one chunk; falls back to CPU on OOM.""" model = tts.synthesizer.tts_model def _run(m, lat, emb): with torch.inference_mode(): out = m.inference(chunk, lang, lat, emb) wav = out["wav"] if isinstance(wav, torch.Tensor): wav = wav.cpu().numpy() if wav.ndim == 1: wav = np.expand_dims(wav, 1) return wav with _model_lock: try: return _run(model, gpt_cond_latent, speaker_embedding) except torch.cuda.OutOfMemoryError: print(f"⚠ CUDA OOM on chunk – falling back to CPU ({os.cpu_count()} cores)") torch.cuda.empty_cache() model.to("cpu") try: result = _run( model, gpt_cond_latent.to("cpu"), speaker_embedding.to("cpu"), ) finally: # Always move back, even if CPU inference also fails model.to("cuda") torch.cuda.empty_cache() return result # ─── Routes ─────────────────────────────────────────────────────────────────── @app.get("/") def root(): return {"status": "XTTS server running", "device": _device} @app.get("/voices") def list_voices(): seen = set() voices = [] for f in VOICE_DIR.iterdir(): if f.suffix in {".wav", ".mp3"} and f.stem not in seen: voices.append(f.stem) seen.add(f.stem) return {"voices": sorted(voices)} @app.get("/tts") @app.get("/api/tts") def synthesize(text: str, voice: str = "default", lang: str = "en"): if not text.strip(): raise HTTPException(400, "text parameter is empty") gpt_cond_latent, speaker_embedding = get_embedding(voice) # If VRAM is already scarce, pin embeddings on CPU for this whole request use_cpu = _vram_low() if use_cpu and torch.cuda.is_available(): print("⚠ Low VRAM – pinning entire request to CPU") gpt_cond_latent = gpt_cond_latent.to("cpu") speaker_embedding = speaker_embedding.to("cpu") with _model_lock: tts.synthesizer.tts_model.to("cpu") chunks = chunk_text(text) wav_all = [] for i, chunk in enumerate(chunks): print(f" chunk {i+1}/{len(chunks)}: {chunk[:60]!r}") try: wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding) except Exception as e: raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}") wav_all.append(wav_chunk) # Restore model to GPU if we moved it if use_cpu and torch.cuda.is_available(): with _model_lock: tts.synthesizer.tts_model.to("cuda") wav = np.concatenate(wav_all, axis=0) buf = io.BytesIO() sf.write(buf, wav, SAMPLE_RATE, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") @app.get("/tts_stream") @app.get("/api/tts_stream") def synthesize_stream(text: str, voice: str = "default", lang: str = "en"): """Stream WAV chunks as they are synthesised — lower latency for long texts.""" if not text.strip(): raise HTTPException(400, "text parameter is empty") gpt_cond_latent, speaker_embedding = get_embedding(voice) chunks = chunk_text(text) def audio_generator(): for i, chunk in enumerate(chunks): print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:60]!r}") try: wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding) except Exception as e: print(f" [stream] chunk {i+1} failed: {e}") continue # skip bad chunk rather than kill the stream buf = io.BytesIO() sf.write(buf, wav, SAMPLE_RATE, format="WAV") buf.seek(0) yield buf.read() return StreamingResponse(audio_generator(), media_type="audio/wav")