|
@@ -23,6 +23,8 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
|
|
|
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
|
import cn.iocoder.yudao.module.ai.service.knowledge.AiKnowledgeSegmentService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
|
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
|
|
|
|
|
+import cn.iocoder.yudao.module.ai.service.websearch.WebSearchService;
|
|
|
|
|
+import cn.iocoder.yudao.module.ai.service.websearch.vo.WebSearchRespVO;
|
|
|
import jakarta.annotation.Resource;
|
|
import jakarta.annotation.Resource;
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
import org.springframework.ai.chat.messages.Message;
|
|
import org.springframework.ai.chat.messages.Message;
|
|
@@ -60,7 +62,6 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
|
|
|
|
|
@Resource
|
|
@Resource
|
|
|
private AiChatMessageMapper chatMessageMapper;
|
|
private AiChatMessageMapper chatMessageMapper;
|
|
|
-
|
|
|
|
|
@Resource
|
|
@Resource
|
|
|
private AiChatConversationService chatConversationService;
|
|
private AiChatConversationService chatConversationService;
|
|
|
@Resource
|
|
@Resource
|
|
@@ -69,6 +70,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
private AiApiKeyService apiKeyService;
|
|
private AiApiKeyService apiKeyService;
|
|
|
@Resource
|
|
@Resource
|
|
|
private AiKnowledgeSegmentService knowledgeSegmentService;
|
|
private AiKnowledgeSegmentService knowledgeSegmentService;
|
|
|
|
|
+ @Resource
|
|
|
|
|
+ private WebSearchService webSearchService;
|
|
|
|
|
|
|
|
@Transactional(rollbackFor = Exception.class)
|
|
@Transactional(rollbackFor = Exception.class)
|
|
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
|
public AiChatMessageSendRespVO sendMessage(AiChatMessageSendReqVO sendReqVO, Long userId) {
|
|
@@ -93,8 +96,11 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
// 3.2 召回段落
|
|
// 3.2 召回段落
|
|
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
|
|
|
|
|
|
|
|
+ // 3.3 联网搜索内容
|
|
|
|
|
+ List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
|
|
|
|
|
+
|
|
|
// 3.3 创建 chat 需要的 Prompt
|
|
// 3.3 创建 chat 需要的 Prompt
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
|
|
|
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
|
|
|
ChatResponse chatResponse = chatModel.call(prompt);
|
|
ChatResponse chatResponse = chatModel.call(prompt);
|
|
|
|
|
|
|
|
// 3.4 段式返回
|
|
// 3.4 段式返回
|
|
@@ -124,12 +130,15 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
AiChatMessageDO assistantMessage = createChatMessage(conversation.getId(), userMessage.getId(), model,
|
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
userId, conversation.getRoleId(), MessageType.ASSISTANT, "", sendReqVO.getUseContext());
|
|
|
|
|
|
|
|
-
|
|
|
|
|
// 3.2 召回段落
|
|
// 3.2 召回段落
|
|
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
List<AiKnowledgeSegmentDO> segmentList = recallSegment(sendReqVO.getContent(), conversation.getKnowledgeId());
|
|
|
|
|
|
|
|
|
|
+ // 3.3 联网搜索
|
|
|
|
|
+ // todo count 看是否需要放到配置文件
|
|
|
|
|
+ List<WebSearchRespVO> webSearch = getWebSearch(sendReqVO.getContent(), sendReqVO.getSearchEnable(), 10);
|
|
|
|
|
+
|
|
|
// 3.3 构建 Prompt,并进行调用
|
|
// 3.3 构建 Prompt,并进行调用
|
|
|
- Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO);
|
|
|
|
|
|
|
+ Prompt prompt = buildPrompt(conversation, historyMessages, segmentList, model, sendReqVO, webSearch);
|
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
|
|
|
|
|
|
|
|
// 3.4 流式返回
|
|
// 3.4 流式返回
|
|
@@ -155,6 +164,29 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
|
}).onErrorResume(error -> Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR)));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
|
+ * 获取 web search
|
|
|
|
|
+ *
|
|
|
|
|
+ * @param prompt 提示词
|
|
|
|
|
+ * @param searchEnable 查询到会否开启
|
|
|
|
|
+ * @param count 查询数量
|
|
|
|
|
+ * @return 返回查询结果
|
|
|
|
|
+ */
|
|
|
|
|
+ private List<WebSearchRespVO> getWebSearch(String prompt, Boolean searchEnable, int count) {
|
|
|
|
|
+ if (searchEnable != null && searchEnable) {
|
|
|
|
|
+ List<WebSearchRespVO> webSearchRespList = webSearchService.bingSearch(prompt, count);
|
|
|
|
|
+ Map<String, String> webCrawlerRespMap
|
|
|
|
|
+ = webSearchService.webCrawler(webSearchRespList.stream().map(WebSearchRespVO::getUrl).toList());
|
|
|
|
|
+ for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
|
|
|
|
|
+ if (!webCrawlerRespMap.containsKey(webSearchRespVO.getUrl())) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ webSearchRespVO.setContent(webCrawlerRespMap.get(webSearchRespVO.getUrl()));
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+ return Collections.emptyList();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
|
|
private List<AiKnowledgeSegmentDO> recallSegment(String content, Long knowledgeId) {
|
|
|
if (Objects.isNull(knowledgeId)) {
|
|
if (Objects.isNull(knowledgeId)) {
|
|
|
return Collections.emptyList();
|
|
return Collections.emptyList();
|
|
@@ -162,8 +194,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
|
|
return knowledgeSegmentService.similaritySearch(new AiKnowledgeSegmentSearchReqVO().setKnowledgeId(knowledgeId).setContent(content));
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,List<AiKnowledgeSegmentDO> segmentList,
|
|
|
|
|
- AiChatModelDO model, AiChatMessageSendReqVO sendReqVO) {
|
|
|
|
|
|
|
+ private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
|
|
|
|
|
+ List<AiKnowledgeSegmentDO> segmentList, AiChatModelDO model,
|
|
|
|
|
+ AiChatMessageSendReqVO sendReqVO, List<WebSearchRespVO> webSearchRespList) {
|
|
|
// 1. 构建 Prompt Message 列表
|
|
// 1. 构建 Prompt Message 列表
|
|
|
List<Message> chatMessages = new ArrayList<>();
|
|
List<Message> chatMessages = new ArrayList<>();
|
|
|
|
|
|
|
@@ -184,7 +217,25 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
|
|
|
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
|
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
|
|
|
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
|
contextMessages.forEach(message -> chatMessages.add(AiUtils.buildMessage(message.getType(), message.getContent())));
|
|
|
// 1.4 user message 新发送消息
|
|
// 1.4 user message 新发送消息
|
|
|
- chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
|
|
|
|
|
|
+ if (sendReqVO.getSearchEnable() != null
|
|
|
|
|
+ && sendReqVO.getSearchEnable() && CollUtil.isNotEmpty(webSearchRespList)) {
|
|
|
|
|
+
|
|
|
|
|
+ StringBuilder promptBuilder = StrUtil.builder();
|
|
|
|
|
+ promptBuilder.append("## 以下是联网搜索内容: \n");
|
|
|
|
|
+ int i = 1;
|
|
|
|
|
+ for (WebSearchRespVO webSearchRespVO : webSearchRespList) {
|
|
|
|
|
+ promptBuilder.append("[内容%s begin]".formatted(i)).append("\n");
|
|
|
|
|
+ promptBuilder.append("标题:").append(webSearchRespVO.getTitle()).append("\n");
|
|
|
|
|
+ promptBuilder.append("地址:").append(webSearchRespVO.getUrl()).append("\n");
|
|
|
|
|
+ promptBuilder.append("内容:").append(webSearchRespVO.getContent()).append("\n");
|
|
|
|
|
+ promptBuilder.append("[内容%s end]".formatted(i)).append("\n");
|
|
|
|
|
+ i++;
|
|
|
|
|
+ }
|
|
|
|
|
+ promptBuilder.append("## 用户问题如下: \n");
|
|
|
|
|
+ promptBuilder.append(sendReqVO.getContent()).append("\n");
|
|
|
|
|
+ } else {
|
|
|
|
|
+ chatMessages.add(new UserMessage(sendReqVO.getContent()));
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
// 2. 构建 ChatOptions 对象
|
|
// 2. 构建 ChatOptions 对象
|
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|
|
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
|