一、实现本文需要先准备一个 DeepSeek 的 ApiKey 以及智谱 AI 的 ApiKey。
对应的官方网站:
DeepSeek:DeepSeek
智谱:智谱AI开放平台
二、直接下载项目
Gitee 地址:https://gitee.com/xu-xx2385/spring-ai
三、贴上全部代码,基本都有注释。实现的比较简易。自行优化。
1、ChatController
package com.springai.controller;
import com.springai.chatData.ChatRedisMemory;
import com.springai.entity.ChatSession;
import com.springai.service.DocumentService;
import com.springai.session.ChatSessionService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.ResponseFormat;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@Slf4j
@RestController
@RequestMapping("/chat")
public class ChatController {
private final ChatClient chatClient;
private final PromptTemplate deepThinkPromptTemplate;
private final ChatRedisMemory chatRedisMemory;
private final VectorStore vectorStore;
private final DocumentService documentService;
private final ChatSessionService chatSessionService;
public ChatController(DeepSeekChatModel chatModel, PromptTemplate deepThinkPromptTemplate, ChatRedisMemory chatRedisMemory, VectorStore vectorStore, DocumentService documentService, ChatSessionService chatSessionService) {
this.chatRedisMemory = chatRedisMemory;
this.deepThinkPromptTemplate = deepThinkPromptTemplate;
this.vectorStore = vectorStore;
this.documentService = documentService;
this.chatSessionService = chatSessionService;
this.chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatRedisMemory).build())
.defaultAdvisors(new SimpleLoggerAdvisor())
.defaultOptions(DeepSeekChatOptions.builder().temperature(0.7d).build())
.build();
}
@PostMapping(value = "/deepThing")
public Flux<String> customOptions (@Validated @RequestBody String prompt,
@RequestParam(value = "model", required = false,defaultValue = "deepseek-chat") String model,
@RequestParam(value = "chatId", required = false, defaultValue = "spring-ai-alibaba-playground-deepthink-chat") String chatId) {
return deepThinkingChat(chatId,model,prompt);
}
// 普通 API问答
public Flux<String> deepThinkingChat(String chatId, String model, String prompt) {
return chatClient.prompt()
.options(DeepSeekChatOptions.builder().model(model)
.temperature(0.8)
.responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.TEXT)
.build()).build()
).system(deepThinkPromptTemplate.getTemplate())
.user(prompt)
.advisors(memoryAdvisor -> memoryAdvisor.param(ChatRedisMemory.CONVERSATION_ID, chatId))
.stream().content();
}
// RAG API问答
@GetMapping("/query")
public Flux<String> processQuery(String query,
@RequestParam(value = "model", required = false, defaultValue = "deepseek-chat") String model,
@RequestParam(value = "chatId", required = false, defaultValue = "spring-ai-alibaba-playground-deepthink-chat") String chatId,
@RequestParam(value = "useId", required = false) String useId) {
return process(chatId, model, query,useId);
}
public Flux<String> process(String chatId, String model, String query,String useId) {
// 1. 检索相关文档
// topk决定了检索到的相关文档数量
List<Document> similarDocuments = vectorStore.similaritySearch(SearchRequest.builder().query(query).topK(2).build());
log.info("检索到相关文档:" + similarDocuments.size());
// 2. 构建上下文
String context = similarDocuments.stream()
.map(Document::getFormattedContent)
.collect(Collectors.joining("\n\n"));
// 3. 构建提示词
Map<Object, Object> map = new HashMap<>();
map.put("current_date", new Date().toString());
map.put("input", query);
map.put("context", context);
String rag = map.toString();
StringBuilder answerBuilder = new StringBuilder();
return chatClient.prompt()
.options(DeepSeekChatOptions.builder().model(model)
.temperature(0.8)
.responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.TEXT)
.build()).build()
).system(deepThinkPromptTemplate.getTemplate() +rag).user(query)
.advisors(memoryAdvisor -> memoryAdvisor.param(ChatRedisMemory.CONVERSATION_ID, chatId))
.stream().content()
.doOnNext(answerBuilder::append)
.doOnComplete(() -> {
chatSessionService.saveSession(ChatSession.builder()
.chatId(chatId)
.userId(useId)
.content("question:"+query+"answer"+ answerBuilder)
.createdAt(LocalDateTime.now()).build());
});
}
//embedding 文件上传
@PostMapping("/upload")
public String initKnowledgeBase(@RequestParam("file") MultipartFile file) {
try {
if (file.isEmpty()) {
return "上传失败:文件为空";
}
// 获取项目根目录
String projectRoot = System.getProperty("user.dir");
File dataDir = new File(projectRoot, "data");
// 校验是否为 PDF 文件(忽略大小写)
String originalFilename = file.getOriginalFilename();
if (originalFilename == null || !originalFilename.toLowerCase().endsWith(".pdf")) {
return "上传失败:仅支持 PDF 文件格式";
}
// 校验文件大小(最大 10MB)
long maxSize = 10 * 1024 * 1024; // 10MB
if (file.getSize() > maxSize) {
return "上传失败:文件大小不能超过 10MB";
}
// 若 data 目录不存在,则创建
if (!dataDir.exists() && !dataDir.mkdirs()) {
throw new IOException("无法创建 data 目录: " + dataDir.getAbsolutePath());
}
// 加入时间戳生成唯一文件名,防止冲突
String timestamp = String.valueOf(System.currentTimeMillis());
String safeFilename = timestamp + "-" + (originalFilename != null ? originalFilename : "uploaded.pdf");
File savedFile = new File(dataDir, safeFilename);
// 保存上传的文件到 data 目录
file.transferTo(savedFile);
// 调用文档处理逻辑
documentService.loadAndStoreDocuments(savedFile);
return "文件上传并处理成功,保存位置:" + savedFile.getAbsolutePath();
} catch (Exception e) {
log.error("文件上传失败: {}", e.getMessage(), e);
return "文件上传失败:" + e.getMessage();
}
}
}
2、DocumentService
package com.springai.service;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;
import java.io.File;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class DocumentService {
private final VectorStore vectorStore;
public void loadAndStoreDocuments(File file) {
try {
PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder()
.withPageExtractedTextFormatter(
new ExtractedTextFormatter.Builder()
.withNumberOfTopTextLinesToDelete(0)
.build())
.build();
Resource resource = new FileSystemResource(file);
ParagraphPdfDocumentReader pdfReader = new ParagraphPdfDocumentReader(resource, config);
List<Document> documents = pdfReader.get();
log.info("Total number of documents (paragraph-based): {}", documents.size());
if (documents.isEmpty()) {
PagePdfDocumentReader pdfReader2 = new PagePdfDocumentReader(resource);
documents = pdfReader2.get();
log.info("Total number of documents (page-based): {}", documents.size());
}
if (documents.isEmpty()) {
log.warn("未能从 PDF 中提取任何内容,跳过存储操作。");
return;
}
TokenTextSplitter splitter = new TokenTextSplitter();
List<Document> splitDocs = splitter.apply(documents);
vectorStore.add(splitDocs);
log.info("成功加载并存储文档到向量数据库,条目数:{}", splitDocs.size());
} catch (Exception e) {
log.error("处理文档时发生异常: {}", e.getMessage(), e);
throw new RuntimeException("文档处理失败,请检查文档内容或格式是否正确", e);
}
}
}
3、三个 Config 文件
package com.springai.config;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class DeepThinkPromptTemplateConfig {
@Bean
public PromptTemplate deepThinkPromptTemplate() {
return new PromptTemplate(
"""
你是一位深思熟虑且专业的AI助手,你的任务是根据提供的文档内容和必要的联网搜索结果,准确、清晰地回答用户的问题。
请遵循以下规则进行回答:
1. 请先判断用户的问题是否清晰明确。如果问题过于模糊、缺乏语义(例如只输入了数字“1”或无意义字符),请友好地请求用户重新提问。
2. 回答优先严格基于文档(context)中的内容;如果文档中没有相关信息,再进行联网搜索获取补充内容。
3. 回答前,请先思考并将你的思路写在 `<think></think>` 标签中,说明你是如何分析用户问题、查找信息并组织答案的。
4. 回答内容需专业、友好,结构清晰,使用合理的小标题(如“背景信息”、“解答内容”、“附加建议”等)。
5. 如果引用的是文档内容,请在回答中明确标注“【来源:文档】”;如果引用的是联网搜索结果,请标注“【来源:联网搜索】”。
6. 如文档和网络中均无相关信息,请礼貌地回答:“抱歉,我在文档和网上都未能找到相关信息。”
7. 注意判断上下文关系,如果用户问的问题与之前的对话内容相关,请结合上下文进行回答。否则,直接基于当前文档和联网搜索结果回答。
请思考并完成以下任务:
当前日期:{current_date}
文档内容如下:
{context}
用户提问:
{input}
请按照上述规范思考并输出你的思路,包裹在 <think></think> 中,然后正式开始回答用户问题。
"""
);
}}
package com.springai.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
@Configuration
public class RedisConfig {
@Bean
public RedisTemplate<String, Object> messageRedisTemplate(RedisConnectionFactory factory, Jackson2ObjectMapperBuilder builder) {
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(factory);
// 使用String序列化器作为key的序列化方式
template.setKeySerializer(new StringRedisSerializer());
// 对value进行序列化
template.setValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class));
// 设置hash类型的key和value序列化方式
template.setHashKeySerializer(new StringRedisSerializer());
template.setHashValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class));
template.afterPropertiesSet();
return template;
}
}
package com.springai.config;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;
import javax.sql.DataSource;
@Configuration
public class VectorStoreConfig {
@Bean
public VectorStore vectorStore(@Qualifier("zhiPuAiEmbeddingModel") EmbeddingModel embeddingClient, JdbcTemplate jdbcTemplate) {
return PgVectorStore.builder(jdbcTemplate, embeddingClient).build();
}
@Bean
public JdbcTemplate jdbcTemplate(DataSource dataSource) {
return new JdbcTemplate(dataSource);
}
}
4、两个实体
package com.springai.entity;
import lombok.Data;
/**
* 聊天消息实体类,表示一次聊天记录
* 包含消息ID、类型、内容、时间戳和发送者信息
*/
@Data
public class ChatEntity {
// 聊天消息唯一标识
private String chatId;
// 消息类型(如text、image等)
private String type;
// 消息内容
private String text;
// 消息发送时间戳
private Long timestamp;
private String sender;
}
package com.springai.entity;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Builder;
import lombok.Data;
import java.io.Serializable;
import java.time.LocalDateTime;
@Data
@Builder
@TableName("chat_session")
public class ChatSession implements Serializable {
@TableId
@Schema(description = "会话ID")
private String chatId;
@Schema(description = "用户ID")
private String userId;
@Schema(description = "会话时间")
private LocalDateTime createdAt;
@Schema(description = "对话内容")
private String content;
}
5、聊天记录保存到 PG
package com.springai.session;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.springai.entity.ChatSession;
import com.springai.mapper.ChatSessionMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.util.List;
@Service
@RequiredArgsConstructor
public class ChatSessionService {
@Autowired
private ChatSessionMapper chatSessionMapper;
/**
* 保存会话到 PostgreSQL(使用MyBatis-Plus)
* 如果用户id和chatId都相同,则追加content,否则新建
*/
public void saveSession(ChatSession chatSession) {
QueryWrapper<ChatSession> query = new QueryWrapper<>();
query.eq("user_id", chatSession.getUserId())
.eq("chat_id", chatSession.getChatId());
ChatSession exist = chatSessionMapper.selectOne(query);
if (exist != null) {
// 追加content
exist.setContent((exist.getContent() == null ? "" : exist.getContent()) + chatSession.getContent());
chatSessionMapper.updateById(exist);
} else {
chatSessionMapper.insert(chatSession);
}
}
/**
* 获取指定用户的会话记录(使用MyBatis-Plus)
*/
public List<ChatSession> getSessions(String userId) {
QueryWrapper<ChatSession> query = new QueryWrapper<>();
query.eq("user_id", userId).orderByDesc("created_at");
return chatSessionMapper.selectList(query);
}
}
6、添加到 Redis 方法
package com.springai.chatData;
import com.alibaba.fastjson.JSON;
import com.springai.entity.ChatEntity;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
@Slf4j
@Component
public class ChatRedisMemory implements ChatMemory {
private static final String KEY_PREFIX = "chat:history:";
private final RedisTemplate<String, Object> redisTemplate;
public ChatRedisMemory(RedisTemplate<String, Object> redisTemplate) {
this.redisTemplate = redisTemplate;
}
@Override
public void add(String conversationId, List<Message> messages) {
String key = KEY_PREFIX + conversationId;
List<String> jsonList = messages.stream()
.map(msg -> {
ChatEntity ent = new ChatEntity();
ent.setChatId(conversationId);
ent.setType(msg.getMessageType().getValue());
ent.setText(msg.getText());
ent.setTimestamp(System.currentTimeMillis());
return JSON.toJSONString(ent);
})
.toList();
try {
redisTemplate.opsForList().rightPushAll(key, jsonList);
redisTemplate.expire(key, 30, TimeUnit.MINUTES);
} catch (Exception e) {
log.error("Failed to add chat messages to Redis for conversationId: {}", conversationId, e);
}
}
@Override
public List<Message> get(String conversationId) {
String key = KEY_PREFIX + conversationId;
List<Object> jsonList = redisTemplate.opsForList().range(key, 0, -1);
List<Message> messages = new ArrayList<>();
if (jsonList == null || jsonList.isEmpty()) {
return messages;
}
for (Object obj : jsonList) {
String json = obj.toString();
try {
ChatEntity ent = JSON.parseObject(json, ChatEntity.class);
String type = ent.getType();
String content = ent.getText();
Message msg = switch (type) {
case "user" -> new UserMessage(content);
case "assistant" -> new AssistantMessage(content);
case "system" -> new SystemMessage(content);
default -> {
log.warn("Unknown message type: {}", type);
yield null;
}
};
if (msg != null) {
messages.add(msg);
}
} catch (Exception e) {
log.warn("Failed to parse chat message from Redis: {}", obj, e);
}
}
return messages;
}
@Override
public void clear(String conversationId) {
redisTemplate.delete(KEY_PREFIX + conversationId);
}
}
7、mapper
package com.springai.mapper;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.springai.entity.ChatSession;
import org.apache.ibatis.annotations.Mapper;
@Mapper
public interface ChatSessionMapper extends BaseMapper<ChatSession> {
}
8、下面是用到的 application.yml配置文件:
spring:
datasource:
url: jdbc:postgresql://localhost:5433/postgres
username: postgres
password: 123456
ai:
vectorstore: #向量存储相关
pgvector:
initialize-schema: true #自动创建表
index-type: HNSW #索引的算法
distance-type: COSINE_DISTANCE #使用 余弦距离 作为度量标准
dimensions: 384 #向量模型 all-minilm:latest 的维度
deepseek:
api-key: "key"
base-url: https://api.deepseek.com
chat:
options:
model: deepseek-chat
openai:
api-key: "sk-fake-key"
zhipuai:
embedding:
api-key: "key"
options:
model: embedding-2
api-key: "key"
data:
redis:
host: 127.0.0.1
port: 46379
database: 0