LLM大语言模型RAG核心代码


1. RAG系统构建的核心代码整理

Spring Boot 脚手架下基于Langchain4j的RAG流程核心代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
/**
* 处理文档上传请求
* 该接口用于上传文档,并将其与特定的知识库关联
*
* @param file 用户上传的文件
* @param knowledgeId 知识库的唯一标识符
* @return 返回一个表示操作结果的响应对象
*/
@PostMapping("/docs/{knowledgeId}")
@SaCheckPermission("aigc:embedding:docs")
public R docs(MultipartFile file, @PathVariable String knowledgeId) {
// 获取当前用户的ID,并将其转换为字符串形式
String userId = String.valueOf(AuthUtil.getUserId());

// 将用户上传的文件保存到云存储,并获取保存后的文件信息
AigcOss oss = aigcOssService.upload(file, userId);

// 创建一个文档对象,用于存储上传文档的相关信息
AigcDocs data = new AigcDocs()
.setName(oss.getOriginalFilename()) // 设置文档的原始名称
.setSliceStatus(false) // 设置文档切片状态为未切片
.setUrl(oss.getUrl()) // 设置文档的访问URL
.setSize(file.getSize()) // 设置文档的大小
.setType(EmbedConst.ORIGIN_TYPE_UPLOAD) // 设置文档的类型为上传类型
.setKnowledgeId(knowledgeId); // 设置文档所属的知识库ID

// 将文档信息添加到知识库中
aigcKnowledgeService.addDocs(data);

// 提交一个异步任务,用于处理文档的嵌入操作
TaskManager.submitTask(userId, Executors.callable(() -> {
embeddingService.embedDocsSlice(data, oss.getUrl());
}));

// 返回一个表示操作成功的响应对象
return R.ok();
}


/**
* 重写embeddingDocs方法,用于处理文档向量解析
* 该方法接收一个ChatReq对象作为参数,其中包含了知识ID和文档名称等信息
* 返回一个EmbeddingR对象的列表,每个对象包含文档段的向量ID和文本
*
* @param req ChatReq对象,包含知识库ID和文档名称等信息
* @return List<EmbeddingR> 文档向量解析结果列表
*/
@Override
public List<EmbeddingR> embeddingDocs(ChatReq req) {
// 日志记录文档向量解析开始信息
log.info(">>>>>>>>>>>>>> Docs文档向量解析开始,KnowledgeId={}, DocsName={}", req.getKnowledgeId(), req.getDocsName());

// 从URL加载文档,并使用Apache Tika解析器进行解析
Document document = UrlDocumentLoader.load(req.getUrl(), new ApacheTikaDocumentParser());

// 在文档元数据中添加知识ID和文档名称
document.metadata().put(EmbedConst.KNOWLEDGE, req.getKnowledgeId()).put(EmbedConst.FILENAME, req.getDocsName());

// 创建自定义清理器,用于清理文档中的HTML标签和多余空格
DocumentTransformer htmlCleaner = doc -> {
// 定义匹配HTML标签的正则表达式
Pattern htmlPattern = Pattern.compile("<[^>]+>");
// 定义匹配连续空格的正则表达式
Pattern spacePattern = Pattern.compile("\\s+");

// 清理HTML标签
String cleanedText = htmlPattern.matcher(doc.text()).replaceAll("");
// 合并连续空格为单个空格
cleanedText = spacePattern.matcher(cleanedText).replaceAll(" ");

// 更新元数据,标记文档已清理
Metadata metadata = doc.metadata();
metadata.put("cleaned", "true");

// 返回清理后的文档
return Document.from(cleanedText, metadata);
};

// 使用自定义清理器清理文档
document = htmlCleaner.transform(document);

// 初始化用于存储文档向量解析结果的列表
List<EmbeddingR> list = new ArrayList<>();

try {
// 创建文档分段器
DocumentSplitter splitter = EmbeddingProvider.splitter();
// 将文档分段
List<TextSegment> segments = splitter.split(document);

// 获取嵌入模型
EmbeddingModel embeddingModel = embeddingProvider.getEmbeddingModel(req.getKnowledgeId());
// 获取嵌入存储
EmbeddingStore<TextSegment> embeddingStore = embeddingProvider.getEmbeddingStore(req.getKnowledgeId());

// 对所有分段进行嵌入,并获取嵌入结果
List<Embedding> embeddings = embeddingModel.embedAll(segments).content();
// 将嵌入结果存储,并获取存储的ID列表
List<String> ids = embeddingStore.addAll(embeddings,segments);

// 遍历ID列表,创建并添加EmbeddingR对象到结果列表
for (int i = 0; i < ids.size(); i++) {
list.add(new EmbeddingR().setVectorId(ids.get(i)).setText(segments.get(i).text()));
}
} catch (Exception e) {
// 打印异常信息
e.printStackTrace();
}

// 日志记录文档向量解析结束信息
log.info(">>>>>>>>>>>>>> Docs文档向量解析结束,KnowledgeId={}, DocsName={}", req.getKnowledgeId(), req.getDocsName());

// 返回文档向量解析结果列表
return list;
}

