Source code for

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

ETL DSL parser.
import abc
import collections
import functools
import logging
import types
import typing

from import dsl
from import function

if typing.TYPE_CHECKING:
    from import parser  # pylint: disable=import-self

LOGGER = logging.getLogger(__name__)

#: Generic storage-native representation of :class:`dsl.Source <>`.
Source = typing.TypeVar('Source')
#: Generic storage-native representation of :class:`dsl.Feature <>`.
Feature = typing.TypeVar('Feature')
#: Generic storage-native representation of any instruction.
Symbol = typing.TypeVar('Symbol')

class Container(typing.Generic[Symbol]):
    """Base parser structure.

    When used as a context manager the internal structure is exclusive to given context and is
    checked for total depletion on exit.

    class Context:
        """Storage context."""

        class Symbols:
            """Stack for parsed symbols."""

            def __init__(self):
                self._stack: list[Symbol] = []

            def __bool__(self):
                return bool(self._stack)

            def push(self, item: Symbol) -> None:
                """Push new parsed item to the stack.

                    item: Item to be added.

            def pop(self) -> Symbol:
                """Remove and return a value from the top of the stack.

                    Item from the stack top.
                if not self._stack:
                    raise RuntimeError('Empty context')
                return self._stack.pop()

        class Tables:
            """Container for segments of all tables."""

            class Segment(collections.namedtuple('Segment', 'fields, factors')):
                """Frame segment specification as a list of features (vertical) and row predicates

                def __new__(cls):
                    return super().__new__(cls, set(), set())

                def predicate(self) -> typing.Optional['dsl.Predicate']:
                    """Combine the factors into single predicate.

                        Predicate expression.
                    return functools.reduce(function.Or, sorted(self.factors)) if self.factors else None

            def __init__(self):
                self._segments: dict['dsl.Table', Container.Context.Tables.Segment] = collections.defaultdict(

            def items(self) -> typing.ItemsView['dsl.Table', 'Container.Context.Tables.Segment']:
                """Get the key-value pairs of this mapping.

                    Key-value mapping items.
                return self._segments.items()

            def __getitem__(self, table: 'dsl.Table') -> 'Container.Context.Tables.Segment':
                return self._segments[table]

            def select(self, *feature: 'dsl.Feature') -> None:
                """Extract fields from given list of features and register them into segments of
                their relevant tables.

                    feature: Features to be to extracted and registered.
                for field in dsl.Column.dissect(*feature):

            def filter(self, expression: 'dsl.Predicate') -> None:
                """Extract predicate factors from given expression and register them into segments
                of their relevant tables. Also register the whole expression using :attr:`select`.

                    expression: Expression to be extracted and registered.
                for table, factor in expression.factors.items():

        def __init__(self):
            self.symbols: Container.Context.Symbols = self.Symbols()
            self.tables: Container.Context.Tables = self.Tables()
   dict['dsl.Origin', 'parser.Source'] = {}

        def dirty(self) -> bool:
            """Check the context is safe to be closed.

                True if not safe for closing.
            return bool(self.symbols)

    def __init__(self):
        self._context: typing.Optional[Container.Context] = None
        self._stack: list[Container.Context] = []

    def context(self) -> 'Container.Context':
        """Context accessor."""
        if not self._context:
            raise RuntimeError('Invalid context')
        return self._context

    def __enter__(self) -> 'Container':
        self._context = self.Context()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type or exc_val or exc_tb:
        if self._context and self._context.dirty:
            raise RuntimeError('Context not fetched')
        self._context = self._stack.pop()

    def fetch(self) -> Symbol:
        """Storage retrieval.

        Must be called exactly once and at the point where there is exactly one symbol pending
        in the context. Successful fetch will kill the context.

            Last symbol from the context.
        symbol = self._context.symbols.pop()
        if self._context.dirty:
            raise RuntimeError('Premature fetch')
        self._context = None
        return symbol

