refactor(prompts-core): 重构代码结构和优化工具函数

This commit is contained in:
2026-06-10 14:51:25 +08:00
parent 1c1db7e30c
commit b69e7386e2
10 changed files with 164 additions and 432 deletions

View File

@@ -1,197 +1,81 @@
package util package util
import ( import (
"encoding/json"
"fmt" "fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j := gjson.New(raw)
messages := j.Get("messages")
if !messages.IsNil() {
return gconv.Maps(messages.Val())
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form []map[string]any) string {
if form == nil {
return "[]"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshalToMap 将对象序列化为 map[string]any失败时返回空 map
func MustMarshalToMap(v any) map[string]any {
b, err := json.Marshal(v)
if err != nil {
return make(map[string]any)
}
var m map[string]any
json.Unmarshal(b, &m)
return m
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.Marshal(tmp)
return string(b)
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中 // MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any { func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 { if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
return messages return messages
} }
// 1) 获取 consult 数组
consult := gconv.Interfaces(req["consult"]) consult := gconv.Interfaces(req["consult"])
if len(consult) == 0 { if len(consult) == 0 {
return messages return messages
} }
// 2) 获取配置
targetPath := gconv.String(extendMapping["target_content_path"]) targetPath := gconv.String(extendMapping["target_content_path"])
if targetPath == "" {
return messages
}
templates := gconv.Map(extendMapping["attachment_templates"]) templates := gconv.Map(extendMapping["attachment_templates"])
if len(templates) == 0 { if targetPath == "" || len(templates) == 0 {
return messages return messages
} }
// 3) 转为 gjson 操作
msgJson := gjson.New(messages) msgJson := gjson.New(messages)
// 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath} // rounds 路径修正
if arr := msgJson.Get("rounds.0").Array(); arr != nil { if !msgJson.Get("rounds.0").IsNil() {
targetPath = "rounds.0." + targetPath targetPath = "rounds.0." + targetPath
} }
// 4) 遍历 consult按类型生成附件并追加 // 遍历追加
for _, item := range consult { for _, item := range consult {
itemJson := gjson.New(item) itemJson := gjson.New(item)
itemType := itemJson.Get("type").String() itemType := itemJson.Get("type").String()
if itemType == "" {
continue
}
// 查找对应模板
tmpl := gconv.Map(templates[itemType]) tmpl := gconv.Map(templates[itemType])
if len(tmpl) == 0 { if itemType == "" || len(tmpl) == 0 {
continue continue
} }
// 生成附件对象
attachment := buildAttachment(tmpl, itemJson.Get("url").String()) attachment := buildAttachment(tmpl, itemJson.Get("url").String())
if attachment == nil { if attachment == nil {
continue continue
} }
// 获取当前数组长度,用索引追加 idx := len(msgJson.Get(targetPath).Array())
arr := msgJson.Get(targetPath).Array() _ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
idx := len(arr)
indexPath := fmt.Sprintf("%s.%d", targetPath, idx)
_ = msgJson.Set(indexPath, attachment)
} }
return msgJson.Map() return msgJson.Map()
} }
// buildAttachment 根据模板和用户数据生成附件对象
func buildAttachment(tmpl map[string]any, url string) map[string]any { func buildAttachment(tmpl map[string]any, url string) map[string]any {
typ := gconv.String(tmpl["type"]) typ := gconv.String(tmpl["type"])
if typ == "" || url == "" { if typ == "" || url == "" {
return nil return nil
} }
// 深拷贝 body 并填充 url
body := gconv.Map(tmpl["body"]) body := gconv.Map(tmpl["body"])
bodyJson := gjson.New(body) fillEmptyInPlace(body, url)
bodyJson = fillEmpty(bodyJson, url)
return map[string]any{ return map[string]any{
"type": typ, "type": typ,
typ: bodyJson.Map(), typ: body,
} }
} }
// fillEmpty 递归查找空字符串并替换 func fillEmptyInPlace(m map[string]any, value string) {
func fillEmpty(j *gjson.Json, value string) *gjson.Json {
m := j.Map()
for k, v := range m { for k, v := range m {
switch vv := v.(type) { switch vv := v.(type) {
case string: case string:
if vv == "" { if vv == "" {
_ = j.Set(k, value) m[k] = value
} }
case map[string]any: case map[string]any:
_ = j.Set(k, fillEmpty(gjson.New(vv), value).Map()) fillEmptyInPlace(vv, value)
} }
} }
return j
} }

