Coverage for mindsdb / integrations / handlers / lightwood_handler / tests / test_lightwood_handler.py: 0%

110 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import mindsdb.interfaces.storage.db as db 

2from mindsdb.integrations.libs.response import RESPONSE_TYPE 

3from mindsdb.interfaces.model.model_controller import ModelController 

4from mindsdb.interfaces.storage.fs import FsStore 

5from mindsdb.integrations.handlers.lightwood_handler.lightwood_handler.lightwood_handler import LightwoodHandler 

6from mindsdb.interfaces.database.integrations import integration_controller 

7from mindsdb.integrations.utilities.test_utils import PG_HANDLER_NAME, PG_CONNECTION_DATA 

8from mindsdb.utilities.config import Config 

9from mindsdb.migrations import migrate 

10import os 

11import time 

12import unittest 

13import tempfile 

14 

15temp_dir = tempfile.mkdtemp(dir='/tmp/', prefix='lightwood_handler_test_') 

16os.environ['MINDSDB_STORAGE_DIR'] = os.environ.get('MINDSDB_STORAGE_DIR', temp_dir) 

17os.environ['MINDSDB_DB_CON'] = 'sqlite:///' + os.path.join(os.environ['MINDSDB_STORAGE_DIR'], 'mindsdb.sqlite3.db') + '?check_same_thread=False&timeout=30' 

18 

19migrate.migrate_to_head() 

20 

21# from mindsdb.integrations.handlers.lightwood_handler.lightwood_handler.utils import load_predictor 

22 

23 

24# TODO: drop all models and tables when closing tests 

25class LightwoodHandlerTest(unittest.TestCase): 

26 @classmethod 

27 def setUpClass(cls): 

28 # region create permanent integrations 

29 for integration_name in ['files', 'lightwood']: 

30 integration_record = db.Integration( 

31 name=integration_name, 

32 data={}, 

33 engine=integration_name, 

34 company_id=None 

35 ) 

36 db.session.add(integration_record) 

37 db.session.commit() 

38 integration_record = db.Integration( 

39 name=PG_HANDLER_NAME, 

40 data=PG_CONNECTION_DATA, 

41 engine='postgres', 

42 company_id=None 

43 ) 

44 db.session.add(integration_record) 

45 db.session.commit() 

46 # endregion 

47 

48 handler_controller = integration_controller 

49 

50 cls.handler = LightwoodHandler( 

51 'lightwood', 

52 handler_controller=handler_controller, 

53 fs_store=FsStore(), 

54 model_controller=ModelController() 

55 ) 

56 cls.config = Config() 

57 

58 cls.target_1 = 'rental_price' 

59 cls.data_table_1 = 'demo_data.home_rentals' 

60 cls.test_model_1 = 'test_lightwood_home_rentals' 

61 cls.test_model_1b = 'test_lightwood_home_rentals_custom' 

62 

63 cls.target_2 = 'Traffic' 

64 cls.data_table_2 = 'demo_data.house_sales' 

65 cls.test_model_2 = 'test_lightwood_house_sales' 

66 

67 def test_00_check_connection(self): 

68 conn = self.handler.check_connection() 

69 assert conn.success 

70 

71 def test_01_drop_predictor(self): 

72 if self.test_model_1 not in self.handler.get_tables().data_frame.values: 

73 # TODO: seems redundant because of test_02 

74 query = f""" 

75 CREATE PREDICTOR {self.test_model_1} 

76 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50) 

77 PREDICT rental_price 

78 """ 

79 self.handler.native_query(query) 

80 response = self.handler.native_query(f"DROP PREDICTOR {self.test_model_1}") 

81 self.assertTrue(response.type == RESPONSE_TYPE.OK) 

82 

83 def test_02_train_predictor(self): 

84 query = f""" 

85 CREATE PREDICTOR {self.test_model_1} 

86 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50) 

87 PREDICT rental_price 

88 """ 

89 response = self.handler.native_query(query) 

