# Copyright 2023 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import json
from abc import abstractmethod
import requests
from .._dispatcher_plugins import local
from .._shared_files import logger
from .._shared_files.config import get_config
from .._shared_files.util_classes import RESULT_STATUS, Status
app_log = logger.app_log
log_stack_info = logger.log_stack_info
[docs]class BaseTrigger:
"""
Base class to be subclassed by any custom defined trigger.
Implements all the necessary methods used for interacting with dispatches, including
getting their statuses and performing a redispatch of them whenever the trigger gets triggered.
Args:
lattice_dispatch_id: Dispatch ID of the worfklow which has to be redispatched in case this trigger gets triggered
dispatcher_addr: Address of dispatcher server used to retrieve info about or redispatch any dispatches
triggers_server_addr: Address of the Triggers server (if there is any) to register this trigger to,
uses the dispatcher's address by default
Attributes:
self.lattice_dispatch_id: Dispatch ID of the worfklow which has to be redispatched in case this trigger gets triggered
self.dispatcher_addr: Address of dispatcher server used to retrieve info about or redispatch any dispatches
self.triggers_server_addr: Address of the Triggers server (if there is any) to register this trigger to,
uses the dispatcher's address by default
self.new_dispatch_ids: List of all the newly created dispatch ids from performing redispatch
self.observe_blocks: Boolean to indicate whether the `self.observe` method is a blocking call
self.event_loop: Event loop to be used if directly calling dispatcher's functions instead of the REST APIs
self.use_internal_funcs: Boolean indicating whether to use dispatcher's functions directly instead of through API calls
self.stop_flag: To handle stopping mechanism in a thread safe manner in case `self.observe()` is a blocking call (e.g. see TimeTrigger)
"""
def __init__(
self,
lattice_dispatch_id: str = None,
dispatcher_addr: str = None,
triggers_server_addr: str = None,
):
self.lattice_dispatch_id = lattice_dispatch_id
self.dispatcher_addr = dispatcher_addr
self.triggers_server_addr = triggers_server_addr
self.new_dispatch_ids = []
self.observe_blocks = True
self.event_loop = (
None # to attach the event loop when directly using dispatcher's functions
)
self.use_internal_funcs = (
True # whether to use dispatcher's functions directly instead of through API calls
)
self.stop_flag = None # to handle stopping mechanism in a thread safe manner in case observe() is a blocking call (e.g. see TimeTrigger)
[docs] def register(self) -> None:
"""
Register this trigger to the Triggers server and start observing.
"""
self._register(self.to_dict())
@staticmethod
def _register(trigger_data) -> None:
"""
Register a trigger to the Triggers server given only its dictionary format and start observing.
Args:
trigger_data: Dictionary representation of a trigger
"""
triggers_server_addr = trigger_data.get("triggers_server_addr")
if triggers_server_addr is None:
triggers_server_addr = (
get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port"))
)
register_trigger_url = f"http://{triggers_server_addr}/api/triggers/register"
r = requests.post(register_trigger_url, json=trigger_data)
r.raise_for_status()
def _get_status(self) -> Status:
"""
Get status about the connected dispatch id to check whether its a pending
dispatch or new redispatch has to be made.
Returns:
status: Status
"""
if self.use_internal_funcs:
from covalent_dispatcher._service.app import export_result
response = asyncio.run_coroutine_threadsafe(
export_result(self.lattice_dispatch_id, status_only=True),
self.event_loop,
).result()
if isinstance(response, dict):
return response["status"]
return json.loads(response.body.decode()).get("status")
from .. import get_result
return get_result(
self.lattice_dispatch_id, status_only=True, dispatcher_addr=self.dispatcher_addr
)["status"]
def _do_redispatch(self, is_pending: bool = False) -> str:
"""
Perform a redispatch of the connected dispatch id and return a new one.
Args:
is_pending: Whether the connected dispatch id is pending
Returns:
new_dispatch_id: Dispatch id of the newly dispatched workflow
"""
if is_pending:
return local.LocalDispatcher.start(self.lattice_dispatch_id, self.dispatcher_addr)
else:
return local.LocalDispatcher.redispatch(
self.lattice_dispatch_id, self.dispatcher_addr
)()
[docs] def trigger(self) -> None:
"""
Trigger this trigger and perform a redispatch of the connected dispatch id's workflow.
Should be called within `self.observe()` whenever a trigger action is desired.
Raises:
RuntimeError: In case no dispatch id is connected to this trigger
"""
if not self.lattice_dispatch_id:
raise RuntimeError(
"`lattice_dispatch_id` is None. Please attach this trigger to a lattice before triggering."
)
status = self._get_status()
if status == str(RESULT_STATUS.NEW_OBJECT) or status is None:
# To continue the pending dispatch
same_dispatch_id = self._do_redispatch(True)
app_log.debug(f"Initiating run for pending dispatch_id: {same_dispatch_id}")
else:
# To run new redispatch
new_dispatch_id = self._do_redispatch()
app_log.debug(f"Redispatching, new dispatch_id: {new_dispatch_id}")
self.new_dispatch_ids.append(new_dispatch_id)
[docs] def to_dict(self) -> dict:
"""
Return a dictionary representation of this trigger which can later be used to regenerate it.
Returns:
tr_dict: Dictionary representation of this trigger
"""
tr_dict = self.__dict__.copy()
tr_dict["name"] = str(self.__class__.__name__)
return tr_dict
[docs] @abstractmethod
def observe(self):
"""
Start observing for any change which can be used to trigger this trigger.
To be implemented by the subclass.
"""
pass
[docs] @abstractmethod
def stop(self):
"""
Stop observing for changes.
To be implemented by the subclass.
"""
pass