Source code for forml.io.asset._directory.level.minor

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

"""Generic assets directory.
"""
import collections
import datetime
import logging
import operator
import types
import typing
import uuid

import toml

from forml.io import dsl

from ... import _directory, _persistent

if typing.TYPE_CHECKING:
    from . import case as prjmod
    from . import major as lngmod

LOGGER = logging.getLogger(__name__)


[docs]class Tag(collections.namedtuple('Tag', 'training, tuning, states')): """Generation metadata. Args: training: Generation training information. tuning: Generation tuning information. states: Sequence of state asset IDs. """ class Mode(types.SimpleNamespace): """Mode metadata.""" class Proxy(tuple): """Mode attributes proxy.""" _tag = property(operator.itemgetter(0)) _mode = property(operator.itemgetter(1)) def __new__(cls, tag: 'Tag', mode: 'Tag.Mode'): return super().__new__(cls, (tag, mode)) def __repr__(self): return f'Mode{repr(self._mode)}' def __bool__(self): return bool(self._mode) def __getattr__(self, item): return getattr(self._mode, item) def __eq__(self, other) -> bool: # pylint: disable=protected-access return isinstance(other, self.__class__) and self._mode == other._mode def replace(self, **kwargs) -> 'Tag': """Mode attributes setter. Args: **kwargs: Keyword parameters to be set on given mode attributes. Returns: New tag instance with new values. """ mode = self._mode.__class__(**self._mode.__dict__ | kwargs) return Tag(**{k: mode if v is self._mode else v for k, v in self._tag._asdict().items()}) def trigger(self, timestamp: typing.Optional[datetime.datetime] = None) -> 'Tag': """Create new tag with given mode triggered (all attributes reset and timestamp set to now). Returns: New tag. """ return self.replace(timestamp=(timestamp or datetime.datetime.utcnow())) def __init__(self, timestamp: typing.Optional[datetime.datetime], **kwargs: typing.Any): super().__init__(timestamp=timestamp, **kwargs) def __bool__(self): return bool(self.timestamp) class Training(Mode): """Training mode attributes.""" def __init__( self, timestamp: typing.Optional[datetime.datetime] = None, ordinal: typing.Optional[dsl.Native] = None ): super().__init__(timestamp, ordinal=ordinal) class Tuning(Mode): """Tuning mode attributes.""" def __init__(self, timestamp: typing.Optional[datetime.datetime] = None, score: typing.Optional[float] = None): super().__init__(timestamp, score=score) def __new__( cls, training: typing.Optional[Training] = None, tuning: typing.Optional[Tuning] = None, states: typing.Optional[typing.Sequence[uuid.UUID]] = None, ): return super().__new__(cls, training or cls.Training(), tuning or cls.Tuning(), tuple(states or [])) def __bool__(self): return bool(self.training or self.tuning) def __getattribute__(self, name: str) -> typing.Any: attribute = super().__getattribute__(name) if isinstance(attribute, Tag.Mode): attribute = self.Mode.Proxy(self, attribute) return attribute def replace(self, **kwargs) -> 'Tag': """Replace the given non-mode attributes. Args: **kwargs: Non-mode attributes to be replaced. Returns: New tag instance. """ if not {k for k, v in self._asdict().items() if not isinstance(v, Tag.Mode)}.issuperset(kwargs.keys()): raise ValueError('Invalid replacement') return self._replace(**kwargs) def dumps(self) -> bytes: """Dump the tag into a string of bytes. Returns: String of bytes representation. """ return toml.dumps( { 'training': {'timestamp': self.training.timestamp, 'ordinal': self.training.ordinal}, 'tuning': {'timestamp': self.tuning.timestamp, 'score': self.tuning.score}, 'states': [str(s) for s in self.states], }, ).encode('utf-8') @classmethod def loads(cls, raw: bytes) -> 'Tag': """Load the previously dumped tag. Args: raw: Serialized tag representation to be loaded. Returns: Tag instance. """ meta = toml.loads(raw.decode('utf-8')) return cls( training=cls.Training(timestamp=meta['training']['timestamp'], ordinal=meta['training'].get('ordinal')), tuning=cls.Tuning(timestamp=meta['tuning'].get('timestamp'), score=meta['tuning'].get('score')), states=(uuid.UUID(s) for s in meta['states']), )
NOTAG = Tag() TAGS = _directory.Cache(_persistent.Registry.open) STATES = _directory.Cache(_persistent.Registry.read) # pylint: disable=unsubscriptable-object; https://github.com/PyCQA/pylint/issues/2822 class Generation(_directory.Level): """Snapshot of project states in its particular training iteration.""" class Key(_directory.Level.Key, int): """Project model generation key - i.e. generation *sequence number*. This must be a natural integer starting from 1. """ MIN = 1 def __new__(cls, key: typing.Optional[typing.Union[str, int, 'Generation.Key']] = MIN): try: instance = super().__new__(cls, str(key)) except ValueError as err: raise cls.Invalid(f'Invalid key {key} (not an integer)') from err if instance < cls.MIN: raise cls.Invalid(f'Invalid key {key} (not natural)') return instance @property def next(self) -> 'Generation.Key': return self.__class__(self + 1) def __init__( self, release: 'lngmod.Release', key: typing.Optional[typing.Union[str, int, 'Generation.Key']] = None ): super().__init__(key, parent=release) @property def project(self) -> 'prjmod.Project': """Get the project of this generation. Returns: Project of this generation. """ return self.release.project @property def release(self) -> 'lngmod.Release': """Get the release key of this generation. Returns: Release key of this generation. """ return self._parent @property def tag(self) -> 'Tag': """Generation metadata. In case of implicit generation and empty release this returns a "null" tag (a Tag object with all fields empty). Returns: Generation tag (metadata) object. """ # project/release must exist so let's fetch it outside of try-except project = self.project.key release = self.release.key try: generation = self.key except self.Listing.Empty: # generation doesn't exist LOGGER.debug('No previous generations found - using a null tag') return NOTAG return TAGS(self.registry, project, release, generation) def list(self) -> _directory.Level.Listing: """Return the listing of this level. Returns: Level listing. """ return self.Listing(self.tag.states) def get(self, key: typing.Union[uuid.UUID, int]) -> bytes: """Load the state based on provided id or positional index. Args: key: Index or absolute id of the state object to be loaded. Returns: Serialized state. """ if not self.tag.training: return b'' if isinstance(key, int): key = self.tag.states[key] if key not in self.tag.states: raise Generation.Invalid(f'Unknown state reference for {self}: {key}') LOGGER.debug('%s: Getting state %s', self, key) return STATES(self.registry, self.project.key, self.release.key, self.key, key)