View File

@@ -4,6 +4,7 @@ import (
"strings" "strings"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
) )
// ReverseMap 映射 payload 到 mapping // ReverseMap 映射 payload 到 mapping
@@ -20,80 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
return jsonObj.Map() return jsonObj.Map()
} }
// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构 // ExtractUserText 从 messages 中提取所有 user 文本
func ExtractUserText(messages map[string]any) map[string]any { func ExtractUserText(messages map[string]any) map[string]any {
msgJson := gjson.New(messages)
msgs := msgJson.Get("rounds.0.messages")
if msgs.IsNil() {
msgs = msgJson.Get("messages")
}
var texts []string var texts []string
for _, m := range msgs.Array() {
// 1) rounds 结构:遍历每轮 msg := gjson.New(m)
if rounds, ok := messages["rounds"].([]any); ok { if msg.Get("role").String() != "user" {
for _, round := range rounds { continue
if rm, ok := round.(map[string]any); ok { }
if msgs, ok := rm["messages"].([]any); ok { content := msg.Get("content").Val()
texts = append(texts, extractTextFromRoleUser(msgs)...) switch c := content.(type) {
case string:
texts = append(texts, c)
case []any:
for _, item := range c {
if m, ok := item.(map[string]any); ok {
if t := gconv.String(m["text"]); t != "" {
texts = append(texts, t)
}
} }
} }
} }
} else if msgs, ok := messages["messages"].([]any); ok {
// 2) messages 结构
texts = extractTextFromRoleUser(msgs)
} }
// 3) 构建返回结构
return map[string]any{ return map[string]any{
"role": "user", "role": "user",
"content": strings.Join(texts, "\n"), "content": strings.Join(texts, "\n"),
} }
} }
// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本
func extractTextFromRoleUser(msgs []any) []string {
var texts []string
for _, msg := range msgs {
m, ok := msg.(map[string]any)
if !ok {
continue
}
if role, _ := m["role"].(string); role != "user" {
continue
}
texts = append(texts, extractAllText(m["content"])...)
}
return texts
}
// extractAllText 从 content 中提取所有文本(递归,最大兼容)
func extractAllText(content any) []string {
switch c := content.(type) {
case string:
return []string{c}
case []any:
var texts []string
for _, item := range c {
m, ok := item.(map[string]any)
if !ok {
continue
}
if t, ok := m["text"].(string); ok && t != "" {
texts = append(texts, t)
continue
}
for _, v := range m {
texts = append(texts, extractAllText(v)...)
}
}
return texts
case map[string]any:
if t, ok := c["text"].(string); ok && t != "" {
return []string{t}
}
var texts []string
for _, v := range c {
texts = append(texts, extractAllText(v)...)
}
return texts
}
return nil
}

View File

@@ -11,3 +11,7 @@ const (
BuildTypeNode = 2 //节点构建 BuildTypeNode = 2 //节点构建
BuildTypeStruct = 3 //结构构建 BuildTypeStruct = 3 //结构构建
) )
const (
ModelTypeInference = 100 // 推理模型
)

View File

@@ -26,8 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) { func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId) return promptService.GetComposeTask(ctx, req.TaskId)
} }
// GetPromptText 纯文本prompt调用接口测试专用
func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) {
return promptService.GetPromptText(ctx, req)
}

View File

@@ -25,12 +25,6 @@ type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"` TaskId string `json:"taskId" dc:"任务ID"`
} }
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
type CallbackReq struct { type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"` g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"` TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
@@ -55,16 +49,7 @@ type GetComposeTaskRes struct {
Status string `json:"status" dc:"业务状态"` Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"` GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"` ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"` Messages map[string]any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"` OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"` FileType string `json:"fileType" dc:"结果文件类型"`
} }
type GetPromptTextReq struct {
g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"`
Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"`
}
type GetPromptTextRes struct {
Messages any `json:"messages" dc:"历史消息"`
}

View File

