45 lines
1.6 KiB
Python
45 lines
1.6 KiB
Python
|
|
import os
|
|
import torch
|
|
from sdsvkvu.model.kvu_model import KVUModel
|
|
from sdsvkvu.model.combined_model import ComKVUModel
|
|
from sdsvkvu.model.document_kvu_model import DocKVUModel
|
|
from sdsvkvu.model.sbt_model import SBTModel
|
|
|
|
def get_model(cfg):
|
|
if cfg.mode == 0 or cfg.mode == 1:
|
|
model = ComKVUModel(cfg=cfg)
|
|
elif cfg.mode == 2:
|
|
model = KVUModel(cfg=cfg)
|
|
elif cfg.mode == 3:
|
|
model = DocKVUModel(cfg=cfg)
|
|
elif cfg.mode == 4:
|
|
model = SBTModel(cfg=cfg)
|
|
else:
|
|
raise ValueError(f'[ERROR] Model mode of {cfg.mode} is not supported')
|
|
return model
|
|
|
|
def load_checkpoint(ckpt_path, model, key_include):
|
|
assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!"
|
|
state_dict = torch.load(ckpt_path, 'cpu')['state_dict']
|
|
for key in list(state_dict.keys()):
|
|
if f'.{key_include}.' not in key:
|
|
del state_dict[key]
|
|
else:
|
|
state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something.
|
|
del state_dict[key]
|
|
model.load_state_dict(state_dict, strict=True)
|
|
print(f"Load checkpoint at {ckpt_path}")
|
|
return model
|
|
|
|
def load_model_weight(net, pretrained_model_file):
|
|
pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[
|
|
"state_dict"
|
|
]
|
|
new_state_dict = {}
|
|
for k, v in pretrained_model_state_dict.items():
|
|
new_k = k
|
|
if new_k.startswith("net."):
|
|
new_k = new_k[len("net.") :]
|
|
new_state_dict[new_k] = v
|
|
net.load_state_dict(new_state_dict) |