Source code for covalent._workflow.transport

# Copyright 2021 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the GNU Affero General Public License 3.0 (the "License").
# A copy of the License may be obtained with this software package or at
#
#      https://www.gnu.org/licenses/agpl-3.0.en.html
#
# Use of this file is prohibited except in compliance with the License. Any
# modifications or derivative works of this file must retain this copyright
# notice, and modified files must contain a notice indicating that they have
# been altered from the originals.
#
# Covalent is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details.
#
# Relief from the License may be granted by purchasing a commercial license.

"""Class implementation of the transport graph in the workflow graph."""

import base64
import json
import platform
from copy import deepcopy
from typing import Any, Callable, Dict

import cloudpickle
import networkx as nx

from .._shared_files.defaults import parameter_prefix
from .._shared_files.util_classes import RESULT_STATUS


class TransportableObject:
    """
    A function is converted to a transportable object by serializing it using cloudpickle
    and then whenever executing it, the transportable object is deserialized. The object
    will also contain additional info like the python version used to serialize it.

    Attributes:
        _object: The serialized object.
        python_version: The python version used on the client's machine.
    """

    def __init__(self, obj: Any) -> None:
        self._object = base64.b64encode(cloudpickle.dumps(obj)).decode("utf-8")
        self.python_version = platform.python_version()

        self.object_string = str(obj)

        try:
            self._json = json.dumps(obj)

        except TypeError as ex:
            self._json = ""

        self.attrs = {"doc": getattr(obj, "__doc__", ""), "name": getattr(obj, "__name__", "")}

    def __eq__(self, obj) -> bool:
        return self.__dict__ == obj.__dict__ if isinstance(obj, TransportableObject) else False

