Sfoglia il codice sorgente

feat:【AI 大模型】增加 AI ToolContext 上下文

YunaiV 6 mesi fa
parent
commit
e11ee654ef

+ 7 - 14
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/chat/AiChatMessageServiceImpl.java

@@ -4,12 +4,10 @@ import cn.hutool.core.collection.CollUtil;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
-import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
 import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
 import cn.iocoder.yudao.framework.common.pojo.CommonResult;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
-import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
 import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessagePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.chat.vo.message.AiChatMessageRespVO;
@@ -103,8 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         ChatModel chatModel = modalService.getChatModel(model.getId());
 
         // 2. 知识库找回
-        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
-                conversation);
+        List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation);
 
         // 3. 插入 user 发送消息
         AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@@ -117,7 +114,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                 knowledgeSegments);
 
         // 3.2 创建 chat 需要的 Prompt
-        Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         ChatResponse chatResponse = chatModel.call(prompt);
 
         // 3.3 更新响应内容
@@ -166,7 +163,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
                 knowledgeSegments);
 
         // 4.2 构建 Prompt,并进行调用
-        Prompt prompt = buildPrompt(chatModel, conversation, historyMessages, knowledgeSegments, model, sendReqVO);
+        Prompt prompt = buildPrompt(conversation, historyMessages, knowledgeSegments, model, sendReqVO);
         Flux<ChatResponse> streamResponse = chatModel.stream(prompt);
 
         // 4.3 流式返回
@@ -222,9 +219,9 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
         return knowledgeSegments;
     }
 
-    private Prompt buildPrompt(StreamingChatModel chatModel, AiChatConversationDO conversation, List<AiChatMessageDO> messages,
-            List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
-            AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
+    private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
+                               List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
+                               AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
         List<Message> chatMessages = new ArrayList<>();
         // 1.1 System Context 角色设定
         if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
@@ -254,11 +251,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
             AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
             if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
                 toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
-                // 2.1.1 构建 Function Calling 的上下文参数
-                toolContext = Map.of(
-                    AiToolContext.CONTEXT_KEY, new AiToolContext().setChatModel(chatModel).setUserId(SecurityFrameworkUtils.getLoginUserId())
-                    .setRoleId(conversation.getRoleId())
-                    .setConversationId(conversation.getId()));
+                toolContext = AiUtils.buildCommonToolContext();
             }
         }
         // 2.2 构建 ChatOptions 对象

+ 7 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeService.java

@@ -29,6 +29,13 @@ public interface AiKnowledgeService {
      */
     void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
 
+    /**
+     * 删除知识库
+     *
+     * @param id 知识库编号
+     */
+    void deleteKnowledge(Long id);
+
     /**
      * 获得知识库
      *

+ 5 - 2
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/knowledge/AiKnowledgeServiceImpl.java

@@ -1,13 +1,11 @@
 package cn.iocoder.yudao.module.ai.service.knowledge;
 
 import cn.hutool.core.util.ObjUtil;
-import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
 import cn.iocoder.yudao.framework.common.pojo.PageResult;
 import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
 import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
 import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
-import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
 import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
 import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
 import cn.iocoder.yudao.module.ai.service.model.AiModelService;
@@ -67,6 +65,11 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
         }
     }
 
+    @Override
+    public void deleteKnowledge(Long id) {
+
+    }
+
     @Override
     public AiKnowledgeDO getKnowledge(Long id) {
         return knowledgeMapper.selectById(id);

+ 0 - 43
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserIdQueryToolFunction.java

@@ -1,43 +0,0 @@
-package cn.iocoder.yudao.module.ai.service.model.tool;
-
-import cn.iocoder.yudao.framework.ai.core.pojo.AiToolContext;
-import com.fasterxml.jackson.annotation.JsonClassDescription;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.springframework.ai.chat.model.ToolContext;
-import org.springframework.stereotype.Component;
-
-import java.util.function.BiFunction;
-
-/**
- * 工具:用户ID查询(上下文参数Demo)
- *
- * @author Ren
- */
-@Component("userid_query")
-public class UserIdQueryToolFunction
-        implements BiFunction<UserIdQueryToolFunction.Request, ToolContext, UserIdQueryToolFunction.Response> {
-
-    @Data
-    @JsonClassDescription("用户ID查询")
-    public static class Request { }
-
-    @Data
-    @AllArgsConstructor
-    @NoArgsConstructor
-    public static class Response {
-        /**
-         * 用户ID
-         */
-        private Long UserId;
-
-    }
-    @Override
-    public UserIdQueryToolFunction.Response apply(UserIdQueryToolFunction.Request request, ToolContext toolContext) {
-        // 获取当前登录用户
-        AiToolContext context = (AiToolContext) toolContext.getContext().get(AiToolContext.CONTEXT_KEY);
-
-        return new Response(context.getUserId());
-    }
-}

+ 75 - 0
yudao-module-ai/yudao-module-ai-biz/src/main/java/cn/iocoder/yudao/module/ai/service/model/tool/UserProfileQueryToolFunction.java

@@ -0,0 +1,75 @@
+package cn.iocoder.yudao.module.ai.service.model.tool;
+
+import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
+import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
+import cn.iocoder.yudao.framework.security.core.LoginUser;
+import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
+import cn.iocoder.yudao.module.system.api.user.AdminUserApi;
+import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
+import com.fasterxml.jackson.annotation.JsonClassDescription;
+import jakarta.annotation.Resource;
+import lombok.AllArgsConstructor;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+import org.springframework.ai.chat.model.ToolContext;
+import org.springframework.stereotype.Component;
+
+import java.util.function.BiFunction;
+
+/**
+ * 工具:当前用户信息查询
+ *
+ * 同时,也是展示 ToolContext 上下文的使用
+ *
+ * @author Ren
+ */
+@Component("user_profile_query")
+public class UserProfileQueryToolFunction
+        implements BiFunction<UserProfileQueryToolFunction.Request, ToolContext, UserProfileQueryToolFunction.Response> {
+
+    @Resource
+    private AdminUserApi adminUserApi;
+
+    @Data
+    @JsonClassDescription("当前用户信息查询")
+    public static class Request { }
+
+    @Data
+    @AllArgsConstructor
+    @NoArgsConstructor
+    public static class Response {
+
+        /**
+         * 用户ID
+         */
+        private Long id;
+        /**
+         * 用户昵称
+         */
+        private String nickname;
+
+        /**
+         * 手机号码
+         */
+        private String mobile;
+        /**
+         * 用户头像
+         */
+        private String avatar;
+
+    }
+
+    @Override
+    public UserProfileQueryToolFunction.Response apply(UserProfileQueryToolFunction.Request request, ToolContext toolContext) {
+        LoginUser loginUser = (LoginUser) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_LOGIN_USER);
+        Long tenantId = (Long) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_TENANT_ID);
+        if (loginUser == null | tenantId == null) {
+            return null;
+        }
+        return TenantUtils.execute(tenantId, () -> {
+            AdminUserRespDTO user = adminUserApi.getUser(loginUser.getId());
+            return BeanUtils.toBean(user, Response.class);
+        });
+    }
+
+}