90 time.sleep(5) 

91 self.assertTrue(response.type == RESPONSE_TYPE.OK) 

92 

93 def test_03_retrain_predictor(self): 

94 query = f"RETRAIN {self.test_model_1}" 

95 response = self.handler.native_query(query) 

96 self.assertTrue(response.type == RESPONSE_TYPE.OK) 

97 

98 def test_04_query_predictor_single_where_condition(self): 

99 time.sleep(120) # TODO 

100 query = f""" 

101 SELECT target 

102 from {self.test_model_1} 

103 WHERE sqft=100 

104 """ 

105 response = self.handler.native_query(query) 

106 self.assertTrue(response.type == RESPONSE_TYPE.TABLE) 

107 self.assertTrue(len(response.data_frame) == 1) 

108 self.assertTrue(response.data_frame['sqft'][0] == 100) 

109 self.assertTrue(response.data_frame['rental_price'][0] is not None) 

110 

111 def test_05_query_predictor_multi_where_condition(self): 

112 query = f""" 

113 SELECT target 

114 from {self.test_model_1} 

115 WHERE sqft=100 

116 AND number_of_rooms=2 

117 AND number_of_bathrooms=1 

118 """ 

119 response = self.handler.native_query(query) 

120 self.assertTrue(response.type == RESPONSE_TYPE.TABLE) 

121 self.assertTrue(len(response.data_frame) == 1) 

122 self.assertTrue(response.data_frame['number_of_rooms'][0] == 2) 

123 self.assertTrue(response.data_frame['number_of_bathrooms'][0] == 1) 

124 

125 def test_06_train_predictor_custom_jsonai(self): 

126 # TODO: turn this into a decorator? 

127 if self.test_model_1b in self.handler.get_tables().data_frame.values: # TODO this accesor feels weird, maybe rethink output format? 

128 self.handler.native_query(f"DROP PREDICTOR {self.test_model_1b}") 

129 

130 using_str = 'model.args={"submodels": [{"module": "LightGBM", "args": {"stop_after": 12, "fit_on_dev": true}}]}' 

131 query = f""" 

132 CREATE PREDICTOR {self.test_model_1b} 

133 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50) 

134 PREDICT rental_price 

135 USING {using_str} 

136 """ 

137 response = self.handler.native_query(query) 

138 self.assertTrue(response.type == RESPONSE_TYPE.OK) 

139 # TODO assert 

140 # m = load_predictor(self.handler.storage.get('models')[self.test_model_1b], self.test_model_1b) 

141 # assert len(m.ensemble.mixers) == 1 

142 # assert type(m.ensemble.mixers[0]).__name__ == 'LightGBM' 

143 

144 def test_07_list_tables(self): 

145 response = self.handler.get_tables() 

146 self.assertTrue(response.type == RESPONSE_TYPE.TABLE) 

147 

148 def test_08_get_columns(self): 

149 response = self.handler.get_columns(f'{self.test_model_1}') 

150 self.assertTrue(response.type == RESPONSE_TYPE.TABLE) 

151 

152 # TODO 

153 # def test_09_join_predictor_into_table(self): 

154 # into_table = 'test_join_into_lw' 

155 # query = f"SELECT tb.{self.target_1} as predicted, ta.{self.target_1} as truth, ta.sqft from {PG_HANDLER_NAME}.{self.data_table_1} AS ta JOIN {self.test_model_1} AS tb LIMIT 10" 

156 # parsed = self.handler.parser(query, dialect=self.handler.dialect) 

157 # predicted = self.handler.join(parsed, self.data_handler, into=into_table) 

158 

159 # # checks whether `into` kwarg does insert into the table or not 

160 # q = f"SELECT * FROM {into_table}" 

161 # qp = self.handler.parser(q, dialect='mysql') 

162 # assert len(self.data_handler.query(qp).data_frame) > 0 

163 

164 def test_10_train_ts_predictor_multigby_hor4(self): 

