Coverage for mindsdb / utilities / render / sqlalchemy_render.py: 86%

581 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-21 00:36 +0000

1import re 

2import datetime as dt 

3 

4import sqlalchemy as sa 

5from sqlalchemy.exc import SQLAlchemyError 

6from sqlalchemy.ext.compiler import compiles 

7from sqlalchemy.orm import aliased 

8from sqlalchemy.engine.interfaces import Dialect 

9from sqlalchemy.dialects import mysql, postgresql, sqlite, mssql, oracle 

10from sqlalchemy.schema import CreateTable, DropTable 

11from sqlalchemy.sql import operators, ColumnElement, functions as sa_fnc 

12from sqlalchemy.sql.expression import ClauseElement 

13 

14from mindsdb_sql_parser import ast 

15 

16 

17RESERVED_WORDS = {"collation"} 

18 

19sa_type_names = [ 

20 key 

21 for key, val in sa.types.__dict__.items() 

22 if hasattr(val, "__module__") and val.__module__ in ("sqlalchemy.sql.sqltypes", "sqlalchemy.sql.type_api") 

23] 

24 

25types_map = {} 

26for type_name in sa_type_names: 

27 types_map[type_name.upper()] = getattr(sa.types, type_name) 

28types_map["BOOL"] = types_map["BOOLEAN"] 

29types_map["DEC"] = types_map["DECIMAL"] 

30 

31 

32class RenderError(Exception): ... 

33 

34 

35# https://github.com/sqlalchemy/sqlalchemy/discussions/9483?sort=old#discussioncomment-5312979 

36class INTERVAL(ColumnElement): 

37 def __init__(self, info): 

38 self.info = info 

39 self.type = sa.Interval() 

40 

41 

42@compiles(INTERVAL) 

43def _compile_interval(element, compiler, **kw): 

44 items = element.info.split(" ", maxsplit=1) 

45 if compiler.dialect.name == "oracle" and len(items) == 2: 

46 # replace to singular names (remove leading S if exists) 

47 if items[1].upper().endswith("S"): 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true

48 items[1] = items[1][:-1] 

49 

50 if getattr(compiler.dialect, "driver", None) == "snowflake" or compiler.dialect.name == "postgresql": 

51 # quote all 

52 args = " ".join(map(str, items)) 

53 args = f"'{args}'" 

54 else: 

55 # quote first element 

56 items[0] = f"'{items[0]}'" 

57 args = " ".join(items) 

58 return "INTERVAL " + args 

59 

60 

61# region definitions of custom clauses for GROUP BY ROLLUP 

62# This will work also in DuckDB, as it use postgres dialect 

63class GroupByRollup(ClauseElement): 

64 def __init__(self, *columns): 

65 self.columns = columns 

66 

67 

68@compiles(GroupByRollup) 

69def visit_group_by_rollup(element, compiler, **kw): 

70 columns = ", ".join([compiler.process(col, **kw) for col in element.columns]) 

71 if compiler.dialect.name in ("mysql", "default"): 

72 return f"{columns} WITH ROLLUP" 

73 else: 

74 return f"ROLLUP({columns})" 

75 

76 

77# endregion 

78 

79 

80class AttributedStr(str): 

81 """ 

82 Custom str-like object to pass it to `_requires_quotes` method with `is_quoted` flag 

83 """ 

84 

85 def __new__(cls, string, is_quoted: bool): 

86 obj = str.__new__(cls, string) 

87 obj.is_quoted = is_quoted 

88 return obj 

89 

90 def replace(self, *args): 

91 obj = super().replace(*args) 

92 return AttributedStr(obj, self.is_quoted) 

93 

94 

95def get_is_quoted(identifier: ast.Identifier): 

96 quoted = getattr(identifier, "is_quoted", []) 

97 # len can be different 

98 quoted = quoted + [None] * (len(identifier.parts) - len(quoted)) 

99 return quoted 

100 

101 

102dialects = { 

103 "mysql": mysql, 

104 "postgresql": postgresql, 

105 "postgres": postgresql, 

106 "sqlite": sqlite, 

107 "mssql": mssql, 

108 "oracle": oracle, 

109} 

110 

111 

112class SqlalchemyRender: 

113 def __init__(self, dialect_name: str | Dialect): 

114 if isinstance(dialect_name, str): 

115 dialect = dialects[dialect_name].dialect 

116 else: 

