JDK17+SpringAi+ PostgreSQL+Vector实现简易对话及历史对话记录保存及 PDF文件 Embedding

一、实现本文需要先准备一个 DeepSeek 的 ApiKey 以及智谱 AI 的 ApiKey。
对应的官方网站:
DeepSeek:DeepSeek
智谱:智谱AI开放平台

二、直接下载项目
Gitee 地址:https://gitee.com/xu-xx2385/spring-ai

三、贴上全部代码,基本都有注释。实现的比较简易。自行优化。

1、ChatController

package com.springai.controller;

import com.springai.chatData.ChatRedisMemory;
import com.springai.entity.ChatSession;
import com.springai.service.DocumentService;
import com.springai.session.ChatSessionService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.deepseek.DeepSeekChatModel;
import org.springframework.ai.deepseek.DeepSeekChatOptions;
import org.springframework.ai.deepseek.api.ResponseFormat;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import reactor.core.publisher.Flux;

import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Slf4j
@RestController
@RequestMapping("/chat")
public class ChatController {

    private final ChatClient chatClient;

    private final PromptTemplate deepThinkPromptTemplate;

    private final ChatRedisMemory chatRedisMemory;

    private final VectorStore vectorStore;

    private final DocumentService documentService;

    private final ChatSessionService chatSessionService;

    public ChatController(DeepSeekChatModel chatModel, PromptTemplate deepThinkPromptTemplate, ChatRedisMemory chatRedisMemory, VectorStore vectorStore, DocumentService documentService, ChatSessionService chatSessionService) {
        this.chatRedisMemory = chatRedisMemory;
        this.deepThinkPromptTemplate = deepThinkPromptTemplate;
        this.vectorStore = vectorStore;
        this.documentService = documentService;
        this.chatSessionService = chatSessionService;

        this.chatClient = ChatClient.builder(chatModel)
                                    .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatRedisMemory).build())
                                    .defaultAdvisors(new SimpleLoggerAdvisor())
                                    .defaultOptions(DeepSeekChatOptions.builder().temperature(0.7d).build())
                                    .build();
    }

    @PostMapping(value = "/deepThing")
    public Flux<String> customOptions (@Validated @RequestBody String prompt,
                                       @RequestParam(value = "model", required = false,defaultValue = "deepseek-chat") String model,
                                       @RequestParam(value = "chatId", required = false, defaultValue = "spring-ai-alibaba-playground-deepthink-chat") String chatId) {
        return deepThinkingChat(chatId,model,prompt);
    }

    // 普通 API问答
    public Flux<String> deepThinkingChat(String chatId, String model, String prompt) {

        return chatClient.prompt()
                         .options(DeepSeekChatOptions.builder().model(model)
                                                     .temperature(0.8)
                                                     .responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.TEXT)
                                                                                   .build()).build()
                         ).system(deepThinkPromptTemplate.getTemplate())
                         .user(prompt)
                         .advisors(memoryAdvisor -> memoryAdvisor.param(ChatRedisMemory.CONVERSATION_ID, chatId))
                         .stream().content();
    }

    // RAG API问答
    @GetMapping("/query")
    public Flux<String> processQuery(String query,
                                     @RequestParam(value = "model", required = false, defaultValue = "deepseek-chat") String model,
                                     @RequestParam(value = "chatId", required = false, defaultValue = "spring-ai-alibaba-playground-deepthink-chat") String chatId,
                                     @RequestParam(value = "useId", required = false) String useId) {
        return process(chatId, model, query,useId);
    }
    public Flux<String> process(String chatId, String model, String query,String useId) {
        // 1. 检索相关文档
        // topk决定了检索到的相关文档数量
        List<Document> similarDocuments = vectorStore.similaritySearch(SearchRequest.builder().query(query).topK(2).build());
        log.info("检索到相关文档:" + similarDocuments.size());
        // 2. 构建上下文
        String context = similarDocuments.stream()
                                         .map(Document::getFormattedContent)
                                         .collect(Collectors.joining("\n\n"));

        // 3. 构建提示词
        Map<Object, Object> map = new HashMap<>();
        map.put("current_date", new Date().toString());
        map.put("input", query);
        map.put("context", context);
        String rag = map.toString();
        StringBuilder answerBuilder = new StringBuilder();
        return chatClient.prompt()
                         .options(DeepSeekChatOptions.builder().model(model)
                                                     .temperature(0.8)
                                                     .responseFormat(ResponseFormat.builder().type(ResponseFormat.Type.TEXT)
                                                                                   .build()).build()
                         ).system(deepThinkPromptTemplate.getTemplate() +rag).user(query)
                         .advisors(memoryAdvisor -> memoryAdvisor.param(ChatRedisMemory.CONVERSATION_ID, chatId))
                         .stream().content()
                         .doOnNext(answerBuilder::append)
                         .doOnComplete(() -> {
                             chatSessionService.saveSession(ChatSession.builder()
                                                                    .chatId(chatId)
                                                                       .userId(useId)
                                                                       .content("question:"+query+"answer"+ answerBuilder)
                                                                       .createdAt(LocalDateTime.now()).build());
                         });
    }

    //embedding 文件上传
    @PostMapping("/upload")
    public String initKnowledgeBase(@RequestParam("file") MultipartFile file) {
        try {
            if (file.isEmpty()) {
                return "上传失败:文件为空";
            }
            // 获取项目根目录
            String projectRoot = System.getProperty("user.dir");
            File dataDir = new File(projectRoot, "data");
            // 校验是否为 PDF 文件(忽略大小写)
            String originalFilename = file.getOriginalFilename();
            if (originalFilename == null || !originalFilename.toLowerCase().endsWith(".pdf")) {
                return "上传失败:仅支持 PDF 文件格式";
            }
            // 校验文件大小(最大 10MB)
            long maxSize = 10 * 1024 * 1024; // 10MB
            if (file.getSize() > maxSize) {
                return "上传失败:文件大小不能超过 10MB";
            }
            // 若 data 目录不存在,则创建
            if (!dataDir.exists() && !dataDir.mkdirs()) {
                throw new IOException("无法创建 data 目录: " + dataDir.getAbsolutePath());
            }
            // 加入时间戳生成唯一文件名,防止冲突
            String timestamp = String.valueOf(System.currentTimeMillis());
            String safeFilename = timestamp + "-" + (originalFilename != null ? originalFilename : "uploaded.pdf");

            File savedFile = new File(dataDir, safeFilename);

            // 保存上传的文件到 data 目录
            file.transferTo(savedFile);

            // 调用文档处理逻辑
            documentService.loadAndStoreDocuments(savedFile);

            return "文件上传并处理成功,保存位置:" + savedFile.getAbsolutePath();
        } catch (Exception e) {
            log.error("文件上传失败: {}", e.getMessage(), e);
            return "文件上传失败:" + e.getMessage();
        }
    }

}

