Source code for forml.provider.runner.spark
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Spark runner.
"""
import logging
import typing
import pyspark
from forml import flow, runtime
if typing.TYPE_CHECKING:
from forml import io
from forml.io import asset
LOGGER = logging.getLogger(__name__)
[docs]class Runner(runtime.Runner, alias='spark'):
"""ForML runner utilizing :doc:`Apache Spark <pyspark:index>` as a distributed executor.
Args:
kwargs: Any `Spark Configuration options
<https://spark.apache.org/docs/latest/configuration.html>`_.
The provider can be enabled using the following :ref:`platform configuration <platform-config>`:
.. code-block:: toml
:caption: config.toml
[RUNNER.compute]
provider = "spark"
"spark.driver.cores" = 1
"spark.driver.memory" = "1g"
"spark.executor.cores" = 2
"spark.executor.memory" = "1g"
"spark.executor.pyspark.memory" = "1g"
Important:
Select the ``spark`` :ref:`extras to install <install-extras>` ForML together with the Spark
support.
Note:
ForML uses Spark purely as an *executor* without any deeper integration with its robust data
management API.
"""
DEFAULTS = {'spark.app.name': 'ForML'}
def __init__(
self,
instance: typing.Optional['asset.Instance'] = None,
feed: typing.Optional['io.Feed'] = None,
sink: typing.Optional['io.Sink'] = None,
**kwargs,
):
super().__init__(instance, feed, sink)
self._config: pyspark.SparkConf = pyspark.SparkConf().setAll((self.DEFAULTS | kwargs).items())
self._context: typing.Optional[pyspark.SparkContext] = None
def start(self) -> None:
self._context = pyspark.SparkContext.getOrCreate(self._config)
def close(self) -> None:
self._context.stop()
# pylint: disable=protected-access
self._context._gateway.shutdown()
self._context._gateway.proc.kill()
pyspark.SparkContext._jvm = None
pyspark.SparkContext._active_spark_context = None
pyspark.SparkContext._gateway = None
self._context = None
@staticmethod
def _submit(spark: pyspark.SparkContext, symbols: typing.Collection[flow.Symbol]) -> typing.Iterable[pyspark.RDD]:
"""Build and submit the task graph in Spark representation.
Args:
symbols: Internal DAG representation in form of the compiled symbols.
Returns:
Leaf nodes of the constructed DAG.
"""
def apply(instruction: flow.Instruction, *args: pyspark.RDD) -> pyspark.RDD:
"""Perform the instruction using the given RDDs as arguments.
Args:
instruction: Flow instruction to be performed.
*args: RDDs to be used as arguments.
Returns:
Result in form of a RDD.
"""
if not args:
return spark.parallelize([instruction()])
if len(args) == 1:
return args[0].map(instruction)
return spark.parallelize([instruction(*(a.collect()[0] for a in args))])
def link(leaf: flow.Instruction) -> pyspark.RDD:
"""Recursive linking the given leaf to its upstream branch.
Args:
leaf: The leaf node to be linked upstream.
Returns:
The leaf node linked to its upstream branch.
"""
if leaf not in nodes:
nodes[leaf] = apply(leaf, *(link(a) for a in arguments.get(leaf, [])))
else:
nodes[leaf].cache()
return nodes[leaf]
arguments: typing.Mapping[flow.Instruction, typing.Sequence[flow.Instruction]] = dict(symbols)
assert len(arguments) == len(symbols), 'Duplicated symbols in DAG sequence'
leaves = set(arguments).difference(p for a in arguments.values() for p in a)
assert leaves, 'Not acyclic'
nodes: dict[flow.Instruction, pyspark.RDD] = {}
return (link(d) for d in leaves)
@classmethod
def run(cls, symbols: typing.Collection[flow.Symbol], **kwargs) -> None:
for result in cls._submit(pyspark.SparkContext.getOrCreate(), symbols):
result.collect()