sbt-idp/cope2n-ai-fi/common/AnyKey_Value/lightning_modules/classifier.py
2023-11-30 18:22:16 +07:00

134 lines
4.6 KiB
Python
Executable File

import time
import torch
import torch.utils.data
from overrides import overrides
from pytorch_lightning import LightningModule
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import LambdaLR
from lightning_modules.schedulers import (
cosine_scheduler,
linear_scheduler,
multistep_scheduler,
)
from model import get_model
from utils import cfg_to_hparams, get_specific_pl_logger
class ClassifierModule(LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.net = get_model(self.cfg)
self.ignore_index = -100
self.time_tracker = None
self.optimizer_types = {
"sgd": SGD,
"adam": Adam,
"adamw": AdamW,
}
@overrides
def setup(self, stage):
self.time_tracker = time.time()
@overrides
def configure_optimizers(self):
optimizer = self._get_optimizer()
scheduler = self._get_lr_scheduler(optimizer)
scheduler = {
"scheduler": scheduler,
"name": "learning_rate",
"interval": "step",
}
return [optimizer], [scheduler]
def _get_lr_scheduler(self, optimizer):
cfg_train = self.cfg.train
lr_schedule_method = cfg_train.optimizer.lr_schedule.method
lr_schedule_params = cfg_train.optimizer.lr_schedule.params
if lr_schedule_method is None:
scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
elif lr_schedule_method == "step":
scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
elif lr_schedule_method == "cosine":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = cosine_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
elif lr_schedule_method == "linear":
total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
total_batch_size = cfg_train.batch_size * self.trainer.world_size
max_iter = total_samples / total_batch_size
scheduler = linear_scheduler(
optimizer, training_steps=max_iter, **lr_schedule_params
)
else:
raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")
return scheduler
def _get_optimizer(self):
opt_cfg = self.cfg.train.optimizer
method = opt_cfg.method.lower()
if method not in self.optimizer_types:
raise ValueError(f"Unknown optimizer method={method}")
kwargs = dict(opt_cfg.params)
kwargs["params"] = self.net.parameters()
optimizer = self.optimizer_types[method](**kwargs)
return optimizer
@rank_zero_only
@overrides
def on_fit_end(self):
hparam_dict = cfg_to_hparams(self.cfg, {})
metric_dict = {"metric/dummy": 0}
tb_logger = get_specific_pl_logger(self.logger, TensorBoardLogger)
if tb_logger:
tb_logger.log_hyperparams(hparam_dict, metric_dict)
@overrides
def training_epoch_end(self, training_step_outputs):
avg_loss = torch.tensor(0.0).to(self.device)
for step_out in training_step_outputs:
avg_loss += step_out["loss"]
log_dict = {"train_loss": avg_loss}
self._log_shell(log_dict, prefix="train ")
def _log_shell(self, log_info, prefix=""):
log_info_shell = {}
for k, v in log_info.items():
new_v = v
if type(new_v) is torch.Tensor:
new_v = new_v.item()
log_info_shell[k] = new_v
out_str = prefix.upper()
if prefix.upper().strip() in ["TRAIN", "VAL"]:
out_str += f"[epoch: {self.current_epoch}/{self.cfg.train.max_epochs}]"
if self.training:
lr = self.trainer._lightning_optimizers[0].param_groups[0]["lr"]
log_info_shell["lr"] = lr
for key, value in log_info_shell.items():
out_str += f" || {key}: {round(value, 5)}"
out_str += f" || time: {round(time.time() - self.time_tracker, 1)}"
out_str += " secs."
# self.print(out_str)
self.time_tracker = time.time()