tts_server_nochunks.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. import os
  2. import io
  3. import hashlib
  4. import pickle
  5. import subprocess
  6. from pathlib import Path
  7. import torch
  8. # FIX for PyTorch >=2.6 security change
  9. from torch.serialization import add_safe_globals
  10. from TTS.tts.configs.xtts_config import XttsConfig
  11. import TTS.tts.configs.xtts_config
  12. import TTS.tts.models.xtts
  13. add_safe_globals([
  14. TTS.tts.configs.xtts_config.XttsConfig,
  15. TTS.tts.models.xtts.XttsAudioConfig
  16. ])
  17. from fastapi import FastAPI, HTTPException
  18. from fastapi.responses import StreamingResponse
  19. from TTS.api import TTS
  20. VOICE_DIR = Path("/voices")
  21. CACHE_DIR = Path("/cache")
  22. VOICE_DIR.mkdir(exist_ok=True)
  23. CACHE_DIR.mkdir(exist_ok=True)
  24. MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
  25. print("Loading XTTS model...")
  26. tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
  27. print("Model loaded.")
  28. app = FastAPI()
  29. embedding_cache = {}
  30. def sha256(path):
  31. h = hashlib.sha256()
  32. with open(path, "rb") as f:
  33. while True:
  34. chunk = f.read(8192)
  35. if not chunk:
  36. break
  37. h.update(chunk)
  38. return h.hexdigest()
  39. def ensure_wav(voice_name):
  40. wav = VOICE_DIR / f"{voice_name}.wav"
  41. mp3 = VOICE_DIR / f"{voice_name}.mp3"
  42. if wav.exists():
  43. if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
  44. print(f"MP3 newer than WAV → reconverting {voice_name}")
  45. convert_to_wav(mp3, wav)
  46. return wav
  47. if mp3.exists():
  48. print(f"Converting MP3 → WAV for {voice_name}")
  49. convert_to_wav(mp3, wav)
  50. return wav
  51. raise HTTPException(404, f"Voice '{voice_name}' not found")
  52. def convert_to_wav(src, dst):
  53. subprocess.run(
  54. [
  55. "ffmpeg",
  56. "-y",
  57. "-i",
  58. str(src),
  59. "-ar",
  60. "22050",
  61. "-ac",
  62. "1",
  63. str(dst),
  64. ],
  65. check=True,
  66. stdout=subprocess.DEVNULL,
  67. stderr=subprocess.DEVNULL,
  68. )
  69. def load_cached_embedding(cache_file):
  70. with open(cache_file, "rb") as f:
  71. return pickle.load(f)
  72. def save_cached_embedding(cache_file, data):
  73. with open(cache_file, "wb") as f:
  74. pickle.dump(data, f)
  75. def get_embedding(voice_name):
  76. if voice_name in embedding_cache:
  77. return embedding_cache[voice_name]
  78. src = None
  79. for ext in ["wav", "mp3"]:
  80. p = VOICE_DIR / f"{voice_name}.{ext}"
  81. if p.exists():
  82. src = p
  83. break
  84. if not src:
  85. raise HTTPException(404, f"Voice '{voice_name}' not found")
  86. wav_file = ensure_wav(voice_name)
  87. # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
  88. file_hash = sha256(wav_file)
  89. cache_file = CACHE_DIR / f"{voice_name}.pkl"
  90. if cache_file.exists():
  91. cached = load_cached_embedding(cache_file)
  92. if cached["hash"] == file_hash:
  93. print(f"Using cached embedding for {voice_name}")
  94. embedding_cache[voice_name] = cached["data"]
  95. return cached["data"]
  96. print(f"Computing embedding for {voice_name}")
  97. model = tts.synthesizer.tts_model
  98. gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
  99. audio_path=str(wav_file)
  100. )
  101. data = (gpt_cond_latent, speaker_embedding)
  102. save_cached_embedding(
  103. cache_file,
  104. {"hash": file_hash, "data": data},
  105. )
  106. embedding_cache[voice_name] = data
  107. return data
  108. @app.get("/")
  109. def root():
  110. return {"status": "XTTS server running"}
  111. @app.get("/voices")
  112. def list_voices():
  113. voices = []
  114. for f in VOICE_DIR.iterdir():
  115. if f.suffix in [".wav", ".mp3"]:
  116. voices.append(f.stem)
  117. return {"voices": voices}
  118. @app.get("/tts")
  119. @app.get("/api/tts")
  120. def synthesize(
  121. text: str,
  122. voice: str = "default",
  123. lang: str = "en",
  124. ):
  125. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  126. out = tts.synthesizer.tts_model.inference(
  127. text,
  128. lang,
  129. gpt_cond_latent,
  130. speaker_embedding,
  131. )
  132. buf = io.BytesIO()
  133. import numpy as np
  134. import soundfile as sf
  135. # XTTS returns a dict, extract numpy array
  136. wav = out["wav"]
  137. # Ensure 2D for soundfile
  138. if len(wav.shape) == 1:
  139. wav = np.expand_dims(wav, 1)
  140. sf.write(buf, wav, 24000, format="WAV")
  141. # Ensure the audio is 2D: (samples, channels)
  142. #if len(wav.shape) == 1:
  143. # wav = np.expand_dims(wav, 1)
  144. #sf.write(buf, wav, 24000, format="WAV")
  145. #import soundfile as sf
  146. #sf.write(buf, wav, 24000, format="WAV", subtype="PCM_16")
  147. #sf.write(buf, wav, 24000, format="WAV")
  148. buf.seek(0)
  149. return StreamingResponse(buf, media_type="audio/wav")