2、DocumentService

package com.springai.service;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.ExtractedTextFormatter;
import org.springframework.ai.reader.pdf.PagePdfDocumentReader;
import org.springframework.ai.reader.pdf.ParagraphPdfDocumentReader;
import org.springframework.ai.reader.pdf.config.PdfDocumentReaderConfig;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.FileSystemResource;
import org.springframework.core.io.Resource;
import org.springframework.stereotype.Service;

import java.io.File;
import java.util.List;
@Slf4j
@Service
@RequiredArgsConstructor
public class DocumentService {

    private final VectorStore vectorStore;


    public void loadAndStoreDocuments(File file) {
        try {
            PdfDocumentReaderConfig config = PdfDocumentReaderConfig.builder()
                                                                    .withPageExtractedTextFormatter(
                                                                            new ExtractedTextFormatter.Builder()
                                                                                    .withNumberOfTopTextLinesToDelete(0)
                                                                                    .build())
                                                                    .build();

            Resource resource = new FileSystemResource(file);
            ParagraphPdfDocumentReader pdfReader = new ParagraphPdfDocumentReader(resource, config);
            List<Document> documents = pdfReader.get();

            log.info("Total number of documents (paragraph-based): {}", documents.size());

            if (documents.isEmpty()) {
                PagePdfDocumentReader pdfReader2 = new PagePdfDocumentReader(resource);
                documents = pdfReader2.get();
                log.info("Total number of documents (page-based): {}", documents.size());
            }

            if (documents.isEmpty()) {
                log.warn("未能从 PDF 中提取任何内容,跳过存储操作。");
                return;
            }

            TokenTextSplitter splitter = new TokenTextSplitter();
            List<Document> splitDocs = splitter.apply(documents);

            vectorStore.add(splitDocs);

            log.info("成功加载并存储文档到向量数据库,条目数:{}", splitDocs.size());
        } catch (Exception e) {
            log.error("处理文档时发生异常: {}", e.getMessage(), e);
            throw new RuntimeException("文档处理失败,请检查文档内容或格式是否正确", e);
        }
    }
}

