Source code for covalent.triggers.database_trigger
# Copyright 2023 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the Apache License 2.0 (the "License"). A copy of the
# License may be obtained with this software package or at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Use of this file is prohibited except in compliance with the License.
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from functools import partial
from threading import Event
from typing import List
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from covalent._shared_files import logger
from .base import BaseTrigger
app_log = logger.app_log
log_stack_info = logger.log_stack_info
[docs]class DatabaseTrigger(BaseTrigger):
"""
Database trigger which can read for changes in a database
and trigger workflows based on record changes.
Args:
db_path: Connection string for the database
table_name: Name of the table to observe
poll_interval: Time in seconds to wait for before reading the database again
where_clauses: List of "WHERE" conditions, e.g. ["id > 2", "status = pending"], to check when
polling the database
trigger_after_n: Number of times the event must happen after which the workflow will be triggered.
e.g value of 2 means workflow will be triggered once the event has occurred twice.
lattice_dispatch_id: Lattice dispatch id of the workflow to be triggered
dispatcher_addr: Address of the dispatcher server
triggers_server_addr: Address of the triggers server
Attributes:
self.db_path: Connection string for the database
self.table_name: Name of the table to observe
self.poll_interval: Time in seconds to wait for before reading the database again
self.where_clauses: List of "WHERE" conditions, e.g. ["id > 2", "status = pending"], to check when
polling the database
self.trigger_after_n: Number of times the event must happen after which the workflow will be triggered.
e.g value of 2 means workflow will be triggered once the event has occurred twice.
self.stop_flag: Thread safe flag used to check whether the stop condition has been met
"""
def __init__(
self,
db_path: str,
table_name: str,
poll_interval: int = 1,
where_clauses: List[str] = None,
trigger_after_n: int = 1,
lattice_dispatch_id: str = None,
dispatcher_addr: str = None,
triggers_server_addr: str = None,
):
super().__init__(lattice_dispatch_id, dispatcher_addr, triggers_server_addr)
self.db_path = db_path
self.table_name = table_name
self.poll_interval = poll_interval
self.where_clauses = where_clauses
self.trigger_after_n = trigger_after_n
self.stop_flag = None
[docs] def observe(self) -> None:
"""
Keep performing the trigger action as long as
where conditions are met or until stop has being called
"""
app_log.debug("Inside DatabaseTrigger's observe")
event_count = 0
try:
self.engine = create_engine(self.db_path)
with Session(self.engine) as db:
sql_poll_cmd = f"SELECT * FROM {self.table_name}"
if self.where_clauses:
sql_poll_cmd += " WHERE "
sql_poll_cmd += " AND ".join(list(self.where_clauses))
sql_poll_cmd += ";"
execute_cmd = partial(db.execute, sql_poll_cmd)
app_log.debug(f"Poll command: {sql_poll_cmd}")
self.stop_flag = Event()
while not self.stop_flag.is_set():
# Read the DB with specified command
try:
app_log.debug("About to execute...")
if rows := execute_cmd().all():
event_count += 1
if event_count == self.trigger_after_n:
app_log.debug("Invoking trigger")
self.trigger()
event_count = 0
except Exception:
pass
time.sleep(self.poll_interval)
except Exception as e:
app_log.debug("Failed to observe:")
raise
[docs] def stop(self) -> None:
"""
Stop the running `self.observe()` method by setting the `self.stop_flag` flag.
"""
self.stop_flag.set()