| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330 |
- import os
- import io
- import hashlib
- import pickle
- import subprocess
- from pathlib import Path
- import torch
- # FIX for PyTorch >=2.6 security change
- from torch.serialization import add_safe_globals
- from TTS.tts.configs.xtts_config import XttsConfig
- import TTS.tts.configs.xtts_config
- import TTS.tts.models.xtts
- add_safe_globals([
- TTS.tts.configs.xtts_config.XttsConfig,
- TTS.tts.models.xtts.XttsAudioConfig
- ])
- from fastapi import FastAPI, HTTPException
- from fastapi.responses import StreamingResponse
- from TTS.api import TTS
- VOICE_DIR = Path("/voices")
- CACHE_DIR = Path("/cache")
- VOICE_DIR.mkdir(exist_ok=True)
- CACHE_DIR.mkdir(exist_ok=True)
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
- print("Loading XTTS model...")
- tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
- print("Model loaded.")
- app = FastAPI()
- embedding_cache = {}
- def sha256(path):
- h = hashlib.sha256()
- with open(path, "rb") as f:
- while True:
- chunk = f.read(8192)
- if not chunk:
- break
- h.update(chunk)
- return h.hexdigest()
- def ensure_wav(voice_name):
- wav = VOICE_DIR / f"{voice_name}.wav"
- mp3 = VOICE_DIR / f"{voice_name}.mp3"
- if wav.exists():
- if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
- print(f"MP3 newer than WAV → reconverting {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- if mp3.exists():
- print(f"Converting MP3 → WAV for {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- def convert_to_wav(src, dst):
- subprocess.run(
- [
- "ffmpeg",
- "-y",
- "-i",
- str(src),
- "-ar",
- "22050",
- "-ac",
- "1",
- str(dst),
- ],
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- def load_cached_embedding(cache_file):
- with open(cache_file, "rb") as f:
- return pickle.load(f)
- def save_cached_embedding(cache_file, data):
- with open(cache_file, "wb") as f:
- pickle.dump(data, f)
- def get_embedding(voice_name):
- if voice_name in embedding_cache:
- return embedding_cache[voice_name]
- src = None
- for ext in ["wav", "mp3"]:
- p = VOICE_DIR / f"{voice_name}.{ext}"
- if p.exists():
- src = p
- break
- if not src:
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- wav_file = ensure_wav(voice_name)
- # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
- file_hash = sha256(wav_file)
- cache_file = CACHE_DIR / f"{voice_name}.pkl"
- if cache_file.exists():
- cached = load_cached_embedding(cache_file)
- if cached["hash"] == file_hash:
- print(f"Using cached embedding for {voice_name}")
- embedding_cache[voice_name] = cached["data"]
- return cached["data"]
- print(f"Computing embedding for {voice_name}")
- model = tts.synthesizer.tts_model
- gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
- audio_path=str(wav_file)
- )
- data = (gpt_cond_latent, speaker_embedding)
- save_cached_embedding(
- cache_file,
- {"hash": file_hash, "data": data},
- )
- embedding_cache[voice_name] = data
- return data
- @app.get("/")
- def root():
- return {"status": "XTTS server running"}
- @app.get("/voices")
- def list_voices():
- voices = []
- for f in VOICE_DIR.iterdir():
- if f.suffix in [".wav", ".mp3"]:
- voices.append(f.stem)
- return {"voices": voices}
- @app.get("/tts")
- @app.get("/api/tts")
- def synthesize(
- text: str,
- voice: str = "default",
- lang: str = "en",
- ):
- import numpy as np
- import torch
- import io
- import soundfile as sf
- import re
- def chunk_text(text, max_len=150):
- sentences = re.split(r'(?<=[.!?])\s+', text)
- chunks = []
- current = ""
- for s in sentences:
- if len(current) + len(s) > max_len:
- if current:
- chunks.append(current.strip())
- current = s
- else:
- current += " " + s
- if current:
- chunks.append(current.strip())
- return chunks
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- text_chunks = chunk_text(text, max_len=150)
- wav_all = []
- for chunk in text_chunks:
- try:
- out = tts.synthesizer.tts_model.inference(
- chunk,
- lang,
- gpt_cond_latent,
- speaker_embedding,
- )
- except torch.cuda.OutOfMemoryError:
- print("⚠ CUDA OOM – retrying chunk on CPU")
- torch.cuda.empty_cache()
- cpu_model = tts.synthesizer.tts_model.to("cpu")
- out = cpu_model.inference(
- chunk,
- lang,
- gpt_cond_latent.to("cpu"),
- speaker_embedding.to("cpu"),
- )
- tts.synthesizer.tts_model.to("cuda")
- wav_chunk = out["wav"]
- if len(wav_chunk.shape) == 1:
- wav_chunk = np.expand_dims(wav_chunk, 1)
- wav_all.append(wav_chunk)
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- wav = np.concatenate(wav_all, axis=0)
- buf = io.BytesIO()
- sf.write(buf, wav, 24000, format="WAV")
- buf.seek(0)
- return StreamingResponse(buf, media_type="audio/wav")
- @app.get("/tts_stream")
- @app.get("/api/tts_stream")
- def synthesize_stream(
- text: str,
- voice: str = "default",
- lang: str = "en",
- ):
- import numpy as np
- import torch
- import soundfile as sf
- import re
- import io
- def chunk_text(text, max_len=150):
- sentences = re.split(r'(?<=[.!?])\s+', text)
- chunks = []
- current = ""
- for s in sentences:
- if len(current) + len(s) > max_len:
- if current:
- chunks.append(current.strip())
- current = s
- else:
- current += " " + s
- if current:
- chunks.append(current.strip())
- return chunks
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- text_chunks = chunk_text(text)
- def audio_generator():
- for chunk in text_chunks:
- try:
- out = tts.synthesizer.tts_model.inference(
- chunk,
- lang,
- gpt_cond_latent,
- speaker_embedding,
- )
- except torch.cuda.OutOfMemoryError:
- print("CUDA OOM – retrying on CPU")
- torch.cuda.empty_cache()
- cpu_model = tts.synthesizer.tts_model.to("cpu")
- out = cpu_model.inference(
- chunk,
- lang,
- gpt_cond_latent.to("cpu"),
- speaker_embedding.to("cpu"),
- )
- tts.synthesizer.tts_model.to("cuda")
- wav = out["wav"]
- buf = io.BytesIO()
- sf.write(buf, wav, 24000, format="WAV")
- buf.seek(0)
- yield buf.read()
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- return StreamingResponse(audio_generator(), media_type="audio/wav")
|