import os import io import re import hashlib import pickle import subprocess import threading from pathlib import Path import numpy as np import soundfile as sf import torch # Set CPU threads BEFORE any torch operations torch.set_num_threads(os.cpu_count()) torch.set_num_interop_threads(max(1, os.cpu_count() // 2)) # FIX for PyTorch >=2.6 security change from torch.serialization import add_safe_globals import TTS.tts.configs.xtts_config import TTS.tts.models.xtts add_safe_globals([ TTS.tts.configs.xtts_config.XttsConfig, TTS.tts.models.xtts.XttsAudioConfig, ]) from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from TTS.api import TTS # ─── Paths & constants ──────────────────────────────────────────────────────── VOICE_DIR = Path("/voices") CACHE_DIR = Path("/cache") MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2" SAMPLE_RATE = 24000 VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit ~400 tokens ≈ 250 chars VOICE_DIR.mkdir(exist_ok=True) CACHE_DIR.mkdir(exist_ok=True) # ─── Model loading ──────────────────────────────────────────────────────────── print("Loading XTTS model...") _device = "cuda" if torch.cuda.is_available() else "cpu" tts = TTS(MODEL_NAME).to(_device) print(f"Model loaded on {_device}.") # Serialise all model access so concurrent requests don't race on .to() calls _model_lock = threading.Lock() app = FastAPI() embedding_cache: dict = {} # ─── Acronym / symbol tables ────────────────────────────────────────────────── # # Keys are matched as whole words (word-boundary regex). # Values are phonetic spellings XTTS pronounces letter-by-letter. # Hyphens between letters reliably force individual-letter pronunciation. # # German rule: spell every letter using German letter names. # English rule: most common EN acronyms are already correct; only fix known # bad ones (mainly German acronyms appearing in mixed text). ACRONYMS_DE: dict[str, str] = { # ── Technology / computing ─────────────────────────────────────────────── "KI": "Ka-I", "IT": "I-Te", "PC": "Pe-Tse", "API": "A-Pe-I", "URL": "U-Er-El", "HTTP": "Ha-Te-Te-Pe", "AI": "Ei-Ei", # English loanword in German text "ML": "Em-El", "UI": "U-I", "GPU": "Ge-Pe-U", "CPU": "Tse-Pe-U", # ── Geography / politics ───────────────────────────────────────────────── "EU": "E-U", "US": "U-Es", "USA": "U-Es-A", "UK": "U-Ka", "UN": "U-En", "NATO": "NATO", # spoken as a word in German too "BRD": "Be-Er-De", "DDR": "De-De-Er", "SPD": "Es-Pe-De", "CDU": "Tse-De-U", "CSU": "Tse-Es-U", "FDP": "Ef-De-Pe", "AfD": "A-Ef-De", "ÖVP": "Ö-Fau-Pe", "FPÖ": "Ef-Pe-Ö", # ── Business / finance ─────────────────────────────────────────────────── "AG": "A-Ge", "GmbH": "Ge-Em-Be-Ha", "CEO": "Tse-E-O", "CFO": "Tse-Ef-O", "CTO": "Tse-Te-O", "HR": "Ha-Er", "PR": "Pe-Er", "BIP": "Be-I-Pe", "EZB": "E-Tse-Be", "IWF": "I-Ve-Ef", "WTO": "Ve-Te-O", # ── Media / broadcasting ───────────────────────────────────────────────── "ARD": "A-Er-De", "ZDF": "Tse-De-Ef", "ORF": "O-Er-Ef", "SRF": "Es-Er-Ef", "WDR": "Ve-De-Er", "NDR": "En-De-Er", "MDR": "Em-De-Er", # ── Units / symbols (text substitution) ────────────────────────────────── "€": "Euro", "$": "Dollar", "£": "Pfund", "%": "Prozent", "°C": "Grad Celsius", "°F": "Grad Fahrenheit", "km": "Kilometer", "kg": "Kilogramm", # ── Common German abbreviations ─────────────────────────────────────────── "bzw.": "beziehungsweise", "ca.": "circa", "usw.": "und so weiter", "z.B.": "zum Beispiel", "d.h.": "das heißt", "u.a.": "unter anderem", "etc.": "etcetera", "Nr.": "Nummer", "vs.": "versus", "Dr.": "Doktor", "Prof.": "Professor", "Hrsg.": "Herausgeber", "Jh.": "Jahrhundert", "Mrd.": "Milliarden", "Mio.": "Millionen", } ACRONYMS_EN: dict[str, str] = { # Only list acronyms that XTTS mispronounces in English context. # German acronyms that appear in English/mixed text: "KI": "Kay Eye", "EU": "E-U", "BRD": "B-R-D", "DDR": "D-D-R", "GmbH": "G-m-b-H", "EZB": "E-Z-B", "ARD": "A-R-D", "ZDF": "Z-D-F", "ORF": "O-R-F", "SRF": "S-R-F", "WDR": "W-D-R", "NDR": "N-D-R", "MDR": "M-D-R", # Units / symbols "€": "euros", "$": "dollars", "£": "pounds", "%": "percent", "°C": "degrees Celsius", "°F": "degrees Fahrenheit", "km": "kilometers", "kg": "kilograms", # Abbreviations "vs.": "versus", "etc.": "et cetera", "Dr.": "Doctor", "Prof.": "Professor", "Nr.": "Number", "Mrd.": "billion", "Mio.": "million", } def _build_acronym_pattern(table: dict[str, str]) -> re.Pattern: """ Compile a single regex matching all keys as whole tokens. Longer keys take priority (sorted descending by length). Pure-symbol keys (€, $, °C) are matched without word boundaries. """ word_keys = sorted([k for k in table if re.match(r'\w', k)], key=len, reverse=True) special_keys = sorted([k for k in table if not re.match(r'\w', k)], key=len, reverse=True) parts = [r'\b' + re.escape(k) + r'\b' for k in word_keys] parts += [re.escape(k) for k in special_keys] return re.compile('|'.join(parts)) if parts else re.compile(r'(?!)') _PATTERN_DE = _build_acronym_pattern(ACRONYMS_DE) _PATTERN_EN = _build_acronym_pattern(ACRONYMS_EN) def expand_acronyms(text: str, lang: str) -> str: """Replace acronyms/symbols with phonetic expansions for the given language.""" if lang.startswith("de"): table, pattern = ACRONYMS_DE, _PATTERN_DE else: table, pattern = ACRONYMS_EN, _PATTERN_EN return pattern.sub(lambda m: table[m.group(0)], text) # ─── Markdown → natural speech ──────────────────────────────────────────────── # # XTTS has no SSML support, but punctuation shapes prosody directly: # Period → short stop / breath # Ellipsis "..." → longer, contemplative pause # Comma → brief breath # # Mapping: # H1 → "..." before + text + "." + "..." after (longest pause) # H2 / H3 → "." before + text + "." (medium pause) # H4–H6 → text + "." (small pause) # **bold** → ", " + text + "," (emphasis breath) # *italic* → ", " + text + "," # Bullets → ", " + text + "." (list breath) # Blank line → "." (paragraph stop) # Code block → plain text, fences stripped # Link → label text only # HR --- → "..." (section break) _RE_HR = re.compile(r'^\s*[-*_]{3,}\s*$', re.MULTILINE) _RE_CODE_BLOCK = re.compile(r'```[\s\S]*?```') _RE_INLINE_CODE = re.compile(r'`[^`]+`') _RE_H1 = re.compile(r'^#\s+(.+)$', re.MULTILINE) _RE_H2 = re.compile(r'^#{2,3}\s+(.+)$', re.MULTILINE) _RE_H_DEEP = re.compile(r'^#{4,6}\s+(.+)$', re.MULTILINE) _RE_BOLD_ITALIC = re.compile(r'\*{3}(.+?)\*{3}|_{3}(.+?)_{3}') _RE_BOLD = re.compile(r'\*{2}(.+?)\*{2}|_{2}(.+?)_{2}') _RE_ITALIC = re.compile(r'\*(.+?)\*|_(.+?)_') _RE_LINK = re.compile(r'\[([^\]]+)\]\([^)]*\)') _RE_BULLET = re.compile(r'^\s*[-*+]\s+(.+)$', re.MULTILINE) _RE_NUMBERED = re.compile(r'^\s*\d+\.\s+(.+)$', re.MULTILINE) _RE_BLOCKQUOTE = re.compile(r'^\s*>\s+(.+)$', re.MULTILINE) _RE_MULTI_SPACE = re.compile(r' +') _RE_MULTI_DOTS = re.compile(r'\.{4,}') _RE_CONTROL = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]') def markdown_to_speech_text(text: str) -> str: """ Convert markdown to plain text shaped for natural TTS prosody. Uses only punctuation cues — no spoken labels. """ # 1. Normalise line endings + strip control chars text = text.replace('\r\n', '\n').replace('\r', '\n') text = _RE_CONTROL.sub('', text) # 2. Code blocks → plain text (strip fences, keep content) text = _RE_CODE_BLOCK.sub( lambda m: m.group(0).split('\n', 1)[-1].rsplit('\n', 1)[0], text ) text = _RE_INLINE_CODE.sub(lambda m: m.group(0).strip('`'), text) # 3. Horizontal rules → long section-break pause text = _RE_HR.sub('\n...\n', text) # 4. Headings — longest pause for H1, medium for H2/H3, small for H4+ text = _RE_H1.sub(r'\n...\n\1.\n...\n', text) text = _RE_H2.sub(r'\n.\n\1.\n', text) text = _RE_H_DEEP.sub(r'\n\1.\n', text) # 5. Blockquotes → comma-padded inline text = _RE_BLOCKQUOTE.sub(r', \1,', text) # 6. Inline emphasis — extract text, add comma-pauses text = _RE_BOLD_ITALIC.sub(lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text) text = _RE_BOLD.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text) text = _RE_ITALIC.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text) # 7. Links → label text only text = _RE_LINK.sub(r'\1', text) # 8. List items → comma breath before, period after text = _RE_BULLET.sub( r', \1.', text) text = _RE_NUMBERED.sub(r', \1.', text) # 9. Paragraph breaks → full stop + implicit pause text = re.sub(r'\n{2,}', '.\n', text) # 10. Remaining single newlines → space text = text.replace('\n', ' ') # 11. Clean up punctuation artifacts left by the above substitutions text = re.sub(r',\s*,', ',', text) # double commas text = re.sub(r'\.\s*\.(?!\.)', '.', text) # double periods (not ellipsis) text = _RE_MULTI_DOTS.sub('...', text) # normalise over-long ellipses text = re.sub(r'\s*\.\s*,', '.', text) # ., → . text = re.sub(r',\s*\.', '.', text) # ,. → . text = re.sub(r'\.\s*\.\.\.', '...', text) # .... → ... text = _RE_MULTI_SPACE.sub(' ', text) return text.strip() # ─── Text chunking ──────────────────────────────────────────────────────────── def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]: """ Split on sentence boundaries; falls back to word-boundary splits for sentences that exceed max_len (e.g. no punctuation, very long clauses). """ sentences = re.split(r'(?<=[.!?…])\s+', text) chunks: list[str] = [] current = "" for s in sentences: s = s.strip() if not s: continue if len(s) > max_len: if current: chunks.append(current) current = "" words = s.split() part = "" for w in words: if len(part) + len(w) + 1 > max_len: if part: chunks.append(part.strip()) part = w else: part = (part + " " + w).strip() if part: chunks.append(part) continue if len(current) + len(s) + 1 > max_len: if current: chunks.append(current) current = s else: current = (current + " " + s).strip() if current: chunks.append(current) return [c for c in chunks if c.strip()] def prepare_text(text: str, lang: str) -> list[str]: """Full pipeline: markdown → prosody text → acronym expansion → chunks.""" text = markdown_to_speech_text(text) text = expand_acronyms(text, lang) return chunk_text(text) # ─── Voice / embedding helpers ──────────────────────────────────────────────── def sha256_file(path: Path) -> str: h = hashlib.sha256() with open(path, "rb") as f: for block in iter(lambda: f.read(65536), b""): h.update(block) return h.hexdigest() def convert_to_wav(src: Path, dst: Path) -> None: subprocess.run( ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) def ensure_wav(voice_name: str) -> Path: wav = VOICE_DIR / f"{voice_name}.wav" mp3 = VOICE_DIR / f"{voice_name}.mp3" if wav.exists(): if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime: print(f"MP3 newer than WAV → reconverting {voice_name}") convert_to_wav(mp3, wav) return wav if mp3.exists(): print(f"Converting MP3 → WAV for {voice_name}") convert_to_wav(mp3, wav) return wav raise HTTPException(404, f"Voice '{voice_name}' not found") def get_embedding(voice_name: str): if voice_name in embedding_cache: return embedding_cache[voice_name] wav_file = ensure_wav(voice_name) file_hash = sha256_file(wav_file) cache_file = CACHE_DIR / f"{voice_name}.pkl" if cache_file.exists(): try: with open(cache_file, "rb") as f: cached = pickle.load(f) if cached.get("hash") == file_hash: print(f"Using cached embedding for {voice_name}") embedding_cache[voice_name] = cached["data"] return cached["data"] except Exception as e: print(f"Cache read error for {voice_name}: {e} – recomputing") print(f"Computing embedding for {voice_name}") model = tts.synthesizer.tts_model gpt_cond_latent, speaker_embedding = model.get_conditioning_latents( audio_path=str(wav_file) ) data = (gpt_cond_latent, speaker_embedding) with open(cache_file, "wb") as f: pickle.dump({"hash": file_hash, "data": data}, f) embedding_cache[voice_name] = data return data # ─── Core inference ─────────────────────────────────────────────────────────── def _vram_low() -> bool: if not torch.cuda.is_available(): return True free, total = torch.cuda.mem_get_info() return (free / total) < VRAM_HEADROOM def _infer_chunk( chunk: str, lang: str, gpt_cond_latent, speaker_embedding ) -> np.ndarray: """Synthesise one text chunk; auto-falls back to CPU on CUDA OOM.""" model = tts.synthesizer.tts_model def _run(m, lat, emb): with torch.inference_mode(): out = m.inference(chunk, lang, lat, emb) wav = out["wav"] if isinstance(wav, torch.Tensor): wav = wav.cpu().numpy() if wav.ndim == 1: wav = np.expand_dims(wav, 1) return wav with _model_lock: try: result = _run(model, gpt_cond_latent, speaker_embedding) # Release XTTS activation memory after every chunk so it doesn't # accumulate across a long document and starve the next request. if torch.cuda.is_available(): torch.cuda.empty_cache() return result except torch.cuda.OutOfMemoryError: print(f"⚠ CUDA OOM – falling back to CPU ({os.cpu_count()} cores)") torch.cuda.empty_cache() model.to("cpu") try: result = _run( model, gpt_cond_latent.to("cpu"), speaker_embedding.to("cpu"), ) finally: model.to("cuda") torch.cuda.empty_cache() return result # ─── Routes ─────────────────────────────────────────────────────────────────── @app.get("/") def root(): return {"status": "XTTS server running", "device": _device} @app.get("/health") def health(): info = {"status": "ok", "device": _device} if torch.cuda.is_available(): free, total = torch.cuda.mem_get_info() info["vram_free_mb"] = round(free / 1024 ** 2) info["vram_total_mb"] = round(total / 1024 ** 2) info["vram_used_pct"] = round((1 - free / total) * 100, 1) return info @app.get("/voices") def list_voices(): seen: set = set() voices: list = [] for f in VOICE_DIR.iterdir(): if f.suffix in {".wav", ".mp3"} and f.stem not in seen: voices.append(f.stem) seen.add(f.stem) return {"voices": sorted(voices)} @app.get("/tts") @app.get("/api/tts") def synthesize(text: str, voice: str = "default", lang: str = "en"): if not text.strip(): raise HTTPException(400, "text parameter is empty") gpt_cond_latent, speaker_embedding = get_embedding(voice) # Pin everything to CPU for this request if VRAM is already low use_cpu = _vram_low() if use_cpu and torch.cuda.is_available(): print("⚠ Low VRAM – pinning entire request to CPU") gpt_cond_latent = gpt_cond_latent.to("cpu") speaker_embedding = speaker_embedding.to("cpu") with _model_lock: tts.synthesizer.tts_model.to("cpu") chunks = prepare_text(text, lang) wav_all = [] for i, chunk in enumerate(chunks): print(f" chunk {i+1}/{len(chunks)}: {chunk[:80]!r}") try: wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding) except Exception as e: raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}") wav_all.append(wav_chunk) if use_cpu and torch.cuda.is_available(): with _model_lock: tts.synthesizer.tts_model.to("cuda") # Final sweep — catches anything the per-chunk clears missed if torch.cuda.is_available(): torch.cuda.empty_cache() wav = np.concatenate(wav_all, axis=0) buf = io.BytesIO() sf.write(buf, wav, SAMPLE_RATE, format="WAV") buf.seek(0) return StreamingResponse(buf, media_type="audio/wav") @app.get("/tts_stream") @app.get("/api/tts_stream") def synthesize_stream(text: str, voice: str = "default", lang: str = "en"): """Stream WAV chunks as synthesised — lower latency for long texts.""" if not text.strip(): raise HTTPException(400, "text parameter is empty") gpt_cond_latent, speaker_embedding = get_embedding(voice) chunks = prepare_text(text, lang) def audio_generator(): for i, chunk in enumerate(chunks): print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:80]!r}") try: wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding) except Exception as e: print(f" [stream] chunk {i+1} failed: {e}") continue # skip bad chunk rather than kill the stream buf = io.BytesIO() sf.write(buf, wav, SAMPLE_RATE, format="WAV") buf.seek(0) yield buf.read() # Clear after each streamed chunk — long documents would otherwise # accumulate VRAM and cause the next request to fall back to CPU. if torch.cuda.is_available(): torch.cuda.empty_cache() return StreamingResponse(audio_generator(), media_type="audio/wav")