@@ -13,11 +13,10 @@ import (
"gitea.com/red-future/common/utils" "gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson" "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
) )
// buildPromptTypeRequest 构建提示词类型请求BuildType=1 // buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) { func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
//1) 构建系统提示词 //1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel) systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt) ir.AddSystem(systemPrompt)
@@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
} }
// buildNodeTypeRequest 构建节点类型请求BuildType=2 // buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req)) ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel, req) return compileToProviderRequest(ctx, ir, chatModel, req)
} }
// buildStructTypeRequest 构建结构体类型请求BuildType=3 // buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) { func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
// 提取 userForm 中的 prompt 作为自定义提示词 customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
var customPrompt string
for _, item := range req.UserForm {
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
customPrompt = gconv.String(prompt)
break
}
}
// 用户消息
ir.AddSystem(customPrompt) ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, "")) ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt) return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
} }
// compileToProviderRequest 编译为 Provider 请求 // compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) { func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName) protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil { if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err) return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
@@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
}, nil }, nil
} }
// promptBuildWithRounds 构建提示词
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string { func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{ providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName, ProviderName: chatModel.OperatorName,
@@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
if err != nil || providerProtocol == nil { if err != nil || providerProtocol == nil {
return "" return ""
} }
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{})) outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate, return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
outputJSON, //【输出结构】 %s outputJSON, //【输出结构】 %s
@@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
} }
// checkOverallContent 检查整体内容是否超出窗口 // checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool { func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
fullContent := ir.String() fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig) return util.CountToken(fullContent, model.TokenConfig)
} }
@@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
return b.String() return b.String()
} }
// buildUserFormText 构建用户表单内容字符串
func buildUserFormText(form []map[string]any) string { func buildUserFormText(form []map[string]any) string {
if len(form) == 0 { if len(form) == 0 {
return "" return ""
@@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string {
var builder strings.Builder var builder strings.Builder
for _, item := range form { for _, item := range form {
for k, v := range item { for k, v := range item {
builder.WriteString(fmt.Sprintf("%s\n", k))
switch val := v.(type) { switch val := v.(type) {
case []any: case []any:
// 数组类型:逐条列出
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, elem := range val { for i, elem := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
if m, ok := elem.(map[string]any); ok { if m, ok := elem.(map[string]any); ok {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m { for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv)) builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
} }
builder.WriteString("\n")
} else { } else {
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem)) builder.WriteString(fmt.Sprint(elem))
}
}
case []map[string]any:
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, m := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
} }
builder.WriteString("\n") builder.WriteString("\n")
} }
default: default:
builder.WriteString(fmt.Sprintf("%s%v\n", k, v)) builder.WriteString(fmt.Sprintf(" %v\n", v))
} }
} }
} }
@@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
if promptTpl == "" { if promptTpl == "" {
return "" return ""
} }
return fmt.Sprintf(promptTpl,
formStr := util.FormToJSON(req.Form) gjson.New(req.Form).MustToJsonString(),
userFormStr := util.UserFormToJSON(req.UserForm) gjson.New(req.UserForm).MustToJsonString(),
)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
} }

View File