117 dialect = dialect_name 

118 

119 # override dialect's preparer 

120 if hasattr(dialect, "preparer") and dialect.preparer.__name__ != "MDBPreparer": 

121 

122 class MDBPreparer(dialect.preparer): 

123 def _requires_quotes(self, value: str) -> bool: 

124 # check force-quote flag 

125 if isinstance(value, AttributedStr): 

126 if value.is_quoted: 

127 return True 

128 

129 lc_value = value.lower() 

130 return ( 

131 lc_value in self.reserved_words 

132 or value[0] in self.illegal_initial_characters 

133 or not self.legal_characters.match(str(value)) 

134 # Override sqlalchemy behavior: don't require to quote mixed- or upper-case 

135 # or (lc_value != value) 

136 ) 

137 

138 dialect.preparer = MDBPreparer 

139 

140 # remove double percent signs 

141 # https://docs.sqlalchemy.org/en/14/faq/sqlexpressions.html#why-are-percent-signs-being-doubled-up-when-stringifying-sql-statements 

142 self.dialect = dialect(paramstyle="named") 

143 self.dialect.div_is_floordiv = False 

144 

145 self.selects_stack = [] 

146 

147 if dialect_name == "mssql": 

148 # update version to MS_2008_VERSION for supports_multivalues_insert 

149 self.dialect.server_version_info = (10,) 

150 self.dialect._setup_version_attributes() 

151 elif dialect_name == "mysql": 

152 # update version for support float cast 

153 self.dialect.server_version_info = (8, 0, 17) 

154 

155 def to_column(self, identifier: ast.Identifier) -> sa.Column: 

156 # because sqlalchemy doesn't allow columns consist from parts therefore we do it manually 

157 

158 parts2 = [] 

159 

160 quoted = get_is_quoted(identifier) 

161 for i, is_quoted in zip(identifier.parts, quoted): 

162 if isinstance(i, ast.Star): 

163 part = "*" 

164 elif is_quoted or i.lower() in RESERVED_WORDS: 

165 # quote anyway 

166 part = self.dialect.identifier_preparer.quote_identifier(i) 

167 else: 

168 # quote if required 

169 part = self.dialect.identifier_preparer.quote(i) 

170 

171 parts2.append(part) 

172 text = ".".join(parts2) 

173 if identifier.is_outer and self.dialect.name == "oracle": 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true

174 text += "(+)" 

175 return sa.column(text, is_literal=True) 

176 

177 def get_alias(self, alias): 

178 if alias is None or len(alias.parts) == 0: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 return None 

180 if len(alias.parts) > 1: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true

181 raise NotImplementedError(f"Multiple alias {alias.parts}") 

182 

183 if self.selects_stack: 183 ↛ 186line 183 didn't jump to line 186 because the condition on line 183 was always true

184 self.selects_stack[-1]["aliases"].append(alias) 

185 

186 is_quoted = get_is_quoted(alias)[0] 

187 return AttributedStr(alias.parts[0], is_quoted) 

188 

189 def make_unique_alias(self, name): 

190 if self.selects_stack: 190 ↛ exitline 190 didn't return from function 'make_unique_alias' because the condition on line 190 was always true

191 aliases = self.selects_stack[-1]["aliases"] 

192 for i in range(10): 192 ↛ exitline 192 didn't return from function 'make_unique_alias' because the loop on line 192 didn't complete

193 name2 = f"{name}_{i}" 

194 if name2 not in aliases: 

195 aliases.append(name2) 

196 return name2 

197 

198 def to_expression(self, t): 

199 # simple type 

200 if isinstance(t, str) or isinstance(t, int) or isinstance(t, float) or t is None: 

201 t = ast.Constant(t) 

202 

203 if isinstance(t, ast.Star): 

204 col = sa.text("*") 

205 elif isinstance(t, ast.Last): 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 col = self.to_column(ast.Identifier(parts=["last"])) 

207 elif isinstance(t, ast.Constant): 

208 col = sa.literal(t.value) 

209 if t.alias: 

210 alias = self.get_alias(t.alias) 

211 else: 

212 if t.value is None: 

213 alias = "NULL" 

214 else: 

215 alias = str(t.value) 

216 col = col.label(alias) 

217 elif isinstance(t, ast.Identifier): 

218 # sql functions 

219 col = None 

220 if len(t.parts) == 1: 

221 if isinstance(t.parts[0], str): 221 ↛ 231line 221 didn't jump to line 231 because the condition on line 221 was always true

222 name = t.parts[0].upper() 

223 if name == "CURRENT_DATE": 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true

224 col = sa_fnc.current_date() 

225 elif name == "CURRENT_TIME": 225 ↛ 226line 225 didn't jump to line 226 because the condition on line 225 was never true

226 col = sa_fnc.current_time() 

227 elif name == "CURRENT_TIMESTAMP": 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true

228 col = sa_fnc.current_timestamp() 

229 elif name == "CURRENT_USER": 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true

230 col = sa_fnc.current_user() 

231 if col is None: 231 ↛ 233line 231 didn't jump to line 233 because the condition on line 231 was always true

232 col = self.to_column(t) 

233 if t.alias: 

234 alias_name = self.get_alias(t.alias) 

235 # Skip self-referencing aliases (e.g., "column AS column") 

236 if len(t.parts) == 1 and t.parts[0] == alias_name: 

237 pass # Don't add alias if it matches the column name 

238 else: 

239 col = col.label(alias_name) 

240 elif isinstance(t, ast.Select): 

241 sub_stmt = self.prepare_select(t) 

242 col = sub_stmt.scalar_subquery() 

243 if t.alias: 

244 alias = self.get_alias(t.alias) 

245 col = col.label(alias) 

246 elif isinstance(t, ast.Function): 

247 col = self.to_function(t) 

248 if t.alias: 

249 alias = self.get_alias(t.alias) 

250 col = col.label(alias) 

251 else: 

252 alias = self.make_unique_alias(str(t.op)) 

253 if alias: 253 ↛ 429line 253 didn't jump to line 429 because the condition on line 253 was always true

254 col = col.label(alias) 

255 

256 elif isinstance(t, ast.BinaryOperation): 

257 ops = { 

258 "+": operators.add, 

259 "-": operators.sub, 

260 "*": operators.mul, 

261 "/": operators.truediv, 

262 "%": operators.mod, 

263 "=": operators.eq, 

264 "!=": operators.ne, 

265 "<>": operators.ne, 

266 ">": operators.gt, 

267 "<": operators.lt, 

268 ">=": operators.ge, 

269 "<=": operators.le, 

270 "is": operators.is_, 

271 "is not": operators.is_not, 

272 "like": operators.like_op, 

273 "not like": operators.not_like_op, 

274 "in": operators.in_op, 

275 "not in": operators.not_in_op, 

276 "||": operators.concat_op, 

277 } 

278 functions = { 

279 "and": sa.and_, 

280 "or": sa.or_, 

281 } 

282 

283 arg0 = self.to_expression(t.args[0]) 

284 arg1 = self.to_expression(t.args[1]) 

285 

286 op = t.op.lower() 

287 if op in ("in", "not in"): 

288 if t.args[1].parentheses: 

289 arg1 = [arg1] 

290 if isinstance(arg1, sa.sql.selectable.ColumnClause): 

291 raise NotImplementedError(f"Required list argument for: {op}") 

292 

293 sa_op = ops.get(op) 

294 

295 if sa_op is not None: 

296 if isinstance(arg0, sa.TextClause): 296 ↛ 298line 296 didn't jump to line 298 because the condition on line 296 was never true

297 # text doesn't have operate method, reverse operator 

298 col = arg1.reverse_operate(sa_op, arg0) 

299 elif isinstance(arg1, sa.TextClause): 299 ↛ 301line 299 didn't jump to line 301 because the condition on line 299 was never true

300 # both args are text, return text 

301 col = sa.text(f"{arg0.compile(dialect=self.dialect)} {op} {arg1.compile(dialect=self.dialect)}") 

302 else: 

303 col = arg0.operate(sa_op, arg1) 

304 

305 elif t.op.lower() in functions: 

306 func = functions[t.op.lower()] 

307 col = func(arg0, arg1) 

308 else: 

309 # for unknown operators wrap arguments into parens 

310 if isinstance(t.args[0], ast.BinaryOperation): 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true

311 arg0 = arg0.self_group() 

312 if isinstance(t.args[1], ast.BinaryOperation): 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 arg1 = arg1.self_group() 

314 

315 col = arg0.op(t.op)(arg1) 

316 

317 if t.alias: 

318 alias = self.get_alias(t.alias) 

319 col = col.label(alias) 

320 

