基于 LangChain 自定义 Embeddings

基于 LangChain 自定义 Embeddings

在 LangChain 中支持 OpenAI、LLAMA 等大模型 Embeddings 的调用接口,不过没有内置所有大模型,但是允许用户自定义 Embeddings 类型。 接下来以 ZhipuAI 为例,基于 LangChain 自定义 Embeddings。

设计思路

  • 要实现自定义 Embeddings,需要定义一个自定义类继承自 LangChain 的 Embeddings 基类,然后定义三个函数
    • _embed: 接受一个字符串,并返回一个存放 Embeddings 的 List[float],即模型的核心调用
    • embed_query: 用于对单个字符串 (query) 进行 embedding
    • embed_documents: 用于对字符串列表 (documents) 进行 embedding

第三方库

1
2
3
4
5
6
7
8
from __future__ import annotations

import logging
from typing import Any, Dict, List, Optional

from langchain.embeddings.base import Embeddings
from langchain.pydantic_v1 import BaseModel, root_validator
from langchain.utils import get_from_dict_or_env

自定义 Embedding

ZhipuAIEmbeddings

定义一个继承自 Embeddings 类的自定义 Embeddings 类:

1
2
3
4
5
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""`Zhipuai Embeddings` embedding models."""

zhipuai_api_key: Optional[str] = None
"""Zhipuai application apikey"""

root_validator 接收一个函数作为参数,该函数包含需要校验的逻辑。函数应该返回一个字典,其中包含经过校验的数据。如果校验失败,则抛出一个 ValueError 异常。

装饰器 root_validator 确保导入了相关的包和并配置了相关的 API_Key 这里取巧,在确保导入 zhipuai model 后直接将 zhipuai.model_api 绑定到 client 上,减少和其他 Embeddings 类的差异。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
验证环境变量或配置文件中的zhipuai_api_key是否可用。

Args:

values (Dict): 包含配置信息的字典,必须包含 zhipuai_api_key 的字段
Returns:

values (Dict): 包含配置信息的字典。如果环境变量或配置文件中未提供 zhipuai_api_key,则将返回原始值;否则将返回包含 zhipuai_api_key 的值。
Raises:

ValueError: zhipuai package not found, please install it with `pip install
zhipuai`
"""
values["zhipuai_api_key"] = get_from_dict_or_env(
values,
"zhipuai_api_key",
"ZHIPUAI_API_KEY",
)

try:
import zhipuai
zhipuai.api_key = values["zhipuai_api_key"]
values["client"] = zhipuai.model_api

except ImportError:
raise ValueError(
"Zhipuai package not found, please install it with "
"`pip install zhipuai`"
)
return values

Override _embed

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def _embed(self, texts: str) -> List[float]:
"""
生成输入文本的 embedding。

Args:
texts (str): 要生成 embedding 的文本。

Return:
embeddings (List[float]): 输入文本的 embedding,一个浮点数值列表。
"""
try:
resp = self.client.invoke(
model="text_embedding",
prompt=texts
)
except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")

if resp["code"] != 200:
raise ValueError(
"Error raised by inference API HTTP code: %s, %s"
% (resp["code"], resp["msg"])
)
embeddings = resp["data"]["embedding"]
return embeddings

Override embed_documents

1
2
3
4
5
6
7
8
9
10
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
生成输入文本列表的 embedding。
Args:
texts (List[str]): 要生成 embedding 的文本列表.

Returns:
List[List[float]]: 输入列表中每个文档的 embedding 列表。每个 embedding 都表示为一个浮点值列表。
"""
return [self._embed(text) for text in texts]

Override embed_query

embed_query 是对单个文本计算 embedding 的方法,因为我们已经定义好对文档列表计算 embedding 的方法 embed_documents 了,这里可以直接将单个文本组装成 list 的形式传给 embed_documents

1
2
3
4
5
6
7
8
9
10
11
12
def embed_query(self, text: str) -> List[float]:
"""
生成输入文本的 embedding。

Args:
text (str): 要生成 embedding 的文本。

Return:
List [float]: 输入文本的 embedding,一个浮点数值列表。
"""
resp = self.embed_documents([text])
return resp[0]

本文作者:jujimeizuo
本文地址https://blog.jujimeizuo.cn/2024/01/29/custom-embeddings-based-on-langchain/
本博客所有文章除特别声明外,均采用 CC BY-SA 3.0 协议。转载请注明出处!