| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- import os
- import io
- import hashlib
- import pickle
- import subprocess
- from pathlib import Path
- import torch
- # FIX for PyTorch >=2.6 security change
- from torch.serialization import add_safe_globals
- from TTS.tts.configs.xtts_config import XttsConfig
- import TTS.tts.configs.xtts_config
- import TTS.tts.models.xtts
- add_safe_globals([
- TTS.tts.configs.xtts_config.XttsConfig,
- TTS.tts.models.xtts.XttsAudioConfig
- ])
- from fastapi import FastAPI, HTTPException
- from fastapi.responses import StreamingResponse
- from TTS.api import TTS
- VOICE_DIR = Path("/voices")
- CACHE_DIR = Path("/cache")
- VOICE_DIR.mkdir(exist_ok=True)
- CACHE_DIR.mkdir(exist_ok=True)
- MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
- print("Loading XTTS model...")
- tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
- print("Model loaded.")
- app = FastAPI()
- embedding_cache = {}
- def sha256(path):
- h = hashlib.sha256()
- with open(path, "rb") as f:
- while True:
- chunk = f.read(8192)
- if not chunk:
- break
- h.update(chunk)
- return h.hexdigest()
- def ensure_wav(voice_name):
- wav = VOICE_DIR / f"{voice_name}.wav"
- mp3 = VOICE_DIR / f"{voice_name}.mp3"
- if wav.exists():
- if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
- print(f"MP3 newer than WAV → reconverting {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- if mp3.exists():
- print(f"Converting MP3 → WAV for {voice_name}")
- convert_to_wav(mp3, wav)
- return wav
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- def convert_to_wav(src, dst):
- subprocess.run(
- [
- "ffmpeg",
- "-y",
- "-i",
- str(src),
- "-ar",
- "22050",
- "-ac",
- "1",
- str(dst),
- ],
- check=True,
- stdout=subprocess.DEVNULL,
- stderr=subprocess.DEVNULL,
- )
- def load_cached_embedding(cache_file):
- with open(cache_file, "rb") as f:
- return pickle.load(f)
- def save_cached_embedding(cache_file, data):
- with open(cache_file, "wb") as f:
- pickle.dump(data, f)
- def get_embedding(voice_name):
- if voice_name in embedding_cache:
- return embedding_cache[voice_name]
- src = None
- for ext in ["wav", "mp3"]:
- p = VOICE_DIR / f"{voice_name}.{ext}"
- if p.exists():
- src = p
- break
- if not src:
- raise HTTPException(404, f"Voice '{voice_name}' not found")
- wav_file = ensure_wav(voice_name)
- # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
- file_hash = sha256(wav_file)
- cache_file = CACHE_DIR / f"{voice_name}.pkl"
- if cache_file.exists():
- cached = load_cached_embedding(cache_file)
- if cached["hash"] == file_hash:
- print(f"Using cached embedding for {voice_name}")
- embedding_cache[voice_name] = cached["data"]
- return cached["data"]
- print(f"Computing embedding for {voice_name}")
- model = tts.synthesizer.tts_model
- gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
- audio_path=str(wav_file)
- )
- data = (gpt_cond_latent, speaker_embedding)
- save_cached_embedding(
- cache_file,
- {"hash": file_hash, "data": data},
- )
- embedding_cache[voice_name] = data
- return data
- @app.get("/")
- def root():
- return {"status": "XTTS server running"}
- @app.get("/voices")
- def list_voices():
- voices = []
- for f in VOICE_DIR.iterdir():
- if f.suffix in [".wav", ".mp3"]:
- voices.append(f.stem)
- return {"voices": voices}
- @app.get("/tts")
- @app.get("/api/tts")
- def synthesize(
- text: str,
- voice: str = "default",
- lang: str = "en",
- ):
- gpt_cond_latent, speaker_embedding = get_embedding(voice)
- out = tts.synthesizer.tts_model.inference(
- text,
- lang,
- gpt_cond_latent,
- speaker_embedding,
- )
- buf = io.BytesIO()
- import numpy as np
- import soundfile as sf
- # XTTS returns a dict, extract numpy array
- wav = out["wav"]
- # Ensure 2D for soundfile
- if len(wav.shape) == 1:
- wav = np.expand_dims(wav, 1)
- sf.write(buf, wav, 24000, format="WAV")
- # Ensure the audio is 2D: (samples, channels)
- #if len(wav.shape) == 1:
- # wav = np.expand_dims(wav, 1)
- #sf.write(buf, wav, 24000, format="WAV")
- #import soundfile as sf
- #sf.write(buf, wav, 24000, format="WAV", subtype="PCM_16")
- #sf.write(buf, wav, 24000, format="WAV")
- buf.seek(0)
- return StreamingResponse(buf, media_type="audio/wav")
|