| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- import os
- import io
- import re
- import hashlib
- import pickle
- import subprocess
- import threading
- from pathlib import Path
- import numpy as np
- import soundfile as sf
- import torch
- # Set CPU threads BEFORE any torch operations
- torch.set_num_threads(os.cpu_count())
- torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
- # FIX for PyTorch >=2.6 security change
- from torch.serialization import add_safe_globals
- import TTS.tts.configs.xtts_config
- import TTS.tts.models.xtts
- add_safe_globals([
- TTS.tts.configs.xtts_config.XttsConfig,
- TTS.tts.models.xtts.XttsAudioConfig,
- ])
- from fastapi import FastAPI, HTTPException
- from fastapi.responses import StreamingResponse
- from TTS.api import TTS
- # ─── Paths & constants ────────────────────────────────────────────────────────
- VOICE_DIR = Path("/voices")
- CACHE_DIR = Path("/cache")
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
- SAMPLE_RATE = 24000
- VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free
- MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit ~400 tokens ≈ 250 chars
- VOICE_DIR.mkdir(exist_ok=True)
- CACHE_DIR.mkdir(exist_ok=True)
- # ─── Model loading ────────────────────────────────────────────────────────────
- print("Loading XTTS model...")
- _device = "cuda" if torch.cuda.is_available() else "cpu"
- tts = TTS(MODEL_NAME).to(_device)
- print(f"Model loaded on {_device}.")
- # Serialise all model access so concurrent requests don't race on .to() calls
- _model_lock = threading.Lock()
- app = FastAPI()
- embedding_cache: dict = {}
- # ─── Acronym / symbol tables ──────────────────────────────────────────────────
- #
- # Keys are matched as whole words (word-boundary regex).
- # Values are phonetic spellings XTTS pronounces letter-by-letter.
- # Hyphens between letters reliably force individual-letter pronunciation.
- #
- # German rule: spell every letter using German letter names.
- # English rule: most common EN acronyms are already correct; only fix known
- # bad ones (mainly German acronyms appearing in mixed text).
- ACRONYMS_DE: dict[str, str] = {
- # ── Technology / computing ───────────────────────────────────────────────
- "KI": "Ka-I",
- "IT": "I-Te",
- "PC": "Pe-Tse",
- "API": "A-Pe-I",
- "URL": "U-Er-El",
- "HTTP": "Ha-Te-Te-Pe",
- "AI": "Ei-Ei", # English loanword in German text
- "ML": "Em-El",
- "UI": "U-I",
- "GPU": "Ge-Pe-U",
- "CPU": "Tse-Pe-U",
- # ── Geography / politics ─────────────────────────────────────────────────
- "EU": "E-U",
- "US": "U-Es",
- "USA": "U-Es-A",
- "UK": "U-Ka",
- "UN": "U-En",
- "NATO": "NATO", # spoken as a word in German too
- "BRD": "Be-Er-De",
- "DDR": "De-De-Er",
- "SPD": "Es-Pe-De",
- "CDU": "Tse-De-U",
- "CSU": "Tse-Es-U",
- "FDP": "Ef-De-Pe",
- "AfD": "A-Ef-De",
- "ÖVP": "Ö-Fau-Pe",
- "FPÖ": "Ef-Pe-Ö",
- # ── Business / finance ───────────────────────────────────────────────────
- "AG": "A-Ge",
- "GmbH": "Ge-Em-Be-Ha",
- "CEO": "Tse-E-O",
- "CFO": "Tse-Ef-O",
- "CTO": "Tse-Te-O",
- "HR": "Ha-Er",
- "PR": "Pe-Er",
- "BIP": "Be-I-Pe",
- "EZB": "E-Tse-Be",
- "IWF": "I-Ve-Ef",
- "WTO": "Ve-Te-O",
- # ── Media / broadcasting ─────────────────────────────────────────────────
- "ARD": "A-Er-De",
- "ZDF": "Tse-De-Ef",
- "ORF": "O-Er-Ef",
- "SRF": "Es-Er-Ef",
- "WDR": "Ve-De-Er",
- "NDR": "En-De-Er",
- "MDR": "Em-De-Er",
- # ── Units / symbols (text substitution) ──────────────────────────────────
- "€": "Euro",
- "$": "Dollar",
- "£": "Pfund",
- "%": "Prozent",
- "°C": "Grad Celsius",
- "°F": "Grad Fahrenheit",
- "km": "Kilometer",
- "kg": "Kilogramm",
- # ── Common German abbreviations ───────────────────────────────────────────
- "bzw.": "beziehungsweise",
- "ca.": "circa",
- "usw.": "und so weiter",
- "z.B.": "zum Beispiel",
- "d.h.": "das heißt",
- "u.a.": "unter anderem",
- "etc.": "etcetera",
- "Nr.": "Nummer",
- "vs.": "versus",
- "Dr.": "Doktor",
- "Prof.": "Professor",
- "Hrsg.": "Herausgeber",
- "Jh.": "Jahrhundert",
- "Mrd.": "Milliarden",
- "Mio.": "Millionen",
- }
- ACRONYMS_EN: dict[str, str] = {
- # Only list acronyms that XTTS mispronounces in English context.
- # German acronyms that appear in English/mixed text:
- "KI": "Kay Eye",
- "EU": "E-U",
- "BRD": "B-R-D",
- "DDR": "D-D-R",
- "GmbH": "G-m-b-H",
- "EZB": "E-Z-B",
- "ARD": "A-R-D",
- "ZDF": "Z-D-F",
- "ORF": "O-R-F",
- "SRF": "S-R-F",
- "WDR": "W-D-R",
- "NDR": "N-D-R",
- "MDR": "M-D-R",
- # Units / symbols
- "€": "euros",
- "$": "dollars",
- "£": "pounds",
- "%": "percent",
- "°C": "degrees Celsius",
- "°F": "degrees Fahrenheit",
- "km": "kilometers",
- "kg": "kilograms",
- # Abbreviations
- "vs.": "versus",
- "etc.": "et cetera",
- "Dr.": "Doctor",
- "Prof.": "Professor",
- "Nr.": "Number",
- "Mrd.": "billion",
- "Mio.": "million",
- }
- def _build_acronym_pattern(table: dict[str, str]) -> re.Pattern:
- """
- Compile a single regex matching all keys as whole tokens.
- Longer keys take priority (sorted descending by length).
- Pure-symbol keys (€, $, °C) are matched without word boundaries.
- """
- word_keys = sorted([k for k in table if re.match(r'\w', k)], key=len, reverse=True)
- special_keys = sorted([k for k in table if not re.match(r'\w', k)], key=len, reverse=True)
- parts = [r'\b' + re.escape(k) + r'\b' for k in word_keys]
- parts += [re.escape(k) for k in special_keys]
- return re.compile('|'.join(parts)) if parts else re.compile(r'(?!)')
- _PATTERN_DE = _build_acronym_pattern(ACRONYMS_DE)
- _PATTERN_EN = _build_acronym_pattern(ACRONYMS_EN)
- def expand_acronyms(text: str, lang: str) -> str:
- """Replace acronyms/symbols with phonetic expansions for the given language."""
- if lang.startswith("de"):
- table, pattern = ACRONYMS_DE, _PATTERN_DE
- else:
- table, pattern = ACRONYMS_EN, _PATTERN_EN
- return pattern.sub(lambda m: table[m.group(0)], text)
- # ─── Markdown → natural speech ────────────────────────────────────────────────
- #
- # XTTS has no SSML support, but punctuation shapes prosody directly:
- # Period → short stop / breath
- # Ellipsis "..." → longer, contemplative pause
- # Comma → brief breath
- #
- # Mapping:
- # H1 → "..." before + text + "." + "..." after (longest pause)
- # H2 / H3 → "." before + text + "." (medium pause)
- # H4–H6 → text + "." (small pause)
- # **bold** → ", " + text + "," (emphasis breath)
- # *italic* → ", " + text + ","
- # Bullets → ", " + text + "." (list breath)
- # Blank line → "." (paragraph stop)
- # Code block → plain text, fences stripped
- # Link → label text only
- # HR --- → "..." (section break)
- _RE_HR = re.compile(r'^\s*[-*_]{3,}\s*$', re.MULTILINE)
- _RE_CODE_BLOCK = re.compile(r'```[\s\S]*?```')
- _RE_INLINE_CODE = re.compile(r'`[^`]+`')
- _RE_H1 = re.compile(r'^#\s+(.+)$', re.MULTILINE)
- _RE_H2 = re.compile(r'^#{2,3}\s+(.+)$', re.MULTILINE)
- _RE_H_DEEP = re.compile(r'^#{4,6}\s+(.+)$', re.MULTILINE)
- _RE_BOLD_ITALIC = re.compile(r'\*{3}(.+?)\*{3}|_{3}(.+?)_{3}')
- _RE_BOLD = re.compile(r'\*{2}(.+?)\*{2}|_{2}(.+?)_{2}')
- _RE_ITALIC = re.compile(r'\*(.+?)\*|_(.+?)_')
- _RE_LINK = re.compile(r'\[([^\]]+)\]\([^)]*\)')
- _RE_BULLET = re.compile(r'^\s*[-*+]\s+(.+)$', re.MULTILINE)
- _RE_NUMBERED = re.compile(r'^\s*\d+\.\s+(.+)$', re.MULTILINE)
- _RE_BLOCKQUOTE = re.compile(r'^\s*>\s+(.+)$', re.MULTILINE)
- _RE_MULTI_SPACE = re.compile(r' +')
- _RE_MULTI_DOTS = re.compile(r'\.{4,}')
- _RE_CONTROL = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
- def markdown_to_speech_text(text: str) -> str:
- """
- Convert markdown to plain text shaped for natural TTS prosody.
- Uses only punctuation cues — no spoken labels.
- """
- # 1. Normalise line endings + strip control chars
- text = text.replace('\r\n', '\n').replace('\r', '\n')
- text = _RE_CONTROL.sub('', text)
- # 2. Code blocks → plain text (strip fences, keep content)
- text = _RE_CODE_BLOCK.sub(
- lambda m: m.group(0).split('\n', 1)[-1].rsplit('\n', 1)[0], text
- )
- text = _RE_INLINE_CODE.sub(lambda m: m.group(0).strip('`'), text)
- # 3. Horizontal rules → long section-break pause
- text = _RE_HR.sub('\n...\n', text)
- # 4. Headings — longest pause for H1, medium for H2/H3, small for H4+
- text = _RE_H1.sub(r'\n...\n\1.\n...\n', text)
- text = _RE_H2.sub(r'\n.\n\1.\n', text)
- text = _RE_H_DEEP.sub(r'\n\1.\n', text)
- # 5. Blockquotes → comma-padded inline
- text = _RE_BLOCKQUOTE.sub(r', \1,', text)
- # 6. Inline emphasis — extract text, add comma-pauses
- text = _RE_BOLD_ITALIC.sub(lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
- text = _RE_BOLD.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
- text = _RE_ITALIC.sub( lambda m: ', ' + (m.group(1) or m.group(2)) + ',', text)
- # 7. Links → label text only
- text = _RE_LINK.sub(r'\1', text)
- # 8. List items → comma breath before, period after
- text = _RE_BULLET.sub( r', \1.', text)
- text = _RE_NUMBERED.sub(r', \1.', text)
- # 9. Paragraph breaks → full stop + implicit pause
- text = re.sub(r'\n{2,}', '.\n', text)
- # 10. Remaining single newlines → space
- text = text.replace('\n', ' ')
- # 11. Clean up punctuation artifacts left by the above substitutions
- text = re.sub(r',\s*,', ',', text) # double commas
- text = re.sub(r'\.\s*\.(?!\.)', '.', text) # double periods (not ellipsis)
- text = _RE_MULTI_DOTS.sub('...', text) # normalise over-long ellipses
- text = re.sub(r'\s*\.\s*,', '.', text) # ., → .
- text = re.sub(r',\s*\.', '.', text) # ,. → .
- text = re.sub(r'\.\s*\.\.\.', '...', text) # .... → ...
- text = _RE_MULTI_SPACE.sub(' ', text)
- return text.strip()
- # ─── Text chunking ────────────────────────────────────────────────────────────
- def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
- """
- Split on sentence boundaries; falls back to word-boundary splits for
- sentences that exceed max_len (e.g. no punctuation, very long clauses).
- """
- sentences = re.split(r'(?<=[.!?…])\s+', text)
- chunks: list[str] = []
- current = ""
- for s in sentences:
- s = s.strip()
- if not s:
- continue
- if len(s) > max_len:
- if current:
- chunks.append(current)
- current = ""
- words = s.split()
- part = ""
- for w in words:
- if len(part) + len(w) + 1 > max_len:
- if part:
- chunks.append(part.strip())
- part = w
- else:
- part = (part + " " + w).strip()
- if part:
- chunks.append(part)
- continue
- if len(current) + len(s) + 1 > max_len:
- if current:
- chunks.append(current)
- current = s
- else:
- current = (current + " " + s).strip()
- if current:
- chunks.append(current)
- return [c for c in chunks if c.strip()]
- def prepare_text(text: str, lang: str) -> list[str]:
- """Full pipeline: markdown → prosody text → acronym expansion → chunks."""
- text = markdown_to_speech_text(text)
- text = expand_acronyms(text, lang)
- return chunk_text(text)
- # ─── Voice / embedding helpers ────────────────────────────────────────────────
- def sha256_file(path: Path) -> str:
- h = hashlib.sha256()
- with open(path, "rb") as f:
- for block in iter(lambda: f.read(65536), b""):
- h.update(block)
- return h.hexdigest()
- def convert_to_wav(src: Path, dst: Path) -> None:
- subprocess.run(
- ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
- check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
- )
- def ensure_wav(voice_name: str) -> Path:
- wav = VOICE_DIR / f"{voice_name}.wav"
- mp3 = VOICE_DIR / f"{voice_name}.mp3"
- if wav.exists():
- if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
- print(f"MP3 newer than WAV → reconverting {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- if mp3.exists():
- print(f"Converting MP3 → WAV for {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- def get_embedding(voice_name: str):
- if voice_name in embedding_cache:
- return embedding_cache[voice_name]
- wav_file = ensure_wav(voice_name)
- file_hash = sha256_file(wav_file)
- cache_file = CACHE_DIR / f"{voice_name}.pkl"
- if cache_file.exists():
- try:
- with open(cache_file, "rb") as f:
- cached = pickle.load(f)
- if cached.get("hash") == file_hash:
- print(f"Using cached embedding for {voice_name}")
- embedding_cache[voice_name] = cached["data"]
- return cached["data"]
- except Exception as e:
- print(f"Cache read error for {voice_name}: {e} – recomputing")
- print(f"Computing embedding for {voice_name}")
- model = tts.synthesizer.tts_model
- gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
- audio_path=str(wav_file)
- )
- data = (gpt_cond_latent, speaker_embedding)
- with open(cache_file, "wb") as f:
- pickle.dump({"hash": file_hash, "data": data}, f)
- embedding_cache[voice_name] = data
- return data
- # ─── Core inference ───────────────────────────────────────────────────────────
- def _vram_low() -> bool:
- if not torch.cuda.is_available():
- return True
- free, total = torch.cuda.mem_get_info()
- return (free / total) < VRAM_HEADROOM
- def _infer_chunk(
- chunk: str, lang: str, gpt_cond_latent, speaker_embedding
- ) -> np.ndarray:
- """Synthesise one text chunk; auto-falls back to CPU on CUDA OOM."""
- model = tts.synthesizer.tts_model
- def _run(m, lat, emb):
- with torch.inference_mode():
- out = m.inference(chunk, lang, lat, emb)
- wav = out["wav"]
- if isinstance(wav, torch.Tensor):
- wav = wav.cpu().numpy()
- if wav.ndim == 1:
- wav = np.expand_dims(wav, 1)
- return wav
- with _model_lock:
- try:
- result = _run(model, gpt_cond_latent, speaker_embedding)
- # Release XTTS activation memory after every chunk so it doesn't
- # accumulate across a long document and starve the next request.
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- return result
- except torch.cuda.OutOfMemoryError:
- print(f"⚠ CUDA OOM – falling back to CPU ({os.cpu_count()} cores)")
- torch.cuda.empty_cache()
- model.to("cpu")
- try:
- result = _run(
- model,
- gpt_cond_latent.to("cpu"),
- speaker_embedding.to("cpu"),
- )
- finally:
- model.to("cuda")
- torch.cuda.empty_cache()
- return result
- # ─── Routes ───────────────────────────────────────────────────────────────────
- @app.get("/")
- def root():
- return {"status": "XTTS server running", "device": _device}
- @app.get("/health")
- def health():
- info = {"status": "ok", "device": _device}
- if torch.cuda.is_available():
- free, total = torch.cuda.mem_get_info()
- info["vram_free_mb"] = round(free / 1024 ** 2)
- info["vram_total_mb"] = round(total / 1024 ** 2)
- info["vram_used_pct"] = round((1 - free / total) * 100, 1)
- return info
- @app.get("/voices")
- def list_voices():
- seen: set = set()
- voices: list = []
- for f in VOICE_DIR.iterdir():
- if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
- voices.append(f.stem)
- seen.add(f.stem)
- return {"voices": sorted(voices)}
- @app.get("/tts")
- @app.get("/api/tts")
- def synthesize(text: str, voice: str = "default", lang: str = "en"):
- if not text.strip():
- raise HTTPException(400, "text parameter is empty")
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- # Pin everything to CPU for this request if VRAM is already low
- use_cpu = _vram_low()
- if use_cpu and torch.cuda.is_available():
- print("⚠ Low VRAM – pinning entire request to CPU")
- gpt_cond_latent = gpt_cond_latent.to("cpu")
- speaker_embedding = speaker_embedding.to("cpu")
- with _model_lock:
- tts.synthesizer.tts_model.to("cpu")
- chunks = prepare_text(text, lang)
- wav_all = []
- for i, chunk in enumerate(chunks):
- print(f" chunk {i+1}/{len(chunks)}: {chunk[:80]!r}")
- try:
- wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
- except Exception as e:
- raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
- wav_all.append(wav_chunk)
- if use_cpu and torch.cuda.is_available():
- with _model_lock:
- tts.synthesizer.tts_model.to("cuda")
- # Final sweep — catches anything the per-chunk clears missed
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- wav = np.concatenate(wav_all, axis=0)
- buf = io.BytesIO()
- sf.write(buf, wav, SAMPLE_RATE, format="WAV")
- buf.seek(0)
- return StreamingResponse(buf, media_type="audio/wav")
- @app.get("/tts_stream")
- @app.get("/api/tts_stream")
- def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
- """Stream WAV chunks as synthesised — lower latency for long texts."""
- if not text.strip():
- raise HTTPException(400, "text parameter is empty")
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- chunks = prepare_text(text, lang)
- def audio_generator():
- for i, chunk in enumerate(chunks):
- print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:80]!r}")
- try:
- wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
- except Exception as e:
- print(f" [stream] chunk {i+1} failed: {e}")
- continue # skip bad chunk rather than kill the stream
- buf = io.BytesIO()
- sf.write(buf, wav, SAMPLE_RATE, format="WAV")
- buf.seek(0)
- yield buf.read()
- # Clear after each streamed chunk — long documents would otherwise
- # accumulate VRAM and cause the next request to fall back to CPU.
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- return StreamingResponse(audio_generator(), media_type="audio/wav")
|