Source code for core.luigihacks.autobatch

'''Automatic preparation, submission and consolidation
of AWS batch tasks; as a single Luigi Task.
'''

from abc import ABC
from abc import abstractmethod
from collections import defaultdict
from nesta.core.luigihacks import batchclient
from subprocess import check_output
from subprocess import CalledProcessError
import time
import luigi
from nesta.core.luigihacks.misctools import get_config
import logging
import os

# Define a global timeout, set to 95% of the timeout time
# in order to give the Luigi worker some grace
_config = get_config("luigi.cfg", "worker")


[docs]def command_line(command, verbose=False): '''Execute command line tasks and return the final output line. This is particularly useful for extracting the AWS access keys directly from the OS; as well as executing the environment preparation script (:code:`core/scripts/nesta_prepare_batch.sh`). ''' # Execute the command and decode the output out = check_output([command], shell=True) out_lines = out.decode("utf-8").split("\n") if verbose: for line in out_lines: logging.info(f"{os.getpid()}: " ">>>\t'{}'".format(line.replace("\r", ' '))) # The second last output is the actual final output # (ignoring the status code, which is the last output) return out_lines[-2]
[docs]class AutoBatchTask(luigi.Task, ABC): '''A base class for automatically preparing and submitting AWS batch tasks. Unlike regular Luigi :code:`Tasks`, which require the user to override the :code:`requires`, :code:`output` and :code:`run` methods, :code:`AutoBatchTask` instead effectively replaces :code:`run` with two new abstract methods: :code:`prepare` and :code:`combine`, which are repectively documented. With these abstract methods specified, :code:`AutoBatchTask` will automatically prepare, submit, and combine one batch task (specified in :code:`core.batchables`) per parameter set specified in the :code:`prepare` method. The :code:`combine` method will subsequently combine the outputs from the batch task. Args: batchable (str): Path to the directory containing the run.py batchable job_def (str): Name of the AWS job definition job_name (str): Name given to this AWS batch job job_queue (str): AWS batch queue region_name (str): AWS region from which to batch env_files (:obj:`list` of :obj:`str`, optional): List of names pointing to local environmental files (for example local imports or scripts) which should be zipped up with the AWS batch job environment. Defaults to []. vcpus (int, optional): Number of CPUs to request for the AWS batch job. Defaults to 1. memory (int, optional): Memory to request for the AWS batch job. Defaults to 512 MiB. max_runs (int, optional): Number of batch jobs to run, which is useful for testing a subset of the full pipeline, or making cost predictions for AWS computing time. Defaults to `None`, implying that all jobs should be run. poll_time (int, optional): Time in seconds between querying the AWS batch job status. Defaults to 60. success_rate (float, optional): If the fraction of FAILED jobs exceeds :code:`success_rate` then the entire Task, along with any submitted AWS batch jobs, is killed. The fraction is calculated with respect to any jobs with RUNNING, SUCCEEDED or FAILED status. Defaults to 0.75. ''' batchable = luigi.Parameter() job_def = luigi.Parameter() job_name = luigi.Parameter() job_queue = luigi.Parameter() region_name = luigi.Parameter() env_files = luigi.ListParameter(default=[]) vcpus = luigi.IntParameter(default=1) memory = luigi.IntParameter(default=512) max_runs = luigi.IntParameter(default=None) # For testing timeout = luigi.IntParameter(default=21600) poll_time = luigi.IntParameter(default=60) success_rate = luigi.FloatParameter(default=0.95) test = luigi.BoolParameter(default=True) max_live_jobs = luigi.IntParameter(default=25) worker_timeout = float('inf') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.failed_jobs = set()
[docs] def run(self): '''DO NOT OVERRIDE THIS METHOD. An implementation of the :code:`Luigi.Task.run` method which is a wrapper around the :code:`prepare`, :code:`execute` and :code:`combine` methods. Instead of overriding this method, you should implement :code:`prepare` and :code:`combine` methods in your class. ''' pid = os.getpid() self.TIMEOUT = time.time() + int(_config["timeout"]) # Generate the parameters for batches job_params = self.prepare() if self.test: if len(job_params) > 2: job_params = job_params[0:2] logging.info(f"Test mode ({pid}): running {len(job_params)} jobs") # Prepare the environment for batching env_files = " ".join(self.env_files) try: if self.test: logging.info(f"Test mode ({pid}): Preparing batch") s3file_timestamp = command_line("nesta_prepare_batch " "{} {}".format(self.batchable, env_files), self.test) if self.test: logging.info(f"Test mode ({pid}): Prepared batch") except CalledProcessError: raise batchclient.BatchJobException("Invalid input " "or environment files") # Execute batch jobs self.execute(job_params, s3file_timestamp) # Combine the outputs self.combine(job_params)
[docs] @abstractmethod def prepare(self): '''You should implement a method which returns a :code:`list` of :code:`dict`, where each :code:`dict` corresponds to inputs to the batchable. Each row of the output must at least contain the following keys: - **done** (`bool`): indicating whether the job has already been finished. - **outinfo** (`str`): Text indicating e.g. the location of the output, for use in the batch job and for `combine` method Returns: :obj:`list` of :obj:`dict` ''' pass
[docs] @abstractmethod def combine(self, job_params): '''You should implement a method which collects the outputs specified by the **outinfo** key of :code:`job_params`, which is the output from the :code:`prepare` method. This method should finally write to the :code:`luigi.Target` output. Parameters: job_params (:obj:`list` of :obj:`dict`): The batchable job parameters, as returned from the :code:`prepare` method. ''' pass
[docs] def execute(self, job_params, s3file_timestamp): ''' The secret sauce, which automatically submits and monitors the AWS batch jobs. Your AWS access key and id are automatically retrieved via the AWS CLI. Parameters: job_params (:obj:`list` of :obj:`dict`): The batchable job parameters, as returned from the :code:`prepare` method. Each job is submitted from every item in this :code:`list`. Each `dict` key-value per is converted into an environmental variable in the batch job, with the variable name formed from the key, prefixed by `BATCHPAR_`. s3file_timestamp (str): The timestamp of the batchable zip file to be found on S3 by the AWS batch job. ''' pid = os.getpid() # Get AWS info to pass to the batch jobs aws_id = command_line("aws --profile default configure " "get aws_access_key_id", self.test) aws_secret = command_line("aws --profile default configure " "get aws_secret_access_key", self.test) # Build a set of environmental variables to send to the jobs env_variables = [{"name": "AWS_ACCESS_KEY_ID", "value": aws_id}, {"name": "AWS_SECRET_ACCESS_KEY", "value": aws_secret}, {"name": "BATCHPAR_S3FILE_TIMESTAMP", "value": s3file_timestamp}] #{"name": "PYTHONIOENCODING", "value": "latin1"}] if self.test: logging.info(f"Test mode ({pid}): Got env variables") # Set up batch client, and check that we haven't # already hit the time limit batch_client = batchclient.BatchClient(poll_time=self.poll_time, region_name=self.region_name) self._assert_timeout(batch_client, job_ids=[]) if self.test: logging.info(f"Test mode ({pid}): Ready to batch") all_job_kwargs = [] for i, params in enumerate(job_params): if params["done"]: continue # Break in case of testing if (self.max_runs is not None) and (i >= self.max_runs): break _env_variables = env_variables.copy() for k, v in params.items(): new_row = dict(name="BATCHPAR_{}".format(k), value=str(v)) _env_variables.append(new_row) # Add the environmental variables to the container overrides overrides = dict(environment=_env_variables, memory=self.memory, vcpus=self.vcpus) job_kwargs = dict(jobDefinition=self.job_def, jobName=self.job_name, jobQueue=self.job_queue, timeout=dict(attemptDurationSeconds=self.timeout), containerOverrides=overrides) all_job_kwargs.append(job_kwargs) # Wait for jobs to finish self._run_batch_jobs(batch_client, all_job_kwargs)
def _run_batch_jobs(self, batch_client, all_job_kwargs): '''Monitor AWS batch jobs until finished or failed. Parameters: batch_client (:obj:`BatchClient`) job_ids (:obj:`list` of :obj:`str`): List of AWS batch job IDs to monitor. ''' # Keep submitting until all submitted all_job_ids = set() done_job_ids = set() submitted_job_idxs = set() logging.info(f"{os.getpid()}: " "{} jobs to run".format(len(all_job_kwargs))) while len(all_job_kwargs) > len(all_job_ids): # Get the number of live jobs running_job_ids = all_job_ids - done_job_ids n_live = len(running_job_ids) n_done = len(done_job_ids) n_left = len(all_job_kwargs) - n_done - n_live logging.info(f"{os.getpid()}: " "{} jobs are live, " "{} are finished, " "and {} are yet to be submitted".format(n_live, n_done, n_left)) if n_live > 1: self._assert_timeout(batch_client, running_job_ids) self._assert_success(batch_client, all_job_ids, done_job_ids) # Submit some jobs until `self.max_live_jobs` reached for ijob, job_kwargs in enumerate(all_job_kwargs): if n_live >= self.max_live_jobs: break if ijob in submitted_job_idxs: continue # Submit a new job id_ = batch_client.submit_job(**job_kwargs) all_job_ids.add(id_) submitted_job_idxs.add(ijob) n_live += 1 # Wait before continuing logging.info(f"{os.getpid()}: Not done submitting...") time.sleep(self.poll_time) # Wait until all finished running_job_ids = all_job_ids - done_job_ids while len(running_job_ids) > 0: running_job_ids = all_job_ids - done_job_ids self._assert_timeout(batch_client, running_job_ids) self._assert_success(batch_client, all_job_ids, done_job_ids) # Wait before continuing #logging.info("Not finished waiting...") time.sleep(self.poll_time) def _assert_success(self, batch_client, job_ids, done_jobs): '''Assert that success rate has not been breached.''' stats = defaultdict(int) # Collection of failure vs total statistics # Check status for each job for id_ in job_ids: status = batch_client.get_job_status(id_) if id_ not in done_jobs: logging.debug(f"{os.getpid()}: " "{} {}".format(id_, status)) if status == "FAILED": self.failed_jobs.add(id_) if status not in ("SUCCEEDED", "FAILED", "RUNNING"): continue stats[status] += 1 if status != "RUNNING": done_jobs.add(id_) # Ignore if jobs are simply stalling if len(stats) == 0: logging.info(f"{os.getpid()}: " "No jobs are currently running") return # Calculate the failure rate total = sum(stats.values()) failure_rate = stats["FAILED"] / total if failure_rate <= (1 - self.success_rate): return reason = "Exiting due to high failure rate: {}%".format(int(failure_rate*100)) reason += "\nFailed jobs are: {}".format(self.failed_jobs) batch_client.hard_terminate(job_ids=job_ids, reason=reason) def _assert_timeout(self, batch_client, job_ids): '''Assert that timeout has not been breached.''' logging.info(f"{os.getpid()}: " "{} seconds left".format(self.TIMEOUT - time.time())) if time.time() < self.TIMEOUT: return reason = f"{os.getpid()}: " reason += "Impending worker timeout, so killing live tasks" reason += "\nFailed jobs are: {}".format(self.failed_jobs) batch_client.hard_terminate(job_ids=job_ids, reason=reason)