"""
Compute statistics over data.
Helper file for computing various statistics over our data such as mention
frequency, mention text frequency in the data (even if not labeled as an
anchor), ...
etc.
"""
import argparse
import logging
import multiprocessing
import os
import time
from collections import Counter
import marisa_trie
import nltk
import numpy as np
import ujson
import ujson as json
from tqdm.auto import tqdm
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.utils import utils
from bootleg.utils.utils import get_lnrm
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
[docs]def parse_args():
"""Parse args."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir", type=str, default="data/", help="Data dir for training data"
)
parser.add_argument(
"--save_dir", type=str, default="data/", help="Data dir for saving stats"
)
parser.add_argument("--train_file", type=str, default="train.jsonl")
parser.add_argument(
"--entity_symbols_dir",
type=str,
default="entity_db/entity_mappings",
help="Path to entities inside data_dir",
)
parser.add_argument("--lower", action="store_true", help="Lower aliases")
parser.add_argument("--strip", action="store_true", help="Strip punc aliases")
parser.add_argument(
"--num_workers", type=int, help="Number of workers to parallelize", default=2
)
args = parser.parse_args()
return args
[docs]def compute_histograms(save_dir, entity_symbols):
"""Compute histogram."""
al_counts = Counter()
for al in entity_symbols.get_all_aliases():
num_entities = len(entity_symbols.get_qid_cands(al))
al_counts.update([num_entities])
utils.dump_json_file(
filename=os.path.join(save_dir, "candidate_counts.json"), contents=al_counts
)
return
[docs]def get_num_lines(input_src):
"""Get number of lines."""
# get number of lines
num_lines = 0
with open(input_src, "r", encoding="utf-8") as in_file:
try:
for line in in_file:
num_lines += 1
except Exception as e:
logging.error("ERROR READING IN TRAINING DATA")
logging.error(e)
return []
return num_lines
[docs]def chunk_text_data(input_src, chunk_files, chunk_size, num_lines):
"""Chunk text data."""
logging.info(f"Reading in {input_src}")
start = time.time()
# write out chunks as text data
chunk_id = 0
num_lines_in_chunk = 0
# keep track of what files are written
out_file = open(chunk_files[chunk_id], "w")
with open(input_src, "r", encoding="utf-8") as in_file:
for i, line in enumerate(in_file):
out_file.write(line)
num_lines_in_chunk += 1
# move on to new chunk when it hits chunk size
if num_lines_in_chunk == chunk_size:
chunk_id += 1
# reset number of lines in chunk and open new file if not at end
num_lines_in_chunk = 0
out_file.close()
if i < (num_lines - 1):
out_file = open(chunk_files[chunk_id], "w")
out_file.close()
logging.info(f"Wrote out data chunks in {round(time.time() - start, 2)}s")
[docs]def compute_occurrences_single(args, max_alias_len=6):
"""Compute statistics single process."""
data_file, aliases_file, lower, strip = args
num_lines = sum(1 for _ in open(data_file))
all_aliases = ujson.load(open(aliases_file))
all_aliases = marisa_trie.Trie(all_aliases)
# entity histogram
ent_occurrences = Counter()
# alias histogram
alias_occurrences = Counter()
# alias text occurrances
alias_text_occurrences = Counter()
# number of aliases per sentence
alias_pair_occurrences = Counter()
# alias|entity histogram
alias_entity_pair = Counter()
with open(data_file, "r") as in_file:
for line in tqdm(in_file, total=num_lines):
line = json.loads(line.strip())
for n in range(max_alias_len + 1, 0, -1):
grams = nltk.ngrams(line["sentence"].split(), n)
for gram_words in grams:
gram_attempt = get_lnrm(" ".join(gram_words), lower, strip)
if gram_attempt in all_aliases:
alias_text_occurrences[gram_attempt] += 1
# Get aliases in wikipedia _before_ the swapping - these represent the true textual aliases
aliases = line["unswap_aliases"]
qids = line["qids"]
for qid, alias in zip(qids, aliases):
ent_occurrences[qid] += 1
alias_occurrences[alias] += 1
alias_entity_pair[alias + "|" + qid] += 1
alias_pair_occurrences[len(aliases)] += 1
results = {
"ent_occurrences": ent_occurrences,
"alias_occurrences": alias_occurrences,
"alias_text_occurrences": alias_text_occurrences,
"alias_pair_occurrences": alias_pair_occurrences,
"alias_entity_pair": alias_entity_pair,
}
return results
[docs]def compute_occurrences(save_dir, data_file, entity_dump, lower, strip, num_workers=8):
"""Compute statistics."""
all_aliases = entity_dump.get_all_aliases()
chunk_file_path = os.path.join(save_dir, "tmp")
all_aliases_f = os.path.join(chunk_file_path, "all_aliases.json")
utils.ensure_dir(chunk_file_path)
ujson.dump(all_aliases, open(all_aliases_f, "w"), ensure_ascii=False)
# divide up data into chunks
num_lines = get_num_lines(data_file)
num_processes = min(num_workers, int(multiprocessing.cpu_count()))
logging.info(f"Using {num_processes} workers...")
chunk_size = int(np.ceil(num_lines / (num_processes)))
utils.ensure_dir(chunk_file_path)
chunk_infiles = [
os.path.join(f"{chunk_file_path}", f"data_chunk_{chunk_id}_in.jsonl")
for chunk_id in range(num_processes)
]
chunk_text_data(data_file, chunk_infiles, chunk_size, num_lines)
pool = multiprocessing.Pool(processes=num_processes)
subprocess_args = [
[chunk_infiles[i], all_aliases_f, lower, strip] for i in range(num_processes)
]
results = pool.map(compute_occurrences_single, subprocess_args)
pool.close()
pool.join()
logging.info("Finished collecting counts")
logging.info("Merging counts....")
# merge counters together
ent_occurrences = Counter()
# alias histogram
alias_occurrences = Counter()
# alias text occurrances
alias_text_occurrences = Counter()
# number of aliases per sentence
alias_pair_occurrences = Counter()
# alias|entity histogram
alias_entity_pair = Counter()
for result_set in tqdm(results, desc="Merging"):
ent_occurrences += result_set["ent_occurrences"]
alias_occurrences += result_set["alias_occurrences"]
alias_text_occurrences += result_set["alias_text_occurrences"]
alias_pair_occurrences += result_set["alias_pair_occurrences"]
alias_entity_pair += result_set["alias_entity_pair"]
# save counters
utils.dump_json_file(
filename=os.path.join(save_dir, "entity_count.json"), contents=ent_occurrences
)
utils.dump_json_file(
filename=os.path.join(save_dir, "alias_counts.json"), contents=alias_occurrences
)
utils.dump_json_file(
filename=os.path.join(save_dir, "alias_text_counts.json"),
contents=alias_text_occurrences,
)
utils.dump_json_file(
filename=os.path.join(save_dir, "alias_pair_occurrences.json"),
contents=alias_pair_occurrences,
)
utils.dump_json_file(
filename=os.path.join(save_dir, "alias_entity_counts.json"),
contents=alias_entity_pair,
)
[docs]def main():
"""Run."""
args = parse_args()
logging.info(json.dumps(vars(args), indent=4))
entity_symbols = EntitySymbols.load_from_cache(
load_dir=os.path.join(args.data_dir, args.entity_symbols_dir)
)
train_file = os.path.join(args.data_dir, args.train_file)
save_dir = os.path.join(args.save_dir, "stats")
logging.info(f"Will save data to {save_dir}")
utils.ensure_dir(save_dir)
# compute_histograms(save_dir, entity_symbols)
compute_occurrences(
save_dir,
train_file,
entity_symbols,
args.lower,
args.strip,
num_workers=args.num_workers,
)
if __name__ == "__main__":
main()