from model.combined_model import CombinedKVUModel from model.kvu_model import KVUModel from model.document_kvu_model import DocumentKVUModel def get_model(cfg): if cfg.stage == 1: model = CombinedKVUModel(cfg=cfg) elif cfg.stage == 2: model = KVUModel(cfg=cfg) elif cfg.stage == 3: model = DocumentKVUModel(cfg=cfg) else: raise Exception('[ERROR] Trainging stage is wrong') return model