Source code for nesta.core.routines.arxiv.deepchange_analysis_task

"""
"Deep learning, Deep Change" analysis
=====================================

Luigi routine to perform the 
analysis from the Deep learning,
deep change paper, placing the results in an 
S3 bucket to be picked up by the `arXlive <https://arxlive.org>`_ front end.
"""
import logging
import luigi
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import pandas as pd
from datetime import datetime as dt

from nesta.packages.arxiv import deepchange_analysis as dc
from nesta.core.luigihacks.misctools import find_filepath_from_pathstub as f3p
from nesta.core.luigihacks.misctools import get_config
from nesta.core.luigihacks import mysqldb
from nesta.core.luigihacks.luigi_logging import set_log_level
from nesta.core.orms.orm_utils import get_mysql_engine
from nesta.core.routines.arxiv.arxiv_topic_tasks import WriteTopicTask
from nesta.core.luigihacks.parameter import DictParameterPlus

matplotlib.rcParams['figure.figsize'] = (20, 13)
matplotlib.rcParams.update({'font.size': 25, "axes.labelpad":10})

ORDERED_QUERIES = [f3p(x) for x in 
                   ('arxlive1_filter_cats.sql', 
                    'arxlive2_join_insts.sql', 
                    'arxlive3_group_cats.sql', 
                    'arxlive4_read_final.sql')]
YEAR_THRESHOLD = 2012
MIN_RCA_YEAR = 2007  # minimum year when calculating rca pre 2012
N_TOP = 15  # number of countries / cities / categories to show
COLOR_A = '#631607'
COLOR_B = '#d68b7a'
STATIC_FILES_BUCKET = 'arxlive-static-files'

