| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283 |
- """
- Local Reranker Server
- =====================
- Two-stage reranking architecture optimized for small GPUs.
- Pipeline:
- 1) FlashRank CPU reranker (fast filtering)
- 2) GPU cross-encoder reranker (accurate final ranking)
- Features
- --------
- ✓ CPU-first architecture (safe for weak GPUs)
- ✓ GPU reranking when available
- ✓ automatic CUDA fallback
- ✓ VRAM-aware routing
- ✓ batching for low VRAM GPUs
- ✓ safe FlashRank integration
- ✓ FastAPI REST interface
- ✓ health endpoint
- Default API
- -----------
- POST /rerank
- {
- "query": "...",
- "documents": ["doc1", "doc2"],
- "top_k": 5
- }
- Server
- ------
- http://localhost:5200
- """
- from fastapi import FastAPI
- from pydantic import BaseModel
- from typing import List
- import torch
- from sentence_transformers import CrossEncoder
- from flashrank import Ranker, RerankRequest
- # --------------------------------------------------
- # Configuration
- # --------------------------------------------------
- PORT = 5200
- # FlashRank stage candidate count
- FIRST_STAGE_TOP_K = 10
- # GPU batch size
- BATCH_SIZE = 8
- # max token length
- MAX_LENGTH = 256
- # minimal VRAM required to attempt GPU
- MIN_GPU_MEMORY_GB = 1.5
- # --------------------------------------------------
- # Model initialization
- # --------------------------------------------------
- print("Loading FlashRank CPU model...")
- cpu_ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2")
- print("Loading GPU cross-encoder model...")
- gpu_available = torch.cuda.is_available()
- gpu_model = None
- if gpu_available:
- try:
- gpu_model = CrossEncoder(
- "cross-encoder/ms-marco-MiniLM-L-6-v2",
- device="cuda",
- max_length=MAX_LENGTH
- )
- print("GPU reranker loaded successfully")
- except Exception as e:
- print("GPU initialization failed, running CPU-only:", e)
- gpu_model = None
- gpu_available = False
- else:
- print("CUDA not available, running CPU-only")
- # --------------------------------------------------
- # API schema
- # --------------------------------------------------
- class RerankRequestModel(BaseModel):
- query: str
- documents: List[str]
- top_k: int = 5
- class RerankResult(BaseModel):
- text: str
- score: float
- class RerankResponse(BaseModel):
- results: List[RerankResult]
- backend: str
- # --------------------------------------------------
- # Utility functions
- # --------------------------------------------------
- def gpu_memory_available():
- """
- Check available VRAM before attempting GPU inference.
- Prevents CUDA OOM on small GPUs.
- """
- if not gpu_available:
- return False
- free, total = torch.cuda.mem_get_info()
- free_gb = free / (1024 ** 3)
- return free_gb > MIN_GPU_MEMORY_GB
- # --------------------------------------------------
- # Stage 1: FlashRank CPU
- # --------------------------------------------------
- def rerank_cpu_stage(query: str, docs: List[str]):
- """
- First-stage reranking using FlashRank.
- FlashRank is extremely fast and filters
- the candidate documents before GPU reranking.
- """
- passages = [{"text": d} for d in docs]
- request = RerankRequest(
- query=query,
- passages=passages
- )
- result = cpu_ranker.rerank(request)
- ranked = [
- (r.get("text", ""), float(r.get("score", 0)))
- for r in result
- ]
- return ranked
- # --------------------------------------------------
- # Stage 2: GPU cross-encoder
- # --------------------------------------------------
- def rerank_gpu_stage(query: str, docs: List[str]):
- """
- Second-stage reranking using GPU cross-encoder.
- """
- pairs = [(query, d) for d in docs]
- scores = []
- for i in range(0, len(pairs), BATCH_SIZE):
- batch = pairs[i:i + BATCH_SIZE]
- batch_scores = gpu_model.predict(batch)
- scores.extend(batch_scores.tolist())
- ranked = list(zip(docs, scores))
- ranked.sort(key=lambda x: x[1], reverse=True)
- return ranked
- # --------------------------------------------------
- # FastAPI app
- # --------------------------------------------------
- app = FastAPI(
- title="Local Two-Stage Reranker",
- description="FlashRank CPU + MiniLM GPU reranking",
- version="2.0"
- )
- @app.post("/rerank", response_model=RerankResponse)
- def rerank(request: RerankRequestModel):
- query = request.query
- docs = request.documents
- top_k = min(request.top_k, len(docs))
- # --------------------------------------------------
- # Stage 1: FlashRank CPU filtering
- # --------------------------------------------------
- first_stage = rerank_cpu_stage(query, docs)
- # select best candidates
- candidates = first_stage[:FIRST_STAGE_TOP_K]
- candidate_docs = [d for d, s in candidates]
- backend = "cpu"
- # --------------------------------------------------
- # Stage 2: GPU reranking (optional)
- # --------------------------------------------------
- if gpu_model and gpu_memory_available():
- try:
- second_stage = rerank_gpu_stage(query, candidate_docs)
- backend = "gpu"
- except (torch.cuda.OutOfMemoryError, RuntimeError):
- print("CUDA failure -> using CPU stage results")
- torch.cuda.empty_cache()
- second_stage = candidates
- backend = "cpu"
- else:
- second_stage = candidates
- # --------------------------------------------------
- # Final result selection
- # --------------------------------------------------
- final = second_stage[:top_k]
- results = [
- RerankResult(text=d, score=float(s))
- for d, s in final
- ]
- return RerankResponse(
- results=results,
- backend=backend
- )
- # --------------------------------------------------
- # Health endpoint
- # --------------------------------------------------
- @app.get("/health")
- def health():
- return {
- "status": "ok",
- "cuda_available": torch.cuda.is_available(),
- "gpu_model_loaded": gpu_model is not None
- }
|