321 elif isinstance(t, ast.UnaryOperation): 

322 # not or munus 

323 opmap = { 

324 "NOT": "__invert__", 

325 "-": "__neg__", 

326 } 

327 arg = self.to_expression(t.args[0]) 

328 

329 method = opmap[t.op.upper()] 

330 col = getattr(arg, method)() 

331 if t.alias: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true

332 alias = self.get_alias(t.alias) 

333 col = col.label(alias) 

334 

335 elif isinstance(t, ast.BetweenOperation): 

336 col0 = self.to_expression(t.args[0]) 

337 lim_down = self.to_expression(t.args[1]) 

338 lim_up = self.to_expression(t.args[2]) 

339 

340 col = sa.between(col0, lim_down, lim_up) 

341 elif isinstance(t, ast.Interval): 

342 col = INTERVAL(t.args[0]) 

343 if t.alias: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 alias = self.get_alias(t.alias) 

345 col = col.label(alias) 

346 

347 elif isinstance(t, ast.WindowFunction): 

348 func = self.to_expression(t.function) 

349 

350 partition = None 

351 if t.partition is not None: 

352 partition = [self.to_expression(i) for i in t.partition] 

353 

354 order_by = None 

355 if t.order_by is not None: 

356 order_by = [] 

357 for f in t.order_by: 

358 col0 = self.to_expression(f.field) 

359 if f.direction == "DESC": 359 ↛ 360line 359 didn't jump to line 360 because the condition on line 359 was never true

360 col0 = col0.desc() 

361 order_by.append(col0) 

362 

363 rows, range_ = None, None 

364 if t.modifier is not None: 

365 words = t.modifier.lower().split() 

366 if words[1] == "between" and words[4] == "and": 366 ↛ 389line 366 didn't jump to line 389 because the condition on line 366 was always true

367 # frame options 

368 # rows/groups BETWEEN <> <> AND <> <> 

369 # https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.over 

370 items = [] 

371 for word1, word2 in (words[2:4], words[5:7]): 

372 if word1 == "unbounded": 

373 items.append(None) 

374 elif (word1, word2) == ("current", "row"): 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true

375 items.append(0) 

376 elif word1.isdigits(): 

377 val = int(word1) 

378 if word2 == "preceding": 

379 val = -val 

380 elif word2 != "following": 

381 continue 

382 items.append(val) 

383 if len(items) == 2: 383 ↛ 389line 383 didn't jump to line 389 because the condition on line 383 was always true

384 if words[0] == "rows": 384 ↛ 386line 384 didn't jump to line 386 because the condition on line 384 was always true

385 rows = tuple(items) 

386 elif words[0] == "range": 

387 range_ = tuple(items) 

388 

389 col = sa.over(func, partition_by=partition, order_by=order_by, range_=range_, rows=rows) 

390 

391 if t.alias: 

392 col = col.label(self.get_alias(t.alias)) 

393 elif isinstance(t, ast.TypeCast): 

394 arg = self.to_expression(t.arg) 

395 type = self.get_type(t.type_name) 

396 if t.precision is not None: 

397 type = type(*t.precision) 

398 col = sa.cast(arg, type) 

399 

400 if t.alias: 

401 alias = self.get_alias(t.alias) 

402 col = col.label(alias) 

403 else: 

404 alias = self.make_unique_alias("cast") 

405 if alias: 405 ↛ 429line 405 didn't jump to line 429 because the condition on line 405 was always true

406 col = col.label(alias) 

407 elif isinstance(t, ast.Parameter): 

408 col = sa.column(t.value, is_literal=True) 

409 if t.alias: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true

410 raise RenderError("Parameter aliases are not supported in the renderer") 

411 elif isinstance(t, ast.Tuple): 

412 col = [self.to_expression(i) for i in t.items] 

413 elif isinstance(t, ast.Variable): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true

414 col = sa.column(t.to_string(), is_literal=True) 

415 elif isinstance(t, ast.Latest): 415 ↛ 416line 415 didn't jump to line 416 because the condition on line 415 was never true

416 col = sa.column(t.to_string(), is_literal=True) 

417 elif isinstance(t, ast.Exists): 

418 sub_stmt = self.prepare_select(t.query) 

419 col = sub_stmt.exists() 

420 elif isinstance(t, ast.NotExists): 

421 sub_stmt = self.prepare_select(t.query) 

422 col = ~sub_stmt.exists() 

