# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ETL DSL parser.
"""
import abc
import collections
import functools
import logging
import types
import typing
from forml.io import dsl
from forml.io.dsl import function
if typing.TYPE_CHECKING:
from forml.io.dsl import parser # pylint: disable=import-self
LOGGER = logging.getLogger(__name__)
#: Generic storage-native representation of :class:`dsl.Source <forml.io.dsl.Source>`.
Source = typing.TypeVar('Source')
#: Generic storage-native representation of :class:`dsl.Feature <forml.io.dsl.Feature>`.
Feature = typing.TypeVar('Feature')
#: Generic storage-native representation of any instruction.
Symbol = typing.TypeVar('Symbol')
class Container(typing.Generic[Symbol]):
"""Base parser structure.
When used as a context manager the internal structure is exclusive to given context and is
checked for total depletion on exit.
"""
class Context:
"""Storage context."""
class Symbols:
"""Stack for parsed symbols."""
def __init__(self):
self._stack: list[Symbol] = []
def __bool__(self):
return bool(self._stack)
def push(self, item: Symbol) -> None:
"""Push new parsed item to the stack.
Args:
item: Item to be added.
"""
self._stack.append(item)
def pop(self) -> Symbol:
"""Remove and return a value from the top of the stack.
Returns:
Item from the stack top.
"""
if not self._stack:
raise RuntimeError('Empty context')
return self._stack.pop()
class Tables:
"""Container for segments of all tables."""
class Segment(collections.namedtuple('Segment', 'fields, factors')):
"""Frame segment specification as a list of features (vertical) and row predicates
(horizontal).
"""
def __new__(cls):
return super().__new__(cls, set(), set())
@property
def predicate(self) -> typing.Optional['dsl.Predicate']:
"""Combine the factors into single predicate.
Returns:
Predicate expression.
"""
return functools.reduce(function.Or, sorted(self.factors)) if self.factors else None
def __init__(self):
self._segments: dict['dsl.Table', Container.Context.Tables.Segment] = collections.defaultdict(
self.Segment
)
def items(self) -> typing.ItemsView['dsl.Table', 'Container.Context.Tables.Segment']:
"""Get the key-value pairs of this mapping.
Returns:
Key-value mapping items.
"""
return self._segments.items()
def __getitem__(self, table: 'dsl.Table') -> 'Container.Context.Tables.Segment':
return self._segments[table]
def select(self, *feature: 'dsl.Feature') -> None:
"""Extract fields from given list of features and register them into segments of
their relevant tables.
Args:
feature: Features to be to extracted and registered.
"""
for field in dsl.Column.dissect(*feature):
self[field.origin].fields.add(field)
def filter(self, expression: 'dsl.Predicate') -> None:
"""Extract predicate factors from given expression and register them into segments
of their relevant tables. Also register the whole expression using :attr:`select`.
Args:
expression: Expression to be extracted and registered.
"""
self.select(expression)
for table, factor in expression.factors.items():
self[table].factors.add(factor)
def __init__(self):
self.symbols: Container.Context.Symbols = self.Symbols()
self.tables: Container.Context.Tables = self.Tables()
self.origins: dict['dsl.Origin', 'parser.Source'] = {}
@property
def dirty(self) -> bool:
"""Check the context is safe to be closed.
Returns:
True if not safe for closing.
"""
return bool(self.symbols)
def __init__(self):
self._context: typing.Optional[Container.Context] = None
self._stack: list[Container.Context] = []
@property
def context(self) -> 'Container.Context':
"""Context accessor."""
if not self._context:
raise RuntimeError('Invalid context')
return self._context
def __enter__(self) -> 'Container':
self._stack.append(self._context)
self._context = self.Context()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type or exc_val or exc_tb:
return
if self._context and self._context.dirty:
raise RuntimeError('Context not fetched')
self._context = self._stack.pop()
def fetch(self) -> Symbol:
"""Storage retrieval.
Must be called exactly once and at the point where there is exactly one symbol pending
in the context. Successful fetch will kill the context.
Returns:
Last symbol from the context.
"""
symbol = self._context.symbols.pop()
if self._context.dirty:
raise RuntimeError('Premature fetch')
self._context = None
return symbol
def bypass(override: typing.Callable[[Container, typing.Any], 'parser.Source']) -> typing.Callable:
"""Bypass the (result of) the particular visit_* implementation if the supplied override
resolver provides an alternative value.
Args:
override: Callable resolver that returns an explicit value for given subject or raises
``KeyError`` for unknown mapping.
Returns:
Visitor method decorator.
"""
def decorator(method: typing.Callable[[Container, typing.Any], typing.ContextManager[None]]) -> typing.Callable:
"""Visitor method decorator with added bypassing capability.
Args:
method: Visitor method to be decorated.
Returns:
Decorated version of the visit_* method.
"""
@functools.wraps(method)
def wrapped(self: Container, subject: typing.Any) -> None:
"""Decorated version of the visit_* method.
Args:
self: Visitor instance.
subject: Visited subject.
"""
method(self, subject)
try:
new = override(self, subject)
except dsl.UnprovisionedError:
pass
else:
old = self.context.symbols.pop()
LOGGER.debug('Overriding result for %s (%s -> %s)', subject, old, new)
self.context.symbols.push(new)
return wrapped
return decorator
[docs]class Visitor(
typing.Generic[Source, Feature],
Container[typing.Union[Source, Feature]],
dsl.Source.Visitor,
dsl.Feature.Visitor,
metaclass=abc.ABCMeta,
):
"""Abstract base class for DSL query statement parser implementations.
In this context, *parsing* essentially means conversion between the generic DSL-based
instance of the particular query and its native representation matching a selected target
storage layer.
Conceptually, the parser is implemented as a combination of a *visitor* traversing the query
statement structure and a *push-down automaton* assembling the generated instructions in their
storage-native representation.
The parser assumes resolving the native representation of all the leaves of the query statement
tree (the :class:`dsl.Table <forml.io.dsl.Table>` and :class:`dsl.Column <forml.io.dsl.Column>`
instances) or possibly entire branches can be accomplished via the provided initial mappings
from which the parser builds the complete query up.
Upon failing to resolve any particular *source*/*feature* using the initial mappings, the parser
raises the :exc:`dsl.UnprovisionedError <forml.io.dsl.UnprovisionedError>` indicating
unavailability of the given data source.
Args:
sources: Explicit mapping of generic DSL *sources* (typically :class:`dsl.Table
<forml.io.dsl.Table>`) to their native representations.
features: Explicit mapping of generic DSL *features* (typically :class:`dsl.Column
<forml.io.dsl.Column>`) to their native representations.
"""
def __init__(
self,
sources: typing.Mapping['dsl.Source', 'parser.Source'],
features: typing.Mapping['dsl.Feature', 'parser.Feature'],
):
super().__init__()
self._sources: typing.Mapping['dsl.Source', 'parser.Source'] = types.MappingProxyType(sources)
self._features: typing.Mapping['dsl.Feature', 'parser.Feature'] = types.MappingProxyType(features)
def resolve_feature(self, feature: 'dsl.Feature') -> 'parser.Feature':
"""Get a custom target code for a feature value.
Args:
feature: Feature instance.
Returns:
Feature in target code representation.
"""
try:
return self._features[feature]
except KeyError as err:
raise dsl.UnprovisionedError(f'Unknown mapping for feature {feature}') from err
@functools.lru_cache
def generate_feature(self, feature: 'dsl.Feature') -> 'parser.Feature':
"""Generate target code for the generic feature type.
Args:
feature: Feature instance
Returns:
Feature in target code.
"""
feature.accept(self)
return self.context.symbols.pop()
@abc.abstractmethod
def generate_element(self, origin: 'parser.Source', element: 'parser.Feature') -> 'parser.Feature':
"""Generate an element code.
Args:
origin: Origin value already in target code.
element: Element symbol to be used for given feature.
Returns:
Element in target code.
"""
@abc.abstractmethod
def generate_alias(self, feature: 'parser.Feature', alias: str) -> 'parser.Feature':
"""Generate feature alias code.
Args:
feature: Feature value already in target code.
alias: Alias to be used for given feature.
Returns:
Aliased feature in target code.
"""
@abc.abstractmethod
def generate_literal(self, value: typing.Any, kind: 'dsl.Any') -> 'parser.Feature':
"""Generate target code for a literal value.
Args:
value: Literal value instance.
kind: Literal value type.
Returns:
Literal in target code representation.
"""
@abc.abstractmethod
def generate_expression(
self, expression: type['dsl.Expression'], arguments: typing.Sequence[typing.Any]
) -> 'parser.Feature':
"""Generate target code for an expression of given arguments.
Args:
expression: Operator or function implementing the expression.
arguments: Expression arguments.
Returns:
Expression in target code representation.
"""
def visit_aliased(self, feature: 'dsl.Aliased') -> None:
super().visit_aliased(feature)
self.context.symbols.push(self.generate_alias(self.context.symbols.pop(), feature.name))
def visit_literal(self, feature: 'dsl.Literal') -> None:
super().visit_literal(feature)
self.context.symbols.push(self.generate_literal(feature.value, feature.kind))
def visit_element(self, feature: 'dsl.Element') -> None:
super().visit_element(feature)
self.context.symbols.push(
self.generate_element(self.context.origins[feature.origin], self.resolve_feature(feature))
)
@bypass(resolve_feature)
def visit_expression(self, feature: 'dsl.Expression') -> None:
super().visit_expression(feature)
arguments = tuple(
reversed([self.context.symbols.pop() if isinstance(c, dsl.Feature) else c for c in reversed(feature)])
)
self.context.symbols.push(self.generate_expression(feature.__class__, arguments))
@bypass(resolve_feature)
def visit_window(self, feature: 'dsl.Window') -> typing.ContextManager[None]:
raise RuntimeError('Window functions not yet supported')
def resolve_source(self, source: 'dsl.Source') -> 'parser.Source':
"""Get a custom target code for a source type.
Args:
source: Source instance.
Returns:
Target code for the source instance.
"""
try:
return self._sources[source]
except KeyError as err:
raise dsl.UnprovisionedError(f'Unknown mapping for source {source}') from err
def generate_table(
self,
table: 'parser.Source',
features: typing.Iterable['parser.Feature'], # pylint: disable=unused-argument
predicate: typing.Optional['parser.Feature'], # pylint: disable=unused-argument
) -> 'parser.Source': # pylint: disable=unused-argument
"""Generate a target code for a table instance given its actual field requirements.
Args:
table: Table (already in target code based on the provided mapping) to be generated.
features: List of fields to be retrieved from the table (potentially subset of all
available).
predicate: Row filter to be possibly pushed down when retrieving the data from given
table.
Returns:
Table target code potentially optimized based on field requirements.
"""
return table
@abc.abstractmethod
def generate_reference(self, instance: 'parser.Source', name: str) -> tuple['parser.Source', 'parser.Source']:
"""Generate reference code.
Args:
instance: Instance value already in target code.
name: Reference name.
Returns:
Tuple of referenced origin and the bare reference handle both in target code.
"""
@abc.abstractmethod
def generate_join(
self,
left: 'parser.Source',
right: 'parser.Source',
condition: typing.Optional['parser.Feature'],
kind: 'dsl.Join.Kind',
) -> 'parser.Source':
"""Generate target code for a join operation using the left/right terms, given condition
and a join type.
Args:
left: Left side of the join pair.
right: Right side of the join pair.
condition: Join condition.
kind: Join type.
Returns:
Target code for the join operation.
"""
@abc.abstractmethod
def generate_set(self, left: 'parser.Source', right: 'parser.Source', kind: 'dsl.Set.Kind') -> 'parser.Source':
"""Generate target code for a set operation using the left/right terms, given a set type.
Args:
left: Left side of the set pair.
right: Right side of the set pair.
kind: Set type.
Returns:
Target code for the set operation.
"""
@abc.abstractmethod
def generate_query(
self,
source: 'parser.Source',
features: typing.Sequence['parser.Feature'],
where: typing.Optional['parser.Feature'],
groupby: typing.Sequence['parser.Feature'],
having: typing.Optional['parser.Feature'],
orderby: typing.Sequence[tuple['parser.Feature', 'dsl.Ordering.Direction']],
rows: typing.Optional['dsl.Rows'],
) -> 'parser.Source':
"""Generate query statement code.
Args:
source: Source already in target code.
features: Sequence of selected features in target code.
where: Where condition in target code.
groupby: Sequence of grouping specifiers in target code.
having: Having condition in target code.
orderby: Ordering specifier in target code.
rows: Limit spec tuple.
Returns:
Query in target code.
"""
def visit_table(self, source: 'dsl.Table') -> None:
self.context.origins[source] = origin = self.resolve_source(source)
features = [self.generate_feature(f) for f in sorted(self.context.tables[source].fields)]
predicate = self.context.tables[source].predicate
if predicate is not None:
predicate = self.generate_feature(predicate)
super().visit_table(source)
self.context.symbols.push(self.generate_table(origin, features, predicate))
def visit_reference(self, source: 'dsl.Reference') -> None:
super().visit_reference(source)
origin, handle = self.generate_reference(self.context.symbols.pop(), source.name)
self.context.origins[source] = handle
self.context.symbols.push(origin)
@bypass(resolve_source)
def visit_join(self, source: 'dsl.Join') -> None:
if source.condition:
self.context.tables.filter(source.condition)
super().visit_join(source)
right = self.context.symbols.pop()
left = self.context.symbols.pop()
expression = self.generate_feature(source.condition) if source.condition is not None else None
self.context.symbols.push(self.generate_join(left, right, expression, source.kind))
@bypass(resolve_source)
def visit_set(self, source: 'dsl.Set') -> None:
super().visit_set(source)
right = self.context.symbols.pop()
left = self.context.symbols.pop()
self.context.symbols.push(self.generate_set(left, right, source.kind))
@bypass(resolve_source)
def visit_query(self, source: 'dsl.Query') -> None:
with self:
self.context.tables.select(*source.features)
if source.prefilter is not None:
self.context.tables.filter(source.prefilter)
if source.postfilter is not None:
self.context.tables.select(source.postfilter)
self.context.tables.select(*source.grouping)
self.context.tables.select(*(c for c, _ in source.ordering))
super().visit_query(source)
features = [self.generate_feature(c) for c in source.features]
where = self.generate_feature(source.prefilter) if source.prefilter is not None else None
groupby = [self.generate_feature(c) for c in source.grouping]
having = self.generate_feature(source.postfilter) if source.postfilter is not None else None
orderby = [(self.generate_feature(c), o) for c, o in source.ordering]
query = self.generate_query(
self.context.symbols.pop(), features, where, groupby, having, orderby, source.rows
)
self.context.symbols.push(query)