tts_server_unstable.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import os
  2. import io
  3. import hashlib
  4. import pickle
  5. import subprocess
  6. from pathlib import Path
  7. import torch
  8. import numpy as np
  9. import soundfile as sf
  10. import re
  11. # FIX for PyTorch >=2.6 security change
  12. from torch.serialization import add_safe_globals
  13. from TTS.tts.configs.xtts_config import XttsConfig
  14. import TTS.tts.configs.xtts_config
  15. import TTS.tts.models.xtts
  16. add_safe_globals([
  17. TTS.tts.configs.xtts_config.XttsConfig,
  18. TTS.tts.models.xtts.XttsAudioConfig
  19. ])
  20. from fastapi import FastAPI, HTTPException
  21. from fastapi.responses import StreamingResponse
  22. from TTS.api import TTS
  23. # Set CPU threads BEFORE any torch operations - do this at module level
  24. torch.set_num_threads(os.cpu_count())
  25. torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
  26. def chunk_text(text, max_len=150):
  27. sentences = re.split(r'(?<=[.!?])\s+', text)
  28. chunks = []
  29. current = ""
  30. for s in sentences:
  31. if len(current) + len(s) > max_len:
  32. if current:
  33. chunks.append(current.strip())
  34. current = s
  35. else:
  36. current += " " + s
  37. if current:
  38. chunks.append(current.strip())
  39. return chunks
  40. VOICE_DIR = Path("/voices")
  41. CACHE_DIR = Path("/cache")
  42. VOICE_DIR.mkdir(exist_ok=True)
  43. CACHE_DIR.mkdir(exist_ok=True)
  44. MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
  45. print("Loading XTTS model...")
  46. tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
  47. print("Model loaded.")
  48. app = FastAPI()
  49. embedding_cache = {}
  50. def sha256(path):
  51. h = hashlib.sha256()
  52. with open(path, "rb") as f:
  53. while True:
  54. chunk = f.read(8192)
  55. if not chunk:
  56. break
  57. h.update(chunk)
  58. return h.hexdigest()
  59. def ensure_wav(voice_name):
  60. wav = VOICE_DIR / f"{voice_name}.wav"
  61. mp3 = VOICE_DIR / f"{voice_name}.mp3"
  62. if wav.exists():
  63. if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
  64. print(f"MP3 newer than WAV → reconverting {voice_name}")
  65. convert_to_wav(mp3, wav)
  66. return wav
  67. if mp3.exists():
  68. print(f"Converting MP3 → WAV for {voice_name}")
  69. convert_to_wav(mp3, wav)
  70. return wav
  71. raise HTTPException(404, f"Voice '{voice_name}' not found")
  72. def convert_to_wav(src, dst):
  73. subprocess.run(
  74. [
  75. "ffmpeg",
  76. "-y",
  77. "-i",
  78. str(src),
  79. "-ar",
  80. "22050",
  81. "-ac",
  82. "1",
  83. str(dst),
  84. ],
  85. check=True,
  86. stdout=subprocess.DEVNULL,
  87. stderr=subprocess.DEVNULL,
  88. )
  89. def load_cached_embedding(cache_file):
  90. with open(cache_file, "rb") as f:
  91. return pickle.load(f)
  92. def save_cached_embedding(cache_file, data):
  93. with open(cache_file, "wb") as f:
  94. pickle.dump(data, f)
  95. def get_embedding(voice_name):
  96. if voice_name in embedding_cache:
  97. return embedding_cache[voice_name]
  98. src = None
  99. for ext in ["wav", "mp3"]:
  100. p = VOICE_DIR / f"{voice_name}.{ext}"
  101. if p.exists():
  102. src = p
  103. break
  104. if not src:
  105. raise HTTPException(404, f"Voice '{voice_name}' not found")
  106. wav_file = ensure_wav(voice_name)
  107. # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
  108. file_hash = sha256(wav_file)
  109. cache_file = CACHE_DIR / f"{voice_name}.pkl"
  110. if cache_file.exists():
  111. cached = load_cached_embedding(cache_file)
  112. if cached["hash"] == file_hash:
  113. print(f"Using cached embedding for {voice_name}")
  114. embedding_cache[voice_name] = cached["data"]
  115. return cached["data"]
  116. print(f"Computing embedding for {voice_name}")
  117. model = tts.synthesizer.tts_model
  118. gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
  119. audio_path=str(wav_file)
  120. )
  121. data = (gpt_cond_latent, speaker_embedding)
  122. save_cached_embedding(
  123. cache_file,
  124. {"hash": file_hash, "data": data},
  125. )
  126. embedding_cache[voice_name] = data
  127. return data
  128. @app.get("/")
  129. def root():
  130. return {"status": "XTTS server running"}
  131. @app.get("/voices")
  132. def list_voices():
  133. voices = []
  134. for f in VOICE_DIR.iterdir():
  135. if f.suffix in [".wav", ".mp3"]:
  136. voices.append(f.stem)
  137. return {"voices": voices}
  138. @app.get("/tts")
  139. @app.get("/api/tts")
  140. def synthesize(
  141. text: str,
  142. voice: str = "default",
  143. lang: str = "en",
  144. ):
  145. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  146. text_chunks = chunk_text(text, max_len=150)
  147. wav_all = []
  148. # Move model to CPU once if GPU is already under pressure
  149. use_cpu_fallback = False
  150. if torch.cuda.is_available():
  151. free_mem, total_mem = torch.cuda.mem_get_info()
  152. if free_mem / total_mem < 0.2: # less than 20% VRAM free
  153. print("⚠ Low VRAM detected – using CPU for this request")
  154. use_cpu_fallback = True
  155. for chunk in text_chunks:
  156. try:
  157. if use_cpu_fallback:
  158. raise torch.cuda.OutOfMemoryError # skip straight to CPU path
  159. out = tts.synthesizer.tts_model.inference(
  160. chunk,
  161. lang,
  162. gpt_cond_latent,
  163. speaker_embedding,
  164. )
  165. except torch.cuda.OutOfMemoryError:
  166. print(f"⚠ CUDA OOM – retrying chunk on CPU ({os.cpu_count()} cores)")
  167. torch.cuda.empty_cache()
  168. use_cpu_fallback = True # stay on CPU for remaining chunks
  169. with torch.inference_mode(): # faster, no grad tracking
  170. cpu_model = tts.synthesizer.tts_model.to("cpu")
  171. out = cpu_model.inference(
  172. chunk,
  173. lang,
  174. gpt_cond_latent.to("cpu"),
  175. speaker_embedding.to("cpu"),
  176. )
  177. tts.synthesizer.tts_model.to("cuda")
  178. wav_chunk = out["wav"]
  179. if isinstance(wav_chunk, torch.Tensor):
  180. wav_chunk = wav_chunk.cpu().numpy()
  181. if len(wav_chunk.shape) == 1:
  182. wav_chunk = np.expand_dims(wav_chunk, 1)
  183. wav_all.append(wav_chunk)
  184. if torch.cuda.is_available() and not use_cpu_fallback:
  185. torch.cuda.empty_cache()
  186. wav = np.concatenate(wav_all, axis=0)
  187. buf = io.BytesIO()
  188. sf.write(buf, wav, 24000, format="WAV")
  189. buf.seek(0)
  190. return StreamingResponse(buf, media_type="audio/wav")
  191. @app.get("/tts_stream")
  192. @app.get("/api/tts_stream")
  193. def synthesize_stream(
  194. text: str,
  195. voice: str = "default",
  196. lang: str = "en",
  197. ):
  198. import numpy as np
  199. import torch
  200. import soundfile as sf
  201. import re
  202. import io
  203. def chunk_text(text, max_len=150):
  204. sentences = re.split(r'(?<=[.!?])\s+', text)
  205. chunks = []
  206. current = ""
  207. for s in sentences:
  208. if len(current) + len(s) > max_len:
  209. if current:
  210. chunks.append(current.strip())
  211. current = s
  212. else:
  213. current += " " + s
  214. if current:
  215. chunks.append(current.strip())
  216. return chunks
  217. gpt_cond_latent, speaker_embedding = get_embedding(voice)
  218. text_chunks = chunk_text(text)
  219. def audio_generator():
  220. for chunk in text_chunks:
  221. try:
  222. out = tts.synthesizer.tts_model.inference(
  223. chunk,
  224. lang,
  225. gpt_cond_latent,
  226. speaker_embedding,
  227. )
  228. except torch.cuda.OutOfMemoryError:
  229. print("CUDA OOM – retrying on CPU")
  230. torch.cuda.empty_cache()
  231. cpu_model = tts.synthesizer.tts_model.to("cpu")
  232. out = cpu_model.inference(
  233. chunk,
  234. lang,
  235. gpt_cond_latent.to("cpu"),
  236. speaker_embedding.to("cpu"),
  237. )
  238. tts.synthesizer.tts_model.to("cuda")
  239. wav = out["wav"]
  240. buf = io.BytesIO()
  241. sf.write(buf, wav, 24000, format="WAV")
  242. buf.seek(0)
  243. yield buf.read()
  244. if torch.cuda.is_available():
  245. torch.cuda.empty_cache()
  246. return StreamingResponse(audio_generator(), media_type="audio/wav")