Coverage for mindsdb / api / http / namespaces / sql.py: 77%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import time 

2from http import HTTPStatus 

3from collections import defaultdict 

4 

5from flask import request 

6from flask_restx import Resource 

7 

8from mindsdb_sql_parser import parse_sql 

9from mindsdb_sql_parser import ast 

10 

11import mindsdb.utilities.hooks as hooks 

12import mindsdb.utilities.profiler as profiler 

13from mindsdb.api.http.utils import http_error 

14from mindsdb.api.http.namespaces.configs.sql import ns_conf 

15from mindsdb.api.mysql.mysql_proxy.mysql_proxy import SQLAnswer 

16from mindsdb.api.mysql.mysql_proxy.classes.fake_mysql_proxy import FakeMysqlProxy 

17from mindsdb.api.executor.data_types.response_type import ( 

18 RESPONSE_TYPE as SQL_RESPONSE_TYPE, 

19) 

20 

21from mindsdb.integrations.utilities.query_traversal import query_traversal 

22from mindsdb.api.executor.exceptions import ExecutorException, UnknownError 

23from mindsdb.metrics.metrics import api_endpoint_metrics 

24from mindsdb.utilities import log 

25from mindsdb.utilities.config import Config 

26from mindsdb.utilities.context import context as ctx 

27from mindsdb.utilities.exception import QueryError 

28from mindsdb.utilities.functions import mark_process 

29 

30logger = log.getLogger(__name__) 

31 

32 

33@ns_conf.route("/query") 

34@ns_conf.param("query", "Execute query") 

35class Query(Resource): 

36 def __init__(self, *args, **kwargs): 

37 super().__init__(*args, **kwargs) 

38 

39 @ns_conf.doc("query") 

40 @api_endpoint_metrics("POST", "/sql/query") 

41 @mark_process(name="http_query") 

42 def post(self): 

43 start_time = time.time() 

44 query = request.json["query"] 

45 context = request.json.get("context", {}) 

46 if "params" in request.json: 

47 ctx.params = request.json["params"] 

48 if isinstance(query, str) is False or isinstance(context, dict) is False: 48 ↛ 49line 48 didn't jump to line 49 because the condition on line 48 was never true

49 return http_error(HTTPStatus.BAD_REQUEST, "Wrong arguments", 'Please provide "query" with the request.') 

50 logger.debug(f"Incoming query: {query}") 

51 

52 if context.get("profiling") is True: 52 ↛ 53line 52 didn't jump to line 53 because the condition on line 52 was never true

53 profiler.enable() 

54 

55 error_type = None 

56 error_code = None 

57 error_text = None 

58 error_traceback = None 

59 

60 profiler.set_meta(query=query, api="http", environment=Config().get("environment")) 

61 with profiler.Context("http_query_processing"): 

62 mysql_proxy = FakeMysqlProxy() 

63 mysql_proxy.set_context(context) 

64 try: 

65 result: SQLAnswer = mysql_proxy.process_query(query) 

66 query_response: dict = result.dump_http_response() 

67 except ExecutorException as e: 

68 # classified error 

69 error_type = "expected" 

70 query_response = { 

71 "type": SQL_RESPONSE_TYPE.ERROR, 

72 "error_code": 0, 

73 "error_message": str(e), 

74 } 

75 logger.warning(f"Error query processing: {e}") 

76 except QueryError as e: 

77 error_type = "expected" if e.is_expected else "unexpected" 

78 query_response = { 

79 "type": SQL_RESPONSE_TYPE.ERROR, 

80 "error_code": 0, 

81 "error_message": str(e), 

82 } 

83 if e.is_expected: 

84 logger.warning(f"Query failed due to expected reason: {e}") 

85 else: 

86 logger.exception("Error query processing:") 

87 except UnknownError as e: 

88 # unclassified 

89 error_type = "unexpected" 

90 query_response = { 

91 "type": SQL_RESPONSE_TYPE.ERROR, 

92 "error_code": 0, 

93 "error_message": str(e), 

94 } 

95 logger.exception("Error query processing:") 

96 

97 except Exception as e: 

98 error_type = "unexpected" 

99 query_response = { 

100 "type": SQL_RESPONSE_TYPE.ERROR, 

101 "error_code": 0, 

102 "error_message": str(e), 

103 } 

104 logger.exception("Error query processing:") 

105 

106 if query_response.get("type") == SQL_RESPONSE_TYPE.ERROR: 

107 error_type = "expected" 

108 error_code = query_response.get("error_code") 

109 error_text = query_response.get("error_message") 

110 

111 context = mysql_proxy.get_context() 

112 

113 query_response["context"] = context 

114 

115 hooks.after_api_query( 

116 company_id=ctx.company_id, 

117 api="http", 

118 command=None, 

119 payload=query, 

120 error_type=error_type, 

121 error_code=error_code, 

122 error_text=error_text, 

123 traceback=error_traceback, 

124 ) 

125 

126 end_time = time.time() 

127 log_msg = f"SQL processed in {(end_time - start_time):.2f}s ({end_time:.2f}-{start_time:.2f}), result is {query_response['type']}" 

128 if query_response["type"] is SQL_RESPONSE_TYPE.TABLE: 

129 log_msg += f" ({len(query_response['data'])} rows), " 

130 elif query_response["type"] is SQL_RESPONSE_TYPE.ERROR: 

131 log_msg += f" ({query_response['error_message']}), " 

132 log_msg += f"used handlers {ctx.used_handlers}" 

133 logger.debug(log_msg) 

134 

135 return query_response, 200 

136 

137 

138@ns_conf.route("/query/utils/parametrize_constants") 

