Source code for forml.flow._code.compiler

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Runtime symbols compilation.
"""

import collections
import functools
import itertools
import logging
import typing
import uuid

from .. import _exception
from .._graph import atomic, span
from . import target
from .target import system, user

if typing.TYPE_CHECKING:
    from forml import flow
    from forml.io import asset


LOGGER = logging.getLogger(__name__)


class Table(span.Visitor, typing.Iterable):
    """Dynamic builder of the runtime symbols. Table uses node UIDs and GIDs where possible as instruction keys."""

    class Linkage:
        """Structure for registering instruction dependency tree as relations between target (receiving) instruction
        and its upstream dependency instructions representing its positional arguments.
        """

        def __init__(self):
            self._absolute: dict[uuid.UUID, list[typing.Optional[uuid.UUID]]] = collections.defaultdict(list)
            self._prefixed: dict[uuid.UUID, list[typing.Optional[uuid.UUID]]] = collections.defaultdict(list)

        def __getitem__(self, instruction: uuid.UUID) -> typing.Sequence[uuid.UUID]:
            return tuple(itertools.chain(reversed(self._prefixed[instruction]), self._absolute[instruction]))

        @property
        def leaves(self) -> typing.AbstractSet[uuid.UUID]:
            """Return the leaf nodes that are anyone's dependency.

            Returns:
                leaf nodes.
            """
            parents = {i for a in itertools.chain(self._absolute.values(), self._prefixed.values()) for i in a}
            children = set(self._absolute).union(self._prefixed).difference(parents)
            assert children, 'Not acyclic'
            return children

        def insert(self, instruction: uuid.UUID, argument: uuid.UUID, index: typing.Optional[int] = None) -> None:
            """Store given argument as a positional parameter of given instruction at absolute offset given by index.

            Index can be omitted for single-argument instructions.

            Args:
                instruction: Target (receiver) instruction.
                argument: Positional argument to be stored.
                index: Position offset of given argument.
            """
            args = self._absolute[instruction]
            argcnt = len(args)
            if index is None:
                assert argcnt <= 1, f'Index required for multiarg ({argcnt}) instruction'
                index = 0
            assert index >= 0, 'Invalid positional index'
            if argcnt <= index:
                args.extend([None] * (index - argcnt + 1))
            assert not args[index], 'Link collision'
            args[index] = argument

        def update(self, node: 'flow.Worker', getter: typing.Callable[[int], uuid.UUID]) -> None:
            """Register given node (its eventual functor) as an absolute positional argument of all of its subscribers.

            For multi-output nodes the output needs to be passed through Getter instructions that are extracting
            individual items.

            Args:
                node: Worker node (representing its actual functor) as an positional argument of its subscribers.
                getter: Callback for creating a Getter instruction for given positional index and returning its key.
            """
            if node.szout == 1:
                for subscriber in node.output[0]:
                    self.insert(subscriber.node.uid, node.uid, subscriber.port)
            else:
                for index, output in enumerate(node.output):
                    source = getter(index)
                    self.insert(source, node.uid)
                    for subscriber in output:
                        self.insert(subscriber.node.uid, source, subscriber.port)

        def prepend(self, instruction: uuid.UUID, argument: uuid.UUID) -> None:
            """In contrast to the absolute positional arguments we can potentially prepend these with various system
            arguments that should eventually prefix the absolute ones.

            Here we just append these to a list but during iteration we read them in reverse to reflect the prepend
            order.

            Args:
                instruction: Key of the target (receiver) instruction.
                argument: Argument (instruction key) to be prepended to the list of the absolute arguments.
            """
            self._prefixed[instruction].append(argument)

    class Index:
        """Mapping of the stored instructions. Same instruction might be stored under multiple keys."""

        def __init__(self):
            self._instructions: dict[uuid.UUID, 'flow.Instruction'] = {}

        def __contains__(self, key: uuid.UUID) -> bool:
            return key in self._instructions

        def __getitem__(self, key: uuid.UUID):
            return self._instructions[key]

        @property
        def instructions(self) -> 'typing.Iterator[tuple[flow.Instruction, typing.Iterator[uuid.UUID]]]':
            """Iterator over tuples of instructions plus iterator of its keys.

            Returns:
                Instruction-keys tuples iterator.
            """
            return itertools.groupby(self._instructions.keys(), self._instructions.__getitem__)

        def set(self, instruction: 'flow.Instruction', key: typing.Optional[uuid.UUID] = None) -> uuid.UUID:
            """Store given instruction by provided or generated key.

            It is an error to store instruction with existing key (to avoid, use the reset method).

            Args:
                instruction: Runtime instruction to be stored.
                key: Optional key to be used as instruction reference.

            Returns:
                Key associated with the instruction.
            """
            if not key:
                key = uuid.uuid4()
            assert key not in self, 'Instruction collision'
            self._instructions[key] = instruction
            return key

        def reset(self, orig: uuid.UUID, new: typing.Optional[uuid.UUID] = None) -> uuid.UUID:
            """Re-register instruction under given key to a new key (provided or generate).

            Args:
                orig: Original key of the instruction to be re-registered.
                new: Optional new key to re-register the instruction with.

            Returns:
                New key associated with the instruction.
            """
            instruction = self._instructions[orig]
            del self._instructions[orig]
            return self.set(instruction, new)

    def __init__(self, assets: typing.Optional['asset.State']):
        self._assets: typing.Optional['asset.State'] = assets
        self._linkage: Table.Linkage = self.Linkage()
        self._index: Table.Index = self.Index()
        self._committer: typing.Optional[uuid.UUID] = None

    def __iter__(self) -> 'flow.Symbol':
        def merge(
            value: typing.Iterable[typing.Optional[uuid.UUID]], element: typing.Iterable[typing.Optional[uuid.UUID]]
        ) -> typing.Iterable[uuid.UUID]:
            """Merge two iterables with at most one of them having non-null value on each offset into single iterable
            with this non-null values picked.

            Args:
                value: Left iterable.
                element: Right iterable.

            Returns:
                Merged iterable.
            """

            def pick(left: typing.Optional[uuid.UUID], right: typing.Optional[uuid.UUID]) -> typing.Optional[uuid.UUID]:
                """Pick the non-null value from the two arguments.

                Args:
                    left: Left input argument to pick from.
                    right: Right input argument to pick from.

                Returns:
                    The non-null value of the two (if any).
                """
                assert not (left and right), 'Expecting at most one non-null value'
                return left if left else right

            return (pick(a, b) for a, b in itertools.zip_longest(value, element))

        stubs = {s for s in (self._index[n] for n in self._linkage.leaves) if isinstance(s, system.Getter)}
        for instruction, keys in self._index.instructions:
            if instruction in stubs:
                LOGGER.debug('Pruning stub getter %s', instruction)
                continue
            try:
                arguments = tuple(self._index[a] for a in functools.reduce(merge, (self._linkage[k] for k in keys)))
            except KeyError as err:
                raise _exception.AssemblyError(f'Argument mismatch for instruction {instruction}') from err
            yield target.Symbol(instruction, arguments)

    def add(self, node: 'flow.Worker') -> None:
        """Populate the symbol table to implement the logical flow of given node.

        Args:
            node: Node to be added - compiled into symbols.
        """
        assert node.uid not in self._index, f'Node collision ({node})'
        assert isinstance(node, atomic.Worker), f'Not a worker node ({node})'

        LOGGER.debug('Adding node %s into the symbol table', node)
        functor = user.Apply().functor(node.builder)
        aliases = [node.uid]
        if node.stateful:
            state = node.gid
            persistent = self._assets and state in self._assets
            if persistent and state not in self._index:
                self._index.set(system.Loader(self._assets, state), state)
            if node.trained:
                functor = user.Train().functor(node.builder)
                aliases.append(state)
                if persistent:
                    if not self._committer:
                        self._committer = self._index.set(system.Committer(self._assets))
                    dumper = self._index.set(system.Dumper(self._assets))
                    self._linkage.insert(dumper, node.uid)
                    self._linkage.insert(self._committer, dumper, self._assets.offset(state))
                    state = self._index.reset(state)  # re-register loader under it's own id
            if persistent or node.derived:
                functor = functor.preset_state()
                self._linkage.prepend(node.uid, state)
        for key in aliases:
            self._index.set(functor, key)
        if not node.trained:
            self._linkage.update(node, lambda index: self._index.set(system.Getter(index)))

    def visit_node(self, node: 'flow.Worker') -> None:
        """Visitor entrypoint.

        Args:
            node: Node to be visited.
        """
        self.add(node)


[docs]def compile( # pylint: disable=redefined-builtin segment: 'flow.Segment', assets: typing.Optional['asset.State'] = None ) -> typing.Collection['flow.Symbol']: """Generate the portable low-level runtime symbol table representing the given flow topology segment augmented with all the necessary system instructions. Args: segment: Flow topology segment to generate the symbol table for. assets: Runtime state asset accessors for all the involved persistent workers. Returns: The portable runtime symbol table. """ table = Table(assets) segment.accept(table) return tuple(table)