Quellcode durchsuchen

stability improved

Lukas Goldschmidt vor 2 Tagen
Ursprung
Commit
40810df9e6
2 geänderte Dateien mit 547 neuen und 225 gelöschten Zeilen
  1. 204 225
      tts_server.py
  2. 343 0
      tts_server_unstable.py

+ 204 - 225
tts_server.py

@@ -1,330 +1,309 @@
 import os
 import io
+import re
 import hashlib
 import pickle
 import subprocess
+import threading
 from pathlib import Path
 
+import numpy as np
+import soundfile as sf
 import torch
 
+# Set CPU threads BEFORE any torch operations
+torch.set_num_threads(os.cpu_count())
+torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
+
 # FIX for PyTorch >=2.6 security change
 from torch.serialization import add_safe_globals
-from TTS.tts.configs.xtts_config import XttsConfig
-
 import TTS.tts.configs.xtts_config
 import TTS.tts.models.xtts
-
 add_safe_globals([
     TTS.tts.configs.xtts_config.XttsConfig,
-    TTS.tts.models.xtts.XttsAudioConfig
+    TTS.tts.models.xtts.XttsAudioConfig,
 ])
 
 from fastapi import FastAPI, HTTPException
 from fastapi.responses import StreamingResponse
 from TTS.api import TTS
 
-VOICE_DIR = Path("/voices")
-CACHE_DIR = Path("/cache")
+# ─── Paths & constants ────────────────────────────────────────────────────────
+
+VOICE_DIR  = Path("/voices")
+CACHE_DIR  = Path("/cache")
+MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
+SAMPLE_RATE = 24000
+VRAM_HEADROOM = 0.20          # fall back to CPU when VRAM < 20% free
+MAX_CHUNK_LEN = 200           # chars; XTTS hard-limit is ~400 tokens ≈ 250 chars
 
 VOICE_DIR.mkdir(exist_ok=True)
 CACHE_DIR.mkdir(exist_ok=True)
 
-MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
+# ─── Model loading ────────────────────────────────────────────────────────────
 
 print("Loading XTTS model...")
-tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
-print("Model loaded.")
+_device = "cuda" if torch.cuda.is_available() else "cpu"
+tts = TTS(MODEL_NAME).to(_device)
+print(f"Model loaded on {_device}.")
 
-app = FastAPI()
+# Single lock so concurrent requests don't fight over GPU / model.to() calls
+_model_lock = threading.Lock()
 
-embedding_cache = {}
-
-
-def sha256(path):
+app = FastAPI()
+embedding_cache: dict = {}
+
+# ─── Text helpers ─────────────────────────────────────────────────────────────
+
+# Characters XTTS tokeniser chokes on → strip or replace before inference
+_MARKDOWN_RE   = re.compile(r'\*{1,2}|_{1,2}|`+|#{1,6}\s?|~~|\[([^\]]*)\]\([^)]*\)')
+_MULTI_SPACE   = re.compile(r'  +')
+_CONTROL_CHARS = re.compile(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]')
+
+def clean_text(text: str) -> str:
+    """Remove markdown and control characters that corrupt XTTS tokenisation."""
+    text = _MARKDOWN_RE.sub(r'\1', text)   # strip md, keep link label
+    text = _CONTROL_CHARS.sub('', text)
+    text = text.replace('\r\n', '\n').replace('\r', '\n')
+    # Collapse multiple blank lines / spaces
+    text = re.sub(r'\n{3,}', '\n\n', text)
+    text = _MULTI_SPACE.sub(' ', text)
+    return text.strip()
+
+
+def chunk_text(text: str, max_len: int = MAX_CHUNK_LEN) -> list[str]:
+    """
+    Split on sentence boundaries.  Falls back to word-boundary splitting
+    for sentences that are still too long (e.g. no punctuation at all).
+    """
+    text = clean_text(text)
+    # Split on sentence-ending punctuation followed by whitespace or end
+    sentences = re.split(r'(?<=[.!?…])\s+', text)
+
+    chunks: list[str] = []
+    current = ""
+
+    for s in sentences:
+        s = s.strip()
+        if not s:
+            continue
+
+        # Single sentence longer than max_len → split on word boundary
+        if len(s) > max_len:
+            if current:
+                chunks.append(current)
+                current = ""
+            words = s.split()
+            part = ""
+            for w in words:
+                if len(part) + len(w) + 1 > max_len:
+                    if part:
+                        chunks.append(part.strip())
+                    part = w
+                else:
+                    part = (part + " " + w).strip()
+            if part:
+                chunks.append(part)
+            continue
+
+        if len(current) + len(s) + 1 > max_len:
+            if current:
+                chunks.append(current)
+            current = s
+        else:
+            current = (current + " " + s).strip()
+
+    if current:
+        chunks.append(current)
+
+    return [c for c in chunks if c]
+
+
+# ─── Voice / embedding helpers ────────────────────────────────────────────────
+
+def sha256_file(path: Path) -> str:
     h = hashlib.sha256()
     with open(path, "rb") as f:
-        while True:
-            chunk = f.read(8192)
-            if not chunk:
-                break
-            h.update(chunk)
+        for block in iter(lambda: f.read(65536), b""):
+            h.update(block)
     return h.hexdigest()
 
 
-def ensure_wav(voice_name):
+def convert_to_wav(src: Path, dst: Path) -> None:
+    subprocess.run(
+        ["ffmpeg", "-y", "-i", str(src), "-ar", "22050", "-ac", "1", str(dst)],
+        check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
+    )
+
 
+def ensure_wav(voice_name: str) -> Path:
     wav = VOICE_DIR / f"{voice_name}.wav"
     mp3 = VOICE_DIR / f"{voice_name}.mp3"
-
     if wav.exists():
-
         if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
             print(f"MP3 newer than WAV → reconverting {voice_name}")
             convert_to_wav(mp3, wav)
-
         return wav
-
     if mp3.exists():
         print(f"Converting MP3 → WAV for {voice_name}")
         convert_to_wav(mp3, wav)
         return wav
-
     raise HTTPException(404, f"Voice '{voice_name}' not found")
 
 
-def convert_to_wav(src, dst):
-
-    subprocess.run(
-        [
-            "ffmpeg",
-            "-y",
-            "-i",
-            str(src),
-            "-ar",
-            "22050",
-            "-ac",
-            "1",
-            str(dst),
-        ],
-        check=True,
-        stdout=subprocess.DEVNULL,
-        stderr=subprocess.DEVNULL,
-    )
-
-def load_cached_embedding(cache_file):
-    with open(cache_file, "rb") as f:
-        return pickle.load(f)
-
-
-def save_cached_embedding(cache_file, data):
-    with open(cache_file, "wb") as f:
-        pickle.dump(data, f)
-
-def get_embedding(voice_name):
-
+def get_embedding(voice_name: str):
     if voice_name in embedding_cache:
         return embedding_cache[voice_name]
 
-    src = None
-
-    for ext in ["wav", "mp3"]:
-        p = VOICE_DIR / f"{voice_name}.{ext}"
-        if p.exists():
-            src = p
-            break
-
-    if not src:
-        raise HTTPException(404, f"Voice '{voice_name}' not found")
-
-    wav_file = ensure_wav(voice_name)
-    # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
-
-    file_hash = sha256(wav_file)
+    wav_file  = ensure_wav(voice_name)
+    file_hash = sha256_file(wav_file)
     cache_file = CACHE_DIR / f"{voice_name}.pkl"
 
     if cache_file.exists():
-
-        cached = load_cached_embedding(cache_file)
-
-        if cached["hash"] == file_hash:
-            print(f"Using cached embedding for {voice_name}")
-            embedding_cache[voice_name] = cached["data"]
-            return cached["data"]
+        try:
+            with open(cache_file, "rb") as f:
+                cached = pickle.load(f)
+            if cached.get("hash") == file_hash:
+                print(f"Using cached embedding for {voice_name}")
+                embedding_cache[voice_name] = cached["data"]
+                return cached["data"]
+        except Exception as e:
+            print(f"Cache read error for {voice_name}: {e} – recomputing")
 
     print(f"Computing embedding for {voice_name}")
-
     model = tts.synthesizer.tts_model
-
     gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
         audio_path=str(wav_file)
     )
-
     data = (gpt_cond_latent, speaker_embedding)
+    with open(cache_file, "wb") as f:
+        pickle.dump({"hash": file_hash, "data": data}, f)
+    embedding_cache[voice_name] = data
+    return data
 
-    save_cached_embedding(
-        cache_file,
-        {"hash": file_hash, "data": data},
-    )
 
-    embedding_cache[voice_name] = data
+# ─── Core inference helper ────────────────────────────────────────────────────
+
+def _vram_low() -> bool:
+    if not torch.cuda.is_available():
+        return True
+    free, total = torch.cuda.mem_get_info()
+    return (free / total) < VRAM_HEADROOM
+
+
+def _infer_chunk(chunk: str, lang: str, gpt_cond_latent, speaker_embedding) -> np.ndarray:
+    """Run inference for one chunk; falls back to CPU on OOM."""
+    model = tts.synthesizer.tts_model
+
+    def _run(m, lat, emb):
+        with torch.inference_mode():
+            out = m.inference(chunk, lang, lat, emb)
+        wav = out["wav"]
+        if isinstance(wav, torch.Tensor):
+            wav = wav.cpu().numpy()
+        if wav.ndim == 1:
+            wav = np.expand_dims(wav, 1)
+        return wav
+
+    with _model_lock:
+        try:
+            return _run(model, gpt_cond_latent, speaker_embedding)
+        except torch.cuda.OutOfMemoryError:
+            print(f"⚠ CUDA OOM on chunk – falling back to CPU ({os.cpu_count()} cores)")
+            torch.cuda.empty_cache()
+            model.to("cpu")
+            try:
+                result = _run(
+                    model,
+                    gpt_cond_latent.to("cpu"),
+                    speaker_embedding.to("cpu"),
+                )
+            finally:
+                # Always move back, even if CPU inference also fails
+                model.to("cuda")
+                torch.cuda.empty_cache()
+            return result
 
-    return data
+
+# ─── Routes ───────────────────────────────────────────────────────────────────
 
 @app.get("/")
 def root():
-    return {"status": "XTTS server running"}
+    return {"status": "XTTS server running", "device": _device}
+
 
 @app.get("/voices")
 def list_voices():
+    seen = set()
     voices = []
     for f in VOICE_DIR.iterdir():
-        if f.suffix in [".wav", ".mp3"]:
+        if f.suffix in {".wav", ".mp3"} and f.stem not in seen:
             voices.append(f.stem)
-    return {"voices": voices}
+            seen.add(f.stem)
+    return {"voices": sorted(voices)}
+
 
 @app.get("/tts")
 @app.get("/api/tts")
-def synthesize(
-    text: str,
-    voice: str = "default",
-    lang: str = "en",
-):
-
-    import numpy as np
-    import torch
-    import io
-    import soundfile as sf
-    import re
-
-    def chunk_text(text, max_len=150):
-        sentences = re.split(r'(?<=[.!?])\s+', text)
-        chunks = []
-        current = ""
-
-        for s in sentences:
-            if len(current) + len(s) > max_len:
-                if current:
-                    chunks.append(current.strip())
-                current = s
-            else:
-                current += " " + s
-
-        if current:
-            chunks.append(current.strip())
-
-        return chunks
+def synthesize(text: str, voice: str = "default", lang: str = "en"):
+    if not text.strip():
+        raise HTTPException(400, "text parameter is empty")
 
     gpt_cond_latent, speaker_embedding = get_embedding(voice)
 
-    text_chunks = chunk_text(text, max_len=150)
+    # If VRAM is already scarce, pin embeddings on CPU for this whole request
+    use_cpu = _vram_low()
+    if use_cpu and torch.cuda.is_available():
+        print("⚠ Low VRAM – pinning entire request to CPU")
+        gpt_cond_latent  = gpt_cond_latent.to("cpu")
+        speaker_embedding = speaker_embedding.to("cpu")
+        with _model_lock:
+            tts.synthesizer.tts_model.to("cpu")
 
+    chunks  = chunk_text(text)
     wav_all = []
 
