Source code for forml.flow._graph.port
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Graph node port functionality.
"""
import collections
import typing
from .. import _exception
from . import atomic # pylint: disable=unused-import
if typing.TYPE_CHECKING:
from forml import flow
class Type(int):
"""Input port base class."""
def __repr__(self):
return f'{self.__class__.__name__}[{int(self)}]'
def __hash__(self):
return hash(self.__class__) ^ int.__hash__(self)
def __eq__(self, other):
return other.__class__ is self.__class__ and int.__eq__(self, other)
class Meta(type):
"""Metaclass for singleton types."""
def __new__(mcs, name: str, bases: tuple[type], namespace: dict[str, typing.Any]):
value = namespace.pop('VALUE')
instance = None
def new(cls):
"""Injected class new method ensuring singletons with static value are only created."""
nonlocal instance
if instance is None:
instance = bases[0].__new__(cls, value)
return instance
namespace['__new__'] = new
return super().__new__(mcs, name, bases, namespace)
class Train(Type, metaclass=Meta):
"""Train input port."""
VALUE = 0
class Label(Type, metaclass=Meta):
"""Label input port."""
VALUE = 1
class Apply(Type):
"""Apply input/output port at the given index."""
class Subscription(collections.namedtuple('Subscription', 'node, port')):
"""Descriptor representing subscription node input port of given type."""
# registry of ports subscribed on given node
_PORTS: dict['flow.Node', set[Type]] = collections.defaultdict(set) # TO-DO: switch to weakref
def __new__(cls, subscriber: 'flow.Node', port: Type):
if port in cls._PORTS[subscriber]:
raise _exception.TopologyError('Double subscription')
if cls._PORTS[subscriber] and (
isinstance(port, Apply) ^ any(isinstance(s, Apply) for s in cls._PORTS[subscriber])
):
raise _exception.TopologyError('Apply/Train collision')
if isinstance(port, (Train, Label)) and any(subscriber.output):
raise _exception.TopologyError('Publishing node trained')
if isinstance(subscriber, atomic.Future):
raise _exception.TopologyError('Future node subscribing')
cls._PORTS[subscriber].add(port)
return super().__new__(cls, subscriber, port)
def __repr__(self):
return f'{self.node}@{self.port}'
def __hash__(self):
return hash(self.node) ^ hash(self.port)
def __eq__(self, other: typing.Any):
return isinstance(other, self.__class__) and self.node == other.node and self.port == other.port
@classmethod
def ports(cls, subscriber: 'flow.Node') -> typing.Iterable[Type]:
"""Get subscribed ports of given atomic.
Args:
subscriber: Node whose subscribed ports should be retrieved.
Returns:
Subscribed ports of given atomic.
"""
return frozenset(cls._PORTS[subscriber])
def __del__(self):
self._PORTS.get(self.node, {}).discard(self.port)
class Applicable:
"""Base for publisher/subscriber proxies."""
def __init__(self, node: 'flow.Node', index: int):
self._node: 'flow.Node' = node
self._index: int = index
[docs]class Publishable(Applicable):
"""Output *Apply* port reference that can be used just for publishing."""
@property
def szout(self) -> int:
"""Size of publisher node output.
Returns:
Output size.
"""
return self._node.szout
def publish(self, subscriber: 'flow.Node', port: Type) -> None:
"""Publish new subscription.
Args:
subscriber: Node to publish to.
port: Port to publish to.
"""
if isinstance(subscriber, atomic.Future) and subscriber is not self._node:
subscriber[port].subscribe(self)
return
subscription = Subscription(subscriber, port)
try:
self.republish(subscription)
except Exception as err:
# TO-DO: use weakref
Subscription._PORTS[subscriber].discard(port) # pylint: disable=protected-access
raise err
def republish(self, subscription: 'flow.Subscription') -> None:
"""Publish existing subscription.
Args:
subscription: Existing subscription descriptor.
"""
self._node._publish(self._index, subscription) # pylint: disable=protected-access
class Subscriptable(Applicable):
"""Input *Apply* port reference that can be used just for subscribing."""
@property
def szin(self) -> int:
"""Size of publisher node input.
Returns:
Input size.
"""
return self._node.szin
def subscribe(self, publisher: 'flow.Publishable') -> None:
"""Subscribe to the given publisher.
Args:
publisher: *Publishable* to subscribe to.
"""
publisher.publish(self._node, Apply(self._index))
[docs]class PubSub(Publishable, Subscriptable):
"""Input or output *Apply* port reference that can be used for both subscribing and publishing."""
@property
def publisher(self) -> 'flow.Publishable':
"""Return just a publishable representation.
Returns:
Publishable *Apply* port reference.
"""
return Publishable(self._node, self._index)
@property
def subscriber(self) -> 'flow.Subscriptable':
"""Return just a subscriptable representation.
Returns:
Subscriptable *Apply* port reference.
"""
return Subscriptable(self._node, self._index)