Source code for bootleg.utils.model_utils

"""Model utils."""
import logging

from bootleg import log_rank_0_debug

logger = logging.getLogger(__name__)


[docs]def count_parameters(model, requires_grad, logger): """Count the number of parameters. Args: model: model to count requires_grad: whether to look at grad or no grad params logger: logger """ for p in [ p for p in model.named_parameters() if p[1].requires_grad is requires_grad ]: log_rank_0_debug( logger, "{:s} {:d} {:.2f} MB".format( p[0], p[1].numel(), p[1].numel() * 4 / 1024**2 ), ) return sum( p.numel() for p in model.parameters() if p.requires_grad is requires_grad )
[docs]def get_max_candidates(entity_symbols, data_config): """ Get max candidates. Returns the maximum number of candidates used in the model, taking into account train_in_candidates If train_in_canddiates is False, we add a NC entity candidate (for null candidate) Args: entity_symbols: entity symbols data_config: data config """ return entity_symbols.max_candidates + int(not data_config.train_in_candidates)