Source code for nesta.core.routines.arxiv.arxiv_collect_iterative_task

'''
Collection task
===============

Luigi routine to collect new data from the arXiv api and load it to MySQL.
'''
from datetime import datetime
import luigi
import logging

from nesta.packages.arxiv.collect_arxiv import add_new_articles, retrieve_all_arxiv_rows, update_existing_articles
from nesta.packages.misc_utils.batches import BatchWriter
from nesta.core.orms.arxiv_orm import Article, Category
from nesta.core.orms.orm_utils import get_mysql_engine, db_session
from nesta.core.luigihacks import misctools
from nesta.core.luigihacks.mysqldb import MySqlTarget


[docs]class CollectNewTask(luigi.Task): '''Collect new data from the arXiv api and dump the data in the MySQL server. Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): environmental variable pointing to the db config file db_config_path (str): The output database configuration insert_batch_size (int): number of records to insert into the database at once articles_from_date (str): new and updated articles from this date will be retrieved. Must be in YYYY-MM-DD format ''' date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) db_config_env = luigi.Parameter() db_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter()
[docs] def output(self): '''Points to the output database engine''' db_config = misctools.get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "arXlive <dummy>" # Note, not a real table update_id = "ArxivIterativeCollect_{}".format(self.date) return MySqlTarget(update_id=update_id, **db_config)
[docs] def run(self): try: datetime.strptime(self.articles_from_date, '%Y-%m-%d') except ValueError: raise ValueError(f"From date for articles is invalid or not in YYYY-MM-DD format: {self.articles_from_date}") # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) with db_session(self.engine) as session: # create lookup for categories (less than 200) and set of article ids all_categories_lookup = {cat.id: cat for cat in session.query(Category).all()} logging.info(f"{len(all_categories_lookup)} existing categories") all_article_ids = {article.id for article in session.query(Article.id).all()} logging.info(f"{len(all_article_ids)} existing articles") already_updated = {article.id: article.updated for article in (session.query(Article) .filter(Article.updated >= self.articles_from_date) .all())} logging.info(f"{len(already_updated)} records exist in the database with a date on or after the update date.") new_count = 0 existing_count = 0 new_articles_batch = BatchWriter(self.insert_batch_size, add_new_articles, session) existing_articles_batch = BatchWriter(self.insert_batch_size, update_existing_articles, self.engine) # retrieve and process, while inserting any missing categories for row in retrieve_all_arxiv_rows(**{'from': self.articles_from_date}): try: # update only newer data if row['updated'] <= already_updated[row['id']]: continue except KeyError: pass # check for missing categories for cat in row.get('categories', []): try: cat = all_categories_lookup[cat] except KeyError: logging.warning(f"Missing category: '{cat}' for article {row['id']}. Adding to Category table") new_cat = Category(id=cat) session.add(new_cat) session.commit() all_categories_lookup.update({cat: ''}) # create new Article and append to batch if row['id'] not in all_article_ids: # convert category ids to Category objects row['categories'] = [all_categories_lookup[cat] for cat in row.get('categories', [])] new_articles_batch.append(Article(**row)) new_count += 1 else: # append to existing articles batch existing_articles_batch.append(row) existing_count += 1 count = new_count + existing_count if not count % 1000: logging.info(f"Processed {count} articles") if self.test and count == 1600: logging.warning("Limiting to 1600 rows while in test mode") break # insert any remaining new and existing articles logging.info("Processing final batches") if new_articles_batch: new_articles_batch.write() if existing_articles_batch: existing_articles_batch.write() logging.info(f"Total {new_count} new articles added") logging.info(f"Total {existing_count} existing articles updated") # mark as done logging.warning("Task complete") self.output().touch()