423 elif isinstance(t, ast.Case): 

424 col = self.prepare_case(t) 

425 else: 

426 # some other complex object? 

427 raise NotImplementedError(f"Column {t}") 

428 

429 return col 

430 

431 def prepare_case(self, t: ast.Case): 

432 conditions = [] 

433 for condition, result in t.rules: 

434 conditions.append((self.to_expression(condition), self.to_expression(result))) 

435 default = None 

436 if t.default is not None: 

437 default = self.to_expression(t.default) 

438 

439 value = None 

440 if t.arg is not None: 

441 value = self.to_expression(t.arg) 

442 

443 col = sa.case(*conditions, else_=default, value=value) 

444 if t.alias: 

445 col = col.label(self.get_alias(t.alias)) 

446 return col 

447 

448 def to_function(self, t): 

449 if t.namespace is not None: 

450 op = getattr(sa.func, t.namespace) 

451 else: 

452 op = sa.func 

453 op = getattr(op, t.op) 

454 if t.from_arg is not None: 

455 arg = t.args[0].to_string() 

456 from_arg = self.to_expression(t.from_arg) 

457 

458 fnc = op(arg, from_arg) 

459 else: 

460 args = [self.to_expression(i) for i in t.args] 

461 if t.distinct: 

462 # set first argument to distinct 

463 args[0] = args[0].distinct() 

464 fnc = op(*args) 

465 return fnc 

466 

467 def get_type(self, typename): 

468 # TODO how to get type 

469 if not isinstance(typename, str): 

470 # sqlalchemy type 

471 return typename 

472 

473 typename = typename.upper() 

474 if re.match(r"^INT[\d]+$", typename): 

475 typename = "BIGINT" 

476 if re.match(r"^FLOAT[\d]+$", typename): 476 ↛ 477line 476 didn't jump to line 477 because the condition on line 476 was never true

477 typename = "FLOAT" 

478 

479 return types_map[typename] 

480 

481 def prepare_join(self, join): 

482 # join tree to table list 

483 

484 if isinstance(join.right, ast.Join): 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true

485 raise NotImplementedError("Wrong join AST") 

486 

487 items = [] 

488 

489 if isinstance(join.left, ast.Join): 

490 # dive to next level 

491 items.extend(self.prepare_join(join.left)) 

492 else: 

493 # this is first table 

494 items.append(dict(table=join.left)) 

495 

496 # all properties set to right table 

497 items.append( 

498 dict(table=join.right, join_type=join.join_type, is_implicit=join.implicit, condition=join.condition) 

499 ) 

500 

501 return items 

502 

503 def get_table_name(self, table_name): 

504 schema = None 

505 if isinstance(table_name, ast.Identifier): 

506 parts = table_name.parts 

507 quoted = get_is_quoted(table_name) 

508 

509 if len(parts) > 2: 

510 # TODO tests is failing 

511 raise NotImplementedError(f"Path to long: {table_name.parts}") 

512 

513 if len(parts) == 2: 

514 schema = AttributedStr(parts[-2], quoted[-2]) 

515 

516 table_name = AttributedStr(parts[-1], quoted[-1]) 

517 

518 return schema, table_name 

519 

520 def to_table(self, node, is_lateral=False): 

521 if isinstance(node, ast.Identifier): 

522 schema, table_name = self.get_table_name(node) 

523 

524 table = sa.table(table_name, schema=schema) 

525 

526 if node.alias: 

527 table = aliased(table, name=self.get_alias(node.alias)) 

528 

529 elif isinstance(node, (ast.Select, ast.Union, ast.Intersect, ast.Except)): 529 ↛ 541line 529 didn't jump to line 541 because the condition on line 529 was always true

530 sub_stmt = self.prepare_select(node) 

531 alias = None 

532 if node.alias: 

533 alias = self.get_alias(node.alias) 

534 if is_lateral: 

535 table = sub_stmt.lateral(alias) 

536 else: 

537 table = sub_stmt.subquery(alias) 

538 

539 else: 

540 # TODO tests are failing 

541 raise NotImplementedError(f"Table {node.__name__}") 

542 

543 return table 

544 

545 def prepare_select(self, node): 

546 if isinstance(node, (ast.Union, ast.Except, ast.Intersect)): 

547 return self.prepare_union(node) 

548 

549 cols = [] 

550 

551 self.selects_stack.append({"aliases": []}) 

