前面介绍了使用关系数据库来存储对话历史,接下来我们更进一步,通过扩展ChatMemoryRepository来实现基于Redis的持久化
一、Redis持久化对话历史 1. 创建项目 创建一个SpringAI项目,基本流程同 创建一个SpringAI-Demo工程
2. 添加依赖 在pom.xml中添加关键依赖
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 <dependencies > <dependency > <groupId > org.springframework.boot</groupId > <artifactId > spring-boot-starter-web</artifactId > </dependency > <dependency > <groupId > org.springframework.ai</groupId > <artifactId > spring-ai-starter-model-zhipuai</artifactId > </dependency > <dependency > <groupId > org.springframework.ai</groupId > <artifactId > spring-ai-starter-model-chat-memory</artifactId > </dependency > <dependency > <groupId > org.springframework.data</groupId > <artifactId > spring-data-redis</artifactId > </dependency > <dependency > <groupId > redis.clients</groupId > <artifactId > jedis</artifactId > </dependency > </dependencies >
直接依赖spring-ai-starter-model-chat-memory来实现聊天记录,这里采用redis来进行持久化,因此我们添加redis的相关依赖(这里一jedis客户端为例,对于lettuce或者redisson实际上没有太大的差别)
3. 配置 在配置文件 application.yml 文件中,指定redis相关配置
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 spring: ai: zhipuai: api-key: ${zhipuai-api-key} chat: options: model: GLM-4-Flash data: redis: host: localhost port: 6379 password: logging: level: org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor: debug
我们这里使用单节点的redis进行演示;对于集群的场景,除了配置不一样之外,实际上对于我们的实现并没有太大的差别(因为我们利用的是Spring封装的RedisTemplate进行操作,因此对于集群的场景,实际上也是可以的)
4. 实现RedisChatMemoryRepository 通过实现接口ChatMemoryRepository来完成我们的 RedisChatMemoryRepository
对于用户聊天记录,我们采用redis的列表的数据结构进行存储(具体实现除了这种方案之外,也可以使用hash数据结构来存储,field为会话id,value为对话记录)
key: chat:${conversationId}
value: [] , 消息列表
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 @Component public class RedisChatMemoryRepository implements ChatMemoryRepository { @Autowired private StringRedisTemplate redisTemplate; private static final String PREFIX = "chat:" ; @Override public List<String> findConversationIds () { Set<String> ans = redisTemplate.keys(PREFIX + "*" ); return ans.stream().map(key -> key.substring(PREFIX.length())).collect(Collectors.toList()); } @Override public List<Message> findByConversationId (String conversationId) { String key = PREFIX + conversationId; List<String> ans = redisTemplate.opsForList().range(key, 0 , -1 ); if (CollectionUtils.isEmpty(ans)) { return Collections.emptyList(); } return ans.stream().map(item -> JsonUtil.toObj(item, Message.class)).collect(Collectors.toList()); } @Override public void saveAll (String conversationId, List<Message> messages) { String key = PREFIX + conversationId; List<String> messageJsons = messages.stream().map(JsonUtil::toStr).toList(); redisTemplate.delete(key); redisTemplate.opsForList().leftPushAll(key, messageJsons); } @Override public void deleteByConversationId (String conversationId) { redisTemplate.delete(PREFIX + conversationId); } }
基于redis的存储实现相对来讲并不算复杂,无非就是Message的读取与写入,这里我们采用json进行序列化与反序列化
我们实现一个自定义的反序列化策略,在获取数据时,只要Message中的 messageType + text,因为这些会一并传递给大模型,其他的元数据并没有太大意义,还会消耗我们的token
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 public class MessageDeserializer extends JsonDeserializer <Message > { private static final Logger logger = LoggerFactory.getLogger(MessageDeserializer.class); private final Map<String, Function<String, Message>> msgFactor = Map.of( "USER" , UserMessage::new , "SYSTEM" , UserMessage::new , "ASSISTANT" , UserMessage::new ); @Override public Message deserialize (JsonParser p, DeserializationContext ctxt) throws IOException, JacksonException { JsonNode node = p.getCodec().readTree(p); if (node.isTextual()) { return new UserMessage(node.asText()); } String type = extractMessageType(node); String content = extractContent(node); return Optional.ofNullable(type).map(String::toUpperCase).map(msgFactor::get).orElseGet(() -> { if (type == null ) { logger.warn("Message type not found, defaulting to USER" ); } else { logger.warn("Unknown message type: {}, defaulting to USER" , type); } return msgFactor.get("USER" ); }).apply(content); } private String extractMessageType (JsonNode node) { return Optional.ofNullable(node.get("messageType" )) .map(JsonNode::asText) .orElse(null ); } private String extractContent (JsonNode node) { return Optional.ofNullable(node.get("text" )) .map(JsonNode::asText) .orElseGet(node::toString); } }
对应的JsonUtil的实现则比较基础,一个统一的json工具类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 public class JsonUtil { private static ObjectMapper mapper = new ObjectMapper(); static { mapper.findAndRegisterModules(); SimpleModule module = new SimpleModule(); module .addDeserializer(Message.class, new MessageDeserializer()); mapper.registerModule(module ); mapper.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); } public static String toStr (Object o) { try { return mapper.writeValueAsString(o); } catch (IOException e) { throw new RuntimeException(e); } } public static <T> T toObj (String s, Class<T> clazz) { try { return mapper.readValue(s, clazz); } catch (IOException e) { throw new RuntimeException(e); } } }
5. 定义ChatMemory 定义一个自定义的ChatMemory,这里我们使用 RedisChatMemoryRepository 来作为消息的存储仓库
1 2 3 4 5 6 7 8 9 @Configuration public class MemConfig { @Bean public ChatMemory jdbcChatMemory (RedisChatMemoryRepository chatMemoryRepository) { return MessageWindowChatMemory.builder() .chatMemoryRepository(chatMemoryRepository) .build(); } }
6. 示例测试 提供一个聊天接口,第一个参数为用户标识,用于区分用户的聊天记录
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 @RestController public class ChatController { private final ChatClient chatClient; public ChatController (ChatModel chatModel, ChatMemory chatMemory) { this .chatClient = ChatClient.builder(chatModel) .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build(), new SimpleLoggerAdvisor()) .build(); } @GetMapping ("/{user}/chat" ) public Object chat (@PathVariable String user, String msg) { return chatClient.prompt().user(msg) .advisors(a -> a.param(ChatMemory.CONVERSATION_ID, user)) .call().content(); } }
二、小结 本文介绍的是如何通过实现SpringAI定义的ChatMemoryRepository接口来完成自定义的对话记录持久化,通过这种方式,可以实现更灵活的上下文管控
文中所有涉及到的代码,可以到项目中获取 https://github.com/liuyueyi/spring-ai-demo
微信公众号: 一灰灰Blog 尽信书则不如,以上内容,纯属一家之言,因个人能力有限,难免有疏漏和错误之处,如发现bug或者有更好的建议,欢迎批评指正,不吝感激
下面一灰灰的个人博客,记录所有学习和工作中的博文,欢迎大家前去逛逛
打赏
如果觉得我的文章对您有帮助,请随意打赏。
微信打赏
支付宝打赏