Source code for forml.application._strategy

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""
ForML application model rollout strategy.
"""
import abc
import collections
import dataclasses
import functools
import logging
import threading
import time
import typing

from forml.io import asset as assetmod

if typing.TYPE_CHECKING:
    from forml import application, runtime
    from forml.io import asset  # pylint: disable=reimported


LOGGER = logging.getLogger(__name__)


[docs]class Selector(abc.ABC): """Abstract base class for the model selection strategy to be used by the :class:`application.Generic <forml.application.Generic>` descriptors. """
[docs] @abc.abstractmethod def select(self, registry: 'asset.Directory', context: typing.Any, stats: 'runtime.Stats') -> 'asset.Instance': """Select the model instance to be used for serving the request. See Also: This serves the same purpose as the :meth:`application.Descriptor.select <forml.application.Descriptor.select>` method only extracted as a separate object. Args: registry: Model registry to select the model from. context: Optional metadata carried over from the :meth:`application.Descriptor.receive <forml.application.Descriptor.receive>`. stats: Application specific serving metrics. Returns: Model instance. """
[docs]class Explicit(Selector): """Model selection strategy always choosing an explicit model generation. Args: project: Project reference of the selected model. release: Project release reference of the selected model. generation: Project generation reference of the selected model. """ def __init__( self, project: typing.Union[str, 'asset.Project.Key'], release: typing.Union[str, 'asset.Release.Key'], generation: typing.Union[str, int, 'asset.Generation.Key'], ): self._project: typing.Union[str, 'asset.Project.Key'] = project self._release: typing.Union[str, 'asset.Release.Key'] = release self._generation: typing.Union[str, int, 'asset.Generation.Key'] = generation self._instance: typing.Optional['asset.Instance'] = None def select(self, registry: 'asset.Directory', context: typing.Any, stats: 'runtime.Stats') -> 'asset.Instance': if not self._instance: self._instance = assetmod.Instance( registry=registry, project=self._project, release=self._release, generation=self._generation, ) return self._instance
[docs]class Latest(Selector): """Model selection strategy choosing an instance of the most recent model release/generation. Args: project: Project reference to choose the most recent generation from. release: Optional release to choose the most recent generation from. refresh: Interval in seconds for refreshing the latest instance from its registry. """ DEFAULT_REFRESH = 30 """Default refresh interval in seconds""" def __init__( self, project: typing.Union[str, 'asset.Project.Key'], release: typing.Optional[typing.Union[str, 'asset.Release.Key']] = None, refresh: float = DEFAULT_REFRESH, ): self._project: typing.Union[str, 'asset.Project.Key'] = project self._release: typing.Optional[typing.Union[str, 'asset.Release.Key']] = release self._interval: float = refresh self._cache: dict['asset.Directory', 'asset.Instance'] = {} self._lock: threading.RLock = threading.RLock() self._refresher: threading.Thread = threading.Thread(target=self._refresh, daemon=True) def __reduce__(self): return self.__class__, (self._project, self._release, self._interval) def _refresh(self) -> None: while True: with self._lock: instances = tuple(self._cache.items()) LOGGER.debug('Refreshing %d cached instances', len(instances)) for registry, old in instances: new = self._pick(registry) if new != old: LOGGER.info('Updating latest instance to %s', new) with self._lock: self._cache[registry] = new time.sleep(self._interval) def _pick(self, registry: 'asset.Directory') -> 'asset.Instance': release = self._release generation = None if not release: project = registry.get(self._project) for release in reversed(project.list()): try: generation = project.get(release).list().last except assetmod.Level.Listing.Empty: continue break else: raise assetmod.Level.Listing.Empty(f'No models available for {self._project}') return assetmod.Instance( registry=registry, project=self._project, release=release, # pylint: disable=undefined-loop-variable generation=generation, ) def select(self, registry: 'asset.Directory', context: typing.Any, stats: 'runtime.Stats') -> 'asset.Instance': with self._lock: if registry not in self._cache: self._cache[registry] = self._pick(registry) if not self._refresher.is_alive(): self._refresher.start() return self._cache[registry]
[docs]class ABTest(Selector): """Model selection strategy choosing among a number of defined variants according to their target weights. The target weights can be defined either as fractions in the float interval (0, 1) or as positive integers that will get normalized internally. If omitted for any variant, the particular target weight gets calculated either as a complement to 1 (if all other provided weights are floats below 1) or as a mean of the provided integer weights. Attention: Instances are expected to be created via the :class:`Builder`. Examples: Setting up an ABTest starts with the :meth:`compare` method: >>> selector = application.ABTest.compare( >>> "forml-tutorial-titanic", "0.1.dev1", 1, 0.9 >>> ).against(generation=2, target=0.1) """ class Variant(collections.namedtuple('Variant', 'project, release, generation, target')): """Internal container for holding the variant parameters.""" project: typing.Union[str, 'asset.Project.Key'] release: typing.Union[str, 'asset.Release.Key'] generation: typing.Union[str, int, 'asset.Generation.Key'] target: typing.Optional[float] def __new__( cls, project: typing.Union[str, 'asset.Project.Key'], release: typing.Union[str, 'asset.Release.Key'], generation: typing.Union[str, int, 'asset.Generation.Key'], target: typing.Optional[float], ): if target is not None and target <= 0: raise ValueError(f'Positive target required: {target}') return super().__new__(cls, project, release, generation, target) def __hash__(self): return hash(self.project) ^ hash(self.release) ^ hash(self.generation) def __eq__(self, other): return ( isinstance(other, self.__class__) and other.project == self.project and other.release == self.release and other.generation == self.generation ) class Builder: """Internal builder for setting up the :class:`application.ABTest <forml.application.ABTest>`. Attention: Instances are expected to be created via the :meth:`application.ABTest.compare <compare>` method. """ def __init__( self, project: typing.Union[str, 'asset.Project.Key'], release: typing.Union[str, 'asset.Release.Key'], generation: typing.Union[str, int, 'asset.Generation.Key'], target: typing.Optional[float], ): self._variants: list['ABTest.Variant'] = [ABTest.Variant(project, release, generation, target)] def over( self, generation: typing.Union[str, int, 'asset.Generation.Key'], *, release: typing.Optional[typing.Union[str, 'asset.Release.Key']] = None, project: typing.Optional[typing.Union[str, 'asset.Project.Key']] = None, target: typing.Optional[float] = None, ) -> 'application.ABTest.Builder': """Intermediate method for adding a non-last variant (ABTest of more than 2 variants). Args: project: Project reference of the selected model. release: Project release reference of the selected model. generation: Project generation reference of the selected model. target: Relative engagement share demand. Returns: ABTest builder. """ last = self._variants[-1] self._variants.append(ABTest.Variant(project or last.project, release or last.release, generation, target)) return self def against( self, generation: typing.Union[str, int, 'asset.Generation.Key'], *, release: typing.Optional[typing.Union[str, 'asset.Release.Key']] = None, project: typing.Optional[typing.Union[str, 'asset.Project.Key']] = None, target: typing.Optional[float] = None, ) -> 'application.ABTest': """ABTest builder completer for adding the last variant and constructing the ABTest instance. Args: project: Project reference of the selected model. release: Project release reference of the selected model. generation: Project generation reference of the selected model. target: Relative engagement share demand. Returns: ABTest instance. """ self.over(generation, release=release, project=project, target=target) return ABTest(*self._variants) # pylint: disable=no-value-for-parameter @dataclasses.dataclass class Slot: """Internal container for variant metadata.""" variant: 'application.ABTest.Variant' target: float count: int = 0 def __hash__(self): return hash(self.variant) @functools.lru_cache def _instance(self, registry: 'asset.Directory') -> 'asset.Instance': return assetmod.Instance( registry=registry, project=self.variant.project, release=self.variant.release, generation=self.variant.generation, ) def hit(self, registry: 'asset.Directory') -> 'asset.Instance': """Select this slot for serving. Args: registry: Model registry to select the model from. Returns: Model instance. """ self.count += 1 return self._instance(registry) def eligible(self, total: int) -> bool: """Check the eligibility of this slot given the provided total amount of requests. Args: total: Overall amount of requests served across all variants. Must be > 0. Return: True if this slot is eligible for serving. """ assert total > 0 return (self.count / total) < self.target def __init__( self, avar: 'application.ABTest.Variant', bvar: 'application.ABTest.Variant', *others: 'application.ABTest.Variant', ): variants = (avar, bvar, *others) if len(set(variants)) != len(variants): raise ValueError('Exclusive variants required') targets = [v.target for v in variants if v.target] missing = sum(1 for v in variants if not v.target) if missing: explicit = sum(targets) implicit = (1 - explicit) / missing if explicit < 1 else explicit / len(targets) targets = [v.target or implicit for v in variants] combined = sum(targets) self._slots: tuple[ABTest.Slot] = tuple( sorted( (self.Slot(v, t / combined) for v, t in zip(variants, targets)), key=lambda s: s.target, reverse=True ) ) self._total: int = 0
[docs] @classmethod def compare( cls, project: typing.Union[str, 'asset.Project.Key'], release: typing.Union[str, 'asset.Release.Key'], generation: typing.Union[str, int, 'asset.Generation.Key'], target: typing.Optional[float] = None, ) -> 'application.ABTest.Builder': """Bootstrap method for creating the ABTest builder. Args: project: Project reference of the selected model. release: Project release reference of the selected model. generation: Project generation reference of the selected model. target: Relative engagement share demand. Returns: ABTest builder. """ return cls.Builder(project, release, generation, target)
def select(self, registry: 'asset.Directory', context: typing.Any, stats: 'runtime.Stats') -> 'asset.Instance': self._total += 1 for slot in self._slots: if slot.eligible(self._total): break else: raise RuntimeError('No eligible slots') return slot.hit(registry)