sbt-idp/cope2n-ai-fi/modules/ocr_engine/externals/sdsvtd/sdsvtd/factory.py
2023-12-11 13:15:11 +00:00

75 lines
3.1 KiB
Python

import os
import shutil
import hashlib
import warnings
def sha256sum(filename):
h = hashlib.sha256()
b = bytearray(128*1024)
mv = memoryview(b)
with open(filename, 'rb', buffering=0) as f:
for n in iter(lambda : f.readinto(mv), 0):
h.update(mv[:n])
return h.hexdigest()
online_model_factory = {
'yolox-s-general-text-pretrain-20221226': {
'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/62j266xm8r.pth',
'hash': '89bff792685af454d0cfea5d6d673be6914d614e4c2044e786da6eddf36f8b50'},
'yolox-s-checkbox-20220726': {
'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/1647d7eys7.pth',
'hash': '7c1e188b7375dcf0b7b9d317675ebd92a86fdc29363558002249867249ee10f8'},
'yolox-s-idcard-5c-20221027': {
'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/jr0egad3ix.pth',
'hash': '73a7772594c1f6d3f6d6a98b6d6e4097af5026864e3bd50531ad9e635ae795a7'},
'yolox-s-handwritten-text-line-20230228': {
'url': 'https://github.com/moewiee/satrn-model-factory/raw/main/rb07rtwmgi.pth',
'hash': 'a31d1bf8fc880479d2e11463dad0b4081952a13e553a02919109b634a1190ef1'}
}
__hub_available_versions__ = online_model_factory.keys()
def _get_from_hub(file_path, version, version_url):
os.system(f'wget -O {file_path} {version_url}')
assert os.path.exists(file_path), \
'wget failed while trying to retrieve from hub.'
downloaded_hash = sha256sum(file_path)
if downloaded_hash != online_model_factory[version]['hash']:
os.remove(file_path)
raise ValueError('sha256 hash doesnt match for version retrieved from hub.')
def _get(version):
use_online = version in __hub_available_versions__
if not use_online and not os.path.exists(version):
raise ValueError(f'Model version {version} not found online and not found local.')
hub_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'hub')
if not os.path.exists(hub_path):
os.makedirs(hub_path)
if use_online:
version_url = online_model_factory[version]['url']
file_path = os.path.join(hub_path, os.path.basename(version_url))
else:
file_path = os.path.join(hub_path, os.path.basename(version))
if not os.path.exists(file_path):
if use_online:
_get_from_hub(file_path, version, version_url)
else:
shutil.copy2(version, file_path)
else:
if use_online:
downloaded_hash = sha256sum(file_path)
if downloaded_hash != online_model_factory[version]['hash']:
os.remove(file_path)
warnings.warn('existing hub version sha256 hash doesnt match, now re-download from hub.')
_get_from_hub(file_path, version, version_url)
else:
if sha256sum(file_path) != sha256sum(version):
os.remove(file_path)
warnings.warn('existing local version sha256 hash doesnt match, now replace with new local version.')
shutil.copy2(version, file_path)
return file_path