|
|
@@ -1,330 +1,309 @@
|
|
|
import os
|
|
|
import io
|
|
|
+import re
|
|
|
import hashlib
|
|
|
import pickle
|
|
|
import subprocess
|
|
|
+import threading
|
|
|
from pathlib import Path
|
|
|
|
|
|
+import numpy as np
|
|
|
+import soundfile as sf
|
|
|
import torch
|
|
|
|
|
|
+# Set CPU threads BEFORE any torch operations
|
|
|
+torch.set_num_threads(os.cpu_count())
|
|
|
+torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
|
|
|
+
|
|
|
# FIX for PyTorch >=2.6 security change
|
|
|
from torch.serialization import add_safe_globals
|
|
|
-from TTS.tts.configs.xtts_config import XttsConfig
|
|
|
-
|
|
|
import TTS.tts.configs.xtts_config
|
|
|
import TTS.tts.models.xtts
|
|
|
-
|
|
|
add_safe_globals([
|
|
|
TTS.tts.configs.xtts_config.XttsConfig,
|
|
|
- TTS.tts.models.xtts.XttsAudioConfig
|
|
|
+ TTS.tts.models.xtts.XttsAudioConfig,
|
|
|
])
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
from fastapi.responses import StreamingResponse
|
|
|
from TTS.api import TTS
|
|
|
|
|
|
-VOICE_DIR = Path("/voices")
|
|
|
-CACHE_DIR = Path("/cache")
|
|
|
+# ─── Paths & constants ────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+VOICE_DIR = Path("/voices")
|
|
|
+CACHE_DIR = Path("/cache")
|
|
|
+MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
|
+SAMPLE_RATE = 24000
|
|
|
+VRAM_HEADROOM = 0.20 # fall back to CPU when VRAM < 20% free
|
|
|
+MAX_CHUNK_LEN = 200 # chars; XTTS hard-limit is ~400 tokens ≈ 250 chars
|
|
|
|
|
|
VOICE_DIR.mkdir(exist_ok=True)
|
|
|
CACHE_DIR.mkdir(exist_ok=True)
|
|
|
|
|
|
-MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
|
+# ─── Model loading ────────────────────────────────────────────────────────────
|
|
|
|
|
|
print("Loading XTTS model...")
|
|
|
-tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
-print("Model loaded.")
|
|
|
+_device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
+tts = TTS(MODEL_NAME).to(_device)
|
|
|
+print(f"Model loaded on {_device}.")
|
|
|
|
|
|
-app = FastAPI()
|
|
|
+# Single lock so concurrent requests don't fight over GPU / model.to() calls
|
|
|
+_model_lock = threading.Lock()
|
|
|
|
|
|
-embedding_cache = {}
|
|
|
-
|
|
|
-
|
|
|
-def sha256(path):
|
|
|
+app = FastAPI()
|
|
|
+embedding_cache: dict = {}
|
|
|
+
|
|
|
+# ─── Text helpers ─────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+# Characters XTTS tokeniser chokes on → strip or replace before inference
|
|
|
+_MARKDOWN_RE = re.compile(r'\*{1,2}|_{1,2}|`+|#{1,6}\s?|~~|\[([^\]]*)\]\([^)]*\)')
|
|
|
+_MULTI_SPACE = re.compile(r' +')
|
|
|
+_CONTROL_CHARS = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
|
|
|
+
|
|
|
+def clean_text(text: str) -> str:
|
|
|
+ """Remove markdown and control characters that corrupt XTTS tokenisation."""
|
|
|
+ text = _MARKDOWN_RE.sub(r'\1', text) # strip md, keep link label
|
|
|
+ text = _CONTROL_CHARS.sub('', text)
|
|
|
+ text = text.replace('\r\n', '\n').replace('\r', '\n')
|
|
|
+ # Collapse multiple blank lines / spaces
|
|
|
+ text = re.sub(r'\n{3,}', '\n\n', text)
|
|
|
+ text = _MULTI_SPACE.sub(' ', text)
|
|
|
+ return text.strip()
|
|
|
+
|
|
|
+
|
|
|
+def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
|
|
|
+ """
|
|
|
+ Split on sentence boundaries. Falls back to word-boundary splitting
|
|
|
+ for sentences that are still too long (e.g. no punctuation at all).
|
|
|
+ """
|
|
|
+ text = clean_text(text)
|
|
|
+ # Split on sentence-ending punctuation followed by whitespace or end
|
|
|
+ sentences = re.split(r'(?<=[.!?…])\s+', text)
|
|
|
+
|
|
|
+ chunks: list[str] = []
|
|
|
+ current = ""
|
|
|
+
|
|
|
+ for s in sentences:
|
|
|
+ s = s.strip()
|
|
|
+ if not s:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Single sentence longer than max_len → split on word boundary
|
|
|
+ if len(s) > max_len:
|
|
|
+ if current:
|
|
|
+ chunks.append(current)
|
|
|
+ current = ""
|
|
|
+ words = s.split()
|
|
|
+ part = ""
|
|
|
+ for w in words:
|
|
|
+ if len(part) + len(w) + 1 > max_len:
|
|
|
+ if part:
|
|
|
+ chunks.append(part.strip())
|
|
|
+ part = w
|
|
|
+ else:
|
|
|
+ part = (part + " " + w).strip()
|
|
|
+ if part:
|
|
|
+ chunks.append(part)
|
|
|
+ continue
|
|
|
+
|
|
|
+ if len(current) + len(s) + 1 > max_len:
|
|
|
+ if current:
|
|
|
+ chunks.append(current)
|
|
|
+ current = s
|
|
|
+ else:
|
|
|
+ current = (current + " " + s).strip()
|
|
|
+
|
|
|
+ if current:
|
|
|
+ chunks.append(current)
|
|
|
+
|
|
|
+ return [c for c in chunks if c]
|
|
|
+
|
|
|
+
|
|
|
+# ─── Voice / embedding helpers ────────────────────────────────────────────────
|
|
|
+
|
|
|
+def sha256_file(path: Path) -> str:
|
|
|
h = hashlib.sha256()
|
|
|
with open(path, "rb") as f:
|
|
|
- while True:
|
|
|
- chunk = f.read(8192)
|
|
|
- if not chunk:
|
|
|
- break
|
|
|
- h.update(chunk)
|
|
|
+ for block in iter(lambda: f.read(65536), b""):
|
|
|
+ h.update(block)
|
|
|
return h.hexdigest()
|
|
|
|
|
|
|
|
|
-def ensure_wav(voice_name):
|
|
|
+def convert_to_wav(src: Path, dst: Path) -> None:
|
|
|
+ subprocess.run(
|
|
|
+ ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
|
|
|
+ check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
|
|
+ )
|
|
|
+
|
|
|
|
|
|
+def ensure_wav(voice_name: str) -> Path:
|
|
|
wav = VOICE_DIR / f"{voice_name}.wav"
|
|
|
mp3 = VOICE_DIR / f"{voice_name}.mp3"
|
|
|
-
|
|
|
if wav.exists():
|
|
|
-
|
|
|
if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
|
|
|
print(f"MP3 newer than WAV → reconverting {voice_name}")
|
|
|
convert_to_wav(mp3, wav)
|
|
|
-
|
|
|
return wav
|
|
|
-
|
|
|
if mp3.exists():
|
|
|
print(f"Converting MP3 → WAV for {voice_name}")
|
|
|
convert_to_wav(mp3, wav)
|
|
|
return wav
|
|
|
-
|
|
|
raise HTTPException(404, f"Voice '{voice_name}' not found")
|
|
|
|
|
|
|
|
|
-def convert_to_wav(src, dst):
|
|
|
-
|
|
|
- subprocess.run(
|
|
|
- [
|
|
|
- "ffmpeg",
|
|
|
- "-y",
|
|
|
- "-i",
|
|
|
- str(src),
|
|
|
- "-ar",
|
|
|
- "22050",
|
|
|
- "-ac",
|
|
|
- "1",
|
|
|
- str(dst),
|
|
|
- ],
|
|
|
- check=True,
|
|
|
- stdout=subprocess.DEVNULL,
|
|
|
- stderr=subprocess.DEVNULL,
|
|
|
- )
|
|
|
-
|
|
|
-def load_cached_embedding(cache_file):
|
|
|
- with open(cache_file, "rb") as f:
|
|
|
- return pickle.load(f)
|
|
|
-
|
|
|
-
|
|
|
-def save_cached_embedding(cache_file, data):
|
|
|
- with open(cache_file, "wb") as f:
|
|
|
- pickle.dump(data, f)
|
|
|
-
|
|
|
-def get_embedding(voice_name):
|
|
|
-
|
|
|
+def get_embedding(voice_name: str):
|
|
|
if voice_name in embedding_cache:
|
|
|
return embedding_cache[voice_name]
|
|
|
|
|
|
- src = None
|
|
|
-
|
|
|
- for ext in ["wav", "mp3"]:
|
|
|
- p = VOICE_DIR / f"{voice_name}.{ext}"
|
|
|
- if p.exists():
|
|
|
- src = p
|
|
|
- break
|
|
|
-
|
|
|
- if not src:
|
|
|
- raise HTTPException(404, f"Voice '{voice_name}' not found")
|
|
|
-
|
|
|
- wav_file = ensure_wav(voice_name)
|
|
|
- # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
|
|
|
-
|
|
|
- file_hash = sha256(wav_file)
|
|
|
+ wav_file = ensure_wav(voice_name)
|
|
|
+ file_hash = sha256_file(wav_file)
|
|
|
cache_file = CACHE_DIR / f"{voice_name}.pkl"
|
|
|
|
|
|
if cache_file.exists():
|
|
|
-
|
|
|
- cached = load_cached_embedding(cache_file)
|
|
|
-
|
|
|
- if cached["hash"] == file_hash:
|
|
|
- print(f"Using cached embedding for {voice_name}")
|
|
|
- embedding_cache[voice_name] = cached["data"]
|
|
|
- return cached["data"]
|
|
|
+ try:
|
|
|
+ with open(cache_file, "rb") as f:
|
|
|
+ cached = pickle.load(f)
|
|
|
+ if cached.get("hash") == file_hash:
|
|
|
+ print(f"Using cached embedding for {voice_name}")
|
|
|
+ embedding_cache[voice_name] = cached["data"]
|
|
|
+ return cached["data"]
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Cache read error for {voice_name}: {e} – recomputing")
|
|
|
|
|
|
print(f"Computing embedding for {voice_name}")
|
|
|
-
|
|
|
model = tts.synthesizer.tts_model
|
|
|
-
|
|
|
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
|
|
|
audio_path=str(wav_file)
|
|
|
)
|
|
|
-
|
|
|
data = (gpt_cond_latent, speaker_embedding)
|
|
|
+ with open(cache_file, "wb") as f:
|
|
|
+ pickle.dump({"hash": file_hash, "data": data}, f)
|
|
|
+ embedding_cache[voice_name] = data
|
|
|
+ return data
|
|
|
|
|
|
- save_cached_embedding(
|
|
|
- cache_file,
|
|
|
- {"hash": file_hash, "data": data},
|
|
|
- )
|
|
|
|
|
|
- embedding_cache[voice_name] = data
|
|
|
+# ─── Core inference helper ────────────────────────────────────────────────────
|
|
|
+
|
|
|
+def _vram_low() -> bool:
|
|
|
+ if not torch.cuda.is_available():
|
|
|
+ return True
|
|
|
+ free, total = torch.cuda.mem_get_info()
|
|
|
+ return (free / total) < VRAM_HEADROOM
|
|
|
+
|
|
|
+
|
|
|
+def _infer_chunk(chunk: str, lang: str, gpt_cond_latent, speaker_embedding) -> np.ndarray:
|
|
|
+ """Run inference for one chunk; falls back to CPU on OOM."""
|
|
|
+ model = tts.synthesizer.tts_model
|
|
|
+
|
|
|
+ def _run(m, lat, emb):
|
|
|
+ with torch.inference_mode():
|
|
|
+ out = m.inference(chunk, lang, lat, emb)
|
|
|
+ wav = out["wav"]
|
|
|
+ if isinstance(wav, torch.Tensor):
|
|
|
+ wav = wav.cpu().numpy()
|
|
|
+ if wav.ndim == 1:
|
|
|
+ wav = np.expand_dims(wav, 1)
|
|
|
+ return wav
|
|
|
+
|
|
|
+ with _model_lock:
|
|
|
+ try:
|
|
|
+ return _run(model, gpt_cond_latent, speaker_embedding)
|
|
|
+ except torch.cuda.OutOfMemoryError:
|
|
|
+ print(f"⚠ CUDA OOM on chunk – falling back to CPU ({os.cpu_count()} cores)")
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ model.to("cpu")
|
|
|
+ try:
|
|
|
+ result = _run(
|
|
|
+ model,
|
|
|
+ gpt_cond_latent.to("cpu"),
|
|
|
+ speaker_embedding.to("cpu"),
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ # Always move back, even if CPU inference also fails
|
|
|
+ model.to("cuda")
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+ return result
|
|
|
|
|
|
- return data
|
|
|
+
|
|
|
+# ─── Routes ───────────────────────────────────────────────────────────────────
|
|
|
|
|
|
@app.get("/")
|
|
|
def root():
|
|
|
- return {"status": "XTTS server running"}
|
|
|
+ return {"status": "XTTS server running", "device": _device}
|
|
|
+
|
|
|
|
|
|
@app.get("/voices")
|
|
|
def list_voices():
|
|
|
+ seen = set()
|
|
|
voices = []
|
|
|
for f in VOICE_DIR.iterdir():
|
|
|
- if f.suffix in [".wav", ".mp3"]:
|
|
|
+ if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
|
|
|
voices.append(f.stem)
|
|
|
- return {"voices": voices}
|
|
|
+ seen.add(f.stem)
|
|
|
+ return {"voices": sorted(voices)}
|
|
|
+
|
|
|
|
|
|
@app.get("/tts")
|
|
|
@app.get("/api/tts")
|
|
|
-def synthesize(
|
|
|
- text: str,
|
|
|
- voice: str = "default",
|
|
|
- lang: str = "en",
|
|
|
-):
|
|
|
-
|
|
|
- import numpy as np
|
|
|
- import torch
|
|
|
- import io
|
|
|
- import soundfile as sf
|
|
|
- import re
|
|
|
-
|
|
|
- def chunk_text(text, max_len=150):
|
|
|
- sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
- chunks = []
|
|
|
- current = ""
|
|
|
-
|
|
|
- for s in sentences:
|
|
|
- if len(current) + len(s) > max_len:
|
|
|
- if current:
|
|
|
- chunks.append(current.strip())
|
|
|
- current = s
|
|
|
- else:
|
|
|
- current += " " + s
|
|
|
-
|
|
|
- if current:
|
|
|
- chunks.append(current.strip())
|
|
|
-
|
|
|
- return chunks
|
|
|
+def synthesize(text: str, voice: str = "default", lang: str = "en"):
|
|
|
+ if not text.strip():
|
|
|
+ raise HTTPException(400, "text parameter is empty")
|
|
|
|
|
|
gpt_cond_latent, speaker_embedding = get_embedding(voice)
|
|
|
|
|
|
- text_chunks = chunk_text(text, max_len=150)
|
|
|
+ # If VRAM is already scarce, pin embeddings on CPU for this whole request
|
|
|
+ use_cpu = _vram_low()
|
|
|
+ if use_cpu and torch.cuda.is_available():
|
|
|
+ print("⚠ Low VRAM – pinning entire request to CPU")
|
|
|
+ gpt_cond_latent = gpt_cond_latent.to("cpu")
|
|
|
+ speaker_embedding = speaker_embedding.to("cpu")
|
|
|
+ with _model_lock:
|
|
|
+ tts.synthesizer.tts_model.to("cpu")
|
|
|
|
|
|
+ chunks = chunk_text(text)
|
|
|
wav_all = []
|
|
|
|
|
|
- for chunk in text_chunks:
|
|
|
-
|
|
|
+ for i, chunk in enumerate(chunks):
|
|
|
+ print(f" chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
|
|
|
try:
|
|
|
- out = tts.synthesizer.tts_model.inference(
|
|
|
- chunk,
|
|
|
- lang,
|
|
|
- gpt_cond_latent,
|
|
|
- speaker_embedding,
|
|
|
- )
|
|
|
-
|
|
|
- except torch.cuda.OutOfMemoryError:
|
|
|
-
|
|
|
- print("⚠ CUDA OOM – retrying chunk on CPU")
|
|
|
-
|
|
|
- torch.cuda.empty_cache()
|
|
|
-
|
|
|
- cpu_model = tts.synthesizer.tts_model.to("cpu")
|
|
|
-
|
|
|
- out = cpu_model.inference(
|
|
|
- chunk,
|
|
|
- lang,
|
|
|
- gpt_cond_latent.to("cpu"),
|
|
|
- speaker_embedding.to("cpu"),
|
|
|
- )
|
|
|
-
|
|
|
- tts.synthesizer.tts_model.to("cuda")
|
|
|
-
|
|
|
- wav_chunk = out["wav"]
|
|
|
-
|
|
|
- if len(wav_chunk.shape) == 1:
|
|
|
- wav_chunk = np.expand_dims(wav_chunk, 1)
|
|
|
-
|
|
|
+ wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
|
|
|
+ except Exception as e:
|
|
|
+ raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
|
|
|
wav_all.append(wav_chunk)
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
+ # Restore model to GPU if we moved it
|
|
|
+ if use_cpu and torch.cuda.is_available():
|
|
|
+ with _model_lock:
|
|
|
+ tts.synthesizer.tts_model.to("cuda")
|
|
|
|
|
|
wav = np.concatenate(wav_all, axis=0)
|
|
|
-
|
|
|
buf = io.BytesIO()
|
|
|
-
|
|
|
- sf.write(buf, wav, 24000, format="WAV")
|
|
|
-
|
|
|
+ sf.write(buf, wav, SAMPLE_RATE, format="WAV")
|
|
|
buf.seek(0)
|
|
|
-
|
|
|
return StreamingResponse(buf, media_type="audio/wav")
|
|
|
|
|
|
|
|
|
-
|
|
|
@app.get("/tts_stream")
|
|
|
@app.get("/api/tts_stream")
|
|
|
-def synthesize_stream(
|
|
|
- text: str,
|
|
|
- voice: str = "default",
|
|
|
- lang: str = "en",
|
|
|
-):
|
|
|
-
|
|
|
- import numpy as np
|
|
|
- import torch
|
|
|
- import soundfile as sf
|
|
|
- import re
|
|
|
- import io
|
|
|
-
|
|
|
- def chunk_text(text, max_len=150):
|
|
|
- sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
- chunks = []
|
|
|
- current = ""
|
|
|
-
|
|
|
- for s in sentences:
|
|
|
- if len(current) + len(s) > max_len:
|
|
|
- if current:
|
|
|
- chunks.append(current.strip())
|
|
|
- current = s
|
|
|
- else:
|
|
|
- current += " " + s
|
|
|
-
|
|
|
- if current:
|
|
|
- chunks.append(current.strip())
|
|
|
-
|
|
|
- return chunks
|
|
|
+def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
|
|
|
+ """Stream WAV chunks as they are synthesised — lower latency for long texts."""
|
|
|
+ if not text.strip():
|
|
|
+ raise HTTPException(400, "text parameter is empty")
|
|
|
|
|
|
gpt_cond_latent, speaker_embedding = get_embedding(voice)
|
|
|
-
|
|
|
- text_chunks = chunk_text(text)
|
|
|
+ chunks = chunk_text(text)
|
|
|
|
|
|
def audio_generator():
|
|
|
-
|
|
|
- for chunk in text_chunks:
|
|
|
-
|
|
|
+ for i, chunk in enumerate(chunks):
|
|
|
+ print(f" [stream] chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
|
|
|
try:
|
|
|
- out = tts.synthesizer.tts_model.inference(
|
|
|
- chunk,
|
|
|
- lang,
|
|
|
- gpt_cond_latent,
|
|
|
- speaker_embedding,
|
|
|
- )
|
|
|
-
|
|
|
- except torch.cuda.OutOfMemoryError:
|
|
|
-
|
|
|
- print("CUDA OOM – retrying on CPU")
|
|
|
-
|
|
|
- torch.cuda.empty_cache()
|
|
|
-
|
|
|
- cpu_model = tts.synthesizer.tts_model.to("cpu")
|
|
|
-
|
|
|
- out = cpu_model.inference(
|
|
|
- chunk,
|
|
|
- lang,
|
|
|
- gpt_cond_latent.to("cpu"),
|
|
|
- speaker_embedding.to("cpu"),
|
|
|
- )
|
|
|
-
|
|
|
- tts.synthesizer.tts_model.to("cuda")
|
|
|
-
|
|
|
- wav = out["wav"]
|
|
|
-
|
|
|
+ wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" [stream] chunk {i+1} failed: {e}")
|
|
|
+ continue # skip bad chunk rather than kill the stream
|
|
|
buf = io.BytesIO()
|
|
|
-
|
|
|
- sf.write(buf, wav, 24000, format="WAV")
|
|
|
-
|
|
|
+ sf.write(buf, wav, SAMPLE_RATE, format="WAV")
|
|
|
buf.seek(0)
|
|
|
-
|
|
|
yield buf.read()
|
|
|
|
|
|
- if torch.cuda.is_available():
|
|
|
- torch.cuda.empty_cache()
|
|
|
-
|
|
|
return StreamingResponse(audio_generator(), media_type="audio/wav")
|