|
|
@@ -1,13 +1,330 @@
|
|
|
-from fastapi import FastAPI, Query
|
|
|
-from fastapi.responses import FileResponse
|
|
|
+import os
|
|
|
+import io
|
|
|
+import hashlib
|
|
|
+import pickle
|
|
|
+import subprocess
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+import torch
|
|
|
+
|
|
|
+# FIX for PyTorch >=2.6 security change
|
|
|
+from torch.serialization import add_safe_globals
|
|
|
+from TTS.tts.configs.xtts_config import XttsConfig
|
|
|
+
|
|
|
+import TTS.tts.configs.xtts_config
|
|
|
+import TTS.tts.models.xtts
|
|
|
+
|
|
|
+add_safe_globals([
|
|
|
+ TTS.tts.configs.xtts_config.XttsConfig,
|
|
|
+ TTS.tts.models.xtts.XttsAudioConfig
|
|
|
+])
|
|
|
+
|
|
|
+from fastapi import FastAPI, HTTPException
|
|
|
+from fastapi.responses import StreamingResponse
|
|
|
from TTS.api import TTS
|
|
|
-import tempfile
|
|
|
+
|
|
|
+VOICE_DIR = Path("/voices")
|
|
|
+CACHE_DIR = Path("/cache")
|
|
|
+
|
|
|
+VOICE_DIR.mkdir(exist_ok=True)
|
|
|
+CACHE_DIR.mkdir(exist_ok=True)
|
|
|
+
|
|
|
+MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
|
|
|
+
|
|
|
+print("Loading XTTS model...")
|
|
|
+tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
+print("Model loaded.")
|
|
|
|
|
|
app = FastAPI()
|
|
|
-tts = TTS(model_name="tts_models/en/ljspeech/vits", gpu=True)
|
|
|
|
|
|
+embedding_cache = {}
|
|
|
+
|
|
|
+
|
|
|
+def sha256(path):
|
|
|
+ h = hashlib.sha256()
|
|
|
+ with open(path, "rb") as f:
|
|
|
+ while True:
|
|
|
+ chunk = f.read(8192)
|
|
|
+ if not chunk:
|
|
|
+ break
|
|
|
+ h.update(chunk)
|
|
|
+ return h.hexdigest()
|
|
|
+
|
|
|
+
|
|
|
+def ensure_wav(voice_name):
|
|
|
+
|
|
|
+ wav = VOICE_DIR / f"{voice_name}.wav"
|
|
|
+ mp3 = VOICE_DIR / f"{voice_name}.mp3"
|
|
|
+
|
|
|
+ if wav.exists():
|
|
|
+
|
|
|
+ if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
|
|
|
+ print(f"MP3 newer than WAV → reconverting {voice_name}")
|
|
|
+ convert_to_wav(mp3, wav)
|
|
|
+
|
|
|
+ return wav
|
|
|
+
|
|
|
+ if mp3.exists():
|
|
|
+ print(f"Converting MP3 → WAV for {voice_name}")
|
|
|
+ convert_to_wav(mp3, wav)
|
|
|
+ return wav
|
|
|
+
|
|
|
+ raise HTTPException(404, f"Voice '{voice_name}' not found")
|
|
|
+
|
|
|
+
|
|
|
+def convert_to_wav(src, dst):
|
|
|
+
|
|
|
+ subprocess.run(
|
|
|
+ [
|
|
|
+ "ffmpeg",
|
|
|
+ "-y",
|
|
|
+ "-i",
|
|
|
+ str(src),
|
|
|
+ "-ar",
|
|
|
+ "22050",
|
|
|
+ "-ac",
|
|
|
+ "1",
|
|
|
+ str(dst),
|
|
|
+ ],
|
|
|
+ check=True,
|
|
|
+ stdout=subprocess.DEVNULL,
|
|
|
+ stderr=subprocess.DEVNULL,
|
|
|
+ )
|
|
|
+
|
|
|
+def load_cached_embedding(cache_file):
|
|
|
+ with open(cache_file, "rb") as f:
|
|
|
+ return pickle.load(f)
|
|
|
+
|
|
|
+
|
|
|
+def save_cached_embedding(cache_file, data):
|
|
|
+ with open(cache_file, "wb") as f:
|
|
|
+ pickle.dump(data, f)
|
|
|
+
|
|
|
+def get_embedding(voice_name):
|
|
|
+
|
|
|
+ if voice_name in embedding_cache:
|
|
|
+ return embedding_cache[voice_name]
|
|
|
+
|
|
|
+ src = None
|
|
|
+
|
|
|
+ for ext in ["wav", "mp3"]:
|
|
|
+ p = VOICE_DIR / f"{voice_name}.{ext}"
|
|
|
+ if p.exists():
|
|
|
+ src = p
|
|
|
+ break
|
|
|
+
|
|
|
+ if not src:
|
|
|
+ raise HTTPException(404, f"Voice '{voice_name}' not found")
|
|
|
+
|
|
|
+ wav_file = ensure_wav(voice_name)
|
|
|
+ # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
|
|
|
+
|
|
|
+ file_hash = sha256(wav_file)
|
|
|
+ cache_file = CACHE_DIR / f"{voice_name}.pkl"
|
|
|
+
|
|
|
+ if cache_file.exists():
|
|
|
+
|
|
|
+ cached = load_cached_embedding(cache_file)
|
|
|
+
|
|
|
+ if cached["hash"] == file_hash:
|
|
|
+ print(f"Using cached embedding for {voice_name}")
|
|
|
+ embedding_cache[voice_name] = cached["data"]
|
|
|
+ return cached["data"]
|
|
|
+
|
|
|
+ print(f"Computing embedding for {voice_name}")
|
|
|
+
|
|
|
+ model = tts.synthesizer.tts_model
|
|
|
+
|
|
|
+ gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
|
|
|
+ audio_path=str(wav_file)
|
|
|
+ )
|
|
|
+
|
|
|
+ data = (gpt_cond_latent, speaker_embedding)
|
|
|
+
|
|
|
+ save_cached_embedding(
|
|
|
+ cache_file,
|
|
|
+ {"hash": file_hash, "data": data},
|
|
|
+ )
|
|
|
+
|
|
|
+ embedding_cache[voice_name] = data
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+@app.get("/")
|
|
|
+def root():
|
|
|
+ return {"status": "XTTS server running"}
|
|
|
+
|
|
|
+@app.get("/voices")
|
|
|
+def list_voices():
|
|
|
+ voices = []
|
|
|
+ for f in VOICE_DIR.iterdir():
|
|
|
+ if f.suffix in [".wav", ".mp3"]:
|
|
|
+ voices.append(f.stem)
|
|
|
+ return {"voices": voices}
|
|
|
+
|
|
|
+@app.get("/tts")
|
|
|
@app.get("/api/tts")
|
|
|
-def synth(text: str = Query(...)):
|
|
|
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
|
|
- tts.tts_to_file(text=text, file_path=tmp.name)
|
|
|
- return FileResponse(tmp.name, media_type="audio/wav", filename="speech.wav")
|
|
|
+def synthesize(
|
|
|
+ text: str,
|
|
|
+ voice: str = "default",
|
|
|
+ lang: str = "en",
|
|
|
+):
|
|
|
+
|
|
|
+ import numpy as np
|
|
|
+ import torch
|
|
|
+ import io
|
|
|
+ import soundfile as sf
|
|
|
+ import re
|
|
|
+
|
|
|
+ def chunk_text(text, max_len=150):
|
|
|
+ sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
+ chunks = []
|
|
|
+ current = ""
|
|
|
+
|
|
|
+ for s in sentences:
|
|
|
+ if len(current) + len(s) > max_len:
|
|
|
+ if current:
|
|
|
+ chunks.append(current.strip())
|
|
|
+ current = s
|
|
|
+ else:
|
|
|
+ current += " " + s
|
|
|
+
|
|
|
+ if current:
|
|
|
+ chunks.append(current.strip())
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ gpt_cond_latent, speaker_embedding = get_embedding(voice)
|
|
|
+
|
|
|
+ text_chunks = chunk_text(text, max_len=150)
|
|
|
+
|
|
|
+ wav_all = []
|
|
|
+
|
|
|
+ for chunk in text_chunks:
|
|
|
+
|
|
|
+ try:
|
|
|
+ out = tts.synthesizer.tts_model.inference(
|
|
|
+ chunk,
|
|
|
+ lang,
|
|
|
+ gpt_cond_latent,
|
|
|
+ speaker_embedding,
|
|
|
+ )
|
|
|
+
|
|
|
+ except torch.cuda.OutOfMemoryError:
|
|
|
+
|
|
|
+ print("⚠ CUDA OOM – retrying chunk on CPU")
|
|
|
+
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ cpu_model = tts.synthesizer.tts_model.to("cpu")
|
|
|
+
|
|
|
+ out = cpu_model.inference(
|
|
|
+ chunk,
|
|
|
+ lang,
|
|
|
+ gpt_cond_latent.to("cpu"),
|
|
|
+ speaker_embedding.to("cpu"),
|
|
|
+ )
|
|
|
+
|
|
|
+ tts.synthesizer.tts_model.to("cuda")
|
|
|
+
|
|
|
+ wav_chunk = out["wav"]
|
|
|
+
|
|
|
+ if len(wav_chunk.shape) == 1:
|
|
|
+ wav_chunk = np.expand_dims(wav_chunk, 1)
|
|
|
+
|
|
|
+ wav_all.append(wav_chunk)
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ wav = np.concatenate(wav_all, axis=0)
|
|
|
+
|
|
|
+ buf = io.BytesIO()
|
|
|
+
|
|
|
+ sf.write(buf, wav, 24000, format="WAV")
|
|
|
+
|
|
|
+ buf.seek(0)
|
|
|
+
|
|
|
+ return StreamingResponse(buf, media_type="audio/wav")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+@app.get("/tts_stream")
|
|
|
+@app.get("/api/tts_stream")
|
|
|
+def synthesize_stream(
|
|
|
+ text: str,
|
|
|
+ voice: str = "default",
|
|
|
+ lang: str = "en",
|
|
|
+):
|
|
|
+
|
|
|
+ import numpy as np
|
|
|
+ import torch
|
|
|
+ import soundfile as sf
|
|
|
+ import re
|
|
|
+ import io
|
|
|
+
|
|
|
+ def chunk_text(text, max_len=150):
|
|
|
+ sentences = re.split(r'(?<=[.!?])\s+', text)
|
|
|
+ chunks = []
|
|
|
+ current = ""
|
|
|
+
|
|
|
+ for s in sentences:
|
|
|
+ if len(current) + len(s) > max_len:
|
|
|
+ if current:
|
|
|
+ chunks.append(current.strip())
|
|
|
+ current = s
|
|
|
+ else:
|
|
|
+ current += " " + s
|
|
|
+
|
|
|
+ if current:
|
|
|
+ chunks.append(current.strip())
|
|
|
+
|
|
|
+ return chunks
|
|
|
+
|
|
|
+ gpt_cond_latent, speaker_embedding = get_embedding(voice)
|
|
|
+
|
|
|
+ text_chunks = chunk_text(text)
|
|
|
+
|
|
|
+ def audio_generator():
|
|
|
+
|
|
|
+ for chunk in text_chunks:
|
|
|
+
|
|
|
+ try:
|
|
|
+ out = tts.synthesizer.tts_model.inference(
|
|
|
+ chunk,
|
|
|
+ lang,
|
|
|
+ gpt_cond_latent,
|
|
|
+ speaker_embedding,
|
|
|
+ )
|
|
|
+
|
|
|
+ except torch.cuda.OutOfMemoryError:
|
|
|
+
|
|
|
+ print("CUDA OOM – retrying on CPU")
|
|
|
+
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ cpu_model = tts.synthesizer.tts_model.to("cpu")
|
|
|
+
|
|
|
+ out = cpu_model.inference(
|
|
|
+ chunk,
|
|
|
+ lang,
|
|
|
+ gpt_cond_latent.to("cpu"),
|
|
|
+ speaker_embedding.to("cpu"),
|
|
|
+ )
|
|
|
+
|
|
|
+ tts.synthesizer.tts_model.to("cuda")
|
|
|
+
|
|
|
+ wav = out["wav"]
|
|
|
+
|
|
|
+ buf = io.BytesIO()
|
|
|
+
|
|
|
+ sf.write(buf, wav, 24000, format="WAV")
|
|
|
+
|
|
|
+ buf.seek(0)
|
|
|
+
|
|
|
+ yield buf.read()
|
|
|
+
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ torch.cuda.empty_cache()
|
|
|
+
|
|
|
+ return StreamingResponse(audio_generator(), media_type="audio/wav")
|