@@ -2,7 +2,6 @@ package prompt
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"prompts-core/service/session" "prompts-core/service/session"
@@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
// handleBuild 通用构建处理 // handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) { func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批 // 1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel) processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil { if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err) return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
} }
@@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
var taskReq map[string]any var taskReq map[string]any
switch req.BuildType { switch req.BuildType {
case public.BuildTypePrompt: case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches) taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
case public.BuildTypeNode: case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir) taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct: case public.BuildTypeStruct:
@@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
SkillName: req.SkillName, SkillName: req.SkillName,
BuildType: req.BuildType, BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl, CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req), RequestPayload: gconv.Map(req),
Status: public.ComposeStatusPending, Status: public.ComposeStatusPending,
}); err != nil { }); err != nil {
return nil, err return nil, err
@@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
return err return err
} }
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error { func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 获取模型配置 // 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{ model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
@@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
Status: 1, Status: 1,
}) })
// 3) 获取历史消息 // 3) 获取历史消息 + 保存当前轮
payload := composeTask.RequestPayload payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"]) sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"]) nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage var history []dto.FlatMessage
if sessionId != "" && nodeId != "" { var epicycleId int64
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
// 3.1 获取历史
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{ h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId, SessionId: sessionId,
NodeId: nodeId, NodeId: nodeId,
@@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if h != nil { if h != nil {
history = h.Messages history = h.Messages
} }
// 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
} }
// 4) 合并附加结构 // 4) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping) messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
// 5) 注入历史到 rounds 中 // 5) 注入历史
if protocol != nil && len(history) > 0 { if len(history) > 0 {
messages = InjectHistory(messages, history, protocol) messages = InjectHistory(messages, history, protocol)
} }
@@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return err return err
} }
// 7) 存储历史
var epicycleId int64
if sessionId != "" && nodeId != "" {
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 8) 回调业务方 // 8) 回调业务方
if composeTask.CallbackUrl != "" { if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess composeTask.Status = public.ComposeStatusSuccess
@@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return nil return nil
} }
// GetComposeTask 查询任务结果 // InjectHistory 插入历史会话
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.ResultJson)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
// 1) 获取协议配置
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "火山引擎",
Status: 1,
})
if err != nil {
return nil, err
}
// 2) 获取历史消息
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: "88888888",
NodeId: "node1",
})
if err != nil {
return nil, err
}
// 3) 模拟roundsData数据
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
})
if err != nil {
return nil, err
}
fmt.Println("[打印数据]", task.ResultJson)
fmt.Println("[打印历史]", history.Messages)
fmt.Println("[打印协议]", protocol)
return &dto.GetPromptTextRes{
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
}, nil
}
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any { func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 { if protocol == nil || len(history) == 0 {
return roundsData return roundsData
@@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco
firstRound["messages"] = result firstRound["messages"] = result
return roundsData return roundsData
} }
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: record.ResultJson,
}, nil
}

View File

