import argparse
import base64
import json
import logging
import os
import traceback
import uuid
from glob import glob
from typing import Dict, Optional
import allspark
import cv2
import numpy as np
from easycv.predictors.feature_extractor import (TorchFaceFeatureExtractor,
TorchFeatureExtractor)
from easycv.utils import json_utils
from easynlp.appzoo import CLIPPredictor
from ev_error import (Error_Import, Error_InputFormat, Error_JsonParse,
Error_UnExpectedServer)
from utils.eas_utils import get_image, torch_model_type_check
logging.basicConfig(
format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
level=logging.INFO)
FLAGS = None
def get_predictor(model_type):
# """
# Get a predictor based on the specified model type.
# Parameters
# ----------
# model_type : str
# The type of the predictor. Available options are:
# - 'torch_feature_extractor'
# - 'dummy'
# - 'exception'
# - 'torch_face_feature_extractor'
# - 'mutil_model_extractor'
# Returns
# -------
# predictor : object
# The predictor object.
# Raises
# ------
# AssertionError
# If the specified model type is not available.
# """
class DummyPredictor():
def __init__(self, model_dir):
pass
def predict(self, images, batch_size=1):
"""
Predict the output for the given images.
Parameters
----------
images : list of numpy.ndarray
The images to predict.
batch_size : int, optional
The batch size to use for prediction. Default is 1.
Returns
-------
results : list of dict
The prediction results for each image.
Each dict contains the following keys:
- 'result1' : str
- 'result2' : str
"""
results = [{
'result1': 'dummy',
'result2': 'dummy'
} for i in range(len(images))]
return results
class ExceptionPredictor():
def __init__(self, model_dir):
pass
def predict(self, images, batch_size=1):
"""
Predict the output for the given images.
Parameters
----------
images : list of numpy.ndarray
The images to predict.
batch_size : int, optional
The batch size to use for prediction. Default is 1.
Raises
------
RuntimeError
If a dummy exception occurs.
"""
raise RuntimeError('dummy exception')
resigter_map = {
'torch_feature_extractor': TorchFeatureExtractor,
'dummy': DummyPredictor,
'exception': ExceptionPredictor,
'torch_face_feature_extractor': TorchFaceFeatureExtractor,
'mutil_model_extractor': CLIPPredictor
}
assert model_type in resigter_map, \
'invalid model_type %s available ones are %s' \
% (model_type, list(resigter_map.keys()))
return resigter_map[model_type]
def imdecode(image_data):
# """
# Decode image data using OpenCV.
# Parameters
# ----------
# image_data : bytes
# Image data in bytes format.
# Returns
# -------
# numpy.ndarray
# Decoded image as a NumPy array.
# """
return cv2.imdecode(image_data)
def get_result_str(request_id, result_dict=None, error=None):
# """
# Get result string for a request.
# Parameters
# ----------
# request_id : str
# ID of the request.
# result_dict : dict, optional
# Dictionary containing the result of the request.
# error : Exception, optional
# Error that occurred while processing the request.
# Returns
# -------
# tuple
# Tuple containing the result string and the status code.
# """
result = {}
result['request_id'] = request_id
if error is not None:
result['success'] = False
result['error_code'] = error.code
result['error_msg'] = error.msg
stat = error.code
elif result_dict is not None:
result['success'] = True
result.update(result_dict)
stat = 200
result_str = json_utils.compat_dumps(result).encode('utf8')
return result_str, stat
[docs]def init_er_service(self, backend='oss', **kwargs):
"""
初始化检索服务。
调用该方式,function_name,设置为'init'.
Args:
backend:backend identifier.
当backend是'oss'时,function_params必须包含'root_path'和oss_io_config字段。
root_path:EasyRetrievalDocArray保存和加载的路径。
oss_io_config:包含OSS配置信息的字典。
当backend是'elasticsearch'时,未完待续。
Returns:
带有“status”字段的信息,指示初始化的成败,以及包含包含额外信息的消息的“info”字段。
"""
self.backend = backend
assert self.backend in ['oss', 'elasticsearch']
if self.backend == 'oss':
root_path = kwargs.pop('root_path')
oss_io_config = kwargs.pop('oss_io_config')
if root_path is None or oss_io_config is None:
return {
'status':
False,
'info':
'root_path & oss_io_config \
should be set when use oss backend!'
}
self.oss_root = root_path
try:
from easy_retrieval.pai_docarray import \
EasyRetrievalDocArray
except BaseException as e:
error = Error_Import(str(e))
return {'info': error, 'status': False}
er_service = EasyRetrievalDocArray(root_path=root_path,
oss_config=oss_io_config)
er_service.predictor = self._predictor
init_status = True
return {'info': 'oss backend success!', 'status': True}, \
er_service, init_status
if self.backend == 'elasticsearch':
error = NotImplementedError()
return {'info': error, 'status': False}
class MyProcessor(allspark.BaseProcessor):
# """ MyProcessor is a example
# you can send mesage like this to predict
# curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105'
# """
def initialize(self, unittest_config: Optional[Dict] = {}):
"""
此函数加载preditor模型并设置其属性值。
此函数不需要调用。
Args:
_model_config: preditor模型类型、特征维度和度量方式的字典。
model_type: 正在使用的preditor类型。
feature_dimension:preditor特征向量的维度。
distance_metric: 用于比较特征向量的距离度量。
_predictor: preditor预测器类的实例。
"""
model_dir = '../../model/'
defaults = allspark.default_properties()
# this abstract use for modelhub model_type
# - (hyperparameter)config set,
# in modelhub every model is unique with model_type
# (we should know this is not a good design for easy_predict)
modelhub_set = {
# ------------- retrieval & tf retrieval-------------
'image_retrieval': {
'model_type': 'retrieval',
'predictor_cls_name': 'TorchFeatureExtractor',
'feature_dimension': 1536,
'distance_metric': 'Cosine'
},
'faceid_retrieval': {
'model_type': 'retrieval',
'predictor_cls_name': 'TorchFaceFeatureExtractor',
'feature_dimension': 512,
'distance_metric': 'SquaredEuclidean'
},
'imagetext_retrieval_cn': {
'model_type': 'retrieval',
'predictor_cls_name': 'CLIPPredictor',
'feature_dimension': 512,
'distance_metric': 'Cosine',
'model_cls': 'CLIPApp',
'second_sequence': 'base64imagestr'
},
'imagetext_retrieval_en': {
'model_type': 'retrieval',
'predictor_cls_name': 'CLIPPredictor',
'feature_dimension': 512,
'distance_metric': 'Cosine',
'model_cls': 'CLIPApp',
'second_sequence': 'base64imagestr'
},
'text_retrieval': {
'model_type': 'retrieval',
'predictor_cls_name': 'FeatureVectorizationPredictor',
'feature_dimension': 768,
'distance_metric': 'Cosine',
'model_cls': 'FeatureVectorization',
'first_sequence': 'text'
}
}
if hasattr(FLAGS, 'local_debug') and FLAGS.local_debug:
model_dir = FLAGS.model_dir
model_type = FLAGS.model_type
defaults.put(
b'model.model_config',
("{\"type\":\"%s\", \"predictor_cls_name\":\"%s\" , \
\"feature_dimension\":\"%s\" , \"distance_metric\":\"%s\"}" %
(model_type, FLAGS.feature_name, FLAGS.feature_dimension,
FLAGS.distance_metric)).encode('utf8'))
defaults.put(b'rpc.worker_threads', b'2')
# start when UT test
if unittest_config:
model_dir = unittest_config['model_dir']
defaults.put(b'model.model_config',
("{\"type\":\"%s\", \"feature_dimension\":\"%s\" , \
\"distance_metric\":\"%s\"}" % (
unittest_config['type'],
unittest_config['feature_dimension'],
unittest_config['distance_metric'],
)).encode('utf8'))
model_config_str = defaults.get('model.model_config',
b'{}').decode('utf-8')
logging.info('model config : %s' % model_config_str)
self._model_config = json.loads(model_config_str)
assert 'type' in self._model_config, \
'type is missing from model_config'
logging.info('create model type %s' % (self._model_config['type']))
self.model_type = self._model_config['type']
assert self.model_type in modelhub_set.keys(), \
'type does not meet the requirements'
logging.info('model type check before torch_check:%s' %
self.model_type)
self.model_type = torch_model_type_check(self.model_type, model_dir)
self._model_config['type'] = self.model_type
logging.info('model type check after torch_check:%s' % self.model_type)
if self.model_type in modelhub_set.keys():
logging.info('Reset by modelhub_set, type: %s' % self.model_type)
self._model_config = modelhub_set[self.model_type]
self.model_type = self._model_config['model_type']
self._model_config['type'] = self._model_config['model_type']
logging.info('final model config : %s' % self._model_config)
if self._model_config['predictor_cls_name'] not in [
'TorchFeatureExtractor', 'TorchFaceFeatureExtractor'
]:
model_dir = os.path.abspath(model_dir)
for file_name in os.listdir(model_dir):
if (os.path.isdir(os.path.join(model_dir, file_name))):
model_dir = os.path.join(model_dir, file_name)
break
elif not model_dir.endswith('pth') and not model_dir.endswith('pt'):
model_path = glob('%s/**/*.pt*' % model_dir,
recursive=True) + glob('%s/*.pt*' % model_dir,
recursive=True)
assert len(model_path) > 0, f'model not found in {model_dir}'
model_dir = model_path[0]
self.er_service = None
logging.info(
'No init service for retrieval, use must init by http-api')
self.init = False
predictor_cls_name = self._model_config.get('predictor_cls_name',
'TorchFeatureExtractor')
self.feature_dimension = self._model_config.get(
'feature_dimension', 2048)
self.distance_metric = self._model_config.get('distance_metric',
'Cosine')
if 'feature_name' in self._model_config.keys():
predictor_cls = get_predictor(predictor_cls_name)
self._predictor = predictor_cls(model_dir,
self._model_config['feature_name'])
else:
if predictor_cls_name == 'CLIPPredictor':
# import when we use
try:
from easynlp.appzoo import CLIPApp, CLIPPredictor
except Exception as e:
raise Error_Import(msg='import CLIPApp/CLIPPredictor error \
from easynlp.appzoo' + str(e))
second_sequence_type = self._model_config.get(
'second_sequence', 'base64imagestr')
self._predictor = CLIPPredictor(
model_dir,
model_cls=CLIPApp,
second_sequence_type=second_sequence_type)
elif predictor_cls_name == 'FeatureVectorizationPredictor':
# import when we use
try:
from easynlp.appzoo import (FeatureVectorization,
FeatureVectorizationPredictor)
except Exception as e:
raise Error_Import(msg='import \
FeatureVectorization/ \
FeatureVectorizationPredictor \
error from easynlp.appzoo ' + str(e))
first_sequence_type = self._model_config.get(
'first_sequence', 'text')
self._predictor = FeatureVectorizationPredictor(
model_dir,
model_cls=FeatureVectorization,
first_sequence=first_sequence_type)
else:
# finally we should remove eval
self._predictor = eval(predictor_cls_name)(model_dir)
logging.info('Feature_extractor : % s' % type(self._predictor))
self.attr_set_by_filelist_log = {}
def process_retrieval(self, data):
"""
处理从客户端收到的数据以进行检索服务。
此函数不需要调用。
Args:
data (dict): 从客户端收到的数据。
Returns:
检索服务返回的结果。
"""
def data_process(data):
"""
Process the given data and return it in a dictionary format.
If the data is a valid JSON string, it is parsed and returned.
If the JSON string contains an 'image' field,
the image is decoded from base64
and returned as a NumPy array or OpenCV image
depending on the value of the 'numpy_input' field.
If any errors occur during parsing or image decoding,
an 'error' field is added to
the returned dictionary with the error message.
Args:
data (bytes): The data to be processed.
Returns:
dict: The processed data.
"""
ret_data = {}
strip_data = data.strip()
try:
if strip_data.startswith(b'{') and strip_data.endswith(b'}'):
try:
ret_data = json.loads(data)
except Exception as e:
error = Error_JsonParse(str(e))
ret_data['error'] = error
if 'error' not in ret_data:
if 'image' in ret_data['function_params'].keys():
try:
assert ('image'
in ret_data['function_params'].keys())
if 'numpy_input' in ret_data[
'function_params'].keys(
) and ret_data['function_params'][
'numpy_input'] is True:
img = base64.b64decode(
ret_data['function_params']['image'])
npimg = np.fromstring(img)
ret_data['function_params'][
'image'] = npimg
else:
ret_data['function_params']['image'] = \
imdecode(base64.b64decode(
ret_data['function_params'][
'image']))
except Exception as e:
error = Error_InputFormat(
'Image base64 data decode error: ' +
str(e))
ret_data['error'] = error
except Exception as e:
error = Error_JsonParse(str(e))
ret_data['error'] = error
return ret_data
request_id = str(uuid.uuid4())
try:
data = data_process(data)
if 'error' in data:
return get_result_str(request_id, error=data['error'])
# image -> oss / feature
img = data['function_params'].pop('image', None)
database_name = data['function_params'].get('database_name', None)
if img is not None:
assert (database_name is not None)
if hasattr(
self.er_service, 'db_getattr'
) and database_name in self.er_service.db_dict.keys():
oss_database_dir = self.er_service.db_getattr(
database_name, 'root_path')
save_path = os.path.join(oss_database_dir, database_name)
img_save_name = data['function_params'].pop(
'savename', None)
img, oss_path = get_image(img, save_path, True,
img_save_name)
data['function_params']['data'] = {
'raw_data_path': oss_path
}
else:
img, oss_path = get_image(img)
feature = self._predictor.predict([img])[0]['feature'].reshape(
[1, -1])
data['function_params']['feature'] = feature
kwargs = data.get('function_params', None)
func_name = data['function_name']
if func_name == 'init':
res, self.er_service, self.init = \
init_er_service(self, **data['function_params'])
else:
if self.init is False:
res = {
'info':
'Call %s before init, should init service first!' %
func_name,
'status':
False
}
else:
db = self.er_service
if db is not None:
if func_name == 'set_predictor':
if kwargs['predictor'] == self._model_config.get(
'predictor_cls_name'):
res = getattr(db,
func_name)(self._predictor,
kwargs['preprocess'])
else:
if type(kwargs) == dict:
res = getattr(db, func_name)(**kwargs)
elif type(kwargs) == list:
res = getattr(db, func_name)(*kwargs)
else:
res = getattr(db, func_name)()
if type(res) != dict:
res = {'info': res}
return get_result_str(request_id, result_dict=res)
except Exception:
logging.info('%s %s' % (request_id, traceback.format_exc()))
error = Error_UnExpectedServer('request_id: %s' % request_id)
return get_result_str(request_id, error=error)
def process(self, data):
"""
处理给定的数据。此函数不需要调用。
Args:
data : 需要被处理的数据。
"""
return self.process_retrieval(data)
def run(args_unused):
# """
# Run the MyProcessor with the specified number of worker threads.
# Args:
# args_unused: Unused arguments.
# Returns:
# None
# """
defaults = allspark.default_properties()
thread_num = int(defaults.get(b'rpc.worker_threads', b'5'))
runner = MyProcessor(worker_threads=thread_num)
runner.run()
if __name__ == '__main__':
parser = argparse.ArgumentParser('ev eas processor local runner')
parser.add_argument('--model_dir',
type=str,
default='/workspace/modelpth/df2_1536_easycv061.pth',
help='local model dir')
parser.add_argument('--model_type',
type=str,
default='image_retrieval',
help='model type')
parser.add_argument('--feature_name',
type=str,
default='TorchFeatureExtractor',
help='output feature name')
parser.add_argument('--feature_dimension',
type=int,
default=1536,
help='output feature dimension')
parser.add_argument('--distance_metric',
type=str,
default='Cosine',
help='retrieval distance metric')
parser.add_argument('--local_debug',
action='store_true',
help='in local debug mode')
FLAGS = parser.parse_args()
run(FLAGS)