Source code for bootleg.utils.preprocessing.get_train_qid_counts

"""
Compute QID counts.

Helper function that computes a dictionary of QID -> count in training data.

If a QID is not in this dictionary, it has a count of zero.
"""

import argparse
import multiprocessing
from collections import defaultdict

import ujson
from tqdm.auto import tqdm

from bootleg.utils import utils


[docs]def parse_args(): """Parse args.""" parser = argparse.ArgumentParser() parser.add_argument( "--train_file", type=str, default="/dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0114/train.jsonl", ) parser.add_argument( "--out_file", type=str, default="/dfs/scratch0/lorr1/projects/bootleg-data/data/wiki_title_0114/train_qidcnt.json", help="Regularization of each qid", ) args = parser.parse_args() return args
[docs]def get_counts(num_processes, file): """Get true anchor slice counts.""" pool = multiprocessing.Pool(processes=num_processes) num_lines = sum(1 for _ in open(file)) qid_cnts = defaultdict(int) for res in tqdm( pool.imap_unordered(get_counts_hlp, open(file), chunksize=1000), total=num_lines, desc="Gathering counts", ): for qid in res: qid_cnts[qid] += res[qid] pool.close() pool.join() return qid_cnts
[docs]def get_counts_hlp(line): """Get count helper.""" res = defaultdict(int) # qid -> cnt line = ujson.loads(line) for qid in line["qids"]: res[qid] += 1 return res
[docs]def main(): """Run.""" args = parse_args() print(ujson.dumps(vars(args), indent=4)) num_processes = int(0.8 * multiprocessing.cpu_count()) print(f"Getting slice counts from {args.train_file}") qid_cnts = get_counts(num_processes, args.train_file) utils.dump_json_file(args.out_file, qid_cnts)
if __name__ == "__main__": main()