import os import torch from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers.tensorboard import TensorBoardLogger from pytorch_lightning.plugins import DDPPlugin from utils.ema_callbacks import EMA import logging import logging.config from utils.logging.logging import LOGGER_CONFIG # Load the logging configuration logging.config.dictConfig(LOGGER_CONFIG) # Get the logger logger = logging.getLogger(__name__) def _update_config(cfg): cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints") cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs") # set per-gpu batch size num_devices = torch.cuda.device_count() logger.info('No. devices:', num_devices) for mode in ["train", "val"]: new_batch_size = cfg[mode].batch_size // num_devices cfg[mode].batch_size = new_batch_size def _get_config_from_cli(): cfg_cli = OmegaConf.from_cli() cli_keys = list(cfg_cli.keys()) for cli_key in cli_keys: if "--" in cli_key: cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key] del cfg_cli[cli_key] return cfg_cli def get_callbacks(cfg): callback_list = [] checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir, filename='best_model', save_last=True, save_top_k=1, save_weights_only=True, verbose=True, monitor='val_f1', mode='max') checkpoint_callback.FILE_EXTENSION = ".pth" checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model" callback_list.append(checkpoint_callback) if cfg.callbacks.ema.decay != -1: ema_callback = EMA(decay=0.9999) callback_list.append(ema_callback) return callback_list if len(callback_list) > 1 else checkpoint_callback def get_plugins(cfg): plugins = [] if cfg.train.strategy.type == "ddp": plugins.append(DDPPlugin()) return plugins def get_loggers(cfg): loggers = [] loggers.append( TensorBoardLogger( cfg.tensorboard_dir, name="", version="", default_hp_metric=False ) ) return loggers def cfg_to_hparams(cfg, hparam_dict, parent_str=""): for key, val in cfg.items(): if isinstance(val, DictConfig): hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__") else: hparam_dict[parent_str + key] = str(val) return hparam_dict def get_specific_pl_logger(pl_loggers, logger_type): for pl_logger in pl_loggers: if isinstance(pl_logger, logger_type): return pl_logger return None def get_class_names(dataset_root_path): class_names_file = os.path.join(dataset_root_path[0], "class_names.txt") class_names = ( open(class_names_file, "r", encoding="utf-8").read().strip().split("\n") ) return class_names def create_exp_dir(save_dir=''): if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) else: logger.info("DIR already existed.") logger.info('Experiment dir : {}'.format(save_dir)) def create_dir(save_dir=''): if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) else: logger.info("DIR already existed.") logger.info('Save dir : {}'.format(save_dir)) def load_checkpoint(ckpt_path, model, key_include): assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!" state_dict = torch.load(ckpt_path, 'cpu')['state_dict'] for key in list(state_dict.keys()): if f'.{key_include}.' not in key: del state_dict[key] else: state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something. del state_dict[key] model.load_state_dict(state_dict, strict=True) logger.info(f"Load checkpoint at {ckpt_path}") return model def load_model_weight(net, pretrained_model_file): pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[ "state_dict" ] new_state_dict = {} for k, v in pretrained_model_state_dict.items(): new_k = k if new_k.startswith("net."): new_k = new_k[len("net.") :] new_state_dict[new_k] = v net.load_state_dict(new_state_dict)