"""AliasEntityTable class."""
import logging
import os
import time
import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import tqdm
from bootleg import log_rank_0_debug
from bootleg.utils import data_utils, utils
from bootleg.utils.model_utils import get_max_candidates
logger = logging.getLogger(__name__)
[docs]class AliasEntityTable(nn.Module):
"""Stores table of the K candidate entity ids for each alias.
Args:
data_config: data config
entity_symbols: entity symbols
"""
def __init__(self, data_config, entity_symbols):
"""Alias table initializer."""
super(AliasEntityTable, self).__init__()
self.num_entities_with_pad_and_nocand = (
entity_symbols.num_entities_with_pad_and_nocand
)
self.num_aliases_with_pad_and_unk = len(entity_symbols.get_all_aliases()) + 2
self.K = get_max_candidates(entity_symbols, data_config)
(self.alias2entity_table, self.prep_file,) = self.prep(
data_config,
entity_symbols,
num_aliases_with_pad_and_unk=self.num_aliases_with_pad_and_unk,
num_cands_K=self.K,
)
# self.alias2entity_table = model_utils.move_to_device(self.alias2entity_table)
# Small check that loading was done correctly. This isn't a catch all,
# but will catch is the same or something went wrong. -2 is for alias not in our set and -1 is pad.
assert torch.equal(
self.alias2entity_table[-2],
torch.ones_like(self.alias2entity_table[-1]) * -1,
), "The second to last row of the alias table isn't -1, something wasn't loaded right."
assert torch.equal(
self.alias2entity_table[-1],
torch.ones_like(self.alias2entity_table[-1]) * -1,
), "The last row of the alias table isn't -1, something wasn't loaded right."
[docs] @classmethod
def prep(
cls,
data_config,
entity_symbols,
num_aliases_with_pad_and_unk,
num_cands_K,
):
"""Preps the alias to entity EID table.
Args:
data_config: data config
entity_symbols: entity symbols
num_aliases_with_pad_and_unk: number of aliases including pad and unk
num_cands_K: number of candidates per alias (aka K)
Returns: torch Tensor of the alias to EID table, save pt file
"""
# we pass num_aliases_with_pad_and_unk and num_cands_K to remove the dependence on entity_symbols
# when the alias table is already prepped
data_shape = (num_aliases_with_pad_and_unk, num_cands_K)
# dependent on train_in_candidates flag
prep_dir = data_utils.get_emb_prep_dir(data_config)
alias_str = os.path.splitext(data_config.alias_cand_map.replace("/", "_"))[0]
prep_file = os.path.join(
prep_dir,
f"alias2entity_table_{alias_str}_InC{int(data_config.train_in_candidates)}.pt",
)
log_rank_0_debug(logger, f"Looking for alias table in {prep_file}")
if not data_config.overwrite_preprocessed_data and os.path.exists(prep_file):
log_rank_0_debug(logger, f"Loading alias table from {prep_file}")
start = time.time()
alias2entity_table = np.memmap(
prep_file, dtype="int64", mode="r+", shape=data_shape
)
log_rank_0_debug(
logger, f"Loaded alias table in {round(time.time() - start, 2)}s"
)
else:
start = time.time()
log_rank_0_debug(logger, "Building alias table")
utils.ensure_dir(prep_dir)
alias2entity_table = cls.build_alias_table(data_config, entity_symbols)
mmap_file = np.memmap(prep_file, dtype="int64", mode="w+", shape=data_shape)
mmap_file[:] = alias2entity_table[:]
mmap_file.flush()
log_rank_0_debug(
logger,
f"Finished building and saving alias table in {round(time.time() - start, 2)}s.",
)
alias2entity_table = torch.from_numpy(alias2entity_table)
return alias2entity_table, prep_file
[docs] @classmethod
def build_alias_table(cls, data_config, entity_symbols):
"""Construct the alias to EID table.
Args:
data_config: data config
entity_symbols: entity symbols
Returns: numpy array where row is alias ID and columns are EID
"""
# we need to include a non candidate entity option for each alias and a row for PAD alias and not in dump alias
# +2 is for PAD alias (last row) and not in dump alias (second to last row)
# - same as -2 entity ids being not in cand list
num_aliases_with_pad_and_unk = len(entity_symbols.get_all_aliases()) + 2
alias2entity_table = (
np.ones(
(
num_aliases_with_pad_and_unk,
get_max_candidates(entity_symbols, data_config),
)
)
* -1
)
for alias in tqdm(
entity_symbols.get_all_aliases(), desc="Iterating over aliases"
):
alias_id = entity_symbols.get_alias_idx(alias)
# set all to -1 and fill in with real values for padding and fill in with real values
entity_list = np.ones(get_max_candidates(entity_symbols, data_config)) * -1
# set first column to zero
# if we are using noncandidate entity, this will remain a 0
# if we are not using noncandidate entities, this will get overwritten below.
entity_list[0] = 0
eid_cands = entity_symbols.get_eid_cands(alias)
# we get qids and want entity ids
# first entry is the non candidate class
# val[0] because vals is [qid, page_counts]
entity_list[
(not data_config.train_in_candidates) : len(eid_cands)
+ (not data_config.train_in_candidates)
] = np.array(eid_cands)
alias2entity_table[alias_id, :] = entity_list
return alias2entity_table
[docs] def get_alias_eid_priors(self, alias_indices):
"""Return the prior scores of the given alias_indices.
Args:
alias_indices: alias indices (B x M)
Returns: entity candidate normalized scores (B x M x K x 1)
"""
candidate_entity_scores = (
self.alias2entityprior_table[alias_indices].unsqueeze(-1).float()
)
return candidate_entity_scores
[docs] def forward(self, alias_indices):
"""Model forward.
Args:
alias_indices: alias indices (B x M)
Returns: entity candidate EIDs (B x M x K)
"""
candidate_entity_ids = self.alias2entity_table[alias_indices]
return candidate_entity_ids
def __getstate__(self):
"""Get state."""
state = self.__dict__.copy()
# Not picklable
del state["alias2entity_table"]
del state["alias2entityprior_table"]
return state
def __setstate__(self, state):
"""Set state."""
self.__dict__.update(state)
self.alias2entity_table = torch.tensor(
np.memmap(
self.prep_file,
dtype="int64",
mode="r",
shape=(self.num_aliases_with_pad_and_unk, self.K),
)
)