tts_server.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import os
  2. import io
  3. import re
  4. import hashlib
  5. import pickle
  6. import subprocess
  7. import threading
  8. from pathlib import Path
  9. import numpy as np
  10. import soundfile as sf
  11. import torch
  12. # Set CPU threads BEFORE any torch operations
  13. torch.set_num_threads(os.cpu_count())
  14. torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
  15. # FIX for PyTorch >=2.6 security change
  16. from torch.serialization import add_safe_globals
  17. import TTS.tts.configs.xtts_config
  18. import TTS.tts.models.xtts
  19. add_safe_globals([
  20. TTS.tts.configs.xtts_config.XttsConfig,
  21. TTS.tts.models.xtts.XttsAudioConfig,
  22. ])
  23. from fastapi import FastAPI, HTTPException
  24. from fastapi.responses import StreamingResponse
  25. from TTS.api import TTS
  26. # ─── Paths & constants ────────────────────────────────────────────────────────
  27. VOICE_DIR = Path("/voices")
  28. CACHE_DIR = Path("/cache")
  29. MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
  30. SAMPLE_RATE = 24000
  31. VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free
  32. MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit ~400 tokens ≈ 250 chars
  33. VOICE_DIR.mkdir(exist_ok=True)
  34. CACHE_DIR.mkdir(exist_ok=True)
  35. # ─── Model loading ────────────────────────────────────────────────────────────
  36. print("Loading XTTS model...")
  37. _device = "cuda" if torch.cuda.is_available() else "cpu"
  38. tts = TTS(MODEL_NAME).to(_device)
  39. print(f"Model loaded on {_device}.")
  40. # Serialise all model access so concurrent requests don't race on .to() calls
  41. _model_lock = threading.Lock()
  42. app = FastAPI()
  43. embedding_cache: dict = {}
  44. # ─── Acronym / symbol tables ──────────────────────────────────────────────────
  45. #
  46. # Keys are matched as whole words (word-boundary regex).
  47. # Values are phonetic spellings XTTS pronounces letter-by-letter.
  48. # Hyphens between letters reliably force individual-letter pronunciation.
  49. #
  50. # German rule: spell every letter using German letter names.
  51. # English rule: most common EN acronyms are already correct; only fix known
  52. # bad ones (mainly German acronyms appearing in mixed text).
  53. ACRONYMS_DE: dict[str, str] = {
  54. # ── Technology / computing ───────────────────────────────────────────────
  55. "KI": "Ka-I",
  56. "IT": "I-Te",
  57. "PC": "Pe-Tse",
  58. "API": "A-Pe-I",
  59. "URL": "U-Er-El",
  60. "HTTP": "Ha-Te-Te-Pe",
  61. "AI": "Ei-Ei", # English loanword in German text
  62. "ML": "Em-El",
  63. "UI": "U-I",
  64. "GPU": "Ge-Pe-U",
  65. "CPU": "Tse-Pe-U",
  66. # ── Geography / politics ─────────────────────────────────────────────────
  67. "EU": "E-U",
  68. "US": "U-Es",
  69. "USA": "U-Es-A",
  70. "UK": "U-Ka",
  71. "UN": "U-En",
  72. "NATO": "NATO", # spoken as a word in German too
  73. "BRD": "Be-Er-De",
  74. "DDR": "De-De-Er",
  75. "SPD": "Es-Pe-De",
  76. "CDU": "Tse-De-U",
  77. "CSU": "Tse-Es-U",
  78. "FDP": "Ef-De-Pe",
  79. "AfD": "A-Ef-De",
  80. "ÖVP": "Ö-Fau-Pe",
  81. "FPÖ": "Ef-Pe-Ö",
  82. # ── Business / finance ───────────────────────────────────────────────────
  83. "AG": "A-Ge",
  84. "GmbH": "Ge-Em-Be-Ha",
  85. "CEO": "Tse-E-O",
  86. "CFO": "Tse-Ef-O",
  87. "CTO": "Tse-Te-O",
  88. "HR": "Ha-Er",
  89. "PR": "Pe-Er",
  90. "BIP": "Be-I-Pe",
  91. "EZB": "E-Tse-Be",
  92. "IWF": "I-Ve-Ef",
  93. "WTO": "Ve-Te-O",
  94. # ── Media / broadcasting ─────────────────────────────────────────────────
  95. "ARD": "A-Er-De",
  96. "ZDF": "Tse-De-Ef",
  97. "ORF": "O-Er-Ef",
  98. "SRF": "Es-Er-Ef",
  99. "WDR": "Ve-De-Er",
  100. "NDR": "En-De-Er",
  101. "MDR": "Em-De-Er",
  102. # ── Units / symbols (text substitution) ──────────────────────────────────
  103. "€": "Euro",
  104. "$": "Dollar",
  105. "£": "Pfund",
  106. "%": "Prozent",
  107. "°C": "Grad Celsius",
  108. "°F": "Grad Fahrenheit",
  109. "km": "Kilometer",
  110. "kg": "Kilogramm",
  111. # ── Common German abbreviations ───────────────────────────────────────────
  112. "bzw.": "beziehungsweise",
  113. "ca.": "circa",
  114. "usw.": "und so weiter",
  115. "z.B.": "zum Beispiel",
  116. "d.h.": "das heißt",
  117. "u.a.": "unter anderem",
  118. "etc.": "etcetera",
  119. "Nr.": "Nummer",
  120. "vs.": "versus",
  121. "Dr.": "Doktor",
  122. "Prof.": "Professor",
  123. "Hrsg.": "Herausgeber",
  124. "Jh.": "Jahrhundert",
  125. "Mrd.": "Milliarden",
  126. "Mio.": "Millionen",
  127. }
  128. ACRONYMS_EN: dict[str, str] = {
  129. # Only list acronyms that XTTS mispronounces in English context.
  130. # German acronyms that appear in English/mixed text:
  131. "KI": "Kay Eye",
  132. "EU": "E-U",
  133. "BRD": "B-R-D",
  134. "DDR": "D-D-R",
  135. "GmbH": "G-m-b-H",
  136. "EZB": "E-Z-B",
  137. "ARD": "A-R-D",
  138. "ZDF": "Z-D-F",
  139. "ORF": "O-R-F",
  140. "SRF": "S-R-F",
  141. "WDR": "W-D-R",
  142. "NDR": "N-D-R",
  143. "MDR": "M-D-R",
  144. # Units / symbols
  145. "€": "euros",
  146. "$": "dollars",
  147. "£": "pounds",
  148. "%": "percent",
  149. "°C": "degrees Celsius",
  150. "°F": "degrees Fahrenheit",
  151. "km": "kilometers",
  152. "kg": "kilograms",
  153. # Abbreviations
  154. "vs.": "versus",
  155. "etc.": "et cetera",
  156. "Dr.": "Doctor",
  157. "Prof.": "Professor",
  158. "Nr.": "Number",
  159. "Mrd.": "billion",
  160. "Mio.": "million",
  161. }
  162. def _build_acronym_pattern(table: dict[str, str]) -> re.Pattern:
  163. """
  164. Compile a single regex matching all keys as whole tokens.
  165. Longer keys take priority (sorted descending by length).
  166. Pure-symbol keys (€, $, °C) are matched without word boundaries.
  167. """
  168. word_keys = sorted([k for k in table if re.match(r'\w', k)], key=len, reverse=True)
  169. special_keys = sorted([k for k in table if not re.match(r'\w', k)], key=len, reverse=True)
  170. parts = [r'\b' + re.escape(k) + r'\b' for k in word_keys]
  171. parts += [re.escape(k) for k in special_keys]
  172. return re.compile('|'.join(parts)) if parts else re.compile(r'(?!)')
  173. _PATTERN_DE = _build_acronym_pattern(ACRONYMS_DE)
  174. _PATTERN_EN = _build_acronym_pattern(ACRONYMS_EN)
  175. def expand_acronyms(text: str, lang: str) -> str:
  176. """Replace acronyms/symbols with phonetic expansions for the given language."""
  177. if lang.startswith("de"):
  178. table, pattern = ACRONYMS_DE, _PATTERN_DE
  179. else:
  180. table, pattern = ACRONYMS_EN, _PATTERN_EN
  181. return pattern.sub(lambda m: table[m.group(0)], text)
  182. # ─── Markdown → natural speech ────────────────────────────────────────────────
  183. #
  184. # XTTS has no SSML support, but punctuation shapes prosody directly:
  185. # Period → short stop / breath
  186. # Ellipsis "..." → longer, contemplative pause
  187. # Comma → brief breath
  188. #
  189. # Mapping:
  190. # H1 → "..." before + text + "." + "..." after (longest pause)
  191. # H2 / H3 → "." before + text + "." (medium pause)
  192. # H4–H6 → text + "." (small pause)
  193. # **bold** → ", " + text + "," (emphasis breath)
  194. # *italic* → ", " + text + ","
  195. # Bullets → ", " + text + "." (list breath)
  196. # Blank line → "." (paragraph stop)
  197. # Code block → plain text, fences stripped
  198. # Link → label text only
  199. # HR --- → "..." (section break)
  200. _RE_HR = re.compile(r'^\s*[-*_]{3,}\s*$', re.MULTILINE)
  201. _RE_CODE_BLOCK = re.compile(r'```[\s\S]*?```')
  202. _RE_INLINE_CODE = re.compile(r'`[^`]+`')
  203. _RE_H1 = re.compile(r'^#\s+(.+)$', re.MULTILINE)
  204. _RE_H2 = re.compile(r'^#{2,3}\s+(.+)$', re.MULTILINE)
  205. _RE_H_DEEP = re.compile(r'^#{4,6}\s+(.+)$', re.MULTILINE)
  206. _RE_BOLD_ITALIC = re.compile(r'\*{3}(.+?)\*{3}|_{3}(.+?)_{3}')
  207. _RE_BOLD = re.compile(r'\*{2}(.+?)\*{2}|_{2}(.+?)_{2}')
  208. _RE_ITALIC = re.compile(r'\*(.+?)\*|_(.+?)_')
  209. _RE_LINK = re.compile(r'\[([^\]]+)\]\([^)]*\)')
  210. _RE_BULLET = re.compile(r'^\s*[-*+]\s+(.+)$', re.MULTILINE)
  211. _RE_NUMBERED = re.compile(r'^\s*\d+\.\s+(.+)$', re.MULTILINE)
  212. _RE_BLOCKQUOTE = re.compile(r'^\s*>\s+(.+)$', re.MULTILINE)
  213. _RE_MULTI_SPACE = re.compile(r' +')
  214. _RE_MULTI_DOTS = re.compile(r'\.{4,}')
  215. _RE_CONTROL = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
  216. def markdown_to_speech_text(text: str) -> str:
  217. """
  218. Convert markdown to plain text shaped for natural TTS prosody.
  219. Uses only punctuation cues — no spoken labels.
  220. """
  221. # 1. Normalise line endings + strip control chars
  222. text = text.replace('\r\n', '\n').replace('\r', '\n')
  223. text = _RE_CONTROL.sub('', text)
  224. # 2. Code blocks → plain text (strip fences, keep content)
  225. text = _RE_CODE_BLOCK.sub(
  226. lambda m: m.group(0).split('\n', 1)[-1].rsplit('\n', 1)[0], text
  227. )
  228. text = _RE_INLINE_CODE.sub(lambda m: m.group(0).strip('`'), text)
  229. # 3. Horizontal rules → long section-break pause
  230. text = _RE_HR.sub('\n...\n', text)
  231. # 4. Headings — longest pause for H1, medium for H2/H3, small for H4+
  232. text = _RE_H1.sub(r'\n...\n\1.\n...\n', text)
  233. text = _RE_H2.sub(r'\n.\n\1.\n', text)
  234. text = _RE_H_DEEP.sub(r'\n\1.\n', text)
  235. # 5. Blockquotes → comma-padded inline
  236. text = _RE_BLOCKQUOTE.sub(r', \1,', text)
  237. # 6. Inline emphasis — extract text, add comma-pauses
  238. text = _RE_BOLD_ITALIC.sub(lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
  239. text = _RE_BOLD.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
  240. text = _RE_ITALIC.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
  241. # 7. Links → label text only
  242. text = _RE_LINK.sub(r'\1', text)
  243. # 8. List items → comma breath before, period after
  244. text = _RE_BULLET.sub( r', \1.', text)
  245. text = _RE_NUMBERED.sub(r', \1.', text)
  246. # 9. Paragraph breaks → full stop + implicit pause
  247. text = re.sub(r'\n{2,}', '.\n', text)
  248. # 10. Remaining single newlines → space
  249. text = text.replace('\n', ' ')
  250. # 11. Clean up punctuation artifacts left by the above substitutions
  251. text = re.sub(r',\s*,', ',', text) # double commas
  252. text = re.sub(r'\.\s*\.(?!\.)', '.', text) # double periods (not ellipsis)
  253. text = _RE_MULTI_DOTS.sub('...', text) # normalise over-long ellipses
  254. text = re.sub(r'\s*\.\s*,', '.', text) # ., → .
  255. text = re.sub(r',\s*\.', '.', text) # ,. → .
  256. text = re.sub(r'\.\s*\.\.\.', '...', text) # .... → ...
  257. text = _RE_MULTI_SPACE.sub(' ', text)
  258. return text.strip()
  259. # ─── Text chunking ────────────────────────────────────────────────────────────
  260. def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
  261. """
  262. Split on sentence boundaries; falls back to word-boundary splits for
  263. sentences that exceed max_len (e.g. no punctuation, very long clauses).
  264. """
  265. sentences = re.split(r'(?<=[.!?…])\s+', text)
  266. chunks: list[str] = []
  267. current = ""
  268. for s in sentences:
  269. s = s.strip()
  270. if not s:
  271. continue
  272. if len(s) > max_len:
  273. if current:
  274. chunks.append(current)
  275. current = ""
  276. words = s.split()
  277. part = ""
  278. for w in words:
  279. if len(part) + len(w) + 1 > max_len:
  280. if part:
  281. chunks.append(part.strip())
  282. part = w
  283. else:
  284. part = (part + " " + w).strip()
  285. if part:
  286. chunks.append(part)
  287. continue
  288. if len(current) + len(s) + 1 > max_len:
  289. if current:
  290. chunks.append(current)
  291. current = s
  292. else:
  293. current = (current + " " + s).strip()
  294. if current:
  295. chunks.append(current)
  296. return [c for c in chunks if c.strip()]
  297. def prepare_text(text: str, lang: str) -> list[str]:
  298. """Full pipeline: markdown → prosody text → acronym expansion → chunks."""
  299. text = markdown_to_speech_text(text)
  300. text = expand_acronyms(text, lang)
  301. return chunk_text(text)
  302. # ─── Voice / embedding helpers ────────────────────────────────────────────────
  303. def sha256_file(path: Path) -> str:
  304. h = hashlib.sha256()
  305. with open(path, "rb") as f:
  306. for block in iter(lambda: f.read(65536), b""):
  307. h.update(block)
  308. return h.hexdigest()
  309. def convert_to_wav(src: Path, dst: Path) -> None:
  310. subprocess.run(
  311. ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
  312. check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
  313. )
  314. def ensure_wav(voice_name: str) -> Path:
  315. wav = VOICE_DIR / f"{voice_name}.wav"
  316. mp3 = VOICE_DIR / f"{voice_name}.mp3"
  317. if wav.exists():
  318. if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
  319. print(f"MP3 newer than WAV → reconverting {voice_name}")
  320. convert_to_wav(mp3, wav)
  321. return wav
  322. if mp3.exists():
  323. print(f"Converting MP3 → WAV for {voice_name}")
  324. convert_to_wav(mp3, wav)
  325. return wav
  326. raise HTTPException(404, f"Voice '{voice_name}' not found")
  327. def get_embedding(voice_name: str):
  328. if voice_name in embedding_cache:
  329. return embedding_cache[voice_name]
  330. wav_file = ensure_wav(voice_name)
  331. file_hash = sha256_file(wav_file)
  332. cache_file = CACHE_DIR / f"{voice_name}.pkl"
  333. if cache_file.exists():
  334. try:
  335. with open(cache_file, "rb") as f:
  336. cached = pickle.load(f)
  337. if cached.get("hash") == file_hash:
  338. print(f"Using cached embedding for {voice_name}")
  339. embedding_cache[voice_name] = cached["data"]
  340. return cached["data"]
  341. except Exception as e:
  342. print(f"Cache read error for {voice_name}: {e} – recomputing")
  343. print(f"Computing embedding for {voice_name}")
  344. model = tts.synthesizer.tts_model
  345. gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
  346. audio_path=str(wav_file)
  347. )
  348. data = (gpt_cond_latent, speaker_embedding)
  349. with open(cache_file, "wb") as f:
  350. pickle.dump({"hash": file_hash, "data": data}, f)
  351. embedding_cache[voice_name] = data
  352. return data
  353. # ─── Core inference ───────────────────────────────────────────────────────────
  354. def _vram_low() -> bool:
  355. if not torch.cuda.is_available():
  356. return True
  357. free, total = torch.cuda.mem_get_info()
  358. return (free / total) < VRAM_HEADROOM
  359. def _infer_chunk(
  360. chunk: str, lang: str, gpt_cond_latent, speaker_embedding
  361. ) -> np.ndarray:
  362. """Synthesise one text chunk; auto-falls back to CPU on CUDA OOM."""
  363. model = tts.synthesizer.tts_model
  364. def _run(m, lat, emb):
  365. with torch.inference_mode():
  366. out = m.inference(chunk, lang, lat, emb)
  367. wav = out["wav"]
  368. if isinstance(wav, torch.Tensor):
  369. wav = wav.cpu().numpy()
  370. if wav.ndim == 1:
  371. wav = np.expand_dims(wav, 1)
  372. return wav
  373. with _model_lock:
  374. try:
  375. result = _run(model, gpt_cond_latent, speaker_embedding)
  376. # Release XTTS activation memory after every chunk so it doesn't
  377. # accumulate across a long document and starve the next request.
  378. if torch.cuda.is_available():
  379. torch.cuda.empty_cache()
  380. return result
  381. except torch.cuda.OutOfMemoryError:
  382. print(f"⚠ CUDA OOM – falling back to CPU ({os.cpu_count()} cores)")
  383. torch.cuda.empty_cache()
  384. model.to("cpu")
  385. try:
  386. result = _run(
  387. model,
  388. gpt_cond_latent.to("cpu"),
  389. speaker_embedding.to("cpu"),
  390. )
  391. finally:
  392. model.to("cuda")
  393. torch.cuda.empty_cache()
  394. return result
  395. # ─── Routes ───────────────────────────────────────────────────────────────────
  396. @app.get("/")
  397. def root():
  398. return {"status": "XTTS server running", "device": _device}
  399. @app.get("/health")
  400. def health():
  401. info = {"status": "ok", "device": _device}
  402. if torch.cuda.is_available():
  403. free, total = torch.cuda.mem_get_info()
  404. info["vram_free_mb"] = round(free / 1024 ** 2)
  405. info["vram_total_mb"] = round(total / 1024 ** 2)
  406. info["vram_used_pct"] = round((1 - free / total) * 100, 1)
  407. return info
  408. @app.get("/voices")
  409. def list_voices():
  410. seen: set = set()
  411. voices: list = []
  412. for f in VOICE_DIR.iterdir():
  413. if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
  414. voices.append(f.stem)
  415. seen.add(f.stem)
  416. return {"voices": sorted(voices)}
  417. @app.get("/tts")
  418. @app.get("/api/tts")
  419. def synthesize(text: str, voice: str = "default", lang: str = "en"):
  420. if not text.strip():
  421. raise HTTPException(400, "text parameter is empty")
  422. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  423. # Pin everything to CPU for this request if VRAM is already low
  424. use_cpu = _vram_low()
  425. if use_cpu and torch.cuda.is_available():
  426. print("⚠ Low VRAM – pinning entire request to CPU")
  427. gpt_cond_latent = gpt_cond_latent.to("cpu")
  428. speaker_embedding = speaker_embedding.to("cpu")
  429. with _model_lock:
  430. tts.synthesizer.tts_model.to("cpu")
  431. chunks = prepare_text(text, lang)
  432. wav_all = []
  433. for i, chunk in enumerate(chunks):
  434. print(f" chunk {i+1}/{len(chunks)}: {chunk[:80]!r}")
  435. try:
  436. wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
  437. except Exception as e:
  438. raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
  439. wav_all.append(wav_chunk)
  440. if use_cpu and torch.cuda.is_available():
  441. with _model_lock:
  442. tts.synthesizer.tts_model.to("cuda")
  443. # Final sweep — catches anything the per-chunk clears missed
  444. if torch.cuda.is_available():
  445. torch.cuda.empty_cache()
  446. wav = np.concatenate(wav_all, axis=0)
  447. buf = io.BytesIO()
  448. sf.write(buf, wav, SAMPLE_RATE, format="WAV")
  449. buf.seek(0)
  450. return StreamingResponse(buf, media_type="audio/wav")
  451. @app.get("/tts_stream")
  452. @app.get("/api/tts_stream")
  453. def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
  454. """Stream WAV chunks as synthesised — lower latency for long texts."""
  455. if not text.strip():
  456. raise HTTPException(400, "text parameter is empty")
  457. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  458. chunks = prepare_text(text, lang)
  459. def audio_generator():
  460. for i, chunk in enumerate(chunks):
  461. print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:80]!r}")
  462. try:
  463. wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
  464. except Exception as e:
  465. print(f" [stream] chunk {i+1} failed: {e}")
  466. continue # skip bad chunk rather than kill the stream
  467. buf = io.BytesIO()
  468. sf.write(buf, wav, SAMPLE_RATE, format="WAV")
  469. buf.seek(0)
  470. yield buf.read()
  471. # Clear after each streamed chunk — long documents would otherwise
  472. # accumulate VRAM and cause the next request to fall back to CPU.
  473. if torch.cuda.is_available():
  474. torch.cuda.empty_cache()
  475. return StreamingResponse(audio_generator(), media_type="audio/wav")