Source code for bootleg.utils.parser.emm_parse_args

"""Overrides the Emmental parse_args."""
import argparse
from argparse import ArgumentParser
from typing import Any, Dict, Optional, Tuple

from emmental.utils.utils import (
    nullable_float,
    nullable_int,
    nullable_string,
    str2bool,
    str2dict,
)

from bootleg.utils.classes.dotted_dict import DottedDict, create_bool_dotted_dict


[docs]def parse_args(parser: Optional[ArgumentParser] = None) -> Tuple[ArgumentParser, Dict]: """Parse args. Overrides the default Emmental parser to add the "emmental." level to the parser so we can parse it correctly with the Bootleg config. Args: parser: Argument parser object, defaults to None. Returns: The updated argument parser object. """ if parser is None: parser = argparse.ArgumentParser( "Emmental configuration", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser_hierarchy = {"emmental": {}} # Load meta configuration meta_config = parser.add_argument_group("Meta configuration") meta_config.add_argument( "--emmental.seed", type=nullable_int, default=1234, help="Random seed for all numpy/torch/cuda operations in model and learning", ) meta_config.add_argument( "--emmental.verbose", type=str2bool, default=True, help="Whether to print the log information", ) meta_config.add_argument( "--emmental.log_path", type=str, default="logs", help="Directory to save running log", ) meta_config.add_argument( "--emmental.use_exact_log_path", type=str2bool, default=False, help="Whether to use the exact log directory", ) parser_hierarchy["emmental"]["_global_meta"] = meta_config # Load data configuration data_config = parser.add_argument_group("Data configuration") data_config.add_argument( "--emmental.min_data_len", type=int, default=0, help="Minimal data length" ) data_config.add_argument( "--emmental.max_data_len", type=int, default=0, help="Maximal data length (0 for no max_len)", ) parser_hierarchy["emmental"]["_global_data"] = data_config # Load model configuration model_config = parser.add_argument_group("Model configuration") model_config.add_argument( "--emmental.model_path", type=nullable_string, default=None, help="Path to pretrained model", ) model_config.add_argument( "--emmental.device", type=int, default=0, help="Which device to use (-1 for cpu or gpu id (e.g., 0 for cuda:0))", ) model_config.add_argument( "--emmental.dataparallel", type=str2bool, default=False, help="Whether to use dataparallel or not", ) model_config.add_argument( "--emmental.distributed_backend", type=str, default="nccl", choices=["nccl", "gloo"], help="Which backend to use for distributed training.", ) parser_hierarchy["emmental"]["_global_model"] = model_config # Learning configuration learner_config = parser.add_argument_group("Learning configuration") learner_config.add_argument( "--emmental.optimizer_path", type=nullable_string, default=None, help="Path to optimizer state", ) learner_config.add_argument( "--emmental.scheduler_path", type=nullable_string, default=None, help="Path to lr scheduler state", ) learner_config.add_argument( "--emmental.fp16", type=str2bool, default=False, help="Whether to use 16-bit (mixed) precision (through NVIDIA apex)" "instead of 32-bit", ) learner_config.add_argument( "--emmental.fp16_opt_level", type=str, default="O1", help="Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']." "See details at https://nvidia.github.io/apex/amp.html", ) learner_config.add_argument( "--emmental.local_rank", type=int, default=-1, help="local_rank for distributed training on gpus", ) learner_config.add_argument( "--emmental.epochs_learned", type=int, default=0, help="Learning epochs learned" ) learner_config.add_argument( "--emmental.n_epochs", type=int, default=1, help="Total number of learning epochs", ) learner_config.add_argument( "--emmental.steps_learned", type=int, default=0, help="Learning steps learned" ) learner_config.add_argument( "--emmental.n_steps", type=int, default=None, help="Total number of learning steps", ) learner_config.add_argument( "--emmental.skip_learned_data", type=str2bool, default=False, help="Iterate through dataloader when steps or epochs learned is true", ) learner_config.add_argument( "--emmental.train_split", nargs="+", type=str, default=["train"], help="The split for training", ) learner_config.add_argument( "--emmental.valid_split", nargs="+", type=str, default=["dev"], help="The split for validation", ) learner_config.add_argument( "--emmental.test_split", nargs="+", type=str, default=["test"], help="The split for testing", ) learner_config.add_argument( "--emmental.ignore_index", type=nullable_int, default=None, help="The ignore index, uses for masking samples", ) learner_config.add_argument( "--emmental.online_eval", type=str2bool, default=False, help="Whether to perform online evaluation", ) parser_hierarchy["emmental"]["_global_learner"] = learner_config # Optimizer configuration optimizer_config = parser.add_argument_group("Optimizer configuration") optimizer_config.add_argument( "--emmental.optimizer", type=nullable_string, default="adamw", choices=[ "asgd", "adadelta", "adagrad", "adam", "adamw", "adamax", "lbfgs", "rms_prop", "r_prop", "sgd", "sparse_adam", "bert_adam", None, ], help="The optimizer to use", ) optimizer_config.add_argument( "--emmental.lr", type=float, default=1e-3, help="Learing rate" ) optimizer_config.add_argument( "--emmental.l2", type=float, default=0.0, help="l2 regularization" ) optimizer_config.add_argument( "--emmental.grad_clip", type=nullable_float, default=None, help="Gradient clipping", ) optimizer_config.add_argument( "--emmental.gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps", ) # ASGD config optimizer_config.add_argument( "--emmental.asgd_lambd", type=float, default=0.0001, help="ASGD lambd" ) optimizer_config.add_argument( "--emmental.asgd_alpha", type=float, default=0.75, help="ASGD alpha" ) optimizer_config.add_argument( "--emmental.asgd_t0", type=float, default=1000000.0, help="ASGD t0" ) # Adadelta config optimizer_config.add_argument( "--emmental.adadelta_rho", type=float, default=0.9, help="Adadelta rho" ) optimizer_config.add_argument( "--emmental.adadelta_eps", type=float, default=0.000001, help="Adadelta eps" ) # Adagrad config optimizer_config.add_argument( "--emmental.adagrad_lr_decay", type=float, default=0, help="Adagrad lr_decay" ) optimizer_config.add_argument( "--emmental.adagrad_initial_accumulator_value", type=float, default=0, help="Adagrad initial accumulator value", ) optimizer_config.add_argument( "--emmental.adagrad_eps", type=float, default=0.0000000001, help="Adagrad eps" ) # Adam config optimizer_config.add_argument( "--emmental.adam_betas", nargs="+", type=float, default=(0.9, 0.999), help="Adam betas", ) optimizer_config.add_argument( "--emmental.adam_eps", type=float, default=1e-6, help="Adam eps" ) optimizer_config.add_argument( "--emmental.adam_amsgrad", type=str2bool, default=False, help="Whether to use the AMSGrad variant of adam", ) # AdamW config optimizer_config.add_argument( "--emmental.adamw_betas", nargs="+", type=float, default=(0.9, 0.999), help="AdamW betas", ) optimizer_config.add_argument( "--emmental.adamw_eps", type=float, default=1e-6, help="AdamW eps" ) optimizer_config.add_argument( "--emmental.adamw_amsgrad", type=str2bool, default=False, help="Whether to use the AMSGrad variant of AdamW", ) # Adamax config optimizer_config.add_argument( "--emmental.adamax_betas", nargs="+", type=float, default=(0.9, 0.999), help="Adamax betas", ) optimizer_config.add_argument( "--emmental.adamax_eps", type=float, default=1e-6, help="Adamax eps" ) # LBFGS config optimizer_config.add_argument( "--emmental.lbfgs_max_iter", type=int, default=20, help="LBFGS max iter" ) optimizer_config.add_argument( "--emmental.lbfgs_max_eval", type=nullable_int, default=None, help="LBFGS max eval", ) optimizer_config.add_argument( "--emmental.lbfgs_tolerance_grad", type=float, default=1e-07, help="LBFGS tolerance grad", ) optimizer_config.add_argument( "--emmental.lbfgs_tolerance_change", type=float, default=1e-09, help="LBFGS tolerance change", ) optimizer_config.add_argument( "--emmental.lbfgs_history_size", type=int, default=100, help="LBFGS history size", ) optimizer_config.add_argument( "--emmental.lbfgs_line_search_fn", type=nullable_string, default=None, help="LBFGS line search fn", ) # RMSprop config optimizer_config.add_argument( "--emmental.rms_prop_alpha", type=float, default=0.99, help="RMSprop alpha" ) optimizer_config.add_argument( "--emmental.rms_prop_eps", type=float, default=1e-08, help="RMSprop eps" ) optimizer_config.add_argument( "--emmental.rms_prop_momentum", type=float, default=0, help="RMSprop momentum" ) optimizer_config.add_argument( "--emmental.rms_prop_centered", type=str2bool, default=False, help="RMSprop centered", ) # Rprop config optimizer_config.add_argument( "--emmental.r_prop_etas", nargs="+", type=float, default=(0.5, 1.2), help="Rprop etas", ) optimizer_config.add_argument( "--emmental.r_prop_step_sizes", nargs="+", type=float, default=(1e-06, 50), help="Rprop step sizes", ) # SGD config optimizer_config.add_argument( "--emmental.sgd_momentum", type=float, default=0, help="SGD momentum" ) optimizer_config.add_argument( "--emmental.sgd_dampening", type=float, default=0, help="SGD dampening" ) optimizer_config.add_argument( "--emmental.sgd_nesterov", type=str2bool, default=False, help="SGD nesterov" ) # SparseAdam config optimizer_config.add_argument( "--emmental.sparse_adam_betas", nargs="+", type=float, default=(0.9, 0.999), help="SparseAdam betas", ) optimizer_config.add_argument( "--emmental.sparse_adam_eps", type=float, default=1e-06, help="SparseAdam eps" ) # BertAdam config optimizer_config.add_argument( "--emmental.bert_adam_betas", nargs="+", type=float, default=(0.9, 0.999), help="BertAdam betas", ) optimizer_config.add_argument( "--emmental.bert_adam_eps", type=float, default=1e-06, help="BertAdam eps" ) parser_hierarchy["emmental"]["_global_optimizer"] = optimizer_config # Scheduler configuration scheduler_config = parser.add_argument_group("Scheduler configuration") scheduler_config.add_argument( "--emmental.lr_scheduler", type=nullable_string, default=None, choices=[ "linear", "exponential", "plateau", "step", "multi_step", "cyclic", "one_cycle", "cosine_annealing", ], help="Learning rate scheduler", ) scheduler_config.add_argument( "--emmental.lr_scheduler_step_unit", type=str, default="batch", choices=["batch", "epoch"], help="Learning rate scheduler step unit", ) scheduler_config.add_argument( "--emmental.lr_scheduler_step_freq", type=int, default=1, help="Learning rate scheduler step freq", ) scheduler_config.add_argument( "--emmental.warmup_steps", type=float, default=None, help="Warm up steps" ) scheduler_config.add_argument( "--emmental.warmup_unit", type=str, default="batch", choices=["batch", "epoch"], help="Warm up unit", ) scheduler_config.add_argument( "--emmental.warmup_percentage", type=float, default=None, help="Warm up percentage", ) scheduler_config.add_argument( "--emmental.min_lr", type=float, default=0.0, help="Minimum learning rate" ) scheduler_config.add_argument( "--emmental.reset_state", type=str2bool, default=False, help="Whether reset the state of the optimizer when lr changes", ) scheduler_config.add_argument( "--emmental.exponential_lr_scheduler_gamma", type=float, default=0.9, help="Gamma for exponential lr scheduler", ) # ReduceLROnPlateau lr scheduler config scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_metric", type=str, default="model/train/all/loss", help="Metric of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_mode", type=str, default="min", choices=["min", "max"], help="Mode of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_factor", type=float, default=0.1, help="Factor of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_patience", type=int, default=10, help="Patience for plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_threshold", type=float, default=0.0001, help="Threshold of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_threshold_mode", type=str, default="rel", choices=["rel", "abs"], help="Threshold mode of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_cooldown", type=int, default=0, help="Cooldown of plateau lr scheduler", ) scheduler_config.add_argument( "--emmental.plateau_lr_scheduler_eps", type=float, default=0.00000001, help="Eps of plateau lr scheduler", ) # Step lr scheduler config scheduler_config.add_argument( "--emmental.step_lr_scheduler_step_size", type=int, default=1, help="Period of learning rate decay", ) scheduler_config.add_argument( "--emmental.step_lr_scheduler_gamma", type=float, default=0.1, help="Multiplicative factor of learning rate decay", ) scheduler_config.add_argument( "--emmental.step_lr_scheduler_last_epoch", type=int, default=-1, help="The index of last epoch", ) scheduler_config.add_argument( "--emmental.multi_step_lr_scheduler_milestones", nargs="+", type=int, default=[1000], help="List of epoch indices. Must be increasing.", ) scheduler_config.add_argument( "--emmental.multi_step_lr_scheduler_gamma", type=float, default=0.1, help="Multiplicative factor of learning rate decay", ) scheduler_config.add_argument( "--emmental.multi_step_lr_scheduler_last_epoch", type=int, default=-1, help="The index of last epoch", ) # Cyclic lr scheduler config scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_base_lr", nargs="+", type=float, default=0.001, help="Base lr of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_max_lr", nargs="+", type=float, default=0.1, help="Max lr of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_step_size_up", type=int, default=2000, help="Step size up of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_step_size_down", type=nullable_int, default=None, help="Step size down of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_mode", type=nullable_string, default="triangular", help="Mode of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_gamma", type=float, default=1.0, help="Gamma of cyclic lr scheduler", ) # TODO: support cyclic_lr_scheduler_scale_fn scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_scale_mode", type=str, default="cycle", choices=["cycle", "iterations"], help="Scale mode of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_cycle_momentum", type=str2bool, default=True, help="Cycle momentum of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_base_momentum", nargs="+", type=float, default=0.8, help="Base momentum of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_max_momentum", nargs="+", type=float, default=0.9, help="Max momentum of cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cyclic_lr_scheduler_last_epoch", type=int, default=-1, help="Last epoch of cyclic lr scheduler", ) # One cycle lr scheduler config scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_max_lr", nargs="+", type=float, default=0.1, help="Max lr of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_pct_start", type=float, default=0.3, help="Percentage start of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_anneal_strategy", type=str, default="cos", choices=["cos", "linear"], help="Anneal strategyr of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_cycle_momentum", type=str2bool, default=True, help="Cycle momentum of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_base_momentum", nargs="+", type=float, default=0.85, help="Base momentum of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_max_momentum", nargs="+", type=float, default=0.95, help="Max momentum of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_div_factor", type=float, default=25, help="Div factor of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_final_div_factor", type=float, default=1e4, help="Final div factor of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.one_cycle_lr_scheduler_last_epoch", type=int, default=-1, help="Last epoch of one cyclic lr scheduler", ) scheduler_config.add_argument( "--emmental.cosine_annealing_lr_scheduler_last_epoch", type=int, default=-1, help="The index of last epoch", ) scheduler_config.add_argument( "--emmental.task_scheduler", type=str, default="round_robin", # choices=["sequential", "round_robin", "mixed"], help="Task scheduler", ) scheduler_config.add_argument( "--emmental.sequential_scheduler_fillup", type=str2bool, default=False, help="Whether fillup in sequential scheduler", ) scheduler_config.add_argument( "--emmental.round_robin_scheduler_fillup", type=str2bool, default=False, help="whether fillup in round robin scheduler", ) scheduler_config.add_argument( "--emmental.mixed_scheduler_fillup", type=str2bool, default=False, help="whether fillup in mixed scheduler scheduler", ) parser_hierarchy["emmental"]["_global_scheduler"] = scheduler_config # Logging configuration logging_config = parser.add_argument_group("Logging configuration") logging_config.add_argument( "--emmental.counter_unit", type=str, default="epoch", choices=["epoch", "batch"], help="Logging unit (epoch, batch)", ) logging_config.add_argument( "--emmental.evaluation_freq", type=float, default=1, help="Logging evaluation frequency", ) logging_config.add_argument( "--emmental.writer", type=str, default="tensorboard", choices=["json", "tensorboard", "wandb"], help="The writer format (json, tensorboard, wandb)", ) logging_config.add_argument( "--emmental.write_loss_per_step", type=bool, default=False, help="Whether to log loss per step", ) logging_config.add_argument( "--emmental.wandb_project_name", type=nullable_string, default=None, help="Wandb project name", ) logging_config.add_argument( "--emmental.wandb_run_name", type=nullable_string, default=None, help="Wandb run name", ) logging_config.add_argument( "--emmental.wandb_watch_model", type=bool, default=False, help="Whether use wandb to watch model", ) logging_config.add_argument( "--emmental.wandb_model_watch_freq", type=nullable_int, default=None, help="Wandb model watch frequency", ) logging_config.add_argument( "--emmental.checkpointing", type=str2bool, default=True, help="Whether to checkpoint the model", ) logging_config.add_argument( "--emmental.checkpoint_path", type=str, default=None, help="Checkpointing path" ) logging_config.add_argument( "--emmental.checkpoint_freq", type=int, default=1, help="Checkpointing every k logging time", ) logging_config.add_argument( "--emmental.checkpoint_metric", type=str2dict, default={"model/train/all/loss": "min"}, help=( "Checkpointing metric (metric_name:mode), " "e.g., `model/train/all/loss:min`" ), ) logging_config.add_argument( "--emmental.checkpoint_task_metrics", type=str2dict, default=None, help=( "Task specific checkpointing metric " "(metric_name1:mode1,metric_name2:mode2)" ), ) logging_config.add_argument( "--emmental.checkpoint_runway", type=float, default=0, help="Checkpointing runway (no checkpointing before k checkpointing unit)", ) logging_config.add_argument( "--emmental.checkpoint_all", type=str2bool, default=True, help="Whether to checkpoint all checkpoints", ) logging_config.add_argument( "--emmental.clear_intermediate_checkpoints", type=str2bool, default=False, help="Whether to clear intermediate checkpoints", ) logging_config.add_argument( "--emmental.clear_all_checkpoints", type=str2bool, default=False, help="Whether to clear all checkpoints", ) parser_hierarchy["emmental"]["_global_logging"] = logging_config return parser, parser_hierarchy
[docs]def parse_args_to_config(args: DottedDict) -> Dict[str, Any]: """Parse the Emmental arguments to config dict. Args: args: parsed namespace from argument parser. Returns: Emmental config dict. """ config = { "meta_config": { "seed": args.seed, "verbose": args.verbose, "log_path": args.log_path, "use_exact_log_path": args.use_exact_log_path, }, "data_config": { "min_data_len": args.min_data_len, "max_data_len": args.max_data_len, }, "model_config": { "model_path": args.model_path, "device": args.device, "dataparallel": args.dataparallel, "distributed_backend": args.distributed_backend, }, "learner_config": { "optimizer_path": args.optimizer_path, "scheduler_path": args.scheduler_path, "fp16": args.fp16, "fp16_opt_level": args.fp16_opt_level, "local_rank": args.local_rank, "epochs_learned": args.epochs_learned, "n_epochs": args.n_epochs, "steps_learned": args.steps_learned, "n_steps": args.n_steps, "skip_learned_data": args.skip_learned_data, "train_split": args.train_split, "valid_split": args.valid_split, "test_split": args.test_split, "ignore_index": args.ignore_index, "online_eval": args.online_eval, "optimizer_config": { "optimizer": args.optimizer, "lr": args.lr, "l2": args.l2, "grad_clip": args.grad_clip, "gradient_accumulation_steps": args.gradient_accumulation_steps, "asgd_config": { "lambd": args.asgd_lambd, "alpha": args.asgd_alpha, "t0": args.asgd_t0, }, "adadelta_config": {"rho": args.adadelta_rho, "eps": args.adadelta_eps}, "adagrad_config": { "lr_decay": args.adagrad_lr_decay, "initial_accumulator_value": args.adagrad_initial_accumulator_value, "eps": args.adagrad_eps, }, "adam_config": { "betas": args.adam_betas, "amsgrad": args.adam_amsgrad, "eps": args.adam_eps, }, "adamw_config": { "betas": args.adamw_betas, "amsgrad": args.adamw_amsgrad, "eps": args.adamw_eps, }, "adamax_config": {"betas": args.adamax_betas, "eps": args.adamax_eps}, "lbfgs_config": { "max_iter": args.lbfgs_max_iter, "max_eval": args.lbfgs_max_eval, "tolerance_grad": args.lbfgs_tolerance_grad, "tolerance_change": args.lbfgs_tolerance_change, "history_size": args.lbfgs_history_size, "line_search_fn": args.lbfgs_line_search_fn, }, "rms_prop_config": { "alpha": args.rms_prop_alpha, "eps": args.rms_prop_eps, "momentum": args.rms_prop_momentum, "centered": args.rms_prop_centered, }, "r_prop_config": { "etas": args.r_prop_etas, "step_sizes": args.r_prop_step_sizes, }, "sgd_config": { "momentum": args.sgd_momentum, "dampening": args.sgd_dampening, "nesterov": args.sgd_nesterov, }, "sparse_adam_config": { "betas": args.sparse_adam_betas, "eps": args.sparse_adam_eps, }, "bert_adam_config": { "betas": args.bert_adam_betas, "eps": args.bert_adam_eps, }, }, "lr_scheduler_config": { "lr_scheduler": args.lr_scheduler, "lr_scheduler_step_unit": args.lr_scheduler_step_unit, "lr_scheduler_step_freq": args.lr_scheduler_step_freq, "warmup_steps": args.warmup_steps, "warmup_unit": args.warmup_unit, "warmup_percentage": args.warmup_percentage, "min_lr": args.min_lr, "reset_state": args.reset_state, "exponential_config": {"gamma": args.exponential_lr_scheduler_gamma}, "plateau_config": { "metric": args.plateau_lr_scheduler_metric, "mode": args.plateau_lr_scheduler_mode, "factor": args.plateau_lr_scheduler_factor, "patience": args.plateau_lr_scheduler_patience, "threshold": args.plateau_lr_scheduler_threshold, "threshold_mode": args.plateau_lr_scheduler_threshold_mode, "cooldown": args.plateau_lr_scheduler_cooldown, "eps": args.plateau_lr_scheduler_eps, }, "step_config": { "step_size": args.step_lr_scheduler_step_size, "gamma": args.step_lr_scheduler_gamma, "last_epoch": args.step_lr_scheduler_last_epoch, }, "multi_step_config": { "milestones": args.multi_step_lr_scheduler_milestones, "gamma": args.multi_step_lr_scheduler_gamma, "last_epoch": args.multi_step_lr_scheduler_last_epoch, }, "cyclic_config": { "base_lr": args.cyclic_lr_scheduler_base_lr, "max_lr": args.cyclic_lr_scheduler_max_lr, "step_size_up": args.cyclic_lr_scheduler_step_size_up, "step_size_down": args.cyclic_lr_scheduler_step_size_down, "mode": args.cyclic_lr_scheduler_mode, "gamma": args.cyclic_lr_scheduler_gamma, "scale_fn": None, "scale_mode": args.cyclic_lr_scheduler_scale_mode, "cycle_momentum": args.cyclic_lr_scheduler_cycle_momentum, "base_momentum": args.cyclic_lr_scheduler_base_momentum, "max_momentum": args.cyclic_lr_scheduler_max_momentum, "last_epoch": args.cyclic_lr_scheduler_last_epoch, }, "one_cycle_config": { "max_lr": args.one_cycle_lr_scheduler_max_lr, "pct_start": args.one_cycle_lr_scheduler_pct_start, "anneal_strategy": args.one_cycle_lr_scheduler_anneal_strategy, "cycle_momentum": args.one_cycle_lr_scheduler_cycle_momentum, "base_momentum": args.one_cycle_lr_scheduler_base_momentum, "max_momentum": args.one_cycle_lr_scheduler_max_momentum, "div_factor": args.one_cycle_lr_scheduler_div_factor, "final_div_factor": args.one_cycle_lr_scheduler_final_div_factor, "last_epoch": args.one_cycle_lr_scheduler_last_epoch, }, "cosine_annealing_config": { "last_epoch": args.cosine_annealing_lr_scheduler_last_epoch }, }, "task_scheduler_config": { "task_scheduler": args.task_scheduler, "sequential_scheduler_config": { "fillup": args.sequential_scheduler_fillup }, "round_robin_scheduler_config": { "fillup": args.round_robin_scheduler_fillup }, "mixed_scheduler_config": {"fillup": args.mixed_scheduler_fillup}, }, }, "logging_config": { "counter_unit": args.counter_unit, "evaluation_freq": args.evaluation_freq, "writer_config": { "verbose": True, "writer": args.writer, "write_loss_per_step": args.write_loss_per_step, "wandb_project_name": args.wandb_project_name, "wandb_run_name": args.wandb_run_name, "wandb_watch_model": args.wandb_watch_model, "wandb_model_watch_freq": args.wandb_model_watch_freq, }, "checkpointing": args.checkpointing, "checkpointer_config": { "checkpoint_path": args.checkpoint_path, "checkpoint_freq": args.checkpoint_freq, "checkpoint_metric": args.checkpoint_metric, "checkpoint_task_metrics": args.checkpoint_task_metrics, "checkpoint_runway": args.checkpoint_runway, "checkpoint_all": args.checkpoint_all, "clear_intermediate_checkpoints": args.clear_intermediate_checkpoints, "clear_all_checkpoints": args.clear_all_checkpoints, }, }, } return create_bool_dotted_dict(config)