Coverage for mindsdb / integrations / handlers / lightwood_handler / tests / test_lightwood_handler.py: 0%
110 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import mindsdb.interfaces.storage.db as db
2from mindsdb.integrations.libs.response import RESPONSE_TYPE
3from mindsdb.interfaces.model.model_controller import ModelController
4from mindsdb.interfaces.storage.fs import FsStore
5from mindsdb.integrations.handlers.lightwood_handler.lightwood_handler.lightwood_handler import LightwoodHandler
6from mindsdb.interfaces.database.integrations import integration_controller
7from mindsdb.integrations.utilities.test_utils import PG_HANDLER_NAME, PG_CONNECTION_DATA
8from mindsdb.utilities.config import Config
9from mindsdb.migrations import migrate
10import os
11import time
12import unittest
13import tempfile
15temp_dir = tempfile.mkdtemp(dir='/tmp/', prefix='lightwood_handler_test_')
16os.environ['MINDSDB_STORAGE_DIR'] = os.environ.get('MINDSDB_STORAGE_DIR', temp_dir)
17os.environ['MINDSDB_DB_CON'] = 'sqlite:///' + os.path.join(os.environ['MINDSDB_STORAGE_DIR'], 'mindsdb.sqlite3.db') + '?check_same_thread=False&timeout=30'
19migrate.migrate_to_head()
21# from mindsdb.integrations.handlers.lightwood_handler.lightwood_handler.utils import load_predictor
24# TODO: drop all models and tables when closing tests
25class LightwoodHandlerTest(unittest.TestCase):
26 @classmethod
27 def setUpClass(cls):
28 # region create permanent integrations
29 for integration_name in ['files', 'lightwood']:
30 integration_record = db.Integration(
31 name=integration_name,
32 data={},
33 engine=integration_name,
34 company_id=None
35 )
36 db.session.add(integration_record)
37 db.session.commit()
38 integration_record = db.Integration(
39 name=PG_HANDLER_NAME,
40 data=PG_CONNECTION_DATA,
41 engine='postgres',
42 company_id=None
43 )
44 db.session.add(integration_record)
45 db.session.commit()
46 # endregion
48 handler_controller = integration_controller
50 cls.handler = LightwoodHandler(
51 'lightwood',
52 handler_controller=handler_controller,
53 fs_store=FsStore(),
54 model_controller=ModelController()
55 )
56 cls.config = Config()
58 cls.target_1 = 'rental_price'
59 cls.data_table_1 = 'demo_data.home_rentals'
60 cls.test_model_1 = 'test_lightwood_home_rentals'
61 cls.test_model_1b = 'test_lightwood_home_rentals_custom'
63 cls.target_2 = 'Traffic'
64 cls.data_table_2 = 'demo_data.house_sales'
65 cls.test_model_2 = 'test_lightwood_house_sales'
67 def test_00_check_connection(self):
68 conn = self.handler.check_connection()
69 assert conn.success
71 def test_01_drop_predictor(self):
72 if self.test_model_1 not in self.handler.get_tables().data_frame.values:
73 # TODO: seems redundant because of test_02
74 query = f"""
75 CREATE PREDICTOR {self.test_model_1}
76 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50)
77 PREDICT rental_price
78 """
79 self.handler.native_query(query)
80 response = self.handler.native_query(f"DROP PREDICTOR {self.test_model_1}")
81 self.assertTrue(response.type == RESPONSE_TYPE.OK)
83 def test_02_train_predictor(self):
84 query = f"""
85 CREATE PREDICTOR {self.test_model_1}
86 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50)
87 PREDICT rental_price
88 """
89 response = self.handler.native_query(query)
90 time.sleep(5)
91 self.assertTrue(response.type == RESPONSE_TYPE.OK)
93 def test_03_retrain_predictor(self):
94 query = f"RETRAIN {self.test_model_1}"
95 response = self.handler.native_query(query)
96 self.assertTrue(response.type == RESPONSE_TYPE.OK)
98 def test_04_query_predictor_single_where_condition(self):
99 time.sleep(120) # TODO
100 query = f"""
101 SELECT target
102 from {self.test_model_1}
103 WHERE sqft=100
104 """
105 response = self.handler.native_query(query)
106 self.assertTrue(response.type == RESPONSE_TYPE.TABLE)
107 self.assertTrue(len(response.data_frame) == 1)
108 self.assertTrue(response.data_frame['sqft'][0] == 100)
109 self.assertTrue(response.data_frame['rental_price'][0] is not None)
111 def test_05_query_predictor_multi_where_condition(self):
112 query = f"""
113 SELECT target
114 from {self.test_model_1}
115 WHERE sqft=100
116 AND number_of_rooms=2
117 AND number_of_bathrooms=1
118 """
119 response = self.handler.native_query(query)
120 self.assertTrue(response.type == RESPONSE_TYPE.TABLE)
121 self.assertTrue(len(response.data_frame) == 1)
122 self.assertTrue(response.data_frame['number_of_rooms'][0] == 2)
123 self.assertTrue(response.data_frame['number_of_bathrooms'][0] == 1)
125 def test_06_train_predictor_custom_jsonai(self):
126 # TODO: turn this into a decorator?
127 if self.test_model_1b in self.handler.get_tables().data_frame.values: # TODO this accesor feels weird, maybe rethink output format?
128 self.handler.native_query(f"DROP PREDICTOR {self.test_model_1b}")
130 using_str = 'model.args={"submodels": [{"module": "LightGBM", "args": {"stop_after": 12, "fit_on_dev": true}}]}'
131 query = f"""
132 CREATE PREDICTOR {self.test_model_1b}
133 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_1} limit 50)
134 PREDICT rental_price
135 USING {using_str}
136 """
137 response = self.handler.native_query(query)
138 self.assertTrue(response.type == RESPONSE_TYPE.OK)
139 # TODO assert
140 # m = load_predictor(self.handler.storage.get('models')[self.test_model_1b], self.test_model_1b)
141 # assert len(m.ensemble.mixers) == 1
142 # assert type(m.ensemble.mixers[0]).__name__ == 'LightGBM'
144 def test_07_list_tables(self):
145 response = self.handler.get_tables()
146 self.assertTrue(response.type == RESPONSE_TYPE.TABLE)
148 def test_08_get_columns(self):
149 response = self.handler.get_columns(f'{self.test_model_1}')
150 self.assertTrue(response.type == RESPONSE_TYPE.TABLE)
152 # TODO
153 # def test_09_join_predictor_into_table(self):
154 # into_table = 'test_join_into_lw'
155 # query = f"SELECT tb.{self.target_1} as predicted, ta.{self.target_1} as truth, ta.sqft from {PG_HANDLER_NAME}.{self.data_table_1} AS ta JOIN {self.test_model_1} AS tb LIMIT 10"
156 # parsed = self.handler.parser(query, dialect=self.handler.dialect)
157 # predicted = self.handler.join(parsed, self.data_handler, into=into_table)
159 # # checks whether `into` kwarg does insert into the table or not
160 # q = f"SELECT * FROM {into_table}"
161 # qp = self.handler.parser(q, dialect='mysql')
162 # assert len(self.data_handler.query(qp).data_frame) > 0
164 def test_10_train_ts_predictor_multigby_hor4(self):
165 # TODO: handle cap/uncapped column name returned from data handler? Had to rename 'MA' -> 'ma' for test to pass
166 query = f"""
167 CREATE PREDICTOR {self.test_model_2}
168 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2})
169 PREDICT ma
170 ORDER BY saledate
171 GROUP BY bedrooms, type
172 WINDOW 8
173 HORIZON 4
174 """
175 if self.test_model_2 not in self.handler.get_tables().data_frame.values:
176 response = self.handler.native_query(query)
177 else:
178 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}")
179 response = self.handler.native_query(query)
181 self.assertTrue(response.type == RESPONSE_TYPE.OK)
183 # TODO: reactivate and add to the rest of the TS tests once cache is back on
184 # p = self.handler.storage.get('models')
185 # m = load_predictor(p[self.test_model_2], self.test_model_2)
186 # assert m.problem_definition.timeseries_settings.is_timeseries
188 def test_12_train_ts_predictor_multigby_hor1(self):
189 query = f"""
190 CREATE PREDICTOR {self.test_model_2}
191 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2})
192 PREDICT ma
193 ORDER BY saledate
194 GROUP BY bedrooms, type
195 WINDOW 8
196 HORIZON 1
197 """
198 if self.test_model_2 not in self.handler.get_tables().data_frame.values:
199 self.handler.native_query(query)
200 else:
201 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}")
202 self.handler.native_query(query)
204 def test_13_train_ts_predictor_no_gby_hor1(self):
205 query = f"""
206 CREATE PREDICTOR {self.test_model_2}
207 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2})
208 PREDICT ma
209 ORDER BY saledate
210 WINDOW 8
211 HORIZON 1
212 """
213 if self.test_model_2 not in self.handler.get_tables().data_frame.values:
214 self.handler.native_query(query)
215 else:
216 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}")
217 self.handler.native_query(query)
219 def test_14_train_ts_predictor_no_gby_hor4(self):
220 query = f"""
221 CREATE PREDICTOR {self.test_model_2}
222 FROM {PG_HANDLER_NAME} (SELECT * FROM {self.data_table_2})
223 PREDICT ma
224 ORDER BY saledate
225 WINDOW 8
226 HORIZON 4
227 """
228 if self.test_model_2 not in self.handler.get_tables().data_frame.values:
229 self.handler.native_query(query)
230 else:
231 self.handler.native_query(f"DROP PREDICTOR {self.test_model_2}")
232 self.handler.native_query(query)
234 # TODO
235 # def test_15_join_predictor_ts_into(self):
236 # query = f"""
237 # SELECT m.saledate as date,
238 # m.ma as forecast
239 # FROM mindsdb.{self.test_model_2} m JOIN {PG_HANDLER_NAME}.demo_data.house_sales t
240 # WHERE t.saledate > LATEST
241 # AND t.type = 'house'
242 # AND t.bedrooms = 2
243 # LIMIT 10
244 # """
245 # response = self.handler.native_query(query)
246 # self.assertTrue(response.type == RESPONSE_TYPE.TABLE)
248 # def test_16_join_predictor_ts_model_left(self):
249 # # TODO: is this one needed?
250 # target = 'Traffic'
251 # oby = 'T'
252 # query = f"SELECT tb.{target} as predicted, ta.{target} as truth, ta.{oby} from mindsdb.{self.test_tsmodel_name_1} AS tb JOIN {self.sql_handler_name}.{self.data_table_name_2} AS ta ON 1=1 WHERE ta.{oby} > LATEST LIMIT 10"
253 # parsed = self.handler.parser(query, dialect=self.handler.dialect)
254 # predicted = self.handler.join(parsed, self.data_handler) # , into=self.model_2_into_table) # TODO: restore when we add support for SQLite and other handlers for `into`
257if __name__ == "__main__":
258 unittest.main(failfast=True)