def bypass(override: typing.Callable[[Container, typing.Any], 'parser.Source']) -> typing.Callable:
    """Bypass the (result of) the particular visit_* implementation if the supplied override
    resolver provides an alternative value.

        override: Callable resolver that returns an explicit value for given subject or raises
                  ``KeyError`` for unknown mapping.

        Visitor method decorator.

    def decorator(method: typing.Callable[[Container, typing.Any], typing.ContextManager[None]]) -> typing.Callable:
        """Visitor method decorator with added bypassing capability.

            method: Visitor method to be decorated.

            Decorated version of the visit_* method.

        def wrapped(self: Container, subject: typing.Any) -> None:
            """Decorated version of the visit_* method.

                self: Visitor instance.
                subject: Visited subject.
            method(self, subject)
                new = override(self, subject)
            except dsl.UnprovisionedError:
                old = self.context.symbols.pop()
                LOGGER.debug('Overriding result for %s (%s -> %s)', subject, old, new)

        return wrapped

    return decorator

[docs]class Visitor( typing.Generic[Source, Feature], Container[typing.Union[Source, Feature]], dsl.Source.Visitor, dsl.Feature.Visitor, metaclass=abc.ABCMeta, ): """Abstract base class for DSL query statement parser implementations. In this context, *parsing* essentially means conversion between the generic DSL-based instance of the particular query and its native representation matching a selected target storage layer. Conceptually, the parser is implemented as a combination of a *visitor* traversing the query statement structure and a *push-down automaton* assembling the generated instructions in their storage-native representation. The parser assumes resolving the native representation of all the leaves of the query statement tree (the :class:`dsl.Table <>` and :class:`dsl.Column <>` instances) or possibly entire branches can be accomplished via the provided initial mappings from which the parser builds the complete query up. Upon failing to resolve any particular *source*/*feature* using the initial mappings, the parser raises the :exc:`dsl.UnprovisionedError <>` indicating unavailability of the given data source. Args: sources: Explicit mapping of generic DSL *sources* (typically :class:`dsl.Table <>`) to their native representations. features: Explicit mapping of generic DSL *features* (typically :class:`dsl.Column <>`) to their native representations. """ def __init__( self, sources: typing.Mapping['dsl.Source', 'parser.Source'], features: typing.Mapping['dsl.Feature', 'parser.Feature'], ): super().__init__() self._sources: typing.Mapping['dsl.Source', 'parser.Source'] = types.MappingProxyType(sources) self._features: typing.Mapping['dsl.Feature', 'parser.Feature'] = types.MappingProxyType(features) def resolve_feature(self, feature: 'dsl.Feature') -> 'parser.Feature': """Get a custom target code for a feature value. Args: feature: Feature instance. Returns: Feature in target code representation. """ try: return self._features[feature] except KeyError as err: raise dsl.UnprovisionedError(f'Unknown mapping for feature {feature}') from err @functools.lru_cache def generate_feature(self, feature: 'dsl.Feature') -> 'parser.Feature': """Generate target code for the generic feature type. Args: feature: Feature instance Returns: Feature in target code. """ feature.accept(self) return self.context.symbols.pop() @abc.abstractmethod def generate_element(self, origin: 'parser.Source', element: 'parser.Feature') -> 'parser.Feature': """Generate an element code. Args: origin: Origin value already in target code. element: Element symbol to be used for given feature. Returns: Element in target code. """ @abc.abstractmethod def generate_alias(self, feature: 'parser.Feature', alias: str) -> 'parser.Feature': """Generate feature alias code. Args: feature: Feature value already in target code. alias: Alias to be used for given feature. Returns: Aliased feature in target code. """ @abc.abstractmethod def generate_literal(self, value: typing.Any, kind: 'dsl.Any') -> 'parser.Feature': """Generate target code for a literal value. Args: value: Literal value instance. kind: Literal value type. Returns: Literal in target code representation. """ @abc.abstractmethod def generate_expression( self, expression: type['dsl.Expression'], arguments: typing.Sequence[typing.Any] ) -> 'parser.Feature': """Generate target code for an expression of given arguments. Args: expression: Operator or function implementing the expression. arguments: Expression arguments. Returns: Expression in target code representation. """ def visit_aliased(self, feature: 'dsl.Aliased') -> None: super().visit_aliased(feature) self.context.symbols.push(self.generate_alias(self.context.symbols.pop(), def visit_literal(self, feature: 'dsl.Literal') -> None: super().visit_literal(feature) self.context.symbols.push(self.generate_literal(feature.value, feature.kind)) def visit_element(self, feature: 'dsl.Element') -> None: super().visit_element(feature) self.context.symbols.push( self.generate_element([feature.origin], self.resolve_feature(feature)) ) @bypass(resolve_feature) def visit_expression(self, feature: 'dsl.Expression') -> None: super().visit_expression(feature) arguments = tuple( reversed([self.context.symbols.pop() if isinstance(c, dsl.Feature) else c for c in reversed(feature)]) ) self.context.symbols.push(self.generate_expression(feature.__class__, arguments)) @bypass(resolve_feature) def visit_window(self, feature: 'dsl.Window') -> typing.ContextManager[None]: raise RuntimeError('Window functions not yet supported') def resolve_source(self, source: 'dsl.Source') -> 'parser.Source': """Get a custom target code for a source type. Args: source: Source instance. Returns: Target code for the source instance. """ try: return self._sources[source] except KeyError as err: raise dsl.UnprovisionedError(f'Unknown mapping for source {source}') from err def generate_table( self, table: 'parser.Source', features: typing.Iterable['parser.Feature'], # pylint: disable=unused-argument predicate: typing.Optional['parser.Feature'], # pylint: disable=unused-argument ) -> 'parser.Source': # pylint: disable=unused-argument """Generate a target code for a table instance given its actual field requirements. Args: table: Table (already in target code based on the provided mapping) to be generated. features: List of fields to be retrieved from the table (potentially subset of all available). predicate: Row filter to be possibly pushed down when retrieving the data from given table. Returns: Table target code potentially optimized based on field requirements. """ return table @abc.abstractmethod def generate_reference(self, instance: 'parser.Source', name: str) -> tuple['parser.Source', 'parser.Source']: """Generate reference code. Args: instance: Instance value already in target code. name: Reference name. Returns: Tuple of referenced origin and the bare reference handle both in target code. """ @abc.abstractmethod def generate_join( self, left: 'parser.Source', right: 'parser.Source', condition: typing.Optional['parser.Feature'], kind: 'dsl.Join.Kind', ) -> 'parser.Source': """Generate target code for a join operation using the left/right terms, given condition and a join type. Args: left: Left side of the join pair. right: Right side of the join pair. condition: Join condition. kind: Join type. Returns: Target code for the join operation. """ @abc.abstractmethod def generate_set(self, left: 'parser.Source', right: 'parser.Source', kind: 'dsl.Set.Kind') -> 'parser.Source': """Generate target code for a set operation using the left/right terms, given a set type. Args: left: Left side of the set pair. right: Right side of the set pair. kind: Set type. Returns: Target code for the set operation. """ @abc.abstractmethod def generate_query( self, source: 'parser.Source', features: typing.Sequence['parser.Feature'], where: typing.Optional['parser.Feature'], groupby: typing.Sequence['parser.Feature'], having: typing.Optional['parser.Feature'], orderby: typing.Sequence[tuple['parser.Feature', 'dsl.Ordering.Direction']], rows: typing.Optional['dsl.Rows'], ) -> 'parser.Source': """Generate query statement code. Args: source: Source already in target code. features: Sequence of selected features in target code. where: Where condition in target code. groupby: Sequence of grouping specifiers in target code. having: Having condition in target code. orderby: Ordering specifier in target code. rows: Limit spec tuple. Returns: Query in target code. """ def visit_table(self, source: 'dsl.Table') -> None:[source] = origin = self.resolve_source(source) features = [self.generate_feature(f) for f in sorted(self.context.tables[source].fields)] predicate = self.context.tables[source].predicate if predicate is not None: predicate = self.generate_feature(predicate) super().visit_table(source) self.context.symbols.push(self.generate_table(origin, features, predicate)) def visit_reference(self, source: 'dsl.Reference') -> None: super().visit_reference(source) origin, handle = self.generate_reference(self.context.symbols.pop(),[source] = handle self.context.symbols.push(origin) @bypass(resolve_source) def visit_join(self, source: 'dsl.Join') -> None: if source.condition: self.context.tables.filter(source.condition) super().visit_join(source) right = self.context.symbols.pop() left = self.context.symbols.pop() expression = self.generate_feature(source.condition) if source.condition is not None else None self.context.symbols.push(self.generate_join(left, right, expression, source.kind)) @bypass(resolve_source) def visit_set(self, source: 'dsl.Set') -> None: super().visit_set(source) right = self.context.symbols.pop() left = self.context.symbols.pop() self.context.symbols.push(self.generate_set(left, right, source.kind)) @bypass(resolve_source) def visit_query(self, source: 'dsl.Query') -> None: with self:*source.features) if source.prefilter is not None: self.context.tables.filter(source.prefilter) if source.postfilter is not None:*source.grouping)*(c for c, _ in source.ordering)) super().visit_query(source) features = [self.generate_feature(c) for c in source.features] where = self.generate_feature(source.prefilter) if source.prefilter is not None else None groupby = [self.generate_feature(c) for c in source.grouping] having = self.generate_feature(source.postfilter) if source.postfilter is not None else None orderby = [(self.generate_feature(c), o) for c, o in source.ordering] query = self.generate_query( self.context.symbols.pop(), features, where, groupby, having, orderby, source.rows ) self.context.symbols.push(query)