Source code for forml.provider.feed.alchemy

SQLAlchemy based feed implementation.
import functools
import hashlib
import logging
import pathlib
import re
import types
import typing

import pandas
import sqlalchemy
from sqlalchemy import sql

import forml
from forml import io, setup
from import dsl as dslmod
from forml.provider.feed.reader import alchemy

if typing.TYPE_CHECKING:
    from import dsl  # pylint: disable=reimported

LOGGER = logging.getLogger(__name__)

class Results:
    """Filesystem backed result cache."""

    def __init__(self, path: pathlib.Path):
        self._frames: dict[str, pandas.DataFrame] = {}
        self._path: pathlib.Path = path
        self._path.mkdir(parents=True, exist_ok=True)

    def _statement2key(statement: sql.Selectable) -> str:
        """Get the key for the given statement.

            statement: Query statement.

            Query key.
        return hashlib.sha256(str(statement.compile(compile_kwargs={'literal_binds': True})).encode()).hexdigest()

    def _key2path(self, key: str) -> pathlib.Path:
        """Get the filesystem path for the given key.

            key: Query key.

            Filesystem path for the cached query.
        return self._path / f'{key}.parquet'

    def exists(self, statement: sql.Selectable) -> bool:
        """Check the result can be provided without execution.

            statement: Query statement.

            True if the query result is already known.
        key = self._statement2key(statement)
        return key in self._frames or self._key2path(key).exists()

    def get_or_exec(
        self, statement: sql.Selectable, loader: typing.Callable[[sql.Selectable], pandas.DataFrame]
    ) -> pandas.DataFrame:
        """Get the result from the cache or execute the loader.

            statement: Query statement representing the expected result.
            loader: Callback for loading the result data.

            Query result as a Pandas dataframe.
        key = self._statement2key(statement)
        if key not in self._frames:
            path = self._key2path(key)
            if path.exists():
                LOGGER.debug('Disk cache hit for %s', statement)
                frame = pandas.read_parquet(path)
                LOGGER.debug('Disk cache miss for %s', statement)
                frame = loader(statement)
                frame.to_parquet(path, index=False)
            self._frames[key] = frame
            LOGGER.debug('Memory cache hit for %s', statement)
        return self._frames[key]

[docs]class Feed(io.Feed[sql.Selectable, sql.ColumnElement], alias='alchemy'): """Generic SQL feed based on :doc:`SQLAlchemy <sqlalchemy:index>`. All the hosted datasets need to be declared using a proper :ref:`content resolver <io-resolution>` mapping specified using the ``sources`` option with keys representing the fully qualified schema name formatted as ``<full.module.path>:<qualified.Class.Name>`` and the values should refer to the physical table names like ``<database>.<table>``. Attention: All the referenced :ref:`schema catalogs <io-catalog>` must be installed. Args: sources: The mapping of :ref:`schema catalogs <io-catalog>` to the DB tables. readerkw: Optional keywords typically for the :func:`pandas.read_sql <pandas:pandas.read_sql>`. The provider can be enabled using the following :ref:`platform configuration <platform-config>`: .. code-block:: toml :caption: config.toml [FEED.sql] provider = "alchemy" connection = "mysql+pymysql://john:smith@localhost/" [FEED.sql.sources] "openschema.kaggle:Titanic" = "kaggle.titanic" "foobar.schemas:Foo.Baz" = "foobar.baz" Important: Select the ``sql`` :ref:`extras to install <install-extras>` ForML together with the SQLAlchemy support. """ _TABLE_NAME = re.compile(r'(?:([\w.]+)\.)?(\w+)') class Reader(alchemy.Reader): """Using the SQLAlchemy reader as is.""" RESULTS: Results = Results(setup.USRDIR / '.cache' / 'alchemy') @classmethod def read(cls, statement: sql.Selectable, **kwargs) -> pandas.DataFrame: return cls.RESULTS.get_or_exec(statement, functools.partial(super().read, **kwargs)) def __init__( self, sources: typing.Mapping[typing.Union['dsl.Source', str], str], **readerkw, ): def ensure_source(src: typing.Union['dsl.Source', str]) -> 'dsl.Source': if isinstance(src, str): src = dslmod.Schema.from_path(src) return src def table(name: str) -> sqlalchemy.TableClause: if not (match := self._TABLE_NAME.fullmatch(name)): raise forml.InvalidError(f'Invalid table name: {name}') schema, name = match.groups() return sqlalchemy.table(sql.quoted_name(name, quote=True), schema=schema) self._sources: typing.Mapping['dsl.Source', sql.Selectable] = { ensure_source(s): table(t) for s, t in sources.items() } super().__init__(**readerkw) @property def sources(self) -> typing.Mapping['dsl.Source', sql.Selectable]: return types.MappingProxyType(self._sources)