Spring Boot集成LangChain来实现Rag应用1.什么是rag? 检索增强生成(RAG)是指对大型语言模型
1.什么是rag?
检索增强生成(RAG)是指对大型语言模型输出进行优化,使其能够在生成响应之前引用训练数据来源之外的权威知识库。大型语言模型(LLM)用海量数据进行训练,使用数十亿个参数为回答问题、翻译语言和完成句子等任务生成原始输出。在 LLM 本就强大的功能基础上,RAG 将其扩展为能访问特定领域或组织的内部知识库,所有这些都无需重新训练模型。这是一种经济高效地改进 LLM 输出的方法,让它在各种情境下都能保持相关性、准确性和实用性。
为什么检索增强生成很重要?
LLM 是一项关键的人工智能(AI)技术,为智能聊天机器人和其他自然语言处理(NLP)应用程序提供支持。目标是通过交叉引用权威知识来源,创建能够在各种环境中回答用户问题的机器人。不幸的是,LLM 技术的本质在 LLM 响应中引入了不可预测性。此外,LLM 训练数据是静态的,并引入了其所掌握知识的截止日期。 LLM 面临的已知挑战包括:
- 在没有答案的情况下提供虚假信息。
- 当用户需要特定的当前响应时,提供过时或通用的信息。
- 从非权威来源创建响应。
- 由于术语混淆,不同的培训来源使用相同的术语来谈论不同的事情,因此会产生不准确的响应。
您可以将大型语言模型看作是一个过于热情的新员工,他拒绝随时了解时事,但总是会绝对自信地回答每一个问题。不幸的是,这种态度会对用户的信任产生负面影响,这是您不希望聊天机器人效仿的! RAG 是解决其中一些挑战的一种方法。它会重定向 LLM,从权威的、预先确定的知识来源中检索相关信息。组织可以更好地控制生成的文本输出,并且用户可以深入了解 LLM 如何生成响应。
检索增强生成的工作原理是什么?
如果没有 RAG,LLM 会接受用户输入,并根据它所接受训练的信息或它已经知道的信息创建响应。RAG 引入了一个信息检索组件,该组件利用用户输入首先从新数据源提取信息。用户查询和相关信息都提供给 LLM。LLM 使用新知识及其训练数据来创建更好的响应。以下各部分概述了该过程。
创建外部数据
LLM 原始训练数据集之外的新数据称为
外部数据
。它可以来自多个数据来源,例如 API、数据库或文档存储库。数据可能以各种格式存在,例如文件、数据库记录或长篇文本。另一种称为
嵌入语言模型
的 AI 技术将数据转换为数字表示形式并将其存储在向量数据库中。这个过程会创建一个生成式人工智能模型可以理解的知识库。
检索相关信息
下一步是执行相关性搜索。用户查询将转换为向量表示形式,并与向量数据库匹配。例如,考虑一个可以回答组织的人力资源问题的智能聊天机器人。如果员工搜索
:“我有多少年假?”
,系统将检索年假政策文件以及员工个人过去的休假记录。这些特定文件将被退回,因为它们与员工输入的内容高度相关。相关性是使用数学向量计算和表示法计算和建立的。
增强 LLM 提示
接下来,RAG 模型通过在上下文中添加检索到的相关数据来增强用户输入(或提示)。此步骤使用提示工程技术与 LLM 进行有效沟通。增强提示允许大型语言模型为用户查询生成准确的答案。
更新外部数据
下一个问题可能是——如果外部数据过时了怎么办? 要维护当前信息以供检索,请异步更新文档并更新文档的嵌入表示形式。您可以通过自动化实时流程或定期批处理来执行此操作。这是数据分析中常见的挑战——可以使用不同的数据科学方法进行变更管理。 下图显示了将 RAG 与 LLM 配合使用的概念流程。
2.什么是LangChain?
LangChain 是一个用于开发由语言模型驱动的应用程序的框架。他主要拥有 2 个能力:
- 可以将 LLM 模型与外部数据源进行连接
- 允许与 LLM 模型进行交互
LLM 模型:Large Language Model,大型语言模型
3.代码工程
实验目的
利用LangChain实现rag应用
pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.2.1</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>rag</artifactId>
<properties>
<java.version>17</java.version>
<langchain4j.version>0.23.0</langchain4j.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-thymeleaf</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-devtools</artifactId>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
<configuration>
<excludes>
<exclude>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</exclude>
</excludes>
</configuration>
</plugin>
</plugins>
</build>
</project>
controller
package com.et.rag.controller;
import com.et.rag.service.SBotService;
import lombok.RequiredArgsConstructor;
import org.springframework.http.ResponseEntity;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
@Controller
@RequiredArgsConstructor
public class SBotController {
private final SBotService sBotService;
@GetMapping
public String home() {
return "index";
}
@PostMapping("/ask")
public ResponseEntity<String> ask(@RequestBody String question) {
try {
return ResponseEntity.ok(sBotService.askQuestion(question));
} catch (Exception e) {
return ResponseEntity.badRequest().body("Sorry, I can't process your question right now.");
}
}
}
service
package com.et.rag.service;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@Service
@RequiredArgsConstructor
@Slf4j
public class SBotService {
private final ConversationalRetrievalChain chain;
public String askQuestion(String question) {
log.debug("======================================================");
log.debug("Question: " + question);
String answer = chain.execute(question);
log.debug("Answer: " + answer);
log.debug("======================================================");
return answer;
}
}
EmbeddingStoreLoggingRetriever
package com.et.rag.retriever;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.retriever.Retriever;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import java.util.List;
/**
* EmbeddingStoreLoggingRetriever is a logging-enhanced for an EmbeddingStoreRetriever.
* <p>
* This class logs the relevant TextSegments discovered by the supplied
* EmbeddingStoreRetriever for improved transparency and debugging.
* <p>
* Logging happens at INFO level, printing each relevant TextSegment found
* for a given input text once the findRelevant method is called.
*/
@RequiredArgsConstructor
@Slf4j
public class EmbeddingStoreLoggingRetriever implements Retriever<TextSegment> {
private final EmbeddingStoreRetriever retriever;
@Override
public List<TextSegment> findRelevant(String text) {
List<TextSegment> relevant = retriever.findRelevant(text);
relevant.forEach(segment -> {
log.debug("=======================================================");
log.debug("Found relevant text segment: {}", segment);
});
return relevant;
}
}
components
初始化documents
package com.et.rag.configuration;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.UrlDocumentLoader;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import static com.et.rag.constant.Constants.SPRING_BOOT_RESOURCES_LIST;
@Configuration
public class DocumentConfiguration {
@Bean
public List<Document> documents() {
return SPRING_BOOT_RESOURCES_LIST.stream()
.map(url -> {
try {
return UrlDocumentLoader.load(url);
} catch (Exception e) {
throw new RuntimeException("Failed to load document from " + url, e);
}
})
.toList();
}
}
初始化langchain
package com.et.rag.configuration;
import com.et.rag.retriever.EmbeddingStoreLoggingRetriever;
import dev.langchain4j.chain.ConversationalRetrievalChain;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.AllMiniLmL6V2EmbeddingModel;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.retriever.EmbeddingStoreRetriever;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.EmbeddingStoreIngestor;
import dev.langchain4j.store.embedding.inmemory.InMemoryEmbeddingStore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.time.Duration;
import java.util.List;
import static com.et.rag.constant.Constants.PROMPT_TEMPLATE_2;
@Configuration
@RequiredArgsConstructor
@Slf4j
public class LangChainConfiguration {
@Value("${langchain.api.key}")
private String apiKey;
@Value("${langchain.timeout}")
private Long timeout;
private final List<Document> documents;
@Bean
public ConversationalRetrievalChain chain() {
EmbeddingModel embeddingModel = new AllMiniLmL6V2EmbeddingModel();
EmbeddingStore<TextSegment> embeddingStore = new InMemoryEmbeddingStore<>();
EmbeddingStoreIngestor ingestor = EmbeddingStoreIngestor.builder()
.documentSplitter(DocumentSplitters.recursive(500, 0))
.embeddingModel(embeddingModel)
.embeddingStore(embeddingStore)
.build();
log.info("Ingesting Spring Boot Resources ...");
ingestor.ingest(documents);
log.info("Ingested {} documents", documents.size());
EmbeddingStoreRetriever retriever = EmbeddingStoreRetriever.from(embeddingStore, embeddingModel);
EmbeddingStoreLoggingRetriever loggingRetriever = new EmbeddingStoreLoggingRetriever(retriever);
/*MessageWindowChatMemory chatMemory = MessageWindowChatMemory.builder()
.maxMessages(10)
.build();*/
log.info("Building ConversationalRetrievalChain ...");
ConversationalRetrievalChain chain = ConversationalRetrievalChain.builder()
.chatLanguageModel(OpenAiChatModel.builder()
.apiKey(apiKey)
.timeout(Duration.ofSeconds(timeout))
.build()
)
.promptTemplate(PromptTemplate.from(PROMPT_TEMPLATE_2))
//.chatMemory(chatMemory)
.retriever(loggingRetriever)
.build();
log.info("Spring Boot knowledge base is ready!");
return chain;
}
}
application.yaml
langchain:
api:
# "demo" is a free API key for testing purposes only. Please replace it with your own API key.
key: demo
# key: OPEN_API_KEY
# API call to complete before it is timed out.
timeout: 30
index.html
<!DOCTYPE html>
<html lang="en"
xmlns="http://www.w3.org/1999/xhtml">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Spring Boot Doc Bot</title>
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.3/css/all.min.css">
</head>
<body>
<nav class="bg-dark text-white py-3">
<div class="text-center d-flex justify-content-center align-items-center">
<img src="/logo.png" alt="Logo" style="width:60px; margin-right: 10px;">
<h2 style="margin: 0;">Welcome to Spring Boot Documentation Bot</h2>
</div>
</nav>
<div class="container mt-5">
<div class="row">
<div class="col-md-8 offset-2">
<h3 class="text-center mb-3">Ask your Spring related queries here!</h3>
<form>
<div class="mb-3">
<label for="questionInput" class="form-label">Question</label>
<input type="text" class="form-control" id="questionInput" name="question" placeholder="Enter your question" required>
</div>
<div class="mb-3 text-center">
<button id="submitBtn" type="button" class="btn btn-primary">Ask!</button>
<button id="clearBtn" type="button" class="btn btn-secondary">Clear</button>
</div>
</form>
</div>
</div>
<div class="row my-5">
<div class="col-md-8 offset-md-2">
<label for="answerBox" class="form-label"><h5>Answer</h5></label>
<div class="position-relative my-3">
<textarea class="form-control" rows="10" id="answerBox" disabled></textarea>
<a href="#" class="position-absolute top-0 end-0 m-2" id="copyBtn">
<i class="far fa-copy"></i>
</a>
</div>
</div>
</div>
</div>
<script src="https://code.jquery.com/jquery-3.7.1.min.js"></script>
<script>
$(document).ready(function () {
$("#submitBtn").click(function () {
let questionValue = $("#questionInput").val();
if (!questionValue) {
alert('Please enter your question');
return;
}
$("#answerBox").val('Please wait... fetching answer...');
$.ajax({
type: "POST",
url: "/ask",
data: JSON.stringify({ question: $("#questionInput").val() }),
//contentType: "application/json; charset=utf-8",
dataType: "text",
success: function (data) {
//console.log(typeof data);
//console.log(data);
$("#answerBox").val(data);
},
error: function (errMsg) {
alert(errMsg);
}
});
});
$("#clearBtn").click(function () {
$("#questionInput").val('');
$("#answerBox").val('');
});
document.getElementById("copyBtn").addEventListener("click", function() {
var copyText = document.getElementById("answerBox");
copyText.select();
copyText.setSelectionRange(0, 99999);
document.execCommand("copy");
alert("Copied: " + copyText.value);
});
});
</script>
</body>
</html>
只是一些关键代码,所有代码请参见下面代码仓库
代码仓库
4.测试
启动Spring Boot应用,访问http://127.0.0.1:8080/
5.引用
转载自:https://juejin.cn/post/7416962250731323433