-    for chunk in text_chunks:
-
+    for i, chunk in enumerate(chunks):
+        print(f"  chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
         try:
-            out = tts.synthesizer.tts_model.inference(
-                chunk,
-                lang,
-                gpt_cond_latent,
-                speaker_embedding,
-            )
-
-        except torch.cuda.OutOfMemoryError:
-
-            print("⚠ CUDA OOM – retrying chunk on CPU")
-
-            torch.cuda.empty_cache()
-
-            cpu_model = tts.synthesizer.tts_model.to("cpu")
-
-            out = cpu_model.inference(
-                chunk,
-                lang,
-                gpt_cond_latent.to("cpu"),
-                speaker_embedding.to("cpu"),
-            )
-
-            tts.synthesizer.tts_model.to("cuda")
-
-        wav_chunk = out["wav"]
-
-        if len(wav_chunk.shape) == 1:
-            wav_chunk = np.expand_dims(wav_chunk, 1)
-
+            wav_chunk = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
+        except Exception as e:
+            raise HTTPException(500, f"Inference failed on chunk {i+1}: {e}")
         wav_all.append(wav_chunk)
 
-        if torch.cuda.is_available():
-            torch.cuda.empty_cache()
+    # Restore model to GPU if we moved it
+    if use_cpu and torch.cuda.is_available():
+        with _model_lock:
+            tts.synthesizer.tts_model.to("cuda")
 
     wav = np.concatenate(wav_all, axis=0)
-
     buf = io.BytesIO()
-
-    sf.write(buf, wav, 24000, format="WAV")
-
+    sf.write(buf, wav, SAMPLE_RATE, format="WAV")
     buf.seek(0)
-
     return StreamingResponse(buf, media_type="audio/wav")
 
 
-
 @app.get("/tts_stream")
 @app.get("/api/tts_stream")
-def synthesize_stream(
-    text: str,
-    voice: str = "default",
-    lang: str = "en",
-):
-
-    import numpy as np
-    import torch
-    import soundfile as sf
-    import re
-    import io
-
-    def chunk_text(text, max_len=150):
-        sentences = re.split(r'(?<=[.!?])\s+', text)
-        chunks = []
-        current = ""
-
-        for s in sentences:
-            if len(current) + len(s) > max_len:
-                if current:
-                    chunks.append(current.strip())
-                current = s
-            else:
-                current += " " + s
-
-        if current:
-            chunks.append(current.strip())
-
-        return chunks
+def synthesize_stream(text: str, voice: str = "default", lang: str = "en"):
+    """Stream WAV chunks as they are synthesised — lower latency for long texts."""
+    if not text.strip():
+        raise HTTPException(400, "text parameter is empty")
 
     gpt_cond_latent, speaker_embedding = get_embedding(voice)
-
-    text_chunks = chunk_text(text)
+    chunks = chunk_text(text)
 
     def audio_generator():
-
-        for chunk in text_chunks:
-
+        for i, chunk in enumerate(chunks):
+            print(f"  [stream] chunk {i+1}/{len(chunks)}: {chunk[:60]!r}")
             try:
-                out = tts.synthesizer.tts_model.inference(
-                    chunk,
-                    lang,
-                    gpt_cond_latent,
-                    speaker_embedding,
-                )
-
-            except torch.cuda.OutOfMemoryError:
-
-                print("CUDA OOM – retrying on CPU")
-
-                torch.cuda.empty_cache()
-
-                cpu_model = tts.synthesizer.tts_model.to("cpu")
-
-                out = cpu_model.inference(
-                    chunk,
-                    lang,
-                    gpt_cond_latent.to("cpu"),
-                    speaker_embedding.to("cpu"),
-                )
-
-                tts.synthesizer.tts_model.to("cuda")
-
-            wav = out["wav"]
-
+                wav = _infer_chunk(chunk, lang, gpt_cond_latent, speaker_embedding)
+            except Exception as e:
+                print(f"  [stream] chunk {i+1} failed: {e}")
+                continue          # skip bad chunk rather than kill the stream
             buf = io.BytesIO()
-
-            sf.write(buf, wav, 24000, format="WAV")
-
+            sf.write(buf, wav, SAMPLE_RATE, format="WAV")
             buf.seek(0)
-
             yield buf.read()
 
-            if torch.cuda.is_available():
-                torch.cuda.empty_cache()
-
     return StreamingResponse(audio_generator(), media_type="audio/wav")

+ 343 - 0
tts_server_unstable.py

@@ -0,0 +1,343 @@
+import os
+import io
+import hashlib
+import pickle
+import subprocess
+from pathlib import Path
+
+import torch
+
+import numpy as np
+import soundfile as sf
+import re
+
+
+# FIX for PyTorch >=2.6 security change
+from torch.serialization import add_safe_globals
+from TTS.tts.configs.xtts_config import XttsConfig
+
+import TTS.tts.configs.xtts_config
+import TTS.tts.models.xtts
+
+add_safe_globals([
+    TTS.tts.configs.xtts_config.XttsConfig,
+    TTS.tts.models.xtts.XttsAudioConfig
+])
+
+from fastapi import FastAPI, HTTPException
+from fastapi.responses import StreamingResponse
+from TTS.api import TTS
+
+
+
+# Set CPU threads BEFORE any torch operations - do this at module level
+torch.set_num_threads(os.cpu_count())
+torch.set_num_interop_threads(max(1, os.cpu_count() // 2))
+
+def chunk_text(text, max_len=150):
+    sentences = re.split(r'(?<=[.!?])\s+', text)
+    chunks = []
+    current = ""
+
+    for s in sentences:
+        if len(current) + len(s) > max_len:
+            if current:
+                chunks.append(current.strip())
+            current = s
+        else:
+            current += " " + s
+
+    if current:
+        chunks.append(current.strip())
+
+    return chunks
+
+
+
+VOICE_DIR = Path("/voices")
+CACHE_DIR = Path("/cache")
+
+VOICE_DIR.mkdir(exist_ok=True)
+CACHE_DIR.mkdir(exist_ok=True)
+
+MODEL_NAME = "tts_models/multilingual/multi-dataset/xtts_v2"
+
+print("Loading XTTS model...")
+tts = TTS(MODEL_NAME).to("cuda" if torch.cuda.is_available() else "cpu")
+print("Model loaded.")
+
+app = FastAPI()
+
+embedding_cache = {}
+
+
+def sha256(path):
+    h = hashlib.sha256()
+    with open(path, "rb") as f:
+        while True:
+            chunk = f.read(8192)
+            if not chunk:
+                break
+            h.update(chunk)
+    return h.hexdigest()
+
+
+def ensure_wav(voice_name):
+
+    wav = VOICE_DIR / f"{voice_name}.wav"
+    mp3 = VOICE_DIR / f"{voice_name}.mp3"
+
+    if wav.exists():
+
+        if mp3.exists() and mp3.stat().st_mtime > wav.stat().st_mtime:
+            print(f"MP3 newer than WAV → reconverting {voice_name}")
+            convert_to_wav(mp3, wav)
+
+        return wav
+
+    if mp3.exists():
+        print(f"Converting MP3 → WAV for {voice_name}")
+        convert_to_wav(mp3, wav)
+        return wav
+
+    raise HTTPException(404, f"Voice '{voice_name}' not found")
+
+
+def convert_to_wav(src, dst):
+
+    subprocess.run(
+        [
+            "ffmpeg",
+            "-y",
+            "-i",
+            str(src),
+            "-ar",
+            "22050",
+            "-ac",
+            "1",
+            str(dst),
+        ],
+        check=True,
+        stdout=subprocess.DEVNULL,
+        stderr=subprocess.DEVNULL,
+    )
+
+def load_cached_embedding(cache_file):
+    with open(cache_file, "rb") as f:
+        return pickle.load(f)
+
+
+def save_cached_embedding(cache_file, data):
+    with open(cache_file, "wb") as f:
+        pickle.dump(data, f)
+
+def get_embedding(voice_name):
+
+    if voice_name in embedding_cache:
+        return embedding_cache[voice_name]
+
+    src = None
+
+    for ext in ["wav", "mp3"]:
+        p = VOICE_DIR / f"{voice_name}.{ext}"
+        if p.exists():
+            src = p
+            break
+
+    if not src:
+        raise HTTPException(404, f"Voice '{voice_name}' not found")
+
+    wav_file = ensure_wav(voice_name)
+    # wav_file = src if src.suffix == ".wav" else convert_to_wav(src)
+
+    file_hash = sha256(wav_file)
+    cache_file = CACHE_DIR / f"{voice_name}.pkl"
+
+    if cache_file.exists():
+
+        cached = load_cached_embedding(cache_file)
+
+        if cached["hash"] == file_hash:
+            print(f"Using cached embedding for {voice_name}")
+            embedding_cache[voice_name] = cached["data"]
+            return cached["data"]
+
+    print(f"Computing embedding for {voice_name}")
+
+    model = tts.synthesizer.tts_model
+
+    gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(
+        audio_path=str(wav_file)
+    )
+
+    data = (gpt_cond_latent, speaker_embedding)
+
+    save_cached_embedding(
+        cache_file,
+        {"hash": file_hash, "data": data},
+    )
+
+    embedding_cache[voice_name] = data
+
+    return data
+
+@app.get("/")
+def root():
+    return {"status": "XTTS server running"}
+
+@app.get("/voices")
+def list_voices():
+    voices = []
+    for f in VOICE_DIR.iterdir():
+        if f.suffix in [".wav", ".mp3"]:
+            voices.append(f.stem)
+    return {"voices": voices}
+
+@app.get("/tts")
+@app.get("/api/tts")
+def synthesize(
+    text: str,
+    voice: str = "default",
+    lang: str = "en",
+):
+    gpt_cond_latent, speaker_embedding = get_embedding(voice)
+    text_chunks = chunk_text(text, max_len=150)
+    wav_all = []
+
+    # Move model to CPU once if GPU is already under pressure
+    use_cpu_fallback = False
+    if torch.cuda.is_available():
+        free_mem, total_mem = torch.cuda.mem_get_info()
+        if free_mem / total_mem < 0.2:  # less than 20% VRAM free
+            print("⚠ Low VRAM detected – using CPU for this request")
+            use_cpu_fallback = True
+
+    for chunk in text_chunks:
+        try:
+            if use_cpu_fallback:
+                raise torch.cuda.OutOfMemoryError  # skip straight to CPU path
+
+            out = tts.synthesizer.tts_model.inference(
+                chunk,
+                lang,
+                gpt_cond_latent,
+                speaker_embedding,
+            )
+
+        except torch.cuda.OutOfMemoryError:
+            print(f"⚠ CUDA OOM – retrying chunk on CPU ({os.cpu_count()} cores)")
+            torch.cuda.empty_cache()
+            use_cpu_fallback = True  # stay on CPU for remaining chunks
+
+            with torch.inference_mode():  # faster, no grad tracking
+                cpu_model = tts.synthesizer.tts_model.to("cpu")
+                out = cpu_model.inference(
+                    chunk,
+                    lang,
+                    gpt_cond_latent.to("cpu"),
+                    speaker_embedding.to("cpu"),
+                )
+
+            tts.synthesizer.tts_model.to("cuda")
+
+        wav_chunk = out["wav"]
+
+        if isinstance(wav_chunk, torch.Tensor):
+            wav_chunk = wav_chunk.cpu().numpy()
+
+        if len(wav_chunk.shape) == 1:
+            wav_chunk = np.expand_dims(wav_chunk, 1)
+
+        wav_all.append(wav_chunk)
+
+        if torch.cuda.is_available() and not use_cpu_fallback:
+            torch.cuda.empty_cache()
+
+    wav = np.concatenate(wav_all, axis=0)
+    buf = io.BytesIO()
+    sf.write(buf, wav, 24000, format="WAV")
+    buf.seek(0)
+
+    return StreamingResponse(buf, media_type="audio/wav")
+
+
+
+@app.get("/tts_stream")
+@app.get("/api/tts_stream")
+def synthesize_stream(
+    text: str,
+    voice: str = "default",
+    lang: str = "en",
+):
+
+    import numpy as np
+    import torch
+    import soundfile as sf
+    import re
+    import io
+
+    def chunk_text(text, max_len=150):
+        sentences = re.split(r'(?<=[.!?])\s+', text)
+        chunks = []
+        current = ""
+
+        for s in sentences:
+            if len(current) + len(s) > max_len:
+                if current:
+                    chunks.append(current.strip())
+                current = s
+            else:
+                current += " " + s
+
+        if current:
+            chunks.append(current.strip())
+
+        return chunks
+
+    gpt_cond_latent, speaker_embedding = get_embedding(voice)
+
+    text_chunks = chunk_text(text)
+
+    def audio_generator():
+
+        for chunk in text_chunks:
+
+            try:
+                out = tts.synthesizer.tts_model.inference(
+                    chunk,
+                    lang,
+                    gpt_cond_latent,
+                    speaker_embedding,
+                )
+
+            except torch.cuda.OutOfMemoryError:
+
+                print("CUDA OOM – retrying on CPU")
+
+                torch.cuda.empty_cache()
+
+                cpu_model = tts.synthesizer.tts_model.to("cpu")
+
+                out = cpu_model.inference(
+                    chunk,
+                    lang,
+                    gpt_cond_latent.to("cpu"),
+                    speaker_embedding.to("cpu"),
+                )
+
+                tts.synthesizer.tts_model.to("cuda")
+
+            wav = out["wav"]
+
+            buf = io.BytesIO()
+
+            sf.write(buf, wav, 24000, format="WAV")
+
+            buf.seek(0)
+
+            yield buf.read()
+
+            if torch.cuda.is_available():
+                torch.cuda.empty_cache()
+
+    return StreamingResponse(audio_generator(), media_type="audio/wav")