552 

553 for t in node.targets: 

554 col = self.to_expression(t) 

555 cols.append(col) 

556 

557 query = sa.select(*cols) 

558 

559 if node.cte is not None: 

560 for cte in node.cte: 

561 if cte.columns is not None and len(cte.columns) > 0: 

562 raise NotImplementedError("CTE columns") 

563 

564 stmt = self.prepare_select(cte.query) 

565 alias = cte.name 

566 

567 query = query.add_cte(stmt.cte(self.get_alias(alias), nesting=True)) 

568 

569 if node.distinct is True: 

570 query = query.distinct() 

571 elif isinstance(node.distinct, list): 

572 columns = [self.to_expression(c) for c in node.distinct] 

573 query = query.distinct(*columns) 

574 

575 if node.from_table is not None: 

576 from_table = node.from_table 

577 

578 if isinstance(from_table, ast.Join): 

579 join_list = self.prepare_join(from_table) 

580 # first table 

581 table = self.to_table(join_list[0]["table"]) 

582 query = query.select_from(table) 

583 

584 # other tables 

585 has_explicit_join = False 

586 for item in join_list[1:]: 

587 join_type = item["join_type"] 

588 table = self.to_table(item["table"], is_lateral=("LATERAL" in join_type)) 

589 if item["is_implicit"]: 

590 # add to from clause 

591 if has_explicit_join: 

592 # sqlalchemy doesn't support implicit join after explicit 

593 # convert it to explicit 

594 query = query.join(table, sa.text("1=1")) 

595 else: 

596 query = query.select_from(table) 

597 else: 

598 has_explicit_join = True 

599 if item["condition"] is None: 

600 # otherwise, sqlalchemy raises "Don't know how to join to ..." 

601 condition = sa.text("1=1") 

602 else: 

603 condition = self.to_expression(item["condition"]) 

604 

605 if "ASOF" in join_type or "RIGHT" in join_type: 

606 raise NotImplementedError(f"Unsupported join type: {join_type}") 

607 

608 is_full = False 

609 is_outer = False 

610 if join_type in ("LEFT JOIN", "LEFT OUTER JOIN"): 

611 is_outer = True 

612 if join_type == "FULL JOIN": 

613 is_full = True 

614 

615 # perform join 

616 query = query.join(table, condition, isouter=is_outer, full=is_full) 

617 elif isinstance(from_table, (ast.Union, ast.Intersect, ast.Except)): 

618 alias = None 

619 if from_table.alias: 619 ↛ 621line 619 didn't jump to line 621 because the condition on line 619 was always true

620 alias = self.get_alias(from_table.alias) 

621 table = self.prepare_union(from_table).subquery(alias) 

622 query = query.select_from(table) 

623 

624 elif isinstance(from_table, ast.Select): 

625 table = self.to_table(from_table) 

626 query = query.select_from(table) 

627 

628 elif isinstance(from_table, ast.Identifier): 

629 table = self.to_table(from_table) 

630 query = query.select_from(table) 

631 

632 elif isinstance(from_table, ast.NativeQuery): 632 ↛ 633line 632 didn't jump to line 633 because the condition on line 632 was never true

633 alias = None 

634 if from_table.alias: 

635 alias = from_table.alias.parts[-1] 

636 table = sa.text(from_table.query).columns().subquery(alias) 

637 query = query.select_from(table) 

638 else: 

639 raise NotImplementedError(f"Select from {from_table}") 

640 

641 if node.where is not None: 

642 query = query.filter(self.to_expression(node.where)) 

643 

644 if node.group_by is not None: 

645 cols = [self.to_expression(i) for i in node.group_by] 

646 if getattr(node.group_by[-1], "with_rollup", False): 

647 query = query.group_by(GroupByRollup(*cols)) 

648 else: 

649 query = query.group_by(*cols) 

650 

651 if node.having is not None: 

652 query = query.having(self.to_expression(node.having)) 

653 

654 if node.order_by is not None: 

655 order_by = [] 

656 for f in node.order_by: 

657 col0 = self.to_expression(f.field) 

658 if f.direction.upper() == "DESC": 

659 col0 = col0.desc() 

660 elif f.direction.upper() == "ASC": 

661 col0 = col0.asc() 

662 if f.nulls.upper() == "NULLS FIRST": 

663 col0 = sa.nullsfirst(col0) 

664 elif f.nulls.upper() == "NULLS LAST": 

