Compare commits
2 Commits
aa7804656f
...
1f9a2b9b5f
| Author | SHA1 | Date | |
|---|---|---|---|
| 1f9a2b9b5f | |||
| e1461cf0f0 |
@@ -2,13 +2,11 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
tGjson "github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertToMessages 将原始数据转换为消息列表
|
||||
@@ -17,7 +15,7 @@ func ConvertToMessages(raw any) []map[string]any {
|
||||
return nil
|
||||
}
|
||||
|
||||
j := gfgjson.New(raw)
|
||||
j := gjson.New(raw)
|
||||
messages := j.Get("messages")
|
||||
if !messages.IsNil() {
|
||||
return gconv.Maps(messages.Val())
|
||||
@@ -66,7 +64,7 @@ func JSONPretty(v any) string {
|
||||
return gconv.String(v)
|
||||
}
|
||||
|
||||
b, _ := json.MarshalIndent(tmp, "", " ")
|
||||
b, _ := json.Marshal(tmp)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
@@ -102,132 +100,98 @@ func ParseJSONFieldFromGvar(source any, target any) {
|
||||
}
|
||||
}
|
||||
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中。
|
||||
//
|
||||
// 参数说明:
|
||||
// - req: 请求参数 map,需包含 "consult" 字段,值为 []any,每个元素是 {"type":"xxx","url":"..."}
|
||||
// - messages: 模型生成的返回结构(如 rounds[...].messages[...].content 数组)
|
||||
// - extendMapping: 附加映射配置,格式:
|
||||
// {"attachments": {"image": {"template": {...}, "target_path": "...", "field_mapping": {...}}, ...}}
|
||||
//
|
||||
// 返回值:合并后的完整 map。
|
||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
reqJSON, _ := json.Marshal(req)
|
||||
msgJSON, _ := json.Marshal(messages)
|
||||
extJSON, _ := json.Marshal(extendMapping)
|
||||
|
||||
reqStr := string(reqJSON)
|
||||
msgStr := string(msgJSON)
|
||||
extStr := string(extJSON)
|
||||
|
||||
// 获取 consult 数组
|
||||
consultResult := tGjson.Get(reqStr, "consult")
|
||||
if !consultResult.Exists() || !consultResult.IsArray() {
|
||||
// 1) 获取 consult 数组
|
||||
consult := gconv.Interfaces(req["consult"])
|
||||
if len(consult) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
// 获取 attachments 配置
|
||||
attachmentsResult := tGjson.Get(extStr, "attachments")
|
||||
if !attachmentsResult.Exists() || !attachmentsResult.IsObject() {
|
||||
// 2) 获取配置
|
||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||
if targetPath == "" {
|
||||
return messages
|
||||
}
|
||||
|
||||
consultArr := consultResult.Array()
|
||||
attachmentsMap := attachmentsResult.Map()
|
||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||
if len(templates) == 0 {
|
||||
return messages
|
||||
}
|
||||
|
||||
for _, consultItem := range consultArr {
|
||||
if !consultItem.IsObject() {
|
||||
continue
|
||||
}
|
||||
// 3) 转为 gjson 操作
|
||||
msgJson := gjson.New(messages)
|
||||
|
||||
itemType := consultItem.Get("type").String()
|
||||
// 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath}
|
||||
if arr := msgJson.Get("rounds.0").Array(); arr != nil {
|
||||
targetPath = "rounds.0." + targetPath
|
||||
}
|
||||
|
||||
// 4) 遍历 consult,按类型生成附件并追加
|
||||
for _, item := range consult {
|
||||
itemJson := gjson.New(item)
|
||||
|
||||
itemType := itemJson.Get("type").String()
|
||||
if itemType == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 查找对应类型的附件配置
|
||||
attachResult, ok := attachmentsMap[itemType]
|
||||
if !ok || !attachResult.IsObject() {
|
||||
// 查找对应模板
|
||||
tmpl := gconv.Map(templates[itemType])
|
||||
if len(tmpl) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// 获取模板
|
||||
templateResult := attachResult.Get("template")
|
||||
if !templateResult.Exists() || !templateResult.IsObject() {
|
||||
// 生成附件对象
|
||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||
if attachment == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 深拷贝模板
|
||||
filledTemplateStr := templateResult.Raw
|
||||
// 获取当前数组长度,用索引追加
|
||||
arr := msgJson.Get(targetPath).Array()
|
||||
idx := len(arr)
|
||||
indexPath := fmt.Sprintf("%s.%d", targetPath, idx)
|
||||
_ = msgJson.Set(indexPath, attachment)
|
||||
}
|
||||
|
||||
// 应用字段映射
|
||||
fieldMappingResult := attachResult.Get("field_mapping")
|
||||
if fieldMappingResult.Exists() && fieldMappingResult.IsObject() {
|
||||
fieldMapping := fieldMappingResult.Map()
|
||||
for fieldPath, valueSource := range fieldMapping {
|
||||
sourceKey := valueSource.String()
|
||||
valueResult := consultItem.Get(sourceKey)
|
||||
if valueResult.Exists() {
|
||||
var err error
|
||||
filledTemplateStr, err = sjson.SetRaw(filledTemplateStr, fieldPath, valueResult.Raw)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return msgJson.Map()
|
||||
}
|
||||
|
||||
// buildAttachment 根据模板和用户数据生成附件对象
|
||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||
typ := gconv.String(tmpl["type"])
|
||||
if typ == "" || url == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 深拷贝 body 并填充 url
|
||||
body := gconv.Map(tmpl["body"])
|
||||
bodyJson := gjson.New(body)
|
||||
bodyJson = fillEmpty(bodyJson, url)
|
||||
|
||||
return map[string]any{
|
||||
"type": typ,
|
||||
typ: bodyJson.Map(),
|
||||
}
|
||||
}
|
||||
|
||||
// fillEmpty 递归查找空字符串并替换
|
||||
func fillEmpty(j *gjson.Json, value string) *gjson.Json {
|
||||
m := j.Map()
|
||||
for k, v := range m {
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
if vv == "" {
|
||||
_ = j.Set(k, value)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取目标路径
|
||||
targetPath := attachResult.Get("target_path").String()
|
||||
if targetPath == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查目标路径是否存在且为数组
|
||||
targetResult := tGjson.Get(msgStr, targetPath)
|
||||
if !targetResult.Exists() || !targetResult.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
// 追加到数组末尾
|
||||
arrLen := len(targetResult.Array())
|
||||
appendPath := targetPath + "." + strconv.Itoa(arrLen)
|
||||
var err error
|
||||
msgStr, err = sjson.SetRaw(msgStr, appendPath, filledTemplateStr)
|
||||
if err != nil {
|
||||
continue
|
||||
case map[string]any:
|
||||
_ = j.Set(k, fillEmpty(gjson.New(vv), value).Map())
|
||||
}
|
||||
}
|
||||
|
||||
// 转回 map[string]any
|
||||
var result map[string]any
|
||||
if err := json.Unmarshal([]byte(msgStr), &result); err != nil {
|
||||
return messages
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// GetUserMessage 获取用户消息
|
||||
func GetUserMessage(taskReq map[string]any) map[string]any {
|
||||
// 先取 requestPayload
|
||||
rp, ok := taskReq["requestPayload"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
// 再取 messages
|
||||
messages, ok := rp["messages"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
for _, msg := range messages {
|
||||
m, ok := msg.(map[string]any)
|
||||
if ok && m["role"] == "user" {
|
||||
return m
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return j
|
||||
}
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ReverseMap 映射 payload 到 mapping
|
||||
@@ -22,86 +20,80 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构
|
||||
func ExtractUserText(messages map[string]any) map[string]any {
|
||||
var texts []string
|
||||
|
||||
responseJson := gjson.New(responseBytes)
|
||||
resultJson := gjson.New("{}")
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
val := responseJson.Get(path)
|
||||
if val.IsNil() {
|
||||
continue
|
||||
}
|
||||
resultJson.Set(standardField, val.Val())
|
||||
}
|
||||
|
||||
return []byte(resultJson.String()), nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
// 1) rounds 结构:遍历每轮
|
||||
if rounds, ok := messages["rounds"].([]any); ok {
|
||||
for _, round := range rounds {
|
||||
if rm, ok := round.(map[string]any); ok {
|
||||
if msgs, ok := rm["messages"].([]any); ok {
|
||||
texts = append(texts, extractTextFromRoleUser(msgs)...)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
} else if msgs, ok := messages["messages"].([]any); ok {
|
||||
// 2) messages 结构
|
||||
texts = extractTextFromRoleUser(msgs)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
|
||||
// 3) 构建返回结构
|
||||
return map[string]any{
|
||||
"role": "user",
|
||||
"content": strings.Join(texts, "\n"),
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PayloadToQuery 将 payload 转为 url.Values
|
||||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本
|
||||
func extractTextFromRoleUser(msgs []any) []string {
|
||||
var texts []string
|
||||
for _, msg := range msgs {
|
||||
m, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
q.Set(k, gconv.String(v))
|
||||
if role, _ := m["role"].(string); role != "user" {
|
||||
continue
|
||||
}
|
||||
texts = append(texts, extractAllText(m["content"])...)
|
||||
}
|
||||
return q, nil
|
||||
return texts
|
||||
}
|
||||
|
||||
// extractAllText 从 content 中提取所有文本(递归,最大兼容)
|
||||
func extractAllText(content any) []string {
|
||||
switch c := content.(type) {
|
||||
case string:
|
||||
return []string{c}
|
||||
|
||||
case []any:
|
||||
var texts []string
|
||||
for _, item := range c {
|
||||
m, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if t, ok := m["text"].(string); ok && t != "" {
|
||||
texts = append(texts, t)
|
||||
continue
|
||||
}
|
||||
for _, v := range m {
|
||||
texts = append(texts, extractAllText(v)...)
|
||||
}
|
||||
}
|
||||
return texts
|
||||
|
||||
case map[string]any:
|
||||
if t, ok := c["text"].(string); ok && t != "" {
|
||||
return []string{t}
|
||||
}
|
||||
var texts []string
|
||||
for _, v := range c {
|
||||
texts = append(texts, extractAllText(v)...)
|
||||
}
|
||||
return texts
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,3 +26,8 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
|
||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||
}
|
||||
|
||||
// GetPromptText 纯文本prompt调用接口(测试专用)
|
||||
func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) {
|
||||
return promptService.GetPromptText(ctx, req)
|
||||
}
|
||||
|
||||
@@ -54,6 +54,7 @@ func (d *composeSessionDao) List(ctx context.Context, req *entity.ComposeSession
|
||||
OmitEmpty()
|
||||
model.Where(entity.ComposeSessionCol.Creator, req.Creator)
|
||||
model.Where(entity.ComposeSessionCol.SessionId, req.SessionId)
|
||||
model.Where(entity.ComposeSessionCol.NodeId, req.NodeId)
|
||||
model.OrderDesc(entity.ComposeSessionCol.CreatedAt)
|
||||
model.Page(page, size)
|
||||
r, total, err := model.AllAndCount(false)
|
||||
|
||||
6
go.mod
6
go.mod
@@ -3,12 +3,10 @@ module prompts-core
|
||||
go 1.26.1
|
||||
|
||||
require (
|
||||
gitea.com/red-future/common v0.0.20
|
||||
gitea.com/red-future/common v0.0.21
|
||||
github.com/gogf/gf/contrib/drivers/pgsql/v2 v2.10.2
|
||||
github.com/gogf/gf/contrib/nosql/redis/v2 v2.10.2
|
||||
github.com/gogf/gf/v2 v2.10.2
|
||||
github.com/tidwall/gjson v1.19.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -65,8 +63,6 @@ require (
|
||||
github.com/r3labs/diff/v2 v2.15.1 // indirect
|
||||
github.com/redis/go-redis/v9 v9.12.1 // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.0 // indirect
|
||||
github.com/tiger1103/gfast-token v1.0.10 // indirect
|
||||
github.com/vcaesar/cedar v0.30.0 // indirect
|
||||
github.com/vmihailenco/msgpack v4.0.4+incompatible // indirect
|
||||
|
||||
13
go.sum
13
go.sum
@@ -1,6 +1,6 @@
|
||||
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
|
||||
gitea.com/red-future/common v0.0.20 h1:KlKINnJFmOVkDzgkptEAFsdpMUZb0zK9BTdiXRxVfAo=
|
||||
gitea.com/red-future/common v0.0.20/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
||||
gitea.com/red-future/common v0.0.21 h1:8w30HmCVmFG/hphH3ODJs1KxDEGmRpq+/PXI0pQjJKc=
|
||||
gitea.com/red-future/common v0.0.21/go.mod h1:6/nqIucVzmjOyqDTIq71feYBXXFNBy0rFwzaQ0/Ueoo=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -288,15 +288,6 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||
github.com/tidwall/gjson v1.19.0 h1:xwxm7n691Uf3u5OFjzngavjGTh55KX5q/9w9xHW88JU=
|
||||
github.com/tidwall/gjson v1.19.0/go.mod h1:V37/opeE/JbLUOfH0QTXiNez2l0RUjYUhpT4szFQAfc=
|
||||
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
|
||||
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||
github.com/tiger1103/gfast-token v1.0.10 h1:fNiBE/Dq5iTHvTGlCx3DmXa2o4hr0NtumFpffZ39k6s=
|
||||
github.com/tiger1103/gfast-token v1.0.10/go.mod h1:a/21mxmj7zFeNvjhZSC0XpEAFHfb1aT2k6DXnufFU1s=
|
||||
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
|
||||
|
||||
@@ -59,3 +59,12 @@ type GetComposeTaskRes struct {
|
||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||
}
|
||||
|
||||
type GetPromptTextReq struct {
|
||||
g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"`
|
||||
Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"`
|
||||
}
|
||||
|
||||
type GetPromptTextRes struct {
|
||||
Messages any `json:"messages" dc:"最终消息数组"`
|
||||
}
|
||||
|
||||
@@ -58,8 +58,7 @@ type AsynchModel struct {
|
||||
ResponseTokenField string `orm:"response_token_field" json:"responseTokenField"`
|
||||
IsPrivate *int `orm:"is_private" json:"isPrivate"`
|
||||
IsChatModel int `orm:"is_chat_model" json:"isChatModel"`
|
||||
IsAsync *int `orm:"is_async" json:"isAsync"`
|
||||
IsStream *int `orm:"is_stream" json:"isStream"`
|
||||
CallModel int `orm:"call_model" json:"callModel"`
|
||||
ApiKey string `orm:"api_key" json:"apiKey"`
|
||||
Enabled *int `orm:"enabled" json:"enabled"`
|
||||
MaxConcurrency int `orm:"max_concurrency" json:"maxConcurrency"`
|
||||
|
||||
@@ -28,13 +28,13 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
availableWindow := util.GetAvailableWindow(aiModel.TokenConfig)
|
||||
return nil, fmt.Errorf("整体内容超出模型窗口大小限制(可用窗口=%d tokens),请精简后重试", availableWindow)
|
||||
}
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||
}
|
||||
|
||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||
@@ -50,18 +50,20 @@ func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ch
|
||||
// 用户消息
|
||||
ir.AddSystem(customPrompt)
|
||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||
return compileToProviderRequest(ctx, ir, chatModel, customPrompt)
|
||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||
}
|
||||
|
||||
// compileToProviderRequest 编译为 Provider 请求
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, customPrompt ...string) (map[string]any, error) {
|
||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||
if err != nil || protocol == nil {
|
||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||
}
|
||||
// 如果传了自定义提示词,替换掉协议模板
|
||||
if len(customPrompt) > 0 && customPrompt[0] != "" {
|
||||
protocol.SystemPromptTemplate = customPrompt[0]
|
||||
protocol.SystemPromptTemplate = customPrompt[0] +
|
||||
"【核心铁律】" +
|
||||
"1.【技能内容skill相关】必须完整拼接到System提示词中,作为System提示词的组成部分,不得拆分到其他位置。"
|
||||
}
|
||||
providerReq, err := Compile(ir, protocol, chatModel)
|
||||
if err != nil {
|
||||
@@ -72,6 +74,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": utils.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
"buildType": req.BuildType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -84,20 +87,12 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
||||
return ""
|
||||
}
|
||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||
|
||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||
outputJSON, //【输出结构】 %s
|
||||
)
|
||||
}
|
||||
|
||||
// buildUserFormContent 构建用户表单内容字符串
|
||||
func buildUserFormContent(userForm []map[string]any) string {
|
||||
var builder strings.Builder
|
||||
for _, item := range userForm {
|
||||
builder.WriteString(fmt.Sprintf("%v\n", item))
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// checkOverallContent 检查整体内容是否超出窗口
|
||||
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
||||
fullContent := ir.String()
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"prompts-core/service/session"
|
||||
|
||||
"prompts-core/common/util"
|
||||
"prompts-core/consts/public"
|
||||
@@ -173,24 +174,10 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return fmt.Errorf("查询模型失败: %w", err)
|
||||
}
|
||||
// 2) 根据运营商获取协议配置
|
||||
//protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
// ProviderName: model.OperatorName,
|
||||
//})
|
||||
|
||||
// 2) 解析结果
|
||||
var messages map[string]any
|
||||
switch composeTask.BuildType {
|
||||
case public.BuildTypePrompt, public.BuildTypeNode:
|
||||
messages = ParseResult(req.Messages, model.ResponseBody)
|
||||
case public.BuildTypeStruct:
|
||||
messages = ParseStructResult(req.Messages, model.ResponseBody)
|
||||
default:
|
||||
messages = req.Messages
|
||||
}
|
||||
// 3) 合并附加结构
|
||||
messages = util.MergeConsult(composeTask.RequestPayload, messages, model.ExtendMapping)
|
||||
// 4) 更新数据库
|
||||
// 2) 合并附加结构
|
||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||
// 3) 更新数据库
|
||||
_, err = dao.ComposeTask.Update(ctx, &entity.ComposeTask{
|
||||
TaskId: req.TaskId,
|
||||
Status: public.ComposeStatusSuccess,
|
||||
@@ -203,21 +190,31 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
//var userHistoryMsg map[string]any
|
||||
var epicycleId int64
|
||||
payload := composeTask.RequestPayload
|
||||
sessionId := gconv.String(payload["sessionId"])
|
||||
nodeId := gconv.String(payload["nodeId"])
|
||||
buildType := gconv.Int(payload["buildType"])
|
||||
if buildType == public.BuildTypePrompt && sessionId != "" && nodeId != "" {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: messages,
|
||||
})
|
||||
// 4) 获取历史内容并拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, sessionId, nodeId)
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// 5) 存储提示词结果作为历史请求
|
||||
if userMsg := util.ExtractUserText(messages); userMsg != nil {
|
||||
epicycleId, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
NodeId: nodeId,
|
||||
SessionId: sessionId,
|
||||
RequestContent: userMsg,
|
||||
})
|
||||
}
|
||||
}
|
||||
// 6) 拼接历史内容
|
||||
// 7) 回调业务方
|
||||
// 6) 回调业务方
|
||||
if composeTask.CallbackUrl != "" {
|
||||
composeTask.Status = public.ComposeStatusSuccess
|
||||
composeTask.Messages = messages
|
||||
@@ -226,95 +223,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseResult 解析结果
|
||||
func ParseResult(raw map[string]any, responseBody string) map[string]any {
|
||||
if responseBody == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
if contentVal == nil {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 已经是数组
|
||||
if arr, ok := contentVal.([]any); ok {
|
||||
rounds := gconv.Maps(arr)
|
||||
if len(rounds) > 0 {
|
||||
return map[string]any{"total_rounds": len(rounds), "rounds": rounds}
|
||||
}
|
||||
return raw
|
||||
}
|
||||
|
||||
// 是字符串
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" {
|
||||
return raw
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
return map[string]any{"total_rounds": len(arr), "rounds": arr}
|
||||
}
|
||||
|
||||
// 尝试解析为单对象
|
||||
var obj map[string]any
|
||||
if err := json.Unmarshal([]byte(contentStr), &obj); err == nil && len(obj) > 0 {
|
||||
return map[string]any{"total_rounds": 1, "rounds": []map[string]any{obj}}
|
||||
}
|
||||
|
||||
return map[string]any{"content": contentStr}
|
||||
}
|
||||
|
||||
func ParseStructResult(raw map[string]any, responseBody string) map[string]any {
|
||||
// 如果外层已有 rounds,直接返回
|
||||
if _, ok := raw["rounds"]; ok {
|
||||
return raw
|
||||
}
|
||||
|
||||
contentVal := raw[responseBody]
|
||||
|
||||
var rounds []map[string]any
|
||||
|
||||
// 是字符串,尝试解析
|
||||
contentStr := gconv.String(contentVal)
|
||||
if contentStr == "" || contentStr == "0" {
|
||||
rounds = append(rounds, map[string]any{responseBody: raw})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为数组
|
||||
var arr []any
|
||||
if err := json.Unmarshal([]byte(contentStr), &arr); err == nil && len(arr) > 0 {
|
||||
rounds = append(rounds, map[string]any{responseBody: arr})
|
||||
return map[string]any{
|
||||
"total_rounds": len(rounds),
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 尝试解析为单个对象
|
||||
var parsed any
|
||||
if err := json.Unmarshal([]byte(contentStr), &parsed); err == nil {
|
||||
rounds = append(rounds, map[string]any{responseBody: parsed})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// 兜底:原始字符串作为内容
|
||||
rounds = append(rounds, map[string]any{responseBody: contentStr})
|
||||
return map[string]any{
|
||||
"total_rounds": 1,
|
||||
"rounds": rounds,
|
||||
}
|
||||
}
|
||||
|
||||
// GetComposeTask 查询任务结果
|
||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||
@@ -351,3 +259,13 @@ func parseMessagesForResponse(messages any) any {
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
||||
// 1) 获取基础数据
|
||||
|
||||
// 4) 模拟历史拼接
|
||||
history, _ := session.GetHistoryMessages(ctx, "88888888", "node1")
|
||||
return &dto.GetPromptTextRes{
|
||||
Messages: history,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ func saveToRedis(ctx context.Context, session *entity.ComposeSession) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化会话数据失败: %w", err)
|
||||
}
|
||||
|
||||
if err = executeRedisCommands(ctx, key, string(b), maxRounds, expireSeconds); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -50,18 +49,15 @@ func executeRedisCommands(ctx context.Context, key string, value string, maxRoun
|
||||
if _, err := g.Redis().Do(ctx, "LTRIM", key, 0, maxRounds-1); err != nil {
|
||||
return fmt.Errorf("裁剪Redis列表失败: %w", err)
|
||||
}
|
||||
|
||||
if _, err := g.Redis().Do(ctx, "EXPIRE", key, expireSeconds); err != nil {
|
||||
return fmt.Errorf("设置过期时间失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getFromRedis 从Redis获取会话历史
|
||||
func getFromRedis(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
key := formatRedisKey(sessionId)
|
||||
|
||||
result, err := g.Redis().Do(ctx, "LRANGE", key, 0, -1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("从Redis获取数据失败: %w", err)
|
||||
|
||||
@@ -47,15 +47,35 @@ func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCal
|
||||
}
|
||||
|
||||
// GetHistoryMessages 获取历史信息
|
||||
func GetHistoryMessages(ctx context.Context, sessionId string) ([]map[string]any, error) {
|
||||
func GetHistoryMessages(ctx context.Context, sessionId string, nodeId string) ([]map[string]any, error) {
|
||||
// 1) 获取最大轮次
|
||||
maxRounds := g.Cfg().MustGet(ctx, "session.maxRounds", 10).Int()
|
||||
|
||||
// 2) 从 Redis 获取历史记录
|
||||
redisHistory, err := GetSessionHistoryForInference(ctx, sessionId)
|
||||
if err == nil && len(redisHistory) > 0 {
|
||||
return redisHistory, nil
|
||||
}
|
||||
|
||||
return getHistoryFromDatabase(ctx, sessionId, maxRounds)
|
||||
// 3) Redis 没有,从数据库查最新 maxRounds 条
|
||||
sessions, _, err := dao.ComposeSession.List(ctx, &entity.ComposeSession{
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
}, 1, maxRounds)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DB获取历史失败: %w", err)
|
||||
}
|
||||
// 4) 为空返回报错
|
||||
if len(sessions) == 0 {
|
||||
return nil, fmt.Errorf("会话不存在: sessionId=%s nodeId=%s", sessionId, nodeId)
|
||||
}
|
||||
// 5) 提取为统一格式
|
||||
messages := extractMessagesFromSessions(sessions)
|
||||
|
||||
// 6) 缓存 Redis 半小时
|
||||
//_ = CacheSessionHistoryForInference(ctx, sessionId, messages, 30*time.Minute)
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// getHistoryFromDatabase 从数据库获取历史记录
|
||||
@@ -77,12 +97,10 @@ func getHistoryFromDatabase(ctx context.Context, sessionId string, maxRounds int
|
||||
// extractMessagesFromSessions 从会话列表中提取消息
|
||||
func extractMessagesFromSessions(sessions []*entity.ComposeSession) []map[string]any {
|
||||
var messages []map[string]any
|
||||
|
||||
for _, session := range sessions {
|
||||
appendRequestMessages(session.RequestContent, &messages)
|
||||
appendResponseMessages(session.ResponseContent, &messages)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user