refactor(task): 重构异步任务处理流程

This commit is contained in:
2026-05-27 09:36:26 +08:00
parent 2548ffc7ac
commit d74559ae74
10 changed files with 162 additions and 212 deletions

View File

@@ -5,10 +5,10 @@ import (
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
"gitea.com/red-future/common/beans"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"prompts-core/common/util"
@@ -44,17 +44,27 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if err != nil {
return nil, nil, fmt.Errorf("获取用户信息失败: %w", err)
}
chatModel, err := getChatModel(ctx, userInfo.UserName)
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
IsChatModel: new(1),
})
if err != nil {
return nil, nil, err
}
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
if err != nil {
return nil, nil, err
}
aiModel, err := getAIModel(ctx, userInfo.UserName, req.ModelName)
if err != nil {
return nil, nil, err
if aiModel == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
}
@@ -73,51 +83,24 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *entity.AsynchModel) er
return fmt.Errorf("UserForm 内容超出窗口大小: 超出 %d tokens可用窗口 %d tokens请精简后重试",
exceedTokens, availableWindow)
}
return nil
}
// handlePromptBuild 处理提示词构建BuildType=1
func handlePromptBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 获取历史会话
history, err := GetHistoryMessages(ctx, req.SessionId)
history, err := session.GetHistoryMessages(ctx, req.SessionId)
if err != nil {
g.Log().Errorf(ctx, "获取历史会话失败: %v将不使用历史会话", err)
history = nil
}
// 调用推理模型
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, history)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
if err = saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
if err := saveComposeTask(ctx, taskID, req); err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
}, nil
}
// saveComposeTask 保存组合任务记录
func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessagesReq) error {
_, err := dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
@@ -126,77 +109,70 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
return err
if err != nil {
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// getChatModel 获取聊天模型
func getChatModel(ctx context.Context, userName string) (*entity.AsynchModel, error) {
chatModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
IsChatModel: new(1),
// handleNodeBuild 处理节点构建BuildType=2
func handleNodeBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *entity.AsynchModel) (*dto.ComposeMessagesRes, error) {
taskID, id, err := callInferenceModel(ctx, req, chatModel, aiModel, nil)
if err != nil {
return nil, fmt.Errorf("调用推理模型失败: %w", err)
}
// 保存任务记录
_, err = dao.ComposeTask.Insert(ctx, &entity.ComposeTask{
TaskId: taskID,
ModelName: req.ModelName,
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
Status: public.ComposeStatusPending,
})
if err != nil {
return nil, fmt.Errorf("查询聊天模型失败: %w", err)
return nil, fmt.Errorf("保存任务记录失败: %w", err)
}
if chatModel == nil {
return nil, errors.New("当前没有对话模型,请添加")
}
return chatModel, nil
}
// getAIModel 获取AI模型
func getAIModel(ctx context.Context, userName, modelName string) (*entity.AsynchModel, error) {
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userName},
ModelName: modelName,
})
if err != nil {
return nil, fmt.Errorf("查询AI模型失败: %w", err)
}
if aiModel == nil {
return nil, fmt.Errorf("需要构建的模型 %s 不存在", modelName)
}
return aiModel, nil
return &dto.ComposeMessagesRes{
TaskId: taskID,
EpicycleId: id,
}, nil
}
// callInferenceModel 调用推理模型
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, idModel *entity.AsynchModel, history []map[string]any) (string, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, idModel, history)
func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (string, int64, error) {
taskReq, err := buildInferenceRequest(ctx, req, chatModel, aiModel, history)
if err != nil {
return "", fmt.Errorf("构建推理请求失败: %w", err)
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {
return "", fmt.Errorf("创建网关任务失败: %w", err)
return "", 0, fmt.Errorf("创建网关任务失败: %w", err)
}
if taskID == "" {
return "", errors.New("网关未返回taskId")
return "", 0, errors.New("网关未返回taskId")
}
return taskID, nil
}
// createDefaultResult 创建默认结果
func createDefaultResult(data map[string]any) map[string]any {
if data == nil {
data = make(map[string]any)
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{data},
}
return taskID, id, nil
}
// Callback 回调处理
func Callback(ctx context.Context, req *dto.CallbackReq) error {
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Messages))
// 查询任务
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: req.TaskId,
@@ -220,7 +196,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
ResultText: req.Messages,
})
// 用更新后的值发送回调
if composeTask.CallbackUrl != "" {
@@ -241,11 +217,11 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
var messages map[string]any
switch composeTask.BuildType {
case public.BuildTypePrompt: // 提示词构建解析
messages = ParsePromptResult(req.Text)
messages = ParsePromptResult(req.Messages)
case public.BuildTypeNode: // 节点构建解析
messages = ParseNodeResult(req.Text)
messages = ParseNodeResult(req.Messages)
default:
messages = gjson.New(req.Text).Map()
messages = req.Messages
}
// 2. 处理附加字段
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
@@ -257,7 +233,7 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
GatewayState: req.State,
OssFile: req.OssFile,
FileType: req.FileType,
ResultText: req.Text,
ResultText: req.Messages,
})
if err != nil {
g.Log().Errorf(ctx, "[Callback] 更新成功状态失败 taskId=%s err=%v", req.TaskId, err)
@@ -278,18 +254,12 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
}
// ParsePromptResult 解析提示词构建结果
func ParsePromptResult(raw string) map[string]any {
var wrapper map[string]any
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
contentStr, ok := wrapper["content"].(string)
func ParsePromptResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if !ok || contentStr == "" {
return createDefaultResult(wrapper)
return raw
}
// 先尝试解析为数组
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
return map[string]any{
"total_rounds": len(roundsArray),
@@ -297,7 +267,6 @@ func ParsePromptResult(raw string) map[string]any {
}
}
// 再尝试解析为单个对象
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
return map[string]any{
"total_rounds": 1,
@@ -305,7 +274,7 @@ func ParsePromptResult(raw string) map[string]any {
}
}
return createDefaultResult(map[string]any{"content": contentStr})
return map[string]any{"content": contentStr}
}
func tryParseAsMapArray(jsonStr string) []map[string]any {
@@ -330,22 +299,20 @@ func tryParseAsMap(jsonStr string) map[string]any {
return obj
}
// ParseNodeResult 解析节点构建结果
func ParseNodeResult(raw string) map[string]any {
var result map[string]any
if err := json.Unmarshal([]byte(raw), &result); err != nil {
return createDefaultResult(map[string]any{"raw": raw})
}
if contentStr, ok := result["content"].(string); ok && contentStr != "" {
func ParseNodeResult(raw map[string]any) map[string]any {
contentStr, ok := raw["content"].(string)
if ok && contentStr != "" {
var inner map[string]any
if err := json.Unmarshal([]byte(contentStr), &inner); err == nil {
result = inner
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{inner},
}
}
}
return map[string]any{
"total_rounds": 1,
"rounds": []map[string]any{result},
"rounds": []map[string]any{raw},
}
}