"""Type symbols class."""
import copy
import os
from typing import Dict, List, Optional, Set, Union
from tqdm.auto import tqdm
from bootleg.symbols.constants import edit_op
from bootleg.utils import utils
from bootleg.utils.classes.nested_vocab_tries import TwoLayerVocabularyScoreTrie
def _convert_to_trie(qid2typenames, max_types):
all_typenames = set()
qid2typenames_filt = {}
for q, typs in qid2typenames.items():
all_typenames.update(set(typs))
qid2typenames_filt[q] = typs[:max_types]
qid2typenames_trie = TwoLayerVocabularyScoreTrie(
input_dict=qid2typenames_filt,
vocabulary=all_typenames,
max_value=max_types,
)
return qid2typenames_trie
[docs]class TypeSymbols:
"""Type Symbols class for managing type metadata."""
def __init__(
self,
qid2typenames: Union[Dict[str, List[str]], TwoLayerVocabularyScoreTrie],
max_types: Optional[int] = 10,
edit_mode: Optional[bool] = False,
verbose: Optional[bool] = False,
):
"""Type Symbols initializer."""
if max_types <= 0:
raise ValueError("max_types must be greater than 0")
self.max_types = max_types
self.edit_mode = edit_mode
self.verbose = verbose
if self.edit_mode:
self._load_edit_mode(
qid2typenames,
)
else:
self._load_non_edit_mode(
qid2typenames,
)
def _load_edit_mode(
self, qid2typenames: Union[Dict[str, List[str]], TwoLayerVocabularyScoreTrie]
):
"""Load qid to type mappings in edit mode."""
if isinstance(qid2typenames, TwoLayerVocabularyScoreTrie):
self._qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = qid2typenames.to_dict(keep_score=False)
else:
self._qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = {q: typs[: self.max_types] for q, typs in qid2typenames.items()}
self._all_typenames: Union[Set[str], None] = set(
[t for typeset in self._qid2typenames.values() for t in typeset]
)
self._typename2qids: Union[Dict[str, set], None] = {}
for qid in tqdm(
self._qid2typenames,
total=len(self._qid2typenames),
desc="Building edit mode objs",
disable=not self.verbose,
):
for typname in self._qid2typenames[qid]:
if typname not in self._typename2qids:
self._typename2qids[typname] = set()
self._typename2qids[typname].add(qid)
# In case extra types in vocab without qids
for typname in self._all_typenames:
if typname not in self._typename2qids:
self._typename2qids[typname] = set()
def _load_non_edit_mode(
self, qid2typenames: Union[Dict[str, List[str]], TwoLayerVocabularyScoreTrie]
):
"""Load qid to type mappings in non edit mode (read only mode)."""
if isinstance(qid2typenames, dict):
self._qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = _convert_to_trie(qid2typenames, self.max_types)
else:
self._qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = qid2typenames
self._all_typenames: Union[Set[str], None] = None
self._typename2qids: Union[Dict[str, set], None] = None
[docs] def save(self, save_dir, prefix=""):
"""Dump the type symbols.
Args:
save_dir: directory string to save
prefix: prefix to add to beginning to file
"""
utils.ensure_dir(str(save_dir))
utils.dump_json_file(
filename=os.path.join(save_dir, "config.json"),
contents={
"max_types": self.max_types,
},
)
if isinstance(self._qid2typenames, dict):
qid2typenames = _convert_to_trie(self._qid2typenames, self.max_types)
qid2typenames.dump(os.path.join(save_dir, f"{prefix}qid2typenames"))
else:
self._qid2typenames.dump(os.path.join(save_dir, f"{prefix}qid2typenames"))
[docs] @classmethod
def load_from_cache(cls, load_dir, prefix="", edit_mode=False, verbose=False):
"""Load type symbols from load_dir.
Args:
load_dir: directory to load from
prefix: prefix to add to beginning to file
edit_mode: edit mode flag
verbose: verbose flag
Returns: TypeSymbols
"""
config = utils.load_json_file(filename=os.path.join(load_dir, "config.json"))
max_types = config["max_types"]
# For backwards compatibility, check if trie directory exists, otherwise load from json
type_load_dir = os.path.join(load_dir, f"{prefix}qid2typenames")
if not os.path.exists(type_load_dir):
qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = utils.load_json_file(
filename=os.path.join(load_dir, f"{prefix}qid2typenames.json")
)
else:
qid2typenames: Union[
Dict[str, List[str]], TwoLayerVocabularyScoreTrie
] = TwoLayerVocabularyScoreTrie(load_dir=type_load_dir, max_value=max_types)
return cls(qid2typenames, max_types, edit_mode, verbose)
[docs] def get_all_types(self):
"""Return all typenames."""
if isinstance(self._qid2typenames, dict):
return self._all_typenames
else:
return set(self._qid2typenames.vocab_keys())
[docs] def get_types(self, qid):
"""Get the type names associated with the given QID.
Args:
qid: QID
Returns: list of typename strings
"""
if isinstance(self._qid2typenames, dict):
types = self._qid2typenames.get(qid, [])
else:
if self._qid2typenames.is_key_in_trie(qid):
# TwoLayerVocabularyScoreTrie assumes values are list of pairs - we only want type name which is first
types = self._qid2typenames.get_value(qid, keep_score=False)
else:
types = []
return types
[docs] def get_qid2typename_dict(self):
"""Return dictionary of qid to typenames.
Returns: Dict of QID to list of typenames.
"""
if isinstance(self._qid2typenames, dict):
return copy.deepcopy(self._qid2typenames)
else:
return self._qid2typenames.to_dict(keep_score=False)
# ============================================================
# EDIT MODE OPERATIONS
# ============================================================
[docs] @edit_op
def get_entities_of_type(self, typename):
"""Get all entity QIDs of type ``typename``.
Args:
typename: typename
Returns: List
"""
if typename not in self._all_typenames:
raise ValueError(f"{typename} is not a type in the typesystem")
# This will not be None as we are in edit mode
return self._typename2qids.get(typename, [])
[docs] @edit_op
def add_type(self, qid, typename):
"""Add the type to the QID.
If the QID already has maximum types, the
last type is removed and replaced by ``typename``.
Args:
qid: QID
typename: type name
"""
if typename not in self._all_typenames:
self._all_typenames.add(typename)
self._typename2qids[typename] = set()
# Update qid->type mappings
if typename not in self._qid2typenames[qid]:
# Remove last type if too many types
if len(self._qid2typenames[qid]) >= self.max_types:
type_to_remove = self._qid2typenames[qid][-1]
self.remove_type(qid, type_to_remove)
self._qid2typenames[qid].append(typename)
# As we are in edit mode, self._typename2qids will not be None
self._typename2qids[typename].add(qid)
return
[docs] @edit_op
def remove_type(self, qid, typename):
"""Remove the type from the QID.
Args:
qid: QID
typename: type name to remove
"""
if typename not in self._all_typenames:
raise ValueError(
f"The type {typename} is not in our vocab. We only support adding types in our vocab."
)
if typename not in self._qid2typenames[qid]:
return
assert (
typename in self._typename2qids
), f"Invalid state a typename is in self._typename2qids for {typename} and {qid}"
self._qid2typenames[qid].remove(typename)
# As we are in edit mode, self._typename2qids will not be None
# Further, we want to keep the typename even if list is empty as our type system doesn't change
self._typename2qids[typename].remove(qid)
return
[docs] @edit_op
def add_entity(self, qid, types):
"""
Add an entity QID with its types to our mappings.
Args:
qid: QID
types: list of type names
"""
for typename in types:
if typename not in self._all_typenames:
self._all_typenames.add(typename)
self._typename2qids[typename] = set()
# Add the qid to the qid dicts so we can call the add/remove functions
self._qid2typenames[qid] = []
for typename in types:
self._qid2typenames[qid].append(typename)
# Cutdown to max types
self._qid2typenames[qid] = self._qid2typenames[qid][: self.max_types]
# Add to typenames to qids
for typename in self._qid2typenames[qid]:
self._typename2qids[typename].add(qid)
return
[docs] @edit_op
def reidentify_entity(self, old_qid, new_qid):
"""Rename ``old_qid`` to ``new_qid``.
Args:
old_qid: old QID
new_qid: new QID
"""
assert (
old_qid in self._qid2typenames and new_qid not in self._qid2typenames
), f"Internal Error: checks on existing versus new qid for {old_qid} and {new_qid} failed"
# Update qid2typenames
self._qid2typenames[new_qid] = self._qid2typenames[old_qid]
del self._qid2typenames[old_qid]
# Update qid2typenames
for typename in self._qid2typenames[new_qid]:
self._typename2qids[typename].remove(old_qid)
self._typename2qids[typename].add(new_qid)
[docs] @edit_op
def prune_to_entities(self, entities_to_keep):
"""Remove all entities except those in ``entities_to_keep``.
Args:
entities_to_keep: Set of entities to keep
"""
# Update qid2typenames
self._qid2typenames = {
k: v for k, v in self._qid2typenames.items() if k in entities_to_keep
}
# Update qid2typenames, keeping the typenames even if empty lists
for typename in self._typename2qids:
self._typename2qids[typename] = self._typename2qids[typename].intersection(
entities_to_keep
)