Coverage for mindsdb / integrations / libs / ml_handler_process / learn_process.py: 17%

103 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import os 

2import importlib 

3import datetime as dt 

4 

5from sqlalchemy.orm.attributes import flag_modified 

6 

7from mindsdb_sql_parser import parse_sql 

8from mindsdb_sql_parser.ast import Identifier, Select, Star, NativeQuery 

9 

10from mindsdb.api.executor.sql_query import SQLQuery 

11import mindsdb.utilities.profiler as profiler 

12from mindsdb.utilities.functions import mark_process 

13from mindsdb.utilities.config import Config 

14from mindsdb.utilities.context import context as ctx 

15from mindsdb.utilities import log 

16import mindsdb.interfaces.storage.db as db 

17from mindsdb.interfaces.storage.model_fs import ModelStorage, HandlerStorage 

18from mindsdb.interfaces.model.functions import get_model_records 

19from mindsdb.integrations.utilities.utils import format_exception_error 

20from mindsdb.integrations.utilities.sql_utils import make_sql_session 

21from mindsdb.integrations.libs.const import PREDICTOR_STATUS 

22from mindsdb.integrations.libs.ml_handler_process.handlers_cacher import handlers_cacher 

23 

24logger = log.getLogger(__name__) 

25 

26 

27@mark_process(name="learn") 

28def learn_process( 

29 data_integration_ref: dict, 

30 problem_definition: dict, 

31 fetch_data_query: str, 

32 project_name: str, 

33 model_id: int, 

34 integration_id: int, 

35 base_model_id: int, 

36 set_active: bool, 

37 module_path: str, 

38): 

39 ctx.profiling = {"level": 0, "enabled": True, "pointer": None, "tree": None} 

40 profiler.set_meta(query="learn_process", api="http", environment=Config().get("environment")) 

41 with profiler.Context("learn_process"): 

42 from mindsdb.interfaces.database.database import DatabaseController 

43 

44 try: 

45 predictor_record = db.Predictor.query.with_for_update().get(model_id) 

46 predictor_record.training_metadata["process_id"] = os.getpid() 

47 flag_modified(predictor_record, "training_metadata") 

48 db.session.commit() 

49 

50 target = problem_definition.get("target", None) 

51 training_data_df = None 

52 if data_integration_ref is not None: 

53 database_controller = DatabaseController() 

54 sql_session = make_sql_session() 

55 if data_integration_ref["type"] == "integration": 

56 integration_name = database_controller.get_integration(data_integration_ref["id"])["name"] 

57 query = Select( 

58 targets=[Star()], 

59 from_table=NativeQuery(integration=Identifier(integration_name), query=fetch_data_query), 

60 ) 

61 sqlquery = SQLQuery(query, session=sql_session) 

62 if data_integration_ref["type"] == "system": 

63 query = Select( 

64 targets=[Star()], from_table=NativeQuery(integration=Identifier("log"), query=fetch_data_query) 

65 ) 

66 sqlquery = SQLQuery(query, session=sql_session) 

67 elif data_integration_ref["type"] == "view": 

68 project = database_controller.get_project(project_name) 

69 query_ast = parse_sql(fetch_data_query) 

70 view_meta = project.get_view_meta(query_ast) 

71 sqlquery = SQLQuery(view_meta["query_ast"], session=sql_session) 

72 elif data_integration_ref["type"] == "project": 

73 query_ast = parse_sql(fetch_data_query) 

74 sqlquery = SQLQuery(query_ast, session=sql_session) 

75 

76 training_data_df = sqlquery.fetched_data.to_df() 

77 

78 training_data_columns_count, training_data_rows_count = 0, 0 

79 if training_data_df is not None: 

80 training_data_columns_count = len(training_data_df.columns) 

81 training_data_rows_count = len(training_data_df) 

82 

83 predictor_record.training_data_columns_count = training_data_columns_count 

84 predictor_record.training_data_rows_count = training_data_rows_count 

85 db.session.commit() 

86 

87 module = importlib.import_module(module_path) 

88 

89 # check if module is imported successfully and raise exception if not 

90 if module.import_error is not None: 

91 raise module.import_error 

92 

93 handlerStorage = HandlerStorage(integration_id) 

94 modelStorage = ModelStorage(model_id) 

95 modelStorage.fileStorage.push() # FIXME 

96 

97 kwargs = {} 

98 if base_model_id is not None: 

99 kwargs["base_model_storage"] = ModelStorage(base_model_id) 

100 kwargs["base_model_storage"].fileStorage.pull() 

101 ml_handler = module.Handler(engine_storage=handlerStorage, model_storage=modelStorage, **kwargs) 

102 handlers_cacher[predictor_record.id] = ml_handler 

103 

104 if not ml_handler.generative and target is not None: 

105 if training_data_df is not None and target not in training_data_df.columns: 

106 # is the case different? convert column case in input dataframe 

107 col_names = {c.lower(): c for c in training_data_df.columns} 

108 target_found = col_names.get(target.lower()) 

109 if target_found: 

110 training_data_df.rename(columns={target_found: target}, inplace=True) 

111 else: 

112 raise Exception( 

113 f'Prediction target "{target}" not found in training dataframe: {list(training_data_df.columns)}' 

114 ) 

115 

116 # create new model 

117 if base_model_id is None: 

118 with profiler.Context("create"): 

119 ml_handler.create(target, df=training_data_df, args=problem_definition) 

120 

121 # fine-tune (partially train) existing model 

122 else: 

123 # load model from previous version, use it as starting point 

124 with profiler.Context("finetune"): 

125 problem_definition["base_model_id"] = base_model_id 

126 ml_handler.finetune(df=training_data_df, args=problem_definition) 

127 

128 predictor_record.status = PREDICTOR_STATUS.COMPLETE 

129 predictor_record.active = set_active 

130 db.session.commit() 

131 # if retrain and set_active after success creation 

132 if set_active is True: 

133 models = get_model_records( 

134 name=predictor_record.name, project_id=predictor_record.project_id, active=None 

135 ) 

136 for model in models: 

137 model.active = False 

138 models = [x for x in models if x.status == PREDICTOR_STATUS.COMPLETE] 

139 models.sort(key=lambda x: x.created_at) 

140 models[-1].active = True 

141 except Exception as e: 

142 logger.exception("Error during 'learn' process:") 

143 error_message = format_exception_error(e) 

144 

145 predictor_record = db.Predictor.query.with_for_update().get(model_id) 

146 predictor_record.data = {"error": error_message} 

147 predictor_record.status = PREDICTOR_STATUS.ERROR 

148 db.session.commit() 

149 

150 predictor_record.training_stop_at = dt.datetime.now() 

151 db.session.commit()