"""Entity symbols."""
import copy
import logging
import os
from typing import Callable, Dict, Optional, Union
from tqdm.auto import tqdm
import bootleg.utils.utils as utils
from bootleg.symbols.constants import edit_op
from bootleg.utils.classes.nested_vocab_tries import (
TwoLayerVocabularyScoreTrie,
VocabularyTrie,
)
logger = logging.getLogger(__name__)
[docs]class EntitySymbols:
"""Entity Symbols class for managing entity metadata."""
def __init__(
self,
alias2qids: Union[Dict[str, list], TwoLayerVocabularyScoreTrie],
qid2title: Dict[str, str],
qid2desc: Union[Dict[str, str]] = None,
qid2eid: Optional[VocabularyTrie] = None,
alias2id: Optional[VocabularyTrie] = None,
max_candidates: int = 30,
alias_cand_map_dir: str = "alias2qids",
alias_idx_dir: str = "alias2id",
edit_mode: Optional[bool] = False,
verbose: Optional[bool] = False,
):
"""Entity symbols initializer."""
# We support different candidate mappings for the same set of entities
self.alias_cand_map_dir = alias_cand_map_dir
self.alias_idx_dir = alias_idx_dir
self.max_candidates = max_candidates
self.edit_mode = edit_mode
self.verbose = verbose
if qid2eid is None:
# +1 as 0 is reserved for not-in-cand list entity
qid2eid = {q: i + 1 for i, q in enumerate(qid2title.keys())}
if alias2id is None:
alias2id = {a: i for i, a in enumerate(alias2qids.keys())}
# If edit mode is ON, we must load everything as a dictionary
if self.edit_mode:
self._load_edit_mode(
alias2qids,
qid2title,
qid2desc,
qid2eid,
alias2id,
)
else:
self._load_non_edit_mode(
alias2qids,
qid2title,
qid2desc,
qid2eid,
alias2id,
)
# This assumes that eid of 0 is NO_CAND and eid of -1 is NULL entity; neither are in dict
self.num_entities = len(self._qid2eid)
self.num_entities_with_pad_and_nocand = self.num_entities + 2
def _load_edit_mode(
self,
alias2qids: Union[Dict[str, list], TwoLayerVocabularyScoreTrie],
qid2title: Dict[str, str],
qid2desc: Union[Dict[str, str]],
qid2eid: Union[Dict[str, int], VocabularyTrie],
alias2id: Union[Dict[str, int], VocabularyTrie],
):
"""Load in edit mode.
Loading in edit mode requires all inputs be cast to dictionaries. Tries do not allow value changes.
"""
# Convert to dict for editing
if isinstance(alias2qids, TwoLayerVocabularyScoreTrie):
alias2qids = alias2qids.to_dict()
self._alias2qids: Union[
Dict[str, list], TwoLayerVocabularyScoreTrie
] = alias2qids
self._qid2title: Dict[str, str] = qid2title
self._qid2desc: Dict[str, str] = qid2desc
# Sort by score and filter to max candidates
self._sort_alias_cands(
self._alias2qids, truncate=True, max_cands=self.max_candidates
)
# Cast to dicts in edit mode
if isinstance(qid2eid, VocabularyTrie):
self._qid2eid: Union[Dict[str, int], VocabularyTrie] = qid2eid.to_dict()
else:
self._qid2eid: Union[Dict[str, int], VocabularyTrie] = qid2eid
if isinstance(alias2id, VocabularyTrie):
self._alias2id: Union[Dict[str, int], VocabularyTrie] = alias2id.to_dict()
else:
self._alias2id: Union[Dict[str, int], VocabularyTrie] = alias2id
# Generate reverse indexes for fast editing
self._id2alias: Union[Dict[int, str], Callable[[int], str]] = {
id: al for al, id in self._alias2id.items()
}
self._eid2qid: Union[Dict[int, str], Callable[[int], str]] = {
eid: qid for qid, eid in self._qid2eid.items()
}
self._qid2aliases: Union[Dict[str, set], None] = {}
for al in tqdm(
self._alias2qids,
total=len(self._alias2qids),
desc="Building edit mode objs",
disable=not self.verbose,
):
for qid_pair in self._alias2qids[al]:
if qid_pair[0] not in self._qid2aliases:
self._qid2aliases[qid_pair[0]] = set()
self._qid2aliases[qid_pair[0]].add(al)
assert len(self._qid2eid) == len(self._eid2qid), (
"The qid2eid mapping is not invertable. "
"This means there is a duplicate id value."
)
assert -1 not in self._eid2qid, "-1 can't be an eid"
assert (
0 not in self._eid2qid
), "0 can't be an eid. It's reserved for null candidate"
# For when we need to add new entities
self.max_eid = max(self._eid2qid.keys())
self.max_alid = max(self._id2alias.keys())
def _load_non_edit_mode(
self,
alias2qids: Union[Dict[str, list], TwoLayerVocabularyScoreTrie],
qid2title: Dict[str, str],
qid2desc: Union[Dict[str, str]],
qid2eid: Optional[VocabularyTrie],
alias2id: Optional[VocabularyTrie],
):
"""Load items in read-only Trie mode."""
# Convert to record trie
if isinstance(alias2qids, dict):
self._sort_alias_cands(
alias2qids, truncate=True, max_cands=self.max_candidates
)
alias2qids = TwoLayerVocabularyScoreTrie(
input_dict=alias2qids,
vocabulary=qid2title,
max_value=self.max_candidates,
)
self._alias2qids: Union[
Dict[str, list], TwoLayerVocabularyScoreTrie
] = alias2qids
self._qid2title: Dict[str, str] = qid2title
self._qid2desc: Dict[str, str] = qid2desc
# Convert to Tries for non edit mode
if isinstance(qid2eid, dict):
self._qid2eid: Union[Dict[str, int], VocabularyTrie] = VocabularyTrie(
input_dict=qid2eid
)
else:
self._qid2eid: Union[Dict[str, int], VocabularyTrie] = qid2eid
if isinstance(alias2id, dict):
self._alias2id: Union[Dict[str, int], VocabularyTrie] = VocabularyTrie(
input_dict=alias2id
)
else:
self._alias2id: Union[Dict[str, int], VocabularyTrie] = alias2id
# Make reverse functions for each of use
self._id2alias: Union[
Dict[int, str], Callable[[int], str]
] = lambda x: self._alias2id.get_key(x)
self._eid2qid: Union[
Dict[int, str], Callable[[int], str]
] = lambda x: self._qid2eid.get_key(x)
self._qid2aliases: Union[Dict[str, set], None] = None
assert not self._qid2eid.is_value_in_trie(
0
), "0 can't be an eid. It's reserved for null candidate"
# For when we need to add new entities
self.max_eid = self._qid2eid.get_max_id()
self.max_alid = self._alias2id.get_max_id()
[docs] def save(self, save_dir):
"""Dump the entity symbols.
Args:
save_dir: directory string to save
"""
utils.ensure_dir(save_dir)
utils.dump_json_file(
filename=os.path.join(save_dir, "config.json"),
contents={
"max_candidates": self.max_candidates,
},
)
# If in edit mode, must convert back to tris for saving
if isinstance(self._alias2qids, dict):
alias2qids = TwoLayerVocabularyScoreTrie(
input_dict=self._alias2qids,
vocabulary=self._qid2title,
max_value=self.max_candidates,
)
alias2qids.dump(os.path.join(save_dir, self.alias_cand_map_dir))
else:
self._alias2qids.dump(os.path.join(save_dir, self.alias_cand_map_dir))
if isinstance(self._alias2id, dict):
alias2id = VocabularyTrie(input_dict=self._alias2id)
alias2id.dump(os.path.join(save_dir, self.alias_idx_dir))
else:
self._alias2id.dump(os.path.join(save_dir, self.alias_idx_dir))
if isinstance(self._qid2eid, dict):
qid2eid = VocabularyTrie(input_dict=self._qid2eid)
qid2eid.dump(os.path.join(save_dir, "qid2eid"))
else:
self._qid2eid.dump(os.path.join(save_dir, "qid2eid"))
utils.dump_json_file(
filename=os.path.join(save_dir, "qid2title.json"), contents=self._qid2title
)
if self._qid2desc is not None:
utils.dump_json_file(
filename=os.path.join(save_dir, "qid2desc.json"),
contents=self._qid2desc,
)
[docs] @classmethod
def load_from_cache(
cls,
load_dir,
alias_cand_map_dir="alias2qids",
alias_idx_dir="alias2id",
edit_mode=False,
verbose=False,
):
"""Load entity symbols from load_dir.
Args:
load_dir: directory to load from
alias_cand_map_dir: alias2qid directory
alias_idx_dir: alias2id directory
edit_mode: edit mode flag
verbose: verbose flag
"""
config = utils.load_json_file(filename=os.path.join(load_dir, "config.json"))
max_candidates = config["max_candidates"]
# For backwards compatibility, check if folder exists - if not, load from json
# Future versions will assume folders exist
alias_load_dir = os.path.join(load_dir, alias_cand_map_dir)
if not os.path.exists(alias_load_dir):
alias2qids: Dict[str, list] = utils.load_json_file(
filename=os.path.join(load_dir, "alias2qids.json")
)
else:
alias2qids: TwoLayerVocabularyScoreTrie = TwoLayerVocabularyScoreTrie(
load_dir=alias_load_dir
)
alias_id_load_dir = os.path.join(load_dir, alias_idx_dir)
alias2id = None
if os.path.exists(alias_id_load_dir):
alias2id: VocabularyTrie = VocabularyTrie(load_dir=alias_id_load_dir)
eid_load_dir = os.path.join(load_dir, "qid2eid")
qid2eid = None
if os.path.exists(eid_load_dir):
qid2eid: VocabularyTrie = VocabularyTrie(load_dir=eid_load_dir)
qid2title: Dict[str, str] = utils.load_json_file(
filename=os.path.join(load_dir, "qid2title.json")
)
qid2desc = None
if os.path.exists(os.path.join(load_dir, "qid2desc.json")):
qid2desc: Dict[str, str] = utils.load_json_file(
filename=os.path.join(load_dir, "qid2desc.json")
)
return cls(
alias2qids,
qid2title,
qid2desc,
qid2eid,
alias2id,
max_candidates,
alias_cand_map_dir,
alias_idx_dir,
edit_mode,
verbose,
)
def _sort_alias_cands(
self, alias2qids: Dict[str, list], truncate: bool = False, max_cands: int = 30
):
"""Sort the candidates for each alias from largest to smallest score, truncating if desired."""
for alias in alias2qids:
# Add second key for determinism in case of same counts
alias2qids[alias] = sorted(
alias2qids[alias], key=lambda x: (x[1], x[0]), reverse=True
)
if truncate:
alias2qids[alias] = alias2qids[alias][:max_cands]
return alias2qids
[docs] def get_qid2eid_dict(self):
"""
Get the qid2eid mapping.
Returns: Dict qid2eid mapping
"""
if isinstance(self._qid2eid, dict):
return copy.deepcopy(self._qid2eid)
else:
return self._qid2eid.to_dict()
[docs] def get_alias2qids_dict(self):
"""
Get the alias2qids mapping.
Key is alias, value is list of candidate tuple of length two of [QID, sort_value].
Returns: Dict alias2qids mapping
"""
if isinstance(self._alias2qids, dict):
return copy.deepcopy(self._alias2qids)
else:
return self._alias2qids.to_dict()
[docs] def get_qid2title_dict(self):
"""
Get the qid2title mapping.
Returns: Dict qid2title mapping
"""
return copy.deepcopy(self._qid2title)
[docs] def get_all_alias_vocabtrie(self):
"""
Get a trie of all aliases.
Returns: Vocab trie of all aliases.
"""
if isinstance(self._alias2id, VocabularyTrie):
return self._alias2id
else:
return VocabularyTrie(input_dict=self._alias2id)
[docs] def get_all_qids(self):
"""
Get all QIDs.
Returns: Dict_keys of all QIDs
"""
return self._qid2eid.keys()
[docs] def get_all_aliases(self):
"""
Get all aliases.
Returns: Dict_keys of all aliases
"""
return self._alias2qids.keys()
[docs] def get_all_titles(self):
"""
Get all QID titles.
Returns: Dict_values of all titles
"""
return self._qid2title.values()
[docs] def get_qid(self, id):
"""Get the QID associated with EID.
Args:
id: EID
Returns: QID string
"""
if isinstance(self._eid2qid, dict):
return self._eid2qid[id]
else:
return self._eid2qid(id)
[docs] def alias_exists(self, alias):
"""Check alias existance.
Args:
alias: alias string
Returns: boolean
"""
if isinstance(self._alias2qids, dict):
return alias in self._alias2id
else:
return self._alias2qids.is_key_in_trie(alias)
[docs] def qid_exists(self, qid):
"""Check QID existance.
Args:
alias: QID string
Returns: boolean
"""
if isinstance(self._qid2eid, dict):
return qid in self._qid2eid
else:
return self._qid2eid.is_key_in_trie(qid)
[docs] def get_eid(self, id):
"""Get the QID for the EID.
Args:
id: EID int
Returns: QID string
"""
return self._qid2eid[id]
def _get_qid_pairs(self, alias):
"""Get the qid pairs for an alias.
Args:
alias: alias
Returns: List of QID pairs
"""
if isinstance(self._alias2qids, dict):
qid_pairs = self._alias2qids[alias]
else:
qid_pairs = self._alias2qids.get_value(alias)
return qid_pairs
[docs] def get_qid_cands(self, alias, max_cand_pad=False):
"""Get the QID candidates for an alias.
Args:
alias: alias
max_cand_pad: whether to pad with '-1' or not if fewer than max_candidates candidates
Returns: List of QID strings
"""
qid_pairs = self._get_qid_pairs(alias)
res = [qid_pair[0] for qid_pair in qid_pairs]
if max_cand_pad:
res = res + ["-1"] * (self.max_candidates - len(res))
return res
[docs] def get_qid_count_cands(self, alias, max_cand_pad=False):
"""Get the [QID, sort_value] candidates for an alias.
Args:
alias: alias
max_cand_pad: whether to pad with ['-1',-1] or not if fewer than max_candidates candidates
Returns: List of [QID, sort_value]
"""
qid_pairs = self._get_qid_pairs(alias)
res = qid_pairs
if max_cand_pad:
res = res + ["-1", -1] * (self.max_candidates - len(res))
return res
[docs] def get_eid_cands(self, alias, max_cand_pad=False):
"""Get the EID candidates for an alias.
Args:
alias: alias
max_cand_pad: whether to pad with -1 or not if fewer than max_candidates candidates
Returns: List of EID ints
"""
qid_pairs = self._get_qid_pairs(alias)
res = [self._qid2eid[qid_pair[0]] for qid_pair in qid_pairs]
if max_cand_pad:
res = res + [-1] * (self.max_candidates - len(res))
return res
[docs] def get_title(self, id):
"""Get title for QID.
Args:
id: QID string
Returns: title string
"""
return self._qid2title[id]
[docs] def get_desc(self, id):
"""Get description for QID.
Args:
id: QID string
Returns: title string
"""
if self._qid2desc is None:
return ""
return self._qid2desc.get(id, "")
[docs] def get_alias_idx(self, alias):
"""Get the numeric index of an alias.
Args:
alias: alias
Returns: integer representation of alias
"""
return self._alias2id[alias]
[docs] def get_alias_from_idx(self, alias_idx):
"""Get the alias from the numeric index.
Args:
alias_idx: alias numeric index
Returns: alias string
"""
if isinstance(self._id2alias, dict):
alias = self._id2alias[alias_idx]
else:
alias = self._id2alias(alias_idx)
return alias
# ============================================================
# EDIT MODE OPERATIONS
# ============================================================
[docs] @edit_op
def set_title(self, qid: str, title: str):
"""Set the title for a QID.
Args:
qid: QID
title: title
"""
assert qid in self._qid2eid
self._qid2title[qid] = title
[docs] @edit_op
def set_desc(self, qid: str, desc: str):
"""Set the description for a QID.
Args:
qid: QID
desc: description
"""
assert qid in self._qid2eid
self._qid2desc[qid] = desc
[docs] @edit_op
def set_score(self, qid: str, mention: str, score: float):
"""Change the mention QID score and resorts candidates.
Highest score is first.
Args:
qid: QID
mention: mention
score: score
"""
if mention not in self._alias2qids:
raise ValueError(f"The mention {mention} is not in our mapping")
qids_only = list(map(lambda x: x[0], self._alias2qids[mention]))
if qid not in set(qids_only):
raise ValueError(
f"The qid {qid} is not already associated with that mention."
)
qid_idx = qids_only.index(qid)
assert self._alias2qids[mention][qid_idx][0] == qid
self._alias2qids[mention][qid_idx][1] = score
self._alias2qids[mention] = sorted(
self._alias2qids[mention], key=lambda x: x[1], reverse=True
)
return
[docs] @edit_op
def add_mention(self, qid: str, mention: str, score: float):
"""Add mention to QID with the associated score.
The mention already exists, error thrown to call ``set_score`` instead.
If there are already max candidates to that mention, the last candidate of the
mention is removed in place of QID.
Args:
qid: QID
mention: mention
score: score
"""
# Cast to lower and stripped for aliases
mention = utils.get_lnrm(mention)
# If mention is in mapping, make sure the qid is not
if mention in self._alias2qids:
if qid in set(map(lambda x: x[0], self._alias2qids[mention])):
logger.warning(
f"The QID {qid} is already associated with {mention}. Use set_score if you want to change "
f"the score of an existing mention-qid pair"
)
return
# If mention is not in mapping, add it
if mention not in self._alias2qids:
self._alias2qids[mention] = []
new_al_id = self.max_alid + 1
self.max_alid += 1
assert (
new_al_id not in self._id2alias
), f"{new_al_id} already in self_id2alias"
self._alias2id[mention] = new_al_id
self._id2alias[new_al_id] = mention
# msg = f"You have added a new mention to the dataset. You MUST reprep you data for this to take effect.
# Set data_config.overwrite_preprocessed_data to be True. This warning will now be supressed."
# logger.warning(msg)
# warnings.filterwarnings("ignore", message=msg)
assert (
mention not in self._qid2aliases[qid]
), f"{mention} was a mention for {qid} despite the alias mapping saying otherwise"
# If adding will go beyond max candidates, remove the last candidate. Even if the score is higher,
# the user still wants this mention added.
if len(self._alias2qids[mention]) >= self.max_candidates:
qid_to_remove = self._alias2qids[mention][-1][0]
self.remove_mention(qid_to_remove, mention)
assert (
len(self._alias2qids[mention]) < self.max_candidates
), f"Invalid state: {mention} still has more than {self.max_candidates} candidates after removal"
# Add pair
self._alias2qids[mention].append([qid, score])
self._alias2qids[mention] = sorted(
self._alias2qids[mention], key=lambda x: x[1], reverse=True
)
self._qid2aliases[qid].add(mention)
[docs] @edit_op
def remove_mention(self, qid, mention):
"""Remove the mention from those associated with the QID.
Args:
qid: QID
mention: mention to remove
"""
# Make sure the mention and qid pair is already in the mapping
if mention not in self._alias2qids:
return
qids_only = list(map(lambda x: x[0], self._alias2qids[mention]))
if qid not in set(qids_only):
return
# Remove the QID
idx_to_remove = qids_only.index(qid)
self._alias2qids[mention].pop(idx_to_remove)
# If the mention has NO candidates, remove it as a possible mention
if len(self._alias2qids[mention]) == 0:
del self._alias2qids[mention]
al_id = self._alias2id[mention]
del self._alias2id[mention]
del self._id2alias[al_id]
assert (
mention not in self._alias2qids and mention not in self._alias2id
), f"Removal of no candidates mention {mention} failed"
# msg = f"You have removed all candidates for an existing mention, which will now be removed.
# You MUST reprep you data for this to take effect. Set data_config.overwrite_preprocessed_data to be
# True. This warning will now be supressed."
# logger.warning(msg)
# warnings.filterwarnings("ignore", message=msg)
# Remove mention from inverse mapping (will be not None in edit mode)
assert (
mention in self._qid2aliases[qid]
), f"{mention} was not a mention for {qid} despite the reverse being true"
self._qid2aliases[qid].remove(mention)
return
[docs] @edit_op
def add_entity(self, qid, mentions, title, desc=""):
"""Add entity QID to our mappings with its mentions and title.
Args:
qid: QID
mentions: List of tuples [mention, score]
title: title
desc: description
"""
assert (
qid not in self._qid2eid
), "Something went wrong with the qid check that this entity doesn't exist"
# Update eid
new_eid = self.max_eid + 1
assert new_eid not in self._eid2qid
self._qid2eid[qid] = new_eid
self._eid2qid[new_eid] = qid
# Update title
self._qid2title[qid] = title
# Update description
self._qid2desc[qid] = desc
# Make empty list to add in add_mention
self._qid2aliases[qid] = set()
# Update mentions
for mention_pair in mentions:
self.add_mention(qid, mention_pair[0], mention_pair[1])
# Update metrics at the end in case of failure
self.max_eid += 1
self.num_entities += 1
self.num_entities_with_pad_and_nocand += 1
[docs] @edit_op
def reidentify_entity(self, old_qid, new_qid):
"""Rename ``old_qid`` to ``new_qid``.
Args:
old_qid: old QID
new_qid: new QID
"""
assert (
old_qid in self._qid2eid and new_qid not in self._qid2eid
), f"Internal Error: checks on existing versus new qid for {old_qid} and {new_qid} failed"
# Save state
eid = self._qid2eid[old_qid]
mentions = self.get_mentions(old_qid)
# Update qid2eid
self._qid2eid[new_qid] = self._qid2eid[old_qid]
del self._qid2eid[old_qid]
# Reassign eid
self._eid2qid[eid] = new_qid
# Update qid2title
self._qid2title[new_qid] = self._qid2title[old_qid]
del self._qid2title[old_qid]
# Update qid2desc
self._qid2desc[new_qid] = self.get_desc(old_qid)
del self._qid2desc[old_qid]
# Update qid2aliases
self._qid2aliases[new_qid] = self._qid2aliases[old_qid]
del self._qid2aliases[old_qid]
# Update alias2qids
for mention in mentions:
for i in range(len(self._alias2qids[mention])):
if self._alias2qids[mention][i][0] == old_qid:
self._alias2qids[mention][i][0] = new_qid
break
[docs] @edit_op
def prune_to_entities(self, entities_to_keep):
"""Remove all entities except those in ``entities_to_keep``.
Args:
entities_to_keep: Set of entities to keep
"""
# Update qid based dictionaries
self._qid2title = {
k: v for k, v in self._qid2title.items() if k in entities_to_keep
}
if self._qid2desc is not None:
self._qid2desc = {
k: v for k, v in self._qid2desc.items() if k in entities_to_keep
}
self._qid2aliases = {
k: v for k, v in self._qid2aliases.items() if k in entities_to_keep
}
# Reindex the entities to compress the embedding matrix (when model is update)
self._qid2eid = {k: i + 1 for i, k in enumerate(sorted(entities_to_keep))}
self._eid2qid = {eid: qid for qid, eid in self._qid2eid.items()}
# Extract mentions to keep
mentions_to_keep = set().union(*self._qid2aliases.values())
# Reindex aliases
self._alias2id = {v: i for i, v in enumerate(sorted(mentions_to_keep))}
self._id2alias = {id: al for al, id in self._alias2id.items()}
# Rebuild self._alias2qids
new_alias2qids = {}
for al in mentions_to_keep:
new_alias2qids[al] = [
pair for pair in self._alias2qids[al] if pair[0] in entities_to_keep
][: self.max_candidates]
assert len(new_alias2qids[al]) > 0
self._alias2qids = new_alias2qids
self.num_entities = len(self._qid2eid)
self.num_entities_with_pad_and_nocand = self.num_entities + 2
assert self.num_entities == len(entities_to_keep)
# For when we need to add new entities
self.max_eid = max(self._eid2qid.keys())
self.max_alid = max(self._id2alias.keys())
[docs] @edit_op
def get_mentions(self, qid):
"""Get the mentions for the QID.
Args:
qid: QID
Returns: List of mentions
"""
# qid2aliases is only created in edit mode to allow for removal of mentions associated with a qid
return self._qid2aliases[qid]
[docs] @edit_op
def get_mentions_with_scores(self, qid):
"""Get the mentions and the associated score for the QID.
Args:
qid: QID
Returns: List of tuples [mention, score]
"""
mentions = self._qid2aliases[qid]
res = []
for men in mentions:
for qid_pair in self._alias2qids[men]:
if qid_pair[0] == qid:
res.append([men, qid_pair[1]])
break
return list(sorted(res, key=lambda x: x[1], reverse=True))