Source code for nesta.core.routines.crunchbase.crunchbase_health_label_task

"""
Organisation health labeling
============================

Luigi routine to determine if crunchbase orgs are involved in health and apply a label
to the data in MYSQL.
"""

import boto3
import luigi
import logging
import os
import pickle

from nesta.core.routines.crunchbase.crunchbase_geocode_task import OrgGeocodeTask
from nesta.packages.crunchbase.crunchbase_collect import predict_health_flag
from nesta.packages.misc_utils.batches import split_batches
from nesta.core.luigihacks.misctools import get_config, find_filepath_from_pathstub
from nesta.core.luigihacks.mysqldb import MySqlTarget
from nesta.core.orms.crunchbase_orm import Base, Organization, OrganizationCategory
from nesta.core.orms.orm_utils import get_mysql_engine, try_until_allowed, db_session


[docs]class HealthLabelTask(luigi.Task): """Apply health labels to the organisation data in MYSQL. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task test (bool): True if in test mode insert_batch_size (int): Number of rows to insert into the db in a batch db_config_env (str): The output database envariable bucket (str): S3 bucket where the models are stored vectoriser_key (str): S3 key for the vectoriser model classifier_key (str): S3 key for the classifier model """ date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter() insert_batch_size = luigi.IntParameter(default=500) db_config_env = luigi.Parameter() bucket = luigi.Parameter() vectoriser_key = luigi.Parameter() classifier_key = luigi.Parameter()
[docs] def requires(self): yield OrgGeocodeTask(date=self.date, _routine_id=self._routine_id, test=self.test, db_config_env="MYSQLDB", city_col=Organization.city, country_col=Organization.country, location_key_col=Organization.location_id, insert_batch_size=self.insert_batch_size, env_files=[find_filepath_from_pathstub("nesta/nesta/"), find_filepath_from_pathstub("config/mysqldb.config")], job_def="py36_amzn1_image", job_name=f"CrunchBaseOrgGeocodeTask-{self._routine_id}", job_queue="HighPriority", region_name="eu-west-2", poll_time=10, memory=4096, max_live_jobs=2)
[docs] def output(self): """Points to the output database engine""" self.db_config_path = os.environ[self.db_config_env] db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "Crunchbase health labels <dummy>" # Note, not a real table update_id = "CrunchbaseHealthLabel_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config)
[docs] def run(self): """Apply health labels using model.""" # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) try_until_allowed(Base.metadata.create_all, self.engine) # collect and unpickle models from s3 logging.info("Collecting models from S3") s3 = boto3.resource('s3') vectoriser_obj = s3.Object(self.bucket, self.vectoriser_key) vectoriser = pickle.loads(vectoriser_obj.get()['Body']._raw_stream.read()) classifier_obj = s3.Object(self.bucket, self.classifier_key) classifier = pickle.loads(classifier_obj.get()['Body']._raw_stream.read()) # retrieve organisations and categories nrows = 1000 if self.test else None logging.info("Collecting organisations from database") with db_session(self.engine) as session: orgs = (session .query(Organization.id) .filter(Organization.is_health.is_(None)) .limit(nrows) .all()) for batch_count, batch in enumerate(split_batches(orgs, self.insert_batch_size), 1): batch_orgs_with_cats = [] for (org_id, ) in batch: with db_session(self.engine) as session: categories = (session .query(OrganizationCategory.category_name) .filter(OrganizationCategory.organization_id == org_id) .all()) # categories should be a list of str, comma separated: ['cat,cat,cat', 'cat,cat'] categories = ','.join(cat_name for (cat_name, ) in categories) batch_orgs_with_cats.append({'id': org_id, 'categories': categories}) logging.debug(f"{len(batch_orgs_with_cats)} organisations retrieved from database") logging.debug("Predicting health flags") batch_orgs_with_flag = predict_health_flag(batch_orgs_with_cats, vectoriser, classifier) logging.debug(f"{len(batch_orgs_with_flag)} organisations to update") with db_session(self.engine) as session: session.bulk_update_mappings(Organization, batch_orgs_with_flag) logging.info(f"{batch_count} batches health labeled and written to db") # mark as done logging.warning("Task complete") self.output().touch()