reranker_server.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. """
  2. Local Reranker Server
  3. =====================
  4. Two-stage reranking architecture optimized for small GPUs.
  5. Pipeline:
  6. 1) FlashRank CPU reranker (fast filtering)
  7. 2) GPU cross-encoder reranker (accurate final ranking)
  8. Features
  9. --------
  10. ✓ CPU-first architecture (safe for weak GPUs)
  11. ✓ GPU reranking when available
  12. ✓ automatic CUDA fallback
  13. ✓ VRAM-aware routing
  14. ✓ batching for low VRAM GPUs
  15. ✓ safe FlashRank integration
  16. ✓ FastAPI REST interface
  17. ✓ health endpoint
  18. Default API
  19. -----------
  20. POST /rerank
  21. {
  22. "query": "...",
  23. "documents": ["doc1", "doc2"],
  24. "top_k": 5
  25. }
  26. Server
  27. ------
  28. http://localhost:5200
  29. """
  30. from fastapi import FastAPI
  31. from pydantic import BaseModel
  32. from typing import List
  33. import torch
  34. from sentence_transformers import CrossEncoder
  35. from flashrank import Ranker, RerankRequest
  36. # --------------------------------------------------
  37. # Configuration
  38. # --------------------------------------------------
  39. PORT = 5200
  40. # FlashRank stage candidate count
  41. FIRST_STAGE_TOP_K = 10
  42. # GPU batch size
  43. BATCH_SIZE = 8
  44. # max token length
  45. MAX_LENGTH = 256
  46. # minimal VRAM required to attempt GPU
  47. MIN_GPU_MEMORY_GB = 1.5
  48. # --------------------------------------------------
  49. # Model initialization
  50. # --------------------------------------------------
  51. print("Loading FlashRank CPU model...")
  52. cpu_ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2")
  53. print("Loading GPU cross-encoder model...")
  54. gpu_available = torch.cuda.is_available()
  55. gpu_model = None
  56. if gpu_available:
  57. try:
  58. gpu_model = CrossEncoder(
  59. "cross-encoder/ms-marco-MiniLM-L-6-v2",
  60. device="cuda",
  61. max_length=MAX_LENGTH
  62. )
  63. print("GPU reranker loaded successfully")
  64. except Exception as e:
  65. print("GPU initialization failed, running CPU-only:", e)
  66. gpu_model = None
  67. gpu_available = False
  68. else:
  69. print("CUDA not available, running CPU-only")
  70. # --------------------------------------------------
  71. # API schema
  72. # --------------------------------------------------
  73. class RerankRequestModel(BaseModel):
  74. query: str
  75. documents: List[str]
  76. top_k: int = 5
  77. class RerankResult(BaseModel):
  78. text: str
  79. score: float
  80. class RerankResponse(BaseModel):
  81. results: List[RerankResult]
  82. backend: str
  83. # --------------------------------------------------
  84. # Utility functions
  85. # --------------------------------------------------
  86. def gpu_memory_available():
  87. """
  88. Check available VRAM before attempting GPU inference.
  89. Prevents CUDA OOM on small GPUs.
  90. """
  91. if not gpu_available:
  92. return False
  93. free, total = torch.cuda.mem_get_info()
  94. free_gb = free / (1024 ** 3)
  95. return free_gb > MIN_GPU_MEMORY_GB
  96. # --------------------------------------------------
  97. # Stage 1: FlashRank CPU
  98. # --------------------------------------------------
  99. def rerank_cpu_stage(query: str, docs: List[str]):
  100. """
  101. First-stage reranking using FlashRank.
  102. FlashRank is extremely fast and filters
  103. the candidate documents before GPU reranking.
  104. """
  105. passages = [{"text": d} for d in docs]
  106. request = RerankRequest(
  107. query=query,
  108. passages=passages
  109. )
  110. result = cpu_ranker.rerank(request)
  111. ranked = [
  112. (r.get("text", ""), float(r.get("score", 0)))
  113. for r in result
  114. ]
  115. return ranked
  116. # --------------------------------------------------
  117. # Stage 2: GPU cross-encoder
  118. # --------------------------------------------------
  119. def rerank_gpu_stage(query: str, docs: List[str]):
  120. """
  121. Second-stage reranking using GPU cross-encoder.
  122. """
  123. pairs = [(query, d) for d in docs]
  124. scores = []
  125. for i in range(0, len(pairs), BATCH_SIZE):
  126. batch = pairs[i:i + BATCH_SIZE]
  127. batch_scores = gpu_model.predict(batch)
  128. scores.extend(batch_scores.tolist())
  129. ranked = list(zip(docs, scores))
  130. ranked.sort(key=lambda x: x[1], reverse=True)
  131. return ranked
  132. # --------------------------------------------------
  133. # FastAPI app
  134. # --------------------------------------------------
  135. app = FastAPI(
  136. title="Local Two-Stage Reranker",
  137. description="FlashRank CPU + MiniLM GPU reranking",
  138. version="2.0"
  139. )
  140. @app.post("/rerank", response_model=RerankResponse)
  141. def rerank(request: RerankRequestModel):
  142. query = request.query
  143. docs = request.documents
  144. top_k = min(request.top_k, len(docs))
  145. # --------------------------------------------------
  146. # Stage 1: FlashRank CPU filtering
  147. # --------------------------------------------------
  148. first_stage = rerank_cpu_stage(query, docs)
  149. # select best candidates
  150. candidates = first_stage[:FIRST_STAGE_TOP_K]
  151. candidate_docs = [d for d, s in candidates]
  152. backend = "cpu"
  153. # --------------------------------------------------
  154. # Stage 2: GPU reranking (optional)
  155. # --------------------------------------------------
  156. if gpu_model and gpu_memory_available():
  157. try:
  158. second_stage = rerank_gpu_stage(query, candidate_docs)
  159. backend = "gpu"
  160. except (torch.cuda.OutOfMemoryError, RuntimeError):
  161. print("CUDA failure -> using CPU stage results")
  162. torch.cuda.empty_cache()
  163. second_stage = candidates
  164. backend = "cpu"
  165. else:
  166. second_stage = candidates
  167. # --------------------------------------------------
  168. # Final result selection
  169. # --------------------------------------------------
  170. final = second_stage[:top_k]
  171. results = [
  172. RerankResult(text=d, score=float(s))
  173. for d, s in final
  174. ]
  175. return RerankResponse(
  176. results=results,
  177. backend=backend
  178. )
  179. # --------------------------------------------------
  180. # Health endpoint
  181. # --------------------------------------------------
  182. @app.get("/health")
  183. def health():
  184. return {
  185. "status": "ok",
  186. "cuda_available": torch.cuda.is_available(),
  187. "gpu_model_loaded": gpu_model is not None
  188. }