Source code for nesta.core.routines.crunchbase.crunchbase_parent_id_collect_task

'''
Merge in parent organisations
=============================

This task picks up the missed org_parents table from the Crumchbase data dump and combines this with organizations.
'''
import boto3
import logging
import luigi

from nesta.core.routines.crunchbase.crunchbase_health_label_task import HealthLabelTask
from nesta.packages.crunchbase.crunchbase_collect import get_files_from_tar
from nesta.packages.misc_utils.batches import split_batches
from nesta.core.luigihacks import misctools
from nesta.core.luigihacks.mysqldb import MySqlTarget
from nesta.core.orms.crunchbase_orm import Organization
from nesta.core.orms.orm_utils import get_mysql_engine, db_session


S3 = boto3.resource('s3')
_BUCKET = S3.Bucket("nesta-production-intermediate")
DONE_KEYS = set(obj.key for obj in _BUCKET.objects.all())


[docs]class ParentIdCollectTask(luigi.Task): '''Download tar file of csvs and append parent_ids to the organizations table. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): The output database envariable db_config_path (str): The output database configuration insert_batch_size (int): number of rows to insert into the db in a batch ''' date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter() db_config_env = luigi.Parameter() db_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500)
[docs] def requires(self): yield HealthLabelTask(date=self.date, _routine_id=self._routine_id, test=self.test, insert_batch_size=self.insert_batch_size, db_config_env=self.db_config_env, bucket='nesta-crunchbase-models', vectoriser_key='vectoriser.pickle', classifier_key='clf.pickle')
[docs] def output(self): '''Points to the output database engine''' db_config = misctools.get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "Crunchbase <dummy>" # Note, not a real table update_id = "CrunchbaseParentIdCollect_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config)
[docs] def run(self): # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # collect file logging.info(f"Collecting org_parents from crunchbase tar") org_parents = get_files_from_tar(['org_parents'])[0] logging.info(f"{len(org_parents)} parent ids in crunchbase export") # collect previously processed orgs logging.info("Extracting previously processed organisations") with db_session(self.engine) as session: processed_orgs = session.query(Organization.id, Organization.parent_id).all() all_orgs = {org for (org, _) in processed_orgs} logging.info(f"{len(all_orgs)} total orgs in database") processed_orgs = {org for (org, parent_id) in processed_orgs if parent_id is not None} logging.info(f"{len(processed_orgs)} previously processed orgs") # reformat into a list of dicts, removing orgs that already have a parent_id # or are missing from the database org_parents = org_parents[['uuid', 'parent_uuid']] org_parents.columns = ['id', 'parent_id'] org_parents = org_parents[org_parents['id'].isin(all_orgs)] org_parents = org_parents[~org_parents['id'].isin(processed_orgs)] org_parents = org_parents.to_dict(orient='records') logging.info(f"{len(org_parents)} organisations to update in MYSQL") # insert parent_ids into db in batches for count, batch in enumerate(split_batches(org_parents, self.insert_batch_size), 1): with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch) logging.info(f"{count} batch{'es' if count > 1 else ''} written to db") if self.test and count > 1: logging.info("Breaking after 2 batches while in test mode") break # mark as done logging.warning("Task complete") self.output().touch()