import logging
import os
import time
import tracemalloc
import warnings
from functools import wraps
from typing import Dict, List, Optional, Sequence, TypeVar, Union, overload
import numpy as np
import scipy
import torch
from docarray import Document, DocumentArray
from docarray.aliyun.ossio import ENVOSSIO, get_oss_config
OSSPREFIX = 'oss://'
DAPOSTFIX = '.dabin'
DOCARRAY_PREPOCESS_LIST = [
'load_uri_to_image_tensor',
'load_uri_to_base64str', # pai-specific
]
EMBEDDING_ATTR_LIST = ['uri', 'text', 'base64imagestr']
ArrayType = TypeVar(
'ArrayType',
np.ndarray,
scipy.sparse.spmatrix,
torch.Tensor,
Sequence[float],
)
def assert_database_exists(f):
@wraps(f)
def decorated(*args, **kwargs):
self = args[0]
if len(args) > 1:
database_name = args[1]
else:
database_name = kwargs.get('database_name')
if database_name is None:
return f(*args, **kwargs)
if self.docarray_dict.get(database_name, None) is not None:
return f(*args, **kwargs)
else:
fname = str(f.__name__)
return {fname: 'database %s is not exits' % database_name}
return decorated
def assert_predictor_exists(f):
@wraps(f)
def decorated(*args, **kwargs):
self = args[0]
if getattr(self, 'predictor', None) is not None:
return f(*args, **kwargs)
else:
raise Exception('EasyRetrievalDocArray predictor is not exits,\
should call set_predictor() before')
return decorated
def check_all_args_islist(*args, islist=True):
# check type
type_list = [type(i) for i in args]
tlength = -1
for k in args:
if k is not None:
if islist:
if type(k) is not list:
return False, type_list
else:
if tlength == -1:
tlength = len(k)
elif tlength != len(k):
return False, type_list
else:
if type(k) is list:
return False, type_list
return True, type_list
def func_memtimer(function):
@wraps(function)
def function_timer(*args, **kwargs):
print('[Function: {name} start...]'.format(name=function.__name__))
tracemalloc.start()
t0 = time.time()
result = function(*args, **kwargs)
t1 = time.time()
current, peak = tracemalloc.get_traced_memory()
current /= (1024 * 1024)
peak /= (1024 * 1024)
print('[Function: {name} finished, spent time: \
{time:.2f}ms, memory peak: {peak:.2f} \
MB, memory now: {curr:.2f} MB]'.format(
name=function.__name__,
time=(t1 - t0) * 1000,
peak=peak,
curr=current))
print(
'---------------------------------------------------------------')
return result
return function_timer
def format_docarray_return_list(da: DocumentArray):
return [{
'uri': d.uri,
'text': d.text,
'group_id': d.tags['group_id'],
'intra_id': d.tags['intra_id']
} for d in da]
def format_match_return_list(da: Document, scores_key: str):
return [{
'scores': d.scores[scores_key].value,
'uri': d.uri,
'text': d.text,
'group_id': d.tags['group_id'],
'intra_id': d.tags['intra_id']
} for d in da.matches]
class EasyRetrievalDocArray(object):
"""
初始化EasyRetrievalDocArray需要配置oss_config和root_path。
初始化后,实例化对象可用于添加、删除、获取、加载和保存EasyRetrievalDocArray数据库。
拥有EasyRetrievalDocArray数据库后,您可以添加、删除、获取和搜索指定数据库中的数据。
"""
@overload
def __init__(
self,
root_path: Optional[str],
):
# """Create an EasyRetrievalDocArray with root_path."""
...
@overload
def __init__(
self,
root_path: Optional[str],
oss_config: Optional[Dict[str, str]] = None,
):
# """Create an EasyRetrievalDocArray with oss_config."""
...
@overload
def __init__(
self,
root_path: Optional[str],
oss_config_file: Optional[str] = None,
):
# """Create an EasyRetrievalDocArray with oss_config_file."""
...
def __init__(
self,
root_path: Optional[str] = None,
oss_config: Optional[Dict[str, str]] = None,
oss_config_file: Optional[str] = None,
):
"""使用oss_config、root_path和oss_config_file创建一个EasyRetrievalDocArray。
oss_config和oss_config_file使用其中一个即可。
Args:
root_path (Optional[str]):
保存、加载和删除EasyRetrievalDocArray数据库的路径。
oss_config (Optional[Dict[str, str]]):
包含OSS配置信息的字典,用来设置oss.
oss_config_file (Optional[str]):
包含OSS配置信息的文件的路径。
Raises:
Exception: 如果root_path是oss地址但未提供oss_config或oss_config_file。
Returns:
无
"""
self.root_path = root_path
if root_path is not None:
if root_path.startswith(OSSPREFIX) and \
(oss_config is None and oss_config_file is None):
raise Exception('EasyRetrievalDocArray with oss root_path\
should init with oss_config')
if oss_config is not None:
self.oss_config = oss_config
ENVOSSIO.access_oss(**oss_config)
if oss_config_file is not None:
self.oss_config = get_oss_config(oss_config_file)
ENVOSSIO.access_oss_byconfig(oss_config_file)
self.docarray_dict = {}
self.load()
self.predictor = None
self.preprocess = None
[docs] def set_root_path(self, root_path: str):
"""设置EasyRetrievalDocArray的根路径(root_path)以保存、删除和加载数据库。
当您完成初始化后,想要修改存储路径可以调用这个函数。
Args:
root_path (str): 要设置的根路径(root_path).
Returns:
无
"""
if root_path.startswith(OSSPREFIX):
if ENVOSSIO.is_accessed('ENVOSSIO'):
target_path = os.path.join(root_path, '*' + DAPOSTFIX)
list_dir = ENVOSSIO.glob(target_path)
logging.info('root_path contains \
database {} '.format(list_dir))
self.root_path = root_path
return
[docs] def set_oss_config(self, oss_config: Optional[Dict[str, str]]):
"""为EasyRetrievalDocArray配置oss信息.
当您完成初始化后,想要修改oss配置信息可以调用这个函数。
Args:
oss_config (Optional[Dict[str, str]]):
包含OSS配置信息的dict。
Returns:
无
"""
ENVOSSIO.access_oss(**oss_config)
return
[docs] def set_predictor(self,
predictor,
preprocess: str = 'load_uri_image_to_tensor'):
"""为EasyRetrievalDocArray设置predictor,使其获得特征提取的能力。
这个函数默认会被调用。
Args:
predictor: 要设置的predictor的名称。
可选的predictor:
TorchFeatherExtractor,TorchFaceFeatherExtractor,
imagetext_retrieval_en, imagetext_retrieval_cn,
FeatureVectorizationPredictor
preprocess (str):
要使用的预处理函数。针对输入数据是uri图片数据。
必须是DOCARRAY_PREPOCESS_LIST存在的选项
[‘load_uri_to_image_tensor’,’load_uri_to_base64str’]。
Raises:
AssertionError: 如果preprocess不在DOCARRAY_PREPOCESS_LIST中。
[‘load_uri_to_image_tensor’,’load_uri_to_base64str’]
Returns:
无
"""
self.predictor = predictor
assert preprocess in DOCARRAY_PREPOCESS_LIST, 'set_predictor\
preprocess must in {}'.format(DOCARRAY_PREPOCESS_LIST)
self.preprocess = preprocess
return
@overload
def load(self):
"""从root_path加载EasyRetrievalDocArray数据库。"""
...
[docs] def load(self, path: Optional[str] = None):
"""从指定的路径加载EasyRetrievalDocArray数据库。
Args:
path (可选[str]):
加载EasyRetrievalDocArray数据库的路径。如果没有指定,将使用root_path。
Returns:
返回加载的EasyRetrievalDocArray数据库信息。
"""
if path is None:
path = self.root_path
if type(path) is str:
target_path = os.path.join(path, '*' + DAPOSTFIX)
list_dir = ENVOSSIO.glob(target_path)
# get all docarray object binaryfile in root_path, load them
da_bin = {}
for filepath in list_dir:
if filepath.endswith(DAPOSTFIX):
da_name = filepath.split('/')[-1][:-6]
da_bin[da_name] = da_bin
if filepath.startswith(OSSPREFIX):
self.docarray_dict[da_name] = DocumentArray.pull_oss(
filepath)
else:
self.docarray_dict[da_name] = DocumentArray.load(
filepath)
self.save()
return {'load': self.get()['get']}
@overload
def save(self):
# """Load an EasyRetrievalDocArray to root_path"""
...
[docs] def save(self, path: Optional[str] = None):
"""将EasyRetrievalDocArray数据库保存到指定的路径。
Args:
path (Optional[str]):
保存EasyRetrievalDocArray数据库的路径。如果没有指定,将使用root_path。
Returns:
返回成功保存的消息。
"""
if path is None:
path = self.root_path
for da_name, da in self.docarray_dict.items():
if type(path) is str:
da_output_path = os.path.join(path, da_name + DAPOSTFIX)
if path.startswith(OSSPREFIX):
da.push_oss(da_output_path)
else:
da.save(da_output_path)
return {'save': 'Done'}
# -------------------------------------------------------------------------------#
@overload
@assert_database_exists
def get(
self,
database_name: str,
):
# """get EasyRetrievalDocArray dababase info by database_name"""
...
@overload
def get(
self,
get_all: bool,
):
# """get EasyRetrievalDocArray all database info"""
...
[docs] @assert_database_exists
def get(self,
database_name: Optional[str] = None,
get_all: Optional[bool] = True):
"""获取指定名称的EasyRetrievalDocArray数据库
或所有EasyRetrievalDocArray数据库。
Args:
database_name (可选[str]):
要获取的数据库的名称。
get_all (可选[bool]):
指示是否返回所有数据库的标志。默认值为True。
Returns:
返回获取到的数据库的具体信息。
"""
if get_all:
return {
'get': [
self.docarray_dict.get(k)._get_raw_summary()
for k in self.docarray_dict.keys()
]
}
else:
return {
'get':
self.docarray_dict.get(database_name)._get_raw_summary()
}
[docs] def add(
self,
database_name: str,
feature_dim: int = 1536,
# storage=None, only memory storage support query easily
):
"""添加一个新的空EasyRetrievalDocArray数据库。
Args:
database_name (str):
新数据库的名称。
feature_dim (int):
数据库中特征向量的大小。主要跟predictor相关。默认值为1536。
Returns:
返回添加数据库的具体信息。
"""
# dummy_doc = [
# Document(
# uri='dummy',
# text='dummy',
# tags =dict(
# group_id='-1',
# intra_id=-1,
# feature_dim=feature_dim
# )
# )
# for _ in range(1)
# ]
self.docarray_dict[database_name] = DocumentArray()
self.save()
return {'add': self.get(database_name)['get']}
[docs] @assert_database_exists
def delete(self, database_name: str):
"""
删除加载数据库中的指定数据库。如果root path是oss地址,也会删除
oss地址的数据库文件。
Args:
database_name (str): 要删除的数据库名称。
Raises:
Exception:
如果root_path是oss地址,并且ENVOSSIO未能从root path中删除数据库文件。
即删除失败。
Returns:
提示删除库不再存在。(database (database_name) is not exits.)
"""
del self.docarray_dict[database_name]
if type(self.root_path) is str:
output_path = os.path.join(self.root_path,
database_name + DAPOSTFIX)
if ENVOSSIO.exists(output_path):
try:
ENVOSSIO.remove(output_path)
except BaseException:
raise Exception(
'ENVOSSIO remove {output_path} failed'.format(
output_path=output_path))
self.save()
return {'delete': self.get(database_name)}
# ------------------------------------------------------------------------------------#
[docs] @assert_database_exists
@func_memtimer
def db_set(
self,
database_name: str,
group_id: Union[str, List[str]],
uri: Optional[Union[str, List[str]]] = None,
text: Optional[Union[str, List[str]]] = None,
base64imagestr: Optional[Union[str, List[str]]] = None,
embedding_attr: Optional[str] = None,
embedding: Optional[ArrayType] = None,
intra_id: Optional[Union[int, List[int]]] = [0],
query_dict: Optional[Dict[str, dict]] = None,
device: Optional[str] = None,
):
"""
往指定的EasyRetrievalDocArray数据库中注册数据。
如果使用了embedding字段,则uri,text以及base64imagestr三个字段可以省略。
如果没使用embedding字段,则需要使用uri,text以及base64imagestr三个字段。
并可以选择在embedding_attr告知使用的字段。
(group_id,intra_id)组成了数据在数据库中的坐标信息。
Args:
database_name (str):
要注册进数据的数据库名称。
group_id (Union[str, List[str]]):
要注册的数据的group_id。
uri (Optional[Union[str, List[str]]], 可选):
要注册的数据的URI。
text (Optional[Union[str, List[str]]], 可选):
要注册的数据文本描述。
base64imagestr (Optional[Union[str, List[str]]], 可选):
要注册的数据的base64image字符串。
embedding_attr (Optional[str], 可选):
要注册数据的类型。uri,text或者base64imagestr。
embedding (Optional[ArrayType], 可选):
要注册的数据的embedding。
intra_id (Optional[Union[int, List[int]]], 可选):
同一个group_id下的内部ID。
query_dict (Optional[Dict[str, dict]], 可选):
要使用的查询的词典。
device (Optional[str], 可选):
用于生成要注册的数据的embedding的设备。
可选device : cpu或者gpu。
Raises:
AssertionError:
如果输入参数[group_id/uri/text/base64imagestr]
个数不同或输入形式不是列表。
AssertionError:
如果未指定embedding_attr,并且输入参数
[uri,text,base64imagestr]三者都是空值。
AssertionError:
如果embedding_attr不在EMBEDDING_ATTR_LIST中。需要是三者之一。
['uri', 'text', 'base64imagestr']
Returns:
返回注册数据的uri、文本描述、组ID和内部id信息。
"""
if query_dict is None:
query_dict = {'tags__group_id': {'$eq': group_id}}
if intra_id != 0:
query_dict['tags__intra_id'] = {'$eq': intra_id}
query_da = self.docarray_dict[database_name].find(query_dict)
# check polymorphism of
all_not_list, type_list = check_all_args_islist(group_id,
uri,
text,
base64imagestr,
islist=False)
is_all_list, type_list = check_all_args_islist(group_id,
uri,
text,
base64imagestr,
islist=True)
assert is_all_list ^ all_not_list, 'db_set : input param \
[group_id/uri/text/base64imagestr] should be all same \
length List or Not a List at all, now is {type_dict}'.format(
type_dict=type_list)
# refine to DocumentArray init param
if all_not_list:
group_id = [group_id]
uri = [uri]
text = [text]
base64imagestr = [base64imagestr]
if embedding is not None and len(embedding.shape) == 1:
embedding = [embedding]
if type(intra_id) is not list:
intra_id = [intra_id]
intra_id += (len(group_id) - len(intra_id)) * [0]
# init insert DocumentArray
tmp_doc = [
Document(
uri=uri[tidx] if uri is not None else None,
text=text[tidx] if text is not None else None,
tags=dict(
group_id=group_id[tidx],
intra_id=intra_id[tidx],
),
embedding=embedding[tidx] if embedding is not None else None,
) for tidx in range(len(uri))
]
if base64imagestr is not None:
for tidx, tdoc in enumerate(tmp_doc):
setattr(tdoc, 'base64imagestr', base64imagestr[tidx])
insert_da = DocumentArray(tmp_doc)
# generate embedding
if device is None:
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if embedding is None:
# traits embedding_attr from input param
if embedding_attr is None:
input_attr_num = 0
for k in EMBEDDING_ATTR_LIST:
if eval(k) is not None and eval(k)[0] is not None:
embedding_attr = k
input_attr_num += 1
if input_attr_num != 1:
embedding_attr = None
assert embedding_attr is not None and embedding_attr \
in EMBEDDING_ATTR_LIST, 'db_set : with no embedding \
input and multi attr[uri/text/base64imagestr] input,\
must set embedding_attr[uri/text/base64imagestr]\
to call generate_embedding'
self.generate_embedding(insert_da, embedding_attr, device)
if len(query_da) == 0:
self.docarray_dict[database_name].extend(insert_da)
if len(query_da) >= 1:
warning_info = 'call db_set to set exist document id={ids} \
in database {db_name}'.format(ids=query_da[:, 'id'],
db_name=database_name)
warnings.warn(warning_info)
if len(tmp_doc) > 1:
assert len(query_da) == len(
tmp_doc), 'call db_set to set exist document id={ids} in \
database {db_name} with {input_num} input doc'.format(
ids=query_da[:, 'id'],
db_name=database_name,
input_num=len(tmp_doc))
self.docarray_dict[database_name][query_da[:, 'id']] = tmp_doc
else:
self.docarray_dict[database_name][
query_da[:, 'id']] = tmp_doc * len(query_da)
query_da = self.docarray_dict[database_name].find(query_dict)
return {'db_set': format_docarray_return_list(query_da)}
[docs] @assert_database_exists
@func_memtimer
def db_get(
self,
database_name: str,
group_id: Optional[str] = None,
intra_id: Optional[int] = None,
query_dict: Optional[Dict[str, dict]] = None,
):
"""
从指定的EasyRetrievalDocArray数据库获取数据。
Args:
database_name (str):
要获取数据的数据库名称。
group_id (Optional[str], 可选):
要获取的数据的group_id。
intra_id (Optional[int], 可选):
要获取的数据group下的intra_id。
query_dict (Optional[Dict[str, dict]], 可选):
要使用的查询的词典。
Returns:
返回所获取数据的uri、文本描述、组ID和内部ID信息。
"""
if query_dict is None:
query_dict = {'tags__group_id': {'$eq': group_id}}
if intra_id is not None:
query_dict['tags__intra_id'] = {'$eq': intra_id}
query_da = self.docarray_dict[database_name].find(query_dict)
return {'db_get': format_docarray_return_list(query_da)}
[docs] @assert_database_exists
@func_memtimer
def db_delete(
self,
database_name: str,
group_id: Optional[str] = None,
intra_id: Optional[int] = None,
query_dict: Optional[Dict[str, dict]] = None,
):
"""
从指定的EasyRetrievalDocArray数据库中删除数据。
Args:
database_name (str):
要删除数据的数据库名称。
group_id (Optional[str], 可选):
要删除的数据的group_id。
intra_id (Optional[int], 可选):
要删除的数据group下的intra_id。
query_dict (Optional[Dict[str, dict]], 可选):
要使用的查询的词典。
Returns:
返回已删除数据的uri、文本描述、组ID和内部ID信息。
"""
if query_dict is None:
query_dict = {'tags__group_id': {'$eq': group_id}}
if intra_id is not None:
query_dict['tags__intra_id'] = {'$eq': intra_id}
query_da = self.docarray_dict[database_name].find(query_dict)
delete_info = format_docarray_return_list(query_da)
delete_ids = query_da[:, 'id']
if len(delete_ids) > 0:
del self.docarray_dict[database_name][delete_ids]
return {'db_delete': delete_info}
[docs] @assert_database_exists
@func_memtimer
def db_search(
self,
database_name: str,
uri: Union[str, List[str]] = None,
text: Union[str, List[str]] = None,
base64imagestr: Union[str, List[str]] = None,
embedding_attr: Optional[str] = None,
embedding: Optional[ArrayType] = None,
search_topk: Optional[int] = 10,
metric: Optional[str] = 'cosine',
device: Optional[str] = None,
):
"""
从指定的EasyRetrievalDocArray数据库中搜索数据。
如果使用了embedding字段,则uri,text以及base64imagestr三个字段可以省略。
如果没使用embedding字段,则需要使用uri,text以及base64imagestr三个字段。
并可以选择在embedding_attr告知使用的字段。
Args:
database_name (str):
要搜索数据的数据库名称。
uri (Union[str, List[str]], 可选):
要搜索的数据的URI。
text (Union[str, List[str]], 可选):
要搜索的数据文本描述。
base64imagestr (Union[str, List[str]], 可选):
要搜索的数据的base64image字符串。
embedding_attr (Optional[str], 可选):
要搜索数据的类型。uri,text或者base64imagestr。
embedding (Optional[ArrayType], 可选):
要搜索的数据的embedding。
search_topk (Optional[int], 可选):
检索返回的记录数目。
metric (Optional[str], 可选): 用于搜索的指标。
可选metric : 'cosine','SquaredEuclidean'
device (Optional[str], 可选):用于搜索的设备。
可选device : cpu,gpu。
Raises:
AssertionError:
如果输入参数[group_id/uri/text/base64imagestr]
个数不同或输入形式不是列表。
AssertionError:
如果未指定embedding_attr,并且输入参数
[uri,text,base64imagestr]三者都是空值。
AssertionError:
如果embedding_attr不在EMBEDDING_ATTR_LIST中。需要是三者之一。
['uri', 'text', 'base64imagestr']
Returns:
返回搜索得到数据的相似性分数、uri、文本描述、组ID和内部ID信息。
"""
# check polymorphism of
all_not_list, type_list = check_all_args_islist(uri,
text,
base64imagestr,
islist=False)
is_all_list, type_list = check_all_args_islist(uri,
text,
base64imagestr,
islist=True)
input_all_none = (uri is None) and (text is None) and (base64imagestr
is None)
if not input_all_none:
assert is_all_list ^ all_not_list, 'db_search : \
input param [uri/text/base64imagestr] should \
be all same length List or Not a List at \
all, now is {type_dict}'.format(type_dict=type_list)
# refine to DocumentArray init param
if all_not_list:
uri = [uri]
text = [text]
base64imagestr = [base64imagestr]
if embedding is not None and len(embedding.shape) == 1:
embedding = [embedding]
# init insert DocumentArray
tmp_doc = [
Document(
uri=uri[tidx] if uri is not None else None,
text=text[tidx] if text is not None else None,
embedding=embedding[tidx] if embedding is not None else None,
) for tidx in range(len(uri))
]
if base64imagestr is not None:
for tidx, tdoc in enumerate(tmp_doc):
setattr(tdoc, 'base64imagestr', base64imagestr[tidx])
query_da = DocumentArray(tmp_doc)
if device is None:
device = 'gpu' if torch.cuda.is_available() else 'cpu'
if embedding is None:
if embedding_attr is None:
input_attr_num = 0
for k in EMBEDDING_ATTR_LIST:
if eval(k) is not None and eval(k)[0] is not None:
embedding_attr = k
input_attr_num += 1
if input_attr_num != 1:
embedding_attr = None
assert embedding_attr is not None and embedding_attr \
in EMBEDDING_ATTR_LIST, 'db_search : with no\
embedding input, must set embedding_attr[uri/\
text/base64imagestr] to call generate_embedding'
self.generate_embedding(query_da, embedding_attr, device)
assert getattr(
query_da, 'embeddings', None
) is not None, 'db_search build temperal query_da must has attr \
embeddings and embeddings not be None'
query_da.match(darray=self.docarray_dict[database_name],
metric=metric,
limit=search_topk,
device=device)
res = []
for q in query_da:
res.append(format_match_return_list(q, metric))
return {'db_search': res}
# --------------------------------------------------------------------------#
@assert_predictor_exists
def generate_embedding(self,
da: DocumentArray,
embedding_attr: str,
device: Optional[str] = 'cpu'):
# """
# Generates embeddings for a DocumentArray.
# Args:
# da (DocumentArray):
# The DocumentArray to generate embeddings for.
# embedding_attr (str):
# The attribute to use for generating embeddings.
# device (Optional[str], optional):
# The device to use for generating embeddings.
# Returns:
# Call different methods according to
# the data format to generate embedding.
# """
return getattr(
self, 'generate_embedding_{embedding_attr}'.format(
embedding_attr=embedding_attr))(da, device)
def generate_embedding_text(self, da, device):
# """
# Generates embeddings for a DocumentArray using text.
# Args:
# da (DocumentArray): The DocumentArray to generate embeddings for.
# device (str): The device to use for generating embeddings.
# Returns:
# None
# """
da.embed(self.predictor, device=device)
return
def generate_embedding_uri(self, da, device):
# """
# Generates embeddings for a DocumentArray using URI.
# Args:
# da (DocumentArray): The DocumentArray to generate embeddings for.
# device (str): The device to use for generating embeddings.
# Returns:
# None
# """
da.apply(lambda d: getattr(d, self.preprocess)(), num_worker=10)
da.embed(self.predictor, device=device)
return
def generate_embedding_base64imagestr(self, da, device):
# """
# Generates embeddings for a DocumentArray using base64imagestr.
# Args:
# da (DocumentArray): The DocumentArray to generate embeddings for.
# device (str): The device to use for generating embeddings.
# Returns:
# None
# """
da.embed(self.predictor, device=device)
return