Coverage for mindsdb / integrations / handlers / writer_handler / writer_handler.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1from typing import Dict, Optional 

2 

3import pandas as pd 

4 

5from mindsdb.integrations.handlers.writer_handler.evaluate import WriterEvaluator 

6from mindsdb.integrations.handlers.writer_handler.ingest import WriterIngestor 

7from mindsdb.integrations.handlers.writer_handler.rag import QuestionAnswerer 

8from mindsdb.integrations.handlers.writer_handler.settings import ( 

9 DEFAULT_EMBEDDINGS_MODEL, 

10 EVAL_COLUMN_NAMES, 

11 WriterHandlerParameters, 

12 WriterLLMParameters, 

13) 

14from mindsdb.integrations.libs.base import BaseMLEngine 

15from mindsdb.integrations.utilities.datasets.dataset import ( 

16 load_dataset, 

17 validate_dataframe, 

18) 

19from mindsdb.utilities import log 

20 

21# these require no additional arguments 

22 

23logger = log.getLogger(__name__) 

24 

25 

26def extract_llm_params(args): 

27 """extract llm params from input query args""" 

28 

29 llm_params = {} 

30 for param in WriterLLMParameters.model_fields: 

31 if param in args: 

32 llm_params[param] = args.pop(param) 

33 

34 args["llm_params"] = llm_params 

35 

36 return args 

37 

38 

39class WriterHandler(BaseMLEngine): 

40 """ 

41 WriterHandler is a MindsDB integration with Writer API LLMs that allows users to run question answering 

42 on their data by providing a question. 

43 

44 The User is able to provide data that provides context for the questions, see create() method for more details. 

45 

46 """ 

47 

48 name = "writer" 

49 

50 def __init__(self, *args, **kwargs): 

51 super().__init__(*args, **kwargs) 

52 self.generative = True 

53 

54 @staticmethod 

55 def create_validation(target, args=None, **kwargs): 

56 if "using" not in args: 

57 raise Exception( 

58 "Writer engine requires a USING clause! Refer to its documentation for more details." 

59 ) 

60 

61 def create( 

62 self, 

63 target: str, 

64 df: pd.DataFrame = pd.DataFrame(), 

65 args: Optional[Dict] = None, 

66 ): 

67 """ 

68 Dispatch is running embeddings and storing in a VectorDB, unless user already has embeddings persisted 

69 """ 

70 

71 input_args = extract_llm_params(args["using"]) 

72 

73 if "evaluate_dataset" not in input_args and df is not None: 

74 # if user doesn't provide an evaluation dataset, use the input df from create query 

75 input_args["evaluate_dataset"] = df.to_dict(orient="records") 

76 

77 args = WriterHandlerParameters(**input_args) 

78 

79 # create folder for vector store to persist embeddings 

80 args.vector_store_storage_path = self.engine_storage.folder_get( 

81 args.vector_store_folder_name 

82 ) 

83 

84 if df is not None and args.run_embeddings: 

85 if "context_columns" not in args: 

86 # if no context columns provided, use all columns in df 

87 logger.info("No context columns provided, using all columns in df") 

88 args.context_columns = df.columns.tolist() 

89 

90 if "embeddings_model_name" not in args: 

91 logger.info( 

92 f"No embeddings model provided in query, using default model: {DEFAULT_EMBEDDINGS_MODEL}" 

93 ) 

94 

95 ingestor = WriterIngestor(args=args, df=df) 

96 ingestor.embeddings_to_vectordb() 

97 

98 else: 

99 logger.info("Skipping embeddings and ingestion into Chroma VectorDB") 

100 

101 export_args = args.dict(exclude={"llm_params"}) 

102 # 'callbacks' aren't json serializable, we do this to avoid errors 

103 export_args["llm_params"] = args.llm_params.dict(exclude={"callbacks"}) 

104 

105 # for mindsdb cloud, store data in shared file system 

106 # for cloud version of mindsdb to make it be usable by all mindsdb nodes 

107 self.engine_storage.folder_sync(args.vector_store_folder_name) 

108 

109 self.model_storage.json_set("args", export_args) 

110 

111 def predict(self, df: pd.DataFrame = None, args: dict = None): 

112 """ 

113 Dispatch is performed depending on the underlying model type. Currently, only question answering 

114 is supported. 

115 """ 

116 

117 input_args = self.model_storage.json_get("args") 

118 args = WriterHandlerParameters(**input_args) 

119 

120 if args.evaluation_type: 

121 # if user adds a WHERE clause with 'run_evaluation = true', run evaluation 

122 if "run_evaluation" in df.columns and df["run_evaluation"].tolist()[0]: 

123 return self.evaluate(args) 

124 else: 

125 logger.info( 

126 "Skipping evaluation, running prediction only. " 

127 "to run evaluation, add a WHERE clause with 'run_evaluation = true'" 

128 ) 

129 

130 args.vector_store_storage_path = self.engine_storage.folder_get( 

131 args.vector_store_folder_name 

132 ) 

133 

134 # get question answering results 

135 question_answerer = QuestionAnswerer(args=args) 

136 

137 # get question from sql query 

138 # e.g. where question = 'What is the capital of France?' 

139 response = question_answerer.query(df["question"].tolist()[0]) 

140 

141 return pd.DataFrame(response) 

142 

143 def evaluate(self, args: WriterHandlerParameters): 

144 

145 if isinstance(args.evaluate_dataset, list): 

146 # if user provides a list of dicts, convert to dataframe and validate 

147 evaluate_df = validate_dataframe( 

148 pd.DataFrame(args.evaluate_dataset), EVAL_COLUMN_NAMES 

149 ) 

150 else: 

151 evaluate_df = load_dataset( 

152 ml_task_type="question_answering", dataset_name=args.evaluate_dataset 

153 ) 

154 args.context_columns = "context" 

155 

156 if args.n_rows_evaluation: 

157 # if user specifies n_rows_evaluation in create, only use that many rows 

158 evaluate_df = evaluate_df.head(args.n_rows_evaluation) 

159 

160 ingestor = WriterIngestor(df=evaluate_df, args=args) 

161 ingestor.embeddings_to_vectordb() 

162 

163 evaluator = WriterEvaluator(args=args, df=evaluate_df, rag=QuestionAnswerer) 

164 df = evaluator.evaluate() 

165 

166 evaluation_metrics = dict( 

167 mean_evaluation_metrics=evaluator.mean_evaluation_metrics, 

168 evaluation_df=df.to_dict(orient="records"), 

169 ) 

170 

171 self.model_storage.json_set("evaluation", evaluation_metrics) 

172 

173 return df 

174 

175 def describe(self, attribute: Optional[str] = None) -> pd.DataFrame: 

176 """ 

177 Describe the model, or a specific attribute of the model 

178 """ 

179 

180 if attribute == "evaluation_output": 

181 evaluation = self.model_storage.json_get("evaluation") 

182 return pd.DataFrame(evaluation["evaluation_df"]) 

183 elif attribute == "mean_evaluation_metrics": 

184 evaluation = self.model_storage.json_get("evaluation") 

185 return pd.DataFrame(evaluation["mean_evaluation_metrics"]) 

186 else: 

187 raise ValueError( 

188 f"Attribute {attribute} not supported, try 'evaluation_output' or 'mean_evaluation_metrics'" 

189 )