Source code for covalent_awslambda_plugin.awslambda

# Copyright 2021 Agnostiq Inc.
#
# This file is part of Covalent.
#
# Licensed under the GNU Affero General Public License 3.0 (the "License").
# A copy of the License may be obtained with this software package or at
#
#      https://www.gnu.org/licenses/agpl-3.0.en.html
#
# Use of this file is prohibited except in compliance with the License. Any
# modifications or derivative works of this file must retain this copyright
# notice, and modified files must contain a notice indicating that they have
# been altered from the originals.
#
# Covalent is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the License for more details.
#
# Relief from the License may be granted by purchasing a commercial license.

import asyncio
import json
import os
from contextlib import contextmanager
from typing import Callable, Dict, List, Tuple

import boto3
import botocore.exceptions
import cloudpickle as pickle
from boto3.session import Session
from covalent._shared_files import logger
from covalent._shared_files.config import get_config
from covalent_aws_plugins import AWSExecutor

app_log = logger.app_log
log_stack_info = logger.log_stack_info

executor_plugin_name = "AWSLambdaExecutor"

_EXECUTOR_PLUGIN_DEFAULTS = {
    "function_name": "covalent-awslambda-executor",
    "credentials_file": "",
    "profile": "",
    "region": "",
    "s3_bucket_name": "covalent-lambda-job-resources",
    "execution_role": "CovalentLambdaExecutionRole",
    "poll_freq": 5,
    "timeout": 900,
}

FUNC_FILENAME = "func-{dispatch_id}-{node_id}.pkl"
RESULT_FILENAME = "result-{dispatch_id}-{node_id}.pkl"
EXCEPTION_FILENAME = "exception-{dispatch_id}-{node_id}.json"


