Coverage for mindsdb / integrations / handlers / writer_handler / settings.py: 0%

53 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import List, Union 

2from pydantic import BaseModel, Extra, field_validator 

3 

4from mindsdb.integrations.handlers.rag_handler.settings import ( 

5 RAGBaseParameters, 

6) 

7from langchain_core.callbacks import StreamingStdOutCallbackHandler 

8 

9EVAL_COLUMN_NAMES = ( 

10 "question", 

11 "answers", 

12 "context", 

13) 

14 

15SUPPORTED_EVALUATION_TYPES = ("retrieval", "e2e") 

16 

17GENERATION_METRICS = ("rouge", "meteor", "cosine_similarity", "accuracy") 

18RETRIEVAL_METRICS = ("cosine_similarity", "accuracy") 

19 

20 

21# todo make a separate class for evaluation parameters 

22 

23 

24class WriterLLMParameters(BaseModel): 

25 """Model parameters for the Writer LLM API interface""" 

26 

27 writer_api_key: str 

28 writer_org_id: str = None 

29 base_url: str = None 

30 model_id: str = "palmyra-x" 

31 callbacks: List[StreamingStdOutCallbackHandler] = [StreamingStdOutCallbackHandler()] 

32 max_tokens: int = 1024 

33 temperature: float = 0.0 

34 top_p: float = 1 

35 stop: List[str] = [] 

36 best_of: int = 5 

37 verbose: bool = False 

38 

39 class Config: 

40 extra = Extra.forbid 

41 arbitrary_types_allowed = True 

42 

43 

44class WriterHandlerParameters(RAGBaseParameters): 

45 """Model parameters for create model""" 

46 

47 llm_params: WriterLLMParameters 

48 generation_evaluation_metrics: List[str] = list(GENERATION_METRICS) 

49 retrieval_evaluation_metrics: List[str] = list(RETRIEVAL_METRICS) 

50 evaluation_type: str = "e2e" 

51 n_rows_evaluation: int = None # if None, evaluate on all rows 

52 retriever_match_threshold: float = 0.7 

53 generator_match_threshold: float = 0.8 

54 evaluate_dataset: Union[List[dict], str] = "squad_v2_val_100_sample" 

55 

56 class Config: 

57 extra = Extra.forbid 

58 arbitrary_types_allowed = True 

59 use_enum_values = True 

60 

61 @field_validator("generation_evaluation_metrics") 

62 def generation_evaluation_metrics_must_be_supported(cls, v): 

63 for metric in v: 

64 if metric not in GENERATION_METRICS: 

65 raise ValueError( 

66 f"generation_evaluation_metrics must be one of {', '.join(str(v) for v in GENERATION_METRICS)}, got {metric}" 

67 ) 

68 return v 

69 

70 @field_validator("retrieval_evaluation_metrics") 

71 def retrieval_evaluation_metrics_must_be_supported(cls, v): 

72 for metric in v: 

73 if metric not in GENERATION_METRICS: 

74 raise ValueError( 

75 f"retrieval_evaluation_metrics must be one of {', '.join(str(v) for v in RETRIEVAL_METRICS)}, got {metric}" 

76 ) 

77 return v 

78 

79 @field_validator("evaluation_type") 

80 def evaluation_type_must_be_supported(cls, v): 

81 if v not in SUPPORTED_EVALUATION_TYPES: 

82 raise ValueError( 

83 f"evaluation_type must be one of `retrieval` or `e2e`, got {v}" 

84 ) 

85 return v