139class ParametrizeConstants(Resource): 

140 def __init__(self, *args, **kwargs): 

141 super().__init__(*args, **kwargs) 

142 

143 @api_endpoint_metrics("POST", "/query/utils/parametrize_constants") 

144 def post(self): 

145 sql_query = request.json["query"] 

146 

147 # find constants in the query and replace them with parameters 

148 query = parse_sql(sql_query) 

149 

150 parameters = [] 

151 param_counts = {} 

152 databases = defaultdict(set) 

153 

154 def to_parameter(param_name, value): 

155 if param_name is None: 155 ↛ 156line 155 didn't jump to line 156 because the condition on line 155 was never true

156 param_name = default_param_name 

157 

158 num = param_counts.get(param_name, 1) 

159 param_counts[param_name] = num + 1 

160 

161 if num > 1: 

162 param_name = param_name + str(num) 

163 

164 parameters.append({"name": param_name, "value": value, "type": type(value).__name__}) 

165 return ast.Parameter(param_name) 

166 

167 def find_constants_f(node, is_table, is_target, callstack, **kwargs): 

168 if is_table and isinstance(node, ast.Identifier): 

169 if len(node.parts) > 1: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true

170 databases[node.parts[0]].add(".".join(node.parts[1:])) 

171 

172 if not isinstance(node, ast.Constant): 

173 return 

174 

175 # it is a target 

176 if is_target and node.alias is not None: 

177 return to_parameter(node.alias.parts[-1], node.value) 

178 

179 param_name = None 

180 

181 for item in callstack: 181 ↛ 200line 181 didn't jump to line 200 because the loop on line 181 didn't complete

182 # try to find the name 

183 if isinstance(item, (ast.BinaryOperation, ast.BetweenOperation)) and item.op.lower() not in ( 

184 "and", 

185 "or", 

186 ): 

187 # it is probably a condition 

188 for arg in item.args: 188 ↛ 192line 188 didn't jump to line 192 because the loop on line 188 didn't complete

189 if isinstance(arg, ast.Identifier): 189 ↛ 188line 189 didn't jump to line 188 because the condition on line 189 was always true

190 param_name = arg.parts[-1] 

191 break 

192 if param_name is not None: 192 ↛ 195line 192 didn't jump to line 195 because the condition on line 192 was always true

193 break 

194 

195 if item.alias is not None: 

196 # it is probably a query target 

197 param_name = item.alias.parts[-1] 

198 break 

199 

200 return to_parameter(param_name, node.value) 

201 

202 if isinstance(query, ast.Update): 

203 for name, value in dict(query.update_columns).items(): 

204 if isinstance(value, ast.Constant): 204 ↛ 207line 204 didn't jump to line 207 because the condition on line 204 was always true

205 query.update_columns[name] = to_parameter(name, value.value) 

206 else: 

207 default_param_name = name 

208 query_traversal(value, find_constants_f) 

209 

210 elif isinstance(query, ast.Insert): 

211 # iterate over node.values and do some processing 

212 if query.values: 212 ↛ 228line 212 didn't jump to line 228 because the condition on line 212 was always true

213 values = [] 

214 for row in query.values: 

215 row2 = [] 

216 for i, val in enumerate(row): 

217 if isinstance(val, ast.Constant): 217 ↛ 224line 217 didn't jump to line 224 because the condition on line 217 was always true

218 param_name = None 

219 if query.columns and i < len(query.columns): 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true

220 param_name = query.columns[i].name 

221 elif query.table: 

222 param_name = query.table.parts[-1] 

223 val = to_parameter(param_name, val.value) 

224 row2.append(val) 

225 values.append(row2) 

226 query.values = values 

227 

228 default_param_name = "param" 

229 query_traversal(query, find_constants_f) 

230 

231 # to lists: 

232 databases = {k: list(v) for k, v in databases.items()} 

233 response = {"query": str(query), "parameters": parameters, "databases": databases} 

234 return response, 200 

235 

236 

237@ns_conf.route("/list_databases") 

238@ns_conf.param("list_databases", "lists databases of mindsdb") 

239class ListDatabases(Resource): 

240 @ns_conf.doc("list_databases") 

241 @api_endpoint_metrics("GET", "/sql/list_databases") 

242 def get(self): 

243 listing_query = "SHOW DATABASES" 

244 mysql_proxy = FakeMysqlProxy() 

245 try: 

246 result: SQLAnswer = mysql_proxy.process_query(listing_query) 

247 

248 # iterate over result.data and perform a query on each item to get the name of the tables 

249 if result.type == SQL_RESPONSE_TYPE.ERROR: 

250 listing_query_response = { 

251 "type": "error", 

252 "error_code": result.error_code, 

253 "error_message": result.error_message, 

254 } 

255 elif result.type == SQL_RESPONSE_TYPE.OK: 

256 listing_query_response = {"type": "ok"} 

257 elif result.type == SQL_RESPONSE_TYPE.TABLE: 

258 listing_query_response = { 

259 "data": [ 

260 { 

261 "name": db_row[0], 

262 "tables": [ 

263 table_row[0] 

264 for table_row in mysql_proxy.process_query( 

265 "SHOW TABLES FROM `{}`".format(db_row[0]) 

266 ).result_set.to_lists() 

267 ], 

268 } 

269 for db_row in result.result_set.to_lists() 

270 ] 

271 } 

272 except Exception as e: 

273 logger.exception("Error while retrieving list of databases") 

274 listing_query_response = { 

275 "type": "error", 

276 "error_code": 0, 

277 "error_message": str(e), 

278 } 

279 

280 return listing_query_response, 200