Forráskód Böngészése

!1323 feat:AI工具新增ToolContext
Merge pull request !1323 from Ren/feature/ai

芋道源码 6 hónapja
szülő
commit
fe2122d3be

+ 12 - 4
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -4,10 +4,12 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
@@ -115,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                 knowledgeSegments);
 
         // 3.2 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
         // 3.3 更新响应内容
@@ -164,7 +166,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                 knowledgeSegments);
 
         // 4.2 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 4.3 流式返回
@@ -220,7 +222,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return knowledgeSegments;
     }
 
-    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
+    private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List<AiChatMessageDO> messages,
             List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
             AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         List<Message> chatMessages = new ArrayList<>();
@@ -247,16 +249,22 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
 
         // 2.1 查询 tool 工具
         Set<String> toolNames = null;
+        Map<String,Object> toolContext = Map.of();
         if (conversation.getRoleId() != null) {
             AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
             if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
                 toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
+                // 2.1.1 构建 Function Calling 的上下文参数
+                toolContext = Map.of(
+                    AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId())
+                    .setRoleId(conversation.getRoleId())
+                    .setConversationId(conversation.getId()));
             }
         }
         // 2.2 构建 ChatOptions 对象
         AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
         ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
-                conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
+                conversation.getTemperature(), conversation.getMaxTokens(), toolNames, toolContext);
         return new Prompt(chatMessages, chatOptions);
     }
 

+ 43 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserIdQueryToolFunction.java

@@ -0,0 +1,43 @@
+package cn.iocoder.yudao.module.ai.service.model.tool;
+
+import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.ai.chat.model.ToolContext;
+import org.springframework.stereotype.Component;
+
+import java.util.function.BiFunction;
+
+/**
+ * 工具:用户ID查询(上下文参数Demo)
+ *
+ * @author Ren
+ */
+@Component("userid_query")
+public class UserIdQueryToolFunction
+        implements BiFunction<UserIdQueryToolFunction.Request, ToolContext, UserIdQueryToolFunction.Response> {
+
+    @Data
+    @JsonClassDescription("用户ID查询")
+    public static class Request { }
+
+    @Data
+    @AllArgsConstructor
+    @NoArgsConstructor
+    public static class Response {
+        /**
+         * 用户ID
+         */
+        private Long UserId;
+
+    }
+    @Override
+    public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) {
+        // 获取当前登录用户
+        AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY);
+
+        return new Response(context.getUserId());
+    }
+}

+ 37 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/pojo/AiToolContext.java

@@ -0,0 +1,37 @@
+package cn.iocoder.yudao.framework.ai.core.pojo;
+
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.ai.chat.model.ChatModel;
+import org.springframework.ai.chat.model.StreamingChatModel;
+
+/**
+ * 工具上下文参数 DTO,让AI工具可以处理当前用户的相关信息
+ *
+ */
+@Data
+@NoArgsConstructor
+@AllArgsConstructor
+public class AiToolContext {
+    public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT";
+    /**
+     * 用户ID
+     */
+    private Long userId;
+
+    /**
+     * 聊天模型
+     */
+    private StreamingChatModel chatModel;
+
+    /**
+     * 关联的聊天角色Id
+     */
+    private Long roleId;
+
+    /**
+     * 会话Id
+     */
+    private Long conversationId;
+}

+ 10 - 9
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -15,6 +15,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
 import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
 
 import java.util.Collections;
+import java.util.Map;
 import java.util.Set;
 
 /**
@@ -25,28 +26,28 @@ import java.util.Set;
 public class AiUtils {
 
     public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
-        return buildChatOptions(platform, model, temperature, maxTokens, null);
+        return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of());
     }
 
     public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
-                                               Set<String> toolNames) {
+                                               Set<String> toolNames, Map<String, Object> toolContext) {
         toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
         // noinspection EnhancedSwitchMigration
         switch (platform) {
             case TONG_YI:
                 return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
-                        .withFunctions(toolNames).build();
+                        .withFunctions(toolNames).withToolContext(toolContext).build();
             case YI_YAN:
                 return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
             case ZHI_PU:
                 return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
-                        .functions(toolNames).build();
+                        .functions(toolNames).toolContext(toolContext).build();
             case MINI_MAX:
                 return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
-                        .functions(toolNames).build();
+                        .functions(toolNames).toolContext(toolContext).build();
             case MOONSHOT:
                 return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
-                        .functions(toolNames).build();
+                        .functions(toolNames).toolContext(toolContext).build();
             case OPENAI:
             case DEEP_SEEK: // 复用 OpenAI 客户端
             case DOU_BAO: // 复用 OpenAI 客户端
@@ -55,14 +56,14 @@ public class AiUtils {
             case SILICON_FLOW: // 复用 OpenAI 客户端
             case BAI_CHUAN: // 复用 OpenAI 客户端
                 return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
-                        .toolNames(toolNames).build();
+                        .toolNames(toolNames).toolContext(toolContext).build();
             case AZURE_OPENAI:
                 // TODO 芋艿:貌似没 model 字段???!
                 return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
-                        .toolNames(toolNames).build();
+                        .toolNames(toolNames).toolContext(toolContext).build();
             case OLLAMA:
                 return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
-                        .toolNames(toolNames).build();
+                        .toolNames(toolNames).toolContext(toolContext).build();
             default:
                 throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
         }