feat: 新增模型扩展映射与查询配置字段
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
|
||||
"prompts-core/common/util"
|
||||
@@ -122,7 +123,7 @@ func saveComposeTask(ctx context.Context, taskID string, req *dto.ComposeMessage
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshal(req),
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
})
|
||||
return err
|
||||
@@ -182,13 +183,13 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
|
||||
}
|
||||
|
||||
// createDefaultResult 创建默认结果
|
||||
func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
|
||||
func createDefaultResult(data map[string]any) map[string]any {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []map[string]any{data},
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{data},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -196,14 +197,19 @@ func createDefaultResult(data map[string]any) *dto.MultiRoundResult {
|
||||
func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
g.Log().Infof(ctx, "[Callback][RECV] taskId=%s state=%d ossFile=%s fileType=%s textLen=%d",
|
||||
req.TaskId, req.State, req.OssFile, req.FileType, len(req.Text))
|
||||
// 查询任务
|
||||
composeTask, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if composeTask == nil {
|
||||
return fmt.Errorf("任务不存在: %s", req.TaskId)
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: composeTask.Creator},
|
||||
ModelName: composeTask.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
//处理失败
|
||||
if req.State == 3 {
|
||||
@@ -232,16 +238,18 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
//处理成功
|
||||
if req.State == 2 {
|
||||
// 1. 根据 BuildType 解析结果
|
||||
var messages any
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt: // 提示词构建解析
|
||||
messages = parsePromptResult(req.Text)
|
||||
messages = ParsePromptResult(req.Text)
|
||||
case public.BuildTypeNode: // 节点构建解析
|
||||
messages = parseNodeResult(req.Text)
|
||||
messages = ParseNodeResult(req.Text)
|
||||
default:
|
||||
messages = req.Text
|
||||
messages = gjson.New(req.Text).Map()
|
||||
}
|
||||
// 2. 更新数据库
|
||||
// 2. 处理附加字段
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 3. 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
@@ -269,8 +277,8 @@ func Callback(ctx context.Context, req *dto.CallbackReq) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// parsePromptResult 解析提示词构建结果
|
||||
func parsePromptResult(raw string) *dto.MultiRoundResult {
|
||||
// ParsePromptResult 解析提示词构建结果
|
||||
func ParsePromptResult(raw string) map[string]any {
|
||||
var wrapper map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &wrapper); err != nil {
|
||||
return createDefaultResult(map[string]any{"raw": raw})
|
||||
@@ -283,17 +291,17 @@ func parsePromptResult(raw string) *dto.MultiRoundResult {
|
||||
|
||||
// 先尝试解析为数组
|
||||
if roundsArray := tryParseAsMapArray(contentStr); roundsArray != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: len(roundsArray),
|
||||
Rounds: roundsArray,
|
||||
return map[string]any{
|
||||
"total_rounds": len(roundsArray),
|
||||
"rounds": roundsArray,
|
||||
}
|
||||
}
|
||||
|
||||
// 再尝试解析为单个对象
|
||||
if singleRound := tryParseAsMap(contentStr); singleRound != nil {
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []map[string]any{singleRound},
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{singleRound},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,8 +330,8 @@ func tryParseAsMap(jsonStr string) map[string]any {
|
||||
return obj
|
||||
}
|
||||
|
||||
// parseNodeResult 解析节点构建结果
|
||||
func parseNodeResult(raw string) *dto.MultiRoundResult {
|
||||
// ParseNodeResult 解析节点构建结果
|
||||
func ParseNodeResult(raw string) map[string]any {
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(raw), &result); err != nil {
|
||||
return createDefaultResult(map[string]any{"raw": raw})
|
||||
@@ -335,10 +343,9 @@ func parseNodeResult(raw string) *dto.MultiRoundResult {
|
||||
result = inner
|
||||
}
|
||||
}
|
||||
|
||||
return &dto.MultiRoundResult{
|
||||
TotalRounds: 1,
|
||||
Rounds: []map[string]any{result},
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": []map[string]any{result},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user