tts_server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import os
  2. import io
  3. import re
  4. import hashlib
  5. import pickle
  6. import subprocess
  7. import threading
  8. from pathlib import Path
  9. import numpy as np
  10. import soundfile as sf
  11. import torch
  12. # Set CPU threads BEFORE any torch operations
  13. torch.set_num_threads(os.cpu_count())
  14. torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
  15. # FIX for PyTorch >=2.6 security change
  16. from torch.serialization import add_safe_globals
  17. import TTS.tts.configs.xtts_config
  18. import TTS.tts.models.xtts
  19. add_safe_globals([
  20. TTS.tts.configs.xtts_config.XttsConfig,
  21. TTS.tts.models.xtts.XttsAudioConfig,
  22. ])
  23. from fastapi import FastAPI, HTTPException
  24. from fastapi.responses import StreamingResponse
  25. from TTS.api import TTS
  26. # ─── Paths & constants ────────────────────────────────────────────────────────
  27. VOICE_DIR = Path("/voices")
  28. CACHE_DIR = Path("/cache")
  29. MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
  30. SAMPLE_RATE = 24000
  31. VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free
  32. MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit is ~400 tokens ≈ 250 chars
  33. VOICE_DIR.mkdir(exist_ok=True)
  34. CACHE_DIR.mkdir(exist_ok=True)
  35. # ─── Model loading ────────────────────────────────────────────────────────────
  36. print("Loading XTTS model...")
  37. _device = "cuda" if torch.cuda.is_available() else "cpu"
  38. tts = TTS(MODEL_NAME).to(_device)
  39. print(f"Model loaded on {_device}.")
  40. # Single lock so concurrent requests don't fight over GPU / model.to() calls
  41. _model_lock = threading.Lock()
  42. app = FastAPI()
  43. embedding_cache: dict = {}
  44. # ─── Text helpers ─────────────────────────────────────────────────────────────
  45. # Characters XTTS tokeniser chokes on → strip or replace before inference
  46. _MARKDOWN_RE = re.compile(r'\*{1,2}|_{1,2}|`+|#{1,6}\s?|~~|\[([^\]]*)\]\([^)]*\)')
  47. _MULTI_SPACE = re.compile(r' +')
  48. _CONTROL_CHARS = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
  49. def clean_text(text: str) -> str:
  50. """Remove markdown and control characters that corrupt XTTS tokenisation."""
  51. text = _MARKDOWN_RE.sub(r'\1', text) # strip md, keep link label
  52. text = _CONTROL_CHARS.sub('', text)
  53. text = text.replace('\r\n', '\n').replace('\r', '\n')
  54. # Collapse multiple blank lines / spaces
  55. text = re.sub(r'\n{3,}', '\n\n', text)
  56. text = _MULTI_SPACE.sub(' ', text)
  57. return text.strip()
  58. def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
  59. """
  60. Split on sentence boundaries. Falls back to word-boundary splitting
  61. for sentences that are still too long (e.g. no punctuation at all).
  62. """
  63. text = clean_text(text)
  64. # Split on sentence-ending punctuation followed by whitespace or end
  65. sentences = re.split(r'(?<=[.!?…])\s+', text)
  66. chunks: list[str] = []
  67. current = ""
  68. for s in sentences:
  69. s = s.strip()
  70. if not s:
  71. continue
  72. # Single sentence longer than max_len → split on word boundary
  73. if len(s) > max_len:
  74. if current:
  75. chunks.append(current)
  76. current = ""
  77. words = s.split()
  78. part = ""
  79. for w in words:
  80. if len(part) + len(w) + 1 > max_len:
  81. if part:
  82. chunks.append(part.strip())
  83. part = w
  84. else:
  85. part = (part + " " + w).strip()
  86. if part:
  87. chunks.append(part)
  88. continue
  89. if len(current) + len(s) + 1 > max_len:
  90. if current:
  91. chunks.append(current)
  92. current = s
  93. else:
  94. current = (current + " " + s).strip()
  95. if current:
  96. chunks.append(current)
  97. return [c for c in chunks if c]
  98. # ─── Voice / embedding helpers ────────────────────────────────────────────────
  99. def sha256_file(path: Path) -> str:
  100. h = hashlib.sha256()
  101. with open(path, "rb") as f:
  102. for block in iter(lambda: f.read(65536), b""):
  103. h.update(block)
  104. return h.hexdigest()
  105. def convert_to_wav(src: Path, dst: Path) -> None:
  106. subprocess.run(
  107. ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
  108. check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
  109. )
  110. def ensure_wav(voice_name: str) -> Path:
  111. wav = VOICE_DIR / f"{voice_name}.wav"
  112. mp3 = VOICE_DIR / f"{voice_name}.mp3"
  113. if wav.exists():
  114. if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
  115. print(f"MP3 newer than WAV → reconverting {voice_name}")
  116. convert_to_wav(mp3, wav)
  117. return wav
  118. if mp3.exists():
  119. print(f"Converting MP3 → WAV for {voice_name}")
  120. convert_to_wav(mp3, wav)
  121. return wav
  122. raise HTTPException(404, f"Voice '{voice_name}' not found")
  123. def get_embedding(voice_name: str):
  124. if voice_name in embedding_cache:
  125. return embedding_cache[voice_name]
  126. wav_file = ensure_wav(voice_name)
  127. file_hash = sha256_file(wav_file)
  128. cache_file = CACHE_DIR / f"{voice_name}.pkl"
  129. if cache_file.exists():
  130. try:
  131. with open(cache_file, "rb") as f:
  132. cached = pickle.load(f)
  133. if cached.get("hash") == file_hash:
  134. print(f"Using cached embedding for {voice_name}")
  135. embedding_cache[voice_name] = cached["data"]
  136. return cached["data"]
  137. except Exception as e:
  138. print(f"Cache read error for {voice_name}: {e} – recomputing")
  139. print(f"Computing embedding for {voice_name}")
  140. model = tts.synthesizer.tts_model
  141. gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
  142. audio_path=str(wav_file)
  143. )
  144. data = (gpt_cond_latent, speaker_embedding)
  145. with open(cache_file, "wb") as f:
  146. pickle.dump({"hash": file_hash, "data": data}, f)
  147. embedding_cache[voice_name] = data
  148. return data
  149. # ─── Core inference helper ────────────────────────────────────────────────────
  150. def _vram_low() -> bool:
  151. if not torch.cuda.is_available():
  152. return True
  153. free, total = torch.cuda.mem_get_info()
  154. return (free / total) < VRAM_HEADROOM
  155. def _infer_chunk(chunk: str, lang: str, gpt_cond_latent, speaker_embedding) -> np.ndarray:
  156. """Run inference for one chunk; falls back to CPU on OOM."""
  157. model = tts.synthesizer.tts_model
  158. def _run(m, lat, emb):
  159. with torch.inference_mode():
  160. out = m.inference(chunk, lang, lat, emb)
  161. wav = out["wav"]
  162. if isinstance(wav, torch.Tensor):
  163. wav = wav.cpu().numpy()
  164. if wav.ndim == 1:
  165. wav = np.expand_dims(wav, 1)
  166. return wav
  167. with _model_lock:
  168. try:
  169. return _run(model, gpt_cond_latent, speaker_embedding)
  170. except torch.cuda.OutOfMemoryError:
  171. print(f"⚠ CUDA OOM on chunk – falling back to CPU ({os.cpu_count()} cores)")
  172. torch.cuda.empty_cache()
  173. model.to("cpu")
  174. try:
  175. result = _run(
  176. model,
  177. gpt_cond_latent.to("cpu"),
  178. speaker_embedding.to("cpu"),
  179. )
  180. finally:
  181. # Always move back, even if CPU inference also fails
  182. model.to("cuda")
  183. torch.cuda.empty_cache()
  184. return result
  185. # ─── Routes ───────────────────────────────────────────────────────────────────
  186. @app.get("/")
  187. def root():
  188. return {"status": "XTTS server running", "device": _device}
  189. @app.get("/voices")
  190. def list_voices():
  191. seen = set()
  192. voices = []
  193. for f in VOICE_DIR.iterdir():
  194. if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
  195. voices.append(f.stem)
  196. seen.add(f.stem)
  197. return {"voices": sorted(voices)}
  198. @app.get("/tts")
  199. @app.get("/api/tts")
  200. def synthesize(text: str, voice: str = "default", lang: str = "en"):
  201. if not text.strip():
  202. raise HTTPException(400, "text parameter is empty")
  203. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  204. # If VRAM is already scarce, pin embeddings on CPU for this whole request
  205. use_cpu = _vram_low()
  206. if use_cpu and torch.cuda.is_available():
  207. print("⚠ Low VRAM – pinning entire request to CPU")
  208. gpt_cond_latent = gpt_cond_latent.to("cpu")
  209. speaker_embedding = speaker_embedding.to("cpu")
  210. with _model_lock:
  211. tts.synthesizer.tts_model.to("cpu")
  212. chunks = chunk_text(text)
  213. wav_all = []
  214. for i, chunk in enumerate(chunks):
  215. print(f" chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
  216. try:
  217. wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
  218. except Exception as e:
  219. raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
  220. wav_all.append(wav_chunk)
  221. # Restore model to GPU if we moved it
  222. if use_cpu and torch.cuda.is_available():
  223. with _model_lock:
  224. tts.synthesizer.tts_model.to("cuda")
  225. wav = np.concatenate(wav_all, axis=0)
  226. buf = io.BytesIO()
  227. sf.write(buf, wav, SAMPLE_RATE, format="WAV")
  228. buf.seek(0)
  229. return StreamingResponse(buf, media_type="audio/wav")
  230. @app.get("/tts_stream")
  231. @app.get("/api/tts_stream")
  232. def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
  233. """Stream WAV chunks as they are synthesised — lower latency for long texts."""
  234. if not text.strip():
  235. raise HTTPException(400, "text parameter is empty")
  236. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  237. chunks = chunk_text(text)
  238. def audio_generator():
  239. for i, chunk in enumerate(chunks):
  240. print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
  241. try:
  242. wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
  243. except Exception as e:
  244. print(f" [stream] chunk {i+1} failed: {e}")
  245. continue # skip bad chunk rather than kill the stream
  246. buf = io.BytesIO()
  247. sf.write(buf, wav, SAMPLE_RATE, format="WAV")
  248. buf.seek(0)
  249. yield buf.read()
  250. return StreamingResponse(audio_generator(), media_type="audio/wav")