3、三个 Config 文件


package com.springai.config;

import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

@Configuration
public class DeepThinkPromptTemplateConfig {

	@Bean
	public PromptTemplate deepThinkPromptTemplate() {

		return new PromptTemplate(
				"""
                你是一位深思熟虑且专业的AI助手,你的任务是根据提供的文档内容和必要的联网搜索结果,准确、清晰地回答用户的问题。
            
                请遵循以下规则进行回答:
            
                1. 请先判断用户的问题是否清晰明确。如果问题过于模糊、缺乏语义(例如只输入了数字“1”或无意义字符),请友好地请求用户重新提问。
                2. 回答优先严格基于文档(context)中的内容;如果文档中没有相关信息,再进行联网搜索获取补充内容。
                3. 回答前,请先思考并将你的思路写在 `<think></think>` 标签中,说明你是如何分析用户问题、查找信息并组织答案的。
                4. 回答内容需专业、友好,结构清晰,使用合理的小标题(如“背景信息”、“解答内容”、“附加建议”等)。
                5. 如果引用的是文档内容,请在回答中明确标注“【来源:文档】”;如果引用的是联网搜索结果,请标注“【来源:联网搜索】”。
                6. 如文档和网络中均无相关信息,请礼貌地回答:“抱歉,我在文档和网上都未能找到相关信息。”
                7. 注意判断上下文关系,如果用户问的问题与之前的对话内容相关,请结合上下文进行回答。否则,直接基于当前文档和联网搜索结果回答。
            
                请思考并完成以下任务:
                
                当前日期:{current_date}
                
                文档内容如下:
                {context}
            
                用户提问:
                {input}
            
                请按照上述规范思考并输出你的思路,包裹在 <think></think> 中,然后正式开始回答用户问题。
                """
		);

	}}
package com.springai.config;

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;

@Configuration
public class RedisConfig {

    @Bean
    public RedisTemplate<String, Object> messageRedisTemplate(RedisConnectionFactory factory, Jackson2ObjectMapperBuilder builder) {
        RedisTemplate<String, Object> template = new RedisTemplate<>();
        template.setConnectionFactory(factory);

        // 使用String序列化器作为key的序列化方式
        template.setKeySerializer(new StringRedisSerializer());
        // 对value进行序列化
        template.setValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class));

        // 设置hash类型的key和value序列化方式
        template.setHashKeySerializer(new StringRedisSerializer());
        template.setHashValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class));

        template.afterPropertiesSet();
        return template;
    }
}
package com.springai.config;

import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.pgvector.PgVectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.jdbc.core.JdbcTemplate;

import javax.sql.DataSource;

@Configuration
public class VectorStoreConfig {

    @Bean
    public VectorStore vectorStore(@Qualifier("zhiPuAiEmbeddingModel") EmbeddingModel embeddingClient, JdbcTemplate jdbcTemplate) {
        return PgVectorStore.builder(jdbcTemplate, embeddingClient).build();
    }

    @Bean
    public JdbcTemplate jdbcTemplate(DataSource dataSource) {
        return new JdbcTemplate(dataSource);
    }
}

4、两个实体
 

package com.springai.entity;

import lombok.Data;

/**
 * 聊天消息实体类,表示一次聊天记录
 * 包含消息ID、类型、内容、时间戳和发送者信息
 */
@Data
public class ChatEntity {
    // 聊天消息唯一标识
    private String chatId;
    // 消息类型(如text、image等)
    private String type;
    // 消息内容
    private String text;
    // 消息发送时间戳
    private Long timestamp;
    private String sender;

}
package com.springai.entity;


import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Builder;
import lombok.Data;

import java.io.Serializable;
import java.time.LocalDateTime;

@Data
@Builder
@TableName("chat_session")
public class ChatSession implements Serializable {
    @TableId
    @Schema(description = "会话ID")
    private String chatId;
    @Schema(description = "用户ID")
    private String userId;
    @Schema(description = "会话时间")
    private LocalDateTime createdAt;
    @Schema(description = "对话内容")
    private String content;
}

5、聊天记录保存到 PG

package com.springai.session;


import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.springai.entity.ChatSession;
import com.springai.mapper.ChatSessionMapper;
import lombok.RequiredArgsConstructor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.List;

@Service
@RequiredArgsConstructor
public class ChatSessionService {

    @Autowired
    private ChatSessionMapper chatSessionMapper;