+ 12 - 0
yudao-module-ai/yudao-spring-boot-starter-ai/pom.xml

@@ -24,6 +24,18 @@
             <artifactId>yudao-common</artifactId>
         </dependency>
 
+        <!-- 业务组件 -->
+        <dependency>
+            <groupId>cn.iocoder.boot</groupId>
+            <artifactId>yudao-spring-boot-starter-biz-tenant</artifactId>
+        </dependency>
+
+        <!-- Web 相关 -->
+        <dependency>
+            <groupId>cn.iocoder.boot</groupId>
+            <artifactId>yudao-spring-boot-starter-security</artifactId>
+        </dependency>
+
         <!-- Spring AI Model 模型接入 -->
         <dependency>
             <groupId>org.springframework.ai</groupId>

+ 0 - 37
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/pojo/AiToolContext.java

@@ -1,37 +0,0 @@
-package cn.iocoder.yudao.framework.ai.core.pojo;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.chat.model.StreamingChatModel;
-
-/**
- * 工具上下文参数 DTO,让AI工具可以处理当前用户的相关信息
- *
- */
-@Data
-@NoArgsConstructor
-@AllArgsConstructor
-public class AiToolContext {
-    public static final String CONTEXT_KEY = "AI_TOOL_CONTEXT";
-    /**
-     * 用户ID
-     */
-    private Long userId;
-
-    /**
-     * 聊天模型
-     */
-    private StreamingChatModel chatModel;
-
-    /**
-     * 关联的聊天角色Id
-     */
-    private Long roleId;
-
-    /**
-     * 会话Id
-     */
-    private Long conversationId;
-}

+ 14 - 1
yudao-module-ai/yudao-spring-boot-starter-ai/src/main/java/cn/iocoder/yudao/framework/ai/core/util/AiUtils.java

@@ -3,6 +3,8 @@ package cn.iocoder.yudao.framework.ai.core.util;
 import cn.hutool.core.util.ObjUtil;
 import cn.hutool.core.util.StrUtil;
 import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
+import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
+import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
 import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
 import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
 import org.springframework.ai.chat.messages.*;
@@ -15,6 +17,7 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
 import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
 
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 
@@ -25,8 +28,11 @@ import java.util.Set;
  */
 public class AiUtils {
 
+    public static final String TOOL_CONTEXT_LOGIN_USER = "LOGIN_USER";
+    public static final String TOOL_CONTEXT_TENANT_ID = "TENANT_ID";
+
     public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
-        return buildChatOptions(platform, model, temperature, maxTokens, null, Map.of());
+        return buildChatOptions(platform, model, temperature, maxTokens, null, null);
     }
 
     public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
@@ -85,4 +91,11 @@ public class AiUtils {
         throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
     }
 
+    public static Map<String, Object> buildCommonToolContext() {
+        Map<String, Object> context = new HashMap<>();
+        context.put(TOOL_CONTEXT_LOGIN_USER, SecurityFrameworkUtils.getLoginUser());
+        context.put(TOOL_CONTEXT_TENANT_ID, TenantContextHolder.getTenantId());
+        return context;
+    }
+
 }