Source code for forml.provider.runner.dask
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Dask runner.
"""
import logging
import typing
import dask
import distributed
from dask.delayed import Delayed
from forml import flow, runtime
if typing.TYPE_CHECKING:
from forml import io
from forml.io import asset
LOGGER = logging.getLogger(__name__)
[docs]class Runner(runtime.Runner, alias='dask'):
"""ForML runner implementation using the :doc:`Dask computing library <dask:index>` as the
execution platform.
Args:
kwargs: Any :doc:`Dask Configuration options <dask:configuration>`.
Noteworthy parameters:
* ``scheduler`` selects the :doc:`scheduling implementation <dask:scheduling>`
(valid options are: ``synchronous``, ``threads``, ``processes``,
``distributed``)
* to submit to a remote :doc:`Dask Cluster <distributed:index>`, set the
``scheduler`` to ``distributed`` and provide the master ``scheduler-address``
The provider can be enabled using the following :ref:`platform configuration <platform-config>`:
.. code-block:: toml
:caption: config.toml
[RUNNER.compute]
provider = "dask"
scheduler = "processes"
Important:
Select the ``dask`` :ref:`extras to install <install-extras>` ForML together with the Dask
support.
"""
DEFAULTS = {
'scheduler': 'processes',
}
def __init__(
self,
instance: typing.Optional['asset.Instance'] = None,
feed: typing.Optional['io.Feed'] = None,
sink: typing.Optional['io.Sink'] = None,
**kwargs,
):
super().__init__(instance, feed, sink)
dask.config.set(self.DEFAULTS | kwargs)
self._client: typing.Optional[distributed.Client] = None
def start(self) -> None:
if dask.config.get('scheduler') == 'distributed':
self._client = distributed.Client()
self._client.start()
def close(self) -> None:
if self._client:
self._client.close()
self._client = None
@staticmethod
def _mkjob(symbols: typing.Collection[flow.Symbol]) -> typing.Iterable[Delayed]:
"""Construct the linked task graph in Dask representation.
Args:
symbols: Internal DAG representation in form of the compiled symbols.
Returns:
Leaf nodes of the constructed DAG.
"""
def link(leaf: flow.Instruction) -> Delayed:
"""Recursive linking the given leaf to its upstream branch.
Args:
leaf: The leaf node to be linked upstream.
Returns:
The leaf node linked to its upstream branch.
"""
if leaf not in branches:
branches[leaf] = dask.delayed(leaf, pure=True, traverse=False)(*(link(a) for a in args.get(leaf, [])))
return branches[leaf]
args: typing.Mapping[flow.Instruction, typing.Sequence[flow.Instruction]] = dict(symbols)
assert len(args) == len(symbols), 'Duplicated symbols in DAG sequence'
leaves = set(args).difference(p for a in args.values() for p in a)
assert leaves, 'Not acyclic'
branches: dict[flow.Instruction, Delayed] = {}
return (link(d) for d in leaves)
@classmethod
def run(cls, symbols: typing.Collection[flow.Symbol], **kwargs) -> None:
dask.compute(cls._mkjob(symbols))