Source code for bootleg.utils.utils

"""Bootleg utils."""
import collections
import json
import logging
import math
import os
import pathlib
import shutil
import time
import unicodedata
from itertools import chain, islice

import marisa_trie
import ujson
import yaml

from bootleg import log_rank_0_info
from bootleg.symbols.constants import USE_LOWER, USE_STRIP
from bootleg.utils.classes.dotted_dict import DottedDict

logger = logging.getLogger(__name__)


[docs]def ensure_dir(d): """ Check if a directory exists. If not, it makes it. Args: d: path """ pathlib.Path(d).mkdir(exist_ok=True, parents=True)
[docs]def exists_dir(d): """ Check if directory exists. Args: d: path """ return pathlib.Path(d).exists()
[docs]def dump_json_file(filename, contents, ensure_ascii=False): """ Dump dictionary to json file. Args: filename: file to write to contents: dictionary to save ensure_ascii: ensure ascii """ filename = pathlib.Path(filename) filename.parent.mkdir(exist_ok=True, parents=True) with open(filename, "w") as f: try: ujson.dump(contents, f, ensure_ascii=ensure_ascii) except OverflowError: json.dump(contents, f, ensure_ascii=ensure_ascii)
[docs]def dump_yaml_file(filename, contents): """ Dump dictionary to yaml file. Args: filename: file to write to contents: dictionary to save """ filename = pathlib.Path(filename) filename.parent.mkdir(exist_ok=True, parents=True) with open(filename, "w") as f: yaml.dump(contents, f)
[docs]def load_json_file(filename): """ Load dictionary from json file. Args: filename: file to read from Returns: Dict """ with open(filename, "r") as f: contents = ujson.load(f) return contents
[docs]def load_yaml_file(filename): """ Load dictionary from yaml file. Args: filename: file to read from Returns: Dict """ with open(filename) as f: contents = yaml.load(f, Loader=yaml.FullLoader) return contents
[docs]def recurse_redict(d): """ Cast all DottedDict values in a dictionary to be dictionaries. Useful for YAML dumping. Args: d: Dict Returns: Dict with no DottedDicts """ d = dict(d) for k, v in d.items(): if isinstance(v, (DottedDict, dict)): d[k] = recurse_redict(dict(d[k])) return d
[docs]def write_to_file(filename, value): """ Write generic value to a file. If value is not string, will cast to str(). Args: filename: file to write to value: context to write Returns: Dict """ ensure_dir(os.path.dirname(filename)) if not isinstance(value, str): value = str(value) fout = open(filename, "w") fout.write(value + "\n") fout.close()
[docs]def write_jsonl(filepath, values, ensure_ascii=False): """ Write List[Dict] data to jsonlines file. Args: filepath: file to write to values: list of dictionary data to write ensure_ascii: ensure_ascii for json """ with open(filepath, "w") as out_f: for val in values: out_f.write(ujson.dumps(val, ensure_ascii=ensure_ascii) + "\n") return
[docs]def chunks(iterable, n): """ Chunk data. chunks(ABCDE,2) => AB CD E. Args: iterable: iterable input n: number of chunks Returns: next chunk """ iterable = iter(iterable) while True: try: yield chain([next(iterable)], islice(iterable, n - 1)) except StopIteration: return None
[docs]def chunk_file(in_file, out_dir, num_lines, prefix="out_"): """ Chunk a file into num_lines chunks. Args: in_file: input file out_dir: output directory num_lines: number of lines in each chunk prefix: prefix for output files in out_dir Returns: total number of lines read, dictionary of output file path -> number of lines in that file (for tqdms) """ ensure_dir(out_dir) out_files = {} total_lines = 0 ending = os.path.splitext(in_file)[1] with open(in_file) as bigfile: i = 0 while True: try: lines = next(chunks(bigfile, num_lines)) except StopIteration: break except RuntimeError: break file_split = os.path.join(out_dir, f"{prefix}{i}{ending}") total_file_lines = 0 i += 1 with open(file_split, "w") as f: while True: try: line = next(lines) except StopIteration: break total_lines += 1 total_file_lines += 1 f.write(line) out_files[file_split] = total_file_lines return total_lines, out_files
[docs]def create_single_item_trie(in_dict, out_file=""): """ Create marisa trie. Creates a marisa trie from the input dictionary. We assume the dictionary has string keys and integer values. Args: in_dict: Dict[str] -> Int out_file: marisa file to save (useful for reading as memmap) (optional) Returns: marisa trie of in_dict """ keys = [] values = [] for k in in_dict: assert type(in_dict[k]) is int keys.append(k) # Tries require list of item for the record trie values.append(tuple([in_dict[k]])) fmt = "<l" trie = marisa_trie.RecordTrie(fmt, zip(keys, values)) if out_file != "": trie.save(out_file) return trie
[docs]def load_single_item_trie(file): """ Load a marisa trie with integer values from memmap file. Args: file: marisa input file Returns: marisa trie """ assert exists_dir(file) return marisa_trie.RecordTrie("<l").mmap(file)
[docs]def get_lnrm(s, strip=USE_STRIP, lower=USE_LOWER): """ Convert to lnrm form. Convert a string to its lnrm form We form the lower-cased normalized version l(s) of a string s by canonicalizing its UTF-8 characters, eliminating diacritics, lower-casing the UTF-8 and throwing out all ASCII- range characters that are not alpha-numeric. from http://nlp.stanford.edu/pubs/subctackbp.pdf Section 2.3 Args: s: input string strip: boolean for stripping alias or not lower: boolean for lowercasing alias or not Returns: the lnrm form of the string """ if not strip and not lower: return s lnrm = str(s) if lower: lnrm = lnrm.lower() if strip: lnrm = unicodedata.normalize("NFD", lnrm) lnrm = "".join( [ x for x in lnrm if (not unicodedata.combining(x) and x.isalnum() or x == " ") ] ).strip() # will remove if there are any duplicate white spaces e.g. "the alias is here" lnrm = " ".join(lnrm.split()) return lnrm
[docs]def strip_nan(input_list): """ Replace float('nan') with nulls. Used for ujson loading/dumping. Args: input_list: list of items to remove the Nans from Returns: list or nested list where Nan is not None """ final_list = [] for item in input_list: if isinstance(item, collections.abc.Iterable): final_list.append(strip_nan(item)) else: final_list.append(item if not math.isnan(item) else None) return final_list
[docs]def try_rmtree(rm_dir): """ Try to remove a directory tree. In the case a resource is open, rmtree will fail. This retries to rmtree after 1 second waits for 5 times. Args: rm_dir: directory to remove """ num_retries = 0 max_retries = 5 while num_retries < max_retries: try: shutil.rmtree(rm_dir) break except OSError: time.sleep(1) num_retries += 1 if num_retries >= max_retries: log_rank_0_info( logger, f"{rm_dir} was not able to be deleted. This is okay but will have to manually be removed.", )