基于 LangChain 自定义 Embeddings
在 LangChain 中支持 OpenAI、LLAMA 等大模型 Embeddings 的调用接口,不过没有内置所有大模型,但是允许用户自定义 Embeddings 类型。 接下来以 ZhipuAI 为例,基于 LangChain 自定义 Embeddings。
设计思路
要实现自定义 Embeddings,需要定义一个自定义类继承自 LangChain 的 Embeddings 基类,然后定义三个函数
_embed
: 接受一个字符串,并返回一个存放 Embeddings 的 List[float],即模型的核心调用
embed_query
: 用于对单个字符串 (query) 进行 embedding
embed_documents
: 用于对字符串列表 (documents) 进行 embedding
第三方库 1 2 3 4 5 6 7 8 from __future__ import annotationsimport loggingfrom typing import Any , Dict , List , Optional from langchain.embeddings.base import Embeddingsfrom langchain.pydantic_v1 import BaseModel, root_validatorfrom langchain.utils import get_from_dict_or_env
自定义 Embedding ZhipuAIEmbeddings 定义一个继承自 Embeddings 类的自定义 Embeddings 类:
1 2 3 4 5 class ZhipuAIEmbeddings (BaseModel, Embeddings): """`Zhipuai Embeddings` embedding models.""" zhipuai_api_key: Optional [str ] = None """Zhipuai application apikey"""
root_validator
接收一个函数作为参数,该函数包含需要校验的逻辑。函数应该返回一个字典,其中包含经过校验的数据。如果校验失败,则抛出一个 ValueError
异常。
装饰器 root_validator
确保导入了相关的包和并配置了相关的 API_Key 这里取巧,在确保导入 zhipuai model 后直接将 zhipuai.model_api
绑定到 client 上,减少和其他 Embeddings 类的差异。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 @root_validator() def validate_environment (cls, values: Dict ) -> Dict : """ 验证环境变量或配置文件中的zhipuai_api_key是否可用。 Args: values (Dict): 包含配置信息的字典,必须包含 zhipuai_api_key 的字段 Returns: values (Dict): 包含配置信息的字典。如果环境变量或配置文件中未提供 zhipuai_api_key,则将返回原始值;否则将返回包含 zhipuai_api_key 的值。 Raises: ValueError: zhipuai package not found, please install it with `pip install zhipuai` """ values["zhipuai_api_key" ] = get_from_dict_or_env( values, "zhipuai_api_key" , "ZHIPUAI_API_KEY" , ) try : import zhipuai zhipuai.api_key = values["zhipuai_api_key" ] values["client" ] = zhipuai.model_api except ImportError: raise ValueError( "Zhipuai package not found, please install it with " "`pip install zhipuai`" ) return values
Override _embed 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def _embed (self, texts: str ) -> List [float ]: """ 生成输入文本的 embedding。 Args: texts (str): 要生成 embedding 的文本。 Return: embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表。 """ try : resp = self.client.invoke( model="text_embedding" , prompt=texts ) except Exception as e: raise ValueError(f"Error raised by inference endpoint: {e} " ) if resp["code" ] != 200 : raise ValueError( "Error raised by inference API HTTP code: %s, %s" % (resp["code" ], resp["msg" ]) ) embeddings = resp["data" ]["embedding" ] return embeddings
Override embed_documents 1 2 3 4 5 6 7 8 9 10 def embed_documents (self, texts: List [str ] ) -> List [List [float ]]: """ 生成输入文本列表的 embedding。 Args: texts (List[str]): 要生成 embedding 的文本列表. Returns: List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。 """ return [self._embed(text) for text in texts]
Override embed_query embed_query
是对单个文本计算 embedding 的方法,因为我们已经定义好对文档列表计算 embedding 的方法 embed_documents
了,这里可以直接将单个文本组装成 list 的形式传给 embed_documents
。
1 2 3 4 5 6 7 8 9 10 11 12 def embed_query (self, text: str ) -> List [float ]: """ 生成输入文本的 embedding。 Args: text (str): 要生成 embedding 的文本。 Return: List [float]: 输入文本的 embedding,一个浮点数值列表。 """ resp = self.embed_documents([text]) return resp[0 ]
本文作者:jujimeizuo 本文地址 : https://blog.jujimeizuo.cn/2024/01/29/custom-embeddings-based-on-langchain/ 本博客所有文章除特别声明外,均采用 CC BY-SA 3.0 协议。转载请注明出处!