/**
* 聊天功能实现方法
* 该方法负责处理聊天请求,包括模型加载、会话初始化、以及知识库的关联等步骤
*
* @param req 聊天请求对象,包含模型ID、会话ID、提示文本等信息
* @return 返回聊天的Token流,用于实时传输聊天响应
* @throws ServiceException 如果模型加载失败,抛出此异常
*/
@Override
public TokenStream chat(ChatReq req) {
// 加载聊天的语言模型
StreamingChatLanguageModel model = provider.stream(req.getModelId());
// 检查模型是否成功加载
if (ObjectUtils.isEmpty(model)) {
throw new ServiceException("模型加载失败");
}
// 如果会话ID为空,则生成新的会话ID
if (StrUtil.isBlank(req.getConversationId())) {
req.setConversationId(IdUtil.simpleUUID());
}
// 构建AI服务配置,包括模型、系统消息提供者、聊天记忆提供者等
AiServices<Agent> aiServices = AiServices.builder(Agent.class)
.streamingChatLanguageModel(model)
.systemMessageProvider(memoryId -> Optional.ofNullable(req.getPromptText()).orElse(""))
.chatMemoryProvider(memoryId -> MessageWindowChatMemory.builder()
.id(req.getConversationId())
.chatMemoryStore(new PersistentChatMemoryStore())
.maxMessages(chatProps.getMemoryMaxMessage())
.build())
;

// 如果提供了知识库ID,则添加到请求的知识库ID列表中
if (StrUtil.isNotBlank(req.getKnowledgeId())) {
req.getKnowledgeIds().add(req.getKnowledgeId());
}
// 是否关联了知识库,如果关联了知识库,说明是RAG应用,需要构建RAG优化链条
if (req.getKnowledgeIds() != null && !req.getKnowledgeIds().isEmpty()) {
// 构建内容检索器,用于从知识库中检索相关信息
Function<Query, Filter> filter = (query) -> metadataKey(KNOWLEDGE).isIn(req.getKnowledgeIds());
ContentRetriever embeddingStoreContentRetriever = EmbeddingStoreContentRetriever.builder()
.embeddingStore(embeddingProvider.getEmbeddingStore(req.getKnowledgeIds()))
.embeddingModel(embeddingProvider.getEmbeddingModel(req.getKnowledgeIds()))
.dynamicFilter(filter)
.maxResults(20)
.build();
// 将聊天的上下文进行压缩
CompressingQueryTransformer queryTransformer = new CompressingQueryTransformer(provider.text(req.getModelId()));
// 将问题转换成3个问题,以获得更全面的回答
ExpandingQueryTransformer expandingQueryTransformer = new ExpandingQueryTransformer(provider.text(req.getModelId()));

// 构建Web搜索引擎,用于从网络中检索相关信息
SearchApiWebSearchEngine webSearchEngine = SearchApiWebSearchEngine.builder()
.apiKey("p8SZVNAweqTtoZBBTVnXttcj")// 测试使用
.engine("google")
.build();
ContentRetriever webSearchContentRetriever = WebSearchContentRetriever.builder()
.webSearchEngine(webSearchEngine)
.maxResults(3)
.build();

// sql 数据库内容获取
ContentRetriever sqlDatabaseContentRetriever = SqlDatabaseContentRetriever.builder()
.dataSource(null)
.chatLanguageModel(null)
.build();

// Let's create a query router that will route each query to both retrievers.
QueryRouter queryRouter = new DefaultQueryRouter(embeddingStoreContentRetriever, webSearchContentRetriever,sqlDatabaseContentRetriever);

ScoringModel scoringModel = CohereScoringModel.builder()
.apiKey(System.getenv("COHERE_API_KEY"))
.modelName("rerank-multilingual-v3.0")
.build();
ContentAggregator contentAggregator = ReRankingContentAggregator.builder()
.scoringModel(scoringModel)
.minScore(0.8) // we want to present the LLM with only the truly relevant segments for the user's query
.build();



// 为AI服务添加RAG检索增强功能
aiServices.retrievalAugmentor(DefaultRetrievalAugmentor
.builder()
.queryTransformer(queryTransformer) // 构建用户查询转换器
.queryRouter(queryRouter) // 构建用户查询路由
.contentRetriever(webSearchContentRetriever) // 构建内容检索器,如果有了
.contentAggregator(contentAggregator) // 构建内容聚合器
.contentInjector(new DefaultContentInjector()) // 构建内容注入器
.executor(ragExecutor) // 自定义线程池,然后涉及多个内容检索器的时候会采用线程池并行的方式执行
.build());
}
// 构建AI服务代理对象
Agent agent = aiServices.build();
// 使用代理对象处理聊天请求,并返回Token流
return agent.stream(req.getConversationId(), req.getMessage());
}
  1. 模型加载与初始化

