134 lines
4.6 KiB
Python
Executable File
134 lines
4.6 KiB
Python
Executable File
import os
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
from omegaconf.dictconfig import DictConfig
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
|
from pytorch_lightning.plugins import DDPPlugin
|
|
from utils.ema_callbacks import EMA
|
|
import logging
|
|
import logging.config
|
|
from utils.logging.logging import LOGGER_CONFIG
|
|
# Load the logging configuration
|
|
logging.config.dictConfig(LOGGER_CONFIG)
|
|
# Get the logger
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def _update_config(cfg):
|
|
cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
|
|
cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
|
|
|
|
# set per-gpu batch size
|
|
num_devices = torch.cuda.device_count()
|
|
logger.info('No. devices:', num_devices)
|
|
for mode in ["train", "val"]:
|
|
new_batch_size = cfg[mode].batch_size // num_devices
|
|
cfg[mode].batch_size = new_batch_size
|
|
|
|
def _get_config_from_cli():
|
|
cfg_cli = OmegaConf.from_cli()
|
|
cli_keys = list(cfg_cli.keys())
|
|
for cli_key in cli_keys:
|
|
if "--" in cli_key:
|
|
cfg_cli[cli_key.replace("--", "")] = cfg_cli[cli_key]
|
|
del cfg_cli[cli_key]
|
|
|
|
return cfg_cli
|
|
|
|
def get_callbacks(cfg):
|
|
callback_list = []
|
|
checkpoint_callback = ModelCheckpoint(dirpath=cfg.save_weight_dir,
|
|
filename='best_model',
|
|
save_last=True,
|
|
save_top_k=1,
|
|
save_weights_only=True,
|
|
verbose=True,
|
|
monitor='val_f1', mode='max')
|
|
checkpoint_callback.FILE_EXTENSION = ".pth"
|
|
checkpoint_callback.CHECKPOINT_NAME_LAST = "last_model"
|
|
callback_list.append(checkpoint_callback)
|
|
if cfg.callbacks.ema.decay != -1:
|
|
ema_callback = EMA(decay=0.9999)
|
|
callback_list.append(ema_callback)
|
|
return callback_list if len(callback_list) > 1 else checkpoint_callback
|
|
|
|
def get_plugins(cfg):
|
|
plugins = []
|
|
if cfg.train.strategy.type == "ddp":
|
|
plugins.append(DDPPlugin())
|
|
|
|
return plugins
|
|
|
|
def get_loggers(cfg):
|
|
loggers = []
|
|
|
|
loggers.append(
|
|
TensorBoardLogger(
|
|
cfg.tensorboard_dir, name="", version="", default_hp_metric=False
|
|
)
|
|
)
|
|
|
|
return loggers
|
|
|
|
def cfg_to_hparams(cfg, hparam_dict, parent_str=""):
|
|
for key, val in cfg.items():
|
|
if isinstance(val, DictConfig):
|
|
hparam_dict = cfg_to_hparams(val, hparam_dict, parent_str + key + "__")
|
|
else:
|
|
hparam_dict[parent_str + key] = str(val)
|
|
return hparam_dict
|
|
|
|
def get_specific_pl_logger(pl_loggers, logger_type):
|
|
for pl_logger in pl_loggers:
|
|
if isinstance(pl_logger, logger_type):
|
|
return pl_logger
|
|
return None
|
|
|
|
def get_class_names(dataset_root_path):
|
|
class_names_file = os.path.join(dataset_root_path[0], "class_names.txt")
|
|
class_names = (
|
|
open(class_names_file, "r", encoding="utf-8").read().strip().split("\n")
|
|
)
|
|
return class_names
|
|
|
|
def create_exp_dir(save_dir=''):
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
else:
|
|
logger.info("DIR already existed.")
|
|
logger.info('Experiment dir : {}'.format(save_dir))
|
|
|
|
def create_dir(save_dir=''):
|
|
if not os.path.exists(save_dir):
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
else:
|
|
logger.info("DIR already existed.")
|
|
logger.info('Save dir : {}'.format(save_dir))
|
|
|
|
def load_checkpoint(ckpt_path, model, key_include):
|
|
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
|
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
|
|
for key in list(state_dict.keys()):
|
|
if f'.{key_include}.' not in key:
|
|
del state_dict[key]
|
|
else:
|
|
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
|
del state_dict[key]
|
|
model.load_state_dict(state_dict, strict=True)
|
|
logger.info(f"Load checkpoint at {ckpt_path}")
|
|
return model
|
|
|
|
def load_model_weight(net, pretrained_model_file):
|
|
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
|
|
"state_dict"
|
|
]
|
|
new_state_dict = {}
|
|
for k, v in pretrained_model_state_dict.items():
|
|
new_k = k
|
|
if new_k.startswith("net."):
|
|
new_k = new_k[len("net.") :]
|
|
new_state_dict[new_k] = v
|
|
net.load_state_dict(new_state_dict)
|
|
|
|
|