Source code for bootleg.utils.eval_utils

"""Bootleg eval utils."""
import glob
import logging
import math
import multiprocessing
import os
import shutil
import time
from collections import defaultdict

import emmental
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import ujson
from emmental.utils.utils import array_to_numpy, prob_to_pred
from tqdm.auto import tqdm

from bootleg import log_rank_0_debug, log_rank_0_info
from bootleg.task_config import NED_TASK
from bootleg.utils import data_utils, utils
from bootleg.utils.classes.nested_vocab_tries import (
    TwoLayerVocabularyScoreTrie,
    VocabularyTrie,
)
from bootleg.utils.utils import strip_nan, try_rmtree

logger = logging.getLogger(__name__)


[docs]def masked_class_logsoftmax(pred, mask, dim=2, temp=1.0, zero_delta=1e-45): """ Masked logsoftmax. Mask of 0/False means mask value (ignore it) Args: pred: input tensor mask: mask dim: softmax dimension temp: softmax temperature zero_delta: small value to add so that vector + (mask+zero_delta).log() is not Nan for all 0s Returns: masked softmax tensor """ assert temp > 0, "You can't have a temperature of 0" # pred is batch x M x K # https://github.com/allenai/allennlp/blob/b6cc9d39651273e8ec2a7e334908ffa9de5c2026/allennlp/nn/util.py#L272-L303 pred = pred / temp pred = ( pred + (mask + zero_delta).log() ) # we could also do 1e-46 but I feel safer 1e-45 # WARNING: might need 5e-16 with FP16 and training # compute softmax over the k dimension return F.log_softmax(input=pred, dim=dim)
[docs]def map_aliases_to_candidates( train_in_candidates, max_candidates, alias_cand_map, aliases ): """ Get list of QID candidates for each alias. Args: train_in_candidates: whether the model has a NC entity or not (assumes all gold QIDs are in candidate lists) alias_cand_map: alias -> candidate qids in dict or TwoLayerVocabularyScoreTrie format aliases: list of aliases Returns: List of lists QIDs """ not_tic = 1 - train_in_candidates res = [] for al in aliases: if isinstance(alias_cand_map, dict): if al in alias_cand_map: cands = [qid_pair[0] for qid_pair in alias_cand_map[al]] else: cands = ["-1"] * max_candidates else: if alias_cand_map.is_key_in_trie(al): cands = alias_cand_map.get_value(al, keep_score=False) else: cands = ["-1"] * max_candidates cands = cands + ["-1"] * (max_candidates - len(cands)) res.append(not_tic * ["NC"] + cands) return res
[docs]def map_candidate_qids_to_eid(candidate_qids, qid2eid): """ Get list of EID candidates for each alias. Args: candidate_qids: list of list of candidate QIDs qid2eid: mapping of qid to entity id Returns: List of lists EIDs """ res = [] for cand_list in candidate_qids: res_cands = [] for q in cand_list: if q == "NC": res_cands.append(0) elif q == "-1": res_cands.append(1) else: if isinstance(qid2eid, dict): res_cands.append(qid2eid[q]) else: res_cands.append(qid2eid[q]) res.append(res_cands) return res
[docs]def get_eval_folder(file): """ Return eval folder for the given evaluation file. Stored in log_path/filename/model_name. Args: file: eval file Returns: eval folder """ return os.path.join( emmental.Meta.log_path, os.path.splitext(file)[0], os.path.splitext( os.path.basename(emmental.Meta.config["model_config"]["model_path"]) )[0], )
[docs]def write_disambig_metrics_to_csv(file_path, dictionary): """Save disambiguation metrics in the dictionary to file_path. Args: file_path: file path dictionary: dictionary of scores (output of Emmental score) """ # Only saving NED, ignore Type. dictionary has keys such as "NED/Bootleg/dev/unif_HD/total_men" which # corresponds to task/dataset/split/slice/metric, and the value is the associated value for that metric as # calculated on the dataset. Sort keys to ensure that the rest of the code below remains in the correct order # across slices all_keys = [x for x in sorted(dictionary.keys()) if x.startswith(NED_TASK)] # This line uses endswith("total_men") because we are just trying to get 1 copy of each task/dataset/split/slice # combo. We are not actually using the total_men information in this line below (could've used acc_boot instead, # etc.) task, dataset, split, slices = list( zip(*[x.split("/")[:4] for x in all_keys if x.endswith("total_men")]) ) acc_boot = [dictionary[x] for x in all_keys if x.endswith("acc_boot")] acc_boot_notNC = [dictionary[x] for x in all_keys if x.endswith("acc_notNC_boot")] mentions = [dictionary[x] for x in all_keys if x.endswith("total_men")] mentions_notNC = [dictionary[x] for x in all_keys if x.endswith("total_notNC_men")] acc_pop = [dictionary[x] for x in all_keys if x.endswith("acc_pop")] acc_pop_notNC = [dictionary[x] for x in all_keys if x.endswith("acc_notNC_pop")] df_info = { "task": task, "dataset": dataset, "split": split, "slice": slices, "mentions": mentions, "mentions_notNC": mentions_notNC, "acc_boot": acc_boot, "acc_boot_notNC": acc_boot_notNC, "acc_pop": acc_pop, "acc_pop_notNC": acc_pop_notNC, } df = pd.DataFrame(data=df_info) df.to_csv(file_path, index=False)
[docs]def get_sent_idx2num_mens(data_file): """Get the map from sentence index to number of mentions and to data. Used for calculating offsets and chunking file. Args: data_file: eval file Returns: Dict of sentence index -> number of mention per sentence, Dict of sentence index -> input line """ sent_idx2num_mens = {} sent_idx2row = {} total_num_mentions = 0 with open(data_file) as f: for line in tqdm( f, total=sum([1 for _ in open(data_file)]), desc="Getting sentidx2line mapping", ): line = ujson.loads(line) # keep track of the start idx in the condensed memory mapped file for each sentence (varying number of # aliases) assert ( line["sent_idx_unq"] not in sent_idx2num_mens ), f'Sentence indices must be unique. {line["sent_idx_unq"]} already seen.' sent_idx2row[str(line["sent_idx_unq"])] = line # Save as string for Marisa Tri later sent_idx2num_mens[str(line["sent_idx_unq"])] = len(line["aliases"]) # We include false aliases for debugging (and alias_pos includes them) total_num_mentions += len(line["aliases"]) # print("INSIDE SENT MAP", str(line["sent_idx_unq"]), total_num_mentions) log_rank_0_debug( logger, f"Total number of mentions across all sentences: {total_num_mentions}" ) return sent_idx2num_mens, sent_idx2row
# Modified from # https://github.com/SenWu/emmental/blob/master/src/emmental/model.py#L455 # to support dump_preds_accumulation_steps
[docs]@torch.no_grad() def batched_pred_iter( model, dataloader, dump_preds_accumulation_steps, sent_idx2num_mens, ): """ Predict from dataloader. Predict from dataloader taking into account eval accumulation steps. Will yield a new prediction set after each set accumulation steps for writing out. If a sentence or batch doesn't have any mentions, it will not be returned by this method. Recall that we split up sentences that are too long to feed to the model. We use the sent_idx2num_mens dict to ensure we have full sentences evaluated before returning, otherwise we'll have incomplete sentences to merge together when dumping. Args: model: model dataloader: The dataloader to predict dump_preds_accumulation_steps: Number of eval steps to run before returning sent_idx2num_mens: list of sent index to number of mentions Returns: Iterator over result dict. """ def collect_result(uid_d, gold_d, pred_d, prob_d, out_d, cur_sentidx_nummen): """Merge results for the sentences where all mentions have been evaluated.""" final_uid_d = defaultdict(list) final_prob_d = defaultdict(list) final_pred_d = defaultdict(list) final_gold_d = defaultdict(list) final_out_d = defaultdict(lambda: defaultdict(list)) sentidxs_finalized = [] # print("FINALIZE", cur_sentidx_nummen, [sent_idx2num_mens[str(k)] for k in cur_sentidx_nummen]) log_rank_0_debug(logger, f"Collecting {len(cur_sentidx_nummen)} results") for sent_idx, cur_mention_set in cur_sentidx_nummen.items(): assert ( len(cur_mention_set) <= sent_idx2num_mens[str(sent_idx)] ), f"Too many mentions for {sent_idx}: {cur_mention_set} VS {sent_idx2num_mens[str(sent_idx)]}" if len(cur_mention_set) == sent_idx2num_mens[str(sent_idx)]: sentidxs_finalized.append(sent_idx) for task_name in uid_d: final_uid_d[task_name].extend(uid_d[task_name][sent_idx]) final_prob_d[task_name].extend(prob_d[task_name][sent_idx]) final_pred_d[task_name].extend(pred_d[task_name][sent_idx]) final_gold_d[task_name].extend(gold_d[task_name][sent_idx]) if task_name in out_d.keys(): for action_name in out_d[task_name].keys(): final_out_d[task_name][action_name].extend( out_d[task_name][action_name][sent_idx] ) # If batch size is close to 1 and accumulation step was close to 1, # we may get to where there are no complete sentences if len(sentidxs_finalized) == 0: return {}, sentidxs_finalized res = { "uids": final_uid_d, "golds": final_gold_d, } for task_name in final_prob_d.keys(): final_prob_d[task_name] = array_to_numpy(final_prob_d[task_name]) res["probs"] = final_prob_d for task_name in final_pred_d.keys(): final_pred_d[task_name] = array_to_numpy(final_pred_d[task_name]) res["preds"] = final_pred_d res["outputs"] = final_out_d return res, sentidxs_finalized model.eval() # Will store sent_idx -> task_name -> list output uid_dict = defaultdict(lambda: defaultdict(list)) prob_dict = defaultdict(lambda: defaultdict(list)) pred_dict = defaultdict(lambda: defaultdict(list)) gold_dict = defaultdict(lambda: defaultdict(list)) # Will store sent_idx -> task_name -> output key -> list output out_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # list of all finalized and yielded sentences all_finalized_sentences = [] # Storing currently stored sent idx -> unique mentions seed (for sentences that aren't complete, # we'll hold until they are) cur_sentidx2_nummentions = dict() num_eval_steps = 0 # Collect dataloader information task_to_label_dict = dataloader.task_to_label_dict uid = dataloader.uid with torch.no_grad(): for batch_num, bdict in tqdm( enumerate(dataloader), total=len(dataloader), desc=f"Evaluating {dataloader.data_name} ({dataloader.split})", ): num_eval_steps += 1 X_bdict, Y_bdict = bdict ( uid_bdict, loss_bdict, prob_bdict, gold_bdict, out_bdict, ) = model.forward( # type: ignore X_bdict[uid], X_bdict, Y_bdict, task_to_label_dict, return_action_outputs=True, return_probs=True, ) assert ( NED_TASK in uid_bdict ), f"{NED_TASK} task needs to be in returned in uid to get number of mentions" for task_name in uid_bdict.keys(): for ex_idx in range(len(uid_bdict[task_name])): # Recall that our uid is # ============================ # guid_dtype = np.dtype( # [ # ("sent_idx", "i8", 1), # ("subsent_idx", "i8", 1), # ("alias_orig_list_pos", "i8", 1), # ] # ) # ============================ # Index 0 -> sent_idx, Index 1 -> subsent_idx, Index 2 -> List of aliases positions # (-1 means no mention in train example) sent_idx = uid_bdict[task_name][ex_idx][0] # Only increment for NED TASK if task_name == NED_TASK: # alias_pos_for_eval gives which mentions are meant to be evaluated in this batch (-1 means # skip) for scoring. This will be different than the mentions seen by the model as we window # sentences and a mention may be seen multiple times but only scored once. This includes for # True and False anchors - we dump all anchors for analysis alias_pos_for_eval = out_bdict[task_name][ "_input__for_dump_gold_cand_K_idx_train" ][ex_idx] # This is the number of mentions - there should only be 1 assert len(uid_bdict[task_name][ex_idx][2]) == 1 alias_pos_in_og_list = uid_bdict[task_name][ex_idx][2][0] if sent_idx not in cur_sentidx2_nummentions: cur_sentidx2_nummentions[sent_idx] = set() # Index 2 is index of alias positions in original list (-1 means no mention) if alias_pos_for_eval != -1: cur_sentidx2_nummentions[sent_idx].add(alias_pos_in_og_list) uid_dict[task_name][sent_idx].extend( uid_bdict[task_name][ex_idx : ex_idx + 1] ) prob_dict[task_name][sent_idx].extend(prob_bdict[task_name][ex_idx : ex_idx + 1]) # type: ignore pred_dict[task_name][sent_idx].extend( # type: ignore prob_to_pred(prob_bdict[task_name][ex_idx : ex_idx + 1]) ) gold_dict[task_name][sent_idx].extend( gold_bdict[task_name][ex_idx : ex_idx + 1] ) if task_name in out_bdict.keys(): for action_name in out_bdict[task_name].keys(): out_dict[task_name][action_name][sent_idx].extend( out_bdict[task_name][action_name][ex_idx : ex_idx + 1] ) if num_eval_steps >= dump_preds_accumulation_steps: # Collect the sentences that have all mentions collected res, finalized_sent_idxs = collect_result( uid_dict, gold_dict, pred_dict, prob_dict, out_dict, cur_sentidx2_nummentions, ) all_finalized_sentences.extend([str(s) for s in finalized_sent_idxs]) num_eval_steps = 0 for final_sent_i in finalized_sent_idxs: assert final_sent_i in cur_sentidx2_nummentions del cur_sentidx2_nummentions[final_sent_i] for task_name in uid_dict.keys(): del uid_dict[task_name][final_sent_i] del prob_dict[task_name][final_sent_i] del pred_dict[task_name][final_sent_i] del gold_dict[task_name][final_sent_i] if task_name in out_dict.keys(): for action_name in out_dict[task_name].keys(): del out_dict[task_name][action_name][final_sent_i] if len(res) > 0: yield res res, finalized_sent_idxs = collect_result( uid_dict, gold_dict, pred_dict, prob_dict, out_dict, cur_sentidx2_nummentions ) all_finalized_sentences.extend([str(s) for s in finalized_sent_idxs]) for final_sent_i in finalized_sent_idxs: del cur_sentidx2_nummentions[final_sent_i] if len(res) > 0: # print("FINALIZED", finalized_sent_idxs) yield res assert ( len(cur_sentidx2_nummentions) == 0 ), f"After eval, some sentences had left over mentions {cur_sentidx2_nummentions}" assert set(all_finalized_sentences).intersection(sent_idx2num_mens.keys()) == set( [k for k, v in sent_idx2num_mens.items() if v > 0] ), ( f"Some sentences are left over " f"{[s for s in sent_idx2num_mens if s not in set(all_finalized_sentences) and sent_idx2num_mens[s] > 0]}" ) return None
[docs]def check_and_create_alias_cand_trie(save_folder, entity_symbols): """Create a mmap memory trie object for storing the alias-candidate mappings. Args: save_folder: save folder for alias trie entity_symbols: entity symbols """ try: TwoLayerVocabularyScoreTrie(load_dir=save_folder) except FileNotFoundError: log_rank_0_debug( logger, "Creating the alias candidate trie for faster parallel processing. " "This is a one time cost", ) alias_trie = entity_symbols._alias2qids alias_trie.dump(save_folder) return
[docs]def get_emb_file(save_folder): """Get the embedding numpy file for the batch. Args: save_folder: save folder Returns: string """ return os.path.join(save_folder, "bootleg_emb_file.npy")
[docs]def get_result_file(save_folder): """Get the jsonl label file for the batch. Args: save_folder: save folder Returns: string """ return os.path.join(save_folder, "bootleg_labels.jsonl")
[docs]def dump_model_outputs( model, dataloader, config, sentidx2num_mentions, save_folder, entity_symbols, task_name, overwrite_data, ): """Dump model outputs. Args: model: model dataloader: data loader config: config sentidx2num_mentions: Dict from sentence idx to number of mentions save_folder: save folder entity_symbols: entity symbols task_name: task name overwrite_data: overwrite saved mmap files Returns: mmemp file name for saved outputs, dtype file name for loading memmap file """ # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence # indices for overflowing sentences) unmerged_memmap_dir = os.path.join(save_folder, "model_outputs_mmap") utils.ensure_dir(unmerged_memmap_dir) final_unmerged_memmap = os.path.join(save_folder, "model_outputs_final.mmap") emb_file_config = os.path.join(unmerged_memmap_dir, "model_outputs_config.npy") if ( not overwrite_data and os.path.exists(final_unmerged_memmap) and os.path.exists(emb_file_config) ): log_rank_0_info( logger, f"Skipping GPU dumpings. {final_unmerged_memmap} already exists and overwrite is F.", ) return final_unmerged_memmap, emb_file_config K = entity_symbols.max_candidates + (not config.data_config.train_in_candidates) unmerged_storage_type = np.dtype( [ ("M", int), ("K", int), ("hidden_size", int), ("sent_idx", int), ("subsent_idx", int), ("alias_list_pos", int, 1), ("final_loss_true", int, 1), ("final_loss_pred", int, 1), ("final_loss_prob", float, 1), ("final_loss_cand_probs", float, K), ] ) np.save(emb_file_config, unmerged_storage_type, allow_pickle=True) item_size = np.memmap( final_unmerged_memmap, dtype=unmerged_storage_type, mode="w+", shape=(1,), ).nbytes total_expected_size = item_size * len(dataloader.dataset) / 1024**3 log_rank_0_info( logger, f"Expected size is {total_expected_size}GB.", ) data_arr = np.memmap( final_unmerged_memmap, dtype=unmerged_storage_type, mode="w+", shape=(len(dataloader.dataset),), ) # Init sent_idx to -1 for debugging data_arr[:]["sent_idx"] = -1 arr_idx = 0 for res_i, res_dict in enumerate( batched_pred_iter( model, dataloader, config.run_config.dump_preds_accumulation_steps, sentidx2num_mentions, ) ): batch_size = len(res_dict["uids"][task_name]) for i in tqdm(range(batch_size), total=batch_size, desc="Saving outputs"): # res_dict["output"][task_name] is dict with keys ['_input__alias_orig_list_pos', # 'bootleg_pred_1', '_input__sent_idx', '_input__for_dump_gold_cand_K_idx_train', # '_input__subsent_idx', 0, 1] sent_idx = res_dict["outputs"][task_name]["_input__sent_idx"][i] # print("INSIDE LOOP", sent_idx, "AT", i) subsent_idx = res_dict["outputs"][task_name]["_input__subsent_idx"][i] alias_orig_list_pos = res_dict["outputs"][task_name][ "_input__alias_orig_list_pos" ][i] gold_cand_K_idx_train = res_dict["outputs"][task_name][ "_input__for_dump_gold_cand_K_idx_train" ][i] data_arr[arr_idx]["K"] = K data_arr[arr_idx]["hidden_size"] = config.model_config.hidden_size data_arr[arr_idx]["sent_idx"] = sent_idx data_arr[arr_idx]["subsent_idx"] = subsent_idx data_arr[arr_idx]["alias_list_pos"] = alias_orig_list_pos # This will give all aliases seen by the model during training, independent of if it's gold or not data_arr[arr_idx]["final_loss_true"] = gold_cand_K_idx_train # get max for each alias, probs is K max_probs = res_dict["probs"][task_name][i].max(axis=0) pred_cands = res_dict["probs"][task_name][i].argmax(axis=0) data_arr[arr_idx]["final_loss_pred"] = pred_cands data_arr[arr_idx]["final_loss_prob"] = max_probs data_arr[arr_idx]["final_loss_cand_probs"] = res_dict["probs"][task_name][ i ].reshape(1, -1) arr_idx += 1 del res_dict # Merge all memmap files log_rank_0_info( logger, f"Finished dumping to memmap files. with {len(dataloader.dataset)} samples. Saved to {final_unmerged_memmap}", ) # for i in range(len(mmap_file)): # si = mmap_file[i]["sent_idx"] # if -1 == si: # import pdb # pdb.set_trace() # assert si != -1, f"{i} {mmap_file[i]}" return final_unmerged_memmap, emb_file_config
[docs]def collect_and_merge_results( unmerged_entity_emb_file, emb_file_config, config, sent_idx2num_mens, sent_idx2row, save_folder, entity_symbols, ): """Merge mentions, filtering non-gold labels, and saves to file. Args: unmerged_entity_emb_file: memmap file from dump step emb_file_config: config file for loading memmap file config: model config res_dict: result dictionary from Emmental predict sent_idx2num_mens: Dict sentence idx to number of mentions sent_idx2row: Dict sentence idx to row of eval data save_folder: folder to save results entity_symbols: entity symbols Returns: saved prediction file, total mentions seen """ num_processes = min( config.run_config.dataset_threads, int(multiprocessing.cpu_count() * 0.9) ) cache_dir = os.path.join(save_folder, "cache") utils.ensure_dir(cache_dir) trie_candidate_map_folder = None trie_qid2eid_file = None # Save the alias->QID candidate map and the QID->EID mapping in memory efficient structures for faster # prediction dumping if num_processes > 1: entity_prep_dir = data_utils.get_emb_prep_dir(config.data_config) trie_candidate_map_folder = os.path.join( entity_prep_dir, "for_dumping_preds", "alias_cand_trie" ) utils.ensure_dir(trie_candidate_map_folder) check_and_create_alias_cand_trie(trie_candidate_map_folder, entity_symbols) trie_qid2eid_file = os.path.join( entity_prep_dir, "for_dumping_preds", "qid2eid_trie" ) if not os.path.exists(trie_qid2eid_file): assert isinstance(entity_symbols._qid2eid, VocabularyTrie) entity_symbols._qid2eid.dump(trie_qid2eid_file) # write to file (M x hidden x size for each data point -- next step will deal with recovering original sentence # indices for overflowing sentences) merged_entity_emb_file = os.path.join(save_folder, "entity_embs_unmerged.mmap") K = entity_symbols.max_candidates + (not config.data_config.train_in_candidates) merged_storage_type = np.dtype( [ ("hidden_size", int), ("sent_idx", int), ("alias_list_pos", int), ("final_loss_pred", int), ("final_loss_prob", float), ("final_loss_cand_probs", float, K), ] ) unmerged_storage_type = np.dtype( np.load(emb_file_config, allow_pickle=True).tolist() ) result_file = get_result_file(save_folder) log_rank_0_debug(logger, f"Writing predictions to {result_file}...") merge_subsentences( num_processes=num_processes, subset_sent_idx2num_mens=sent_idx2num_mens, cache_folder=cache_dir, to_save_file=merged_entity_emb_file, to_save_storage=merged_storage_type, to_read_file=unmerged_entity_emb_file, to_read_storage=unmerged_storage_type, ) write_data_labels( num_processes=num_processes, merged_entity_emb_file=merged_entity_emb_file, merged_storage_type=merged_storage_type, sent_idx2row=sent_idx2row, cache_folder=cache_dir, out_file=result_file, entity_dump=entity_symbols, train_in_candidates=config.data_config.train_in_candidates, max_candidates=entity_symbols.max_candidates, trie_candidate_map_folder=trie_candidate_map_folder, trie_qid2eid_file=trie_qid2eid_file, ) filt_emb_data = np.memmap( merged_entity_emb_file, dtype=merged_storage_type, mode="r" ) total_mentions_seen = len(filt_emb_data) filt_emb_data = None # Cleanup cache - sometimes the file in cache_dir is still open so we need to retry to delete it try_rmtree(cache_dir) log_rank_0_debug( logger, f"Wrote predictions to {result_file} with {total_mentions_seen} mentions", ) return result_file, total_mentions_seen
[docs]def merge_subsentences( num_processes, subset_sent_idx2num_mens, cache_folder, to_save_file, to_save_storage, to_read_file, to_read_storage, ): """ Merge and flatten sentence over sub-sentences. Flatten all sentences back together over sub-sentences; removing the PAD aliases from the data I.e., converts from sent_idx -> array of values to (sent_idx, alias_idx) -> value with varying numbers of aliases per sentence. Args: num_processes: number of processes subset_sent_idx2num_mens: Dict of sentence index to number of mentions for this batch cache_folder: cache directory to_save_file: memmap file to save results to to_save_storage: save file storage type to_read_file: memmap file to read predictions from to_read_storage: read file storage type """ # Compute sent idx to offset so we know where to fill in mentions cur_offset = 0 sentidx2offset = {} for k, v in subset_sent_idx2num_mens.items(): sentidx2offset[k] = cur_offset cur_offset += v # print("Sent Idx, Num Mens, Offset", k, v, cur_offset) total_num_mentions = cur_offset # print("TOTAL", total_num_mentions) full_pred_data = np.memmap(to_read_file, dtype=to_read_storage, mode="r") K = int(full_pred_data[0]["K"]) hidden_size = int(full_pred_data[0]["hidden_size"]) # print("TOTAL MENS", total_num_mentions) filt_emb_data = np.memmap( to_save_file, dtype=to_save_storage, mode="w+", shape=(total_num_mentions,) ) filt_emb_data["hidden_size"] = hidden_size filt_emb_data["sent_idx"][:] = -1 filt_emb_data["alias_list_pos"][:] = -1 all_ids = list(range(0, len(full_pred_data))) start = time.time() if num_processes == 1: seen_ids = merge_subsentences_single( K, hidden_size, all_ids, filt_emb_data, full_pred_data, sentidx2offset, ) else: # Get trie for sentence start map trie_folder = os.path.join(cache_folder, "bootleg_sent_idx2num_mens") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sentidx2offset, out_file=trie_file) # Chunk up data chunk_size = int(np.ceil(len(full_pred_data) / num_processes)) row_idx_set_chunks = [ all_ids[ids : ids + chunk_size] for ids in range(0, len(full_pred_data), chunk_size) ] # Start pool input_args = [[K, hidden_size, chunk] for chunk in row_idx_set_chunks] log_rank_0_debug( logger, f"Merging sentences together with {num_processes} processes" ) pool = multiprocessing.Pool( processes=num_processes, initializer=merge_subsentences_initializer, initargs=[ to_save_file, to_save_storage, to_read_file, to_read_storage, trie_file, ], ) seen_ids = set() for sent_ids_seen in pool.imap_unordered( merge_subsentences_hlp, input_args, chunksize=1 ): for emb_id in sent_ids_seen: assert ( emb_id not in seen_ids ), f"{emb_id} already seen, something went wrong with sub-sentences" seen_ids.add(emb_id) pool.close() pool.join() filt_emb_data = np.memmap(to_save_file, dtype=to_save_storage, mode="r") # for i in range(len(filt_emb_data)): # si = filt_emb_data[i]["sent_idx"] # al_test = filt_emb_data[i]["alias_list_pos"] # if si == -1 or al_test == -1: # print("BAD", i, filt_emb_data[i]) # import pdb # # pdb.set_trace() logging.debug(f"Saw {len(seen_ids)} sentences") logging.debug(f"Time to merge sub-sentences {time.time() - start}s") return
[docs]def merge_subsentences_initializer( to_write_file, to_write_storage, to_read_file, to_read_storage, sentidx2offset_file ): """Merge subsentences initializer for multiprocessing. Args: to_write_file: file to write to_write_storage: mmap storage type to_read_file: file to read to_read_storage: mmap storage type sentidx2offset_file: sentence index to offset in mmap data """ global filt_emb_data_global filt_emb_data_global = np.memmap(to_write_file, dtype=to_write_storage, mode="r+") global full_pred_data_global full_pred_data_global = np.memmap(to_read_file, dtype=to_read_storage, mode="r") global sentidx2offset_marisa_global sentidx2offset_marisa_global = utils.load_single_item_trie(sentidx2offset_file)
[docs]def merge_subsentences_hlp(args): """Merge subsentences multiprocessing subprocess helper.""" K, hidden_size, r_idx_set = args return merge_subsentences_single( K, hidden_size, r_idx_set, filt_emb_data_global, full_pred_data_global, sentidx2offset_marisa_global, )
[docs]def merge_subsentences_single( K, hidden_size, r_idx_set, filt_emb_data, full_pred_data, sentidx2offset, ): """ Merge subsentences single process. Will flatted out the results from `full_pred_data` so each line of `filt_emb_data` is one alias prediction. Args: K: number candidates hidden_size: hidden size r_idx_set: batch result index filt_emb_data: mmap embedding file to write full_pred_data: mmap result file to read sentidx2offset: sentence to emb data offset """ seen_ids = set() for r_idx in r_idx_set: row = full_pred_data[r_idx] # get corresponding row to start writing into condensed memory mapped file sent_idx = str(row["sent_idx"]) if isinstance(sentidx2offset, dict): sent_start_idx = sentidx2offset[sent_idx] else: # Get from Trie sent_start_idx = sentidx2offset[sent_idx][0][0] # print("R IDS", r_idx, row["sent_idx"], "START", sent_start_idx) # for each VALID mention, need to write into original alias list pos in list true_val = row["final_loss_true"] alias_orig_pos = row["alias_list_pos"] # bc we are are using the mentions which includes both true and false golds, true_val == -1 only for # padded mentions or sub-sentence mentions if true_val != -1: # print( # "INSIDE MERGE", "I", i, "SENT", sent_idx, "TRUE", true_val, "ALIAS ORIG POS", alias_orig_pos, # "START SENT IDX", sent_start_idx, "EMB ID", sent_start_idx + alias_orig_pos # ) # id in condensed embedding emb_id = sent_start_idx + alias_orig_pos assert ( emb_id not in seen_ids ), f"{emb_id} already seen, something went wrong with sub-sentences" seen_ids.add(emb_id) filt_emb_data["sent_idx"][emb_id] = sent_idx filt_emb_data["alias_list_pos"][emb_id] = alias_orig_pos filt_emb_data["final_loss_pred"][emb_id] = row["final_loss_pred"] filt_emb_data["final_loss_prob"][emb_id] = row["final_loss_prob"] filt_emb_data["final_loss_cand_probs"][emb_id] = row[ "final_loss_cand_probs" ] return seen_ids
[docs]def get_sental2embid(merged_entity_emb_file, merged_storage_type): """Get sent_idx, alias_idx mapping to emb idx for quick lookup. Args: merged_entity_emb_file: memmap file after merge sentences merged_storage_type: file storage type Returns: Dict of f"{sent_idx}_{alias_idx}" -> index in merged_entity_emb_file """ filt_emb_data = np.memmap( merged_entity_emb_file, dtype=merged_storage_type, mode="r" ) sental2embid = {} for i, row in tqdm( enumerate(filt_emb_data), total=len(filt_emb_data), desc="Getting setnal2emb map", ): sent_idx = row["sent_idx"] alias_idx = row["alias_list_pos"] assert ( sent_idx != -1 and alias_idx != -1 ), f"{i} {row} Has Sent {sent_idx}, Al {alias_idx}" # Keep as string for Marisa Tri later sental2embid[f"{sent_idx}_{alias_idx}"] = i return sental2embid
[docs]def write_data_labels( num_processes, merged_entity_emb_file, merged_storage_type, sent_idx2row, cache_folder, out_file, entity_dump, train_in_candidates, max_candidates, trie_candidate_map_folder=None, trie_qid2eid_file=None, ): """Take the flattened data from merge_sentences and write out predictions. Args: num_processes: number of processes merged_entity_emb_file: input memmap file after merge sentences merged_storage_type: input file storage type sent_idx2row: Dict of sentence idx to row relevant to this subbatch cache_folder: folder to save temporary outputs out_file: final output file for predictions entity_dump: entity dump train_in_candidates: whether NC entities are not in candidate lists max_candidates: maximum number of candidates trie_candidate_map_folder: folder where trie of alias->candidate map is stored for parallel proccessing trie_qid2eid_file: file where trie of qid->eid map is stored for parallel proccessing """ st = time.time() sental2embid = get_sental2embid(merged_entity_emb_file, merged_storage_type) log_rank_0_debug(logger, f"Finished getting sentence map {time.time() - st}s") total_input = len(sent_idx2row) if num_processes == 1: filt_emb_data = np.memmap( merged_entity_emb_file, dtype=merged_storage_type, mode="r" ) write_data_labels_single( sentidx2row=sent_idx2row, output_file=out_file, filt_emb_data=filt_emb_data, sental2embid=sental2embid, alias_cand_map=entity_dump.get_alias2qids_dict(), qid2eid=entity_dump.get_qid2eid_dict(), train_in_cands=train_in_candidates, max_cands=max_candidates, ) else: assert ( trie_candidate_map_folder is not None ), "trie_candidate_map_folder is None and you have parallel turned on" assert ( trie_qid2eid_file is not None ), "trie_qid2eid_file is None and you have parallel turned on" # Get trie of sentence map trie_folder = os.path.join(cache_folder, "bootleg_sental2embid") utils.ensure_dir(trie_folder) trie_file = os.path.join(trie_folder, "sentidx.marisa") utils.create_single_item_trie(sental2embid, out_file=trie_file) # Chunk file for parallel writing # We do not use TemporaryFolders as the temp dir may not have enough space for large files create_ex_indir = os.path.join(cache_folder, "_bootleg_eval_temp_indir") utils.ensure_dir(create_ex_indir) create_ex_outdir = os.path.join(cache_folder, "_bootleg_eval_temp_outdir") utils.ensure_dir(create_ex_outdir) chunk_input = int(np.ceil(total_input / num_processes)) logger.debug( f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines" ) # Chunk up dictionary of data for parallel processing input_files = [] i = 0 cur_lines = 0 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") for s_idx in sent_idx2row: if cur_lines >= chunk_input: open_file.close() input_files.append(file_split) cur_lines = 0 i += 1 file_split = os.path.join(create_ex_indir, f"out{i}.jsonl") open_file = open(file_split, "w") line = sent_idx2row[s_idx] open_file.write(ujson.dumps(line, ensure_ascii=False) + "\n") cur_lines += 1 open_file.close() input_files.append(file_split) # Generation input/output pairs output_files = [ in_file_name.replace(create_ex_indir, create_ex_outdir) for in_file_name in input_files ] log_rank_0_debug(logger, "Done chunking files. Starting pool") pool = multiprocessing.Pool( processes=num_processes, initializer=write_data_labels_initializer, initargs=[ merged_entity_emb_file, merged_storage_type, trie_file, train_in_candidates, max_candidates, trie_candidate_map_folder, trie_qid2eid_file, ], ) input_args = list(zip(input_files, output_files)) total = 0 for res in pool.imap(write_data_labels_hlp, input_args, chunksize=1): total += 1 pool.close() pool.join() # Merge output files to final file log_rank_0_debug(logger, "Merging output files") with open(out_file, "wb") as outfile: for filename in glob.glob(os.path.join(create_ex_outdir, "*")): if filename == out_file: # don't want to copy the output into the output continue with open(filename, "rb") as readfile: shutil.copyfileobj(readfile, outfile)
[docs]def write_data_labels_initializer( merged_entity_emb_file, merged_storage_type, sental2embid_file, train_in_candidates, max_cands, trie_candidate_map_folder, trie_qid2eid_file, ): """ Write data labels multiprocessing initializer. Args: merged_entity_emb_file: flattened embedding input file merged_storage_type: mmap storage type sental2embid_file: sentence, alias -> embedding id mapping train_in_candidates: train in candidates flag max_cands: max candidates trie_candidate_map_folder: alias trie folder trie_qid2eid_file: qid to eid trie file """ global filt_emb_data_global filt_emb_data_global = np.memmap( merged_entity_emb_file, dtype=merged_storage_type, mode="r" ) global sental2embid_global sental2embid_global = utils.load_single_item_trie(sental2embid_file) global alias_cand_trie_global alias_cand_trie_global = TwoLayerVocabularyScoreTrie( load_dir=trie_candidate_map_folder ) global qid2eid_global qid2eid_global = VocabularyTrie(load_dir=trie_qid2eid_file) global train_in_candidates_global train_in_candidates_global = train_in_candidates global max_cands_global max_cands_global = max_cands
[docs]def write_data_labels_hlp(args): """Write data labels multiprocess helper function.""" input_file, output_file = args s_idx2row = {} with open(input_file) as in_f: for line in in_f: line = ujson.loads(line) s_idx2row[str(line["sent_idx_unq"])] = line return write_data_labels_single( s_idx2row, output_file, filt_emb_data_global, sental2embid_global, alias_cand_trie_global, qid2eid_global, train_in_candidates_global, max_cands_global, )
[docs]def write_data_labels_single( sentidx2row, output_file, filt_emb_data, sental2embid, alias_cand_map, qid2eid, train_in_cands, max_cands, ): """Write data labels single subprocess function. Will take the alias predictions and merge them back by sentence to be written out. Args: sentidx2row: sentence index to raw eval data row output_file: output file filt_emb_data: mmap embedding data (one prediction per row) sental2embid: sentence index, alias index -> embedding row id alias_cand_map: alias to candidate map qid2eid: qid to entity id map train_in_cands: training in candidates flag max_cands: maximum candidates """ with open(output_file, "w") as f_out: for sent_idx in sentidx2row: line = sentidx2row[sent_idx] aliases = line["aliases"] char_spans = line["char_spans"] assert sent_idx == str(line["sent_idx_unq"]) qids = [] ctx_emb_ids = [] entity_ids = [] probs = [] cands = [] cand_probs = [] entity_cands_qid = map_aliases_to_candidates( train_in_cands, max_cands, alias_cand_map, aliases ) # eid is entity id entity_cands_eid = map_candidate_qids_to_eid(entity_cands_qid, qid2eid) for al_idx, alias in enumerate(aliases): sent_idx_key = f"{sent_idx}_{al_idx}" assert ( sent_idx_key in sental2embid ), f"Dumped prediction data does not match data file. Can not find {sent_idx} - {al_idx}" if isinstance(sental2embid, dict): emb_idx = sental2embid[sent_idx_key] else: # Get from Trie emb_idx = sental2embid[sent_idx_key][0][0] # We will concatenate all contextualized embeddings at the end and need the row id to be offset here ctx_emb_ids.append(emb_idx) prob = filt_emb_data[emb_idx]["final_loss_prob"] prob = prob if not math.isnan(prob) else None cand_prob = strip_nan(filt_emb_data[emb_idx]["final_loss_cand_probs"]) pred_cand = filt_emb_data[emb_idx]["final_loss_pred"] eid = entity_cands_eid[al_idx][pred_cand] qid = entity_cands_qid[al_idx][pred_cand] qids.append(qid) probs.append(prob) cands.append(list(entity_cands_qid[al_idx])) cand_probs.append(list(cand_prob)) entity_ids.append(eid) line["qids"] = qids line["probs"] = probs line["cands"] = cands line["cand_probs"] = cand_probs line["entity_ids"] = entity_ids line["char_spans"] = char_spans f_out.write(ujson.dumps(line, ensure_ascii=False) + "\n")