Source code for bootleg.data

"""Bootleg data creation."""
import copy
import logging
import os
from collections import defaultdict
from typing import Any, Dict, List, Tuple, Union

import torch
from emmental import Meta
from emmental.data import EmmentalDataLoader, emmental_collate_fn
from emmental.utils.utils import list_to_tensor
from torch.utils.data import DistributedSampler, RandomSampler

from bootleg import log_rank_0_info
from bootleg.dataset import BootlegDataset, BootlegEntityDataset
from bootleg.slicing.slice_dataset import BootlegSliceDataset
from bootleg.task_config import BATCH_CANDS_LABEL, CANDS_LABEL

logger = logging.getLogger(__name__)


[docs]def get_slicedatasets(args, splits, entity_symbols): """Get the slice datasets. Args: args: main args splits: splits to get datasets for entity_symbols: entity symbols Returns: Dict of slice datasets """ datasets = {} splits = splits for split in splits: dataset_path = os.path.join( args.data_config.data_dir, args.data_config[f"{split}_dataset"].file ) datasets[split] = BootlegSliceDataset( main_args=args, dataset=dataset_path, use_weak_label=args.data_config[f"{split}_dataset"].use_weak_label, entity_symbols=entity_symbols, dataset_threads=args.run_config.dataset_threads, split=split, ) return datasets
[docs]def get_dataloaders( args, tasks, use_batch_cands, load_entity_data, splits, entity_symbols, tokenizer, dataset_offsets: Dict[str, List[int]] = None, ): """Get the dataloaders. Args: args: main args tasks: task names use_batch_cands: whether to use candidates across a batch (train and eval_batch_cands) load_entity_data: whether to load entity data splits: data splits to generate dataloaders for entity_symbols: entity symbols dataset_offsets: [start, end] offsets for each split to index into the dataset. Dataset len is end-start. If end is None, end is the length of the dataset. Returns: list of dataloaders """ if dataset_offsets is None: dataset_offsets = {split: None for split in splits} task_to_label_dict = { t: BATCH_CANDS_LABEL if use_batch_cands else CANDS_LABEL for t in tasks } is_bert = True datasets = {} for split in splits: if dataset_offsets[split] is not None and not isinstance( dataset_offsets[split], list ): raise TypeError( "dataset_offsets must be dict from split to list of indexes to subselect." ) dataset_path = os.path.join( args.data_config.data_dir, args.data_config[f"{split}_dataset"].file ) datasets[split] = BootlegDataset( main_args=args, name="Bootleg", dataset=dataset_path, use_weak_label=args.data_config[f"{split}_dataset"].use_weak_label, load_entity_data=load_entity_data, tokenizer=tokenizer, entity_symbols=entity_symbols, dataset_threads=args.run_config.dataset_threads, split=split, is_bert=is_bert, dataset_range=dataset_offsets[split], ) dataloaders = [] for split, dataset in datasets.items(): if split in args.learner_config.train_split: dataset_sampler = ( RandomSampler(dataset) if Meta.config["learner_config"]["local_rank"] == -1 else DistributedSampler( dataset, seed=Meta.config["meta_config"]["seed"] ) ) else: dataset_sampler = None if Meta.config["learner_config"]["local_rank"] != -1: log_rank_0_info( logger, "You are using distributed computing for eval. We are not using a distributed sampler. " "Please use DataParallel and not DDP.", ) dataloaders.append( EmmentalDataLoader( task_to_label_dict=task_to_label_dict, dataset=dataset, sampler=dataset_sampler, split=split, collate_fn=bootleg_collate_fn if use_batch_cands else emmental_collate_fn, batch_size=args.train_config.batch_size if split in args.learner_config.train_split or args.run_config.eval_batch_size is None else args.run_config.eval_batch_size, num_workers=args.run_config.dataloader_threads, pin_memory=False, ) ) log_rank_0_info( logger, f"Built dataloader for {split} set with {len(dataset)} and {args.run_config.dataloader_threads} threads " f"samples (Shuffle={split in args.learner_config.train_split}, " f"Batch size={dataloaders[-1].batch_size}).", ) return dataloaders
[docs]def get_entity_dataloaders( args, tasks, entity_symbols, tokenizer, ): """Get the entity dataloaders. Args: args: main args tasks: task names entity_symbols: entity symbols Returns: list of dataloaders """ task_to_label_dict = {t: None for t in tasks} split = "test" dataset_path = os.path.join( args.data_config.data_dir, args.data_config[f"{split}_dataset"].file ) dataset = BootlegEntityDataset( main_args=args, name="Bootleg", dataset=dataset_path, tokenizer=tokenizer, entity_symbols=entity_symbols, dataset_threads=args.run_config.dataset_threads, split=split, ) dataset_sampler = None if Meta.config["learner_config"]["local_rank"] != -1: log_rank_0_info( logger, "You are using distributed computing for eval. We are not using a distributed sampler. " "Please use DataParallel and not DDP.", ) dataloader = EmmentalDataLoader( task_to_label_dict=task_to_label_dict, dataset=dataset, sampler=dataset_sampler, split=split, collate_fn=emmental_collate_fn, batch_size=args.train_config.batch_size if split in args.learner_config.train_split or args.run_config.eval_batch_size is None else args.run_config.eval_batch_size, num_workers=args.run_config.dataloader_threads, pin_memory=False, ) log_rank_0_info( logger, f"Built dataloader for {split} set with {len(dataset)} and {args.run_config.dataloader_threads} threads " f"samples (Shuffle={split in args.learner_config.train_split}, " f"Batch size={dataloader.batch_size}).", ) return dataloader
[docs]def bootleg_collate_fn( batch: Union[ List[Tuple[Dict[str, Any], Dict[str, torch.Tensor]]], List[Dict[str, Any]] ] ) -> Union[Tuple[Dict[str, Any], Dict[str, torch.Tensor]], Dict[str, Any]]: """Collate function (modified from emmental collate fn). The main difference is our collate function merges candidates from across the batch for disambiguation. Args: batch: The batch to collate. Returns: The collated batch. """ X_batch: defaultdict = defaultdict(list) # In Bootleg, we may have a nested dictionary in x_dict; we want to keep this structure but # collate the subtensors X_sub_batch: defaultdict = defaultdict(lambda: defaultdict(list)) Y_batch: defaultdict = defaultdict(list) # Learnable batch should be a pair of dict, while non learnable batch is a dict is_learnable = True if not isinstance(batch[0], dict) else False if is_learnable: for x_dict, y_dict in batch: if isinstance(x_dict, dict) and isinstance(y_dict, dict): for field_name, value in x_dict.items(): if isinstance(value, list): X_batch[field_name] += value elif isinstance(value, dict): # We reinstantiate the field_name here # This keeps the field_name key intact if field_name not in X_sub_batch: X_sub_batch[field_name] = defaultdict(list) for sub_field_name, sub_value in value.items(): if isinstance(sub_value, list): X_sub_batch[field_name][sub_field_name] += sub_value else: X_sub_batch[field_name][sub_field_name].append( sub_value ) else: X_batch[field_name].append(value) for label_name, value in y_dict.items(): if isinstance(value, list): Y_batch[label_name] += value else: Y_batch[label_name].append(value) else: for x_dict in batch: # type: ignore for field_name, value in x_dict.items(): # type: ignore if isinstance(value, list): X_batch[field_name] += value elif isinstance(value, dict): # We reinstantiate the field_name here # This keeps the field_name key intact if field_name not in X_sub_batch: X_sub_batch[field_name] = defaultdict(list) for sub_field_name, sub_value in value.items(): if isinstance(sub_value, list): X_sub_batch[field_name][sub_field_name] += sub_value else: X_sub_batch[field_name][sub_field_name].append(sub_value) else: X_batch[field_name].append(value) field_names = copy.deepcopy(list(X_batch.keys())) for field_name in field_names: values = X_batch[field_name] # Only merge list of tensors if isinstance(values[0], torch.Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_batch[field_name] = item_tensor field_names = copy.deepcopy(list(X_sub_batch.keys())) for field_name in field_names: sub_field_names = copy.deepcopy(list(X_sub_batch[field_name].keys())) for sub_field_name in sub_field_names: values = X_sub_batch[field_name][sub_field_name] # Only merge list of tensors if isinstance(values[0], torch.Tensor): item_tensor, item_mask_tensor = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], ) X_sub_batch[field_name][sub_field_name] = item_tensor # Add sub batch to batch for field_name in field_names: X_batch[field_name] = dict(X_sub_batch[field_name]) if is_learnable: for label_name, values in Y_batch.items(): Y_batch[label_name] = list_to_tensor( values, min_len=Meta.config["data_config"]["min_data_len"], max_len=Meta.config["data_config"]["max_data_len"], )[0] # ACROSS BATCH CANDIDATE MERGING # Turns from b x m x k to E where E is the number of unique entities all_uniq_eids = [] all_uniq_eid_idx = [] label = [] for k, batch_eids in enumerate(X_batch["entity_cand_eid"]): for j, eid in enumerate(batch_eids): # Skip if already in batch or if it's the unk...we don't use masking in the softmax for batch_cands # data loading (training and during train eval) if ( eid in all_uniq_eids or X_batch["entity_cand_eval_mask"][k][j].item() is True ): continue all_uniq_eids.append(eid) all_uniq_eid_idx.append([k, j]) for eid in X_batch["gold_eid"]: men_label = [] if eid not in all_uniq_eids: men_label.append(-2) else: men_label.append(all_uniq_eids.index(eid)) label.append(men_label) # Super rare edge case if doing eval during training on small batch sizes and have an entire batch # where the alias is -2 (i.e., we don't have it in our dump) if len(all_uniq_eids) == 0: # Give the unq entity in this case -> we want the model to get the wrong answer anyways and it will # all_uniq_eids = [X_batch["entity_cand_eid"][0][0]] all_uniq_eid_idx = [[0, 0]] all_uniq_eid_idx = torch.LongTensor(all_uniq_eid_idx) assert len(all_uniq_eid_idx.size()) == 2 and all_uniq_eid_idx.size(1) == 2 for key in X_batch.keys(): # Don't transform the mask as that's only used for no batch cands if ( key.startswith("entity_") and key != "entity_cand_eval_mask" and key != "entity_to_mask" ): X_batch[key] = X_batch[key][all_uniq_eid_idx[:, 0], all_uniq_eid_idx[:, 1]] # print("FINAL", X_batch["entity_cand_eid"]) Y_batch["gold_unq_eid_idx"] = torch.LongTensor(label) # for k in X_batch: # try: # print(k, X_batch[k].shape) # except: # print(k, len(X_batch[k])) # for k in Y_batch: # print(k, Y_batch[k].shape, Y_batch[k]) if is_learnable: return dict(X_batch), dict(Y_batch) else: return dict(X_batch)