Source code for odata_query.sql.base

import logging
from typing import Optional

from odata_query import ast, exceptions, typing, visitor

log = logging.getLogger(__name__)


[docs]class AstToSqlVisitor(visitor.NodeVisitor): """ :class:`NodeVisitor` that transforms an :term:`AST` into a SQL ``WHERE`` clause. Based on SQL-99 as described here: https://crate.io/docs/sql-99/en/latest/ Args: table_alias: Optional alias for the root table. """ def __init__(self, table_alias: Optional[str] = None): super().__init__() self.table_alias = table_alias def visit_Identifier(self, node: ast.Identifier) -> str: ":meta private:" # Double quotes for column names acc SQL Standard sql_id = f'"{node.name}"' if self.table_alias: sql_id = f'"{self.table_alias}".' + sql_id return sql_id def visit_Null(self, node: ast.Null) -> str: ":meta private:" return "NULL" def visit_Integer(self, node: ast.Integer) -> str: ":meta private:" return node.val def visit_Float(self, node: ast.Float) -> str: ":meta private:" return node.val def visit_Boolean(self, node: ast.Boolean) -> str: ":meta private:" return node.val.upper() def visit_String(self, node: ast.String) -> str: ":meta private:" # Replace single quotes with double single-quotes acc SQL standard: val = node.val.replace("'", "''") # Wrap in single quotes for string constants acc SQL Standard return f"'{val}'" def visit_Date(self, node: ast.Date) -> str: ":meta private:" # Single quotes for date constants acc SQL Standard return f"DATE '{node.val}'" def visit_DateTime(self, node: ast.DateTime) -> str: ":meta private:" sql_ts = node.val.replace("T", " ") # Single quotes for datetime constants acc SQL Standard return f"TIMESTAMP '{sql_ts}'" def visit_Duration(self, node: ast.Duration) -> str: ":meta private:" sign, years, months, days, hours, minutes, seconds = node.unpack() sign = sign or "" intervals = [] if years: intervals.append(f"INTERVAL '{years}' YEAR") if months: intervals.append(f"INTERVAL '{months}' MONTH") if days: intervals.append(f"INTERVAL '{days}' DAY") if hours: intervals.append(f"INTERVAL '{hours}' HOUR") if minutes: intervals.append(f"INTERVAL '{minutes}' MINUTE") if seconds: intervals.append(f"INTERVAL '{seconds}' SECOND") if len(intervals) == 0: # Shouldn't occur but whatever return "" if len(intervals) == 1: return f"{sign}{intervals[0]}" if len(intervals) > 1: interval = " + ".join(intervals) return f"{sign}({interval})" # Make Quality checks happy: raise Exception("This code is never reachable...") def visit_GUID(self, node: ast.GUID) -> str: ":meta private:" return f"'{node.val}'" def visit_List(self, node: ast.List) -> str: ":meta private:" options = ", ".join(self.visit(n) for n in node.val) return f"({options})" def visit_Add(self, node: ast.Add) -> str: ":meta private:" return "+" def visit_Sub(self, node: ast.Sub) -> str: ":meta private:" return "-" def visit_Mult(self, node: ast.Mult) -> str: ":meta private:" return "*" def visit_Div(self, node: ast.Div) -> str: ":meta private:" return "/" def visit_Mod(self, node: ast.Mod) -> str: ":meta private:" return "%" def visit_BinOp(self, node: ast.BinOp) -> str: ":meta private:" left = self.visit(node.left) right = self.visit(node.right) op = self.visit(node.op) return f"{left} {op} {right}" def visit_Eq(self, node: ast.Eq) -> str: ":meta private:" return "=" def visit_NotEq(self, node: ast.NotEq) -> str: ":meta private:" return "!=" def visit_Lt(self, node: ast.Lt) -> str: ":meta private:" return "<" def visit_LtE(self, node: ast.LtE) -> str: ":meta private:" return "<=" def visit_Gt(self, node: ast.Gt) -> str: ":meta private:" return ">" def visit_GtE(self, node: ast.GtE) -> str: ":meta private:" return ">=" def visit_In(self, node: ast.In) -> str: ":meta private:" return "IN" def visit_Compare(self, node: ast.Compare) -> str: ":meta private:" left = self.visit(node.left) right = self.visit(node.right) comparator = self.visit(node.comparator) # In case of a subexpression, wrap it in parentheses if isinstance(node.left, (ast.BoolOp, ast.Compare)): left = f"({left})" if isinstance(node.right, (ast.BoolOp, ast.Compare)): right = f"({right})" # 'eq/ne null' should become 'IS (NOT) NULL' instead of '(!)= NULL' if isinstance(node.right, ast.Null): if isinstance(node.comparator, ast.Eq): comparator = "IS" elif isinstance(node.comparator, ast.NotEq): comparator = "IS NOT" return f"{left} {comparator} {right}" def visit_And(self, node: ast.And) -> str: ":meta private:" return "AND" def visit_Or(self, node: ast.Or) -> str: ":meta private:" return "OR" def visit_BoolOp(self, node: ast.BoolOp) -> str: ":meta private:" left = self.visit(node.left) op = self.visit(node.op) right = self.visit(node.right) # In case of a subexpression, wrap it in parentheses # UNLESS it has the same operator as the current BoolOp, e.g.: # x AND y AND z if isinstance(node.left, ast.BoolOp) and node.left.op != node.op: left = f"({left})" if isinstance(node.right, ast.BoolOp) and node.right.op != node.op: right = f"({right})" return f"{left} {op} {right}" def visit_Not(self, node: ast.Not) -> str: ":meta private:" return "NOT" def visit_UnaryOp(self, node: ast.UnaryOp) -> str: ":meta private:" op = self.visit(node.op) operand = self.visit(node.operand) # In case of a subexpression, wrap it in parentheses if isinstance(node.operand, ast.BoolOp): operand = f"({operand})" return f"{op} {operand}" def visit_Call(self, node: ast.Call) -> str: ":meta private:" try: # Grammar has already validated that the function is valid OData, # but that doesn't guarantee we can represent it in SQL: sql_gen = getattr(self, "sqlfunc_" + node.func.name.lower()) except AttributeError: raise exceptions.UnsupportedFunctionException(node.func.name) return sql_gen(*node.args) def sqlfunc_concat(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] return f"{args_sql[0]} || {args_sql[1]}" def _to_pattern(self, arg: ast._Node, prefix: str = "", suffix: str = "") -> str: """ Transform a node into a pattern usable in `LIKE` clauses. :meta private: """ if isinstance(arg, (ast.Identifier, ast.Call)): res = self.visit(arg) if prefix: res = f"'{prefix}' || " + res if suffix: res = res + f" || '{suffix}'" else: res = str(arg.val).replace("%", "%%").replace("_", "__") # type: ignore res = "'" + prefix + res + suffix + "'" return res def sqlfunc_contains(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = [typing.infer_type(arg) for arg in args] # If any of the inputs is a string or default, assume str-contains: if any(typ is ast.String for typ in inferred_type) or all( typ is None for typ in inferred_type ): pattern = self._to_pattern(args[1], prefix="%", suffix="%") return f"{args_sql[0]} LIKE {pattern}" # If any of the inputs is a list, assume list-contains: if any(typ is ast.List for typ in inferred_type): raise exceptions.UnsupportedFunctionException("contains<List>") raise exceptions.ArgumentTypeException("contains") def sqlfunc_endswith(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = [typing.infer_type(arg) for arg in args] # If any of the inputs is a string or default, assume str-endswith: if any(typ is ast.String for typ in inferred_type) or all( typ is None for typ in inferred_type ): pattern = self._to_pattern(args[1], prefix="%") return f"{args_sql[0]} LIKE {pattern}" # If any of the inputs is a list, assume list-endswith # which isn't easily doable at the moment: if any(typ is ast.List for typ in inferred_type): raise exceptions.UnsupportedFunctionException("endswith<List>") raise exceptions.ArgumentTypeException("endswith") def sqlfunc_indexof(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = [typing.infer_type(arg) for arg in args] # If any of the inputs is a string, assume str-indexof: if any(typ is ast.String for typ in inferred_type) or all( typ is None for typ in inferred_type ): return f"POSITION({args_sql[1]} IN {args_sql[0]}) - 1" # If any of the inputs is a list, assume list-indexof # which isn't easily doable at the moment: if any(typ is ast.List for typ in inferred_type): raise exceptions.UnsupportedFunctionException("indexof<List>") raise exceptions.ArgumentTypeException("indexof") def sqlfunc_length(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) inferred_type = typing.infer_type(arg) # If the input is a string or default, assume str-length: if inferred_type is ast.String or inferred_type is None: return f"CHAR_LENGTH({arg_sql})" # If the input is a list, assume list-length: if inferred_type is ast.List: return f"CARDINALITY({arg_sql})" raise exceptions.ArgumentTypeException("length") def sqlfunc_startswith(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = [typing.infer_type(arg) for arg in args] # If any of the inputs is a string or default, assume str-startswith: if any(typ is ast.String for typ in inferred_type) or all( typ is None for typ in inferred_type ): pattern = self._to_pattern(args[1], suffix="%") return f"{args_sql[0]} LIKE {pattern}" # If any of the inputs is a list, assume list-startswith # which isn't easily doable at the moment: if any(typ is ast.List for typ in inferred_type): raise exceptions.UnsupportedFunctionException("startswith<List>") raise exceptions.ArgumentTypeException("startswith") def sqlfunc_substring(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = typing.infer_type(args[0]) # If the first input is a string or default, assume str-substr: if inferred_type is ast.String or inferred_type is None: if len(args) == 2: return f"SUBSTRING({args_sql[0]} FROM {args_sql[1]} + 1)" if len(args) == 3: return ( f"SUBSTRING({args_sql[0]} FROM {args_sql[1]} + 1 FOR {args_sql[2]})" ) # If the first input is a list, assume list-substr: if inferred_type is ast.List: raise exceptions.UnsupportedFunctionException("substring<List>") raise exceptions.ArgumentTypeException("substring") def sqlfunc_tolower(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"LOWER({arg_sql})" def sqlfunc_toupper(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"UPPER({arg_sql})" def sqlfunc_trim(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"TRIM({arg_sql})" def sqlfunc_year(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"EXTRACT (YEAR FROM {arg_sql})" def sqlfunc_month(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"EXTRACT (MONTH FROM {arg_sql})" def sqlfunc_day(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"EXTRACT (DAY FROM {arg_sql})" def sqlfunc_hour(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"EXTRACT (HOUR FROM {arg_sql})" def sqlfunc_minute(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"EXTRACT (MINUTE FROM {arg_sql})" def sqlfunc_date(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"CAST ({arg_sql} AS DATE)" def sqlfunc_now(self) -> str: ":meta private:" return "CURRENT_TIMESTAMP" def sqlfunc_round(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"CAST ({arg_sql} + 0.5 AS INTEGER)" def sqlfunc_floor(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"""CASE {arg_sql} WHEN > 0 CAST ({arg_sql} AS INTEGER) WHEN < 0 CAST (0 - (ABS({arg_sql}) + 0.5) AS INTEGER)) ELSE {arg_sql} END""" def sqlfunc_ceiling(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"""CASE {arg_sql} - CAST ({arg_sql} AS INTEGER) WHEN > 0 {arg_sql}+1 WHEN < 0 {arg_sql}-1 ELSE {arg_sql} END""" def sqlfunc_hassubset(self, *args: ast._Node) -> str: ":meta private:" raise exceptions.UnsupportedFunctionException("hassubset")