import os import torch from sdsvkvu.model.kvu_model import KVUModel from sdsvkvu.model.combined_model import ComKVUModel from sdsvkvu.model.document_kvu_model import DocKVUModel from sdsvkvu.model.sbt_model import SBTModel def get_model(cfg): if cfg.mode == 0 or cfg.mode == 1: model = ComKVUModel(cfg=cfg) elif cfg.mode == 2: model = KVUModel(cfg=cfg) elif cfg.mode == 3: model = DocKVUModel(cfg=cfg) elif cfg.mode == 4: model = SBTModel(cfg=cfg) else: raise ValueError(f'[ERROR] Model mode of {cfg.mode} is not supported') return model def load_checkpoint(ckpt_path, model, key_include): assert os.path.exists(ckpt_path) == True, f"Ckpt path at {ckpt_path} not exist!" state_dict = torch.load(ckpt_path, 'cpu')['state_dict'] for key in list(state_dict.keys()): if f'.{key_include}.' not in key: del state_dict[key] else: state_dict[key[4:].replace(key_include + '.', "")] = state_dict[key] # remove net.something. del state_dict[key] model.load_state_dict(state_dict, strict=True) print(f"Load checkpoint at {ckpt_path}") return model def load_model_weight(net, pretrained_model_file): pretrained_model_state_dict = torch.load(pretrained_model_file, map_location="cpu")[ "state_dict" ] new_state_dict = {} for k, v in pretrained_model_state_dict.items(): new_k = k if new_k.startswith("net."): new_k = new_k[len("net.") :] new_state_dict[new_k] = v net.load_state_dict(new_state_dict)