Coverage for mindsdb / integrations / handlers / writer_handler / settings.py: 0%
53 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1from typing import List, Union
2from pydantic import BaseModel, Extra, field_validator
4from mindsdb.integrations.handlers.rag_handler.settings import (
5 RAGBaseParameters,
6)
7from langchain_core.callbacks import StreamingStdOutCallbackHandler
9EVAL_COLUMN_NAMES = (
10 "question",
11 "answers",
12 "context",
13)
15SUPPORTED_EVALUATION_TYPES = ("retrieval", "e2e")
17GENERATION_METRICS = ("rouge", "meteor", "cosine_similarity", "accuracy")
18RETRIEVAL_METRICS = ("cosine_similarity", "accuracy")
21# todo make a separate class for evaluation parameters
24class WriterLLMParameters(BaseModel):
25 """Model parameters for the Writer LLM API interface"""
27 writer_api_key: str
28 writer_org_id: str = None
29 base_url: str = None
30 model_id: str = "palmyra-x"
31 callbacks: List[StreamingStdOutCallbackHandler] = [StreamingStdOutCallbackHandler()]
32 max_tokens: int = 1024
33 temperature: float = 0.0
34 top_p: float = 1
35 stop: List[str] = []
36 best_of: int = 5
37 verbose: bool = False
39 class Config:
40 extra = Extra.forbid
41 arbitrary_types_allowed = True
44class WriterHandlerParameters(RAGBaseParameters):
45 """Model parameters for create model"""
47 llm_params: WriterLLMParameters
48 generation_evaluation_metrics: List[str] = list(GENERATION_METRICS)
49 retrieval_evaluation_metrics: List[str] = list(RETRIEVAL_METRICS)
50 evaluation_type: str = "e2e"
51 n_rows_evaluation: int = None # if None, evaluate on all rows
52 retriever_match_threshold: float = 0.7
53 generator_match_threshold: float = 0.8
54 evaluate_dataset: Union[List[dict], str] = "squad_v2_val_100_sample"
56 class Config:
57 extra = Extra.forbid
58 arbitrary_types_allowed = True
59 use_enum_values = True
61 @field_validator("generation_evaluation_metrics")
62 def generation_evaluation_metrics_must_be_supported(cls, v):
63 for metric in v:
64 if metric not in GENERATION_METRICS:
65 raise ValueError(
66 f"generation_evaluation_metrics must be one of {', '.join(str(v) for v in GENERATION_METRICS)}, got {metric}"
67 )
68 return v
70 @field_validator("retrieval_evaluation_metrics")
71 def retrieval_evaluation_metrics_must_be_supported(cls, v):
72 for metric in v:
73 if metric not in GENERATION_METRICS:
74 raise ValueError(
75 f"retrieval_evaluation_metrics must be one of {', '.join(str(v) for v in RETRIEVAL_METRICS)}, got {metric}"
76 )
77 return v
79 @field_validator("evaluation_type")
80 def evaluation_type_must_be_supported(cls, v):
81 if v not in SUPPORTED_EVALUATION_TYPES:
82 raise ValueError(
83 f"evaluation_type must be one of `retrieval` or `e2e`, got {v}"
84 )
85 return v