Browse Source

【增加】增加模型联网搜索,重写用户 prompt

cherishsince 8 tháng trước cách đây
mục cha
commit
cb2e9e044f

+ 2 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/controller/admin/chat/vo/message/AiChatMessageSendReqVO.java

@@ -22,4 +22,6 @@ public class AiChatMessageSendReqVO {
     @Schema(description = "是否携带上下文", example = "true")
     private Boolean useContext;
 
+    @Schema(description = "搜索enable", example = "true")
+    private Boolean searchEnable;
 }

+ 58 - 7
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -23,6 +23,8 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
 import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
 import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
 import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
+import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService;
+import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO;
 import jakarta.annotation.Resource;
 import lombok.extern.slf4j.Slf4j;
 import org.springframework.ai.chat.messages.Message;
@@ -60,7 +62,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
     @Resource
     private AiChatMessageMapper chatMessageMapper;
-
     @Resource
     private AiChatConversationService chatConversationService;
     @Resource
@@ -69,6 +70,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
     private AiApiKeyService apiKeyService;
     @Resource
     private AiKnowledgeSegmentService knowledgeSegmentService;
+    @Resource
+    private WebSearchService webSearchService;
 
     @Transactional(rollbackFor = Exception.class)
     public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
@@ -93,8 +96,11 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         // 3.2 召回段落
         List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
 
+        // 3.3 联网搜索内容
+        List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
+
         // 3.3 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
         ChatResponse chatResponse = chatModel.call(prompt);
 
         // 3.4 段式返回
@@ -124,12 +130,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
                 userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
 
-
         // 3.2 召回段落
         List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
 
+        // 3.3 联网搜索
+        // todo count 看是否需要放到配置文件
+        List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
+
         // 3.3 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
+        Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 3.4 流式返回
@@ -155,6 +164,29 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         }).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
     }
 
+    /**
+     * 获取 web search
+     *
+     * @param prompt 提示词
+     * @param searchEnable 查询到会否开启
+     * @param count 查询数量
+     * @return 返回查询结果
+     */
+    private List<WebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) {
+        if (searchEnable != null && searchEnable) {
+            List<WebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count);
+            Map<String, String> webCrawlerRespMap
+                    = webSearchService.webCrawler(webSearchRespList.stream().map(WebSearchRespVO::getUrl).toList());
+            for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
+                if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) {
+                    continue;
+                }
+                webSearchRespVO.setContent(webCrawlerRespMap.get(webSearchRespVO.getUrl()));
+            }
+        }
+        return Collections.emptyList();
+    }
+
     private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
         if (Objects.isNull(knowledgeId)) {
             return Collections.emptyList();
@@ -162,8 +194,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
     }
 
-    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
-                               AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
+                               List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model,
+                               AiChatMessageSendReqVO sendReqVO, List<WebSearchRespVO> webSearchRespList) {
         // 1. 构建 Prompt Message 列表
         List<Message> chatMessages = new ArrayList<>();
 
@@ -184,7 +217,25 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
         contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
         // 1.4 user message 新发送消息
-        chatMessages.add(new UserMessage(sendReqVO.getContent()));
+        if (sendReqVO.getSearchEnable() != null
+                && sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) {
+
+            StringBuilder promptBuilder = StrUtil.builder();
+            promptBuilder.append("## 以下是联网搜索内容: \n");
+            int i = 1;
+            for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
+                promptBuilder.append("[内容%s begin]".formatted(i)).append("\n");
+                promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n");
+                promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n");
+                promptBuilder.append("内容:").append(webSearchRespVO.getContent()).append("\n");
+                promptBuilder.append("[内容%s end]".formatted(i)).append("\n");
+                i++;
+            }
+            promptBuilder.append("## 用户问题如下: \n");
+            promptBuilder.append(sendReqVO.getContent()).append("\n");
+        } else {
+            chatMessages.add(new UserMessage(sendReqVO.getContent()));
+        }
 
         // 2. 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());

+ 4 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/websearch/vo/WebSearchRespVO.java

@@ -22,4 +22,8 @@ public class WebSearchRespVO {
      * 摘要
      */
     private String snippet;
+    /**
+     * 网站内容
+     */
+    private String content;
 } 

+ 1 - 0
yudao-module-ai/yudao-module-ai-biz/src/test/java/cn/iocoder/yudao/module/ai/package-info.java

@@ -0,0 +1 @@
+package cn.iocoder.yudao.module.ai;

+ 3 - 1
yudao-server/src/main/resources/application-local.yaml

@@ -228,7 +228,9 @@ yudao:
   wxa-subscribe-message:
     miniprogram-state: developer # 跳转小程序类型:开发版为 “developer”;体验版为 “trial”为;正式版为 “formal”
   tencent-lbs-key: TVDBZ-TDILD-4ON4B-PFDZA-RNLKH-VVF6E # QQ 地图的密钥 https://lbs.qq.com/service/staticV2/staticGuide/staticDoc
-
+  web-search:
+      bing-api-key: xx
+      google-api-key: xx
 justauth:
   enabled: true
   type: