refactor(prompts-core): 重构代码结构和优化工具函数
This commit is contained in:
@@ -1,197 +1,81 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ConvertToMessages 将原始数据转换为消息列表
|
||||
func ConvertToMessages(raw any) []map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
j := gjson.New(raw)
|
||||
messages := j.Get("messages")
|
||||
if !messages.IsNil() {
|
||||
return gconv.Maps(messages.Val())
|
||||
}
|
||||
return []map[string]any{j.Map()}
|
||||
}
|
||||
|
||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||
func FormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "[]"
|
||||
}
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
||||
func UserFormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// MustMarshalToMap 将对象序列化为 map[string]any,失败时返回空 map
|
||||
func MustMarshalToMap(v any) map[string]any {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return make(map[string]any)
|
||||
}
|
||||
var m map[string]any
|
||||
json.Unmarshal(b, &m)
|
||||
return m
|
||||
}
|
||||
|
||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
||||
func JSONPretty(v any) string {
|
||||
if gv, ok := v.(*gvar.Var); ok {
|
||||
v = gconv.Map(gv.String())
|
||||
}
|
||||
|
||||
var tmp map[string]any
|
||||
if err := gconv.Struct(v, &tmp); err != nil {
|
||||
return gconv.String(v)
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(tmp)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
||||
func ParseJSONFieldFromGvar(source any, target any) {
|
||||
if source == nil {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := source.(type) {
|
||||
case *gvar.Var:
|
||||
if v.IsNil() {
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试获取 map
|
||||
if m := v.Map(); len(m) > 0 {
|
||||
data, _ := json.Marshal(m)
|
||||
json.Unmarshal(data, target)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解析 JSON 字符串
|
||||
str := v.String()
|
||||
if str != "" && str != "<nil>" {
|
||||
json.Unmarshal([]byte(str), target)
|
||||
}
|
||||
|
||||
default:
|
||||
// 其他类型走原来的逻辑
|
||||
data, _ := json.Marshal(source)
|
||||
json.Unmarshal(data, target)
|
||||
}
|
||||
}
|
||||
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 1) 获取 consult 数组
|
||||
consult := gconv.Interfaces(req["consult"])
|
||||
if len(consult) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 2) 获取配置
|
||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||
if targetPath == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||
if len(templates) == 0 {
|
||||
if targetPath == "" || len(templates) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 3) 转为 gjson 操作
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
// 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath}
|
||||
if arr := msgJson.Get("rounds.0").Array(); arr != nil {
|
||||
// rounds 路径修正
|
||||
if !msgJson.Get("rounds.0").IsNil() {
|
||||
targetPath = "rounds.0." + targetPath
|
||||
}
|
||||
|
||||
// 4) 遍历 consult,按类型生成附件并追加
|
||||
// 遍历追加
|
||||
for _, item := range consult {
|
||||
itemJson := gjson.New(item)
|
||||
|
||||
itemType := itemJson.Get("type").String()
|
||||
if itemType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查找对应模板
|
||||
tmpl := gconv.Map(templates[itemType])
|
||||
if len(tmpl) == 0 {
|
||||
if itemType == "" || len(tmpl) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 生成附件对象
|
||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||
if attachment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取当前数组长度,用索引追加
|
||||
arr := msgJson.Get(targetPath).Array()
|
||||
idx := len(arr)
|
||||
indexPath := fmt.Sprintf("%s.%d", targetPath, idx)
|
||||
_ = msgJson.Set(indexPath, attachment)
|
||||
idx := len(msgJson.Get(targetPath).Array())
|
||||
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
||||
}
|
||||
|
||||
return msgJson.Map()
|
||||
}
|
||||
|
||||
// buildAttachment 根据模板和用户数据生成附件对象
|
||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||
typ := gconv.String(tmpl["type"])
|
||||
if typ == "" || url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 深拷贝 body 并填充 url
|
||||
body := gconv.Map(tmpl["body"])
|
||||
bodyJson := gjson.New(body)
|
||||
bodyJson = fillEmpty(bodyJson, url)
|
||||
fillEmptyInPlace(body, url)
|
||||
|
||||
return map[string]any{
|
||||
"type": typ,
|
||||
typ: bodyJson.Map(),
|
||||
typ: body,
|
||||
}
|
||||
}
|
||||
|
||||
// fillEmpty 递归查找空字符串并替换
|
||||
func fillEmpty(j *gjson.Json, value string) *gjson.Json {
|
||||
m := j.Map()
|
||||
func fillEmptyInPlace(m map[string]any, value string) {
|
||||
for k, v := range m {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
if vv == "" {
|
||||
_ = j.Set(k, value)
|
||||
m[k] = value
|
||||
}
|
||||
case map[string]any:
|
||||
_ = j.Set(k, fillEmpty(gjson.New(vv), value).Map())
|
||||
fillEmptyInPlace(vv, value)
|
||||
}
|
||||
}
|
||||
return j
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ReverseMap 映射 payload 到 mapping
|
||||
@@ -20,80 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构
|
||||
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||
func ExtractUserText(messages map[string]any) map[string]any {
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
msgs := msgJson.Get("rounds.0.messages")
|
||||
if msgs.IsNil() {
|
||||
msgs = msgJson.Get("messages")
|
||||
}
|
||||
var texts []string
|
||||
|
||||
// 1) rounds 结构:遍历每轮
|
||||
if rounds, ok := messages["rounds"].([]any); ok {
|
||||
for _, round := range rounds {
|
||||
if rm, ok := round.(map[string]any); ok {
|
||||
if msgs, ok := rm["messages"].([]any); ok {
|
||||
texts = append(texts, extractTextFromRoleUser(msgs)...)
|
||||
for _, m := range msgs.Array() {
|
||||
msg := gjson.New(m)
|
||||
if msg.Get("role").String() != "user" {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content").Val()
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
texts = append(texts, c)
|
||||
case []any:
|
||||
for _, item := range c {
|
||||
if m, ok := item.(map[string]any); ok {
|
||||
if t := gconv.String(m["text"]); t != "" {
|
||||
texts = append(texts, t)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if msgs, ok := messages["messages"].([]any); ok {
|
||||
// 2) messages 结构
|
||||
texts = extractTextFromRoleUser(msgs)
|
||||
}
|
||||
|
||||
// 3) 构建返回结构
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": strings.Join(texts, "\n"),
|
||||
}
|
||||
}
|
||||
|
||||
// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本
|
||||
func extractTextFromRoleUser(msgs []any) []string {
|
||||
var texts []string
|
||||
for _, msg := range msgs {
|
||||
m, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if role, _ := m["role"].(string); role != "user" {
|
||||
continue
|
||||
}
|
||||
texts = append(texts, extractAllText(m["content"])...)
|
||||
}
|
||||
return texts
|
||||
}
|
||||
|
||||
// extractAllText 从 content 中提取所有文本(递归,最大兼容)
|
||||
func extractAllText(content any) []string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return []string{c}
|
||||
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, item := range c {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if t, ok := m["text"].(string); ok && t != "" {
|
||||
texts = append(texts, t)
|
||||
continue
|
||||
}
|
||||
for _, v := range m {
|
||||
texts = append(texts, extractAllText(v)...)
|
||||
}
|
||||
}
|
||||
return texts
|
||||
|
||||
case map[string]any:
|
||||
if t, ok := c["text"].(string); ok && t != "" {
|
||||
return []string{t}
|
||||
}
|
||||
var texts []string
|
||||
for _, v := range c {
|
||||
texts = append(texts, extractAllText(v)...)
|
||||
}
|
||||
return texts
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,3 +11,7 @@ const (
|
||||
BuildTypeNode = 2 //节点构建
|
||||
BuildTypeStruct = 3 //结构构建
|
||||
)
|
||||
|
||||
const (
|
||||
ModelTypeInference = 100 // 推理模型
|
||||
)
|
||||
|
||||
@@ -26,8 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||
}
|
||||
|
||||
// GetPromptText 纯文本prompt调用接口(测试专用)
|
||||
func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) {
|
||||
return promptService.GetPromptText(ctx, req)
|
||||
}
|
||||
|
||||
@@ -25,12 +25,6 @@ type ComposeMessagesRes struct {
|
||||
TaskId string `json:"taskId" dc:"任务ID"`
|
||||
}
|
||||
|
||||
// MultiRoundResult 多轮返回结果
|
||||
type MultiRoundResult struct {
|
||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
||||
}
|
||||
|
||||
type CallbackReq struct {
|
||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||
@@ -55,16 +49,7 @@ type GetComposeTaskRes struct {
|
||||
Status string `json:"status" dc:"业务状态"`
|
||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||
Messages any `json:"messages" dc:"最终消息数组"`
|
||||
Messages map[string]any `json:"messages" dc:"最终消息数组"`
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
}
|
||||
|
||||
type GetPromptTextReq struct {
|
||||
g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"`
|
||||
Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"`
|
||||
}
|
||||
|
||||
type GetPromptTextRes struct {
|
||||
Messages any `json:"messages" dc:"历史消息"`
|
||||
}
|
||||
|
||||
@@ -13,11 +13,10 @@ import (
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
//1) 构建系统提示词
|
||||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||
ir.AddSystem(systemPrompt)
|
||||
@@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
// 提取 userForm 中的 prompt 作为自定义提示词
|
||||
var customPrompt string
|
||||
for _, item := range req.UserForm {
|
||||
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
|
||||
customPrompt = gconv.String(prompt)
|
||||
break
|
||||
}
|
||||
}
|
||||
// 用户消息
|
||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil || protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||
@@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuildWithRounds 构建提示词
|
||||
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: chatModel.OperatorName,
|
||||
@@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
||||
if err != nil || providerProtocol == nil {
|
||||
return ""
|
||||
}
|
||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
outputJSON, //【输出结构】 %s
|
||||
@@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
||||
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
return util.CountToken(fullContent, model.TokenConfig)
|
||||
}
|
||||
@@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// buildUserFormText 构建用户表单内容字符串
|
||||
func buildUserFormText(form []map[string]any) string {
|
||||
if len(form) == 0 {
|
||||
return ""
|
||||
@@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range form {
|
||||
for k, v := range item {
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
switch val := v.(type) {
|
||||
case []any:
|
||||
// 数组类型:逐条列出
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
for i, elem := range val {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
if m, ok := elem.(map[string]any); ok {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
for mk, mv := range m {
|
||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
} else {
|
||||
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
|
||||
}
|
||||
}
|
||||
case []map[string]any:
|
||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||
for i, m := range val {
|
||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||
for mk, mv := range m {
|
||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||
builder.WriteString(fmt.Sprint(elem))
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
default:
|
||||
builder.WriteString(fmt.Sprintf("%s:%v\n", k, v))
|
||||
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
||||
if promptTpl == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
formStr := util.FormToJSON(req.Form)
|
||||
userFormStr := util.UserFormToJSON(req.UserForm)
|
||||
|
||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
||||
return fmt.Sprintf(promptTpl,
|
||||
gjson.New(req.Form).MustToJsonString(),
|
||||
gjson.New(req.UserForm).MustToJsonString(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
@@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
|
||||
// handleBuild 通用构建处理
|
||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||
// 1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
}
|
||||
@@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
var taskReq map[string]any
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
|
||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||
case public.BuildTypeNode:
|
||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||
case public.BuildTypeStruct:
|
||||
@@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
||||
SkillName: req.SkillName,
|
||||
BuildType: req.BuildType,
|
||||
CallbackUrl: req.CallbackUrl,
|
||||
RequestPayload: util.MustMarshalToMap(req),
|
||||
RequestPayload: gconv.Map(req),
|
||||
Status: public.ComposeStatusPending,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
@@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
||||
return err
|
||||
}
|
||||
|
||||
// handleCallbackSuccess 处理回调成功
|
||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||
// 1) 获取模型配置
|
||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||
@@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 3) 获取历史消息
|
||||
// 3) 获取历史消息 + 保存当前轮
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
var history []dto.FlatMessage
|
||||
if sessionId != "" && nodeId != "" {
|
||||
var epicycleId int64
|
||||
|
||||
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||
// 3.1 获取历史
|
||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
@@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if h != nil {
|
||||
history = h.Messages
|
||||
}
|
||||
|
||||
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
// 5) 注入历史到 rounds 中
|
||||
if protocol != nil && len(history) > 0 {
|
||||
// 5) 注入历史
|
||||
if len(history) > 0 {
|
||||
messages = InjectHistory(messages, history, protocol)
|
||||
}
|
||||
|
||||
@@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return err
|
||||
}
|
||||
|
||||
// 7) 存储历史
|
||||
var epicycleId int64
|
||||
if sessionId != "" && nodeId != "" {
|
||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 8) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
@@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
if record == nil {
|
||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
||||
}
|
||||
|
||||
messages := parseMessagesForResponse(record.ResultJson)
|
||||
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: messages,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseMessagesForResponse 解析用于响应的消息
|
||||
func parseMessagesForResponse(messages any) any {
|
||||
str, ok := messages.(string)
|
||||
if !ok || str == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
||||
return parsed
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取协议配置
|
||||
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: "火山引擎",
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 获取历史消息
|
||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||
SessionId: "88888888",
|
||||
NodeId: "node1",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 模拟roundsData数据
|
||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Println("[打印数据]", task.ResultJson)
|
||||
fmt.Println("[打印历史]", history.Messages)
|
||||
fmt.Println("[打印协议]", protocol)
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InjectHistory 插入历史会话
|
||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||
if protocol == nil || len(history) == 0 {
|
||||
return roundsData
|
||||
@@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco
|
||||
firstRound["messages"] = result
|
||||
return roundsData
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
TaskId: taskID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||
}
|
||||
return &dto.GetComposeTaskRes{
|
||||
TaskId: record.TaskId,
|
||||
Status: record.Status,
|
||||
ErrorMessage: record.ErrorMessage,
|
||||
Messages: record.ResultJson,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
||||
}
|
||||
|
||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
if skillName == "" {
|
||||
return ""
|
||||
}
|
||||
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
||||
|
||||
@@ -2,9 +2,7 @@ package prompt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/service/gateway"
|
||||
"strings"
|
||||
|
||||
@@ -14,8 +12,8 @@ import (
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// PromptIR 统一 Prompt 中间表示
|
||||
type PromptIR struct {
|
||||
// IR 统一 Prompt 中间表示
|
||||
type IR struct {
|
||||
System []Segment `json:"system"`
|
||||
History []Segment `json:"history"`
|
||||
User []Segment `json:"user"`
|
||||
@@ -46,8 +44,8 @@ type ContentMapping struct {
|
||||
}
|
||||
|
||||
// NewPromptIR 创建空 PromptIR
|
||||
func NewPromptIR() *PromptIR {
|
||||
return &PromptIR{
|
||||
func NewPromptIR() *IR {
|
||||
return &IR{
|
||||
System: make([]Segment, 0),
|
||||
History: make([]Segment, 0),
|
||||
User: make([]Segment, 0),
|
||||
@@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR {
|
||||
}
|
||||
|
||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||
func (ir *PromptIR) String() string {
|
||||
func (ir *IR) String() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -81,7 +79,7 @@ func (ir *PromptIR) String() string {
|
||||
}
|
||||
|
||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||
func (ir *PromptIR) GetTotalContent() string {
|
||||
func (ir *IR) GetTotalContent() string {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
|
||||
}
|
||||
|
||||
// AddSystem 添加系统提示
|
||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
func (ir *IR) AddSystem(content string) *IR {
|
||||
if content != "" {
|
||||
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
||||
}
|
||||
@@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
||||
}
|
||||
|
||||
// AddUser 添加用户消息
|
||||
func (ir *PromptIR) AddUser(content string) *PromptIR {
|
||||
func (ir *IR) AddUser(content string) *IR {
|
||||
if content != "" {
|
||||
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
||||
}
|
||||
@@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
|
||||
}
|
||||
|
||||
// AddHistory 添加历史消息
|
||||
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
func (ir *IR) AddHistory(role, content string) *IR {
|
||||
if content != "" {
|
||||
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
||||
}
|
||||
@@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
||||
}
|
||||
|
||||
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
||||
func (ir *IR) ToMessages() []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, seg := range ir.System {
|
||||
@@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
||||
|
||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||
p := &ProviderProtocol{
|
||||
return &ProviderProtocol{
|
||||
TargetField: e.TargetField,
|
||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||
MergeOrder: e.MergeOrder,
|
||||
RoleMapping: gconv.MapStrStr(e.RoleMapping),
|
||||
ContentMapping: ContentMapping{
|
||||
Type: gconv.String(e.ContentMapping["type"]),
|
||||
Field: gconv.String(e.ContentMapping["field"]),
|
||||
},
|
||||
RequestTemplate: e.RequestTemplate,
|
||||
Capabilities: e.Capabilities,
|
||||
}
|
||||
|
||||
// 使用通用解析方法处理各个字段
|
||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
||||
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
|
||||
return p
|
||||
}
|
||||
|
||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||
if ir == nil || p == nil {
|
||||
return nil, fmt.Errorf("ir and protocol are required")
|
||||
}
|
||||
@@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
|
||||
}
|
||||
|
||||
// mergeByOrder 按协议配置顺序拼接消息
|
||||
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
||||
func mergeByOrder(ir *IR, order []string) []map[string]any {
|
||||
roleMap := map[string][]Segment{
|
||||
"system": ir.System,
|
||||
"history": ir.History,
|
||||
"user": ir.User,
|
||||
}
|
||||
|
||||
var messages []map[string]any
|
||||
|
||||
for _, part := range order {
|
||||
switch part {
|
||||
case "system":
|
||||
for _, seg := range ir.System {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "system",
|
||||
"content": seg.Content,
|
||||
})
|
||||
for _, seg := range roleMap[part] {
|
||||
msg := map[string]any{"content": seg.Content}
|
||||
if part == "history" {
|
||||
msg["role"] = seg.Role
|
||||
} else {
|
||||
msg["role"] = part
|
||||
}
|
||||
case "history":
|
||||
for _, seg := range ir.History {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": seg.Role,
|
||||
"content": seg.Content,
|
||||
})
|
||||
}
|
||||
case "user":
|
||||
for _, seg := range ir.User {
|
||||
messages = append(messages, map[string]any{
|
||||
"role": "user",
|
||||
"content": seg.Content,
|
||||
})
|
||||
messages = append(messages, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
||||
return messages
|
||||
}
|
||||
|
||||
// mapContent 内容字段映射
|
||||
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||
for _, msg := range messages {
|
||||
content := msg["content"]
|
||||
delete(msg, "content")
|
||||
if cm.Field == "" || cm.Field == "content" {
|
||||
return messages
|
||||
}
|
||||
|
||||
for i, msg := range messages {
|
||||
if content, ok := msg["content"]; ok {
|
||||
delete(msg, "content")
|
||||
switch cm.Type {
|
||||
case "parts":
|
||||
msg["parts"] = []map[string]any{
|
||||
{cm.Field: content},
|
||||
}
|
||||
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
|
||||
default:
|
||||
msg[cm.Field] = content
|
||||
messages[i][cm.Field] = content
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
@@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
|
||||
}
|
||||
}
|
||||
|
||||
// renderTemplate 简单的 {{key}} 模板替换
|
||||
// renderTemplate 模板渲染
|
||||
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||
b, _ := json.Marshal(p.RequestTemplate)
|
||||
str := string(b)
|
||||
|
||||
if chatModel != nil {
|
||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
||||
result := make(map[string]any, len(p.RequestTemplate)+1)
|
||||
for k, v := range p.RequestTemplate {
|
||||
result[k] = v
|
||||
}
|
||||
|
||||
msgBytes, _ := json.Marshal(messages)
|
||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
||||
|
||||
var result map[string]any
|
||||
_ = json.Unmarshal([]byte(str), &result)
|
||||
if chatModel != nil {
|
||||
result["model"] = chatModel.ModelName
|
||||
}
|
||||
result["messages"] = messages
|
||||
|
||||
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||
result["max_tokens"] = maxTokens
|
||||
|
||||
@@ -21,8 +21,8 @@ import (
|
||||
|
||||
// Callback 会话回调
|
||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||
fmt.Println("打印会话回调", req)
|
||||
req.Messages["role"] = "assistant"
|
||||
|
||||
// 1) 更新 DB
|
||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||
@@ -163,23 +163,15 @@ func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteS
|
||||
|
||||
// entityToHistoryRound entity → HistoryRound
|
||||
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
||||
|
||||
round := &dto.HistoryRound{
|
||||
return &dto.HistoryRound{
|
||||
Id: s.Id,
|
||||
SessionId: s.SessionId,
|
||||
NodeId: s.NodeId,
|
||||
CreatedAt: gconv.String(s.CreatedAt),
|
||||
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||
User: s.RequestContent,
|
||||
Assistant: s.ResponseContent,
|
||||
}
|
||||
if len(reqMsgs) > 0 {
|
||||
round.User = reqMsgs[0]
|
||||
}
|
||||
if len(respMsgs) > 0 {
|
||||
round.Assistant = respMsgs[0]
|
||||
}
|
||||
return round
|
||||
}
|
||||
|
||||
// sessionsToHistoryRounds 批量转换
|
||||
|
||||
Reference in New Issue
Block a user