sbt-idp/cope2n-ai-fi/modules/_sdsvkvu/sdsvkvu/model/__init__.py
2023-11-30 18:22:16 +07:00

45 lines
1.6 KiB
Python

import os
import torch
from sdsvkvu.model.kvu_model import KVUModel
from sdsvkvu.model.combined_model import ComKVUModel
from sdsvkvu.model.document_kvu_model import DocKVUModel
from sdsvkvu.model.sbt_model import SBTModel
def get_model(cfg):
if cfg.mode == 0 or cfg.mode == 1:
model = ComKVUModel(cfg=cfg)
elif cfg.mode == 2:
model = KVUModel(cfg=cfg)
elif cfg.mode == 3:
model = DocKVUModel(cfg=cfg)
elif cfg.mode == 4:
model = SBTModel(cfg=cfg)
else:
raise ValueError(f'[ERROR] Model mode of {cfg.mode} is not supported')
return model
def load_checkpoint(ckpt_path, model, key_include):
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
for key in list(state_dict.keys()):
if f'.{key_include}.' not in key:
del state_dict[key]
else:
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
del state_dict[key]
model.load_state_dict(state_dict, strict=True)
print(f"Load checkpoint at {ckpt_path}")
return model
def load_model_weight(net, pretrained_model_file):
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
"state_dict"
]
new_state_dict = {}
for k, v in pretrained_model_state_dict.items():
new_k = k
if new_k.startswith("net."):
new_k = new_k[len("net.") :]
new_state_dict[new_k] = v
net.load_state_dict(new_state_dict)