refactor(prompts-core): 重构代码结构和优化工具函数
This commit is contained in:
@@ -1,197 +1,81 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/container/gvar"
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertToMessages 将原始数据转换为消息列表
|
|
||||||
func ConvertToMessages(raw any) []map[string]any {
|
|
||||||
if raw == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
j := gjson.New(raw)
|
|
||||||
messages := j.Get("messages")
|
|
||||||
if !messages.IsNil() {
|
|
||||||
return gconv.Maps(messages.Val())
|
|
||||||
}
|
|
||||||
return []map[string]any{j.Map()}
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
|
||||||
func FormToJSON(form []map[string]any) string {
|
|
||||||
if form == nil {
|
|
||||||
return "[]"
|
|
||||||
}
|
|
||||||
b, _ := json.Marshal(form)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
|
|
||||||
func UserFormToJSON(form []map[string]any) string {
|
|
||||||
if form == nil {
|
|
||||||
return "{}"
|
|
||||||
}
|
|
||||||
|
|
||||||
b, _ := json.Marshal(form)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustMarshalToMap 将对象序列化为 map[string]any,失败时返回空 map
|
|
||||||
func MustMarshalToMap(v any) map[string]any {
|
|
||||||
b, err := json.Marshal(v)
|
|
||||||
if err != nil {
|
|
||||||
return make(map[string]any)
|
|
||||||
}
|
|
||||||
var m map[string]any
|
|
||||||
json.Unmarshal(b, &m)
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSONPretty 将任意类型转为格式化的 JSON 字符串
|
|
||||||
func JSONPretty(v any) string {
|
|
||||||
if gv, ok := v.(*gvar.Var); ok {
|
|
||||||
v = gconv.Map(gv.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
var tmp map[string]any
|
|
||||||
if err := gconv.Struct(v, &tmp); err != nil {
|
|
||||||
return gconv.String(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
b, _ := json.Marshal(tmp)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
|
|
||||||
func ParseJSONFieldFromGvar(source any, target any) {
|
|
||||||
if source == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch v := source.(type) {
|
|
||||||
case *gvar.Var:
|
|
||||||
if v.IsNil() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试获取 map
|
|
||||||
if m := v.Map(); len(m) > 0 {
|
|
||||||
data, _ := json.Marshal(m)
|
|
||||||
json.Unmarshal(data, target)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 尝试解析 JSON 字符串
|
|
||||||
str := v.String()
|
|
||||||
if str != "" && str != "<nil>" {
|
|
||||||
json.Unmarshal([]byte(str), target)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
// 其他类型走原来的逻辑
|
|
||||||
data, _ := json.Marshal(source)
|
|
||||||
json.Unmarshal(data, target)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
|
||||||
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
|
||||||
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1) 获取 consult 数组
|
|
||||||
consult := gconv.Interfaces(req["consult"])
|
consult := gconv.Interfaces(req["consult"])
|
||||||
if len(consult) == 0 {
|
if len(consult) == 0 {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) 获取配置
|
|
||||||
targetPath := gconv.String(extendMapping["target_content_path"])
|
targetPath := gconv.String(extendMapping["target_content_path"])
|
||||||
if targetPath == "" {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
templates := gconv.Map(extendMapping["attachment_templates"])
|
templates := gconv.Map(extendMapping["attachment_templates"])
|
||||||
if len(templates) == 0 {
|
if targetPath == "" || len(templates) == 0 {
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 转为 gjson 操作
|
|
||||||
msgJson := gjson.New(messages)
|
msgJson := gjson.New(messages)
|
||||||
|
|
||||||
// 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath}
|
// rounds 路径修正
|
||||||
if arr := msgJson.Get("rounds.0").Array(); arr != nil {
|
if !msgJson.Get("rounds.0").IsNil() {
|
||||||
targetPath = "rounds.0." + targetPath
|
targetPath = "rounds.0." + targetPath
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 遍历 consult,按类型生成附件并追加
|
// 遍历追加
|
||||||
for _, item := range consult {
|
for _, item := range consult {
|
||||||
itemJson := gjson.New(item)
|
itemJson := gjson.New(item)
|
||||||
|
|
||||||
itemType := itemJson.Get("type").String()
|
itemType := itemJson.Get("type").String()
|
||||||
if itemType == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查找对应模板
|
|
||||||
tmpl := gconv.Map(templates[itemType])
|
tmpl := gconv.Map(templates[itemType])
|
||||||
if len(tmpl) == 0 {
|
if itemType == "" || len(tmpl) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成附件对象
|
|
||||||
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
|
||||||
if attachment == nil {
|
if attachment == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前数组长度,用索引追加
|
idx := len(msgJson.Get(targetPath).Array())
|
||||||
arr := msgJson.Get(targetPath).Array()
|
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
|
||||||
idx := len(arr)
|
|
||||||
indexPath := fmt.Sprintf("%s.%d", targetPath, idx)
|
|
||||||
_ = msgJson.Set(indexPath, attachment)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return msgJson.Map()
|
return msgJson.Map()
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildAttachment 根据模板和用户数据生成附件对象
|
|
||||||
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
func buildAttachment(tmpl map[string]any, url string) map[string]any {
|
||||||
typ := gconv.String(tmpl["type"])
|
typ := gconv.String(tmpl["type"])
|
||||||
if typ == "" || url == "" {
|
if typ == "" || url == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 深拷贝 body 并填充 url
|
|
||||||
body := gconv.Map(tmpl["body"])
|
body := gconv.Map(tmpl["body"])
|
||||||
bodyJson := gjson.New(body)
|
fillEmptyInPlace(body, url)
|
||||||
bodyJson = fillEmpty(bodyJson, url)
|
|
||||||
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"type": typ,
|
"type": typ,
|
||||||
typ: bodyJson.Map(),
|
typ: body,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fillEmpty 递归查找空字符串并替换
|
func fillEmptyInPlace(m map[string]any, value string) {
|
||||||
func fillEmpty(j *gjson.Json, value string) *gjson.Json {
|
|
||||||
m := j.Map()
|
|
||||||
for k, v := range m {
|
for k, v := range m {
|
||||||
switch vv := v.(type) {
|
switch vv := v.(type) {
|
||||||
case string:
|
case string:
|
||||||
if vv == "" {
|
if vv == "" {
|
||||||
_ = j.Set(k, value)
|
m[k] = value
|
||||||
}
|
}
|
||||||
case map[string]any:
|
case map[string]any:
|
||||||
_ = j.Set(k, fillEmpty(gjson.New(vv), value).Map())
|
fillEmptyInPlace(vv, value)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return j
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReverseMap 映射 payload 到 mapping
|
// ReverseMap 映射 payload 到 mapping
|
||||||
@@ -20,80 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
|||||||
return jsonObj.Map()
|
return jsonObj.Map()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构
|
// ExtractUserText 从 messages 中提取所有 user 文本
|
||||||
func ExtractUserText(messages map[string]any) map[string]any {
|
func ExtractUserText(messages map[string]any) map[string]any {
|
||||||
|
msgJson := gjson.New(messages)
|
||||||
|
|
||||||
|
msgs := msgJson.Get("rounds.0.messages")
|
||||||
|
if msgs.IsNil() {
|
||||||
|
msgs = msgJson.Get("messages")
|
||||||
|
}
|
||||||
var texts []string
|
var texts []string
|
||||||
|
for _, m := range msgs.Array() {
|
||||||
// 1) rounds 结构:遍历每轮
|
msg := gjson.New(m)
|
||||||
if rounds, ok := messages["rounds"].([]any); ok {
|
if msg.Get("role").String() != "user" {
|
||||||
for _, round := range rounds {
|
continue
|
||||||
if rm, ok := round.(map[string]any); ok {
|
}
|
||||||
if msgs, ok := rm["messages"].([]any); ok {
|
content := msg.Get("content").Val()
|
||||||
texts = append(texts, extractTextFromRoleUser(msgs)...)
|
switch c := content.(type) {
|
||||||
|
case string:
|
||||||
|
texts = append(texts, c)
|
||||||
|
case []any:
|
||||||
|
for _, item := range c {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if t := gconv.String(m["text"]); t != "" {
|
||||||
|
texts = append(texts, t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if msgs, ok := messages["messages"].([]any); ok {
|
|
||||||
// 2) messages 结构
|
|
||||||
texts = extractTextFromRoleUser(msgs)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3) 构建返回结构
|
|
||||||
return map[string]any{
|
return map[string]any{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": strings.Join(texts, "\n"),
|
"content": strings.Join(texts, "\n"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本
|
|
||||||
func extractTextFromRoleUser(msgs []any) []string {
|
|
||||||
var texts []string
|
|
||||||
for _, msg := range msgs {
|
|
||||||
m, ok := msg.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if role, _ := m["role"].(string); role != "user" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
texts = append(texts, extractAllText(m["content"])...)
|
|
||||||
}
|
|
||||||
return texts
|
|
||||||
}
|
|
||||||
|
|
||||||
// extractAllText 从 content 中提取所有文本(递归,最大兼容)
|
|
||||||
func extractAllText(content any) []string {
|
|
||||||
switch c := content.(type) {
|
|
||||||
case string:
|
|
||||||
return []string{c}
|
|
||||||
|
|
||||||
case []any:
|
|
||||||
var texts []string
|
|
||||||
for _, item := range c {
|
|
||||||
m, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if t, ok := m["text"].(string); ok && t != "" {
|
|
||||||
texts = append(texts, t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, v := range m {
|
|
||||||
texts = append(texts, extractAllText(v)...)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return texts
|
|
||||||
|
|
||||||
case map[string]any:
|
|
||||||
if t, ok := c["text"].(string); ok && t != "" {
|
|
||||||
return []string{t}
|
|
||||||
}
|
|
||||||
var texts []string
|
|
||||||
for _, v := range c {
|
|
||||||
texts = append(texts, extractAllText(v)...)
|
|
||||||
}
|
|
||||||
return texts
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -11,3 +11,7 @@ const (
|
|||||||
BuildTypeNode = 2 //节点构建
|
BuildTypeNode = 2 //节点构建
|
||||||
BuildTypeStruct = 3 //结构构建
|
BuildTypeStruct = 3 //结构构建
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ModelTypeInference = 100 // 推理模型
|
||||||
|
)
|
||||||
|
|||||||
@@ -26,8 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
|
|||||||
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
|
||||||
return promptService.GetComposeTask(ctx, req.TaskId)
|
return promptService.GetComposeTask(ctx, req.TaskId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPromptText 纯文本prompt调用接口(测试专用)
|
|
||||||
func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) {
|
|
||||||
return promptService.GetPromptText(ctx, req)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -25,12 +25,6 @@ type ComposeMessagesRes struct {
|
|||||||
TaskId string `json:"taskId" dc:"任务ID"`
|
TaskId string `json:"taskId" dc:"任务ID"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// MultiRoundResult 多轮返回结果
|
|
||||||
type MultiRoundResult struct {
|
|
||||||
TotalRounds int `json:"total_rounds"` // 总轮数
|
|
||||||
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
|
|
||||||
}
|
|
||||||
|
|
||||||
type CallbackReq struct {
|
type CallbackReq struct {
|
||||||
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调:callbackUrl/{bizName}"`
|
||||||
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
|
||||||
@@ -55,16 +49,7 @@ type GetComposeTaskRes struct {
|
|||||||
Status string `json:"status" dc:"业务状态"`
|
Status string `json:"status" dc:"业务状态"`
|
||||||
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
GatewayState int `json:"gatewayState" dc:"网关状态"`
|
||||||
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
|
||||||
Messages any `json:"messages" dc:"最终消息数组"`
|
Messages map[string]any `json:"messages" dc:"最终消息数组"`
|
||||||
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
OssFile string `json:"ossFile" dc:"结果文件地址"`
|
||||||
FileType string `json:"fileType" dc:"结果文件类型"`
|
FileType string `json:"fileType" dc:"结果文件类型"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetPromptTextReq struct {
|
|
||||||
g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"`
|
|
||||||
Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GetPromptTextRes struct {
|
|
||||||
Messages any `json:"messages" dc:"历史消息"`
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,11 +13,10 @@ import (
|
|||||||
|
|
||||||
"gitea.com/red-future/common/utils"
|
"gitea.com/red-future/common/utils"
|
||||||
"github.com/gogf/gf/v2/encoding/gjson"
|
"github.com/gogf/gf/v2/encoding/gjson"
|
||||||
"github.com/gogf/gf/v2/util/gconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
//1) 构建系统提示词
|
//1) 构建系统提示词
|
||||||
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
|
||||||
ir.AddSystem(systemPrompt)
|
ir.AddSystem(systemPrompt)
|
||||||
@@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
|||||||
}
|
}
|
||||||
|
|
||||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
ir.AddUser(NodeBuild(ctx, req))
|
ir.AddUser(NodeBuild(ctx, req))
|
||||||
return compileToProviderRequest(ctx, ir, chatModel, req)
|
return compileToProviderRequest(ctx, ir, chatModel, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
// buildStructTypeRequest 构建结构体类型请求(BuildType=3)
|
||||||
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
|
||||||
// 提取 userForm 中的 prompt 作为自定义提示词
|
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
|
||||||
var customPrompt string
|
|
||||||
for _, item := range req.UserForm {
|
|
||||||
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
|
|
||||||
customPrompt = gconv.String(prompt)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 用户消息
|
|
||||||
ir.AddSystem(customPrompt)
|
ir.AddSystem(customPrompt)
|
||||||
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
ir.AddUser(buildUserPrompt(ctx, req, ""))
|
||||||
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compileToProviderRequest 编译为 Provider 请求
|
// compileToProviderRequest 编译为 Provider 请求
|
||||||
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
|
||||||
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
|
||||||
if err != nil || protocol == nil {
|
if err != nil || protocol == nil {
|
||||||
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
|
||||||
@@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// promptBuildWithRounds 构建提示词
|
||||||
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
|
||||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||||
ProviderName: chatModel.OperatorName,
|
ProviderName: chatModel.OperatorName,
|
||||||
@@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
|||||||
if err != nil || providerProtocol == nil {
|
if err != nil || providerProtocol == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
|
||||||
|
|
||||||
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
|
||||||
outputJSON, //【输出结构】 %s
|
outputJSON, //【输出结构】 %s
|
||||||
@@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkOverallContent 检查整体内容是否超出窗口
|
// checkOverallContent 检查整体内容是否超出窗口
|
||||||
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
|
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
|
||||||
fullContent := ir.String()
|
fullContent := ir.String()
|
||||||
return util.CountToken(fullContent, model.TokenConfig)
|
return util.CountToken(fullContent, model.TokenConfig)
|
||||||
}
|
}
|
||||||
@@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
|||||||
return b.String()
|
return b.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildUserFormText 构建用户表单内容字符串
|
|
||||||
func buildUserFormText(form []map[string]any) string {
|
func buildUserFormText(form []map[string]any) string {
|
||||||
if len(form) == 0 {
|
if len(form) == 0 {
|
||||||
return ""
|
return ""
|
||||||
@@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string {
|
|||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
for _, item := range form {
|
for _, item := range form {
|
||||||
for k, v := range item {
|
for k, v := range item {
|
||||||
|
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
||||||
switch val := v.(type) {
|
switch val := v.(type) {
|
||||||
case []any:
|
case []any:
|
||||||
// 数组类型:逐条列出
|
|
||||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
|
||||||
for i, elem := range val {
|
for i, elem := range val {
|
||||||
|
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
||||||
if m, ok := elem.(map[string]any); ok {
|
if m, ok := elem.(map[string]any); ok {
|
||||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
|
||||||
for mk, mv := range m {
|
for mk, mv := range m {
|
||||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
||||||
}
|
}
|
||||||
builder.WriteString("\n")
|
|
||||||
} else {
|
} else {
|
||||||
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
|
builder.WriteString(fmt.Sprint(elem))
|
||||||
}
|
|
||||||
}
|
|
||||||
case []map[string]any:
|
|
||||||
builder.WriteString(fmt.Sprintf("%s:\n", k))
|
|
||||||
for i, m := range val {
|
|
||||||
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
|
|
||||||
for mk, mv := range m {
|
|
||||||
builder.WriteString(fmt.Sprintf("%s:%v ", mk, mv))
|
|
||||||
}
|
}
|
||||||
builder.WriteString("\n")
|
builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
builder.WriteString(fmt.Sprintf("%s:%v\n", k, v))
|
builder.WriteString(fmt.Sprintf(" %v\n", v))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
|
|||||||
if promptTpl == "" {
|
if promptTpl == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
return fmt.Sprintf(promptTpl,
|
||||||
formStr := util.FormToJSON(req.Form)
|
gjson.New(req.Form).MustToJsonString(),
|
||||||
userFormStr := util.UserFormToJSON(req.UserForm)
|
gjson.New(req.UserForm).MustToJsonString(),
|
||||||
|
)
|
||||||
return fmt.Sprintf(promptTpl, formStr, userFormStr)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package prompt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/service/session"
|
"prompts-core/service/session"
|
||||||
@@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
|
|||||||
// handleBuild 通用构建处理
|
// handleBuild 通用构建处理
|
||||||
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
|
||||||
// 1) 处理表单分批
|
// 1) 处理表单分批
|
||||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||||
}
|
}
|
||||||
@@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
|||||||
var taskReq map[string]any
|
var taskReq map[string]any
|
||||||
switch req.BuildType {
|
switch req.BuildType {
|
||||||
case public.BuildTypePrompt:
|
case public.BuildTypePrompt:
|
||||||
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
|
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
|
||||||
case public.BuildTypeNode:
|
case public.BuildTypeNode:
|
||||||
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
|
||||||
case public.BuildTypeStruct:
|
case public.BuildTypeStruct:
|
||||||
@@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
|
|||||||
SkillName: req.SkillName,
|
SkillName: req.SkillName,
|
||||||
BuildType: req.BuildType,
|
BuildType: req.BuildType,
|
||||||
CallbackUrl: req.CallbackUrl,
|
CallbackUrl: req.CallbackUrl,
|
||||||
RequestPayload: util.MustMarshalToMap(req),
|
RequestPayload: gconv.Map(req),
|
||||||
Status: public.ComposeStatusPending,
|
Status: public.ComposeStatusPending,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleCallbackSuccess 处理回调成功
|
||||||
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
|
||||||
// 1) 获取模型配置
|
// 1) 获取模型配置
|
||||||
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
|
||||||
@@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
Status: 1,
|
Status: 1,
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3) 获取历史消息
|
// 3) 获取历史消息 + 保存当前轮
|
||||||
payload := composeTask.RequestPayload
|
payload := composeTask.RequestPayload
|
||||||
sessionId := gconv.String(payload["sessionId"])
|
sessionId := gconv.String(payload["sessionId"])
|
||||||
nodeId := gconv.String(payload["nodeId"])
|
nodeId := gconv.String(payload["nodeId"])
|
||||||
var history []dto.FlatMessage
|
var history []dto.FlatMessage
|
||||||
if sessionId != "" && nodeId != "" {
|
var epicycleId int64
|
||||||
|
|
||||||
|
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
|
||||||
|
// 3.1 获取历史
|
||||||
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
||||||
SessionId: sessionId,
|
SessionId: sessionId,
|
||||||
NodeId: nodeId,
|
NodeId: nodeId,
|
||||||
@@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
if h != nil {
|
if h != nil {
|
||||||
history = h.Messages
|
history = h.Messages
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3.2 保存当前轮(先存,下次查询就能拿到)
|
||||||
|
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
||||||
|
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||||
|
NodeId: nodeId,
|
||||||
|
SessionId: sessionId,
|
||||||
|
RequestContent: userMsg,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4) 合并附加结构
|
// 4) 合并附加结构
|
||||||
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
|
||||||
// 5) 注入历史到 rounds 中
|
// 5) 注入历史
|
||||||
if protocol != nil && len(history) > 0 {
|
if len(history) > 0 {
|
||||||
messages = InjectHistory(messages, history, protocol)
|
messages = InjectHistory(messages, history, protocol)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7) 存储历史
|
|
||||||
var epicycleId int64
|
|
||||||
if sessionId != "" && nodeId != "" {
|
|
||||||
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
|
|
||||||
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
|
||||||
NodeId: nodeId,
|
|
||||||
SessionId: sessionId,
|
|
||||||
RequestContent: userMsg,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 8) 回调业务方
|
// 8) 回调业务方
|
||||||
if composeTask.CallbackUrl != "" {
|
if composeTask.CallbackUrl != "" {
|
||||||
composeTask.Status = public.ComposeStatusSuccess
|
composeTask.Status = public.ComposeStatusSuccess
|
||||||
@@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetComposeTask 查询任务结果
|
// InjectHistory 插入历史会话
|
||||||
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
|
||||||
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: taskID,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("查询任务失败: %w", err)
|
|
||||||
}
|
|
||||||
if record == nil {
|
|
||||||
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
|
|
||||||
}
|
|
||||||
|
|
||||||
messages := parseMessagesForResponse(record.ResultJson)
|
|
||||||
|
|
||||||
return &dto.GetComposeTaskRes{
|
|
||||||
TaskId: record.TaskId,
|
|
||||||
Status: record.Status,
|
|
||||||
ErrorMessage: record.ErrorMessage,
|
|
||||||
Messages: messages,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseMessagesForResponse 解析用于响应的消息
|
|
||||||
func parseMessagesForResponse(messages any) any {
|
|
||||||
str, ok := messages.(string)
|
|
||||||
if !ok || str == "" {
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
var parsed any
|
|
||||||
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
|
|
||||||
return parsed
|
|
||||||
}
|
|
||||||
|
|
||||||
return messages
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
|
|
||||||
// 1) 获取协议配置
|
|
||||||
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
|
||||||
ProviderName: "火山引擎",
|
|
||||||
Status: 1,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2) 获取历史消息
|
|
||||||
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
|
|
||||||
SessionId: "88888888",
|
|
||||||
NodeId: "node1",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3) 模拟roundsData数据
|
|
||||||
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
|
||||||
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
fmt.Println("[打印数据]", task.ResultJson)
|
|
||||||
fmt.Println("[打印历史]", history.Messages)
|
|
||||||
fmt.Println("[打印协议]", protocol)
|
|
||||||
return &dto.GetPromptTextRes{
|
|
||||||
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
|
||||||
if protocol == nil || len(history) == 0 {
|
if protocol == nil || len(history) == 0 {
|
||||||
return roundsData
|
return roundsData
|
||||||
@@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco
|
|||||||
firstRound["messages"] = result
|
firstRound["messages"] = result
|
||||||
return roundsData
|
return roundsData
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetComposeTask 查询任务结果
|
||||||
|
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
|
||||||
|
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
|
||||||
|
TaskId: taskID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("查询任务失败: %w", err)
|
||||||
|
}
|
||||||
|
return &dto.GetComposeTaskRes{
|
||||||
|
TaskId: record.TaskId,
|
||||||
|
Status: record.Status,
|
||||||
|
ErrorMessage: record.ErrorMessage,
|
||||||
|
Messages: record.ResultJson,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||||
|
if skillName == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
skillResp, err := gateway.GetSkillUser(ctx, skillName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ package prompt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"prompts-core/common/util"
|
|
||||||
"prompts-core/service/gateway"
|
"prompts-core/service/gateway"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -14,8 +12,8 @@ import (
|
|||||||
"github.com/gogf/gf/v2/util/gconv"
|
"github.com/gogf/gf/v2/util/gconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PromptIR 统一 Prompt 中间表示
|
// IR 统一 Prompt 中间表示
|
||||||
type PromptIR struct {
|
type IR struct {
|
||||||
System []Segment `json:"system"`
|
System []Segment `json:"system"`
|
||||||
History []Segment `json:"history"`
|
History []Segment `json:"history"`
|
||||||
User []Segment `json:"user"`
|
User []Segment `json:"user"`
|
||||||
@@ -46,8 +44,8 @@ type ContentMapping struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewPromptIR 创建空 PromptIR
|
// NewPromptIR 创建空 PromptIR
|
||||||
func NewPromptIR() *PromptIR {
|
func NewPromptIR() *IR {
|
||||||
return &PromptIR{
|
return &IR{
|
||||||
System: make([]Segment, 0),
|
System: make([]Segment, 0),
|
||||||
History: make([]Segment, 0),
|
History: make([]Segment, 0),
|
||||||
User: make([]Segment, 0),
|
User: make([]Segment, 0),
|
||||||
@@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
|
||||||
func (ir *PromptIR) String() string {
|
func (ir *IR) String() string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -81,7 +79,7 @@ func (ir *PromptIR) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
|
||||||
func (ir *PromptIR) GetTotalContent() string {
|
func (ir *IR) GetTotalContent() string {
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddSystem 添加系统提示
|
// AddSystem 添加系统提示
|
||||||
func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
func (ir *IR) AddSystem(content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
ir.System = append(ir.System, Segment{Type: "text", Content: content})
|
||||||
}
|
}
|
||||||
@@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUser 添加用户消息
|
// AddUser 添加用户消息
|
||||||
func (ir *PromptIR) AddUser(content string) *PromptIR {
|
func (ir *IR) AddUser(content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
ir.User = append(ir.User, Segment{Type: "text", Content: content})
|
||||||
}
|
}
|
||||||
@@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddHistory 添加历史消息
|
// AddHistory 添加历史消息
|
||||||
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
func (ir *IR) AddHistory(role, content string) *IR {
|
||||||
if content != "" {
|
if content != "" {
|
||||||
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
|
||||||
}
|
}
|
||||||
@@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
// ToMessages 转换为 OpenAI 兼容的 messages 格式(MVP 默认)
|
||||||
func (ir *PromptIR) ToMessages() []map[string]any {
|
func (ir *IR) ToMessages() []map[string]any {
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
for _, seg := range ir.System {
|
for _, seg := range ir.System {
|
||||||
@@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
|
|||||||
|
|
||||||
// parseProtocol 将 DB entity 转为编译用协议配置
|
// parseProtocol 将 DB entity 转为编译用协议配置
|
||||||
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
|
||||||
p := &ProviderProtocol{
|
return &ProviderProtocol{
|
||||||
TargetField: e.TargetField,
|
TargetField: e.TargetField,
|
||||||
SystemPromptTemplate: e.SystemPromptTemplate,
|
SystemPromptTemplate: e.SystemPromptTemplate,
|
||||||
|
MergeOrder: e.MergeOrder,
|
||||||
|
RoleMapping: gconv.MapStrStr(e.RoleMapping),
|
||||||
|
ContentMapping: ContentMapping{
|
||||||
|
Type: gconv.String(e.ContentMapping["type"]),
|
||||||
|
Field: gconv.String(e.ContentMapping["field"]),
|
||||||
|
},
|
||||||
|
RequestTemplate: e.RequestTemplate,
|
||||||
|
Capabilities: e.Capabilities,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用通用解析方法处理各个字段
|
|
||||||
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
|
|
||||||
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
|
|
||||||
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
|
|
||||||
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
|
|
||||||
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
|
|
||||||
return p
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
// Compile 将 PromptIR 按协议配置编译为 Provider Request
|
||||||
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
|
||||||
if ir == nil || p == nil {
|
if ir == nil || p == nil {
|
||||||
return nil, fmt.Errorf("ir and protocol are required")
|
return nil, fmt.Errorf("ir and protocol are required")
|
||||||
}
|
}
|
||||||
@@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mergeByOrder 按协议配置顺序拼接消息
|
// mergeByOrder 按协议配置顺序拼接消息
|
||||||
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
|
func mergeByOrder(ir *IR, order []string) []map[string]any {
|
||||||
|
roleMap := map[string][]Segment{
|
||||||
|
"system": ir.System,
|
||||||
|
"history": ir.History,
|
||||||
|
"user": ir.User,
|
||||||
|
}
|
||||||
|
|
||||||
var messages []map[string]any
|
var messages []map[string]any
|
||||||
|
|
||||||
for _, part := range order {
|
for _, part := range order {
|
||||||
switch part {
|
for _, seg := range roleMap[part] {
|
||||||
case "system":
|
msg := map[string]any{"content": seg.Content}
|
||||||
for _, seg := range ir.System {
|
if part == "history" {
|
||||||
messages = append(messages, map[string]any{
|
msg["role"] = seg.Role
|
||||||
"role": "system",
|
} else {
|
||||||
"content": seg.Content,
|
msg["role"] = part
|
||||||
})
|
|
||||||
}
|
}
|
||||||
case "history":
|
messages = append(messages, msg)
|
||||||
for _, seg := range ir.History {
|
|
||||||
messages = append(messages, map[string]any{
|
|
||||||
"role": seg.Role,
|
|
||||||
"content": seg.Content,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case "user":
|
|
||||||
for _, seg := range ir.User {
|
|
||||||
messages = append(messages, map[string]any{
|
|
||||||
"role": "user",
|
|
||||||
"content": seg.Content,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
|
|||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
// mapContent 内容字段映射
|
|
||||||
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
|
||||||
for _, msg := range messages {
|
if cm.Field == "" || cm.Field == "content" {
|
||||||
content := msg["content"]
|
return messages
|
||||||
delete(msg, "content")
|
}
|
||||||
|
|
||||||
|
for i, msg := range messages {
|
||||||
|
if content, ok := msg["content"]; ok {
|
||||||
|
delete(msg, "content")
|
||||||
switch cm.Type {
|
switch cm.Type {
|
||||||
case "parts":
|
case "parts":
|
||||||
msg["parts"] = []map[string]any{
|
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
|
||||||
{cm.Field: content},
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
msg[cm.Field] = content
|
messages[i][cm.Field] = content
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// renderTemplate 简单的 {{key}} 模板替换
|
// renderTemplate 模板渲染
|
||||||
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
|
||||||
b, _ := json.Marshal(p.RequestTemplate)
|
result := make(map[string]any, len(p.RequestTemplate)+1)
|
||||||
str := string(b)
|
for k, v := range p.RequestTemplate {
|
||||||
|
result[k] = v
|
||||||
if chatModel != nil {
|
|
||||||
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msgBytes, _ := json.Marshal(messages)
|
if chatModel != nil {
|
||||||
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
|
result["model"] = chatModel.ModelName
|
||||||
|
}
|
||||||
var result map[string]any
|
result["messages"] = messages
|
||||||
_ = json.Unmarshal([]byte(str), &result)
|
|
||||||
|
|
||||||
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
|
||||||
result["max_tokens"] = maxTokens
|
result["max_tokens"] = maxTokens
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ import (
|
|||||||
|
|
||||||
// Callback 会话回调
|
// Callback 会话回调
|
||||||
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
|
||||||
|
fmt.Println("打印会话回调", req)
|
||||||
req.Messages["role"] = "assistant"
|
req.Messages["role"] = "assistant"
|
||||||
|
|
||||||
// 1) 更新 DB
|
// 1) 更新 DB
|
||||||
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
|
||||||
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
|
||||||
@@ -163,23 +163,15 @@ func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteS
|
|||||||
|
|
||||||
// entityToHistoryRound entity → HistoryRound
|
// entityToHistoryRound entity → HistoryRound
|
||||||
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
|
||||||
reqMsgs := util.ConvertToMessages(s.RequestContent)
|
return &dto.HistoryRound{
|
||||||
respMsgs := util.ConvertToMessages(s.ResponseContent)
|
|
||||||
|
|
||||||
round := &dto.HistoryRound{
|
|
||||||
Id: s.Id,
|
Id: s.Id,
|
||||||
SessionId: s.SessionId,
|
SessionId: s.SessionId,
|
||||||
NodeId: s.NodeId,
|
NodeId: s.NodeId,
|
||||||
CreatedAt: gconv.String(s.CreatedAt),
|
CreatedAt: gconv.String(s.CreatedAt),
|
||||||
UpdatedAt: gconv.String(s.UpdatedAt),
|
UpdatedAt: gconv.String(s.UpdatedAt),
|
||||||
|
User: s.RequestContent,
|
||||||
|
Assistant: s.ResponseContent,
|
||||||
}
|
}
|
||||||
if len(reqMsgs) > 0 {
|
|
||||||
round.User = reqMsgs[0]
|
|
||||||
}
|
|
||||||
if len(respMsgs) > 0 {
|
|
||||||
round.Assistant = respMsgs[0]
|
|
||||||
}
|
|
||||||
return round
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sessionsToHistoryRounds 批量转换
|
// sessionsToHistoryRounds 批量转换
|
||||||
|
|||||||
Reference in New Issue
Block a user