Source code for covalent._workflow.electron

# Copyright 2021 Agnostiq Inc.
# This file is part of Covalent.
# Licensed under the GNU Affero General Public License 3.0 (the "License").
# A copy of the License may be obtained with this software package or at
# Use of this file is prohibited except in compliance with the License. Any
# modifications or derivative works of this file must retain this copyright
# notice, and modified files must contain a notice indicating that they have
# been altered from the originals.
# Covalent is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details.
# Relief from the License may be granted by purchasing a commercial license.

"""Class corresponding to computation nodes."""

import inspect
import operator
from builtins import list
from dataclasses import asdict
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union

from .._file_transfer.enums import Order
from .._file_transfer.file_transfer import FileTransfer
from .._shared_files import logger
from .._shared_files.context_managers import active_lattice_manager
from .._shared_files.defaults import (
from .._shared_files.utils import (
from .depsbash import DepsBash
from .depscall import RESERVED_RETVAL_KEY__FILES, DepsCall
from .depspip import DepsPip
from .lattice import Lattice
from .transport import TransportableObject, encode_metadata

consumable_constraints = ["budget", "time_limit"]

DEFAULT_METADATA_VALUES = asdict(DefaultMetadataValues())

    from ..executor import BaseExecutor
    from .transport import _TransportGraph

app_log = logger.app_log
log_stack_info = logger.log_stack_info

[docs]class Electron: """ An electron (or task) object that is a modular component of a work flow and is returned by :obj:`electron <covalent.electron>`. Attributes: function: Function to be executed. node_id: Node id of the electron. metadata: Metadata to be used for the function execution. kwargs: Keyword arguments if any. """ def __init__(self, function: Callable, node_id: int = None, metadata: dict = None) -> None: if metadata is None: metadata = {} self.function = function self.node_id = node_id self.metadata = metadata
[docs] def set_metadata(self, name: str, value: Any) -> None: """ Function to add/edit metadata of given name and value to electron's metadata. Args: name: Name of the metadata to be added/edited. value: Value of the metadata to be added/edited. Returns: None """ self.metadata[name] = value
[docs] def get_metadata(self, name: str) -> Any: """ Get value of the metadata of given name. Args: name: Name of the metadata whose value is needed. Returns: value: Value of the metadata of given name. Raises: KeyError: If metadata of given name is not present. """ return self.metadata[name]
[docs] def get_op_function( self, operand_1: Union[Any, "Electron"], operand_2: Union[Any, "Electron"], op: str ) -> "Electron": """ Function to handle binary operations with electrons as operands. This will not execute the operation but rather create another electron which will be postponed to be executed according to the default electron configuration/metadata. This also makes sure that if these operations are being performed outside of a lattice, then they are performed as is. Args: operand_1: First operand of the binary operation. operand_2: Second operand of the binary operation. op: Operator to be used in the binary operation. Returns: electron: Electron object corresponding to the operation execution. Behaves as a normal function call if outside a lattice. """ op_table = { "+": operator.add, "-": operator.sub, "*": operator.mul, "/": operator.truediv, } def rename(op1: Any, op: str, op2: Any) -> Callable: """ Decorator to rename a function according to the operation being performed. Args: op1: First operand op: Operator op2: Second operand Returns: function: Renamed decorated function. """ def decorator(f): op1_name = op1 if hasattr(op1, "function") and op1.function: op1_name = op1.function.__name__ op2_name = op2 if hasattr(op2, "function") and op2.function: op2_name = op2.function.__name__ f.__name__ = f"{op1_name}_{op}_{op2_name}" return f return decorator @electron @rename(operand_1, op, operand_2) def func_for_op(arg_1: Union[Any, "Electron"], arg_2: Union[Any, "Electron"]) -> Any: """ Intermediate function for the binary operation. Args: arg_1: First operand arg_2: Second operand Returns: result: Result of the binary operation. """ return op_table[op](arg_1, arg_2) return func_for_op(arg_1=operand_1, arg_2=operand_2)
def __add__(self, other): return self.get_op_function(self, other, "+") def __radd__(self, other): return self.__add__(other) def __sub__(self, other): return self.get_op_function(self, other, "-") def __rsub__(self, other): return self.get_op_function(other, self, "-") def __mul__(self, other): return self.get_op_function(self, other, "*") def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): return self.get_op_function(self, other, "/") def __rtruediv__(self, other): return self.get_op_function(other, self, "/") def __int__(self): return int() def __float__(self): return float() def __complex__(self): return complex() def __iter__(self): last_frame = inspect.currentframe().f_back bytecode = last_frame.f_code.co_code expected_unpack_values = bytecode[last_frame.f_lasti + 1] if expected_unpack_values < 2: return for i in range(expected_unpack_values): if active_lattice := active_lattice_manager.get_active_lattice(): try: node_name = prefix_separator + self.function.__name__ + "()" + f"[{i}]" except AttributeError: # The case when nested iter calls are made on the same electron node_name = prefix_separator + active_lattice.transport_graph.get_node_value( self.node_id, "name" ) node_name += f"[{i}]" def get_item(e, key): return e[key] get_item.__name__ = node_name iterable_metadata = self.metadata.copy() filtered_call_before = [] for elem in iterable_metadata["call_before"]: if elem["attributes"]["retval_keyword"] != "files": filtered_call_before.append(elem) iterable_metadata["call_before"] = filtered_call_before get_item_electron = Electron(function=get_item, metadata=iterable_metadata) yield get_item_electron(self, i) def __getattr__(self, attr: str) -> "Electron": # This is to handle the cases where magic functions are attempted # to be accessed. For example, in the case of pickling, sometimes # __getstate__ is called and we don't want to return an electron # object in that case. if attr.startswith("__") and attr.endswith("__"): return super().__getattr__(attr) if attr == "keys": raise AttributeError( "`keys` attribute should not be used in Electron objects due to conflict with `dict.keys`", "Please change the name of the attribute you want to use.", ) if active_lattice := active_lattice_manager.get_active_lattice(): def get_attr(e, attr): return getattr(e, attr) get_attr.__name__ = prefix_separator + self.function.__name__ + ".__getattr__" get_attr_electron = Electron(function=get_attr, metadata=self.metadata.copy()) return get_attr_electron(self, attr) return super().__getattr__(attr) def __getitem__(self, key: Union[int, str]) -> "Electron": if active_lattice := active_lattice_manager.get_active_lattice(): def get_item(e, key): return e[key] get_item.__name__ = prefix_separator + self.function.__name__ + ".__getitem__" get_item_electron = Electron(function=get_item, metadata=self.metadata.copy()) return get_item_electron(self, key) raise StopIteration def __call__(self, *args, **kwargs) -> Union[Any, "Electron"]: """ Function to execute the electron. This behaves differently if the execution call is made inside a lattice and just adds the electron as a node to the lattice's transport graph. If the execution call is made outside of a lattice, then it executes the electron as a normal function call. Also contains a postprocessing part where the lattice's function is executed after all the nodes in the lattice's transport graph are executed. Then the execution call to the electron is replaced by its corresponding result. """ # Check if inside a lattice and if not, perform a direct invocation of the function active_lattice = active_lattice_manager.get_active_lattice() if active_lattice is None: return self.function(*args, **kwargs) if active_lattice.post_processing: id, output = active_lattice.electron_outputs[0] active_lattice.electron_outputs.pop(0) return output.get_deserialized() # Setting metadata for default values according to lattice's metadata. for k in self.metadata: if ( k not in consumable_constraints and k in DEFAULT_METADATA_VALUES and not self.get_metadata(k) ): meta = active_lattice.get_metadata(k) if not meta: meta = DEFAULT_METADATA_VALUES[k] self.set_metadata(k, meta) # Add a node to the transport graph of the active lattice self.node_id = active_lattice.transport_graph.add_node( name=sublattice_prefix + self.function.__name__ if isinstance(self.function, Lattice) else self.function.__name__, function=self.function, metadata=self.metadata.copy(), function_string=get_serialized_function_str(self.function), ) if self.function: named_args, named_kwargs = get_named_params(self.function, args, kwargs) # For positional arguments # We use the fact that as of Python 3.6, dict order == insertion order for arg_index, item in enumerate(named_args.items()): key, value = item self.connect_node_with_others( self.node_id, key, value, "arg", arg_index, active_lattice.transport_graph ) # For keyword arguments # Filter out kwargs to be injected by call_before calldeps at execution call_before = self.metadata["call_before"] retval_keywords = {item["attributes"]["retval_keyword"]: None for item in call_before} for key, value in named_kwargs.items(): if key in retval_keywords: app_log.debug( f"kwarg {key} for function {self.function.__name__} to be injected at runtime" ) continue self.connect_node_with_others( self.node_id, key, value, "kwarg", None, active_lattice.transport_graph ) return Electron( self.function, metadata=self.metadata, node_id=self.node_id, )
[docs] def connect_node_with_others( self, node_id: int, param_name: str, param_value: Union[Any, "Electron"], param_type: str, arg_index: int, transport_graph: "_TransportGraph", ): """ Adds a node along with connecting edges for all the arguments to the electron. Args: node_id: Node number of the electron param_name: Name of the parameter param_value: Value of the parameter param_type: Type of parameter, positional or keyword transport_graph: Transport graph of the lattice Returns: None """ collection_metadata = encode_metadata(DEFAULT_METADATA_VALUES.copy()) if "executor" in self.metadata: collection_metadata["executor"] = self.metadata["executor"] collection_metadata["executor_data"] = self.metadata["executor_data"] if isinstance(param_value, Electron): transport_graph.add_edge( param_value.node_id, node_id, edge_name=param_name, param_type=param_type, arg_index=arg_index, ) elif isinstance(param_value, list): def _auto_list_node(*args, **kwargs): return list(args) list_electron = Electron(function=_auto_list_node, metadata=collection_metadata) bound_electron = list_electron(*param_value) transport_graph.set_node_value(bound_electron.node_id, "name", electron_list_prefix) transport_graph.add_edge( list_electron.node_id, node_id, edge_name=param_name, param_type=param_type, arg_index=arg_index, ) elif isinstance(param_value, dict): def _auto_dict_node(*args, **kwargs): return dict(kwargs) dict_electron = Electron(function=_auto_dict_node, metadata=collection_metadata) bound_electron = dict_electron(**param_value) transport_graph.set_node_value(bound_electron.node_id, "name", electron_dict_prefix) transport_graph.add_edge( dict_electron.node_id, node_id, edge_name=param_name, param_type=param_type, arg_index=arg_index, ) else: encoded_param_value = TransportableObject.make_transportable(param_value) parameter_node = transport_graph.add_node( name=parameter_prefix + str(param_value), function=None, metadata=encode_metadata(DEFAULT_METADATA_VALUES.copy()), value=encoded_param_value, ) transport_graph.add_edge( parameter_node, node_id, edge_name=param_name, param_type=param_type, arg_index=arg_index, )
[docs] def add_collection_node_to_graph(self, graph: "_TransportGraph", prefix: str) -> int: """ Adds the node to lattice's transport graph in the case where a collection of electrons is passed as an argument to another electron. Args: graph: Transport graph of the lattice prefix: Prefix of the node Returns: node_id: Node id of the added node """ new_metadata = encode_metadata(DEFAULT_METADATA_VALUES.copy()) if "executor" in self.metadata: new_metadata["executor"] = self.metadata["executor"] new_metadata["executor_data"] = self.metadata["executor_data"] node_id = graph.add_node( name=prefix, function=to_decoded_electron_collection, metadata=new_metadata, function_string=get_serialized_function_str(to_decoded_electron_collection), ) return node_id
[docs] def wait_for(self, electrons: Union["Electron", Iterable["Electron"]]): """ Waits for the given electrons to complete before executing this one. Adds the necessary edges between this and those electrons without explicitly connecting their inputs/outputs. Useful when execution of this electron relies on a side-effect from the another one. Args: electrons: Electron(s) which will be waited for to complete execution before starting execution for this one Returns: Electron """ active_lattice = active_lattice_manager.get_active_lattice() # Just using list(electrons) will not work since we are overriding the __iter__ # method for an Electron which results in it essentially disappearing, thus using # [electrons] to create the list if there's a single electron electrons = [electrons] if isinstance(electrons, Electron) else list(electrons) for el in electrons: active_lattice.transport_graph.add_edge( el.node_id, self.node_id, edge_name=WAIT_EDGE_NAME, wait_for=True, ) return Electron( self.function, metadata=self.metadata, node_id=self.node_id, )
@property def as_transportable_dict(self) -> Dict: """Get transportable electron object and metadata.""" return { "name": self.function.__name__, "function": TransportableObject(self.function).to_dict(), "function_string": get_serialized_function_str(self.function), "metadata": filter_null_metadata(self.metadata), }
def electron( _func: Optional[Callable] = None, *, backend: Optional[str] = None, executor: Optional[Union[List[Union[str, "BaseExecutor"]], Union[str, "BaseExecutor"]]] = None, # Add custom metadata fields here files: List[FileTransfer] = [], deps_bash: Union[DepsBash, List, str] = None, deps_pip: Union[DepsPip, list] = None, call_before: Union[List[DepsCall], DepsCall] = [], call_after: Union[List[DepsCall], DepsCall] = [], ) -> Callable: """Electron decorator to be called upon a function. Returns the wrapper function with the same functionality as `_func`. Args: _func: function to be decorated Keyword Args: backend: DEPRECATED: Same as `executor`. executor: Alternative executor object to be used by the electron execution. If not passed, the dask executor is used by default. deps_bash: An optional DepsBash object specifying a list of shell commands to run before `_func` deps_pip: An optional DepsPip object specifying a list of PyPI packages to install before running `_func` call_before: An optional list of DepsCall objects specifying python functions to invoke before the electron call_after: An optional list of DepsCall objects specifying python functions to invoke after the electron files: An optional list of FileTransfer objects which copy files to/from remote or local filesystems. Returns: :obj:`Electron <covalent._workflow.electron.Electron>` : Electron object inside which the decorated function exists. """ if backend: app_log.warning( "backend is deprecated and will be removed in a future release. Please use executor keyword instead.", exc_info=DeprecationWarning, ) executor = backend deps = {} if isinstance(deps_bash, DepsBash): deps["bash"] = deps_bash if isinstance(deps_bash, (list, str)): deps["bash"] = DepsBash(commands=deps_bash) internal_call_before_deps = [] internal_call_after_deps = [] if files: for file_transfer in files: _file_transfer_pre_hook_, _file_transfer_call_dep_ = file_transfer.cp() # pre-file transfer hook to create any necessary temporary files internal_call_before_deps.append( DepsCall( _file_transfer_pre_hook_, retval_keyword=RESERVED_RETVAL_KEY__FILES, override_reserved_retval_keys=True, ) ) if file_transfer.order == Order.AFTER: internal_call_after_deps.append(DepsCall(_file_transfer_call_dep_)) else: internal_call_before_deps.append(DepsCall(_file_transfer_call_dep_)) if isinstance(deps_pip, DepsPip): deps["pip"] = deps_pip if isinstance(deps_pip, list): deps["pip"] = DepsPip(packages=deps_pip) if isinstance(call_before, DepsCall): call_before = [call_before] if isinstance(call_after, DepsCall): call_after = [call_after] call_before = internal_call_before_deps + call_before call_after = internal_call_after_deps + call_after constraints = { "executor": executor, "deps": deps, "call_before": call_before, "call_after": call_after, } constraints = encode_metadata(constraints) def decorator_electron(func=None): """Electron decorator function""" electron_object = Electron(func) for k, v in constraints.items(): electron_object.set_metadata(k, v) electron_object.__doc__ = func.__doc__ @wraps(func) def wrapper(*args, **kwargs): return electron_object(*args, **kwargs) wrapper.electron_object = electron_object return wrapper if _func is None: # decorator is called with arguments return decorator_electron else: # decorator is called without arguments return decorator_electron(_func) def wait(child, parents): """Instructs Covalent that an electron should wait for some other tasks to complete before it is dispatched. Args: child: the dependent electron parents: Electron(s) which must complete before `waiting_electron` starts Returns: waiting_electron Useful when execution of an electron relies on a side-effect from another one. """ active_lattice = active_lattice_manager.get_active_lattice() if active_lattice and not active_lattice.post_processing: return child.wait_for(parents) else: return child @electron def to_decoded_electron_collection(**x): """Interchanges order of serialize -> collection""" collection = list(x.values())[0] if isinstance(collection, list): return TransportableObject.deserialize_list(collection) elif isinstance(collection, dict): return TransportableObject.deserialize_dict(collection)