Source code for app

import argparse
import base64
import json
import logging
import os
import traceback
import uuid
from glob import glob
from typing import Dict, Optional

import allspark
import cv2
import numpy as np
from easycv.predictors.feature_extractor import (TorchFaceFeatureExtractor,
                                                 TorchFeatureExtractor)
from easycv.utils import json_utils
from easynlp.appzoo import CLIPPredictor
from ev_error import (Error_Import, Error_InputFormat, Error_JsonParse,
                      Error_UnExpectedServer)
from utils.eas_utils import get_image, torch_model_type_check

logging.basicConfig(
    format='[%(levelname)s] %(asctime)s %(filename)s:%(lineno)d : %(message)s',
    level=logging.INFO)

FLAGS = None


def get_predictor(model_type):
    # """
    # Get a predictor based on the specified model type.

    # Parameters
    # ----------
    # model_type : str
    #     The type of the predictor. Available options are:
    #     - 'torch_feature_extractor'
    #     - 'dummy'
    #     - 'exception'
    #     - 'torch_face_feature_extractor'
    #     - 'mutil_model_extractor'

    # Returns
    # -------
    # predictor : object
    #     The predictor object.

    # Raises
    # ------
    # AssertionError
    #     If the specified model type is not available.
    # """

    class DummyPredictor():
        def __init__(self, model_dir):
            pass

        def predict(self, images, batch_size=1):
            """
            Predict the output for the given images.

            Parameters
            ----------
            images : list of numpy.ndarray
                The images to predict.
            batch_size : int, optional
                The batch size to use for prediction. Default is 1.

            Returns
            -------
            results : list of dict
                The prediction results for each image.
                Each dict contains the following keys:
                - 'result1' : str
                - 'result2' : str
            """
            results = [{
                'result1': 'dummy',
                'result2': 'dummy'
            } for i in range(len(images))]
            return results

    class ExceptionPredictor():
        def __init__(self, model_dir):
            pass

        def predict(self, images, batch_size=1):
            """
            Predict the output for the given images.

            Parameters
            ----------
            images : list of numpy.ndarray
                The images to predict.
            batch_size : int, optional
                The batch size to use for prediction. Default is 1.

            Raises
            ------
            RuntimeError
                If a dummy exception occurs.
            """
            raise RuntimeError('dummy exception')

    resigter_map = {
        'torch_feature_extractor': TorchFeatureExtractor,
        'dummy': DummyPredictor,
        'exception': ExceptionPredictor,
        'torch_face_feature_extractor': TorchFaceFeatureExtractor,
        'mutil_model_extractor': CLIPPredictor
    }

    assert model_type in resigter_map, \
        'invalid model_type %s available ones are %s' \
        % (model_type, list(resigter_map.keys()))

    return resigter_map[model_type]


def imdecode(image_data):
    # """
    # Decode image data using OpenCV.

    # Parameters
    # ----------
    # image_data : bytes
    #     Image data in bytes format.

    # Returns
    # -------
    # numpy.ndarray
    #     Decoded image as a NumPy array.
    # """
    return cv2.imdecode(image_data)


def get_result_str(request_id, result_dict=None, error=None):
    # """
    # Get result string for a request.

    # Parameters
    # ----------
    # request_id : str
    #     ID of the request.
    # result_dict : dict, optional
    #     Dictionary containing the result of the request.
    # error : Exception, optional
    #     Error that occurred while processing the request.

    # Returns
    # -------
    # tuple
    #     Tuple containing the result string and the status code.
    # """
    result = {}
    result['request_id'] = request_id
    if error is not None:
        result['success'] = False
        result['error_code'] = error.code
        result['error_msg'] = error.msg
        stat = error.code
    elif result_dict is not None:
        result['success'] = True
        result.update(result_dict)
        stat = 200

    result_str = json_utils.compat_dumps(result).encode('utf8')
    return result_str, stat


