Coverage for mindsdb / integrations / utilities / rag / rerankers / base_reranker.py: 15%
333 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from __future__ import annotations
3import re
4import json
5import asyncio
6import logging
7import math
8import os
9import random
10from abc import ABC
11from typing import Any, List, Optional, Tuple
13from openai import AsyncOpenAI, AsyncAzureOpenAI
14from pydantic import BaseModel
16from mindsdb.integrations.utilities.rag.settings import (
17 DEFAULT_RERANKING_MODEL,
18 DEFAULT_LLM_ENDPOINT,
19 DEFAULT_RERANKER_N,
20 DEFAULT_RERANKER_LOGPROBS,
21 DEFAULT_RERANKER_TOP_LOGPROBS,
22 DEFAULT_RERANKER_MAX_TOKENS,
23 DEFAULT_VALID_CLASS_TOKENS,
24 RerankerMode,
25)
26from mindsdb.integrations.libs.base import BaseMLEngine
28log = logging.getLogger(__name__)
31def get_event_loop():
32 try:
33 loop = asyncio.get_running_loop()
34 except RuntimeError:
35 # If no running loop exists, create a new one
36 loop = asyncio.new_event_loop()
37 asyncio.set_event_loop(loop)
38 return loop
41class BaseLLMReranker(BaseModel, ABC):
42 filtering_threshold: float = 0.0 # Default threshold for filtering
43 provider: str = "openai"
44 model: str = DEFAULT_RERANKING_MODEL # Model to use for reranking
45 temperature: float = 0.0 # Temperature for the model
46 api_key: Optional[str] = None
47 base_url: Optional[str] = None
48 api_version: Optional[str] = None
49 num_docs_to_keep: Optional[int] = None # How many of the top documents to keep after reranking & compressing.
50 method: str = "multi-class" # Scoring method: 'multi-class' or 'binary'
51 mode: RerankerMode = RerankerMode.POINTWISE
52 _api_key_var: str = "OPENAI_API_KEY"
53 client: Optional[AsyncOpenAI | BaseMLEngine] = None
54 _semaphore: Optional[asyncio.Semaphore] = None
55 max_concurrent_requests: int = 20
56 max_retries: int = 4
57 retry_delay: float = 1.0
58 request_timeout: float = 20.0 # Timeout for API requests
59 early_stop: bool = True # Whether to enable early stopping
60 early_stop_threshold: float = 0.8 # Confidence threshold for early stopping
61 n: int = DEFAULT_RERANKER_N # Number of completions to generate
62 logprobs: bool = DEFAULT_RERANKER_LOGPROBS # Whether to include log probabilities
63 top_logprobs: int = DEFAULT_RERANKER_TOP_LOGPROBS # Number of top log probabilities to include
64 max_tokens: int = DEFAULT_RERANKER_MAX_TOKENS # Maximum tokens to generate
65 valid_class_tokens: List[str] = DEFAULT_VALID_CLASS_TOKENS
67 class Config:
68 arbitrary_types_allowed = True
69 extra = "allow"
71 def __init__(self, **kwargs):
72 super().__init__(**kwargs)
73 self._semaphore = asyncio.Semaphore(self.max_concurrent_requests)
74 self._init_client()
76 def _init_client(self):
77 if self.client is None:
78 if self.provider == "azure_openai":
79 azure_api_key = self.api_key or os.getenv("AZURE_OPENAI_API_KEY")
80 azure_api_endpoint = self.base_url or os.environ.get("AZURE_OPENAI_ENDPOINT")
81 azure_api_version = self.api_version or os.environ.get("AZURE_OPENAI_API_VERSION")
82 self.client = AsyncAzureOpenAI(
83 api_key=azure_api_key,
84 azure_endpoint=azure_api_endpoint,
85 api_version=azure_api_version,
86 timeout=self.request_timeout,
87 max_retries=2,
88 )
89 elif self.provider in ("openai", "ollama"):
90 if self.provider == "ollama":
91 self.method = "no-logprobs"
92 if self.api_key is None:
93 self.api_key = "n/a"
95 api_key_var: str = "OPENAI_API_KEY"
96 openai_api_key = self.api_key or os.getenv(api_key_var)
97 if not openai_api_key:
98 raise ValueError(f"OpenAI API key not found in environment variable {api_key_var}")
100 base_url = self.base_url or DEFAULT_LLM_ENDPOINT
101 self.client = AsyncOpenAI(
102 api_key=openai_api_key, base_url=base_url, timeout=self.request_timeout, max_retries=2
103 )
104 else:
105 # try to use litellm
106 from mindsdb.api.executor.controllers.session_controller import SessionController
108 session = SessionController()
109 module = session.integration_controller.get_handler_module("litellm")
111 if module is None or module.Handler is None:
112 raise ValueError(f'Unable to use "{self.provider}" provider. Litellm handler is not installed')
114 self.client = module.Handler
115 self.method = "no-logprobs"
117 async def _call_llm(self, messages):
118 if self.provider in ("azure_openai", "openai", "ollama"):
119 return await self.client.chat.completions.create(
120 model=self.model,
121 messages=messages,
122 )
123 else:
124 kwargs = self.model_extra.copy()
126 if self.api_key is not None:
127 kwargs["api_key"] = self.api_key
129 return await self.client.acompletion(self.provider, model=self.model, messages=messages, args=kwargs)
131 async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
132 ranked_results = []
134 # Process in larger batches for better throughput
135 batch_size = min(self.max_concurrent_requests * 2, len(query_document_pairs))
136 for i in range(0, len(query_document_pairs), batch_size):
137 batch = query_document_pairs[i : i + batch_size]
139 results = await asyncio.gather(
140 *[
141 self._backoff_wrapper(query=query, document=document, rerank_callback=rerank_callback)
142 for (query, document) in batch
143 ],
144 return_exceptions=True,
145 )
147 for idx, result in enumerate(results):
148 if isinstance(result, Exception):
149 log.error(f"Error processing document {i + idx}: {str(result)}")
150 raise RuntimeError(f"Error during reranking: {result}") from result
152 score = result["relevance_score"]
154 ranked_results.append((batch[idx][1], score))
156 # Check if we should stop early
157 try:
158 high_scoring_docs = [r for r in ranked_results if r[1] >= self.filtering_threshold]
159 can_stop_early = (
160 self.early_stop # Early stopping is enabled
161 and self.num_docs_to_keep # We have a target number of docs
162 and len(high_scoring_docs) >= self.num_docs_to_keep # Found enough good docs
163 and score >= self.early_stop_threshold # Current doc is good enough
164 )
166 if can_stop_early:
167 log.info(f"Early stopping after finding {self.num_docs_to_keep} documents with high confidence")
168 return ranked_results
169 except Exception as e:
170 # Don't let early stopping errors stop the whole process
171 log.warning(f"Error in early stopping check: {e}")
173 return ranked_results
175 async def _backoff_wrapper(self, query: str, document: str, rerank_callback=None) -> Any:
176 async with self._semaphore:
177 for attempt in range(self.max_retries):
178 try:
179 if self.method == "multi-class":
180 rerank_data = await self.search_relevancy_score(query, document)
181 elif self.method == "no-logprobs":
182 rerank_data = await self.search_relevancy_no_logprob(query, document)
183 else:
184 rerank_data = await self.search_relevancy(query, document)
185 if rerank_callback is not None:
186 rerank_callback(rerank_data)
187 return rerank_data
189 except Exception as e:
190 if attempt == self.max_retries - 1:
191 log.error(f"Failed after {self.max_retries} attempts: {str(e)}")
192 raise
193 # Exponential backoff with jitter
194 retry_delay = self.retry_delay * (2**attempt) + random.uniform(0, 0.1)
195 await asyncio.sleep(retry_delay)
197 async def search_relevancy(self, query: str, document: str) -> Any:
198 response = await self.client.chat.completions.create(
199 model=self.model,
200 messages=[
201 {
202 "role": "system",
203 "content": "Rate the relevance of the document to the query. Respond with 'yes' or 'no'.",
204 },
205 {"role": "user", "content": f"Query: {query}\nDocument: {document}\nIs this document relevant?"},
206 ],
207 temperature=self.temperature,
208 n=1,
209 logprobs=True,
210 max_tokens=1,
211 )
213 # Extract response and logprobs
214 answer = response.choices[0].message.content
215 logprob = response.choices[0].logprobs.content[0].logprob
217 # Convert answer to score using the model's confidence
218 if answer.lower().strip() == "yes":
219 score = logprob # If yes, use the model's confidence
220 elif answer.lower().strip() == "no":
221 score = 1 - logprob # If no, invert the confidence
222 else:
223 score = 0.5 * logprob # For unclear answers, reduce confidence
225 rerank_data = {
226 "document": document,
227 "relevance_score": score,
228 }
230 return rerank_data
232 async def search_relevancy_no_logprob(self, query: str, document: str) -> Any:
233 prompt = (
234 f"Score the relevance between search query and user message on scale between 0 and 100 per cents. "
235 f"Consider semantic meaning, key concepts, and contextual relevance. "
236 f"Return ONLY a numerical score between 0 and 100 per cents. No other text. Stop after sending a number. "
237 f"Search query: {query}"
238 )
240 response = await self._call_llm(
241 messages=[{"role": "system", "content": prompt}, {"role": "user", "content": document}],
242 )
244 answer = response.choices[0].message.content
246 try:
247 value = re.findall(r"[\d]+", answer)[0]
248 score = float(value) / 100
249 score = max(0.0, min(score, 1.0))
250 except (ValueError, IndexError):
251 score = 0.0
253 rerank_data = {
254 "document": document,
255 "relevance_score": score,
256 }
258 return rerank_data
260 async def search_relevancy_score(self, query: str, document: str) -> Any:
261 """
262 This method is used to score the relevance of a document to a query.
264 Args:
265 query: The query to score the relevance of.
266 document: The document to score the relevance of.
268 Returns:
269 A dictionary with the document and the relevance score.
270 """
272 log.debug("Start search_relevancy_score")
273 log.debug(f"Reranker query: {query[:5]}")
274 log.debug(f"Reranker document: {document[:50]}")
275 log.debug(f"Reranker model: {self.model}")
276 log.debug(f"Reranker temperature: {self.temperature}")
277 log.debug(f"Reranker n: {self.n}")
278 log.debug(f"Reranker logprobs: {self.logprobs}")
279 log.debug(f"Reranker top_logprobs: {self.top_logprobs}")
280 log.debug(f"Reranker max_tokens: {self.max_tokens}")
281 log.debug(f"Reranker valid_class_tokens: {self.valid_class_tokens}")
283 response = await self.client.chat.completions.create(
284 model=self.model,
285 messages=[
286 {
287 "role": "system",
288 "content": """
289 You are an intelligent assistant that evaluates how relevant a given document chunk is to a user's search query.
290 Your task is to analyze the similarity between the search query and the document chunk, and return **only the class label** that best represents the relevance:
292 - "class_1": Not relevant (score between 0.0 and 0.25)
293 - "class_2": Slightly relevant (score between 0.25 and 0.5)
294 - "class_3": Moderately relevant (score between 0.5 and 0.75)
295 - "class_4": Highly relevant (score between 0.75 and 1.0)
297 Respond with only one of: "class_1", "class_2", "class_3", or "class_4".
299 Examples:
301 Search query: "How to reset a router to factory settings?"
302 Document chunk: "Computers often come with customizable parental control settings."
303 Score: class_1
305 Search query: "Symptoms of vitamin D deficiency"
306 Document chunk: "Vitamin D deficiency has been linked to fatigue, bone pain, and muscle weakness."
307 Score: class_4
309 Search query: "Best practices for onboarding remote employees"
310 Document chunk: "An employee handbook can be useful for new hires, outlining company policies and benefits."
311 Score: class_2
313 Search query: "Benefits of mindfulness meditation"
314 Document chunk: "Practicing mindfulness has shown to reduce stress and improve focus in multiple studies."
315 Score: class_3
317 Search query: "What is Kubernetes used for?"
318 Document chunk: "Kubernetes is an open-source system for automating deployment, scaling, and management of containerized applications."
319 Score: class_4
321 Search query: "How to bake sourdough bread at home"
322 Document chunk: "The French Revolution began in 1789 and radically transformed society."
323 Score: class_1
325 Search query: "Machine learning algorithms for image classification"
326 Document chunk: "Convolutional Neural Networks (CNNs) are particularly effective in image classification tasks."
327 Score: class_4
329 Search query: "How to improve focus while working remotely"
330 Document chunk: "Creating a dedicated workspace and setting a consistent schedule can significantly improve focus during remote work."
331 Score: class_4
333 Search query: "Carbon emissions from electric vehicles vs gas cars"
334 Document chunk: "Electric vehicles produce zero emissions while driving, but battery production has environmental impacts."
335 Score: class_3
337 Search query: "Time zones in the United States"
338 Document chunk: "The U.S. is divided into six primary time zones: Eastern, Central, Mountain, Pacific, Alaska, and Hawaii-Aleutian."
339 Score: class_4
340 """,
341 },
342 {
343 "role": "user",
344 "content": f"""
345 Now evaluate the following pair:
347 Search query: {query}
348 Document chunk: {document}
350 Which class best represents the relevance?
351 """,
352 },
353 ],
354 temperature=self.temperature,
355 n=self.n,
356 logprobs=self.logprobs,
357 top_logprobs=self.top_logprobs,
358 max_tokens=self.max_tokens,
359 )
361 # Extract response and logprobs
362 token_logprobs = response.choices[0].logprobs.content
364 # Find the token that contains the class number
365 # Instead of just taking the last token, search for the actual class number token
366 class_token_logprob = None
367 for token_logprob in reversed(token_logprobs):
368 if token_logprob.token in self.valid_class_tokens:
369 class_token_logprob = token_logprob
370 break
372 # If we couldn't find a class token, fall back to the last non-empty token
373 if class_token_logprob is None:
374 log.warning("No class token logprob found, using the last token as fallback")
375 class_token_logprob = token_logprobs[-1]
377 top_logprobs = class_token_logprob.top_logprobs
379 # Create a map of 'class_1' -> probability, using token combinations
380 class_probs = {}
381 for top_token in top_logprobs:
382 full_label = f"class_{top_token.token}"
383 prob = math.exp(top_token.logprob)
384 class_probs[full_label] = prob
385 # Optional: normalize in case some are missing
386 total_prob = sum(class_probs.values())
387 class_probs = {k: v / total_prob for k, v in class_probs.items()}
388 # Assign weights to classes
389 class_weights = {"class_1": 0.25, "class_2": 0.5, "class_3": 0.75, "class_4": 1.0}
390 # Compute the final smooth score
391 score = sum(class_weights.get(class_label, 0) * prob for class_label, prob in class_probs.items())
392 if score is not None:
393 if score > 1.0:
394 score = 1.0
395 elif score < 0.0:
396 score = 0.0
398 rerank_data = {"document": document, "relevance_score": score}
399 log.debug(f"Reranker score: {score}")
400 log.debug("End search_relevancy_score")
401 return rerank_data
403 def get_scores(self, query: str, documents: list[str]):
404 query_document_pairs = [(query, doc) for doc in documents]
405 # Create event loop and run async code
407 documents_and_scores = get_event_loop().run_until_complete(self._rank(query_document_pairs))
409 scores = [score for _, score in documents_and_scores]
410 return scores
413def _strip_code_fences(text: str) -> str:
414 """Strip code fences from text, handling cases where first line has content after fence."""
415 stripped = text.strip()
416 if stripped.startswith("```") and stripped.endswith("```"):
417 lines = stripped.splitlines()
418 # Check if first line has content after the fence (e.g., ```json)
419 first_line = lines[0] if lines else ""
420 if first_line.strip() == "```" or (first_line.startswith("```") and len(first_line.strip()) > 3):
421 # Drop first fence line (with or without language specifier)
422 lines = lines[1:]
423 # Drop trailing fence lines
424 while lines and lines[-1].strip().startswith("```"):
425 lines.pop()
426 stripped = "\n".join(lines).strip()
427 return stripped
430class ListwiseLLMReranker(BaseLLMReranker):
431 mode: RerankerMode = RerankerMode.LISTWISE
432 max_document_characters: int = 3000
433 max_documents_per_batch: int = 50 # Maximum documents to rank in a single LLM call
434 document_separator: str = "\n---DOCUMENT_SEPARATOR---\n" # Unique separator to avoid conflicts
436 async def _rank(self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None) -> List[Tuple[str, float]]:
437 if not query_document_pairs:
438 return []
440 query = query_document_pairs[0][0]
441 documents = [document for _, document in query_document_pairs]
443 # Handle large document sets by batching
444 if len(documents) > self.max_documents_per_batch:
445 log.info(f"Batching {len(documents)} documents into groups of {self.max_documents_per_batch}")
446 return await self._rank_with_batching(query, documents, rerank_callback)
448 # Use _rank_single_batch for consistency
449 return await self._rank_single_batch(query_document_pairs, rerank_callback)
451 async def _rank_with_batching(
452 self, query: str, documents: List[str], rerank_callback=None
453 ) -> List[Tuple[str, float]]:
454 """Rank documents in batches to avoid overwhelming the LLM with too many documents."""
455 batch_size = self.max_documents_per_batch
456 num_batches = (len(documents) + batch_size - 1) // batch_size
458 all_results: List[Tuple[str, float]] = []
460 for batch_idx in range(num_batches):
461 start_idx = batch_idx * batch_size
462 end_idx = min(start_idx + batch_size, len(documents))
463 batch_docs = documents[start_idx:end_idx]
465 # Create query-document pairs for this batch
466 batch_pairs = [(query, doc) for doc in batch_docs]
468 # Rank this batch
469 batch_results = await self._rank_single_batch(batch_pairs, rerank_callback)
470 all_results.extend(batch_results)
472 # Sort all results by score to get final ranking
473 all_results.sort(key=lambda item: item[1], reverse=True)
474 return all_results
476 async def _rank_single_batch(
477 self, query_document_pairs: List[Tuple[str, str]], rerank_callback=None
478 ) -> List[Tuple[str, float]]:
479 """Rank a single batch of documents."""
480 query = query_document_pairs[0][0]
481 documents = [document for _, document in query_document_pairs]
483 messages = self._build_messages(query, documents)
485 for attempt in range(self.max_retries):
486 try:
487 response = await self._call_llm(messages)
488 content = response.choices[0].message.content
489 scores = self._extract_scores(content, len(documents))
490 return list(zip(documents, scores))
491 except Exception as exc:
492 if attempt == self.max_retries - 1:
493 log.error(f"Failed listwise reranking batch after {self.max_retries} attempts: {exc}")
494 raise
495 retry_delay = self.retry_delay * (2**attempt) + random.uniform(0, 0.1)
496 await asyncio.sleep(retry_delay)
498 return []
500 def _build_messages(self, query: str, documents: List[str]) -> List[dict]:
501 document_blocks = []
502 for idx, document in enumerate(documents, start=1):
503 # Remove any existing 'Document [N]:' prefix from content
504 cleaned_doc = self._clean_document_prefix(document)
505 truncated = self._truncate_document(cleaned_doc)
506 document_blocks.append(f"Document {idx}:\n{truncated}")
508 docs_text = self.document_separator.join(document_blocks)
509 system_prompt = (
510 "You are an expert reranker. Given a user query and a list of candidate "
511 "documents, you must rank the documents from most to least relevant. "
512 'Only respond with JSON following the schema: {"ranking": ['
513 '{"doc_index": <1-based document index>, "score": <float between 0 and 1>}]}.'
514 )
516 user_prompt = (
517 f"""
518 Query:
519 {query}
521 Documents:
522 {docs_text}
524 Return the ranking as JSON. Make sure every document appears once. Scores must be between 0 and 1.
525 """
526 ).strip()
528 return [
529 {"role": "system", "content": system_prompt},
530 {"role": "user", "content": user_prompt},
531 ]
533 def _clean_document_prefix(self, document: str) -> str:
534 """Remove 'Document [N]:' prefix if present in the document content."""
535 pattern = r"^Document\s+\d+:\s*"
536 return re.sub(pattern, "", document, count=1)
538 def _truncate_document(self, document: str) -> str:
539 if len(document) <= self.max_document_characters:
540 return document
541 return document[: self.max_document_characters] + "..."
543 def _extract_scores(self, content: str, num_documents: int) -> List[float]:
544 sanitized = _strip_code_fences(content)
545 fallback_scores = self._fallback_scores(num_documents)
546 parsed_scores = fallback_scores.copy()
548 try:
549 parsed = json.loads(sanitized)
550 except json.JSONDecodeError as exc:
551 log.warning(f"Failed to parse listwise reranker response as JSON: {exc}. Using fallback scores.")
552 return parsed_scores
554 ranking = parsed.get("ranking", []) if isinstance(parsed, dict) else parsed
555 if not isinstance(ranking, list):
556 log.warning("Listwise reranker response missing 'ranking' list. Using fallback scores.")
557 return parsed_scores
559 assignment_order = 0
560 assigned: dict[int, float] = {}
562 for rank_position, entry in enumerate(ranking):
563 doc_index: Optional[int] = None
564 score: Optional[float] = None
566 if isinstance(entry, dict):
567 doc_index = entry.get("doc_index")
568 score = entry.get("score")
569 elif isinstance(entry, (list, tuple)) and entry:
570 doc_index = entry[0]
571 if len(entry) > 1:
572 score = entry[1]
573 elif isinstance(entry, int):
574 doc_index = entry
576 if doc_index is None:
577 continue
579 if isinstance(doc_index, str) and doc_index.isdigit():
580 doc_index = int(doc_index)
582 if not isinstance(doc_index, int):
583 continue
585 # Accept either 0-based or 1-based indices
586 if doc_index <= 0:
587 adjusted_index = doc_index
588 else:
589 adjusted_index = doc_index - 1
591 if adjusted_index < 0 or adjusted_index >= num_documents:
592 continue
594 normalized_score = self._normalize_score(score)
595 if normalized_score is None:
596 normalized_score = fallback_scores[min(rank_position, num_documents - 1)]
598 assigned[adjusted_index] = normalized_score
599 assignment_order = max(assignment_order, rank_position + 1)
601 next_rank = assignment_order
602 for doc_idx in range(num_documents):
603 if doc_idx in assigned:
604 parsed_scores[doc_idx] = assigned[doc_idx]
605 else:
606 parsed_scores[doc_idx] = fallback_scores[min(next_rank, num_documents - 1)]
607 next_rank += 1
609 return parsed_scores
611 def _normalize_score(self, score: Any) -> Optional[float]:
612 if score is None:
613 return None
614 try:
615 value = float(score)
616 except (TypeError, ValueError):
617 return None
619 if math.isnan(value) or math.isinf(value):
620 return None
622 if value > 1:
623 value = 1.0
624 elif value < 0:
625 value = 0.0
627 return value
629 def _fallback_scores(self, length: int) -> List[float]:
630 if length <= 0:
631 return []
632 return [max(0.0, (length - idx) / length) for idx in range(length)]