sbt-idp/cope2n-ai-fi/common/AnyKey_Value/utils/__init__.py

134 lines
4.6 KiB
Python
Executable File

import os
import torch
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import DDPPlugin
from utils.ema_callbacks import EMA
import logging
import logging.config
from utils.logging.logging import LOGGER_CONFIG
# Load the logging configuration
logging.config.dictConfig(LOGGER_CONFIG)
# Get the logger
logger = logging.getLogger(__name__)
def _update_config(cfg):
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
# set per-gpu batch size
num_devices = torch.cuda.device_count()
logger.info('No. devices:', num_devices)
for mode in ["train", "val"]:
new_batch_size = cfg[mode].batch_size // num_devices
cfg[mode].batch_size = new_batch_size
def _get_config_from_cli():
cfg_cli = OmegaConf.from_cli()
cli_keys = list(cfg_cli.keys())
for cli_key in cli_keys:
if "--" in cli_key:
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
del cfg_cli[cli_key]
return cfg_cli
def get_callbacks(cfg):
callback_list = []
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
filename='best_model',
save_last=True,
save_top_k=1,
save_weights_only=True,
verbose=True,
monitor='val_f1', mode='max')
checkpoint_callback.FILE_EXTENSION = ".pth"
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
callback_list.append(checkpoint_callback)
if cfg.callbacks.ema.decay != -1:
ema_callback = EMA(decay=0.9999)
callback_list.append(ema_callback)
return callback_list if len(callback_list) > 1 else checkpoint_callback
def get_plugins(cfg):
plugins = []
if cfg.train.strategy.type == "ddp":
plugins.append(DDPPlugin())
return plugins
def get_loggers(cfg):
loggers = []
loggers.append(
TensorBoardLogger(
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
)
)
return loggers
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
for key, val in cfg.items():
if isinstance(val, DictConfig):
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
else:
hparam_dict[parent_str + key] = str(val)
return hparam_dict
def get_specific_pl_logger(pl_loggers, logger_type):
for pl_logger in pl_loggers:
if isinstance(pl_logger, logger_type):
return pl_logger
return None
def get_class_names(dataset_root_path):
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
class_names = (
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
)
return class_names
def create_exp_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
logger.info("DIR already existed.")
logger.info('Experiment dir : {}'.format(save_dir))
def create_dir(save_dir=''):
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
else:
logger.info("DIR already existed.")
logger.info('Save dir : {}'.format(save_dir))
def load_checkpoint(ckpt_path, model, key_include):
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
for key in list(state_dict.keys()):
if f'.{key_include}.' not in key:
del state_dict[key]
else:
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
del state_dict[key]
model.load_state_dict(state_dict, strict=True)
logger.info(f"Load checkpoint at {ckpt_path}")
return model
def load_model_weight(net, pretrained_model_file):
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
"state_dict"
]
new_state_dict = {}
for k, v in pretrained_model_state_dict.items():
new_k = k
if new_k.startswith("net."):
new_k = new_k[len("net.") :]
new_state_dict[new_k] = v
net.load_state_dict(new_state_dict)