665 col0 = sa.nullslast(col0) 

666 order_by.append(col0) 

667 

668 query = query.order_by(*order_by) 

669 

670 if node.limit is not None: 

671 query = query.limit(node.limit.value) 

672 

673 if node.offset is not None: 

674 query = query.offset(node.offset.value) 

675 

676 if node.mode is not None: 

677 if node.mode == "FOR UPDATE": 677 ↛ 680line 677 didn't jump to line 680 because the condition on line 677 was always true

678 query = query.with_for_update() 

679 else: 

680 raise NotImplementedError(f"Select mode: {node.mode}") 

681 

682 self.selects_stack.pop() 

683 

684 return query 

685 

686 def prepare_union(self, from_table): 

687 step1 = self.prepare_select(from_table.left) 

688 step2 = self.prepare_select(from_table.right) 

689 

690 if isinstance(from_table, ast.Except): 

691 func = sa.except_ if from_table.unique else sa.except_all 

692 elif isinstance(from_table, ast.Intersect): 

693 func = sa.intersect if from_table.unique else sa.intersect_all 

694 else: 

695 func = sa.union if from_table.unique else sa.union_all 

696 

697 return func(step1, step2) 

698 

699 def prepare_create_table(self, ast_query): 

700 columns = [] 

701 

702 for col in ast_query.columns: 

703 default = None 

704 if col.default is not None: 704 ↛ 705line 704 didn't jump to line 705 because the condition on line 704 was never true

705 if isinstance(col.default, str): 

706 default = sa.text(col.default) 

707 

708 if isinstance(col.type, str) and col.type.lower() == "serial": 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true

709 col.is_primary_key = True 

710 col.type = "INT" 

711 

712 kwargs = { 

713 "primary_key": col.is_primary_key, 

714 "server_default": default, 

715 } 

716 if col.nullable is not None: 716 ↛ 717line 716 didn't jump to line 717 because the condition on line 716 was never true

717 kwargs["nullable"] = col.nullable 

718 

719 columns.append(sa.Column(col.name, self.get_type(col.type), **kwargs)) 

720 

721 schema, table_name = self.get_table_name(ast_query.name) 

722 

723 metadata = sa.MetaData() 

724 table = sa.Table(table_name, metadata, schema=schema, *columns) 

725 

726 return CreateTable(table) 

727 

728 def prepare_drop_table(self, ast_query): 

729 if len(ast_query.tables) != 1: 729 ↛ 730line 729 didn't jump to line 730 because the condition on line 729 was never true

730 raise NotImplementedError("Only one table is supported") 

731 

732 schema, table_name = self.get_table_name(ast_query.tables[0]) 

733 

734 metadata = sa.MetaData() 

735 table = sa.Table(table_name, metadata, schema=schema) 

736 return DropTable(table, if_exists=ast_query.if_exists) 

737 

738 def prepare_insert(self, ast_query, with_params=False): 

739 params = None 

740 schema, table_name = self.get_table_name(ast_query.table) 

741 

742 names = [] 

743 columns = [] 

744 

745 if ast_query.columns is None: 

746 raise NotImplementedError("Columns is required in insert query") 

747 for col in ast_query.columns: 

748 columns.append( 

749 sa.Column( 

750 col.name, 

751 # self.get_type(col.type) 

752 ) 

753 ) 

754 # check doubles 

755 if col.name in names: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true

756 raise RenderError(f"Columns name double: {col.name}") 

757 names.append(col.name) 

758 

759 table = sa.table(table_name, schema=schema, *columns) 

760 

761 if ast_query.values is not None: 

762 values = [] 

763 

764 if ast_query.is_plain and with_params: 

765 for i in range(len(ast_query.columns)): 

766 values.append(sa.column("%s", is_literal=True)) 

767 

768 values = [values] 

769 params = ast_query.values 

770 else: 

771 for row in ast_query.values: 

772 row = [self.to_expression(val) for val in row] 

773 values.append(row) 

774 

775 stmt = table.insert().values(values) 

776 else: 

777 # is insert from subselect 

778 subquery = self.prepare_select(ast_query.from_select) 

779 stmt = table.insert().from_select(names, subquery) 

780 

781 return stmt, params 

782 

783 def prepare_update(self, ast_query): 

784 if ast_query.from_select is not None: 

785 raise NotImplementedError("Render of update with sub-select is not implemented") 

