21 lines
743 B
Python
21 lines
743 B
Python
|
# dirty path export
|
||
|
from sdsvtr import StandaloneSATRNRunner
|
||
|
import yaml
|
||
|
|
||
|
class Predictor:
|
||
|
def __init__(self, setting_file='./setting.yml'):
|
||
|
with open(setting_file) as f:
|
||
|
# use safe_load instead load
|
||
|
self.setting = yaml.safe_load(f)
|
||
|
|
||
|
self.batch_size = self.setting['batch_size']
|
||
|
self.runner = StandaloneSATRNRunner(version='satrn-lite-general-pretrain-20230106',
|
||
|
return_confident=True, device=self.setting['device'])
|
||
|
|
||
|
def __call__(self, images):
|
||
|
results = []
|
||
|
for i in range(0, len(images), self.batch_size):
|
||
|
result = self.runner(images[i:i+self.batch_size])
|
||
|
results += result[0]
|
||
|
return results
|