Source code for odata_query.sql.athena

import re

from odata_query import ast, exceptions, typing

from .base import AstToSqlVisitor

UNSAFE_CHARS = re.compile(r"[^a-zA-Z0-9_]")


[docs]def clean_athena_identifier(identifier: str) -> str: """ Cleans an Athena identifier so it passes the following validation rules: - Table names and table column names in Athena must be lowercase - Athena table, view, database, and column names allow only underscore special characters - Names should be quoted or backticked when starting with a number or underscore Source: https://docs.aws.amazon.com/athena/latest/ug/tables-databases-columns-names.html """ id_new = identifier.lower() id_new = UNSAFE_CHARS.sub("_", id_new) return id_new
[docs]class AstToAthenaSqlVisitor(AstToSqlVisitor): """ :class:`NodeVisitor` that transforms an :term:`AST` into an Athena SQL ``WHERE`` clause. Args: table_alias: Optional alias for the root table. """ def visit_Identifier(self, node: ast.Identifier) -> str: ":meta private:" # Double quotes for column names acc SQL Standard sql_id = f'"{clean_athena_identifier(node.name)}"' if self.table_alias: sql_id = f'"{self.table_alias}".' + sql_id return sql_id def visit_DateTime(self, node: ast.DateTime) -> str: ":meta private:" return f"FROM_ISO8601_TIMESTAMP('{node.val}')" def sqlfunc_length(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) inferred_type = typing.infer_type(arg) # If the input is a string or default, assume str-length: if inferred_type is ast.String or inferred_type is None: return f"LENGTH({arg_sql})" # If the input is a list, assume list-length: if inferred_type is ast.List: return f"CARDINALITY({arg_sql})" raise exceptions.ArgumentTypeException("length") def sqlfunc_substring(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] inferred_type = typing.infer_type(args[0]) # If the first input is a string or default, assume str-substr: if inferred_type is ast.String or inferred_type is None: if len(args) == 2: return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1)" if len(args) == 3: return f"SUBSTR({args_sql[0]}, {args_sql[1]} + 1, {args_sql[2]})" # If the first input is a list, assume list-substr: if inferred_type is ast.List: if len(args) == 2: return f"SLICE({args_sql[0]}, {args_sql[1]})" if len(args) == 3: return f"SLICE({args_sql[0]}, {args_sql[1]}, {args_sql[2]})" raise exceptions.ArgumentTypeException("substring") def sqlfunc_round(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"ROUND({arg_sql})" def sqlfunc_floor(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"FLOOR({arg_sql})" def sqlfunc_ceiling(self, arg: ast._Node) -> str: ":meta private:" arg_sql = self.visit(arg) return f"CEILING({arg_sql})" def sqlfunc_hassubset(self, *args: ast._Node) -> str: ":meta private:" args_sql = [self.visit(arg) for arg in args] return f"CARDINALITY(ARRAY_INTERSECT({args_sql[0]}, {args_sql[1]})) = CARDINALITY({args_sql[1]})"