Source code for forml.flow._task

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

Flow actor abstraction.

import abc
import collections
import io
import logging
import pickle
import types
import typing

import joblib

import forml

LOGGER = logging.getLogger(__name__)

def name(actor: typing.Any, *args, **kwargs) -> str:
    """Infer the task name of given instance or type.

        actor: Type or actor instance.
        *args: Optional positional parameters.
        **kwargs: Optional keyword parameters.

        String name representation.

    def extract(obj: typing.Any) -> str:
        """Extract the name of given object
            obj: Object whose name to be extracted.

            Extracted name.
        return obj.__name__ if hasattr(obj, '__name__') else repr(obj)

    value = extract(actor)
    params = [extract(a) for a in args] + [f'{k}={extract(v)}' for k, v in kwargs.items()]
    if params:
        value += '(' + ', '.join(params) + ')'
    return value

Features = typing.TypeVar('Features')
Labels = typing.TypeVar('Labels')
Result = typing.TypeVar('Result')

[docs]class Actor(typing.Generic[Features, Labels, Result], metaclass=abc.ABCMeta): """Abstract interface of an actor.""" @classmethod def spec(cls, *args, **kwargs: typing.Any) -> 'Spec[Features, Labels, Result]': """Shortcut for creating a spec of this actor. Args: *args: Positional params. **kwargs: Keyword params. Returns: Actor spec instance. """ return Spec(cls, *args, **kwargs) @classmethod def is_stateful(cls) -> bool: """Check whether this actor is stateful (by default determined based on existence of user-overridden train method). Returns: True if stateful. """ return cls.train.__code__ is not Actor.train.__code__
[docs] def train(self, features: Features, labels: Labels) -> None: # pylint: disable=no-self-use """Train the actor using the provided features and label. Optional method engaging the *Train* (``features``) and *Label* (``label``) ports on stateful actors. Args: features: Table of feature vectors. labels: Table of labels. """ raise RuntimeError('Stateless actor')
[docs] @abc.abstractmethod def apply(self, *features: Features) -> Result: """Pass features through the apply function (typically transform or predict). Mandatory M:N input-output *Apply* ports. Args: features: Table(s) of feature vectors. Returns: Transformed features (ie predictions). """
[docs] def get_params(self) -> typing.Mapping[str, typing.Any]: # pylint: disable=no-self-use """Get hyper-parameters of this actor. Mandatory input and output *Params* ports. Returns: Dictionary of the name-value of the hyperparameters. All of the returned parameters must be acceptable by the companion set_params. """ return {}
[docs] def set_params(self, **params: typing.Any) -> None: """Set hyper-parameters of this actor. Args: params: Dictionary of hyper parameters. """ if params: raise RuntimeError(f'Params setter for {params} not implemented on {self}')
[docs] def get_state(self) -> bytes: """Return the internal state of the actor. Returns: State as bytes. """ if not self.is_stateful(): return bytes() LOGGER.debug('Getting %s state', self) with io.BytesIO() as bio: joblib.dump(self.__dict__, bio, protocol=pickle.HIGHEST_PROTOCOL) return bio.getvalue()
[docs] def set_state(self, state: bytes) -> None: """Set new internal state of the actor. Note this doesn't change the setting of the actor hyper-parameters. Args: state: bytes to be used as internal state. """ if not state: return if not self.is_stateful(): raise forml.UnexpectedError('State provided but actor stateless') LOGGER.debug('Setting %s state (%d bytes)', self, len(state)) params = self.get_params() # keep the original hyper-params with io.BytesIO(state) as bio: self.__dict__.update(joblib.load(bio)) self.set_params(**params) # restore the original hyper-params
def __repr__(self): return name(self.__class__, **self.get_params())
class Spec(typing.Generic[Features, Labels, Result], collections.namedtuple('Spec', 'actor, args, kwargs')): """Wrapper of actor class and init params.""" actor: type[Actor[Features, Labels, Result]] args: tuple[typing.Any] kwargs: typing.Mapping[str, typing.Any] def __new__(cls, actor: type[Actor], *args: typing.Any, **kwargs: typing.Any): return super().__new__(cls, actor, args, types.MappingProxyType(kwargs)) def __repr__(self): return name(, *self.args, **self.kwargs) def __getnewargs_ex__(self): return (, *self.args), dict(self.kwargs) def __call__(self, *args, **kwargs) -> Actor[Features, Labels, Result]: return*(args or self.args), **self.kwargs | kwargs)