tts_server_noncaching.py 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import torch
  2. from fastapi import FastAPI, Query
  3. from fastapi.responses import FileResponse, JSONResponse
  4. from TTS.api import TTS
  5. import tempfile
  6. import os
  7. # Allow XTTS classes for torch.load
  8. #torch.serialization.add_safe_globals(["TTS.tts.configs.xtts_config.XttsConfig"])
  9. app = FastAPI(title="Multilingual TTS Server")
  10. # Use full unpickling (PyTorch ≥2.6)
  11. # Only safe because we trust Coqui models
  12. torch_load_original = torch.load
  13. def torch_load_patch(*args, **kwargs):
  14. kwargs["weights_only"] = False
  15. return torch_load_original(*args, **kwargs)
  16. torch.load = torch_load_patch
  17. # Load model
  18. tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
  19. tts.to("cuda")
  20. #tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
  21. # Directory where voice samples are stored
  22. VOICE_DIR = "/voices"
  23. # Helper: list available voices (WAV files in /voices)
  24. def list_voice_files():
  25. if not os.path.exists(VOICE_DIR):
  26. return []
  27. return [f for f in os.listdir(VOICE_DIR) if f.lower().endswith(".wav")]
  28. # Endpoint: list available voices
  29. @app.get("/voices")
  30. def get_voices():
  31. return JSONResponse(list_voice_files())
  32. # Endpoint: list supported languages
  33. @app.get("/languages")
  34. def get_languages():
  35. return JSONResponse(tts.languages)
  36. # Endpoint: list available speakers (from the model)
  37. #@app.get("/speakers")
  38. #def get_speakers():
  39. # return JSONResponse(tts.speakers)
  40. #@app.get("/api/speakers")
  41. #def speakers():
  42. # return {"speakers": tts.speakers}
  43. # Endpoint: TTS synthesis
  44. @app.get("/api/tts")
  45. def synthesize(
  46. text: str = Query(..., description="Text to speak"),
  47. lang: str = Query("en", description="Language code (e.g., en, de, es)"),
  48. speaker: str = Query(None, description="Speaker ID from model speakers"),
  49. voice_file: str = Query(None, description="Filename of WAV voice sample in /voices for cloning"),
  50. speed: float = Query(1.0, ge=0.5, le=2.0, description="Speech speed multiplier"),
  51. pitch: float = Query(1.0, ge=0.5, le=2.0, description="Pitch multiplier (approximate)")
  52. ):
  53. if voice_file:
  54. path = os.path.join(VOICE_DIR, voice_file)
  55. if not os.path.isfile(path):
  56. raise HTTPException(status_code=404, detail=f"Voice file '{voice_file}' not found")
  57. speaker_wav = path
  58. else:
  59. speaker_wav = None
  60. # Temporary output file
  61. tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
  62. # voice = request.args.get("voice", "default")
  63. # speaker_wav = f"voices/{voice}.wav"
  64. # Generate TTS
  65. tts.tts_to_file(
  66. text=text,
  67. file_path=tmp.name,
  68. language=lang,
  69. speaker_wav="/voices/trump.wav",
  70. )
  71. return FileResponse(tmp.name, media_type="audio/wav", filename="speech.wav")