786 

787 schema, table_name = self.get_table_name(ast_query.table) 

788 

789 columns = [] 

790 

791 to_update = {} 

792 for col, value in ast_query.update_columns.items(): 

793 columns.append( 

794 sa.Column( 

795 col, 

796 ) 

797 ) 

798 

799 to_update[col] = self.to_expression(value) 

800 

801 table = sa.table(table_name, schema=schema, *columns) 

802 

803 stmt = table.update().values(**to_update) 

804 

805 if ast_query.where is not None: 

806 stmt = stmt.where(self.to_expression(ast_query.where)) 

807 

808 return stmt 

809 

810 def prepare_delete(self, ast_query: ast.Delete): 

811 schema, table_name = self.get_table_name(ast_query.table) 

812 

813 columns = [] 

814 

815 table = sa.table(table_name, schema=schema, *columns) 

816 

817 stmt = table.delete() 

818 

819 if ast_query.where is not None: 

820 stmt = stmt.where(self.to_expression(ast_query.where)) 

821 

822 return stmt 

823 

824 def get_query(self, ast_query, with_params=False): 

825 params = None 

826 if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)): 

827 stmt = self.prepare_select(ast_query) 

828 elif isinstance(ast_query, ast.Insert): 

829 stmt, params = self.prepare_insert(ast_query, with_params=with_params) 

830 elif isinstance(ast_query, ast.Update): 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true

831 stmt = self.prepare_update(ast_query) 

832 elif isinstance(ast_query, ast.Delete): 

833 stmt = self.prepare_delete(ast_query) 

834 elif isinstance(ast_query, ast.CreateTable): 

835 stmt = self.prepare_create_table(ast_query) 

836 elif isinstance(ast_query, ast.DropTables): 

837 stmt = self.prepare_drop_table(ast_query) 

838 else: 

839 raise NotImplementedError(f"Unknown statement: {ast_query.__class__.__name__}") 

840 return stmt, params 

841 

842 def get_string(self, ast_query, with_failback=True): 

843 """ 

844 Render query to sql string 

845 

846 :param ast_query: query to render 

847 :param with_failback: switch to standard render in case of error 

848 :return: 

849 """ 

850 sql, _ = self.get_exec_params(ast_query, with_failback=with_failback, with_params=False) 

851 return sql 

852 

853 def get_exec_params(self, ast_query, with_failback=True, with_params=True): 

854 """ 

855 Render query with separated parameters and placeholders 

856 :param ast_query: query to render 

857 :param with_failback: switch to standard render in case of error 

858 :return: sql query and parameters 

859 """ 

860 

861 if isinstance(ast_query, (ast.CreateTable, ast.DropTables)): 

862 render_func = render_ddl_query 

863 else: 

864 render_func = render_dml_query 

865 

866 try: 

867 stmt, params = self.get_query(ast_query, with_params=with_params) 

868 

869 sql = render_func(stmt, self.dialect) 

870 

871 return sql, params 

872 

873 except (SQLAlchemyError, NotImplementedError) as e: 

874 if not with_failback: 874 ↛ 877line 874 didn't jump to line 877 because the condition on line 874 was always true

875 raise e 

876 

877 sql_query = str(ast_query) 

878 if self.dialect.name == "postgresql": 

879 sql_query = sql_query.replace("`", "") 

880 return sql_query, None 

881 

882 

883def render_dml_query(statement, dialect): 

884 class LiteralCompiler(dialect.statement_compiler): 

885 def render_literal_value(self, value, type_): 

886 if isinstance(value, (str, dt.date, dt.datetime, dt.timedelta)): 

887 return "'{}'".format(str(value).replace("'", "''")) 

888 

889 return super(LiteralCompiler, self).render_literal_value(value, type_) 

890 

891 return str(LiteralCompiler(dialect, statement, compile_kwargs={"literal_binds": True})) 

892 

893 

894def render_ddl_query(statement, dialect): 

895 class LiteralCompiler(dialect.ddl_compiler): 

896 def render_literal_value(self, value, type_): 

897 if isinstance(value, (str, dt.date, dt.datetime, dt.timedelta)): 

898 return "'{}'".format(str(value).replace("'", "''")) 

899 

900 return super(LiteralCompiler, self).render_literal_value(value, type_) 

901 

902 return str(LiteralCompiler(dialect, statement, compile_kwargs={"literal_binds": True}))