"""Bootleg slice dataset."""
import hashlib
import logging
import multiprocessing
import os
import shutil
import time
import traceback
from collections import defaultdict
import numpy as np
import ujson
from tqdm.auto import tqdm
from bootleg import log_rank_0_debug, log_rank_0_info
from bootleg.symbols.constants import ANCHOR_KEY, FINAL_LOSS
from bootleg.utils import data_utils, utils
logger = logging.getLogger(__name__)
[docs]def get_slice_values(slice_names, line):
"""
Results a dictionary of all slice values for an input example.
Any mention with a slice value of > 0.5 gets assigned that slice. If some
slices are missing from the input, we assign all mentions as not being in
that slice (getting a 0 label value). We also check that slices are
formatted correctly.
Args:
slice_names: slice names to evaluate on
line: input data json line
Returns: Dict of slice name to alias index string to float value of if mention is in a slice or not.
"""
slices = {}
if "slices" in line:
assert type(line["slices"]) is dict
aliases = line["aliases"]
slices = line["slices"]
# remove slices that we don't need
for slice_name in list(slices.keys()):
if slice_name not in slice_names:
del slices[slice_name]
else:
assert len(slices[slice_name]) == len(
aliases
), "Must have a prob label for each mention"
# FINAL_LOSS and BASE_SLICE are in slice_names but are generated by us so we do not want them to be in slices
assert (
FINAL_LOSS not in slices
), f"You can't have {FINAL_LOSS} be slice names. You have {slices.keys()}. {FINAL_LOSS} is used internally."
for slice_name in slice_names:
if slice_name in [FINAL_LOSS]:
continue
# Add slices that are empty
if slice_name not in slices:
slices[slice_name] = {str(i): 0.0 for i in range(len(aliases))}
return slices
[docs]def create_examples_initializer(
data_config, slice_names, use_weak_label, split, train_in_candidates
):
"""Create example multiprocessing initialiezr."""
global constants_global
constants_global = {
"slice_names": slice_names,
"use_weak_label": use_weak_label,
"split": split,
"train_in_candidates": train_in_candidates,
}
[docs]def create_examples(
dataset,
create_ex_indir,
create_ex_outdir,
meta_file,
data_config,
dataset_threads,
slice_names,
use_weak_label,
split,
):
"""Create examples from the raw input data.
Args:
dataset: dataset file
create_ex_indir: temporary directory where input files are stored
create_ex_outdir: temporary directory to store output files from method
meta_file: metadata file to save the file names/paths for the next step in prep pipeline
data_config: data config
dataset_threads: number of threads
slice_names: list of slices to evaluate on
use_weak_label: whether to use weak labeling or not
split: data split
"""
log_rank_0_debug(logger, "Starting to extract subsentences")
start = time.time()
num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count()))
log_rank_0_debug(logger, "Counting lines")
total_input = sum(1 for _ in open(dataset))
if num_processes == 1:
out_file_name = os.path.join(create_ex_outdir, os.path.basename(dataset))
constants_dict = {
"slice_names": slice_names,
"use_weak_label": use_weak_label,
"split": split,
"train_in_candidates": data_config.train_in_candidates,
}
files_and_counts = {}
res = create_examples_single(
dataset, total_input, out_file_name, constants_dict
)
total_output = res["total_lines"]
max_alias2pred = res["max_alias2pred"]
files_and_counts[res["output_filename"]] = res["total_lines"]
else:
log_rank_0_info(
logger, f"Strating to extract examples with {num_processes} threads"
)
log_rank_0_debug(
logger, "Parallelizing with " + str(num_processes) + " threads."
)
chunk_input = int(np.ceil(total_input / num_processes))
log_rank_0_debug(
logger,
f"Chunking up {total_input} lines into subfiles of size {chunk_input} lines",
)
total_input_from_chunks, input_files_dict = utils.chunk_file(
dataset, create_ex_indir, chunk_input
)
input_files = list(input_files_dict.keys())
input_file_lines = [input_files_dict[k] for k in input_files]
output_files = [
in_file_name.replace(create_ex_indir, create_ex_outdir)
for in_file_name in input_files
]
assert (
total_input == total_input_from_chunks
), f"Lengths of files {total_input} doesn't mathc {total_input_from_chunks}"
log_rank_0_debug(logger, "Done chunking files")
pool = multiprocessing.Pool(
processes=num_processes,
initializer=create_examples_initializer,
initargs=[
data_config,
slice_names,
use_weak_label,
split,
data_config.train_in_candidates,
],
)
total_output = 0
max_alias2pred = 0
input_args = list(zip(input_files, input_file_lines, output_files))
# Store output files and counts for saving in next step
files_and_counts = {}
for res in pool.imap_unordered(create_examples_hlp, input_args, chunksize=1):
total_output += res["total_lines"]
max_alias2pred = max(max_alias2pred, res["max_alias2pred"])
files_and_counts[res["output_filename"]] = res["total_lines"]
pool.close()
pool.join()
utils.dump_json_file(
meta_file,
{
"num_mentions": total_output,
"files_and_counts": files_and_counts,
"max_alias2pred": max_alias2pred,
},
)
log_rank_0_debug(
logger,
f"Done with extracting examples in {time.time()-start}. Total lines seen {total_input}. "
f"Total lines kept {total_output}.",
)
return
[docs]def create_examples_hlp(args):
"""Create examples wrapper helper."""
in_file_name, in_file_lines, out_file_name = args
return create_examples_single(
in_file_name, in_file_lines, out_file_name, constants_global
)
[docs]def create_examples_single(in_file_name, in_file_lines, out_file_name, constants_dict):
"""Create examples multiprocessing helper."""
split = constants_dict["split"]
use_weak_label = constants_dict["use_weak_label"]
slice_names = constants_dict["slice_names"]
with open(out_file_name, "w") as out_f:
total_subsents = 0
# The memmap stores things differently when you have two integers and we want to keep a2p as an array
# Therefore for force max the minimum max_a2p to be 2
max_a2pred = 2
for ex in tqdm(
open(in_file_name), total=in_file_lines, desc=f"Reading in {in_file_name}"
):
line = ujson.loads(ex)
assert "sent_idx_unq" in line
assert "aliases" in line
assert ANCHOR_KEY in line
sent_idx = line["sent_idx_unq"]
# aliases are assumed to be lower-cased in candidate map
aliases = [alias.lower() for alias in line["aliases"]]
num_alias2pred = len(aliases)
slices = get_slice_values(slice_names, line)
# We need to only use True anchors for eval
anchor = [True for i in range(len(aliases))]
if ANCHOR_KEY in line:
anchor = line[ANCHOR_KEY]
assert len(aliases) == len(anchor)
assert all(isinstance(a, bool) for a in anchor)
if split != "train":
# Reindex aliases to predict to be where anchor == True because we only ever want to predict
# those (it will see all aliases in the forward pass but we will only score the True anchors)
for slice_name in slices:
aliases_to_predict = slices[slice_name]
slices[slice_name] = {
i: aliases_to_predict[i]
for i in aliases_to_predict
if anchor[int(i)] is True
}
# Add in FINAL LOSS slice
if split != "train":
slices[FINAL_LOSS] = {
str(i): 1.0 for i in range(len(aliases)) if anchor[i] is True
}
else:
slices[FINAL_LOSS] = {str(i): 1.0 for i in range(len(aliases))}
# If not use_weak_label, only the anchor is True aliases will be given to the model
# We must re-index alias to predict to be in terms of anchors == True
# Ex: anchors = [F, T, T, F, F, T]
# If dataset_is_eval, let
# a2p = [2,5] (a2p must only be for T anchors)
# AFTER NOT USE_WEAK_LABEL, DATA WILL BE ONLY THE TRUE ANCHORS
# a2p needs to be [1, 2] for the 3rd and 6th true become the 2nd and 3rd after not weak labelling
# If dataset_is_eval is False, let
# a2p = [0,2,4,5] (a2p can be anything)
if not use_weak_label:
assert (
ANCHOR_KEY in line
), "Cannot toggle off data weak labelling without anchor info"
# The number of aliases will be reduced to the number of true anchors
num_alias2pred = sum(anchor)
# We must correct this mapping because the indexing will change when we remove False anchors (see
# comment example above)
slices = data_utils.correct_not_augmented_dict_values(anchor, slices)
# print("ANCHOR", anchor, "LINE", line, "SLICeS", slices)
# Remove slices that have no aliases to predict
for slice_name in list(slices.keys()):
if len(slices[slice_name]) <= 0:
del slices[slice_name]
all_false_anchors = all([anc is False for anc in anchor])
# For nicer code downstream, we make sure FINAL_LOSS is in here
# Only cases where it won't be is if use_weak_labels is True and the split is train
# (then we may have all false anchors)
if FINAL_LOSS not in slices:
assert (
all_false_anchors
), f"If {FINAL_LOSS} isn't in slice, it must be that all anchors are False. This is not true"
assert (
split != "train" or not use_weak_label
), "As all anchors are false, this must happen if you are evaling or training and using weak labels"
# TODO: optimizer here
# for i in range(0, num_alias2pred, 1):
# subset_slices = {}
# for slice_name in list(slices.keys()):
# subset_slices[slice_name] = dict(str(j):slice[slice_name][str(j)] for
# j in range(i:i+1))
# ex = InputExample(
# sent_idx=sent_idx,
# subslice_idx=i,
# anchor=anchor,
# num_alias2pred=num_alias2pred,
# slices=slices)
# examples.append(ex)
subslice_idx = 0
total_subsents += 1
max_a2pred = max(max_a2pred, num_alias2pred)
out_f.write(
ujson.dumps(
InputExample(
sent_idx=sent_idx,
subslice_idx=subslice_idx,
anchor=anchor,
num_alias2pred=num_alias2pred,
slices=slices,
).to_dict()
)
+ "\n"
)
return {
"total_lines": total_subsents,
"output_filename": out_file_name,
"max_alias2pred": max_a2pred,
}
[docs]def convert_examples_to_features_and_save_initializer(save_dataset_name, storage):
"""Convert to features multiprocessing initializer."""
global mmap_file_global
mmap_file_global = np.memmap(save_dataset_name, dtype=storage, mode="r+")
[docs]def convert_examples_to_features_and_save(
meta_file, dataset_threads, slice_names, save_dataset_name, storage
):
"""Convert the prepped examples into input features.
Saves in memmap files. These are used in the __get_item__ method.
Args:
meta_file: metadata file where input file paths are saved
dataset_threads: number of threads
slice_names: list of slice names to evaluation on
save_dataset_name: data file name to save
storage: data storage type (for memmap)
"""
log_rank_0_debug(logger, "Starting to extract subsentences")
start = time.time()
num_processes = min(dataset_threads, int(0.8 * multiprocessing.cpu_count()))
log_rank_0_info(
logger, f"Starting to build and save features with {num_processes} threads"
)
log_rank_0_debug(logger, "Counting lines")
total_input = utils.load_json_file(meta_file)["num_mentions"]
max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"]
files_and_counts = utils.load_json_file(meta_file)["files_and_counts"]
# IMPORTANT: for distributed writing to memmap files, you must create them in w+ mode before
# being opened in r+ mode by workers
memmap_file = np.memmap(
save_dataset_name, dtype=storage, mode="w+", shape=(total_input,), order="C"
)
# Save -1 in sent_idx to check that things are loaded correctly later
memmap_file[slice_names[0]]["sent_idx"][:] = -1
input_args = []
# Saves where in memap file to start writing
offset = 0
for i, in_file_name in enumerate(files_and_counts.keys()):
input_args.append(
{
"file_name": in_file_name,
"in_file_lines": files_and_counts[in_file_name],
"save_file_offset": offset,
"ex_print_mod": int(np.ceil(total_input / 20)),
"slice_names": slice_names,
"max_alias2pred": max_alias2pred,
}
)
offset += files_and_counts[in_file_name]
if num_processes == 1:
assert len(input_args) == 1
total_output = convert_examples_to_features_and_save_single(
input_args[0], memmap_file
)
else:
log_rank_0_debug(
logger,
"Initializing pool. This make take a few minutes.",
)
pool = multiprocessing.Pool(
processes=num_processes,
initializer=convert_examples_to_features_and_save_initializer,
initargs=[save_dataset_name, storage],
)
total_output = 0
for res in pool.imap_unordered(
convert_examples_to_features_and_save_hlp, input_args, chunksize=1
):
total_output += res
pool.close()
pool.join()
# Verify that sentences are unique and saved correctly
mmap_file = np.memmap(save_dataset_name, dtype=storage, mode="r")
all_uniq_ids = set()
for i in tqdm(range(total_input), desc="Checking sentence uniqueness"):
assert (
mmap_file[slice_names[0]]["sent_idx"][i] != -1
), f"Index {i} has -1 sent idx"
uniq_id = str(
f"{mmap_file[slice_names[0]]['sent_idx'][i]}.{mmap_file[slice_names[0]]['subslice_idx'][i]}"
)
assert (
uniq_id not in all_uniq_ids
), f"Idx {uniq_id} is not unique and already in data"
all_uniq_ids.add(uniq_id)
log_rank_0_debug(
logger,
f"Done with extracting examples in {time.time() - start}. Total lines seen {total_input}. "
f"Total lines kept {total_output}",
)
return
[docs]def convert_examples_to_features_and_save_hlp(input_dict):
"""Convert to features helper."""
return convert_examples_to_features_and_save_single(input_dict, mmap_file_global)
[docs]def convert_examples_to_features_and_save_single(input_dict, mmap_file):
"""Convert examples to features multiprocessing helper."""
file_name = input_dict["file_name"]
in_file_lines = input_dict["in_file_lines"]
save_file_offset = input_dict["save_file_offset"]
ex_print_mod = input_dict["ex_print_mod"]
max_alias2pred = input_dict["max_alias2pred"]
slice_names = input_dict["slice_names"]
total_saved_features = 0
for idx, in_line in tqdm(
enumerate(open(file_name)),
total=in_file_lines,
desc=f"Processing slice {file_name}",
):
example = InputExample.from_dict(ujson.loads(in_line))
example_idx = save_file_offset + idx
sent_idx = example.sent_idx
subslice_idx = example.subslice_idx
slices = example.slices
row_data = {}
for slice_name in slice_names:
# We use the information in "slices" key to generate two pieces of info
# 1. Binary info for if a mention is in the slice
# 2. Probabilistic info for the prob the mention is in the slice, this is used to train indicator heads
if slice_name in slices:
# Set indices of aliases to predict relevant to slice to 1-hot vector
slice_indexes = np.array([0] * (max_alias2pred))
for idx in slices[slice_name]:
# We consider an example as "in the slice" if it's probability is greater than 0.5
slice_indexes[int(idx)] = slices[slice_name][idx] > 0.5
alias_slice_incidence = slice_indexes
else:
# Set to zero for all aliases if no aliases in example occur in the slice
alias_slice_incidence = np.array([0] * max_alias2pred)
# Add probabilistic labels for training indicators
if slice_name in slices:
# padded values are -1 so they are masked in score function
slices_padded = np.array([-1.0] * (max_alias2pred))
for idx in slices[slice_name]:
# The indexes needed to be string for json
slices_padded[int(idx)] = slices[slice_name][idx]
alias2pred_probs = slices_padded
else:
alias2pred_probs = np.array([-1] * max_alias2pred)
total_saved_features += 1
# Write slice indices into record array
feature = InputFeatures(
sent_idx=sent_idx,
subslice_idx=subslice_idx,
alias_slice_incidence=alias_slice_incidence,
alias2pred_probs=alias2pred_probs,
)
# We are storing mmap file in column format, so column name first
mmap_file[slice_name]["sent_idx"][example_idx] = feature.sent_idx
mmap_file[slice_name]["subslice_idx"][example_idx] = feature.subslice_idx
mmap_file[slice_name]["alias_slice_incidence"][
example_idx
] = feature.alias_slice_incidence
mmap_file[slice_name]["prob_labels"][example_idx] = feature.alias2pred_probs
if example_idx % ex_print_mod == 0:
for slice_name in row_data:
# Make one string for distributed computation consistency
output_str = ""
output_str += f'*** Example Slice "{slice_name}" ***' + "\n"
output_str += (
f"sent_idx: {example.sent_idx}" + "\n"
)
output_str += (
f"subslice_idx: {example.subslice_idx}" + "\n"
)
output_str += (
f"anchor: {example.anchor}" + "\n"
)
output_str += (
f"slices: {example.slices.get(slice_name, {})}"
+ "\n"
) # Sometimes slices are emtpy if all anchors are false
output_str += "*** Feature ***" + "\n"
output_str += (
f"alias_slice_incidence: {row_data[slice_name].alias_slice_incidence}"
+ "\n"
)
output_str += (
f"alias2pred_probs: {row_data[slice_name].alias2pred_probs}"
+ "\n"
)
print(output_str)
mmap_file.flush()
return total_saved_features
[docs]class BootlegSliceDataset:
"""
Slice dataset class.
Our dataset class for holding data slices (or subpopulations).
Each mention can be part of 0 or more slices. When running eval, we use
the SliceDataset to determine which mentions are part of what slices. Importantly, although the model
"sees" all mentions, only GOLD anchor links are evaluated for eval (splits of test/dev).
Args:
main_args: main arguments
dataset: dataset file
use_weak_label: whether to use weak labeling or not
entity_symbols: entity symbols
dataset_threads: number of processes to use
split: data split
"""
def __init__(
self,
main_args,
dataset,
use_weak_label,
entity_symbols,
dataset_threads,
split="train",
):
"""Slice dataset initializer."""
global_start = time.time()
log_rank_0_info(logger, f"Building slice dataset for {split} from {dataset}.")
spawn_method = main_args.run_config.spawn_method
data_config = main_args.data_config
orig_spawn = multiprocessing.get_start_method()
multiprocessing.set_start_method(spawn_method, force=True)
self.slice_names = data_utils.get_eval_slices(data_config.eval_slices)
self.get_slice_dt = lambda max_a2p: np.dtype(
[
("sent_idx", int),
("subslice_idx", int),
("alias_slice_incidence", int, (max_a2p,)),
("prob_labels", float, (max_a2p,)),
]
)
self.get_storage = lambda max_a2p: np.dtype(
[
(slice_name, self.get_slice_dt(max_a2p))
for slice_name in self.slice_names
]
)
# Folder for all mmap saved files -> if call from cand_gen code, save_data_folder will fail
try:
save_dataset_folder = data_utils.get_save_data_folder(
data_config, use_weak_label, dataset
)
except AttributeError:
save_dataset_folder = data_utils.get_save_data_folder_candgen(
data_config, use_weak_label, dataset
)
utils.ensure_dir(save_dataset_folder)
# Folder for temporary output files
temp_output_folder = os.path.join(
data_config.data_dir, data_config.data_prep_dir, f"prep_{split}_slice_files"
)
utils.ensure_dir(temp_output_folder)
# Input step 1
create_ex_indir = os.path.join(temp_output_folder, "create_examples_input")
utils.ensure_dir(create_ex_indir)
# Input step 2
create_ex_outdir = os.path.join(temp_output_folder, "create_examples_output")
utils.ensure_dir(create_ex_outdir)
# Meta data saved files
meta_file = os.path.join(temp_output_folder, "meta_data.json")
# File for standard training data
hash = hashlib.sha1(str(self.slice_names).encode("UTF-8")).hexdigest()[:10]
self.save_dataset_name = os.path.join(
save_dataset_folder, f"ned_slices_{hash}.bin"
)
self.save_data_config_name = os.path.join(
save_dataset_folder, "ned_slices_config.json"
)
# =======================================================================================
# SLICE DATA
# =======================================================================================
log_rank_0_debug(logger, "Loading dataset...")
log_rank_0_debug(logger, f"Seeing if {self.save_dataset_name} exists")
if data_config.overwrite_preprocessed_data or (
not os.path.exists(self.save_dataset_name)
):
st_time = time.time()
try:
log_rank_0_info(
logger,
f"Building dataset from scratch. Saving to {save_dataset_folder}",
)
create_examples(
dataset,
create_ex_indir,
create_ex_outdir,
meta_file,
data_config,
dataset_threads,
self.slice_names,
use_weak_label,
split,
)
max_alias2pred = utils.load_json_file(meta_file)["max_alias2pred"]
convert_examples_to_features_and_save(
meta_file,
dataset_threads,
self.slice_names,
self.save_dataset_name,
self.get_storage(max_alias2pred),
)
utils.dump_json_file(
self.save_data_config_name, {"max_alias2pred": max_alias2pred}
)
log_rank_0_debug(
logger, f"Finished prepping data in {time.time() - st_time}"
)
except Exception as e:
tb = traceback.TracebackException.from_exception(e)
logger.error(e)
logger.error("\n".join(tb.stack.format()))
shutil.rmtree(save_dataset_folder, ignore_errors=True)
raise
log_rank_0_info(
logger,
f"Loading data from {self.save_dataset_name} and {self.save_data_config_name}",
)
max_alias2pred = utils.load_json_file(self.save_data_config_name)[
"max_alias2pred"
]
self.data, self.sent_to_row_id_dict = self.build_data_dict(
self.save_dataset_name, self.get_storage(max_alias2pred)
)
assert len(self.data) > 0
assert len(self.sent_to_row_id_dict) > 0
log_rank_0_debug(logger, "Removing temporary output files")
shutil.rmtree(temp_output_folder, ignore_errors=True)
# Set spawn back to original/default, which is "fork" or "spawn". This is needed for the Meta.config to
# be correctly passed in the collate_fn.
multiprocessing.set_start_method(orig_spawn, force=True)
log_rank_0_info(
logger,
f"Final slice data initialization time from {split} is {time.time() - global_start}s",
)
[docs] @classmethod
def build_data_dict(cls, save_dataset_name, storage):
"""
Build the slice dataset from saved file.
Loads the memmap slice dataset and create a mapping from sentence
index to row index.
Args:
save_dataset_name: saved memmap file name
storage: storage type of memmap file
Returns: numpy memmap data, Dict of sentence index to row in data
"""
sent_to_row_id_dict = defaultdict(list)
data = np.expand_dims(
np.memmap(save_dataset_name, dtype=storage, mode="r").view(np.recarray),
axis=1,
)
# Get any slice name for getting the sentence index
slice_name = data[0].dtype.names[0]
for i in tqdm(range(len(data)), desc="Building sent idx to row idx mapping"):
sent_idx = data[i][slice_name]["sent_idx"][0]
sent_to_row_id_dict[sent_idx].append(i)
return data, dict(sent_to_row_id_dict)
[docs] def contains_sentidx(self, sent_idx):
"""Return true if the sentence index is in the dataset.
Args:
sent_idx: sentence index
Returns: bool whether in dataset or not
"""
return sent_idx in self.sent_to_row_id_dict
[docs] def get_slice_incidence_arr(self, sent_idx, alias_orig_list_pos):
"""
Get slice incident array.
Given the sentence index and the list of aliases to get slice
indices for (may have -1 indicating no alias), return a dictionary of
slice_name -> 0/1 incidence array of if each alias in
alias_orig_list_pos was in the slice or not (-1 for no alias).
Args:
sent_idx: sentence index
alias_orig_list_pos: list of alias positions in input data list
(due to sentence splitting, aliases may be split up)
Returns: Dict of slice name -> 0/1 incidence array
"""
assert (
sent_idx in self.sent_to_row_id_dict
), f"Sentence {sent_idx} not in {self.save_dataset_name}"
alias_orig_list_pos = np.array(alias_orig_list_pos)
row_ids = self.sent_to_row_id_dict[sent_idx]
slices_to_return = {}
for row_i in row_ids:
for slice_name in self.slice_names:
slices_to_return[slice_name] = self.data[row_i][slice_name][
"alias_slice_incidence"
][0][alias_orig_list_pos]
slices_to_return[slice_name][alias_orig_list_pos == -1] = -1
return slices_to_return