# Copyright 2021 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the GNU Affero General Public License 3.0 (the "License").
# A copy of the License may be obtained with this software package or at
#
# https://www.gnu.org/licenses/agpl-3.0.en.html
#
# Use of this file is prohibited except in compliance with the License. Any
# modifications or derivative works of this file must retain this copyright
# notice, and modified files must contain a notice indicating that they have
# been altered from the originals.
#
# Covalent is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details.
#
# Relief from the License may be granted by purchasing a commercial license.
"""
Module for defining a Dask executor that submits the input python function in a dask cluster
and waits for execution to finish then returns the result.
This is a plugin executor module; it is loaded if found and properly structured.
"""
import os
from typing import Callable, Dict, List, Literal
from dask.distributed import CancelledError, Client, Future
from covalent._shared_files import TaskRuntimeError, logger
# Relative imports are not allowed in executor plugins
from covalent._shared_files.config import get_config
from covalent._shared_files.exceptions import TaskCancelledError
from covalent._shared_files.utils import _address_client_mapper
from covalent.executor.base import AsyncBaseExecutor
from covalent.executor.utils.wrappers import io_wrapper as dask_wrapper
# The plugin class name must be given by the executor_plugin_name attribute:
EXECUTOR_PLUGIN_NAME = "DaskExecutor"
app_log = logger.app_log
log_stack_info = logger.log_stack_info
_EXECUTOR_PLUGIN_DEFAULTS = {
"log_stdout": "stdout.log",
"log_stderr": "stderr.log",
"cache_dir": os.path.join(
os.environ.get("XDG_CACHE_HOME") or os.path.join(os.environ["HOME"], ".cache"), "covalent"
),
}
[docs]class DaskExecutor(AsyncBaseExecutor):
"""
Dask executor class that submits the input function to a running dask cluster.
"""
def __init__(
self,
scheduler_address: str = "",
log_stdout: str = "stdout.log",
log_stderr: str = "stderr.log",
conda_env: str = "",
cache_dir: str = "",
current_env_on_conda_fail: bool = False,
) -> None:
if not cache_dir:
cache_dir = os.path.join(
os.environ.get("XDG_CACHE_HOME") or os.path.join(os.environ["HOME"], ".cache"),
"covalent",
)
if not scheduler_address:
try:
scheduler_address = get_config("dask.scheduler_address")
except KeyError as ex:
app_log.debug(
"No dask scheduler address found in config. Address must be set manually."
)
super().__init__(log_stdout, log_stderr, conda_env, cache_dir, current_env_on_conda_fail)
self.scheduler_address = scheduler_address
[docs] async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict):
"""Submit the function and inputs to the dask cluster"""
if await self.get_cancel_requested():
app_log.debug("Task has cancelled")
raise TaskCancelledError
node_id = task_metadata["node_id"]
dask_client = _address_client_mapper.get(self.scheduler_address)
if not dask_client:
dask_client = Client(address=self.scheduler_address, asynchronous=True)
_address_client_mapper[self.scheduler_address] = dask_client
await dask_client
future = dask_client.submit(dask_wrapper, function, args, kwargs)
await self.set_job_handle(future.key)
app_log.debug(f"Submitted task {node_id} to dask with key {future.key}")
try:
result, worker_stdout, worker_stderr, tb = await future
except CancelledError:
raise TaskCancelledError()
print(worker_stdout, end="", file=self.task_stdout)
print(worker_stderr, end="", file=self.task_stderr)
if tb:
print(tb, end="", file=self.task_stderr)
raise TaskRuntimeError(tb)
# FIX: need to get stdout and stderr from dask worker and print them
return result
[docs] async def cancel(self, task_metadata: Dict, job_handle) -> Literal[True]:
"""
Cancel the task being executed by the dask executor currently
Arg(s)
task_metadata: Metadata associated with the task
job_handle: Key assigned to the job by Dask
Return(s)
True by default
"""
dask_client = _address_client_mapper.get(self.scheduler_address)
if not dask_client:
dask_client = Client(address=self.scheduler_address, asynchronous=True)
_address_client_mapper[self.scheduler_address] = dask_client
await dask_client
fut: Future = Future(key=job_handle, client=dask_client)
await fut.cancel()
app_log.debug(f"Cancelled future with key {job_handle}")
return True