| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309 |
- import os
- import io
- import re
- import hashlib
- import pickle
- import subprocess
- import threading
- from pathlib import Path
- import numpy as np
- import soundfile as sf
- import torch
- # Set CPU threads BEFORE any torch operations
- torch.set_num_threads(os.cpu_count())
- torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
- # FIX for PyTorch >=2.6 security change
- from torch.serialization import add_safe_globals
- import TTS.tts.configs.xtts_config
- import TTS.tts.models.xtts
- add_safe_globals([
- TTS.tts.configs.xtts_config.XttsConfig,
- TTS.tts.models.xtts.XttsAudioConfig,
- ])
- from fastapi import FastAPI, HTTPException
- from fastapi.responses import StreamingResponse
- from TTS.api import TTS
- # ─── Paths & constants ────────────────────────────────────────────────────────
- VOICE_DIR = Path("/voices")
- CACHE_DIR = Path("/cache")
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
- SAMPLE_RATE = 24000
- VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free
- MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit is ~400 tokens ≈ 250 chars
- VOICE_DIR.mkdir(exist_ok=True)
- CACHE_DIR.mkdir(exist_ok=True)
- # ─── Model loading ────────────────────────────────────────────────────────────
- print("Loading XTTS model...")
- _device = "cuda" if torch.cuda.is_available() else "cpu"
- tts = TTS(MODEL_NAME).to(_device)
- print(f"Model loaded on {_device}.")
- # Single lock so concurrent requests don't fight over GPU / model.to() calls
- _model_lock = threading.Lock()
- app = FastAPI()
- embedding_cache: dict = {}
- # ─── Text helpers ─────────────────────────────────────────────────────────────
- # Characters XTTS tokeniser chokes on → strip or replace before inference
- _MARKDOWN_RE = re.compile(r'\*{1,2}|_{1,2}|`+|#{1,6}\s?|~~|\[([^\]]*)\]\([^)]*\)')
- _MULTI_SPACE = re.compile(r' +')
- _CONTROL_CHARS = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
- def clean_text(text: str) -> str:
- """Remove markdown and control characters that corrupt XTTS tokenisation."""
- text = _MARKDOWN_RE.sub(r'\1', text) # strip md, keep link label
- text = _CONTROL_CHARS.sub('', text)
- text = text.replace('\r\n', '\n').replace('\r', '\n')
- # Collapse multiple blank lines / spaces
- text = re.sub(r'\n{3,}', '\n\n', text)
- text = _MULTI_SPACE.sub(' ', text)
- return text.strip()
- def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
- """
- Split on sentence boundaries. Falls back to word-boundary splitting
- for sentences that are still too long (e.g. no punctuation at all).
- """
- text = clean_text(text)
- # Split on sentence-ending punctuation followed by whitespace or end
- sentences = re.split(r'(?<=[.!?…])\s+', text)
- chunks: list[str] = []
- current = ""
- for s in sentences:
- s = s.strip()
- if not s:
- continue
- # Single sentence longer than max_len → split on word boundary
- if len(s) > max_len:
- if current:
- chunks.append(current)
- current = ""
- words = s.split()
- part = ""
- for w in words:
- if len(part) + len(w) + 1 > max_len:
- if part:
- chunks.append(part.strip())
- part = w
- else:
- part = (part + " " + w).strip()
- if part:
- chunks.append(part)
- continue
- if len(current) + len(s) + 1 > max_len:
- if current:
- chunks.append(current)
- current = s
- else:
- current = (current + " " + s).strip()
- if current:
- chunks.append(current)
- return [c for c in chunks if c]
- # ─── Voice / embedding helpers ────────────────────────────────────────────────
- def sha256_file(path: Path) -> str:
- h = hashlib.sha256()
- with open(path, "rb") as f:
- for block in iter(lambda: f.read(65536), b""):
- h.update(block)
- return h.hexdigest()
- def convert_to_wav(src: Path, dst: Path) -> None:
- subprocess.run(
- ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
- check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
- )
- def ensure_wav(voice_name: str) -> Path:
- wav = VOICE_DIR / f"{voice_name}.wav"
- mp3 = VOICE_DIR / f"{voice_name}.mp3"
- if wav.exists():
- if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
- print(f"MP3 newer than WAV → reconverting {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- if mp3.exists():
- print(f"Converting MP3 → WAV for {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- def get_embedding(voice_name: str):
- if voice_name in embedding_cache:
- return embedding_cache[voice_name]
- wav_file = ensure_wav(voice_name)
- file_hash = sha256_file(wav_file)
- cache_file = CACHE_DIR / f"{voice_name}.pkl"
- if cache_file.exists():
- try:
- with open(cache_file, "rb") as f:
- cached = pickle.load(f)
- if cached.get("hash") == file_hash:
- print(f"Using cached embedding for {voice_name}")
- embedding_cache[voice_name] = cached["data"]
- return cached["data"]
- except Exception as e:
- print(f"Cache read error for {voice_name}: {e} – recomputing")
- print(f"Computing embedding for {voice_name}")
- model = tts.synthesizer.tts_model
- gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
- audio_path=str(wav_file)
- )
- data = (gpt_cond_latent, speaker_embedding)
- with open(cache_file, "wb") as f:
- pickle.dump({"hash": file_hash, "data": data}, f)
- embedding_cache[voice_name] = data
- return data
- # ─── Core inference helper ────────────────────────────────────────────────────
- def _vram_low() -> bool:
- if not torch.cuda.is_available():
- return True
- free, total = torch.cuda.mem_get_info()
- return (free / total) < VRAM_HEADROOM
- def _infer_chunk(chunk: str, lang: str, gpt_cond_latent, speaker_embedding) -> np.ndarray:
- """Run inference for one chunk; falls back to CPU on OOM."""
- model = tts.synthesizer.tts_model
- def _run(m, lat, emb):
- with torch.inference_mode():
- out = m.inference(chunk, lang, lat, emb)
- wav = out["wav"]
- if isinstance(wav, torch.Tensor):
- wav = wav.cpu().numpy()
- if wav.ndim == 1:
- wav = np.expand_dims(wav, 1)
- return wav
- with _model_lock:
- try:
- return _run(model, gpt_cond_latent, speaker_embedding)
- except torch.cuda.OutOfMemoryError:
- print(f"⚠ CUDA OOM on chunk – falling back to CPU ({os.cpu_count()} cores)")
- torch.cuda.empty_cache()
- model.to("cpu")
- try:
- result = _run(
- model,
- gpt_cond_latent.to("cpu"),
- speaker_embedding.to("cpu"),
- )
- finally:
- # Always move back, even if CPU inference also fails
- model.to("cuda")
- torch.cuda.empty_cache()
- return result
- # ─── Routes ───────────────────────────────────────────────────────────────────
- @app.get("/")
- def root():
- return {"status": "XTTS server running", "device": _device}
- @app.get("/voices")
- def list_voices():
- seen = set()
- voices = []
- for f in VOICE_DIR.iterdir():
- if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
- voices.append(f.stem)
- seen.add(f.stem)
- return {"voices": sorted(voices)}
- @app.get("/tts")
- @app.get("/api/tts")
- def synthesize(text: str, voice: str = "default", lang: str = "en"):
- if not text.strip():
- raise HTTPException(400, "text parameter is empty")
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- # If VRAM is already scarce, pin embeddings on CPU for this whole request
- use_cpu = _vram_low()
- if use_cpu and torch.cuda.is_available():
- print("⚠ Low VRAM – pinning entire request to CPU")
- gpt_cond_latent = gpt_cond_latent.to("cpu")
- speaker_embedding = speaker_embedding.to("cpu")
- with _model_lock:
- tts.synthesizer.tts_model.to("cpu")
- chunks = chunk_text(text)
- wav_all = []
- for i, chunk in enumerate(chunks):
- print(f" chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
- try:
- wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
- except Exception as e:
- raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
- wav_all.append(wav_chunk)
- # Restore model to GPU if we moved it
- if use_cpu and torch.cuda.is_available():
- with _model_lock:
- tts.synthesizer.tts_model.to("cuda")
- wav = np.concatenate(wav_all, axis=0)
- buf = io.BytesIO()
- sf.write(buf, wav, SAMPLE_RATE, format="WAV")
- buf.seek(0)
- return StreamingResponse(buf, media_type="audio/wav")
- @app.get("/tts_stream")
- @app.get("/api/tts_stream")
- def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
- """Stream WAV chunks as they are synthesised — lower latency for long texts."""
- if not text.strip():
- raise HTTPException(400, "text parameter is empty")
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- chunks = chunk_text(text)
- def audio_generator():
- for i, chunk in enumerate(chunks):
- print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
- try:
- wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
- except Exception as e:
- print(f" [stream] chunk {i+1} failed: {e}")
- continue # skip bad chunk rather than kill the stream
- buf = io.BytesIO()
- sf.write(buf, wav, SAMPLE_RATE, format="WAV")
- buf.seek(0)
- yield buf.read()
- return StreamingResponse(audio_generator(), media_type="audio/wav")
|