    /**
     * 保存会话到 PostgreSQL(使用MyBatis-Plus)
     * 如果用户id和chatId都相同,则追加content,否则新建
     */
    public void saveSession(ChatSession chatSession) {
        QueryWrapper<ChatSession> query = new QueryWrapper<>();
        query.eq("user_id", chatSession.getUserId())
             .eq("chat_id", chatSession.getChatId());
        ChatSession exist = chatSessionMapper.selectOne(query);
        if (exist != null) {
            // 追加content
            exist.setContent((exist.getContent() == null ? "" : exist.getContent()) + chatSession.getContent());
            chatSessionMapper.updateById(exist);
        } else {
            chatSessionMapper.insert(chatSession);
        }
    }

    /**
     * 获取指定用户的会话记录(使用MyBatis-Plus)
     */
    public List<ChatSession> getSessions(String userId) {
        QueryWrapper<ChatSession> query = new QueryWrapper<>();
        query.eq("user_id", userId).orderByDesc("created_at");
        return chatSessionMapper.selectList(query);
    }
}

6、添加到 Redis 方法

package com.springai.chatData;

import com.alibaba.fastjson.JSON;
import com.springai.entity.ChatEntity;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.*;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;


@Slf4j
@Component
public class ChatRedisMemory implements ChatMemory {

    private static final String KEY_PREFIX = "chat:history:";
    private final RedisTemplate<String, Object> redisTemplate;

    public ChatRedisMemory(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Override
    public void add(String conversationId, List<Message> messages) {
        String key = KEY_PREFIX + conversationId;
        List<String> jsonList = messages.stream()
                                        .map(msg -> {
                                            ChatEntity ent = new ChatEntity();
                                            ent.setChatId(conversationId);
                                            ent.setType(msg.getMessageType().getValue());
                                            ent.setText(msg.getText());
                                            ent.setTimestamp(System.currentTimeMillis());
                                            return JSON.toJSONString(ent);
                                        })
                                        .toList();
        try {
            redisTemplate.opsForList().rightPushAll(key, jsonList);
            redisTemplate.expire(key, 30, TimeUnit.MINUTES);
        } catch (Exception e) {
            log.error("Failed to add chat messages to Redis for conversationId: {}", conversationId, e);
        }
    }

    @Override
    public List<Message> get(String conversationId) {
        String key = KEY_PREFIX + conversationId;
        List<Object> jsonList = redisTemplate.opsForList().range(key, 0, -1);

        List<Message> messages = new ArrayList<>();
        if (jsonList == null || jsonList.isEmpty()) {
            return messages;
        }
        for (Object obj : jsonList) {
            String json = obj.toString();
            try {
                    ChatEntity ent = JSON.parseObject(json, ChatEntity.class);
                    String type = ent.getType();
                    String content = ent.getText();

                    Message msg = switch (type) {
                        case "user" -> new UserMessage(content);
                        case "assistant" -> new AssistantMessage(content);
                        case "system" -> new SystemMessage(content);
                        default -> {
                            log.warn("Unknown message type: {}", type);
                            yield null;
                        }
                    };

                    if (msg != null) {
                        messages.add(msg);
                    }
                } catch (Exception e) {
                    log.warn("Failed to parse chat message from Redis: {}", obj, e);
                }
            }
        return messages;
    }

    @Override
    public void clear(String conversationId) {
        redisTemplate.delete(KEY_PREFIX + conversationId);
    }
}

7、mapper

package com.springai.mapper;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.springai.entity.ChatSession;
import org.apache.ibatis.annotations.Mapper;

@Mapper
public interface ChatSessionMapper extends BaseMapper<ChatSession> {

}


8、下面是用到的 application.yml配置文件:

spring:
  datasource:
    url: jdbc:postgresql://localhost:5433/postgres
    username: postgres
    password: 123456
  ai:
    vectorstore: #向量存储相关
      pgvector:
        initialize-schema: true #自动创建表
        index-type: HNSW #索引的算法
        distance-type: COSINE_DISTANCE #使用 余弦距离 作为度量标准
        dimensions: 384 #向量模型 all-minilm:latest 的维度
    deepseek:
      api-key: "key"
      base-url: https://api.deepseek.com
      chat:
        options:
          model: deepseek-chat
    openai:
      api-key: "sk-fake-key"
    zhipuai:
      embedding:
        api-key: "key"
        options:
          model: embedding-2
      api-key: "key"

  data:
    redis:
      host: 127.0.0.1
      port: 46379
      database: 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值