165 # TODO: handle cap/uncapped column name returned from data handler? Had to rename 'MA' -> 'ma' for test to pass 

166 query = f""" 

167 CREATE PREDICTOR {self.test_model_2} 

168 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2}) 

169 PREDICT ma 

170 ORDER BY saledate 

171 GROUP BY bedrooms, type 

172 WINDOW 8 

173 HORIZON 4 

174 """ 

175 if self.test_model_2 not in self.handler.get_tables().data_frame.values: 

176 response = self.handler.native_query(query) 

177 else: 

178 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}") 

179 response = self.handler.native_query(query) 

180 

181 self.assertTrue(response.type == RESPONSE_TYPE.OK) 

182 

183 # TODO: reactivate and add to the rest of the TS tests once cache is back on 

184 # p = self.handler.storage.get('models') 

185 # m = load_predictor(p[self.test_model_2], self.test_model_2) 

186 # assert m.problem_definition.timeseries_settings.is_timeseries 

187 

188 def test_12_train_ts_predictor_multigby_hor1(self): 

189 query = f""" 

190 CREATE PREDICTOR {self.test_model_2} 

191 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2}) 

192 PREDICT ma 

193 ORDER BY saledate 

194 GROUP BY bedrooms, type 

195 WINDOW 8 

196 HORIZON 1 

197 """ 

198 if self.test_model_2 not in self.handler.get_tables().data_frame.values: 

199 self.handler.native_query(query) 

200 else: 

201 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}") 

202 self.handler.native_query(query) 

203 

204 def test_13_train_ts_predictor_no_gby_hor1(self): 

205 query = f""" 

206 CREATE PREDICTOR {self.test_model_2} 

207 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2}) 

208 PREDICT ma 

209 ORDER BY saledate 

210 WINDOW 8 

211 HORIZON 1 

212 """ 

213 if self.test_model_2 not in self.handler.get_tables().data_frame.values: 

214 self.handler.native_query(query) 

215 else: 

216 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}") 

217 self.handler.native_query(query) 

218 

219 def test_14_train_ts_predictor_no_gby_hor4(self): 

220 query = f""" 

221 CREATE PREDICTOR {self.test_model_2} 

222 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2}) 

223 PREDICT ma 

224 ORDER BY saledate 

225 WINDOW 8 

226 HORIZON 4 

227 """ 

228 if self.test_model_2 not in self.handler.get_tables().data_frame.values: 

229 self.handler.native_query(query) 

230 else: 

231 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}") 

232 self.handler.native_query(query) 

233 

234 # TODO 

235 # def test_15_join_predictor_ts_into(self): 

236 # query = f""" 

237 # SELECT m.saledate as date, 

238 # m.ma as forecast 

239 # FROM mindsdb.{self.test_model_2} m JOIN {PG_HANDLER_NAME}.demo_data.house_sales t 

240 # WHERE t.saledate > LATEST 

241 # AND t.type = 'house' 

242 # AND t.bedrooms = 2 

243 # LIMIT 10 

244 # """ 

245 # response = self.handler.native_query(query) 

246 # self.assertTrue(response.type == RESPONSE_TYPE.TABLE) 

247 

248 # def test_16_join_predictor_ts_model_left(self): 

249 # # TODO: is this one needed? 

250 # target = 'Traffic' 

251 # oby = 'T' 

252 # query = f"SELECT tb.{target} as predicted, ta.{target} as truth, ta.{oby} from mindsdb.{self.test_tsmodel_name_1} AS tb JOIN {self.sql_handler_name}.{self.data_table_name_2} AS ta ON 1=1 WHERE ta.{oby} > LATEST LIMIT 10" 

253 # parsed = self.handler.parser(query, dialect=self.handler.dialect) 

254 # predicted = self.handler.join(parsed, self.data_handler) # , into=self.model_2_into_table) # TODO: restore when we add support for SQLite and other handlers for `into` 

255 

256 

257if __name__ == "__main__": 

258 unittest.main(failfast=True)