Source code for bootleg.end2end.extract_mentions
"""
Extract mentions.
This file takes in a jsonlines file with sentences
and extract aliases and spans using a pre-computed alias table.
"""
import argparse
import logging
import multiprocessing
import os
import time
import jsonlines
import numpy as np
from tqdm.auto import tqdm
from bootleg.symbols.constants import ANCHOR_KEY
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.utils.classes.nested_vocab_tries import VocabularyTrie
from bootleg.utils.mention_extractor_utils import (
ngram_spacy_extract_aliases,
spacy_extract_aliases,
)
logger = logging.getLogger(__name__)
MENTION_EXTRACTOR_OPTIONS = {
"ngram_spacy": ngram_spacy_extract_aliases,
"spacy": spacy_extract_aliases,
}
[docs]def parse_args():
"""Generate args."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--in_file", type=str, required=True, help="File to extract mentions from"
)
parser.add_argument(
"--out_file",
type=str,
required=True,
help="File to write extracted mentions to",
)
parser.add_argument(
"--entity_db_dir", type=str, required=True, help="Path to entity db"
)
parser.add_argument(
"--extract_method",
type=str,
choices=list(MENTION_EXTRACTOR_OPTIONS.keys()),
default="ngram_spacy",
)
parser.add_argument("--min_alias_len", type=int, default=1)
parser.add_argument("--max_alias_len", type=int, default=6)
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--num_chunks", type=int, default=8)
parser.add_argument("--verbose", action="store_true")
return parser.parse_args()
[docs]def create_out_line(sent_obj, final_aliases, final_spans, found_char_spans):
"""Create JSON output line.
Args:
sent_obj: input sentence JSON
final_aliases: list of final aliases
final_spans: list of final spans
found_char_spans: list of final char spans
Returns: JSON object
"""
sent_obj["aliases"] = final_aliases
sent_obj["spans"] = final_spans
sent_obj["char_spans"] = found_char_spans
# we don't know the true QID (or even if there is one) at this stage
# we assign to the most popular candidate for now so models w/o NIL can also evaluate this data
sent_obj["qids"] = ["Q-1"] * len(final_aliases)
# global alias2qids
# sent_obj["qids"] = [alias2qids[alias][0] for alias in final_aliases]
sent_obj[ANCHOR_KEY] = [True] * len(final_aliases)
return sent_obj
[docs]def chunk_text_data(input_src, chunk_files, chunk_size, num_lines):
"""Chunk text input file into chunk_size chunks.
Args:
input_src: input file
chunk_files: list of chunk file names
chunk_size: chunk size in number of lines
num_lines: total number of lines
"""
logger.debug(f"Reading in {input_src}")
start = time.time()
# write out chunks as text data
chunk_id = 0
num_lines_in_chunk = 0
# keep track of what files are written
out_file = open(chunk_files[chunk_id], "w")
with open(input_src, "r", encoding="utf-8") as in_file:
for i, line in enumerate(in_file):
out_file.write(line)
num_lines_in_chunk += 1
# move on to new chunk when it hits chunk size
if num_lines_in_chunk == chunk_size:
chunk_id += 1
# reset number of lines in chunk and open new file if not at end
num_lines_in_chunk = 0
out_file.close()
if i < (num_lines - 1):
out_file = open(chunk_files[chunk_id], "w")
out_file.close()
logger.debug(f"Wrote out data chunks in {round(time.time() - start, 2)}s")
[docs]def subprocess(args):
"""
Extract mentions single process.
Args:
args: subprocess args
"""
in_file = args["in_file"]
out_file = args["out_file"]
extact_method = args["extract_method"]
min_alias_len = args["min_alias_len"]
max_alias_len = args["max_alias_len"]
verbose = args["verbose"]
all_aliases = VocabularyTrie(load_dir=args["all_aliases_trie_f"])
num_lines = sum(1 for _ in open(in_file))
with jsonlines.open(in_file) as f_in, jsonlines.open(out_file, "w") as f_out:
for line in tqdm(
f_in, total=num_lines, disable=not verbose, desc="Processing data"
):
found_aliases, found_spans, found_char_spans = MENTION_EXTRACTOR_OPTIONS[
extact_method
](line["sentence"], all_aliases, min_alias_len, max_alias_len)
f_out.write(
create_out_line(line, found_aliases, found_spans, found_char_spans)
)
[docs]def merge_files(chunk_outfiles, out_filepath):
"""Merge output files.
Args:
chunk_outfiles: list of chunk files
out_filepath: final output file path
"""
sent_idx_unq = 0
with jsonlines.open(out_filepath, "w") as f_out:
for file in chunk_outfiles:
with jsonlines.open(file) as f_in:
for line in f_in:
if "sent_idx_unq" not in line:
line["sent_idx_unq"] = sent_idx_unq
f_out.write(line)
sent_idx_unq += 1
[docs]def extract_mentions(
in_filepath,
out_filepath,
entity_db_dir,
extract_method="ngram_spacy",
min_alias_len=1,
max_alias_len=6,
num_workers=8,
num_chunks=None,
verbose=False,
):
"""Extract mentions from file.
Args:
in_filepath: input file
out_filepath: output file
entity_db_dir: path to entity db
extract_method: mention extraction method
min_alias_len: minimum alias length (in words)
max_alias_len: maximum alias length (in words)
num_workers: number of multiprocessing workers
num_chunks: number of subchunks to feed to workers
verbose: verbose boolean
"""
assert os.path.exists(in_filepath), f"{in_filepath} does not exist"
entity_symbols: EntitySymbols = EntitySymbols.load_from_cache(entity_db_dir)
all_aliases_trie: VocabularyTrie = entity_symbols.get_all_alias_vocabtrie()
if num_chunks is None:
num_chunks = num_workers
start_time = time.time()
# multiprocessing
if num_workers > 1:
prep_dir = os.path.join(os.path.dirname(out_filepath), "prep")
os.makedirs(prep_dir, exist_ok=True)
all_aliases_trie_f = os.path.join(prep_dir, "mention_extract_alias.marisa")
all_aliases_trie.dump(all_aliases_trie_f)
# chunk file for multiprocessing
num_lines = sum([1 for _ in open(in_filepath)])
num_processes = min(num_workers, int(multiprocessing.cpu_count()))
num_chunks = min(num_lines, num_chunks)
logger.debug(f"Using {num_processes} workers...")
chunk_size = int(np.ceil(num_lines / num_chunks))
chunk_file_path = os.path.join(prep_dir, "data_chunk")
chunk_infiles = [
f"{chunk_file_path}_{chunk_id}_in.jsonl" for chunk_id in range(num_chunks)
]
chunk_text_data(in_filepath, chunk_infiles, chunk_size, num_lines)
logger.debug("Calling subprocess...")
# call subprocesses on chunks
pool = multiprocessing.Pool(processes=num_processes)
chunk_outfiles = [
f"{chunk_file_path}_{chunk_id}_out.jsonl" for chunk_id in range(num_chunks)
]
subprocess_args = [
{
"in_file": chunk_infiles[i],
"out_file": chunk_outfiles[i],
"extract_method": extract_method,
"min_alias_len": min_alias_len,
"max_alias_len": max_alias_len,
"all_aliases_trie_f": all_aliases_trie_f,
"verbose": verbose,
}
for i in range(num_chunks)
]
pool.map(subprocess, subprocess_args)
pool.close()
pool.join()
logger.debug("Merging files...")
# write all chunks back in single file
merge_files(chunk_outfiles, out_filepath)
logger.debug("Removing temporary files...")
# clean up and remove chunked files
for file in chunk_infiles:
try:
os.remove(file)
except PermissionError:
pass
for file in chunk_outfiles:
try:
os.remove(file)
except PermissionError:
pass
try:
os.remove(all_aliases_trie_f)
except PermissionError:
pass
# single process
else:
logger.debug("Using 1 worker...")
with jsonlines.open(in_filepath, "r") as in_file, jsonlines.open(
out_filepath, "w"
) as out_file:
sent_idx_unq = 0
for line in in_file:
(
found_aliases,
found_spans,
found_char_spans,
) = MENTION_EXTRACTOR_OPTIONS[extract_method](
line["sentence"], all_aliases_trie, min_alias_len, max_alias_len
)
new_line = create_out_line(
line, found_aliases, found_spans, found_char_spans
)
if "sent_idx_unq" not in new_line:
new_line["sent_idx_unq"] = sent_idx_unq
sent_idx_unq += 1
out_file.write(new_line)
logger.debug(
f"Finished in {time.time() - start_time} seconds. Wrote out to {out_filepath}"
)
[docs]def main():
"""Run."""
args = parse_args()
in_file = args.in_file
out_file = args.out_file
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
print(args)
extract_mentions(
in_file,
out_file,
entity_db_dir=args.entity_db_dir,
min_alias_len=args.min_alias_len,
max_alias_len=args.max_alias_len,
num_workers=args.num_workers,
num_chunks=args.num_chunks,
verbose=args.verbose,
)
if __name__ == "__main__":
main()