"""Bootleg NED Dataset."""
import logging
import multiprocessing
import os
import re
import shutil
import sys
import time
import traceback
import warnings
from collections import defaultdict
import numpy as np
import torch
import ujson
from emmental.data import EmmentalDataset
from tqdm.auto import tqdm
from bootleg import log_rank_0_debug, log_rank_0_info
from bootleg.layers.alias_to_ent_encoder import AliasEntityTable
from bootleg.symbols.constants import ANCHOR_KEY, PAD_ID, STOP_WORDS
from bootleg.symbols.entity_symbols import EntitySymbols
from bootleg.symbols.kg_symbols import KGSymbols
from bootleg.symbols.type_symbols import TypeSymbols
from bootleg.utils import data_utils, utils
warnings.filterwarnings(
"ignore",
message="Could not import the lzma module. Your installed Python is incomplete. "
"Attempting to use lzma compression will result in a RuntimeError.",
)
warnings.filterwarnings(
"ignore",
message="FutureWarning: Passing (type, 1) or '1type'*",
)
logger = logging.getLogger(__name__)
# Removes warnings about TOKENIZERS_PARALLELISM
os.environ["TOKENIZERS_PARALLELISM"] = "false"
[docs]def get_structural_entity_str(items, max_tok_len, sep_tok):
"""Return sep_tok joined list of items of strucutral resources.
Args:
items: list of structural resources
max_tok_len: maximum token length
sep_tok: token to separate out resources
Returns:
result string, number of items that went beyond ``max_tok_len``
"""
i = 1
over_len = 0
while True:
res = f" {sep_tok} " + f" {sep_tok} ".join(items[:i])
if len(res.split()) > max_tok_len or i > len(items):
if i < len(items):
over_len = 1
res = f" {sep_tok} " + f" {sep_tok} ".join(items[: max(1, i - 1)])
break
i += 1
return res, over_len
[docs]def get_entity_string(
qid,
constants,
entity_symbols,
kg_symbols,
type_symbols,
):
"""
Get string representation of entity.
For each entity, generates a string that is fed into a language model to
generate an entity embedding. Returns all tokens that are the title of the
entity (even if in the description)
Args:
qid: QID
constants: Dict of constants
entity_symbols: entity symbols
kg_symbols: kg symbols
type_symbols: type symbols
Returns: entity strings, number of types over max length, number of relations over max length
"""
over_kg_len = 0
over_type_len = 0
desc_str = (
"[ent_desc] " + entity_symbols.get_desc(qid) if constants["use_desc"] else ""
)
title_str = entity_symbols.get_title(qid) if entity_symbols.qid_exists(qid) else ""
# To encourage mention similarity, we remove the (<type>) from titles
title_str = re.sub(r"(\(.*\))", r"", title_str).strip()
# To add kgs, sep by "[ent_kg]" and then truncate to max_ent_kg_len
# Then merge with description text
if constants["use_kg"]:
# Triples stores "relation tail_qid_title" (e.g. "is member of Manchester United" for qid = David Beckham)
triples = []
for rel, tail_qids in kg_symbols.get_relations_tails_for_qid(qid).items():
for tail_q in tail_qids:
if not entity_symbols.qid_exists(tail_q):
continue
triples.append(rel + " " + entity_symbols.get_title(tail_q))
kg_str, over_len = get_structural_entity_str(
triples,
constants["max_ent_kg_len"],
"[ent_kg]",
)
over_kg_len += over_len
desc_str = " ".join([kg_str, desc_str])
# To add types, sep by "[ent_type]" and then truncate to max_type_ent_len
# Then merge with description text
if constants["use_types"]:
type_str, over_len = get_structural_entity_str(
type_symbols.get_types(qid),
constants["max_ent_type_len"],
"[ent_type]",
)
over_type_len += over_len
desc_str = " ".join([type_str, desc_str])
ent_str = " ".join([title_str, desc_str])
# Remove double spaces
ent_split = ent_str.split()
ent_str = " ".join(ent_split)
title_spans = []
if len(title_str) > 0:
# Find all occurrences of title words in the ent_str (helps if description has abbreviated name)
# Make sure you don't mask any types or kg relations
title_pieces = set(title_str.split())
to_skip = False
for e_id, ent_w in enumerate(ent_split):
if ent_w == "[ent_type]":
to_skip = True
if ent_w == "[ent_desc]":
to_skip = False
if to_skip:
continue
if ent_w in title_pieces and ent_w not in STOP_WORDS:
title_spans.append(e_id)
# all_title_occ = re.finditer(f"({title_str})", ent_str)
# all_spaces = np.array([m.start() for m in re.finditer("\s", ent_str)])
# for match in all_title_occ:
# start_w = np.sum(all_spaces < match.start())
# end_w = np.sum(all_spaces <= match.end())
# for i in range(start_w, end_w):
# title_spans.append(i)
return ent_str, title_spans, over_type_len, over_kg_len
[docs]def create_examples_initializer(constants_dict, tokenizer):
"""Create examples multiprocessing initializer."""
global constants_global
constants_global = constants_dict
global tokenizer_global
tokenizer_global = tokenizer
[docs]def create_examples(
dataset,
create_ex_indir,
create_ex_outdir,
meta_file,
data_config,
dataset_threads,
use_weak_label,
split,
is_bert,
tokenizer,
):
"""Create examples from the raw input data.
Args:
dataset: data file to read
create_ex_indir: temporary directory where input files are stored
create_ex_outdir: temporary directory to store output files from method
meta_file: metadata file to save the file names/paths for the next step in prep pipeline
data_config: data config
dataset_threads: number of threads
use_weak_label: whether to use weak labeling or not
split: data split
is_bert: is the tokenizer a BERT one
tokenizer: tokenizer
"""
start = time.time()
num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count()))
qidcnt_file = os.path.join(data_config.data_dir, data_config.qid_cnt_map)
log_rank_0_debug(logger, "Counting lines")
total_input = sum(1 for _ in open(dataset))
constants_dict = {
"is_bert": is_bert,
"use_weak_label": use_weak_label,
"split": split,
"qidcnt_file": qidcnt_file,
"max_seq_len": data_config.max_seq_len,
"max_seq_window_len": data_config.max_seq_window_len,
}
if not os.path.exists(qidcnt_file):
log_rank_0_info(
logger, f"{qidcnt_file} does not exist. Using uniform counts..."
)
if num_processes == 1:
out_file_name = os.path.join(create_ex_outdir, os.path.basename(dataset))
res = create_examples_single(
in_file_idx=0,
in_file_name=dataset,
in_file_lines=total_input,
out_file_name=out_file_name,
constants_dict=constants_dict,
tokenizer=tokenizer,
)
files_and_counts = {}
total_output = res["total_lines"]
files_and_counts[res["output_filename"]] = res["total_lines"]
else:
log_rank_0_info(
logger, f"Starting to extract examples using {num_processes} processes"
)
chunk_input = int(np.ceil(total_input / num_processes))
log_rank_0_debug(
logger,
f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines",
)
total_input_from_chunks, input_files_dict = utils.chunk_file(
dataset, create_ex_indir, chunk_input
)
input_files = list(input_files_dict.keys())
input_file_lines = [input_files_dict[k] for k in input_files]
output_files = [
in_file_name.replace(create_ex_indir, create_ex_outdir)
for in_file_name in input_files
]
assert (
total_input == total_input_from_chunks
), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
log_rank_0_debug(logger, "Done chunking files. Starting pool.")
pool = multiprocessing.Pool(
processes=num_processes,
initializer=create_examples_initializer,
initargs=[constants_dict, tokenizer],
)
total_output = 0
input_args = list(
zip(
list(range(len(input_files))),
input_files,
input_file_lines,
output_files,
)
)
# Store output files and counts for saving in next step
files_and_counts = {}
for res in pool.imap_unordered(create_examples_hlp, input_args, chunksize=1):
total_output += res["total_lines"]
files_and_counts[res["output_filename"]] = res["total_lines"]
pool.close()
pool.join()
utils.dump_json_file(
meta_file, {"num_mentions": total_output, "files_and_counts": files_and_counts}
)
log_rank_0_debug(
logger,
f"Done with extracting examples in {time.time() - start}. "
f"Total lines seen {total_input}. Total lines kept {total_output}.",
)
return
[docs]def create_examples_hlp(args):
"""Create examples multiprocessing helper."""
in_file_idx, in_file_name, in_file_lines, out_file_name = args
return create_examples_single(
in_file_idx,
in_file_name,
in_file_lines,
out_file_name,
constants_global,
tokenizer_global,
)
[docs]def create_examples_single(
in_file_idx, in_file_name, in_file_lines, out_file_name, constants_dict, tokenizer
):
"""Create examples."""
split = constants_dict["split"]
max_seq_window_len = constants_dict["max_seq_window_len"]
use_weak_label = constants_dict["use_weak_label"]
qidcnt_file = constants_dict["qidcnt_file"]
qid2cnt = {}
quantile_buckets = [float(i / 100) for i in list(range(0, 101, 5))]
# If not qid2cnt, the quantile_bucket will be 1.0
quants = np.array([-1 for _ in quantile_buckets])
quants[-1] = 0
if os.path.exists(qidcnt_file):
qid2cnt = ujson.load(open(qidcnt_file))
quants = np.quantile(list(qid2cnt.values()), quantile_buckets)
with open(out_file_name, "w", encoding="utf-8") as out_f:
total_subsents = 0
total_lines = 0
for ex in tqdm(
open(in_file_name, "r", encoding="utf-8"),
total=in_file_lines,
desc=f"{in_file_idx}",
position=in_file_idx,
):
total_lines += 1
line = ujson.loads(ex)
assert "sent_idx_unq" in line
assert "aliases" in line
assert "qids" in line
assert "char_spans" in line, (
'Require "char_spans" to be input. '
"See utils/preprocessing/convert_to_char_spans.py"
)
assert "sentence" in line
assert ANCHOR_KEY in line
sent_idx = line["sent_idx_unq"]
# aliases are assumed to be lower-cased in candidate map
aliases = [alias.lower() for alias in line["aliases"]]
qids = line["qids"]
spans = line["char_spans"]
phrase = line["sentence"]
assert (
len(spans) == len(aliases) == len(qids)
), "lengths of alias-related values not equal"
# For datasets, we see all aliases, unless use_weak_label is turned off
# aliases_seen_by_model = [i for i in range(len(aliases))]
anchor = [True for i in range(len(aliases))]
if ANCHOR_KEY in line:
anchor = line[ANCHOR_KEY]
assert len(aliases) == len(anchor)
assert all(isinstance(a, bool) for a in anchor)
for span in spans:
assert (
len(span) == 2
), f"Span should be len 2. Your span {span} is {len(span)}"
assert span[1] <= len(
phrase
), f"You have span {span} that is beyond the length of the sentence {phrase}"
if not use_weak_label:
aliases = [aliases[i] for i in range(len(anchor)) if anchor[i] is True]
qids = [qids[i] for i in range(len(anchor)) if anchor[i] is True]
spans = [spans[i] for i in range(len(anchor)) if anchor[i] is True]
# aliases_seen_by_model = [i for i in range(len(aliases))]
anchor = [True for i in range(len(aliases))]
# Happens if use weak labels is False
if len(aliases) == 0:
continue
for subsent_idx in range(len(aliases)):
span = spans[subsent_idx]
alias_anchor = anchor[subsent_idx]
alias = aliases[subsent_idx]
qid = qids[subsent_idx]
context = extract_context(span, phrase, max_seq_window_len, tokenizer)
# Get the percentile bucket between [0, 1]
# Large counts will be closer to 1
qid_cnt_mask_score = quantile_buckets[sum(qid2cnt.get(qid, 0) > quants)]
assert 0 <= qid_cnt_mask_score <= 100
new_span = [
context.index("[ent_start]"),
context.index("[ent_end]") + len("[ent_end]"),
]
# alias_to_predict_arr is an index into idxs_arr/anchor_arr/aliases_arr.
# It should only include true anchors if eval dataset.
# During training want to backpropagate on false anchors as well
if split != "train":
alias_to_predict = 0 if alias_anchor is True else -1
else:
alias_to_predict = 0
total_subsents += 1
out_f.write(
ujson.dumps(
InputExample(
sent_idx=sent_idx,
subsent_idx=subsent_idx,
alias_list_pos=subsent_idx,
alias_to_predict=alias_to_predict,
span=new_span,
phrase=context,
alias=alias,
qid=qid,
qid_cnt_mask_score=qid_cnt_mask_score,
).to_dict(),
ensure_ascii=False,
)
+ "\n"
)
return {"total_lines": total_subsents, "output_filename": out_file_name}
[docs]def convert_examples_to_features_and_save_initializer(
tokenizer,
data_config,
save_dataset_name,
save_labels_name,
X_storage,
Y_storage,
):
"""Create examples multiprocessing initializer."""
global tokenizer_global
tokenizer_global = tokenizer
global entitysymbols_global
entitysymbols_global = EntitySymbols.load_from_cache(
load_dir=os.path.join(data_config.entity_dir, data_config.entity_map_dir),
alias_cand_map_dir=data_config.alias_cand_map,
alias_idx_dir=data_config.alias_idx_map,
)
global mmap_file_global
mmap_file_global = np.memmap(save_dataset_name, dtype=X_storage, mode="r+")
global mmap_label_file_global
mmap_label_file_global = np.memmap(save_labels_name, dtype=Y_storage, mode="r+")
[docs]def convert_examples_to_features_and_save(
meta_file,
guid_dtype,
data_config,
dataset_threads,
use_weak_label,
split,
is_bert,
save_dataset_name,
save_labels_name,
X_storage,
Y_storage,
tokenizer,
entity_symbols,
):
"""
Create features from examples.
Converts the prepped examples into input features and saves in memmap
files. These are used in the __get_item__ method.
Args:
meta_file: metadata file where input file paths are saved
guid_dtype: unique identifier dtype
data_config: data config
dataset_threads: number of threads
use_weak_label: whether to use weak labeling or not
split: data split
is_bert: is the tokenizer a BERT tokenizer
save_dataset_name: data features file name to save
save_labels_name: data labels file name to save
X_storage: data features storage type (for memmap)
Y_storage: data labels storage type (for memmap)
tokenizer: tokenizer
entity_symbols: entity symbols
"""
start = time.time()
num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count()))
# One example per mention per candidate
total_input = utils.load_json_file(meta_file)["num_mentions"]
files_and_counts = utils.load_json_file(meta_file)["files_and_counts"]
# IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before being opened in r+
memmap_file = np.memmap(
save_dataset_name, dtype=X_storage, mode="w+", shape=(total_input,), order="C"
)
# Save -1 in sent_idx to check that things are loaded correctly later
memmap_file["sent_idx"][:] = -1
memmap_label_file = np.memmap(
save_labels_name, dtype=Y_storage, mode="w+", shape=(total_input,), order="C"
)
input_args = []
# Saves where in memap file to start writing
offset = 0
for i, in_file_name in enumerate(files_and_counts.keys()):
input_args.append(
{
"file_name": in_file_name,
"in_file_idx": i,
"in_file_lines": files_and_counts[in_file_name],
"save_file_offset": offset,
"ex_print_mod": int(np.ceil(total_input / 20)),
"guid_dtype": guid_dtype,
"is_bert": is_bert,
"use_weak_label": use_weak_label,
"split": split,
"max_seq_len": data_config.max_seq_len,
"train_in_candidates": data_config.train_in_candidates,
"print_examples": data_config.print_examples_prep,
}
)
offset += files_and_counts[in_file_name]
if num_processes == 1:
assert len(input_args) == 1
total_output = convert_examples_to_features_and_save_single(
input_args[0],
tokenizer,
entity_symbols,
memmap_file,
memmap_label_file,
)
else:
log_rank_0_debug(
logger,
"Initializing pool. This make take a few minutes.",
)
pool = multiprocessing.Pool(
processes=num_processes,
initializer=convert_examples_to_features_and_save_initializer,
initargs=[
tokenizer,
data_config,
save_dataset_name,
save_labels_name,
X_storage,
Y_storage,
],
)
total_output = 0
for res in pool.imap_unordered(
convert_examples_to_features_and_save_hlp, input_args, chunksize=1
):
c = res
total_output += c
pool.close()
# Verify that sentences are unique and saved correctly
mmap_file = np.memmap(save_dataset_name, dtype=X_storage, mode="r")
all_uniq_ids = set()
for i in tqdm(range(total_input), desc="Checking sentence uniqueness"):
assert mmap_file["sent_idx"][i] != -1, f"Index {i} has -1 sent idx"
uniq_id_without_al = str(
f"{mmap_file['sent_idx'][i]}.{mmap_file['subsent_idx'][i]}"
)
assert (
uniq_id_without_al not in all_uniq_ids
), f"Idx {uniq_id_without_al} is not unique and already in data"
all_uniq_ids.add(uniq_id_without_al)
log_rank_0_debug(
logger,
f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. "
f"Total lines kept {total_output}.",
)
return
[docs]def convert_examples_to_features_and_save_hlp(input_dict):
"""Convert examples to features multiprocessing initializer."""
return convert_examples_to_features_and_save_single(
input_dict,
tokenizer_global,
entitysymbols_global,
mmap_file_global,
mmap_label_file_global,
)
[docs]def convert_examples_to_features_and_save_single(
input_dict,
tokenizer,
entitysymbols,
mmap_file,
mmap_label_file,
):
"""Convert examples to features multiprocessing helper."""
file_name = input_dict["file_name"]
in_file_idx = input_dict["in_file_idx"]
in_file_lines = input_dict["in_file_lines"]
save_file_offset = input_dict["save_file_offset"]
ex_print_mod = input_dict["ex_print_mod"]
guid_dtype = input_dict["guid_dtype"]
print_examples = input_dict["print_examples"]
max_seq_len = input_dict["max_seq_len"]
split = input_dict["split"]
train_in_candidates = input_dict["train_in_candidates"]
# if not train_in_candidates:
# raise NotImplementedError("train_in_candidates of False is not fully supported yet")
max_total_input_len = max_seq_len
total_saved_features = 0
for idx, in_line in tqdm(
enumerate(open(file_name, "r", encoding="utf-8")),
total=in_file_lines,
desc=f"Processing {file_name}",
position=in_file_idx,
):
example = InputExample.from_dict(ujson.loads(in_line))
example_idx = save_file_offset + idx
alias_to_predict = (
example.alias_to_predict
) # Stores -1 if dev data and false anchor
alias_list_pos = example.alias_list_pos
span_start_idx, span_end_idx = example.span
alias = example.alias
qid = example.qid
candidate_sentence_input_ids = (
np.ones(max_total_input_len) * tokenizer.pad_token_id
)
candidate_sentence_attn_msks = np.ones(max_total_input_len) * 0
candidate_sentence_token_type_ids = np.ones(max_total_input_len) * 0
candidate_mention_cnt_ratio = np.ones(max_total_input_len) * -1
# ===========================================================
# GET GOLD LABEL
# ===========================================================
# generate indexes into alias table; -2 if unk
if not entitysymbols.alias_exists(alias):
# if we do not have this alias in our set, we give it an index of -2, meaning we will
# always get it wrong in eval
assert split in ["test", "dev"], (
f"Expected split of 'test' or 'dev'. If you are training, "
f"the alias {alias} must be in our entity dump"
)
alias_trie_idx = -2
alias_qids = []
else:
alias_trie_idx = entitysymbols.get_alias_idx(alias)
alias_qids = entitysymbols.get_qid_cands(alias)
# EID used in generating labels in dataloader - will set to 0 for NC
eid = -1
# EID used in final prediction dumping - keep as gold EID
for_dump_eid = -1
if entitysymbols.qid_exists(qid):
eid = entitysymbols.get_eid(qid)
for_dump_eid = eid
if qid not in alias_qids:
# if we are not training in candidates, we only assign 0 correct id if the alias is in our map;
# otherwise we assign -2
if not train_in_candidates and alias_trie_idx != -2:
# set class label to be "not in candidate set"
gold_cand_K_idx = 0
eid = 0
else:
# if we are not using a NC (no candidate) but are in eval mode, we let the gold
# candidate not be in the candidate set we give in a true index of -2,
# meaning our model will always get this example incorrect
assert split in ["test", "dev"], (
f"Expected split of 'test' or 'dev' in sent {example.sent_idx}. If you are training, "
f"the QID {qid} must be in the candidate list for {alias} for "
f"data_args.train_in_candidates to be True"
)
gold_cand_K_idx = -2
else:
# Here we are getting the correct class label for training.
# Our training is "which of the max_entities entity candidates is the right one
# (class labels 1 to max_entities) or is it none of these (class label 0)".
# + (not discard_noncandidate_entities) is to ensure label 0 is
# reserved for "not in candidate set" class
gold_cand_K_idx = np.nonzero(np.array(alias_qids) == qid)[0][0] + (
not train_in_candidates
)
assert gold_cand_K_idx < entitysymbols.max_candidates + int(
not train_in_candidates
), (
f"The qid {qid} and alias {alias} is not in the top {entitysymbols.max_candidates} max candidates. "
f"The QID must be within max candidates."
)
# Create input IDs here to ensure each entity is truncated properly
inputs = tokenizer(
example.phrase,
is_split_into_words=False,
padding="max_length",
add_special_tokens=True,
truncation=True,
max_length=max_seq_len,
return_overflowing_tokens=False,
)
# In the rare case that the pre-context goes beyond max_seq_len, retokenize strating from
# ent start to guarantee the start/end tok will be there
start_tok = inputs.char_to_token(span_start_idx)
if start_tok is None:
new_phrase = example.phrase[span_start_idx:]
# Adjust spans
span_dist = span_end_idx - span_start_idx
span_start_idx = 0
span_end_idx = span_start_idx + span_dist
inputs = tokenizer(
new_phrase,
is_split_into_words=False,
padding="max_length",
add_special_tokens=True,
truncation=True,
max_length=max_seq_len,
return_overflowing_tokens=False,
)
if inputs.char_to_token(span_start_idx) is None:
print("REALLY BAD")
print(example)
new_span_start = inputs.char_to_token(span_start_idx) + 1
else:
# Includes the [ent_start]; we do not want to mask that so +1
new_span_start = start_tok + 1
# -1 to index the [ent_end] token, not the token after
end_tok = inputs.char_to_token(span_end_idx - 1)
if end_tok is None:
# -1 for CLS token
new_span_end = len(inputs["input_ids"])
else:
new_span_end = end_tok
final_toks = tokenizer.convert_ids_to_tokens(inputs["input_ids"])
assert (
final_toks[new_span_start - 1] == "[ent_start]"
), f"{final_toks} {new_span_start} {new_span_end} {span_start_idx} {span_end_idx}"
assert (new_span_end == len(inputs["input_ids"])) or final_toks[
new_span_end
] == "[ent_end]", f"{final_toks} {new_span_start} {new_span_end} {span_start_idx} {span_end_idx}"
candidate_sentence_input_ids[: len(inputs["input_ids"])] = inputs["input_ids"]
candidate_mention_cnt_ratio[new_span_start:new_span_end] = [
example.qid_cnt_mask_score for _ in range(new_span_start, new_span_end)
]
candidate_sentence_attn_msks[: len(inputs["attention_mask"])] = inputs[
"attention_mask"
]
candidate_sentence_token_type_ids[: len(inputs["token_type_ids"])] = inputs[
"token_type_ids"
]
# this stores the true entity pos in the candidate list we use to compute loss -
# all anchors for train and true anchors for dev/test
# leave as -1 if it's not an alias we want to predict; we get these if we split a
# sentence and need to only predict subsets
example_true_cand_positions_for_loss = PAD_ID
# this stores the true entity pos in the candidate list for all alias seen by model -
# all anchors for both train and eval
example_true_entity_eid = PAD_ID
# checks if alias is gold or not - alias_to_predict will be -1 for non gold aliases for eval
if alias_to_predict == 0:
example_true_cand_positions_for_loss = gold_cand_K_idx
example_true_entity_eid = eid
example_true_cand_positions_for_train = gold_cand_K_idx
# drop example if we have nothing to predict (no valid aliases) -- make sure this doesn't cause
# problems when we start using unk aliases...
if alias_trie_idx == PAD_ID:
logging.error(
f"There were 0 aliases in this example {example}. This shouldn't happen."
)
sys.exit(0)
total_saved_features += 1
feature = InputFeatures(
alias_idx=alias_trie_idx,
word_input_ids=candidate_sentence_input_ids,
word_token_type_ids=candidate_sentence_token_type_ids,
word_attention_mask=candidate_sentence_attn_msks,
word_qid_cnt_mask_score=candidate_mention_cnt_ratio,
gold_eid=example_true_entity_eid,
for_dump_gold_eid=for_dump_eid, # Store the one that isn't -1 for non-gold aliases
gold_cand_K_idx=example_true_cand_positions_for_loss,
for_dump_gold_cand_K_idx_train=example_true_cand_positions_for_train,
alias_list_pos=alias_list_pos,
sent_idx=int(example.sent_idx),
subsent_idx=int(example.subsent_idx),
guid=np.array(
[
(
int(example.sent_idx),
int(example.subsent_idx),
[alias_list_pos],
)
],
dtype=guid_dtype,
),
)
# Write feature
# We are storing mmap file in column format, so column name first
mmap_file["sent_idx"][example_idx] = feature.sent_idx
mmap_file["subsent_idx"][example_idx] = feature.subsent_idx
mmap_file["guids"][example_idx] = feature.guid
mmap_file["alias_idx"][example_idx] = feature.alias_idx
mmap_file["input_ids"][example_idx] = feature.word_input_ids
mmap_file["token_type_ids"][example_idx] = feature.word_token_type_ids
mmap_file["attention_mask"][example_idx] = feature.word_attention_mask
mmap_file["word_qid_cnt_mask_score"][
example_idx
] = feature.word_qid_cnt_mask_score
mmap_file["alias_orig_list_pos"][example_idx] = feature.alias_list_pos
mmap_file["for_dump_gold_cand_K_idx_train"][
example_idx
] = feature.for_dump_gold_cand_K_idx_train
mmap_file["gold_eid"][example_idx] = feature.gold_eid
mmap_file["for_dump_gold_eid"][example_idx] = feature.for_dump_gold_eid
mmap_label_file["gold_cand_K_idx"][example_idx] = feature.gold_cand_K_idx
if example_idx % ex_print_mod == 0:
# Make one string for distributed computation consistency
output_str = ""
output_str += "*** Example ***" + "\n"
output_str += (
f"guid: {example.sent_idx} subsent {example.subsent_idx}"
+ "\n"
)
output_str += f"phrase toks: {example.phrase}" + "\n"
output_str += (
f"alias_to_predict: {example.alias_to_predict}" + "\n"
)
output_str += (
f"alias_list_pos: {example.alias_list_pos}" + "\n"
)
output_str += f"aliases: {example.alias}" + "\n"
output_str += f"qids: {example.qid}" + "\n"
output_str += "*** Feature ***" + "\n"
output_str += (
f"gold_cand_K_idx: {feature.gold_cand_K_idx}" + "\n"
)
output_str += f"gold_eid: {feature.gold_eid}" + "\n"
output_str += (
f"for_dump_gold_eid: {feature.for_dump_gold_eid}"
+ "\n"
)
output_str += (
f"for_dump_gold_cand_K_idx_train: {feature.for_dump_gold_cand_K_idx_train}"
+ "\n"
)
output_str += (
f"input_ids: {' '.join([str(x) for x in feature.word_input_ids])}"
+ "\n"
)
output_str += (
f"token_type_ids: {' '.join([str(x) for x in feature.word_token_type_ids])}"
+ "\n"
)
output_str += (
f"attention_mask: {' '.join([str(x) for x in feature.word_attention_mask])}"
+ "\n"
)
output_str += (
f"word_qid_cnt_mask_score: {' '.join([str(x) for x in feature.word_qid_cnt_mask_score])}"
+ "\n"
)
output_str += f"guid: {feature.guid}" + "\n"
if print_examples:
print(output_str)
mmap_file.flush()
mmap_label_file.flush()
return total_saved_features
[docs]class BootlegDataset(EmmentalDataset):
"""Bootleg Dataset class.
Args:
main_args: input config
name: internal dataset name
dataset: dataset file
use_weak_label: whether to use weakly labeled mentions or not
load_entity_data: whether to load entity data or not
tokenizer: sentence tokenizer
entity_symbols: entity database class
dataset_threads: number of threads to use
split: data split
is_bert: is the tokenizer a BERT or not
dataset_range: offset into dataset
"""
def __init__(
self,
main_args,
name,
dataset,
use_weak_label,
load_entity_data,
tokenizer,
entity_symbols,
dataset_threads,
split="train",
is_bert=True,
dataset_range=None,
):
"""Bootleg dataset initlializer."""
log_rank_0_info(
logger,
f"Starting to build data for {split} from {dataset}",
)
global_start = time.time()
data_config = main_args.data_config
spawn_method = main_args.run_config.spawn_method
log_rank_0_debug(logger, f"Setting spawn method to be {spawn_method}")
orig_spawn = multiprocessing.get_start_method()
multiprocessing.set_start_method(spawn_method, force=True)
# Unique identifier is sentence index, subsentence index (due to sentence splitting), and aliases in split
guid_dtype = np.dtype(
[
("sent_idx", "i8", 1),
("subsent_idx", "i8", 1),
("alias_orig_list_pos", "i8", (1,)),
]
)
max_total_input_len = data_config.max_seq_len
# Storage for saving the data.
self.X_storage, self.Y_storage, self.X_entity_storage = (
[
("guids", guid_dtype, 1),
("sent_idx", "i8", 1),
("subsent_idx", "i8", 1),
("alias_idx", "i8", 1),
(
"input_ids",
"i8",
(max_total_input_len,),
),
(
"token_type_ids",
"i8",
(max_total_input_len,),
),
(
"attention_mask",
"i8",
(max_total_input_len,),
),
(
"word_qid_cnt_mask_score",
"float",
(max_total_input_len,),
),
("alias_orig_list_pos", "i8", 1),
(
"gold_eid",
"i8",
1,
), # What the eid of the gold entity is
(
"for_dump_gold_eid",
"i8",
1,
), # What the eid of the gold entity is independent of gold alias or not
(
"for_dump_gold_cand_K_idx_train",
"i8",
1,
), # Which of the K candidates is correct. Only used in dump_pred to stitch sub-sentences together
],
[
(
"gold_cand_K_idx",
"i8",
1,
), # Which of the K candidates is correct.
],
[
("entity_input_ids", "i8", (data_config.max_ent_len)),
("entity_token_type_ids", "i8", (data_config.max_ent_len)),
("entity_attention_mask", "i8", (data_config.max_ent_len)),
("entity_to_mask", "i8", (data_config.max_ent_len)),
],
)
self.split = split
self.popularity_mask = data_config.popularity_mask
self.context_mask_perc = data_config.context_mask_perc
self.tokenizer = tokenizer
# Table to map from alias_idx to entity_cand_eid used in the __get_item__
self.alias2cands_model = AliasEntityTable(
data_config=data_config, entity_symbols=entity_symbols
)
# Total number of entities used in the __get_item__
self.num_entities_with_pad_and_nocand = (
entity_symbols.num_entities_with_pad_and_nocand
)
self.raw_filename = dataset
# Folder for all mmap saved files
save_dataset_folder = data_utils.get_save_data_folder(
data_config, use_weak_label, self.raw_filename
)
utils.ensure_dir(save_dataset_folder)
# Folder for entity mmap saved files
save_entity_folder = data_utils.get_emb_prep_dir(data_config)
utils.ensure_dir(save_entity_folder)
# Folder for temporary output files
temp_output_folder = os.path.join(
data_config.data_dir,
data_config.data_prep_dir,
f"prep_{split}_dataset_files",
)
utils.ensure_dir(temp_output_folder)
# Input step 1
create_ex_indir = os.path.join(temp_output_folder, "create_examples_input")
utils.ensure_dir(create_ex_indir)
# Input step 2
create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output")
utils.ensure_dir(create_ex_outdir)
# Meta data saved files
meta_file = os.path.join(temp_output_folder, "meta_data.json")
# File for standard training data
self.save_dataset_name = os.path.join(save_dataset_folder, "ned_data.bin")
# File for standard labels
self.save_labels_name = os.path.join(save_dataset_folder, "ned_label.bin")
# File for type labels
self.save_entity_dataset_name = None
# =======================================================================================
# =======================================================================================
# =======================================================================================
# STANDARD DISAMBIGUATION
# =======================================================================================
# =======================================================================================
# =======================================================================================
log_rank_0_debug(
logger,
f"Seeing if {self.save_dataset_name} exists and {self.save_labels_name} exists",
)
if (
data_config.overwrite_preprocessed_data
or (not os.path.exists(self.save_dataset_name))
or (not os.path.exists(self.save_labels_name))
):
st_time = time.time()
log_rank_0_info(
logger,
f"Building dataset from scratch. Saving to {save_dataset_folder}.",
)
create_examples(
dataset,
create_ex_indir,
create_ex_outdir,
meta_file,
data_config,
dataset_threads,
use_weak_label,
split,
is_bert,
tokenizer,
)
try:
convert_examples_to_features_and_save(
meta_file,
guid_dtype,
data_config,
dataset_threads,
use_weak_label,
split,
is_bert,
self.save_dataset_name,
self.save_labels_name,
self.X_storage,
self.Y_storage,
tokenizer,
entity_symbols,
)
log_rank_0_debug(
logger,
f"Finished prepping disambig training data in {time.time() - st_time}",
)
except Exception as e:
tb = traceback.TracebackException.from_exception(e)
logger.error(e)
logger.error(traceback.format_exc())
logger.error("\n".join(tb.stack.format()))
os.remove(self.save_dataset_name)
os.remove(self.save_labels_name)
shutil.rmtree(save_dataset_folder, ignore_errors=True)
raise
log_rank_0_info(
logger,
f"Loading data from {self.save_dataset_name} and {self.save_labels_name}",
)
X_dict, Y_dict = self.build_data_dicts(
self.save_dataset_name,
self.save_labels_name,
self.X_storage,
self.Y_storage,
)
# =======================================================================================
# =======================================================================================
# =======================================================================================
# ENTITY TOKENS
# =======================================================================================
# =======================================================================================
# =======================================================================================
self.save_entity_dataset_name = os.path.join(
save_entity_folder,
f"entity_data"
f"_type{int(data_config.entity_type_data.use_entity_types)}"
f"_kg{int(data_config.entity_kg_data.use_entity_kg)}"
f"_desc{int(data_config.use_entity_desc)}.bin",
)
log_rank_0_debug(logger, f"Seeing if {self.save_entity_dataset_name} exists")
if load_entity_data:
if data_config.overwrite_preprocessed_data or (
not os.path.exists(self.save_entity_dataset_name)
):
st_time = time.time()
log_rank_0_info(logger, "Building entity data from scatch.")
try:
# Creating/saving data
build_and_save_entity_inputs(
self.save_entity_dataset_name,
self.X_entity_storage,
data_config,
dataset_threads,
tokenizer,
entity_symbols,
)
log_rank_0_debug(
logger, f"Finished prepping data in {time.time() - st_time}"
)
except Exception as e:
tb = traceback.TracebackException.from_exception(e)
logger.error(e)
logger.error(traceback.format_exc())
logger.error("\n".join(tb.stack.format()))
os.remove(self.save_entity_dataset_name)
raise
X_entity_dict = self.build_data_entity_dicts(
self.save_entity_dataset_name, self.X_entity_storage
)
self.X_entity_dict = X_entity_dict
else:
self.X_entity_dict = None
log_rank_0_debug(logger, "Removing temporary output files")
shutil.rmtree(temp_output_folder, ignore_errors=True)
log_rank_0_info(
logger,
f"Final data initialization time for {split} is {time.time() - global_start}s",
)
self.dataset_range = (
list(range(len(X_dict[next(iter(X_dict.keys()))])))
if dataset_range is None
else dataset_range
)
# Set spawn back to original/default, which is "fork" or "spawn".
# This is needed for the Meta.config to be correctly passed in the collate_fn.
multiprocessing.set_start_method(orig_spawn, force=True)
super().__init__(name, X_dict=X_dict, Y_dict=Y_dict, uid="guids")
[docs] @classmethod
def build_data_dicts(
cls, save_dataset_name, save_labels_name, X_storage, Y_storage
):
"""Return the X_dict and Y_dict of inputs and labels.
Args:
save_dataset_name: memmap file name with inputs
save_labels_name: memmap file name with labels
X_storage: memmap storage for inputs
Y_storage: memmap storage labels
Returns: X_dict of inputs and Y_dict of labels for Emmental datasets
"""
X_dict, Y_dict = (
{
"guids": [],
"sent_idx": [],
"subsent_idx": [],
"alias_idx": [],
"input_ids": [],
"token_type_ids": [],
"attention_mask": [],
"word_qid_cnt_mask_score": [],
"alias_orig_list_pos": [], # list of original position in the alias list this example is (see eval)
"gold_eid": [], # List of gold entity eids
"for_dump_gold_eid": [],
"for_dump_gold_cand_K_idx_train": [], # list of gold indices without subsentence masking (see eval)
},
{
"gold_cand_K_idx": [],
},
)
mmap_file = np.memmap(save_dataset_name, dtype=X_storage, mode="r")
mmap_label_file = np.memmap(save_labels_name, dtype=Y_storage, mode="r")
X_dict["sent_idx"] = torch.from_numpy(mmap_file["sent_idx"])
X_dict["subsent_idx"] = torch.from_numpy(mmap_file["subsent_idx"])
X_dict["guids"] = mmap_file["guids"] # uid doesn't need to be tensor
X_dict["alias_idx"] = torch.from_numpy(mmap_file["alias_idx"])
X_dict["input_ids"] = torch.from_numpy(mmap_file["input_ids"])
X_dict["token_type_ids"] = torch.from_numpy(mmap_file["token_type_ids"])
X_dict["attention_mask"] = torch.from_numpy(mmap_file["attention_mask"])
X_dict["word_qid_cnt_mask_score"] = torch.from_numpy(
mmap_file["word_qid_cnt_mask_score"]
)
X_dict["alias_orig_list_pos"] = torch.from_numpy(
mmap_file["alias_orig_list_pos"]
)
X_dict["gold_eid"] = torch.from_numpy(mmap_file["gold_eid"])
X_dict["for_dump_gold_eid"] = torch.from_numpy(mmap_file["for_dump_gold_eid"])
X_dict["for_dump_gold_cand_K_idx_train"] = torch.from_numpy(
mmap_file["for_dump_gold_cand_K_idx_train"]
)
Y_dict["gold_cand_K_idx"] = torch.from_numpy(mmap_label_file["gold_cand_K_idx"])
return X_dict, Y_dict
[docs] @classmethod
def build_data_entity_dicts(cls, save_dataset_name, X_storage):
"""Return the X_dict for the entity data.
Args:
save_dataset_name: memmap file name with entity data
X_storage: memmap storage type
Returns: Dict of labels
"""
X_dict = {
"entity_input_ids": [],
"entity_token_type_ids": [],
"entity_attention_mask": [],
"entity_to_mask": [],
}
mmap_label_file = np.memmap(save_dataset_name, dtype=X_storage, mode="r")
X_dict["entity_input_ids"] = torch.from_numpy(
mmap_label_file["entity_input_ids"]
)
X_dict["entity_token_type_ids"] = torch.from_numpy(
mmap_label_file["entity_token_type_ids"]
)
X_dict["entity_attention_mask"] = torch.from_numpy(
mmap_label_file["entity_attention_mask"]
)
X_dict["entity_to_mask"] = torch.from_numpy(mmap_label_file["entity_to_mask"])
return X_dict
[docs] def get_sentidx_to_rowids(self):
"""Get mapping from sent idx to row id in X_dict.
Returns: Dict of sent idx to row id
"""
sentidx2rowids = defaultdict(list)
for i, sent_id in enumerate(self.X_dict["sent_idx"]):
# Saving/loading dict will convert numeric keys to strings - keep consistent
sentidx2rowids[str(sent_id.item())].append(i)
return dict(sentidx2rowids)
def __getitem__(self, index):
r"""Get item by index.
Args:
index(index): The index of the item.
Returns:
Tuple[Dict[str, Any], Dict[str, Tensor]]: Tuple of x_dict and y_dict
"""
index = self.dataset_range[index]
x_dict = {name: feature[index] for name, feature in self.X_dict.items()}
y_dict = {name: label[index] for name, label in self.Y_dict.items()}
# Mask the mention tokens
if self.split == "train" and self.popularity_mask:
input_ids = self._mask_input_ids(x_dict)
x_dict["input_ids"] = input_ids
# Get the entity_cand_eid
entity_cand_eid = self.alias2cands_model(x_dict["alias_idx"]).long()
if self.X_entity_dict is not None:
entity_cand_input_ids = []
entity_cand_token_type_ids = []
entity_cand_attention_mask = []
# Get the entity token ids
for eid in entity_cand_eid:
if self.split == "train" and self.popularity_mask:
entity_input_ids = self._mask_entity_input_ids(x_dict, eid)
else:
entity_input_ids = self.X_entity_dict["entity_input_ids"][eid]
entity_cand_input_ids.append(entity_input_ids)
entity_cand_token_type_ids.append(
self.X_entity_dict["entity_token_type_ids"][eid]
)
entity_cand_attention_mask.append(
self.X_entity_dict["entity_attention_mask"][eid]
)
# Create M x K x token length
x_dict["entity_cand_input_ids"] = torch.stack(entity_cand_input_ids, dim=0)
x_dict["entity_cand_token_type_ids"] = torch.stack(
entity_cand_token_type_ids, dim=0
)
x_dict["entity_cand_attention_mask"] = torch.stack(
entity_cand_attention_mask, dim=0
)
x_dict["entity_cand_eval_mask"] = entity_cand_eid == -1
# Handles the index errors with -1 indexing into an embedding
x_dict["entity_cand_eid"] = torch.where(
entity_cand_eid >= 0,
entity_cand_eid,
(
torch.ones_like(entity_cand_eid, dtype=torch.long)
* (self.num_entities_with_pad_and_nocand - 1)
),
)
# Add dummy gold_unq_eid_idx for Emmental init - this gets overwritten in the collator in data.py
y_dict["gold_unq_eid_idx"] = y_dict["gold_cand_K_idx"]
return x_dict, y_dict
def _mask_input_ids(self, x_dict):
"""
Mask input context ids.
Mask the entity mention with high probability, especially if rare.
Further mask tokens 10% of the time
"""
# Get core dump if you don't do this
input_ids = torch.clone(x_dict["input_ids"])
cnt_ratio = x_dict["word_qid_cnt_mask_score"]
probability_matrix = torch.full(cnt_ratio.shape, 0.0)
fill_v = 0.0
if torch.any((0.0 <= cnt_ratio) & (cnt_ratio < 0.5)):
fill_v = 0.5
elif torch.any((0.5 <= cnt_ratio) & (cnt_ratio < 0.65)):
fill_v = 0.62
elif torch.any((0.65 <= cnt_ratio) & (cnt_ratio < 0.8)):
fill_v = 0.73
elif torch.any((0.8 <= cnt_ratio) & (cnt_ratio < 0.95)):
fill_v = 0.84
elif torch.any(0.95 <= cnt_ratio):
fill_v = 0.95
probability_matrix.masked_fill_(cnt_ratio >= 0.0, value=fill_v)
masked_indices = torch.bernoulli(probability_matrix).bool()
input_ids.masked_fill_(
masked_indices,
value=self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token),
)
# Mask all tokens by context_mask_perc
if self.context_mask_perc > 0.0:
input_ids_clone = input_ids.clone()
# We sample a few tokens in each sequence
probability_matrix = torch.full(
input_ids_clone.shape, self.context_mask_perc
)
special_tokens_mask = self.tokenizer.get_special_tokens_mask(
input_ids.tolist(), already_has_special_tokens=True
)
probability_matrix.masked_fill_(
torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0
)
if self.tokenizer._pad_token is not None:
padding_mask = input_ids.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
input_ids_clone[masked_indices] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token
)
input_ids = input_ids_clone
return input_ids
def _mask_entity_input_ids(self, x_dict, eid):
"""
Mask entity input ids.
Mask the entity to_mask index with high probability, especially if
mention is rare.
"""
# Get core dump if you don't do this
entity_input_ids = torch.clone(self.X_entity_dict["entity_input_ids"][eid])
cnt_ratio = x_dict["word_qid_cnt_mask_score"]
probability_matrix = torch.tensor(
self.X_entity_dict["entity_to_mask"][eid]
).float()
fill_v = 0.0
if torch.any((0.0 <= cnt_ratio) & (cnt_ratio < 0.5)):
fill_v = 0.5
elif torch.any((0.5 <= cnt_ratio) & (cnt_ratio < 0.65)):
fill_v = 0.62
elif torch.any((0.65 <= cnt_ratio) & (cnt_ratio < 0.8)):
fill_v = 0.73
elif torch.any((0.8 <= cnt_ratio) & (cnt_ratio < 0.95)):
fill_v = 0.84
elif torch.any(0.95 <= cnt_ratio):
fill_v = 0.95
probability_matrix.masked_fill_(probability_matrix > 0.0, value=fill_v)
masked_indices = torch.bernoulli(probability_matrix).bool()
entity_input_ids.masked_fill_(
masked_indices, value=self.tokenizer.convert_tokens_to_ids("[MASK]")
)
return entity_input_ids
def __getstate__(self):
"""Get state."""
state = self.__dict__.copy()
del state["X_dict"]
del state["Y_dict"]
return state
def __setstate__(self, state):
"""Set state."""
self.__dict__.update(state)
self.X_dict, self.Y_dict = self.build_data_dicts(
self.save_dataset_name,
self.save_labels_name,
self.X_storage,
self.Y_storage,
)
return state
def __repr__(self):
"""Repr."""
return (
f"Bootleg Dataset. Data at {self.save_dataset_name}. "
f"Labels at {self.save_labels_name}. "
)
def __len__(self):
"""Length."""
return len(self.dataset_range)
[docs]class BootlegEntityDataset(EmmentalDataset):
"""Bootleg Dataset class for entities.
Args:
main_args: input config
name: internal dataset name
dataset: dataset file
tokenizer: sentence tokenizer
entity_symbols: entity database class
dataset_threads: number of threads to use
split: data split
"""
def __init__(
self,
main_args,
name,
dataset,
tokenizer,
entity_symbols,
dataset_threads,
split="test",
):
"""Bootleg entity dataset initializer."""
assert split == "test", "Split must be test split for EntityDataset"
log_rank_0_info(
logger,
f"Starting to build data for {split} from {dataset}",
)
global_start = time.time()
data_config = main_args.data_config
spawn_method = main_args.run_config.spawn_method
log_rank_0_debug(logger, f"Setting spawn method to be {spawn_method}")
orig_spawn = multiprocessing.get_start_method()
multiprocessing.set_start_method(spawn_method, force=True)
# Storage for saving the data.
self.X_entity_storage = [
("entity_input_ids", "i8", (data_config.max_ent_len)),
("entity_token_type_ids", "i8", (data_config.max_ent_len)),
("entity_attention_mask", "i8", (data_config.max_ent_len)),
("entity_to_mask", "i8", (data_config.max_ent_len)),
]
self.split = split
self.popularity_mask = data_config.popularity_mask
self.context_mask_perc = data_config.context_mask_perc
self.tokenizer = tokenizer
# Table to map from alias_idx to entity_cand_eid used in the __get_item__
self.alias2cands_model = AliasEntityTable(
data_config=data_config, entity_symbols=entity_symbols
)
# Total number of entities used in the __get_item__
self.num_entities_with_pad_and_nocand = (
entity_symbols.num_entities_with_pad_and_nocand
)
# Folder for entity mmap saved files
save_entity_folder = data_utils.get_emb_prep_dir(data_config)
utils.ensure_dir(save_entity_folder)
# =======================================================================================
# =======================================================================================
# =======================================================================================
# ENTITY TOKENS
# =======================================================================================
# =======================================================================================
# =======================================================================================
self.save_entity_dataset_name = os.path.join(
save_entity_folder,
f"entity_data"
f"_type{int(data_config.entity_type_data.use_entity_types)}"
f"_kg{int(data_config.entity_kg_data.use_entity_kg)}"
f"_desc{int(data_config.use_entity_desc)}.bin",
)
log_rank_0_debug(logger, f"Seeing if {self.save_entity_dataset_name} exists")
if data_config.overwrite_preprocessed_data or (
not os.path.exists(self.save_entity_dataset_name)
):
st_time = time.time()
log_rank_0_info(logger, "Building entity data from scatch.")
try:
# Creating/saving data
build_and_save_entity_inputs(
self.save_entity_dataset_name,
self.X_entity_storage,
data_config,
dataset_threads,
tokenizer,
entity_symbols,
)
log_rank_0_debug(
logger, f"Finished prepping data in {time.time() - st_time}"
)
except Exception as e:
tb = traceback.TracebackException.from_exception(e)
logger.error(e)
logger.error(traceback.format_exc())
logger.error("\n".join(tb.stack.format()))
os.remove(self.save_entity_dataset_name)
raise
X_entity_dict = self.build_data_entity_dicts(
self.save_entity_dataset_name, self.X_entity_storage
)
# Add the unique identified of EID (the embeddings are already in this order)
X_entity_dict["guids"] = torch.arange(len(X_entity_dict["entity_input_ids"]))
log_rank_0_info(
logger,
f"Final data initialization time for {split} is {time.time() - global_start}s",
)
# Set spawn back to original/default, which is "fork" or "spawn".
# This is needed for the Meta.config to be correctly passed in the collate_fn.
multiprocessing.set_start_method(orig_spawn, force=True)
super().__init__(name, X_dict=X_entity_dict, uid="guids")
[docs] @classmethod
def build_data_entity_dicts(cls, save_dataset_name, X_storage):
"""Return the X_dict for the entity data.
Args:
save_dataset_name: memmap file name with entity data
X_storage: memmap storage type
Returns: Dict of labels
"""
X_dict = {
"entity_input_ids": [],
"entity_token_type_ids": [],
"entity_attention_mask": [],
"entity_to_mask": [],
}
mmap_label_file = np.memmap(save_dataset_name, dtype=X_storage, mode="r")
X_dict["entity_input_ids"] = torch.from_numpy(
mmap_label_file["entity_input_ids"]
)
X_dict["entity_token_type_ids"] = torch.from_numpy(
mmap_label_file["entity_token_type_ids"]
)
X_dict["entity_attention_mask"] = torch.from_numpy(
mmap_label_file["entity_attention_mask"]
)
X_dict["entity_to_mask"] = torch.from_numpy(mmap_label_file["entity_to_mask"])
return X_dict
def __getitem__(self, index):
r"""Get item by index.
Args:
index(index): The index of the item.
Returns:
Tuple[Dict[str, Any], Dict[str, Tensor]]: Tuple of x_dict and y_dict
"""
x_dict = {name: feature[index] for name, feature in self.X_dict.items()}
return x_dict
def __getstate__(self):
"""Get state."""
state = self.__dict__.copy()
del state["X_dict"]
del state["Y_dict"]
return state
def __setstate__(self, state):
"""Set state."""
self.__dict__.update(state)
return state
def __repr__(self):
"""Repr."""
return f"Bootleg Entity Dataset. Data at {self.save_entity_dataset_name}."