[docs]def sql_queries(): for i, filepath in enumerate(ORDERED_QUERIES): is_last = bool(i+1 == len(ORDERED_QUERIES)) with open(filepath) as f: query = f.read() yield query, is_last
[docs]class AnalysisTask(luigi.Task): """Extract and analyse arXiv data to produce data and charts for the arXlive front end to consume. Proposed charts: 1. distribution of dl/non dl papers by country (horizontal bar) 2. distribution of dl/non dl papers by city (horizontal bar) 3. % ML papers by year (line) 4. share of ML activity in arxiv subjects, pre/post 2012 (horizontal point / slope) 5. rca, pre/post 2012 by country (horizontal point / slope) 6. rca over time, citation > mean & top 50 countries (horizontal violin) [NOT DONE] Proposed table data: 1. top countries by rca (moving window of last 12 months?) [NOT DONE] Args: date (datetime): Datetime used to label the outputs _routine_id (str): String used to label the AWS task db_config_env (str): environmental variable pointing to the db config file db_config_path (str): The output database configuration mag_config_path (str): Microsoft Academic Graph Api key configuration path insert_batch_size (int): number of records to insert into the database at once (not used in this task but passed down to others) articles_from_date (str): new and updated articles from this date will be retrieved. Must be in YYYY-MM-DD format (not used in this task but passed down to others) """ date = luigi.DateParameter() _routine_id = luigi.Parameter() test = luigi.BoolParameter(default=True) db_config_env = luigi.Parameter() db_config_path = luigi.Parameter() mag_config_path = luigi.Parameter() insert_batch_size = luigi.IntParameter(default=500) articles_from_date = luigi.Parameter() s3_path_prefix = luigi.Parameter(default="s3://nesta-arxlive") raw_data_path = luigi.Parameter(default="raw-inputs") grid_task_kwargs = DictParameterPlus() cherry_picked = luigi.Parameter()
[docs] def output(self): '''Points to the output database engine''' db_config = get_config(self.db_config_path, "mysqldb") db_config["database"] = 'dev' if self.test else 'production' db_config["table"] = "arXlive <dummy>" # NB: not a real table update_id = "ArxivAnalysis_{}_{}".format(self.date, self.test) return mysqldb.MySqlTarget(update_id=update_id, **db_config)
[docs] def requires(self): s3_path_prefix=(f"{self.s3_path_prefix}/" f"automl/{self.date}") data_path = (f"{self.s3_path_prefix}/" f"{self.raw_data_path}/{self.date}") yield WriteTopicTask(raw_s3_path_prefix=self.s3_path_prefix, s3_path_prefix=s3_path_prefix, data_path=data_path, date=self.date, cherry_picked=self.cherry_picked, test=self.test, grid_task_kwargs=self.grid_task_kwargs)
[docs] def run(self): # Threshold for testing year_threshold = 2008 if self.test else YEAR_THRESHOLD test_label = 'test' if self.test else '' # database setup database = 'dev' if self.test else 'production' logging.warning(f"Using {database} database") self.engine = get_mysql_engine(self.db_config_env, 'mysqldb', database) # All queries except last prepare temporary tables # and the final query produces the dataframe # which collects data, such that there is one row per # article / institute / institute country for query, is_last in sql_queries(): if not is_last: self.engine.execute(query) df = pd.read_sql(query, self.engine) logging.info(f"Dataset contains {len(df)} articles") # Manual hack to factor Hong Kong outside of China for city in ["Hong Kong", "Tsuen Wan", "Tuen Mun", "Tai Po", "Sai Kung"]: df.loc[df.institute_city == f"{city}, CN", "institute_country"] = "Hong Kong" # Manual hack to factor out transnational corps countries = set(df.institute_country) df['is_multinational'] = df['institute_name'].apply(lambda x: dc.is_multinational(x, countries)) df.loc[df.is_multinational, 'institute_city'] = df.loc[df.is_multinational, 'institute_name'].apply(lambda x: ''.join(x.split("(")[:-1])) df.loc[df.is_multinational, 'institute_country'] = "Transnationals" # collect topics, determine which represents # deep_learning and apply flag terms = ["deep", "deep_learning", "reinforcement", "neural_networks", "neural_network"] min_weight = 0.1 if self.test else 0.3 dl_topic_ids = dc.get_article_ids_by_terms(self.engine, terms=terms, min_weight=min_weight) df['is_dl'] = df.article_id.apply(lambda i: i in dl_topic_ids) logging.info(f"Flagged {df.is_dl.sum()} deep learning articles in dataset") df['date'] = df.apply(lambda row: row.article_updated or row.article_created, axis=1) df['year'] = df.date.apply(lambda date: date.year) df = dc.add_before_date_flag(df, date_column='date', before_year=year_threshold) # first plot - dl/non dl distribution by country (top n) pivot_by_country = (pd.pivot_table(df.groupby(['institute_country', 'is_dl']) .size() .reset_index(drop=False), index='institute_country', columns='is_dl', values=0) .apply(lambda x: 100 * (x / x.sum())) .rename(columns={True: 'DL', False: 'non DL'})) fig, ax = plt.subplots() (pivot_by_country.sort_values('DL', ascending=False)[:N_TOP] .sort_values('DL') .plot.barh(ax=ax, color=[COLOR_B, COLOR_A], width=0.6)) ax.set_xlabel('Percentage of DL papers in arXiv CompSci\ncategories, by country') ax.set_ylabel('') handles, labels = ax.get_legend_handles_labels() _ = ax.legend(labels=[labels[1], labels[0]], handles=[handles[1], handles[0]], title='') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_1{test_label}.png', plt) # second plot - dl/non dl distribution by city (top n) pivot_by_city = (pd.pivot_table(df.groupby(['institute_city', 'is_dl']) .size() .reset_index(drop=False), index='institute_city', columns='is_dl', values=0) .apply(lambda x: 100 * (x / x.sum())) .rename(columns={True: 'DL', False: 'non DL'})) fig, ax = plt.subplots() (pivot_by_city.sort_values('DL', ascending=False)[:N_TOP] .sort_values('DL') .plot.barh(ax=ax, color=[COLOR_B, COLOR_A], width=0.8)) ax.set_xlabel('Percentage of DL papers in arXiv CompSci\ncategories, by city or multinational') ax.set_ylabel('') handles, labels = ax.get_legend_handles_labels() ax.legend(labels=[labels[1], labels[0]], handles=[handles[1], handles[0]], title='') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_2{test_label}.png', plt) # third plot - percentage of dl papers by year deduped = df.drop_duplicates('article_id') start_year = 2000 papers_by_year = pd.crosstab(deduped['year'], deduped['is_dl']).loc[start_year:] papers_by_year = (100 * papers_by_year.apply(lambda x: x / x.sum(), axis=1)) papers_by_year = papers_by_year.drop(False, axis=1) # drop non-dl column fig, ax = plt.subplots(figsize=(20,8)) papers_by_year.plot(ax=ax, legend=None, color=COLOR_A, linewidth=10) plt.xlabel('\nYear of paper publication') plt.ylabel('Percentage of DL papers\nin arXiv CompSci categories\n') plt.xticks(np.arange(min(papers_by_year.index), max(papers_by_year.index) + 1, 1)) ax.set_xticklabels(['' if i % 2 else y for i, y in enumerate(papers_by_year.index)]) dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_3{test_label}.png', plt) # fourth plot - share of DL activity by arxiv # subject pre/post threshold df_all_cats = pd.read_sql("SELECT * FROM arxiv_categories", self.engine) condition = (df_all_cats.id.str.startswith('cs.') | (df_all_cats.id.str == 'stat.ML')) all_categories = list(df_all_cats.loc[condition].description) _before = f'Before {year_threshold}' _after = f'After {year_threshold}' cat_period_container = [] for cat in all_categories: subset = df.loc[[cat in x for x in df['arxiv_category_descs']], :] subset_ct = pd.crosstab(subset[f'before_{year_threshold}'], subset.is_dl, normalize=0) # This is true for some categories in dev mode # due to a smaller dataset if list(subset_ct.index) != [False, True]: continue subset_ct.index = [_after, _before] # this try /except may not be required when # running on the full dataset try: cat_period_container.append(pd.Series(subset_ct[True], name=cat)) except KeyError: pass cat_thres_df = (pd.concat(cat_period_container, axis=1) .T .sort_values(_after, ascending=False)) other = cat_thres_df[N_TOP:].mean().rename('Other') cat_thres_df = cat_thres_df[:N_TOP].append(other) fig, ax = plt.subplots() (100*cat_thres_df[_before]).plot(markeredgecolor=COLOR_B, marker='o', color=COLOR_B, ax=ax, markerfacecolor=COLOR_B, linewidth=7.5) (100*cat_thres_df[_after]).plot(markeredgecolor=COLOR_A, marker='o', color=COLOR_A, ax=ax, markerfacecolor=COLOR_A, linewidth=7.5) ax.vlines(np.arange(len(cat_thres_df)), ymin=len(cat_thres_df) * [0], ymax=100*cat_thres_df[_after], linestyle=':', linewidth=2) ax.set_xticks(np.arange(len(cat_thres_df))) ax.set_xticklabels(cat_thres_df.index, rotation=40, ha='right') ax.set_ylabel('Percentage of DL papers,\n' 'by arXiv CompSci category') ax.legend() dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_4{test_label}.png', plt) # fifth chart - changes in specialisation before / after threshold (top n countries) dl_counts = df.groupby('institute_country')['is_dl'].count() # remove the bottom 10% of countries here top_countries = list(dl_counts.loc[dl_counts > dl_counts.quantile(0.25)].index) top_countries = df.institute_country.apply(lambda x: x in top_countries) # Only highly citated papers avg_citation_counts = df[['year','citation_count']].groupby('year').quantile(0.5) avg_citation_counts['citation_count'] = avg_citation_counts['citation_count'].apply(lambda x: x if x > 0 else 1) highly_cited = map(lambda x : dc.highly_cited(x, avg_citation_counts), [row for _, row in df.iterrows()]) highly_cited = np.array(list(highly_cited)) if self.test: highly_cited = np.array([True]*len(df)) # Min year threshold min_year = (df.year >= MIN_RCA_YEAR if not self.test else df.year >= 2000) # Apply filters before calculating RCA top_df = df.loc[top_countries & highly_cited & min_year] logging.info(f'Got {len(top_df)} rows for RCA calculation.\n' 'Breakdown (ctry, cite, yr) = ' f'{sum(top_countries)}, ' f'{sum(highly_cited)}, {sum(min_year)}') before_year = top_df[f'before_{year_threshold}'] logging.info("Before is DL = " f"{sum(top_df.loc[before_year].is_dl)}") logging.info("After is DL = " f"{sum(top_df.loc[~before_year].is_dl)}") # Calculate revealed comparative advantage pre_threshold_rca = dc.calculate_rca_by_country( top_df[top_df[f'before_{year_threshold}']], country_column='institute_country', commodity_column='is_dl') post_threshold_rca = dc.calculate_rca_by_country( top_df[~top_df[f'before_{year_threshold}']], country_column='institute_country', commodity_column='is_dl') rca_combined = (pd.merge(pre_threshold_rca, post_threshold_rca, left_index=True, right_index=True, suffixes=('_before', '_after')) .rename(columns={'is_dl_before': _before, 'is_dl_after': _after}) .sort_values(_after, ascending=False)) top_dl_countries = list(top_df[['institute_country', 'is_dl']] .groupby('institute_country') .sum() .sort_values('is_dl', ascending=False)[:N_TOP] .index) condition = rca_combined.index.isin(top_dl_countries) rca_combined_top = rca_combined[condition] fig, ax = plt.subplots() rca_combined_top[_before].plot(markeredgecolor=COLOR_B, marker='o', markersize=20, color='white', ax=ax, markerfacecolor=COLOR_B, linewidth=0) rca_combined_top[_after].plot(markeredgecolor=COLOR_A, marker='o', markersize=20, color='white', ax=ax, markerfacecolor=COLOR_A, linewidth=0) col = [COLOR_A if x > y else '#d18270' for x, y in zip(rca_combined_top[_after], rca_combined_top[_before])] ax.vlines(np.arange(len(rca_combined_top)), ymin=rca_combined_top[_before], ymax=rca_combined_top[_after], linestyle=':', color=col, linewidth=4) ax.hlines(y=1, xmin=-0.5, xmax=len(rca_combined_top)-0.5, color='darkgrey', linestyle='--', linewidth=4) ax.set_xticks(np.arange(len(rca_combined_top))) ax.set_xlim(-1, len(rca_combined_top)) ax.set_xticklabels(rca_combined_top.index, rotation=40, ha='right') ax.legend() ax.set_ylabel('Specialisation in Deep Learning\n' 'relative to other arXiv CompSci categories') ax.set_xlabel('') dc.plot_to_s3(STATIC_FILES_BUCKET, f'static/figure_5{test_label}.png', plt) # mark as done logging.warning("Task complete") self.output().touch()
[docs]class StandaloneAnalysisTask(AnalysisTask): date = luigi.DateParameter(default=dt.now()) _routine_id = luigi.Parameter(default=f'StandaloneDLDC{dt.now()}') production = luigi.BoolParameter(default=False) test = luigi.BoolParameter(default=True) db_config_env = luigi.Parameter(default='MYSQLDB') db_config_path = luigi.Parameter(default='mysqldb.config') mag_config_path = luigi.Parameter(default=None) insert_batch_size = luigi.IntParameter(default=None) articles_from_date = luigi.Parameter(default=None) s3_path_prefix = luigi.Parameter(default=None) raw_data_path = luigi.Parameter(default=None) grid_task_kwargs = DictParameterPlus(default={}) cherry_picked = luigi.Parameter(default=None)
[docs] def requires(self): if self.production: self.test = False set_log_level(True) pass