| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889 |
- import torch
- from fastapi import FastAPI, Query
- from fastapi.responses import FileResponse, JSONResponse
- from TTS.api import TTS
- import tempfile
- import os
- # Allow XTTS classes for torch.load
- #torch.serialization.add_safe_globals(["TTS.tts.configs.xtts_config.XttsConfig"])
- app = FastAPI(title="Multilingual TTS Server")
- # Use full unpickling (PyTorch ≥2.6)
- # Only safe because we trust Coqui models
- torch_load_original = torch.load
- def torch_load_patch(*args, **kwargs):
- kwargs["weights_only"] = False
- return torch_load_original(*args, **kwargs)
- torch.load = torch_load_patch
- # Load model
- tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
- tts.to("cuda")
- #tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
- # Directory where voice samples are stored
- VOICE_DIR = "/voices"
- # Helper: list available voices (WAV files in /voices)
- def list_voice_files():
- if not os.path.exists(VOICE_DIR):
- return []
- return [f for f in os.listdir(VOICE_DIR) if f.lower().endswith(".wav")]
- # Endpoint: list available voices
- @app.get("/voices")
- def get_voices():
- return JSONResponse(list_voice_files())
- # Endpoint: list supported languages
- @app.get("/languages")
- def get_languages():
- return JSONResponse(tts.languages)
- # Endpoint: list available speakers (from the model)
- #@app.get("/speakers")
- #def get_speakers():
- # return JSONResponse(tts.speakers)
- #@app.get("/api/speakers")
- #def speakers():
- # return {"speakers": tts.speakers}
- # Endpoint: TTS synthesis
- @app.get("/api/tts")
- def synthesize(
- text: str = Query(..., description="Text to speak"),
- lang: str = Query("en", description="Language code (e.g., en, de, es)"),
- speaker: str = Query(None, description="Speaker ID from model speakers"),
- voice_file: str = Query(None, description="Filename of WAV voice sample in /voices for cloning"),
- speed: float = Query(1.0, ge=0.5, le=2.0, description="Speech speed multiplier"),
- pitch: float = Query(1.0, ge=0.5, le=2.0, description="Pitch multiplier (approximate)")
- ):
- if voice_file:
- path = os.path.join(VOICE_DIR, voice_file)
- if not os.path.isfile(path):
- raise HTTPException(status_code=404, detail=f"Voice file '{voice_file}' not found")
- speaker_wav = path
- else:
- speaker_wav = None
- # Temporary output file
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
- # voice = request.args.get("voice", "default")
- # speaker_wav = f"voices/{voice}.wav"
- # Generate TTS
- tts.tts_to_file(
- text=text,
- file_path=tmp.name,
- language=lang,
- speaker_wav="/voices/trump.wav",
- )
- return FileResponse(tmp.name, media_type="audio/wav", filename="speech.wav")
|