tts_server.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import os
  2. import io
  3. import hashlib
  4. import pickle
  5. import subprocess
  6. from pathlib import Path
  7. import torch
  8. # FIX for PyTorch >=2.6 security change
  9. from torch.serialization import add_safe_globals
  10. from TTS.tts.configs.xtts_config import XttsConfig
  11. import TTS.tts.configs.xtts_config
  12. import TTS.tts.models.xtts
  13. add_safe_globals([
  14. TTS.tts.configs.xtts_config.XttsConfig,
  15. TTS.tts.models.xtts.XttsAudioConfig
  16. ])
  17. from fastapi import FastAPI, HTTPException
  18. from fastapi.responses import StreamingResponse
  19. from TTS.api import TTS
  20. VOICE_DIR = Path("/voices")
  21. CACHE_DIR = Path("/cache")
  22. VOICE_DIR.mkdir(exist_ok=True)
  23. CACHE_DIR.mkdir(exist_ok=True)
  24. MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
  25. print("Loading XTTS model...")
  26. tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
  27. print("Model loaded.")
  28. app = FastAPI()
  29. embedding_cache = {}
  30. def sha256(path):
  31. h = hashlib.sha256()
  32. with open(path, "rb") as f:
  33. while True:
  34. chunk = f.read(8192)
  35. if not chunk:
  36. break
  37. h.update(chunk)
  38. return h.hexdigest()
  39. def ensure_wav(voice_name):
  40. wav = VOICE_DIR / f"{voice_name}.wav"
  41. mp3 = VOICE_DIR / f"{voice_name}.mp3"
  42. if wav.exists():
  43. if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
  44. print(f"MP3 newer than WAV → reconverting {voice_name}")
  45. convert_to_wav(mp3, wav)
  46. return wav
  47. if mp3.exists():
  48. print(f"Converting MP3 → WAV for {voice_name}")
  49. convert_to_wav(mp3, wav)
  50. return wav
  51. raise HTTPException(404, f"Voice '{voice_name}' not found")
  52. def convert_to_wav(src, dst):
  53. subprocess.run(
  54. [
  55. "ffmpeg",
  56. "-y",
  57. "-i",
  58. str(src),
  59. "-ar",
  60. "22050",
  61. "-ac",
  62. "1",
  63. str(dst),
  64. ],
  65. check=True,
  66. stdout=subprocess.DEVNULL,
  67. stderr=subprocess.DEVNULL,
  68. )
  69. def load_cached_embedding(cache_file):
  70. with open(cache_file, "rb") as f:
  71. return pickle.load(f)
  72. def save_cached_embedding(cache_file, data):
  73. with open(cache_file, "wb") as f:
  74. pickle.dump(data, f)
  75. def get_embedding(voice_name):
  76. if voice_name in embedding_cache:
  77. return embedding_cache[voice_name]
  78. src = None
  79. for ext in ["wav", "mp3"]:
  80. p = VOICE_DIR / f"{voice_name}.{ext}"
  81. if p.exists():
  82. src = p
  83. break
  84. if not src:
  85. raise HTTPException(404, f"Voice '{voice_name}' not found")
  86. wav_file = ensure_wav(voice_name)
  87. # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
  88. file_hash = sha256(wav_file)
  89. cache_file = CACHE_DIR / f"{voice_name}.pkl"
  90. if cache_file.exists():
  91. cached = load_cached_embedding(cache_file)
  92. if cached["hash"] == file_hash:
  93. print(f"Using cached embedding for {voice_name}")
  94. embedding_cache[voice_name] = cached["data"]
  95. return cached["data"]
  96. print(f"Computing embedding for {voice_name}")
  97. model = tts.synthesizer.tts_model
  98. gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
  99. audio_path=str(wav_file)
  100. )
  101. data = (gpt_cond_latent, speaker_embedding)
  102. save_cached_embedding(
  103. cache_file,
  104. {"hash": file_hash, "data": data},
  105. )
  106. embedding_cache[voice_name] = data
  107. return data
  108. @app.get("/")
  109. def root():
  110. return {"status": "XTTS server running"}
  111. @app.get("/voices")
  112. def list_voices():
  113. voices = []
  114. for f in VOICE_DIR.iterdir():
  115. if f.suffix in [".wav", ".mp3"]:
  116. voices.append(f.stem)
  117. return {"voices": voices}
  118. @app.get("/tts")
  119. @app.get("/api/tts")
  120. def synthesize(
  121. text: str,
  122. voice: str = "default",
  123. lang: str = "en",
  124. ):
  125. import numpy as np
  126. import torch
  127. import io
  128. import soundfile as sf
  129. import re
  130. def chunk_text(text, max_len=150):
  131. sentences = re.split(r'(?<=[.!?])\s+', text)
  132. chunks = []
  133. current = ""
  134. for s in sentences:
  135. if len(current) + len(s) > max_len:
  136. if current:
  137. chunks.append(current.strip())
  138. current = s
  139. else:
  140. current += " " + s
  141. if current:
  142. chunks.append(current.strip())
  143. return chunks
  144. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  145. text_chunks = chunk_text(text, max_len=150)
  146. wav_all = []
  147. for chunk in text_chunks:
  148. try:
  149. out = tts.synthesizer.tts_model.inference(
  150. chunk,
  151. lang,
  152. gpt_cond_latent,
  153. speaker_embedding,
  154. )
  155. except torch.cuda.OutOfMemoryError:
  156. print("⚠ CUDA OOM – retrying chunk on CPU")
  157. torch.cuda.empty_cache()
  158. cpu_model = tts.synthesizer.tts_model.to("cpu")
  159. out = cpu_model.inference(
  160. chunk,
  161. lang,
  162. gpt_cond_latent.to("cpu"),
  163. speaker_embedding.to("cpu"),
  164. )
  165. tts.synthesizer.tts_model.to("cuda")
  166. wav_chunk = out["wav"]
  167. if len(wav_chunk.shape) == 1:
  168. wav_chunk = np.expand_dims(wav_chunk, 1)
  169. wav_all.append(wav_chunk)
  170. if torch.cuda.is_available():
  171. torch.cuda.empty_cache()
  172. wav = np.concatenate(wav_all, axis=0)
  173. buf = io.BytesIO()
  174. sf.write(buf, wav, 24000, format="WAV")
  175. buf.seek(0)
  176. return StreamingResponse(buf, media_type="audio/wav")
  177. @app.get("/tts_stream")
  178. @app.get("/api/tts_stream")
  179. def synthesize_stream(
  180. text: str,
  181. voice: str = "default",
  182. lang: str = "en",
  183. ):
  184. import numpy as np
  185. import torch
  186. import soundfile as sf
  187. import re
  188. import io
  189. def chunk_text(text, max_len=150):
  190. sentences = re.split(r'(?<=[.!?])\s+', text)
  191. chunks = []
  192. current = ""
  193. for s in sentences:
  194. if len(current) + len(s) > max_len:
  195. if current:
  196. chunks.append(current.strip())
  197. current = s
  198. else:
  199. current += " " + s
  200. if current:
  201. chunks.append(current.strip())
  202. return chunks
  203. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  204. text_chunks = chunk_text(text)
  205. def audio_generator():
  206. for chunk in text_chunks:
  207. try:
  208. out = tts.synthesizer.tts_model.inference(
  209. chunk,
  210. lang,
  211. gpt_cond_latent,
  212. speaker_embedding,
  213. )
  214. except torch.cuda.OutOfMemoryError:
  215. print("CUDA OOM – retrying on CPU")
  216. torch.cuda.empty_cache()
  217. cpu_model = tts.synthesizer.tts_model.to("cpu")
  218. out = cpu_model.inference(
  219. chunk,
  220. lang,
  221. gpt_cond_latent.to("cpu"),
  222. speaker_embedding.to("cpu"),
  223. )
  224. tts.synthesizer.tts_model.to("cuda")
  225. wav = out["wav"]
  226. buf = io.BytesIO()
  227. sf.write(buf, wav, 24000, format="WAV")
  228. buf.seek(0)
  229. yield buf.read()
  230. if torch.cuda.is_available():
  231. torch.cuda.empty_cache()
  232. return StreamingResponse(audio_generator(), media_type="audio/wav")