[docs]class AWSLambdaExecutor(AWSExecutor): """AWS Lambda executor plugin Args: function_name: Name of an existing lambda function to use during execution (default: `covalent-awsambda-executor`) s3_bucket_name: Name of a AWS S3 bucket that the executor can use to store temporary files (default: `covalent-lambda-job-resources`) execution_role: Name of the IAM role assigned to the AWS Lambda function credentials_file: Path to AWS credentials file (default: `~/.aws/credentials`) profile: AWS profile (default: `default`) region: AWS region (default: `us-east-1`) poll_freq: Time interval between successive polls to the lambda function (default: `5`) timeout: Duration in seconds to poll Lambda function for results (default: `900`) """ def __init__( self, function_name: str = None, s3_bucket_name: str = None, credentials_file: str = None, profile: str = None, region: str = None, execution_role: str = "", poll_freq: int = None, timeout: int = 900, ) -> None: # AWSExecutor parameters required_attrs = { "credentials_file": credentials_file or get_config("executors.awslambda.credentials_file"), "profile": profile or get_config("executors.awslambda.profile"), "region": region or get_config("executors.awslambda.region"), "s3_bucket_name": s3_bucket_name or get_config("executors.awslambda.s3_bucket_name"), "execution_role": execution_role or get_config("executors.awslambda.execution_role"), } super().__init__(**required_attrs) # Lambda executor parameters self.function_name = ( function_name or get_config("executors.awslambda.function_name") or "covalent-awslambda-executor" ) self.poll_freq = poll_freq or get_config("executors.awslambda.poll_freq") self.timeout = timeout or get_config("executors.awslambda.timeout")
[docs] @contextmanager def get_session(self) -> Session: """Yield a boto3 session to be used for instantiating AWS service clients/resources Args: None Returns: session: AWS boto3.Session object """ yield boto3.Session(**self.boto_session_options())
def _upload_task_sync(self, workdir: str, func_filename: str): """ Upload the function file to remote Args: workdir: Work dir on remote to upload file to func_filename: Name of the function file Returns: None """ app_log.debug(f"Uploading function to S3 bucket {self.s3_bucket_name}") with self.get_session() as session: client = session.client("s3") try: with open(os.path.join(workdir, func_filename), "rb") as f: client.upload_fileobj(f, self.s3_bucket_name, func_filename) except botocore.exceptions.ClientError as ce: app_log.exception(ce) raise app_log.debug(f"Function {func_filename} uploaded to S3 bucket {self.s3_bucket_name}") async def _upload_task(self, workdir: str, func_filename: str): """Method to upload task.""" loop = asyncio.get_running_loop() fut = loop.run_in_executor(None, self._upload_task_sync, workdir, func_filename) await fut
[docs] def submit_task_sync( self, function_name: str, func_filename: str, result_filename: str, exception_filename: str ) -> Dict: """The actual (blocking) submit_task function""" app_log.debug(f"Invoking AWS Lambda function {function_name}") with self.get_session() as session: client = session.client("lambda") try: return client.invoke( FunctionName=function_name, Payload=json.dumps( { "S3_BUCKET_NAME": self.s3_bucket_name, "COVALENT_TASK_FUNC_FILENAME": func_filename, "RESULT_FILENAME": result_filename, "EXCEPTION_FILENAME": exception_filename, } ), InvocationType="Event", ) except botocore.exceptions.ClientError as ce: app_log.exception(ce) raise
[docs] async def submit_task( self, function_name: str, func_filename: str, result_filename: str, exception_filename: str ) -> Dict: """ Submit the task by invoking the AWS Lambda function Args: function_name: AWS Lambda function name Returns: response: AWS boto3 client invoke lambda response """ loop = asyncio.get_running_loop() fut = loop.run_in_executor( None, self.submit_task_sync, function_name, func_filename, result_filename, exception_filename, ) return await fut
def get_status_sync(self, object_key: str) -> bool: with self.get_session() as session: s3_client = session.client("s3") try: s3_client.head_object(Bucket=self.s3_bucket_name, Key=object_key) except botocore.exceptions.ClientError: return False return True
[docs] async def get_status(self, object_key: str): """ Return status of availability of result object on remote machine Args: object_key: Name of the S3 object Returns: bool indicating whether the object exists or not on S3 bucket """ loop = asyncio.get_running_loop() fut = loop.run_in_executor(None, self.get_status_sync, object_key) return await fut
async def _poll_task(self, object_keys: List[str]) -> str: """ Poll task until its result is ready Args: object_key: Name of the object to check if present in S3 """ time_left = self.timeout while time_left > 0: for object_key in object_keys: app_log.debug(f"Polling object: {object_key}") status = await self.get_status(object_key) if status: return object_key await asyncio.sleep(self.poll_freq) time_left -= self.poll_freq raise TimeoutError(f"{object_keys} not found in {self.s3_bucket_name}")
[docs] def query_task_exception_sync(self, workdir: str, exception_filename: str): """ Fetch the exception raised from the S3 bucket Args: workdir: Path on the local file system where the exception json dump is downloaded Returns: None """ with self.get_session() as session: s3_client = session.client("s3") # Download file try: s3_client.download_file( self.s3_bucket_name, exception_filename, os.path.join(workdir, exception_filename), ) except botocore.exceptions.ClientError as ce: app_log.exception(ce) raise with open(os.path.join(workdir, exception_filename), "r") as f: task_exception = json.load(f) return task_exception
async def query_task_exception(self, workdir: str, exception_filename: str): loop = asyncio.get_running_loop() fut = loop.run_in_executor( None, self.query_task_exception_sync, workdir, exception_filename ) return await fut
[docs] def query_result_sync(self, workdir: str, result_filename: str): """ Fetch the result object from the S3 bucket Args: workdir: Path on the local file system where the pickled object is downloaded Returns: None """ with self.get_session() as session: s3_client = session.client("s3") # Download file try: s3_client.download_file( self.s3_bucket_name, result_filename, os.path.join(workdir, result_filename), ) except botocore.exceptions.ClientError as ce: app_log.exception(ce) raise with open(os.path.join(workdir, result_filename), "rb") as f: result_object = pickle.load(f) return result_object
[docs] async def query_result(self, workdir: str, result_filename: str): loop = asyncio.get_running_loop() fut = loop.run_in_executor(None, self.query_result_sync, workdir, result_filename) return await fut
def _pickle_func_sync( self, function: Callable, workdir: str, func_filename: str, args: List, kwargs: Dict ): """Method to pickle function synchronously.""" app_log.debug("Pickling function, args and kwargs..") with open(os.path.join(workdir, func_filename), "wb") as f: pickle.dump((function, args, kwargs), f) async def _pickle_func( self, function: Callable, workdir: str, func_filename: str, args: List, kwargs: Dict ): """Pickle function asynchronously.""" loop = asyncio.get_running_loop() fut = loop.run_in_executor( None, self._pickle_func_sync, function, workdir, func_filename, args, kwargs ) return await fut
[docs] async def run(self, function: Callable, args: List, kwargs: Dict, task_metadata: Dict): """Run the executor Args: function: Python callable to be executed on the remote executor args: List of positional arguments to be passed to the function kwargs: Keyword arguments to be passed into the function task_metadata: Dictionary containing the task dispatch_id and node_id Returns: None """ dispatch_id = task_metadata["dispatch_id"] node_id = task_metadata["node_id"] workdir = self.cache_dir func_filename = FUNC_FILENAME.format(dispatch_id=dispatch_id, node_id=node_id) result_filename = RESULT_FILENAME.format(dispatch_id=dispatch_id, node_id=node_id) exception_filename = EXCEPTION_FILENAME.format(dispatch_id=dispatch_id, node_id=node_id) app_log.debug(f"In run for task - {dispatch_id} - {node_id} ... ") # Pickle function asynchronously await self._pickle_func(function, workdir, func_filename, args, kwargs) # Upload pickled file to s3 bucket created await self._upload_task(workdir, func_filename) # Invoke the created lambda lambda_invocation_response = await self.submit_task( self.function_name, func_filename, result_filename, exception_filename ) app_log.debug(f"Lambda function response: {lambda_invocation_response}") if "FunctionError" in lambda_invocation_response: error = lambda_invocation_response["Payload"].read().decode("utf-8") raise RuntimeError( f"Exception occurred while running task {dispatch_id}:{node_id}: {error}" ) # Poll task object_key = await self._poll_task([result_filename, exception_filename]) if object_key == exception_filename: # Download the raised exception app_log.debug( f"Retrieving exception raised during task execution - {dispatch_id} - {node_id}" ) exception = await self.query_task_exception(workdir, exception_filename) app_log.debug(f"Exception retrived for task - {dispatch_id} - {node_id}") raise RuntimeError(exception) if object_key == result_filename: # Download the result object app_log.debug(f"Retrieving result for task - {dispatch_id} - {node_id}") result_object = await self.query_result(workdir, result_filename) app_log.debug(f"Result retrived for task - {dispatch_id} - {node_id}") return result_object
[docs] def cancel(self) -> None: """ Cancel execution """ raise NotImplementedError("Cancellation is currently not supported")
# copied from RemoteExecutor
[docs] @staticmethod async def run_async_subprocess(cmd) -> Tuple: """ Invokes an async subprocess to run a command. """ proc = await asyncio.create_subprocess_shell( cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE ) stdout, stderr = await proc.communicate() if stdout: app_log.debug(stdout) if stderr: app_log.debug(stderr) return proc, stdout, stderr