refactor(task): 重构任务服务和数据结构
This commit is contained in:
@@ -26,14 +26,12 @@ type ComposeMessagesRes struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type CallbackReq struct {
|
type CallbackReq struct {
|
||||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||||
State int `json:"state" dc:"网关任务状态"`
|
State int `json:"state" dc:"网关任务状态"`
|
||||||
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
OssFile string `json:"oss_file" dc:"结果文件地址"`
|
||||||
FileType string `json:"file_type" dc:"结果文件类型"`
|
FileType string `json:"file_type" dc:"结果文件类型"`
|
||||||
Messages map[string]any `json:"messages" dc:"消息数组"`
|
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
||||||
ErrorMsg string `json:"error_msg" dc:"错误信息"`
|
|
||||||
EpicycleId int64 `json:"epicycleId" dc:"轮次ID"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type CallbackRes struct {
|
type CallbackRes struct {
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
"prompts-core/common/util"
|
"prompts-core/common/util"
|
||||||
"prompts-core/model/entity"
|
"prompts-core/model/entity"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/os/gtime"
|
"github.com/gogf/gf/v2/os/gtime"
|
||||||
)
|
)
|
||||||
@@ -147,11 +148,10 @@ func GetSkillUser(ctx context.Context, name string) (*SkillUserVO, error) {
|
|||||||
|
|
||||||
// SendCallbackReq 发送回调的请求体
|
// SendCallbackReq 发送回调的请求体
|
||||||
type SendCallbackReq struct {
|
type SendCallbackReq struct {
|
||||||
TaskId string `json:"taskId"`
|
TaskId string `json:"taskId"`
|
||||||
Status string `json:"status"`
|
Status string `json:"status"`
|
||||||
Messages map[string]any `json:"messages,omitempty"`
|
EpicycleId int64 `json:"epicycleId"`
|
||||||
EpicycleId int64 `json:"epicycleId"`
|
ErrorMsg string `json:"errorMsg,omitempty"`
|
||||||
ErrorMsg string `json:"errorMsg,omitempty"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SendCallback 向业务方发送回调
|
// SendCallback 向业务方发送回调
|
||||||
@@ -164,18 +164,32 @@ func SendCallback(ctx context.Context, composeTask *entity.ComposeTask, epicycle
|
|||||||
req := SendCallbackReq{
|
req := SendCallbackReq{
|
||||||
TaskId: composeTask.TaskId,
|
TaskId: composeTask.TaskId,
|
||||||
Status: composeTask.Status,
|
Status: composeTask.Status,
|
||||||
Messages: composeTask.ResultJson,
|
|
||||||
ErrorMsg: composeTask.ErrorMessage,
|
ErrorMsg: composeTask.ErrorMessage,
|
||||||
EpicycleId: epicycleId,
|
EpicycleId: epicycleId,
|
||||||
}
|
}
|
||||||
// 3. 发送 POST 请求
|
// 3. 发送 POST 请求
|
||||||
headers := util.ForwardHeaders(ctx)
|
headers := util.ForwardHeaders(ctx)
|
||||||
var resp struct{}
|
var resp struct{}
|
||||||
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s 消息=%v",
|
g.Log().Infof(ctx, "[回调业务] 开始发送 taskId=%s 回调地址=%s",
|
||||||
composeTask.TaskId, composeTask.CallbackUrl, gjson.New(req.Messages).String())
|
composeTask.TaskId, composeTask.CallbackUrl)
|
||||||
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
if err := commonHttp.Post(ctx, composeTask.CallbackUrl, headers, &resp, req); err != nil {
|
||||||
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
return fmt.Errorf("[回调业务] 发送失败 taskId=%s url=%s err=%w", composeTask.TaskId, composeTask.CallbackUrl, err)
|
||||||
}
|
}
|
||||||
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
g.Log().Infof(ctx, "[回调业务] 发送成功 taskId=%s 回调地址=%s ", composeTask.TaskId, composeTask.CallbackUrl)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DownloadFile 从 OSS 下载文件内容
|
||||||
|
func DownloadFile(ossURL string) ([]byte, error) {
|
||||||
|
resp, err := http.Get(ossURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件失败: %w", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("下载OSS文件返回非200: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return io.ReadAll(resp.Body)
|
||||||
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
|
|
||||||
"gitea.redpowerfuture.com/red-future/common/beans"
|
"gitea.redpowerfuture.com/red-future/common/beans"
|
||||||
"gitea.redpowerfuture.com/red-future/common/utils"
|
"gitea.redpowerfuture.com/red-future/common/utils"
|
||||||
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/frame/g"
|
"github.com/gogf/gf/v2/frame/g"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
@@ -128,24 +129,43 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
|||||||
// Callback 回调处理
|
// Callback 回调处理
|
||||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||||
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
g.Log().Infof(ctx, "[开始回调处理] taskId=%s state=%d", req.TaskId, req.State)
|
||||||
|
|
||||||
// 1) 查询任务
|
// 1) 查询任务
|
||||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{TaskId: req.TaskId})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("查询任务失败: %w", err)
|
return fmt.Errorf("查询任务失败: %w", err)
|
||||||
}
|
}
|
||||||
// 2) 处理失败
|
|
||||||
|
// 2) 读取 OSS 文件内容
|
||||||
|
var ossContent []byte
|
||||||
|
if req.OssFile != "" {
|
||||||
|
ossContent, err = gateway.DownloadFile(req.OssFile)
|
||||||
|
if err != nil {
|
||||||
|
g.Log().Warningf(ctx, "[回调处理] 读取OSS失败 taskId=%s err=%v", req.TaskId, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3) 解析 OSS 内容为消息
|
||||||
|
var messages map[string]any
|
||||||
|
if len(ossContent) > 0 {
|
||||||
|
messages, _ = gjson.New(ossContent).Map(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4) 处理失败
|
||||||
if req.State == 3 {
|
if req.State == 3 {
|
||||||
return handleCallbackFailed(ctx, req, composeTask)
|
return handleCallbackFailed(ctx, req, composeTask, messages)
|
||||||
}
|
}
|
||||||
// 3) 处理成功
|
|
||||||
|
// 5) 处理成功
|
||||||
if req.State == 2 {
|
if req.State == 2 {
|
||||||
return handleCallbackSuccess(ctx, req, composeTask)
|
return handleCallbackSuccess(ctx, req, composeTask, messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleCallbackFailed 处理回调失败
|
// handleCallbackFailed 处理回调失败
|
||||||
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||||
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
_, err := dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||||
TaskId: req.TaskId,
|
TaskId: req.TaskId,
|
||||||
Status: public.ComposeStatusFailed,
|
Status: public.ComposeStatusFailed,
|
||||||
@@ -153,7 +173,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
GatewayState: req.State,
|
GatewayState: req.State,
|
||||||
OssFile: req.OssFile,
|
OssFile: req.OssFile,
|
||||||
FileType: req.FileType,
|
FileType: req.FileType,
|
||||||
ResultJson: req.Messages,
|
ResultJson: messages,
|
||||||
})
|
})
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusFailed
|
composeTask.Status = public.ComposeStatusFailed
|
||||||
@@ -164,7 +184,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleCallbackSuccess 处理回调成功
|
// handleCallbackSuccess 处理回调成功
|
||||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask, messages map[string]any) error {
|
||||||
// 1) 获取模型配置
|
// 1) 获取模型配置
|
||||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||||
@@ -198,7 +218,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
NodeId: nodeId,
|
NodeId: nodeId,
|
||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
@@ -208,7 +228,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4) 合并附加结构
|
// 4) 合并附加结构
|
||||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||||
// 5) 注入历史
|
// 5) 注入历史
|
||||||
if len(history) > 0 {
|
if len(history) > 0 {
|
||||||
messages = InjectHistory(messages, history, protocol)
|
messages = InjectHistory(messages, history, protocol)
|
||||||
|
|||||||
Reference in New Issue
Block a user