根据请求中的modelId加载对应的流式聊天语言模型

验证模型是否成功加载

为会话生成唯一ID(如果未提供)

  1. AI服务构建

使用AiServices构建器创建Agent接口的实现

配置流式聊天语言模型

设置系统消息提供者(使用请求中的promptText或默认为空字符串)

配置聊天记忆提供者,使用MessageWindowChatMemory实现会话记忆功能

设置会话ID

使用PersistentChatMemoryStore持久化存储聊天历史

限制最大消息数量

  1. RAG(检索增强生成)实现

检查是否提供了知识库ID

如果关联了知识库,构建完整的RAG处理流程:

内容检索:构建EmbeddingStoreContentRetriever,用于从向量数据库检索相关内容

设置过滤器,只检索指定知识库的内容

配置嵌入模型和存储

限制最大结果数

查询转换:

使用CompressingQueryTransformer压缩聊天上下文

使用ExpandingQueryTransformer将问题扩展为多个问题,获取更全面的回答

Web搜索集成:

构建SearchApiWebSearchEngine,用于从Google搜索引擎获取信息

创建WebSearchContentRetriever,限制最大结果数

SQL数据库集成:

创建SqlDatabaseContentRetriever(示例中未实际配置数据源)

查询路由:

使用DefaultQueryRouter将查询路由到多个内容检索器(向量库、Web搜索、SQL数据库)

内容评分与排序:

使用CohereScoringModel对检索结果进行评分

通过ReRankingContentAggregator聚合并重新排序内容,设置最低相关性分数阈值

RAG增强器配置:

将所有组件组装到DefaultRetrievalAugmentor中

配置自定义线程池用于并行执行多个内容检索任务

执行聊天请求

构建完成AI服务代理对象

使用代理处理聊天请求,传入会话ID和用户消息

返回Token流(流式响应)

技术亮点

  1. 流式响应:使用StreamingChatLanguageModel实现实时响应,提升用户体验

  2. 多源检索:集成了多种数据源(向量库、Web搜索、SQL数据库)

  3. 查询优化:

​ 压缩查询减少token消耗

​ 扩展查询获取更全面的回答

  1. 内容质量控制:

​ 使用Cohere的重排序模型对检索结果进行评分

​ 设置最低分数阈值,确保只有高质量内容被提供给LLM

  1. 并行处理:使用自定义线程池并行执行多个内容检索任务,提高效率

  2. 会话管理:使用持久化的聊天记忆存储,维护对话上下文


文章作者: ring2
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 ring2 !
  目录