[docs]def init_er_service(self, backend='oss', **kwargs): """ 初始化检索服务。 调用该方式,function_name,设置为'init'. Args: backend:backend identifier. 当backend是'oss'时,function_params必须包含'root_path'和oss_io_config字段。 root_path:EasyRetrievalDocArray保存和加载的路径。 oss_io_config:包含OSS配置信息的字典。 当backend是'elasticsearch'时,未完待续。 Returns: 带有“status”字段的信息,指示初始化的成败,以及包含包含额外信息的消息的“info”字段。 """ self.backend = backend assert self.backend in ['oss', 'elasticsearch'] if self.backend == 'oss': root_path = kwargs.pop('root_path') oss_io_config = kwargs.pop('oss_io_config') if root_path is None or oss_io_config is None: return { 'status': False, 'info': 'root_path & oss_io_config \ should be set when use oss backend!' } self.oss_root = root_path try: from easy_retrieval.pai_docarray import \ EasyRetrievalDocArray except BaseException as e: error = Error_Import(str(e)) return {'info': error, 'status': False} er_service = EasyRetrievalDocArray(root_path=root_path, oss_config=oss_io_config) er_service.predictor = self._predictor init_status = True return {'info': 'oss backend success!', 'status': True}, \ er_service, init_status if self.backend == 'elasticsearch': error = NotImplementedError() return {'info': error, 'status': False}
class MyProcessor(allspark.BaseProcessor): # """ MyProcessor is a example # you can send mesage like this to predict # curl -v http://127.0.0.1:8080/api/predict/service_name -d '2 105' # """ def initialize(self, unittest_config: Optional[Dict] = {}): """ 此函数加载preditor模型并设置其属性值。 此函数不需要调用。 Args: _model_config: preditor模型类型、特征维度和度量方式的字典。 model_type: 正在使用的preditor类型。 feature_dimension:preditor特征向量的维度。 distance_metric: 用于比较特征向量的距离度量。 _predictor: preditor预测器类的实例。 """ model_dir = '../../model/' defaults = allspark.default_properties() # this abstract use for modelhub model_type # - (hyperparameter)config set, # in modelhub every model is unique with model_type # (we should know this is not a good design for easy_predict) modelhub_set = { # ------------- retrieval & tf retrieval------------- 'image_retrieval': { 'model_type': 'retrieval', 'predictor_cls_name': 'TorchFeatureExtractor', 'feature_dimension': 1536, 'distance_metric': 'Cosine' }, 'faceid_retrieval': { 'model_type': 'retrieval', 'predictor_cls_name': 'TorchFaceFeatureExtractor', 'feature_dimension': 512, 'distance_metric': 'SquaredEuclidean' }, 'imagetext_retrieval_cn': { 'model_type': 'retrieval', 'predictor_cls_name': 'CLIPPredictor', 'feature_dimension': 512, 'distance_metric': 'Cosine', 'model_cls': 'CLIPApp', 'second_sequence': 'base64imagestr' }, 'imagetext_retrieval_en': { 'model_type': 'retrieval', 'predictor_cls_name': 'CLIPPredictor', 'feature_dimension': 512, 'distance_metric': 'Cosine', 'model_cls': 'CLIPApp', 'second_sequence': 'base64imagestr' }, 'text_retrieval': { 'model_type': 'retrieval', 'predictor_cls_name': 'FeatureVectorizationPredictor', 'feature_dimension': 768, 'distance_metric': 'Cosine', 'model_cls': 'FeatureVectorization', 'first_sequence': 'text' } } if hasattr(FLAGS, 'local_debug') and FLAGS.local_debug: model_dir = FLAGS.model_dir model_type = FLAGS.model_type defaults.put( b'model.model_config', ("{\"type\":\"%s\", \"predictor_cls_name\":\"%s\" , \ \"feature_dimension\":\"%s\" , \"distance_metric\":\"%s\"}" % (model_type, FLAGS.feature_name, FLAGS.feature_dimension, FLAGS.distance_metric)).encode('utf8')) defaults.put(b'rpc.worker_threads', b'2') # start when UT test if unittest_config: model_dir = unittest_config['model_dir'] defaults.put(b'model.model_config', ("{\"type\":\"%s\", \"feature_dimension\":\"%s\" , \ \"distance_metric\":\"%s\"}" % ( unittest_config['type'], unittest_config['feature_dimension'], unittest_config['distance_metric'], )).encode('utf8')) model_config_str = defaults.get('model.model_config', b'{}').decode('utf-8') logging.info('model config : %s' % model_config_str) self._model_config = json.loads(model_config_str) assert 'type' in self._model_config, \ 'type is missing from model_config' logging.info('create model type %s' % (self._model_config['type'])) self.model_type = self._model_config['type'] assert self.model_type in modelhub_set.keys(), \ 'type does not meet the requirements' logging.info('model type check before torch_check:%s' % self.model_type) self.model_type = torch_model_type_check(self.model_type, model_dir) self._model_config['type'] = self.model_type logging.info('model type check after torch_check:%s' % self.model_type) if self.model_type in modelhub_set.keys(): logging.info('Reset by modelhub_set, type: %s' % self.model_type) self._model_config = modelhub_set[self.model_type] self.model_type = self._model_config['model_type'] self._model_config['type'] = self._model_config['model_type'] logging.info('final model config : %s' % self._model_config) if self._model_config['predictor_cls_name'] not in [ 'TorchFeatureExtractor', 'TorchFaceFeatureExtractor' ]: model_dir = os.path.abspath(model_dir) for file_name in os.listdir(model_dir): if (os.path.isdir(os.path.join(model_dir, file_name))): model_dir = os.path.join(model_dir, file_name) break elif not model_dir.endswith('pth') and not model_dir.endswith('pt'): model_path = glob('%s/**/*.pt*' % model_dir, recursive=True) + glob('%s/*.pt*' % model_dir, recursive=True) assert len(model_path) > 0, f'model not found in {model_dir}' model_dir = model_path[0] self.er_service = None logging.info( 'No init service for retrieval, use must init by http-api') self.init = False predictor_cls_name = self._model_config.get('predictor_cls_name', 'TorchFeatureExtractor') self.feature_dimension = self._model_config.get( 'feature_dimension', 2048) self.distance_metric = self._model_config.get('distance_metric', 'Cosine') if 'feature_name' in self._model_config.keys(): predictor_cls = get_predictor(predictor_cls_name) self._predictor = predictor_cls(model_dir, self._model_config['feature_name']) else: if predictor_cls_name == 'CLIPPredictor': # import when we use try: from easynlp.appzoo import CLIPApp, CLIPPredictor except Exception as e: raise Error_Import(msg='import CLIPApp/CLIPPredictor error \ from easynlp.appzoo' + str(e)) second_sequence_type = self._model_config.get( 'second_sequence', 'base64imagestr') self._predictor = CLIPPredictor( model_dir, model_cls=CLIPApp, second_sequence_type=second_sequence_type) elif predictor_cls_name == 'FeatureVectorizationPredictor': # import when we use try: from easynlp.appzoo import (FeatureVectorization, FeatureVectorizationPredictor) except Exception as e: raise Error_Import(msg='import \ FeatureVectorization/ \ FeatureVectorizationPredictor \ error from easynlp.appzoo ' + str(e)) first_sequence_type = self._model_config.get( 'first_sequence', 'text') self._predictor = FeatureVectorizationPredictor( model_dir, model_cls=FeatureVectorization, first_sequence=first_sequence_type) else: # finally we should remove eval self._predictor = eval(predictor_cls_name)(model_dir) logging.info('Feature_extractor : % s' % type(self._predictor)) self.attr_set_by_filelist_log = {} def process_retrieval(self, data): """ 处理从客户端收到的数据以进行检索服务。 此函数不需要调用。 Args: data (dict): 从客户端收到的数据。 Returns: 检索服务返回的结果。 """ def data_process(data): """ Process the given data and return it in a dictionary format. If the data is a valid JSON string, it is parsed and returned. If the JSON string contains an 'image' field, the image is decoded from base64 and returned as a NumPy array or OpenCV image depending on the value of the 'numpy_input' field. If any errors occur during parsing or image decoding, an 'error' field is added to the returned dictionary with the error message. Args: data (bytes): The data to be processed. Returns: dict: The processed data. """ ret_data = {} strip_data = data.strip() try: if strip_data.startswith(b'{') and strip_data.endswith(b'}'): try: ret_data = json.loads(data) except Exception as e: error = Error_JsonParse(str(e)) ret_data['error'] = error if 'error' not in ret_data: if 'image' in ret_data['function_params'].keys(): try: assert ('image' in ret_data['function_params'].keys()) if 'numpy_input' in ret_data[ 'function_params'].keys( ) and ret_data['function_params'][ 'numpy_input'] is True: img = base64.b64decode( ret_data['function_params']['image']) npimg = np.fromstring(img) ret_data['function_params'][ 'image'] = npimg else: ret_data['function_params']['image'] = \ imdecode(base64.b64decode( ret_data['function_params'][ 'image'])) except Exception as e: error = Error_InputFormat( 'Image base64 data decode error: ' + str(e)) ret_data['error'] = error except Exception as e: error = Error_JsonParse(str(e)) ret_data['error'] = error return ret_data request_id = str(uuid.uuid4()) try: data = data_process(data) if 'error' in data: return get_result_str(request_id, error=data['error']) # image -> oss / feature img = data['function_params'].pop('image', None) database_name = data['function_params'].get('database_name', None) if img is not None: assert (database_name is not None) if hasattr( self.er_service, 'db_getattr' ) and database_name in self.er_service.db_dict.keys(): oss_database_dir = self.er_service.db_getattr( database_name, 'root_path') save_path = os.path.join(oss_database_dir, database_name) img_save_name = data['function_params'].pop( 'savename', None) img, oss_path = get_image(img, save_path, True, img_save_name) data['function_params']['data'] = { 'raw_data_path': oss_path } else: img, oss_path = get_image(img) feature = self._predictor.predict([img])[0]['feature'].reshape( [1, -1]) data['function_params']['feature'] = feature kwargs = data.get('function_params', None) func_name = data['function_name'] if func_name == 'init': res, self.er_service, self.init = \ init_er_service(self, **data['function_params']) else: if self.init is False: res = { 'info': 'Call %s before init, should init service first!' % func_name, 'status': False } else: db = self.er_service if db is not None: if func_name == 'set_predictor': if kwargs['predictor'] == self._model_config.get( 'predictor_cls_name'): res = getattr(db, func_name)(self._predictor, kwargs['preprocess']) else: if type(kwargs) == dict: res = getattr(db, func_name)(**kwargs) elif type(kwargs) == list: res = getattr(db, func_name)(*kwargs) else: res = getattr(db, func_name)() if type(res) != dict: res = {'info': res} return get_result_str(request_id, result_dict=res) except Exception: logging.info('%s %s' % (request_id, traceback.format_exc())) error = Error_UnExpectedServer('request_id: %s' % request_id) return get_result_str(request_id, error=error) def process(self, data): """ 处理给定的数据。此函数不需要调用。 Args: data : 需要被处理的数据。 """ return self.process_retrieval(data) def run(args_unused): # """ # Run the MyProcessor with the specified number of worker threads. # Args: # args_unused: Unused arguments. # Returns: # None # """ defaults = allspark.default_properties() thread_num = int(defaults.get(b'rpc.worker_threads', b'5')) runner = MyProcessor(worker_threads=thread_num) runner.run() if __name__ == '__main__': parser = argparse.ArgumentParser('ev eas processor local runner') parser.add_argument('--model_dir', type=str, default='/workspace/modelpth/df2_1536_easycv061.pth', help='local model dir') parser.add_argument('--model_type', type=str, default='image_retrieval', help='model type') parser.add_argument('--feature_name', type=str, default='TorchFeatureExtractor', help='output feature name') parser.add_argument('--feature_dimension', type=int, default=1536, help='output feature dimension') parser.add_argument('--distance_metric', type=str, default='Cosine', help='retrieval distance metric') parser.add_argument('--local_debug', action='store_true', help='in local debug mode') FLAGS = parser.parse_args() run(FLAGS)