@@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
} }
func SkillMdContent(ctx context.Context, skillName string) string { func SkillMdContent(ctx context.Context, skillName string) string {
if skillName == "" {
return ""
}
skillResp, err := gateway.GetSkillUser(ctx, skillName) skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil { if err != nil {
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err) g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)

View File

@@ -2,9 +2,7 @@ package prompt
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"prompts-core/common/util"
"prompts-core/service/gateway" "prompts-core/service/gateway"
"strings" "strings"
@@ -14,8 +12,8 @@ import (
"github.com/gogf/gf/v2/util/gconv" "github.com/gogf/gf/v2/util/gconv"
) )
// PromptIR 统一 Prompt 中间表示 // IR 统一 Prompt 中间表示
type PromptIR struct { type IR struct {
System []Segment `json:"system"` System []Segment `json:"system"`
History []Segment `json:"history"` History []Segment `json:"history"`
User []Segment `json:"user"` User []Segment `json:"user"`
@@ -46,8 +44,8 @@ type ContentMapping struct {
} }
// NewPromptIR 创建空 PromptIR // NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR { func NewPromptIR() *IR {
return &PromptIR{ return &IR{
System: make([]Segment, 0), System: make([]Segment, 0),
History: make([]Segment, 0), History: make([]Segment, 0),
User: make([]Segment, 0), User: make([]Segment, 0),
@@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR {
} }
// String 返回 PromptIR 的完整内容字符串(用于 token 计算) // String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string { func (ir *IR) String() string {
var builder strings.Builder var builder strings.Builder
for _, seg := range ir.System { for _, seg := range ir.System {
@@ -81,7 +79,7 @@ func (ir *PromptIR) String() string {
} }
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算) // GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string { func (ir *IR) GetTotalContent() string {
var builder strings.Builder var builder strings.Builder
for _, seg := range ir.System { for _, seg := range ir.System {
@@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
} }
// AddSystem 添加系统提示 // AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR { func (ir *IR) AddSystem(content string) *IR {
if content != "" { if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content}) ir.System = append(ir.System, Segment{Type: "text", Content: content})
} }
@@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
} }
// AddUser 添加用户消息 // AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR { func (ir *IR) AddUser(content string) *IR {
if content != "" { if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content}) ir.User = append(ir.User, Segment{Type: "text", Content: content})
} }
@@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
} }
// AddHistory 添加历史消息 // AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR { func (ir *IR) AddHistory(role, content string) *IR {
if content != "" { if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role}) ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
} }
@@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
} }
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认) // ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any { func (ir *IR) ToMessages() []map[string]any {
var messages []map[string]any var messages []map[string]any
for _, seg := range ir.System { for _, seg := range ir.System {
@@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
// parseProtocol 将 DB entity 转为编译用协议配置 // parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol { func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{ return &ProviderProtocol{
TargetField: e.TargetField, TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate, SystemPromptTemplate: e.SystemPromptTemplate,
MergeOrder: e.MergeOrder,
RoleMapping: gconv.MapStrStr(e.RoleMapping),
ContentMapping: ContentMapping{
Type: gconv.String(e.ContentMapping["type"]),
Field: gconv.String(e.ContentMapping["field"]),
},
RequestTemplate: e.RequestTemplate,
Capabilities: e.Capabilities,
} }
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
return p
} }
// Compile 将 PromptIR 按协议配置编译为 Provider Request // Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) { func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil { if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required") return nil, fmt.Errorf("ir and protocol are required")
} }
@@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
} }
// mergeByOrder 按协议配置顺序拼接消息 // mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any { func mergeByOrder(ir *IR, order []string) []map[string]any {
roleMap := map[string][]Segment{
"system": ir.System,
"history": ir.History,
"user": ir.User,
}
var messages []map[string]any var messages []map[string]any
for _, part := range order { for _, part := range order {
switch part { for _, seg := range roleMap[part] {
case "system": msg := map[string]any{"content": seg.Content}
for _, seg := range ir.System { if part == "history" {
messages = append(messages, map[string]any{ msg["role"] = seg.Role
"role": "system", } else {
"content": seg.Content, msg["role"] = part
})
} }
case "history": messages = append(messages, msg)
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
} }
} }
}
return messages return messages
} }
@@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
return messages return messages
} }
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any { func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages { if cm.Field == "" || cm.Field == "content" {
content := msg["content"] return messages
delete(msg, "content") }
for i, msg := range messages {
if content, ok := msg["content"]; ok {
delete(msg, "content")
switch cm.Type { switch cm.Type {
case "parts": case "parts":
msg["parts"] = []map[string]any{ messages[i]["parts"] = []map[string]any{{cm.Field: content}}
{cm.Field: content},
}
default: default:
msg[cm.Field] = content messages[i][cm.Field] = content
}
} }
} }
return messages return messages
} }
@@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
} }
} }
// renderTemplate 简单的 {{key}} 模板替换 // renderTemplate 模板渲染
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any { func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(p.RequestTemplate) result := make(map[string]any, len(p.RequestTemplate)+1)
str := string(b) for k, v := range p.RequestTemplate {
result[k] = v
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
} }
msgBytes, _ := json.Marshal(messages) if chatModel != nil {
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes)) result["model"] = chatModel.ModelName
}
var result map[string]any result["messages"] = messages
_ = json.Unmarshal([]byte(str), &result)
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 { if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens result["max_tokens"] = maxTokens

View File

@@ -21,8 +21,8 @@ import (
// Callback 会话回调 // Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) { func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
fmt.Println("打印会话回调", req)
req.Messages["role"] = "assistant" req.Messages["role"] = "assistant"
// 1) 更新 DB // 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{ _, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId}, SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
@@ -163,23 +163,15 @@ func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteS
// entityToHistoryRound entity → HistoryRound // entityToHistoryRound entity → HistoryRound
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound { func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
reqMsgs := util.ConvertToMessages(s.RequestContent) return &dto.HistoryRound{
respMsgs := util.ConvertToMessages(s.ResponseContent)
round := &dto.HistoryRound{
Id: s.Id, Id: s.Id,
SessionId: s.SessionId, SessionId: s.SessionId,
NodeId: s.NodeId, NodeId: s.NodeId,
CreatedAt: gconv.String(s.CreatedAt), CreatedAt: gconv.String(s.CreatedAt),
UpdatedAt: gconv.String(s.UpdatedAt), UpdatedAt: gconv.String(s.UpdatedAt),
User: s.RequestContent,
Assistant: s.ResponseContent,
} }
if len(reqMsgs) > 0 {
round.User = reqMsgs[0]
}
if len(respMsgs) > 0 {
round.Assistant = respMsgs[0]
}
return round
} }
// sessionsToHistoryRounds 批量转换 // sessionsToHistoryRounds 批量转换