Source code for bootleg.symbols.kg_symbols

"""KG symbols class."""
import copy
import os
import re
from typing import Dict, List, Optional, Set, Union

from tqdm.auto import tqdm

from bootleg.symbols.constants import edit_op
from bootleg.utils import utils
from bootleg.utils.classes.nested_vocab_tries import ThreeLayerVocabularyTrie


def _convert_to_trie(qid2relations, max_connections):
    all_relations = set()
    all_qids = set()
    qid2relations_filt = {}
    for q, rel_dict in qid2relations.items():
        qid2relations_filt[q] = {}
        for rel, tail_qs in rel_dict.items():
            all_qids.update(set(tail_qs))
            all_relations.add(rel)
            qid2relations_filt[q][rel] = tail_qs[:max_connections]
    qid2relations_trie = ThreeLayerVocabularyTrie(
        input_dict=qid2relations_filt,
        key_vocabulary=all_relations,
        value_vocabulary=all_qids,
        max_value=max_connections,
    )
    return qid2relations_trie


[docs]class KGSymbols: """KG Symbols class for managing KG metadata.""" def __init__( self, qid2relations: Union[Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie], max_connections: Optional[int] = 50, edit_mode: Optional[bool] = False, verbose: Optional[bool] = False, ): """KG initializer. max_connections acts as the max single number of connections for a given relation. max_connections * 2 is the max number of connections across all relations for a given entity (see ThreeLayerVocabularyTrie). """ self.max_connections = max_connections self.edit_mode = edit_mode self.verbose = verbose if self.edit_mode: self._load_edit_mode( qid2relations, ) else: self._load_non_edit_mode( qid2relations, ) def _load_edit_mode( self, qid2relations: Union[Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie], ): """Load relations in edit mode.""" if isinstance(qid2relations, ThreeLayerVocabularyTrie): self._qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = qid2relations.to_dict() else: self._qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = { head_qid: { rel: tail_qids[: self.max_connections] for rel, tail_qids in rel_dict.items() } for head_qid, rel_dict in qid2relations.items() } self._obj2head: Union[Dict[str, set], None] = {} self._all_relations: Union[Set[str], None] = set() for qid in tqdm( self._qid2relations, total=len(self._qid2relations), desc="Checking relations and building edit mode objs", disable=not self.verbose, ): for rel in self._qid2relations[qid]: self._all_relations.add(rel) for qid2 in self._qid2relations[qid][rel]: if qid2 not in self._obj2head: self._obj2head[qid2] = set() self._obj2head[qid2].add(qid) def _load_non_edit_mode( self, qid2relations: Union[Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie], ): """Load relations in not edit mode.""" if isinstance(qid2relations, dict): self._qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = _convert_to_trie(qid2relations, self.max_connections) else: self._qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = qid2relations self._all_relations: Union[Set[str], None] = None self._obj2head: Union[Dict[str, set], None] = None
[docs] def save(self, save_dir, prefix=""): """Dump the kg symbols. Args: save_dir: directory string to save prefix: prefix to add to beginning to file """ utils.ensure_dir(str(save_dir)) utils.dump_json_file( filename=os.path.join(save_dir, "config.json"), contents={ "max_connections": self.max_connections, }, ) if isinstance(self._qid2relations, dict): qid2relations = _convert_to_trie(self._qid2relations, self.max_connections) qid2relations.dump(os.path.join(save_dir, f"{prefix}qid2relations")) else: self._qid2relations.dump(os.path.join(save_dir, f"{prefix}qid2relations"))
[docs] @classmethod def load_from_cache(cls, load_dir, prefix="", edit_mode=False, verbose=False): """Load type symbols from load_dir. Args: load_dir: directory to load from prefix: prefix to add to beginning to file edit_mode: edit mode verbose: verbose flag Returns: TypeSymbols """ config = utils.load_json_file(filename=os.path.join(load_dir, "config.json")) max_connections = config["max_connections"] # For backwards compatibility, check if trie directory exists, otherwise load from json rel_load_dir = os.path.join(load_dir, f"{prefix}qid2relations") if not os.path.exists(rel_load_dir): qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = utils.load_json_file( filename=os.path.join(load_dir, f"{prefix}qid2relations.json") ) # Make sure relation is _not_ PID. The user should have the qid2relation dict that is pre-translated first_qid = next(iter(qid2relations.keys())) first_rel = next(iter(qid2relations[first_qid].keys())) if re.match("^P[0-9]+$", first_rel): raise ValueError( "Your qid2relations dict has a relation as a PID identifier. Please replace " "with human readable strings for training. " "See https://www.wikidata.org/wiki/Wikidata:Database_reports/List_of_properties/all" ) else: qid2relations: Union[ Dict[str, Dict[str, List[str]]], ThreeLayerVocabularyTrie ] = ThreeLayerVocabularyTrie( load_dir=rel_load_dir, max_value=max_connections ) return cls(qid2relations, max_connections, edit_mode, verbose)
[docs] def get_qid2relations_dict(self): """Return a dictionary form of the relation to qid mappings object. Returns: Dict of relation to head qid to list of tail qids """ if isinstance(self._qid2relations, dict): return copy.deepcopy(self._qid2relations) else: return self._qid2relations.to_dict()
[docs] def get_all_relations(self): """Get all relations in our KG mapping. Returns: Set """ if isinstance(self._qid2relations, dict): return self._all_relations else: return set(self._qid2relations.key_vocab_keys())
[docs] def get_relations_between(self, qid1, qid2): """Check if two QIDs are connected in KG and returns the relations between then. Args: qid1: QID one qid2: QID two Returns: string relation or empty set """ rel_dict = {} if isinstance(self._qid2relations, dict): rel_dict = self._qid2relations.get(qid1, {}) else: if self._qid2relations.is_key_in_trie(qid1): rel_dict = self._qid2relations.get_value(qid1) rels = set() for rel, tail_qids in rel_dict.items(): if qid2 in set(tail_qids): rels.add(rel) return rels
[docs] def get_relations_tails_for_qid(self, qid): """Get dict of relation to tail qids for given qid. Args: qid: QID Returns: Dict relation to list of tail qids for that relation """ if isinstance(self._qid2relations, dict): return self._qid2relations.get(qid, {}) else: rel_dict = {} if self._qid2relations.is_key_in_trie(qid): rel_dict = self._qid2relations.get_value(qid) return rel_dict
# ============================================================ # EDIT MODE OPERATIONS # ============================================================
[docs] @edit_op def add_relation(self, qid, relation, qid2): """Add a relationship triple to our mapping. If the QID already has max connection through ``relation``, the last ``other_qid`` is removed and replaced by ``qid2``. Args: qid: head entity QID relation: relation qid2: tail entity QID: """ if relation not in self._all_relations: self._all_relations.add(relation) if relation not in self._qid2relations[qid]: self._qid2relations[qid][relation] = [] # Check if qid2 already in that relation if qid2 in self._qid2relations[qid][relation]: return if len(self._qid2relations[qid][relation]) >= self.max_connections: qid_to_remove = self._qid2relations[qid][relation][-1] self.remove_relation(qid, relation, qid_to_remove) assert len(self._qid2relations[qid][relation]) < self.max_connections, ( f"Something went wrong and we still have more that {self.max_connections} " f"relations when removing {qid}, {relation}, {qid2}" ) self._qid2relations[qid][relation].append(qid2) if qid2 not in self._obj2head: self._obj2head[qid2] = set() self._obj2head[qid2].add(qid) return
[docs] @edit_op def remove_relation(self, qid, relation, qid2): """Remove a relation triple from our mapping. Args: qid: head entity QID relation: relation qid2: tail entity QID """ if relation not in self._qid2relations[qid]: return if qid2 not in self._qid2relations[qid][relation]: return self._qid2relations[qid][relation].remove(qid2) self._obj2head[qid2].remove(qid) # If no connections, remove relation if len(self._qid2relations[qid][relation]) <= 0: del self._qid2relations[qid][relation] if len(self._obj2head[qid2]) <= 0: del self._obj2head[qid2] return
[docs] @edit_op def add_entity(self, qid, relation_dict): """Add a new entity to our relation mapping. Args: qid: QID relation_dict: dictionary of relation -> list of connected other_qids by relation """ if qid in self._qid2relations: raise ValueError(f"{qid} is already in kg symbols") for relation in relation_dict: if relation not in self._all_relations: self._all_relations.add(relation) self._qid2relations[qid] = relation_dict.copy() for rel in self._qid2relations[qid]: self._qid2relations[qid][rel] = self._qid2relations[qid][rel][ : self.max_connections ] # Use self._qid2relations[qid] rather than relation_dict as the former is limited by max connections for rel in self._qid2relations[qid]: for obj_qid in self._qid2relations[qid][rel]: if obj_qid not in self._obj2head: self._obj2head[obj_qid] = set() self._obj2head[obj_qid].add(qid) return
[docs] @edit_op def reidentify_entity(self, old_qid, new_qid): """Rename ``old_qid`` to ``new_qid``. Args: old_qid: old QID new_qid: new QID """ if old_qid not in self._qid2relations or new_qid in self._qid2relations: raise ValueError( f"Either old qid {old_qid} is not in kg symbols or new qid {new_qid} is already in kg symbols" ) # Update all object qids (aka subjects-object pairs where the object is the old qid) for subj_qid in self._obj2head.get(old_qid, {}): for rel in self._qid2relations[subj_qid]: if old_qid in self._qid2relations[subj_qid][rel]: for j in range(len(self._qid2relations[subj_qid][rel])): if self._qid2relations[subj_qid][rel][j] == old_qid: self._qid2relations[subj_qid][rel][j] = new_qid # Update all subject qids - take the set union in case a subject has the same object with different relations for obj_qid in set().union( *[ set(self._qid2relations[old_qid][rel]) for rel in self._qid2relations[old_qid] ] ): # May get cyclic relationship ann the obj qid qill already have been transformed if obj_qid == new_qid: obj_qid = old_qid assert ( old_qid in self._obj2head[obj_qid] ), f"{old_qid} {obj_qid} {self._obj2head[obj_qid]}" self._obj2head[obj_qid].remove(old_qid) self._obj2head[obj_qid].add(new_qid) # Update qid2relations and the object2head mappings self._qid2relations[new_qid] = self._qid2relations[old_qid] del self._qid2relations[old_qid] if old_qid in self._obj2head: self._obj2head[new_qid] = self._obj2head[old_qid] del self._obj2head[old_qid]
[docs] @edit_op def prune_to_entities(self, entities_to_keep): """Remove all entities except those in ``entities_to_keep``. Args: entities_to_keep: Set of entities to keep """ # Update qid2relations self._qid2relations = { k: v for k, v in self._qid2relations.items() if k in entities_to_keep } new_obj2head = {} # Update all object qids for qid in self._qid2relations: for rel in list(self._qid2relations[qid].keys()): filtered_object_ents = [ j for j in self._qid2relations[qid][rel] if j in entities_to_keep ][: self.max_connections] # Keep relation only if more than one object if len(filtered_object_ents) > 0: self._qid2relations[qid][rel] = filtered_object_ents for obj_qid in filtered_object_ents: if obj_qid not in new_obj2head: new_obj2head[obj_qid] = set() new_obj2head[obj_qid].add(qid) else: del self._qid2relations[qid][rel] self._obj2head = new_obj2head