Coverage for mindsdb / utilities / render / sqlalchemy_render.py: 86%
581 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-21 00:36 +0000
1import re
2import datetime as dt
4import sqlalchemy as sa
5from sqlalchemy.exc import SQLAlchemyError
6from sqlalchemy.ext.compiler import compiles
7from sqlalchemy.orm import aliased
8from sqlalchemy.engine.interfaces import Dialect
9from sqlalchemy.dialects import mysql, postgresql, sqlite, mssql, oracle
10from sqlalchemy.schema import CreateTable, DropTable
11from sqlalchemy.sql import operators, ColumnElement, functions as sa_fnc
12from sqlalchemy.sql.expression import ClauseElement
14from mindsdb_sql_parser import ast
17RESERVED_WORDS = {"collation"}
19sa_type_names = [
20 key
21 for key, val in sa.types.__dict__.items()
22 if hasattr(val, "__module__") and val.__module__ in ("sqlalchemy.sql.sqltypes", "sqlalchemy.sql.type_api")
23]
25types_map = {}
26for type_name in sa_type_names:
27 types_map[type_name.upper()] = getattr(sa.types, type_name)
28types_map["BOOL"] = types_map["BOOLEAN"]
29types_map["DEC"] = types_map["DECIMAL"]
32class RenderError(Exception): ...
35# https://github.com/sqlalchemy/sqlalchemy/discussions/9483?sort=old#discussioncomment-5312979
36class INTERVAL(ColumnElement):
37 def __init__(self, info):
38 self.info = info
39 self.type = sa.Interval()
42@compiles(INTERVAL)
43def _compile_interval(element, compiler, **kw):
44 items = element.info.split(" ", maxsplit=1)
45 if compiler.dialect.name == "oracle" and len(items) == 2:
46 # replace to singular names (remove leading S if exists)
47 if items[1].upper().endswith("S"): 47 ↛ 48line 47 didn't jump to line 48 because the condition on line 47 was never true
48 items[1] = items[1][:-1]
50 if getattr(compiler.dialect, "driver", None) == "snowflake" or compiler.dialect.name == "postgresql":
51 # quote all
52 args = " ".join(map(str, items))
53 args = f"'{args}'"
54 else:
55 # quote first element
56 items[0] = f"'{items[0]}'"
57 args = " ".join(items)
58 return "INTERVAL " + args
61# region definitions of custom clauses for GROUP BY ROLLUP
62# This will work also in DuckDB, as it use postgres dialect
63class GroupByRollup(ClauseElement):
64 def __init__(self, *columns):
65 self.columns = columns
68@compiles(GroupByRollup)
69def visit_group_by_rollup(element, compiler, **kw):
70 columns = ", ".join([compiler.process(col, **kw) for col in element.columns])
71 if compiler.dialect.name in ("mysql", "default"):
72 return f"{columns} WITH ROLLUP"
73 else:
74 return f"ROLLUP({columns})"
77# endregion
80class AttributedStr(str):
81 """
82 Custom str-like object to pass it to `_requires_quotes` method with `is_quoted` flag
83 """
85 def __new__(cls, string, is_quoted: bool):
86 obj = str.__new__(cls, string)
87 obj.is_quoted = is_quoted
88 return obj
90 def replace(self, *args):
91 obj = super().replace(*args)
92 return AttributedStr(obj, self.is_quoted)
95def get_is_quoted(identifier: ast.Identifier):
96 quoted = getattr(identifier, "is_quoted", [])
97 # len can be different
98 quoted = quoted + [None] * (len(identifier.parts) - len(quoted))
99 return quoted
102dialects = {
103 "mysql": mysql,
104 "postgresql": postgresql,
105 "postgres": postgresql,
106 "sqlite": sqlite,
107 "mssql": mssql,
108 "oracle": oracle,
109}
112class SqlalchemyRender:
113 def __init__(self, dialect_name: str | Dialect):
114 if isinstance(dialect_name, str):
115 dialect = dialects[dialect_name].dialect
116 else:
117 dialect = dialect_name
119 # override dialect's preparer
120 if hasattr(dialect, "preparer") and dialect.preparer.__name__ != "MDBPreparer":
122 class MDBPreparer(dialect.preparer):
123 def _requires_quotes(self, value: str) -> bool:
124 # check force-quote flag
125 if isinstance(value, AttributedStr):
126 if value.is_quoted:
127 return True
129 lc_value = value.lower()
130 return (
131 lc_value in self.reserved_words
132 or value[0] in self.illegal_initial_characters
133 or not self.legal_characters.match(str(value))
134 # Override sqlalchemy behavior: don't require to quote mixed- or upper-case
135 # or (lc_value != value)
136 )
138 dialect.preparer = MDBPreparer
140 # remove double percent signs
141 # https://docs.sqlalchemy.org/en/14/faq/sqlexpressions.html#why-are-percent-signs-being-doubled-up-when-stringifying-sql-statements
142 self.dialect = dialect(paramstyle="named")
143 self.dialect.div_is_floordiv = False
145 self.selects_stack = []
147 if dialect_name == "mssql":
148 # update version to MS_2008_VERSION for supports_multivalues_insert
149 self.dialect.server_version_info = (10,)
150 self.dialect._setup_version_attributes()
151 elif dialect_name == "mysql":
152 # update version for support float cast
153 self.dialect.server_version_info = (8, 0, 17)
155 def to_column(self, identifier: ast.Identifier) -> sa.Column:
156 # because sqlalchemy doesn't allow columns consist from parts therefore we do it manually
158 parts2 = []
160 quoted = get_is_quoted(identifier)
161 for i, is_quoted in zip(identifier.parts, quoted):
162 if isinstance(i, ast.Star):
163 part = "*"
164 elif is_quoted or i.lower() in RESERVED_WORDS:
165 # quote anyway
166 part = self.dialect.identifier_preparer.quote_identifier(i)
167 else:
168 # quote if required
169 part = self.dialect.identifier_preparer.quote(i)
171 parts2.append(part)
172 text = ".".join(parts2)
173 if identifier.is_outer and self.dialect.name == "oracle": 173 ↛ 174line 173 didn't jump to line 174 because the condition on line 173 was never true
174 text += "(+)"
175 return sa.column(text, is_literal=True)
177 def get_alias(self, alias):
178 if alias is None or len(alias.parts) == 0: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 return None
180 if len(alias.parts) > 1: 180 ↛ 181line 180 didn't jump to line 181 because the condition on line 180 was never true
181 raise NotImplementedError(f"Multiple alias {alias.parts}")
183 if self.selects_stack: 183 ↛ 186line 183 didn't jump to line 186 because the condition on line 183 was always true
184 self.selects_stack[-1]["aliases"].append(alias)
186 is_quoted = get_is_quoted(alias)[0]
187 return AttributedStr(alias.parts[0], is_quoted)
189 def make_unique_alias(self, name):
190 if self.selects_stack: 190 ↛ exitline 190 didn't return from function 'make_unique_alias' because the condition on line 190 was always true
191 aliases = self.selects_stack[-1]["aliases"]
192 for i in range(10): 192 ↛ exitline 192 didn't return from function 'make_unique_alias' because the loop on line 192 didn't complete
193 name2 = f"{name}_{i}"
194 if name2 not in aliases:
195 aliases.append(name2)
196 return name2
198 def to_expression(self, t):
199 # simple type
200 if isinstance(t, str) or isinstance(t, int) or isinstance(t, float) or t is None:
201 t = ast.Constant(t)
203 if isinstance(t, ast.Star):
204 col = sa.text("*")
205 elif isinstance(t, ast.Last): 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true
206 col = self.to_column(ast.Identifier(parts=["last"]))
207 elif isinstance(t, ast.Constant):
208 col = sa.literal(t.value)
209 if t.alias:
210 alias = self.get_alias(t.alias)
211 else:
212 if t.value is None:
213 alias = "NULL"
214 else:
215 alias = str(t.value)
216 col = col.label(alias)
217 elif isinstance(t, ast.Identifier):
218 # sql functions
219 col = None
220 if len(t.parts) == 1:
221 if isinstance(t.parts[0], str): 221 ↛ 231line 221 didn't jump to line 231 because the condition on line 221 was always true
222 name = t.parts[0].upper()
223 if name == "CURRENT_DATE": 223 ↛ 224line 223 didn't jump to line 224 because the condition on line 223 was never true
224 col = sa_fnc.current_date()
225 elif name == "CURRENT_TIME": 225 ↛ 226line 225 didn't jump to line 226 because the condition on line 225 was never true
226 col = sa_fnc.current_time()
227 elif name == "CURRENT_TIMESTAMP": 227 ↛ 228line 227 didn't jump to line 228 because the condition on line 227 was never true
228 col = sa_fnc.current_timestamp()
229 elif name == "CURRENT_USER": 229 ↛ 230line 229 didn't jump to line 230 because the condition on line 229 was never true
230 col = sa_fnc.current_user()
231 if col is None: 231 ↛ 233line 231 didn't jump to line 233 because the condition on line 231 was always true
232 col = self.to_column(t)
233 if t.alias:
234 alias_name = self.get_alias(t.alias)
235 # Skip self-referencing aliases (e.g., "column AS column")
236 if len(t.parts) == 1 and t.parts[0] == alias_name:
237 pass # Don't add alias if it matches the column name
238 else:
239 col = col.label(alias_name)
240 elif isinstance(t, ast.Select):
241 sub_stmt = self.prepare_select(t)
242 col = sub_stmt.scalar_subquery()
243 if t.alias:
244 alias = self.get_alias(t.alias)
245 col = col.label(alias)
246 elif isinstance(t, ast.Function):
247 col = self.to_function(t)
248 if t.alias:
249 alias = self.get_alias(t.alias)
250 col = col.label(alias)
251 else:
252 alias = self.make_unique_alias(str(t.op))
253 if alias: 253 ↛ 429line 253 didn't jump to line 429 because the condition on line 253 was always true
254 col = col.label(alias)
256 elif isinstance(t, ast.BinaryOperation):
257 ops = {
258 "+": operators.add,
259 "-": operators.sub,
260 "*": operators.mul,
261 "/": operators.truediv,
262 "%": operators.mod,
263 "=": operators.eq,
264 "!=": operators.ne,
265 "<>": operators.ne,
266 ">": operators.gt,
267 "<": operators.lt,
268 ">=": operators.ge,
269 "<=": operators.le,
270 "is": operators.is_,
271 "is not": operators.is_not,
272 "like": operators.like_op,
273 "not like": operators.not_like_op,
274 "in": operators.in_op,
275 "not in": operators.not_in_op,
276 "||": operators.concat_op,
277 }
278 functions = {
279 "and": sa.and_,
280 "or": sa.or_,
281 }
283 arg0 = self.to_expression(t.args[0])
284 arg1 = self.to_expression(t.args[1])
286 op = t.op.lower()
287 if op in ("in", "not in"):
288 if t.args[1].parentheses:
289 arg1 = [arg1]
290 if isinstance(arg1, sa.sql.selectable.ColumnClause):
291 raise NotImplementedError(f"Required list argument for: {op}")
293 sa_op = ops.get(op)
295 if sa_op is not None:
296 if isinstance(arg0, sa.TextClause): 296 ↛ 298line 296 didn't jump to line 298 because the condition on line 296 was never true
297 # text doesn't have operate method, reverse operator
298 col = arg1.reverse_operate(sa_op, arg0)
299 elif isinstance(arg1, sa.TextClause): 299 ↛ 301line 299 didn't jump to line 301 because the condition on line 299 was never true
300 # both args are text, return text
301 col = sa.text(f"{arg0.compile(dialect=self.dialect)} {op} {arg1.compile(dialect=self.dialect)}")
302 else:
303 col = arg0.operate(sa_op, arg1)
305 elif t.op.lower() in functions:
306 func = functions[t.op.lower()]
307 col = func(arg0, arg1)
308 else:
309 # for unknown operators wrap arguments into parens
310 if isinstance(t.args[0], ast.BinaryOperation): 310 ↛ 311line 310 didn't jump to line 311 because the condition on line 310 was never true
311 arg0 = arg0.self_group()
312 if isinstance(t.args[1], ast.BinaryOperation): 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 arg1 = arg1.self_group()
315 col = arg0.op(t.op)(arg1)
317 if t.alias:
318 alias = self.get_alias(t.alias)
319 col = col.label(alias)
321 elif isinstance(t, ast.UnaryOperation):
322 # not or munus
323 opmap = {
324 "NOT": "__invert__",
325 "-": "__neg__",
326 }
327 arg = self.to_expression(t.args[0])
329 method = opmap[t.op.upper()]
330 col = getattr(arg, method)()
331 if t.alias: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true
332 alias = self.get_alias(t.alias)
333 col = col.label(alias)
335 elif isinstance(t, ast.BetweenOperation):
336 col0 = self.to_expression(t.args[0])
337 lim_down = self.to_expression(t.args[1])
338 lim_up = self.to_expression(t.args[2])
340 col = sa.between(col0, lim_down, lim_up)
341 elif isinstance(t, ast.Interval):
342 col = INTERVAL(t.args[0])
343 if t.alias: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 alias = self.get_alias(t.alias)
345 col = col.label(alias)
347 elif isinstance(t, ast.WindowFunction):
348 func = self.to_expression(t.function)
350 partition = None
351 if t.partition is not None:
352 partition = [self.to_expression(i) for i in t.partition]
354 order_by = None
355 if t.order_by is not None:
356 order_by = []
357 for f in t.order_by:
358 col0 = self.to_expression(f.field)
359 if f.direction == "DESC": 359 ↛ 360line 359 didn't jump to line 360 because the condition on line 359 was never true
360 col0 = col0.desc()
361 order_by.append(col0)
363 rows, range_ = None, None
364 if t.modifier is not None:
365 words = t.modifier.lower().split()
366 if words[1] == "between" and words[4] == "and": 366 ↛ 389line 366 didn't jump to line 389 because the condition on line 366 was always true
367 # frame options
368 # rows/groups BETWEEN <> <> AND <> <>
369 # https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.over
370 items = []
371 for word1, word2 in (words[2:4], words[5:7]):
372 if word1 == "unbounded":
373 items.append(None)
374 elif (word1, word2) == ("current", "row"): 374 ↛ 376line 374 didn't jump to line 376 because the condition on line 374 was always true
375 items.append(0)
376 elif word1.isdigits():
377 val = int(word1)
378 if word2 == "preceding":
379 val = -val
380 elif word2 != "following":
381 continue
382 items.append(val)
383 if len(items) == 2: 383 ↛ 389line 383 didn't jump to line 389 because the condition on line 383 was always true
384 if words[0] == "rows": 384 ↛ 386line 384 didn't jump to line 386 because the condition on line 384 was always true
385 rows = tuple(items)
386 elif words[0] == "range":
387 range_ = tuple(items)
389 col = sa.over(func, partition_by=partition, order_by=order_by, range_=range_, rows=rows)
391 if t.alias:
392 col = col.label(self.get_alias(t.alias))
393 elif isinstance(t, ast.TypeCast):
394 arg = self.to_expression(t.arg)
395 type = self.get_type(t.type_name)
396 if t.precision is not None:
397 type = type(*t.precision)
398 col = sa.cast(arg, type)
400 if t.alias:
401 alias = self.get_alias(t.alias)
402 col = col.label(alias)
403 else:
404 alias = self.make_unique_alias("cast")
405 if alias: 405 ↛ 429line 405 didn't jump to line 429 because the condition on line 405 was always true
406 col = col.label(alias)
407 elif isinstance(t, ast.Parameter):
408 col = sa.column(t.value, is_literal=True)
409 if t.alias: 409 ↛ 410line 409 didn't jump to line 410 because the condition on line 409 was never true
410 raise RenderError("Parameter aliases are not supported in the renderer")
411 elif isinstance(t, ast.Tuple):
412 col = [self.to_expression(i) for i in t.items]
413 elif isinstance(t, ast.Variable): 413 ↛ 414line 413 didn't jump to line 414 because the condition on line 413 was never true
414 col = sa.column(t.to_string(), is_literal=True)
415 elif isinstance(t, ast.Latest): 415 ↛ 416line 415 didn't jump to line 416 because the condition on line 415 was never true
416 col = sa.column(t.to_string(), is_literal=True)
417 elif isinstance(t, ast.Exists):
418 sub_stmt = self.prepare_select(t.query)
419 col = sub_stmt.exists()
420 elif isinstance(t, ast.NotExists):
421 sub_stmt = self.prepare_select(t.query)
422 col = ~sub_stmt.exists()
423 elif isinstance(t, ast.Case):
424 col = self.prepare_case(t)
425 else:
426 # some other complex object?
427 raise NotImplementedError(f"Column {t}")
429 return col
431 def prepare_case(self, t: ast.Case):
432 conditions = []
433 for condition, result in t.rules:
434 conditions.append((self.to_expression(condition), self.to_expression(result)))
435 default = None
436 if t.default is not None:
437 default = self.to_expression(t.default)
439 value = None
440 if t.arg is not None:
441 value = self.to_expression(t.arg)
443 col = sa.case(*conditions, else_=default, value=value)
444 if t.alias:
445 col = col.label(self.get_alias(t.alias))
446 return col
448 def to_function(self, t):
449 if t.namespace is not None:
450 op = getattr(sa.func, t.namespace)
451 else:
452 op = sa.func
453 op = getattr(op, t.op)
454 if t.from_arg is not None:
455 arg = t.args[0].to_string()
456 from_arg = self.to_expression(t.from_arg)
458 fnc = op(arg, from_arg)
459 else:
460 args = [self.to_expression(i) for i in t.args]
461 if t.distinct:
462 # set first argument to distinct
463 args[0] = args[0].distinct()
464 fnc = op(*args)
465 return fnc
467 def get_type(self, typename):
468 # TODO how to get type
469 if not isinstance(typename, str):
470 # sqlalchemy type
471 return typename
473 typename = typename.upper()
474 if re.match(r"^INT[\d]+$", typename):
475 typename = "BIGINT"
476 if re.match(r"^FLOAT[\d]+$", typename): 476 ↛ 477line 476 didn't jump to line 477 because the condition on line 476 was never true
477 typename = "FLOAT"
479 return types_map[typename]
481 def prepare_join(self, join):
482 # join tree to table list
484 if isinstance(join.right, ast.Join): 484 ↛ 485line 484 didn't jump to line 485 because the condition on line 484 was never true
485 raise NotImplementedError("Wrong join AST")
487 items = []
489 if isinstance(join.left, ast.Join):
490 # dive to next level
491 items.extend(self.prepare_join(join.left))
492 else:
493 # this is first table
494 items.append(dict(table=join.left))
496 # all properties set to right table
497 items.append(
498 dict(table=join.right, join_type=join.join_type, is_implicit=join.implicit, condition=join.condition)
499 )
501 return items
503 def get_table_name(self, table_name):
504 schema = None
505 if isinstance(table_name, ast.Identifier):
506 parts = table_name.parts
507 quoted = get_is_quoted(table_name)
509 if len(parts) > 2:
510 # TODO tests is failing
511 raise NotImplementedError(f"Path to long: {table_name.parts}")
513 if len(parts) == 2:
514 schema = AttributedStr(parts[-2], quoted[-2])
516 table_name = AttributedStr(parts[-1], quoted[-1])
518 return schema, table_name
520 def to_table(self, node, is_lateral=False):
521 if isinstance(node, ast.Identifier):
522 schema, table_name = self.get_table_name(node)
524 table = sa.table(table_name, schema=schema)
526 if node.alias:
527 table = aliased(table, name=self.get_alias(node.alias))
529 elif isinstance(node, (ast.Select, ast.Union, ast.Intersect, ast.Except)): 529 ↛ 541line 529 didn't jump to line 541 because the condition on line 529 was always true
530 sub_stmt = self.prepare_select(node)
531 alias = None
532 if node.alias:
533 alias = self.get_alias(node.alias)
534 if is_lateral:
535 table = sub_stmt.lateral(alias)
536 else:
537 table = sub_stmt.subquery(alias)
539 else:
540 # TODO tests are failing
541 raise NotImplementedError(f"Table {node.__name__}")
543 return table
545 def prepare_select(self, node):
546 if isinstance(node, (ast.Union, ast.Except, ast.Intersect)):
547 return self.prepare_union(node)
549 cols = []
551 self.selects_stack.append({"aliases": []})
553 for t in node.targets:
554 col = self.to_expression(t)
555 cols.append(col)
557 query = sa.select(*cols)
559 if node.cte is not None:
560 for cte in node.cte:
561 if cte.columns is not None and len(cte.columns) > 0:
562 raise NotImplementedError("CTE columns")
564 stmt = self.prepare_select(cte.query)
565 alias = cte.name
567 query = query.add_cte(stmt.cte(self.get_alias(alias), nesting=True))
569 if node.distinct is True:
570 query = query.distinct()
571 elif isinstance(node.distinct, list):
572 columns = [self.to_expression(c) for c in node.distinct]
573 query = query.distinct(*columns)
575 if node.from_table is not None:
576 from_table = node.from_table
578 if isinstance(from_table, ast.Join):
579 join_list = self.prepare_join(from_table)
580 # first table
581 table = self.to_table(join_list[0]["table"])
582 query = query.select_from(table)
584 # other tables
585 has_explicit_join = False
586 for item in join_list[1:]:
587 join_type = item["join_type"]
588 table = self.to_table(item["table"], is_lateral=("LATERAL" in join_type))
589 if item["is_implicit"]:
590 # add to from clause
591 if has_explicit_join:
592 # sqlalchemy doesn't support implicit join after explicit
593 # convert it to explicit
594 query = query.join(table, sa.text("1=1"))
595 else:
596 query = query.select_from(table)
597 else:
598 has_explicit_join = True
599 if item["condition"] is None:
600 # otherwise, sqlalchemy raises "Don't know how to join to ..."
601 condition = sa.text("1=1")
602 else:
603 condition = self.to_expression(item["condition"])
605 if "ASOF" in join_type or "RIGHT" in join_type:
606 raise NotImplementedError(f"Unsupported join type: {join_type}")
608 is_full = False
609 is_outer = False
610 if join_type in ("LEFT JOIN", "LEFT OUTER JOIN"):
611 is_outer = True
612 if join_type == "FULL JOIN":
613 is_full = True
615 # perform join
616 query = query.join(table, condition, isouter=is_outer, full=is_full)
617 elif isinstance(from_table, (ast.Union, ast.Intersect, ast.Except)):
618 alias = None
619 if from_table.alias: 619 ↛ 621line 619 didn't jump to line 621 because the condition on line 619 was always true
620 alias = self.get_alias(from_table.alias)
621 table = self.prepare_union(from_table).subquery(alias)
622 query = query.select_from(table)
624 elif isinstance(from_table, ast.Select):
625 table = self.to_table(from_table)
626 query = query.select_from(table)
628 elif isinstance(from_table, ast.Identifier):
629 table = self.to_table(from_table)
630 query = query.select_from(table)
632 elif isinstance(from_table, ast.NativeQuery): 632 ↛ 633line 632 didn't jump to line 633 because the condition on line 632 was never true
633 alias = None
634 if from_table.alias:
635 alias = from_table.alias.parts[-1]
636 table = sa.text(from_table.query).columns().subquery(alias)
637 query = query.select_from(table)
638 else:
639 raise NotImplementedError(f"Select from {from_table}")
641 if node.where is not None:
642 query = query.filter(self.to_expression(node.where))
644 if node.group_by is not None:
645 cols = [self.to_expression(i) for i in node.group_by]
646 if getattr(node.group_by[-1], "with_rollup", False):
647 query = query.group_by(GroupByRollup(*cols))
648 else:
649 query = query.group_by(*cols)
651 if node.having is not None:
652 query = query.having(self.to_expression(node.having))
654 if node.order_by is not None:
655 order_by = []
656 for f in node.order_by:
657 col0 = self.to_expression(f.field)
658 if f.direction.upper() == "DESC":
659 col0 = col0.desc()
660 elif f.direction.upper() == "ASC":
661 col0 = col0.asc()
662 if f.nulls.upper() == "NULLS FIRST":
663 col0 = sa.nullsfirst(col0)
664 elif f.nulls.upper() == "NULLS LAST":
665 col0 = sa.nullslast(col0)
666 order_by.append(col0)
668 query = query.order_by(*order_by)
670 if node.limit is not None:
671 query = query.limit(node.limit.value)
673 if node.offset is not None:
674 query = query.offset(node.offset.value)
676 if node.mode is not None:
677 if node.mode == "FOR UPDATE": 677 ↛ 680line 677 didn't jump to line 680 because the condition on line 677 was always true
678 query = query.with_for_update()
679 else:
680 raise NotImplementedError(f"Select mode: {node.mode}")
682 self.selects_stack.pop()
684 return query
686 def prepare_union(self, from_table):
687 step1 = self.prepare_select(from_table.left)
688 step2 = self.prepare_select(from_table.right)
690 if isinstance(from_table, ast.Except):
691 func = sa.except_ if from_table.unique else sa.except_all
692 elif isinstance(from_table, ast.Intersect):
693 func = sa.intersect if from_table.unique else sa.intersect_all
694 else:
695 func = sa.union if from_table.unique else sa.union_all
697 return func(step1, step2)
699 def prepare_create_table(self, ast_query):
700 columns = []
702 for col in ast_query.columns:
703 default = None
704 if col.default is not None: 704 ↛ 705line 704 didn't jump to line 705 because the condition on line 704 was never true
705 if isinstance(col.default, str):
706 default = sa.text(col.default)
708 if isinstance(col.type, str) and col.type.lower() == "serial": 708 ↛ 709line 708 didn't jump to line 709 because the condition on line 708 was never true
709 col.is_primary_key = True
710 col.type = "INT"
712 kwargs = {
713 "primary_key": col.is_primary_key,
714 "server_default": default,
715 }
716 if col.nullable is not None: 716 ↛ 717line 716 didn't jump to line 717 because the condition on line 716 was never true
717 kwargs["nullable"] = col.nullable
719 columns.append(sa.Column(col.name, self.get_type(col.type), **kwargs))
721 schema, table_name = self.get_table_name(ast_query.name)
723 metadata = sa.MetaData()
724 table = sa.Table(table_name, metadata, schema=schema, *columns)
726 return CreateTable(table)
728 def prepare_drop_table(self, ast_query):
729 if len(ast_query.tables) != 1: 729 ↛ 730line 729 didn't jump to line 730 because the condition on line 729 was never true
730 raise NotImplementedError("Only one table is supported")
732 schema, table_name = self.get_table_name(ast_query.tables[0])
734 metadata = sa.MetaData()
735 table = sa.Table(table_name, metadata, schema=schema)
736 return DropTable(table, if_exists=ast_query.if_exists)
738 def prepare_insert(self, ast_query, with_params=False):
739 params = None
740 schema, table_name = self.get_table_name(ast_query.table)
742 names = []
743 columns = []
745 if ast_query.columns is None:
746 raise NotImplementedError("Columns is required in insert query")
747 for col in ast_query.columns:
748 columns.append(
749 sa.Column(
750 col.name,
751 # self.get_type(col.type)
752 )
753 )
754 # check doubles
755 if col.name in names: 755 ↛ 756line 755 didn't jump to line 756 because the condition on line 755 was never true
756 raise RenderError(f"Columns name double: {col.name}")
757 names.append(col.name)
759 table = sa.table(table_name, schema=schema, *columns)
761 if ast_query.values is not None:
762 values = []
764 if ast_query.is_plain and with_params:
765 for i in range(len(ast_query.columns)):
766 values.append(sa.column("%s", is_literal=True))
768 values = [values]
769 params = ast_query.values
770 else:
771 for row in ast_query.values:
772 row = [self.to_expression(val) for val in row]
773 values.append(row)
775 stmt = table.insert().values(values)
776 else:
777 # is insert from subselect
778 subquery = self.prepare_select(ast_query.from_select)
779 stmt = table.insert().from_select(names, subquery)
781 return stmt, params
783 def prepare_update(self, ast_query):
784 if ast_query.from_select is not None:
785 raise NotImplementedError("Render of update with sub-select is not implemented")
787 schema, table_name = self.get_table_name(ast_query.table)
789 columns = []
791 to_update = {}
792 for col, value in ast_query.update_columns.items():
793 columns.append(
794 sa.Column(
795 col,
796 )
797 )
799 to_update[col] = self.to_expression(value)
801 table = sa.table(table_name, schema=schema, *columns)
803 stmt = table.update().values(**to_update)
805 if ast_query.where is not None:
806 stmt = stmt.where(self.to_expression(ast_query.where))
808 return stmt
810 def prepare_delete(self, ast_query: ast.Delete):
811 schema, table_name = self.get_table_name(ast_query.table)
813 columns = []
815 table = sa.table(table_name, schema=schema, *columns)
817 stmt = table.delete()
819 if ast_query.where is not None:
820 stmt = stmt.where(self.to_expression(ast_query.where))
822 return stmt
824 def get_query(self, ast_query, with_params=False):
825 params = None
826 if isinstance(ast_query, (ast.Select, ast.Union, ast.Except, ast.Intersect)):
827 stmt = self.prepare_select(ast_query)
828 elif isinstance(ast_query, ast.Insert):
829 stmt, params = self.prepare_insert(ast_query, with_params=with_params)
830 elif isinstance(ast_query, ast.Update): 830 ↛ 831line 830 didn't jump to line 831 because the condition on line 830 was never true
831 stmt = self.prepare_update(ast_query)
832 elif isinstance(ast_query, ast.Delete):
833 stmt = self.prepare_delete(ast_query)
834 elif isinstance(ast_query, ast.CreateTable):
835 stmt = self.prepare_create_table(ast_query)
836 elif isinstance(ast_query, ast.DropTables):
837 stmt = self.prepare_drop_table(ast_query)
838 else:
839 raise NotImplementedError(f"Unknown statement: {ast_query.__class__.__name__}")
840 return stmt, params
842 def get_string(self, ast_query, with_failback=True):
843 """
844 Render query to sql string
846 :param ast_query: query to render
847 :param with_failback: switch to standard render in case of error
848 :return:
849 """
850 sql, _ = self.get_exec_params(ast_query, with_failback=with_failback, with_params=False)
851 return sql
853 def get_exec_params(self, ast_query, with_failback=True, with_params=True):
854 """
855 Render query with separated parameters and placeholders
856 :param ast_query: query to render
857 :param with_failback: switch to standard render in case of error
858 :return: sql query and parameters
859 """
861 if isinstance(ast_query, (ast.CreateTable, ast.DropTables)):
862 render_func = render_ddl_query
863 else:
864 render_func = render_dml_query
866 try:
867 stmt, params = self.get_query(ast_query, with_params=with_params)
869 sql = render_func(stmt, self.dialect)
871 return sql, params
873 except (SQLAlchemyError, NotImplementedError) as e:
874 if not with_failback: 874 ↛ 877line 874 didn't jump to line 877 because the condition on line 874 was always true
875 raise e
877 sql_query = str(ast_query)
878 if self.dialect.name == "postgresql":
879 sql_query = sql_query.replace("`", "")
880 return sql_query, None
883def render_dml_query(statement, dialect):
884 class LiteralCompiler(dialect.statement_compiler):
885 def render_literal_value(self, value, type_):
886 if isinstance(value, (str, dt.date, dt.datetime, dt.timedelta)):
887 return "'{}'".format(str(value).replace("'", "''"))
889 return super(LiteralCompiler, self).render_literal_value(value, type_)
891 return str(LiteralCompiler(dialect, statement, compile_kwargs={"literal_binds": True}))
894def render_ddl_query(statement, dialect):
895 class LiteralCompiler(dialect.ddl_compiler):
896 def render_literal_value(self, value, type_):
897 if isinstance(value, (str, dt.date, dt.datetime, dt.timedelta)):
898 return "'{}'".format(str(value).replace("'", "''"))
900 return super(LiteralCompiler, self).render_literal_value(value, type_)
902 return str(LiteralCompiler(dialect, statement, compile_kwargs={"literal_binds": True}))