Source code for forml.flow._graph.atomic

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Graph node entities.

Output ports:
* apply (multi-port)

Input ports:
* apply (multi-port)
* train
* label

Each port can have at most one publisher.
Apply and train input port subscriptions are exclusive.
Trained node cannot be copied.
"""
import abc
import collections
import typing
import uuid

from .. import _exception
from . import port

if typing.TYPE_CHECKING:
    from forml import flow


class Visitor:
    """View visitor interface."""

    def visit_node(self, node: 'flow.Node') -> None:
        """Node visit.

        Args:
            node: Visited node.
        """


class Port(typing.Iterable[port.Subscription]):
    """Output port subscriptions as an ordered set.

    The ordering is a dependency for the equality comparison of two nodes.
    """

    def __init__(self):
        self._subscriptions: dict['flow.Subscription', None] = collections.OrderedDict()

    def add(self, subscription: 'flow.Subscription') -> None:
        """Add new subscription to this port.

        Args:
            subscription: Subscription to be registered.
        """
        self._subscriptions[subscription] = None

    def __iter__(self):
        return iter(self._subscriptions.keys())


[docs]class Node(metaclass=abc.ABCMeta): """Abstract primitive task graph node. Args: szin: Number of input *Apply* ports. szout: Number of output *Apply* ports. """ def __init__(self, szin: int, szout: int): if min(szin, szout) < 0 or szin == szout == 0: raise ValueError('Invalid node shape') self.szin: int = szin self.uid: uuid.UUID = uuid.uuid4() self._output: tuple[Port] = tuple(Port() for _ in range(szout)) def __repr__(self): return f'{self.__class__.__name__}[uid={self.uid}]' def __getitem__(self, index: int) -> 'flow.PubSub': """Semantical construct for creating PubSub port instance. Args: index: Input/output apply port index. Returns: Applicable instance """ return port.PubSub(self, index) def __eq__(self, other: typing.Any) -> bool: """If each node is of different type the equality is based on the equality of their subscriptions. Otherwise the equality is based on object identity. Args: other: Object to compare with. Returns: True if equal. """ if isinstance(other, Node) and other.__class__ is not self.__class__: return ( self.szout == other.szout and any(self._output) and all(s == o for s, o in zip(self.output, other.output)) ) return id(self) == id(other) def __hash__(self) -> int: """We need a Future node to appear identical to a Worker node of same shape and subscriptions (so that the Future can represent a placeholder for that Worker). From that reason we need to hash both of these instances into same hashcode and the only attributes can distinguish them in that case is the shape. Returns: Node hashcode. """ return hash(self.szin) ^ hash(self.szout) def accept(self, visitor: Visitor) -> None: """Visitor entrypoint. Args: visitor: Accepted visitor. """ visitor.visit_node(self) @property def szout(self) -> int: """Width of the output apply port. Returns: Output apply port width. """ return len(self._output) @property def output(self) -> typing.Sequence[typing.Iterable['flow.Subscription']]: """Get list of output subscriptions per each port. Returns: Output subscriptions. """ return tuple(tuple(s) for s in self._output) def _publish(self, index: int, subscription: 'flow.Subscription') -> None: """Publish an output port based on the given subscription. Args: index: Output port index to publish from. subscription: Subscriber node and port to publish to. """ assert 0 <= index < self.szout, 'Invalid output index' if self is subscription.node: raise _exception.TopologyError('Self subscription') self._output[index].add(subscription) @abc.abstractmethod def subscribed(self, publisher: 'flow.Node') -> bool: """Checking we are on given node's subscription list. Args: publisher: Node to check for being it's subscriber, Returns: True if we are given node's subscriber. """ @abc.abstractmethod def fork(self) -> 'flow.Node': """Create new node with same shape and actor as self but without any subscriptions. Returns: Forked node. """
[docs]class Worker(Node): """Main primitive node type. Args: builder: Actor builder instance. group: Worker group container. szin: Number of input *Apply* ports. szout: Number of output *Apply* ports. """ class Group(set): """A container for holding all forked workers. Args: builder: Actor builder instance. """ def __init__(self, builder: 'flow.Builder'): super().__init__() self.builder: 'flow.Builder' = builder self.uid: uuid.UUID = uuid.uuid4() def __repr__(self): return f'{self.builder}[uid={self.uid}]' @typing.overload def __init__(self, builder: 'flow.Builder', /, szin: int, szout: int): """Constructor for a new independent worker. Args: builder: Actor builder instance. szin: Number of input *Apply* ports. szout: Number of output *Apply* ports. """ @typing.overload def __init__(self, group: 'flow.Worker.Group', /, szin: int, szout: int): """Constructor for a new worker belonging to the same group (sharing their state). Args: group: Worker group container. szin: Number of input *Apply* ports. szout: Number of output *Apply* ports. """ def __init__(self, group_or_builder, /, szin, szout): super().__init__(szin, szout) self._group: Worker.Group = ( group_or_builder if isinstance(group_or_builder, Worker.Group) else self.Group(group_or_builder) ) self._group.add(self) def __repr__(self): return repr(self._group) @property def builder(self) -> 'flow.Builder': """Task builder in this worker. Returns: Task builder. """ return self._group.builder def _publish(self, index: int, subscription: 'flow.Subscription') -> None: """Publish an output port based on the given subscription. Args: index: Output port index to publish from. subscription: Subscriber node and port to publish to. Trained node must not be publishing. """ if self.trained: raise _exception.TopologyError('Trained node publishing') super()._publish(index, subscription) @property def input(self) -> typing.Iterable[port.Type]: """Get subscribed input ports. Returns: Ports. """ return port.Subscription.ports(self) @property def trained(self) -> bool: """Check if this node is subscribed for training data. Returns: True if trained. """ return any(isinstance(p, (port.Train, port.Label)) for p in self.input) @property def stateful(self) -> bool: """Check this actor is stateful. Returns: True if stateful. """ return self._group.builder.actor.is_stateful() @property def derived(self) -> bool: """Check this node is a state receiver in a trained group. Returns: True if persistent. """ return self.stateful and any(n.trained for n in self.group if n is not self) @property def gid(self) -> uuid.UUID: """Return the group ID shared by all forks of this worker. Returns: Group ID. """ return self._group.uid @property def group(self) -> typing.AbstractSet['flow.Worker']: """Set of forked workers in the same fork group. Returns: Workers in same fork group. """ return frozenset(self._group)
[docs] def train(self, train: 'flow.Publishable', label: 'flow.Publishable') -> None: """Subscribe this node *Train* and *Label* ports to the given publishers. Args: train: Train port publisher. label: Label port publisher. """ if not self.stateful: raise _exception.TopologyError('Stateless node training') if any(f.trained for f in self._group): raise _exception.TopologyError('Fork train collision') train.publish(self, port.Train()) label.publish(self, port.Label())
def subscribed(self, publisher: 'flow.Node') -> bool: """Checking we are on given node's subscription list. Args: publisher: Node to check for being its subscriber, Returns: True if we are given node's subscriber. """ return any(s.node is self for p in publisher.output for s in p)
[docs] def fork(self) -> 'flow.Worker': """Create new a node belonging to the same group (having the same shape and actor as self) but without any subscriptions. Returns: Forked node. """ return Worker(self._group, self.szin, self.szout)
[docs] @classmethod def fgen(cls, builder: 'flow.Builder', szin: int, szout: int) -> typing.Iterator['flow.Worker']: """Generator producing forks of the same node belonging to the same group. Args: builder: Actor builder. szin: Worker input apply port size. szout: Worker output apply port size. Returns: Generator producing same-group worker forks. """ node = cls(builder, szin, szout) yield node while True: yield node.fork()
[docs]class Future(Node): """Fake transparent *Apply* port node that can be used as a lazy publisher/subscriber that disappears from the chain once it gets connected to another apply node(s). """ class PubSub(port.PubSub): """Overridden implementation that does the proxied publishing/subscription.""" def __init__( self, node: 'flow.Future', index: int, register: typing.Callable[[port.Publishable], None], ): super().__init__(node, index) self._register: typing.Callable[[port.Publishable], None] = register def subscribe(self, publisher: 'flow.Publishable') -> None: """Register publisher for future subscriptions. Args: publisher: Actual left side publisher to be used for all the interim subscriptions. """ self._register(publisher) def __init__(self, szin: int = 1, szout: int = 1): super().__init__(szin, szout) self._input: dict[port.Publishable, int] = {} def __getitem__(self, index) -> 'flow.PubSub': def register(publisher: 'flow.Publishable') -> None: """Callback for publisher input registration. Args: publisher: Left side publisher """ if publisher in self._input: raise _exception.TopologyError('Publisher collision') self._input[publisher] = index self._collapse() return self.PubSub(self, index, register) def subscribed(self, publisher: 'flow.Node') -> bool: """Overridden subscription checker. Future node checks the subscriptions in its proxy registrations. Args: publisher: Node to check for being its subscriber, Returns: True if we are given node's subscriber. """ # pylint: disable=protected-access return any(p._node is publisher or p._node.subscribed(publisher) for p in self._input) def _collapse(self) -> None: """Callback for interconnecting proxied registrations.""" for publisher, subscription in ((p, s) for p, i in self._input.items() for s in self._output[i]): publisher.republish(subscription) def _publish(self, index: int, subscription: 'flow.Subscription') -> None: """Publish an output port based on the given subscription. Args: index: Output port index to publish from. subscription: Subscriber node and port to publish to. Upstream publish followed by proxy synchronization. """ super()._publish(index, subscription) self._collapse() def fork(self) -> 'flow.Future': """There is nothing to copy on a Future node so just create a new one. Returns: new Future node. """ return Future(self.szin, self.szout)