[docs] def get_deserialized(self) -> Callable: """ Get the deserialized transportable object. Args: None Returns: function: The deserialized object/callable function. """ return cloudpickle.loads(base64.b64decode(self._object.encode("utf-8")))
@property def json(self): return self._json
[docs] def to_dict(self) -> dict: """Return a JSON-serializable dictionary representation of self""" return {"type": "TransportableObject", "attributes": self.__dict__.copy()}
[docs] @staticmethod def from_dict(object_dict) -> "TransportableObject": """Rehydrate a dictionary representation Args: object_dict: a dictionary representation returned by `to_dict` Returns: A `TransportableObject` represented by `object_dict` """ sc = TransportableObject(None) sc.__dict__ = object_dict["attributes"] return sc
[docs] def get_serialized(self) -> str: """ Get the serialized transportable object. Args: None Returns: object: The serialized transportable object. """ return self._object
[docs] def serialize(self) -> bytes: """ Serialize the transportable object. Args: None Returns: pickled_object: The serialized object alongwith the python version. """ return cloudpickle.dumps( { "object": self.get_serialized(), "object_string": self.object_string, "json": self._json, "attrs": self.attrs, "py_version": self.python_version, } )
[docs] def serialize_to_json(self) -> str: """ Serialize the transportable object to JSON. Args: None Returns: A JSON string representation of the transportable object """ return json.dumps(self.to_dict())
[docs] @staticmethod def deserialize_from_json(json_string: str) -> str: """ Reconstruct a transportable object from JSON Args: json_string: A JSON string representation of a TransportableObject Returns: A TransportableObject instance """ object_dict = json.loads(json_string) return TransportableObject.from_dict(object_dict)
[docs] @staticmethod def make_transportable(obj) -> "TransportableObject": if isinstance(obj, TransportableObject): return obj else: return TransportableObject(obj)
[docs] @staticmethod def deserialize(data: bytes) -> "TransportableObject": """ Deserialize the transportable object. Args: data: Cloudpickled function. Returns: object: The deserialized transportable object. """ obj = cloudpickle.loads(data) sc = TransportableObject(None) sc._object = obj["object"] sc._json = obj["json"] sc.attrs = obj["attrs"] sc.python_version = obj["py_version"] return sc
[docs] @staticmethod def deserialize_list(collection: list) -> list: """ Recursively deserializes a list of TransportableObjects. More precisely, `collection` is a list, each of whose entries is assumed to be either a `TransportableObject`, a list, or dict` """ new_list = [] for item in collection: if isinstance(item, TransportableObject): new_list.append(item.get_deserialized()) elif isinstance(item, list): new_list.append(TransportableObject.deserialize_list(item)) elif isinstance(item, dict): new_list.append(TransportableObject.deserialize_dict(item)) else: raise TypeError("Couldn't deserialize collection") return new_list
[docs] @staticmethod def deserialize_dict(collection: dict) -> dict: """ Recursively deserializes a dict of TransportableObjects. More precisely, `collection` is a dict, each of whose entries is assumed to be either a `TransportableObject`, a list, or dict` """ new_dict = {} for k, item in collection.items(): if isinstance(item, TransportableObject): new_dict[k] = item.get_deserialized() elif isinstance(item, list): new_dict[k] = TransportableObject.deserialize_list(item) elif isinstance(item, dict): new_dict[k] = TransportableObject.deserialize_dict(item) else: raise TypeError("Couldn't deserialize collection") return new_dict
# Functions for encoding the transport graph def encode_metadata(metadata: dict) -> dict: # Idempotent # Special handling required for: executor, workflow_executor, deps, call_before/after, triggers encoded_metadata = deepcopy(metadata) if "executor" in metadata: if "executor_data" not in metadata: encoded_metadata["executor_data"] = {} if metadata["executor"] is not None and not isinstance(metadata["executor"], str): encoded_executor = metadata["executor"].to_dict() encoded_metadata["executor"] = encoded_executor["short_name"] encoded_metadata["executor_data"] = encoded_executor if "workflow_executor" in metadata: if "workflow_executor_data" not in metadata: encoded_metadata["workflow_executor_data"] = {} if metadata["workflow_executor"] is not None and not isinstance( metadata["workflow_executor"], str ): encoded_wf_executor = metadata["workflow_executor"].to_dict() encoded_metadata["workflow_executor"] = encoded_wf_executor["short_name"] encoded_metadata["workflow_executor_data"] = encoded_wf_executor # Bash Deps, Pip Deps, Env Deps, etc if "deps" in metadata and metadata["deps"] is not None: for dep_type, dep_object in metadata["deps"].items(): if dep_object and not isinstance(dep_object, dict): encoded_metadata["deps"][dep_type] = dep_object.to_dict() # call_before/after if "call_before" in metadata and metadata["call_before"] is not None: for i, dep in enumerate(metadata["call_before"]): if not isinstance(dep, dict): encoded_metadata["call_before"][i] = dep.to_dict() if "call_after" in metadata and metadata["call_after"] is not None: for i, dep in enumerate(metadata["call_after"]): if not isinstance(dep, dict): encoded_metadata["call_after"][i] = dep.to_dict() # triggers if "triggers" in metadata: if isinstance(metadata["triggers"], list): encoded_metadata["triggers"] = [] for tr in metadata["triggers"]: if isinstance(tr, dict): encoded_metadata["triggers"].append(tr) else: encoded_metadata["triggers"].append(tr.to_dict()) else: encoded_metadata["triggers"] = metadata["triggers"] return encoded_metadata class _TransportGraph: """ A TransportGraph is the most essential part of the whole workflow. This contains all the information about each electron and lattice required for determining how, when, and where to execute the workflow. The TransportGraph contains a directed graph which is used to determine the execution order of the nodes. Each node in this graph is an electron which is ready to be executed. Attributes: _graph: The directed graph object of type networkx.DiGraph(). lattice_metadata: The lattice metadata of the transport graph. """ def __init__(self) -> None: self._graph = nx.MultiDiGraph() self.lattice_metadata = None # IDs of nodes modified during the workflow run self.dirty_nodes = [] self._default_node_attrs = { "start_time": None, "end_time": None, "status": RESULT_STATUS.NEW_OBJECT, "output": None, "error": None, "sub_dispatch_id": None, "sublattice_result": None, "stdout": None, "stderr": None, } def add_node(self, name: str, function: Callable, metadata: Dict, **attr) -> int: """ Adds a node to the graph. Args: name: The name of the node. function: The function to be executed. metadata: The metadata of the node. attr: Any other attributes that need to be added to the node. Returns: node_key: The node id. """ node_id = len(self._graph.nodes) self._graph.add_node( node_id, name=name, function=TransportableObject(function), metadata=metadata, **attr, ) return node_id def add_edge(self, x: int, y: int, edge_name: Any, **attr) -> None: """ Adds an edge to the graph and assigns a name to it. Edge insertion order is not preserved in networkx. So in case of positional arguments passed into the electron, we need to preserve the order when we deserialize the request in the lattice. Args: x: The node id for first node. y: The node id for second node. edge_name: The name to be assigned to the edge. attr: Any other attributes that need to be added to the edge. Returns: None Raises: ValueError: If the edge already exists. """ self._graph.add_edge(x, y, edge_name=edge_name, **attr) def reset(self) -> None: """ Resets the graph. Args: None Returns: None """ self._graph = nx.MultiDiGraph() def get_node_value(self, node_key: int, value_key: str) -> Any: """ Get a specific value from a node depending upon the value key. Args: node_key: The node id. value_key: The value key. Returns: value: The value from the node stored at the value key. Raises: KeyError: If the value key or node key is not found. """ return self._graph.nodes[node_key][value_key] def set_node_value(self, node_key: int, value_key: int, value: Any) -> None: """ Set a certain value of a node. This allows for saving custom data in the graph nodes. Args: node_key: The node id. value_key: The value key. value: The value to be set at value_key position of the node. Returns: None Raises: KeyError: If the node key is not found. """ self.dirty_nodes.append(node_key) self._graph.nodes[node_key][value_key] = value def get_edge_data(self, dep_key: int, node_key: int) -> Any: """ Get the metadata for all edges between two nodes. Args: dep_key: The node id for first node. node_key: The node id for second node. Returns: values: A dict {edge_key : value} Raises: KeyError: If the edge is not found. """ return self._graph.get_edge_data(dep_key, node_key) def get_dependencies(self, node_key: int) -> list: """ Gets the parent node ids of a node. Args: node_key: The node id. Returns: parents: The dependencies of the node. """ return list(self._graph.predecessors(node_key)) def get_internal_graph_copy(self) -> nx.MultiDiGraph: """ Get a copy of the internal directed graph to avoid modifying the original graph. Args: None Returns: graph: A copy of the internal directed graph. """ return self._graph.copy() def reset_node(self, node_id: int) -> None: """Reset node values to starting state.""" for node_attr, default_val in self._default_node_attrs.items(): self.set_node_value(node_id, node_attr, default_val) def _replace_node(self, node_id: int, new_attrs: Dict[str, Any]) -> None: """Replace node data with new attribute values and flag descendants (used in re-dispatching).""" metadata = self.get_node_value(node_id, "metadata") metadata.update(new_attrs["metadata"]) serialized_callable = TransportableObject.from_dict(new_attrs["function"]) self.set_node_value(node_id, "function", serialized_callable) self.set_node_value(node_id, "function_string", new_attrs["function_string"]) self.set_node_value(node_id, "name", new_attrs["name"]) self._reset_descendants(node_id) def _reset_descendants(self, node_id: int) -> None: """Reset node and all its descendants to starting state.""" try: if self.get_node_value(node_id, "status") == RESULT_STATUS.NEW_OBJECT: return except Exception: return self.reset_node(node_id) for successor in self._graph.neighbors(node_id): self._reset_descendants(successor) def apply_electron_updates(self, electron_updates: Dict[str, Callable]) -> None: """Replace transport graph node data based on the electrons that need to be updated during re-dispatching.""" for n in self._graph.nodes: name = self.get_node_value(n, "name") if name in electron_updates: self._replace_node(n, electron_updates[name]) def serialize(self, metadata_only: bool = False) -> bytes: """ Convert transport graph object to JSON to be used in the workflow scheduler. Convert transport graph networkx.DiGraph object into JSON format, filter out computation specific attributes and lastly add the lattice metadata. This also serializes the function Callable into by base64 encoding the cloudpickled result. Args: metadata_only: If true, only serialize the metadata. Returns: str: json string representation of transport graph """ # Convert networkx.DiGraph to a format that can be converted to json . data = nx.readwrite.node_link_data(self._graph) # process each node for idx, node in enumerate(data["nodes"]): data["nodes"][idx]["function"] = data["nodes"][idx].pop("function").serialize() if metadata_only: parameter_node_id = [ i for i, node in enumerate(data["nodes"]) if node["name"].startswith(parameter_prefix) ] for node in data["nodes"].copy(): if node["id"] in parameter_node_id: data["nodes"].remove(node) # Remove the non-metadata fields such as 'function', 'name', etc from the scheduler workflow input data. for idx, node in enumerate(data["nodes"]): for field in data["nodes"][idx].copy(): if field != "metadata": data["nodes"][idx].pop(field, None) # Remove the non-source-target fields from the scheduler workflow input data. for idx, node in enumerate(data["links"]): for name in data["links"][idx].copy(): if name not in ["source", "target"]: data["links"][idx].pop("edge_name", None) data["lattice_metadata"] = self.lattice_metadata return cloudpickle.dumps(data) def serialize_to_json(self, metadata_only: bool = False) -> str: """ Convert transport graph object to JSON to be used in the workflow scheduler. Convert transport graph networkx.DiGraph object into JSON format, filter out computation specific attributes and lastly add the lattice metadata. This also serializes the function Callable into by base64 encoding the cloudpickled result. Args: metadata_only: If true, only serialize the metadata. Returns: str: json string representation of transport graph Note: serialize_to_json converts metadata objects into dictionary representations. """ # Convert networkx.DiGraph to a format that can be converted to json . data = nx.readwrite.node_link_data(self._graph) # process each node for idx, node in enumerate(data["nodes"]): data["nodes"][idx]["function"] = data["nodes"][idx].pop("function").to_dict() if "value" in node: node["value"] = node["value"].to_dict() if "metadata" in node: node["metadata"] = encode_metadata(node["metadata"]) if metadata_only: parameter_node_id = [ i for i, node in enumerate(data["nodes"]) if node["name"].startswith(parameter_prefix) ] for node in data["nodes"].copy(): if node["id"] in parameter_node_id: data["nodes"].remove(node) # Remove the non-metadata fields such as 'function', 'name', etc from the scheduler workflow input data. for idx, node in enumerate(data["nodes"]): for field in data["nodes"][idx].copy(): if field != "metadata": data["nodes"][idx].pop(field, None) # Remove the non-source-target fields from the scheduler workflow input data. for idx, node in enumerate(data["links"]): for name in data["links"][idx].copy(): if name not in ["source", "target"]: data["links"][idx].pop("edge_name", None) data["lattice_metadata"] = encode_metadata(self.lattice_metadata) return json.dumps(data) def deserialize(self, pickled_data: bytes) -> None: """ Load pickled representation of transport graph into the transport graph instance. This overwrites anything currently set in the transport graph and deserializes the base64 encoded cloudpickled function into a Callable. Args: pickled_data: Cloudpickled representation of the transport graph Returns: None """ node_link_data = cloudpickle.loads(pickled_data) if "lattice_metadata" in node_link_data: self.lattice_metadata = node_link_data["lattice_metadata"] for idx, _ in enumerate(node_link_data["nodes"]): function_ser = node_link_data["nodes"][idx].pop("function") node_link_data["nodes"][idx]["function"] = TransportableObject.deserialize( function_ser ) self._graph = nx.readwrite.node_link_graph(node_link_data) def deserialize_from_json(self, json_data: str) -> None: """Load JSON representation of transport graph into the transport graph instance. This overwrites anything currently set in the transport graph. Note that metadata (node and lattice-level) need to be reconstituted from their dictionary representations when needed. Args: json_data: JSON representation of the transport graph Returns: None """ node_link_data = json.loads(json_data) if "lattice_metadata" in node_link_data: self.lattice_metadata = node_link_data["lattice_metadata"] for idx, node in enumerate(node_link_data["nodes"]): function_ser = node_link_data["nodes"][idx].pop("function") node_link_data["nodes"][idx]["function"] = TransportableObject.from_dict(function_ser) if "value" in node: node["value"] = TransportableObject.from_dict(node["value"]) self._graph = nx.readwrite.node_link_graph(node_link_data)