refactor(prompts-core): 重构代码结构和优化工具函数

This commit is contained in:
2026-06-10 14:51:25 +08:00
parent 1c1db7e30c
commit b69e7386e2
10 changed files with 164 additions and 432 deletions

View File

@@ -1,197 +1,81 @@
package util
import (
"encoding/json"
"fmt"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j := gjson.New(raw)
messages := j.Get("messages")
if !messages.IsNil() {
return gconv.Maps(messages.Val())
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form []map[string]any) string {
if form == nil {
return "[]"
}
b, _ := json.Marshal(form)
return string(b)
}
// UserFormToJSON 将用户表单数据转换为 JSON 字符串
func UserFormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
}
b, _ := json.Marshal(form)
return string(b)
}
// MustMarshalToMap 将对象序列化为 map[string]any失败时返回空 map
func MustMarshalToMap(v any) map[string]any {
b, err := json.Marshal(v)
if err != nil {
return make(map[string]any)
}
var m map[string]any
json.Unmarshal(b, &m)
return m
}
// JSONPretty 将任意类型转为格式化的 JSON 字符串
func JSONPretty(v any) string {
if gv, ok := v.(*gvar.Var); ok {
v = gconv.Map(gv.String())
}
var tmp map[string]any
if err := gconv.Struct(v, &tmp); err != nil {
return gconv.String(v)
}
b, _ := json.Marshal(tmp)
return string(b)
}
// ParseJSONFieldFromGvar 专门处理 *gvar.Var 类型的 JSON 字段解析
func ParseJSONFieldFromGvar(source any, target any) {
if source == nil {
return
}
switch v := source.(type) {
case *gvar.Var:
if v.IsNil() {
return
}
// 尝试获取 map
if m := v.Map(); len(m) > 0 {
data, _ := json.Marshal(m)
json.Unmarshal(data, target)
return
}
// 尝试解析 JSON 字符串
str := v.String()
if str != "" && str != "<nil>" {
json.Unmarshal([]byte(str), target)
}
default:
// 其他类型走原来的逻辑
data, _ := json.Marshal(source)
json.Unmarshal(data, target)
}
}
// MergeConsult 将 consult 附件合并到模型生成的 messages 结构中
func MergeConsult(req map[string]any, messages map[string]any, extendMapping map[string]any) map[string]any {
if len(req) == 0 || len(messages) == 0 || len(extendMapping) == 0 {
return messages
}
// 1) 获取 consult 数组
consult := gconv.Interfaces(req["consult"])
if len(consult) == 0 {
return messages
}
// 2) 获取配置
targetPath := gconv.String(extendMapping["target_content_path"])
if targetPath == "" {
return messages
}
templates := gconv.Map(extendMapping["attachment_templates"])
if len(templates) == 0 {
if targetPath == "" || len(templates) == 0 {
return messages
}
// 3) 转为 gjson 操作
msgJson := gjson.New(messages)
// 固定:如果有 rounds 结构,路径替换为 rounds.0.{targetPath}
if arr := msgJson.Get("rounds.0").Array(); arr != nil {
// rounds 路径修正
if !msgJson.Get("rounds.0").IsNil() {
targetPath = "rounds.0." + targetPath
}
// 4) 遍历 consult按类型生成附件并追加
// 遍历追加
for _, item := range consult {
itemJson := gjson.New(item)
itemType := itemJson.Get("type").String()
if itemType == "" {
continue
}
// 查找对应模板
tmpl := gconv.Map(templates[itemType])
if len(tmpl) == 0 {
if itemType == "" || len(tmpl) == 0 {
continue
}
// 生成附件对象
attachment := buildAttachment(tmpl, itemJson.Get("url").String())
if attachment == nil {
continue
}
// 获取当前数组长度,用索引追加
arr := msgJson.Get(targetPath).Array()
idx := len(arr)
indexPath := fmt.Sprintf("%s.%d", targetPath, idx)
_ = msgJson.Set(indexPath, attachment)
idx := len(msgJson.Get(targetPath).Array())
_ = msgJson.Set(fmt.Sprintf("%s.%d", targetPath, idx), attachment)
}
return msgJson.Map()
}
// buildAttachment 根据模板和用户数据生成附件对象
func buildAttachment(tmpl map[string]any, url string) map[string]any {
typ := gconv.String(tmpl["type"])
if typ == "" || url == "" {
return nil
}
// 深拷贝 body 并填充 url
body := gconv.Map(tmpl["body"])
bodyJson := gjson.New(body)
bodyJson = fillEmpty(bodyJson, url)
fillEmptyInPlace(body, url)
return map[string]any{
"type": typ,
typ: bodyJson.Map(),
typ: body,
}
}
// fillEmpty 递归查找空字符串并替换
func fillEmpty(j *gjson.Json, value string) *gjson.Json {
m := j.Map()
func fillEmptyInPlace(m map[string]any, value string) {
for k, v := range m {
switch vv := v.(type) {
case string:
if vv == "" {
_ = j.Set(k, value)
m[k] = value
}
case map[string]any:
_ = j.Set(k, fillEmpty(gjson.New(vv), value).Map())
fillEmptyInPlace(vv, value)
}
}
return j
}

View File

@@ -4,6 +4,7 @@ import (
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// ReverseMap 映射 payload 到 mapping
@@ -20,80 +21,37 @@ func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
return jsonObj.Map()
}
// ExtractUserText 从 messages map 中提取用户文本,返回标准的 user message 结构
// ExtractUserText 从 messages 中提取所有 user 文本
func ExtractUserText(messages map[string]any) map[string]any {
msgJson := gjson.New(messages)
msgs := msgJson.Get("rounds.0.messages")
if msgs.IsNil() {
msgs = msgJson.Get("messages")
}
var texts []string
// 1) rounds 结构:遍历每轮
if rounds, ok := messages["rounds"].([]any); ok {
for _, round := range rounds {
if rm, ok := round.(map[string]any); ok {
if msgs, ok := rm["messages"].([]any); ok {
texts = append(texts, extractTextFromRoleUser(msgs)...)
for _, m := range msgs.Array() {
msg := gjson.New(m)
if msg.Get("role").String() != "user" {
continue
}
content := msg.Get("content").Val()
switch c := content.(type) {
case string:
texts = append(texts, c)
case []any:
for _, item := range c {
if m, ok := item.(map[string]any); ok {
if t := gconv.String(m["text"]); t != "" {
texts = append(texts, t)
}
}
}
}
} else if msgs, ok := messages["messages"].([]any); ok {
// 2) messages 结构
texts = extractTextFromRoleUser(msgs)
}
// 3) 构建返回结构
return map[string]any{
"role": "user",
"content": strings.Join(texts, "\n"),
}
}
// extractTextFromRoleUser 从 messages 数组中提取所有 role=user 的文本
func extractTextFromRoleUser(msgs []any) []string {
var texts []string
for _, msg := range msgs {
m, ok := msg.(map[string]any)
if !ok {
continue
}
if role, _ := m["role"].(string); role != "user" {
continue
}
texts = append(texts, extractAllText(m["content"])...)
}
return texts
}
// extractAllText 从 content 中提取所有文本(递归,最大兼容)
func extractAllText(content any) []string {
switch c := content.(type) {
case string:
return []string{c}
case []any:
var texts []string
for _, item := range c {
m, ok := item.(map[string]any)
if !ok {
continue
}
if t, ok := m["text"].(string); ok && t != "" {
texts = append(texts, t)
continue
}
for _, v := range m {
texts = append(texts, extractAllText(v)...)
}
}
return texts
case map[string]any:
if t, ok := c["text"].(string); ok && t != "" {
return []string{t}
}
var texts []string
for _, v := range c {
texts = append(texts, extractAllText(v)...)
}
return texts
}
return nil
}

View File

@@ -11,3 +11,7 @@ const (
BuildTypeNode = 2 //节点构建
BuildTypeStruct = 3 //结构构建
)
const (
ModelTypeInference = 100 // 推理模型
)

View File

@@ -26,8 +26,3 @@ func (c *prompt) Callback(ctx context.Context, req *dto.CallbackReq) (res *dto.C
func (c *prompt) GetComposeTask(ctx context.Context, req *dto.GetComposeTaskReq) (res *dto.GetComposeTaskRes, err error) {
return promptService.GetComposeTask(ctx, req.TaskId)
}
// GetPromptText 纯文本prompt调用接口测试专用
func (c *prompt) GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (res *dto.GetPromptTextRes, err error) {
return promptService.GetPromptText(ctx, req)
}

View File

@@ -25,12 +25,6 @@ type ComposeMessagesRes struct {
TaskId string `json:"taskId" dc:"任务ID"`
}
// MultiRoundResult 多轮返回结果
type MultiRoundResult struct {
TotalRounds int `json:"total_rounds"` // 总轮数
Rounds []map[string]any `json:"rounds"` // 每轮详情(动态类型)
}
type CallbackReq struct {
g.Meta `path:"/callback" method:"post" tags:"提示词处理" summary:"model-gateway 回调" dc:"model-gateway 成功后 POST 回调callbackUrl/{bizName}"`
TaskId string `json:"task_id" v:"required#task_id不能为空" dc:"网关任务ID"`
@@ -55,16 +49,7 @@ type GetComposeTaskRes struct {
Status string `json:"status" dc:"业务状态"`
GatewayState int `json:"gatewayState" dc:"网关状态"`
ErrorMessage string `json:"errorMessage" dc:"错误信息"`
Messages any `json:"messages" dc:"最终消息数组"`
Messages map[string]any `json:"messages" dc:"最终消息数组"`
OssFile string `json:"ossFile" dc:"结果文件地址"`
FileType string `json:"fileType" dc:"结果文件类型"`
}
type GetPromptTextReq struct {
g.Meta `path:"/getPromptText" method:"get" tags:"提示词测试" summary:"测试文本生成" dc:"传入提示词,返回模型纯文本结果,用于接口连通性测试"`
Prompt string `p:"prompt" json:"prompt" dc:"测试用提示词"`
}
type GetPromptTextRes struct {
Messages any `json:"messages" dc:"历史消息"`
}

View File

@@ -13,11 +13,10 @@ import (
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
)
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *PromptIR, totalBatches int) (map[string]any, error) {
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *gateway.AsynchModel, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, chatModel, aiModel)
ir.AddSystem(systemPrompt)
@@ -32,29 +31,21 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel, req)
}
// buildStructTypeRequest 构建结构体类型请求BuildType=3
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *PromptIR) (map[string]any, error) {
// 提取 userForm 中的 prompt 作为自定义提示词
var customPrompt string
for _, item := range req.UserForm {
if prompt, ok := item["prompt"]; ok && gconv.String(prompt) != "" {
customPrompt = gconv.String(prompt)
break
}
}
// 用户消息
func buildStructTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *gateway.AsynchModel, ir *IR) (map[string]any, error) {
customPrompt := gjson.New(req.UserForm).Get("0.prompt").String()
ir.AddSystem(customPrompt)
ir.AddUser(buildUserPrompt(ctx, req, ""))
return compileToProviderRequest(ctx, ir, chatModel, req, customPrompt)
}
// compileToProviderRequest 编译为 Provider 请求
func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
func compileToProviderRequest(ctx context.Context, ir *IR, chatModel *gateway.AsynchModel, req *dto.ComposeMessagesReq, customPrompt ...string) (map[string]any, error) {
protocol, err := GetProtocolByProvider(ctx, chatModel.OperatorName)
if err != nil || protocol == nil {
return nil, fmt.Errorf("协议配置不存在或获取失败: %w", err)
@@ -78,6 +69,7 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *gate
}, nil
}
// promptBuildWithRounds 构建提示词
func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel, aiModel *gateway.AsynchModel) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: chatModel.OperatorName,
@@ -86,7 +78,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
outputJSON := gjson.New(util.ReverseMap(aiModel.RequestMapping, map[string]any{})).MustToJsonIndentString()
return fmt.Sprintf(providerProtocol.SystemPromptTemplate,
outputJSON, //【输出结构】 %s
@@ -94,7 +86,7 @@ func promptBuildWithRounds(ctx context.Context, chatModel *gateway.AsynchModel,
}
// checkOverallContent 检查整体内容是否超出窗口
func checkOverallContent(ir *PromptIR, model *gateway.AsynchModel) bool {
func checkOverallContent(ir *IR, model *gateway.AsynchModel) bool {
fullContent := ir.String()
return util.CountToken(fullContent, model.TokenConfig)
}
@@ -124,7 +116,6 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
return b.String()
}
// buildUserFormText 构建用户表单内容字符串
func buildUserFormText(form []map[string]any) string {
if len(form) == 0 {
return ""
@@ -132,32 +123,22 @@ func buildUserFormText(form []map[string]any) string {
var builder strings.Builder
for _, item := range form {
for k, v := range item {
builder.WriteString(fmt.Sprintf("%s\n", k))
switch val := v.(type) {
case []any:
// 数组类型:逐条列出
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, elem := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
if m, ok := elem.(map[string]any); ok {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
}
builder.WriteString("\n")
} else {
builder.WriteString(fmt.Sprintf(" %d. %v\n", i+1, elem))
}
}
case []map[string]any:
builder.WriteString(fmt.Sprintf("%s\n", k))
for i, m := range val {
builder.WriteString(fmt.Sprintf(" %d. ", i+1))
for mk, mv := range m {
builder.WriteString(fmt.Sprintf("%s%v ", mk, mv))
builder.WriteString(fmt.Sprint(elem))
}
builder.WriteString("\n")
}
default:
builder.WriteString(fmt.Sprintf("%s%v\n", k, v))
builder.WriteString(fmt.Sprintf(" %v\n", v))
}
}
}
@@ -170,9 +151,8 @@ func NodeBuild(ctx context.Context, req *dto.ComposeMessagesReq) string {
if promptTpl == "" {
return ""
}
formStr := util.FormToJSON(req.Form)
userFormStr := util.UserFormToJSON(req.UserForm)
return fmt.Sprintf(promptTpl, formStr, userFormStr)
return fmt.Sprintf(promptTpl,
gjson.New(req.Form).MustToJsonString(),
gjson.New(req.UserForm).MustToJsonString(),
)
}

View File

@@ -2,7 +2,6 @@ package prompt
import (
"context"
"encoding/json"
"errors"
"fmt"
"prompts-core/service/session"
@@ -80,7 +79,7 @@ func validateUserForm(req *dto.ComposeMessagesReq, model *gateway.AsynchModel) e
// handleBuild 通用构建处理
func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, aiModel *gateway.AsynchModel) (*dto.ComposeMessagesRes, error) {
// 1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
processedReq, _, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
}
@@ -90,7 +89,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
var taskReq map[string]any
switch req.BuildType {
case public.BuildTypePrompt:
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir, totalBatches)
taskReq, err = buildPromptTypeRequest(ctx, processedReq, aiModel, chatModel, ir)
case public.BuildTypeNode:
taskReq, err = buildNodeTypeRequest(ctx, req, chatModel, ir)
case public.BuildTypeStruct:
@@ -118,7 +117,7 @@ func handleBuild(ctx context.Context, req *dto.ComposeMessagesReq, chatModel, ai
SkillName: req.SkillName,
BuildType: req.BuildType,
CallbackUrl: req.CallbackUrl,
RequestPayload: util.MustMarshalToMap(req),
RequestPayload: gconv.Map(req),
Status: public.ComposeStatusPending,
}); err != nil {
return nil, err
@@ -164,6 +163,7 @@ func handleCallbackFailed(ctx context.Context, req *dto.CallbackReq, composeTask
return err
}
// handleCallbackSuccess 处理回调成功
func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTask *entity.ComposeTask) error {
// 1) 获取模型配置
model, err := gateway.GetModelConfig(ctx, &gateway.AsynchModel{
@@ -180,12 +180,15 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
Status: 1,
})
// 3) 获取历史消息
// 3) 获取历史消息 + 保存当前轮
payload := composeTask.RequestPayload
sessionId := gconv.String(payload["sessionId"])
nodeId := gconv.String(payload["nodeId"])
var history []dto.FlatMessage
if sessionId != "" && nodeId != "" {
var epicycleId int64
if sessionId != "" && nodeId != "" && model.ModelType == public.ModelTypeInference {
// 3.1 获取历史
h, _ := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: sessionId,
NodeId: nodeId,
@@ -193,12 +196,21 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
if h != nil {
history = h.Messages
}
// 3.2 保存当前轮(先存,下次查询就能拿到)
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 4) 合并附加结构
messages := util.MergeConsult(composeTask.RequestPayload, req.Messages, model.ExtendMapping)
// 5) 注入历史到 rounds 中
if protocol != nil && len(history) > 0 {
// 5) 注入历史
if len(history) > 0 {
messages = InjectHistory(messages, history, protocol)
}
@@ -215,18 +227,6 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return err
}
// 7) 存储历史
var epicycleId int64
if sessionId != "" && nodeId != "" {
if userMsg := util.ExtractUserText(req.Messages); userMsg != nil {
epicycleId, _ = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
NodeId: nodeId,
SessionId: sessionId,
RequestContent: userMsg,
})
}
}
// 8) 回调业务方
if composeTask.CallbackUrl != "" {
composeTask.Status = public.ComposeStatusSuccess
@@ -237,77 +237,7 @@ func handleCallbackSuccess(ctx context.Context, req *dto.CallbackReq, composeTas
return nil
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
if record == nil {
return nil, fmt.Errorf("未找到任务(taskId=%s)", taskID)
}
messages := parseMessagesForResponse(record.ResultJson)
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: messages,
}, nil
}
// parseMessagesForResponse 解析用于响应的消息
func parseMessagesForResponse(messages any) any {
str, ok := messages.(string)
if !ok || str == "" {
return messages
}
var parsed any
if err := json.Unmarshal([]byte(str), &parsed); err == nil {
return parsed
}
return messages
}
func GetPromptText(ctx context.Context, req *dto.GetPromptTextReq) (*dto.GetPromptTextRes, error) {
// 1) 获取协议配置
protocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: "火山引擎",
Status: 1,
})
if err != nil {
return nil, err
}
// 2) 获取历史消息
history, err := session.GetHistoryMessages(ctx, &dto.GetHistoryMessagesReq{
SessionId: "88888888",
NodeId: "node1",
})
if err != nil {
return nil, err
}
// 3) 模拟roundsData数据
task, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: "0e1872f0-0e73-42f1-9aa8-63d317300ffc",
})
if err != nil {
return nil, err
}
fmt.Println("[打印数据]", task.ResultJson)
fmt.Println("[打印历史]", history.Messages)
fmt.Println("[打印协议]", protocol)
return &dto.GetPromptTextRes{
Messages: InjectHistory(task.ResultJson, history.Messages, protocol),
}, nil
}
// InjectHistory 插入历史会话
func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protocol *entity.ProviderProtocol) map[string]any {
if protocol == nil || len(history) == 0 {
return roundsData
@@ -363,3 +293,19 @@ func InjectHistory(roundsData map[string]any, history []dto.FlatMessage, protoco
firstRound["messages"] = result
return roundsData
}
// GetComposeTask 查询任务结果
func GetComposeTask(ctx context.Context, taskID string) (*dto.GetComposeTaskRes, error) {
record, err := dao.ComposeTask.Get(ctx, &entity.ComposeTask{
TaskId: taskID,
})
if err != nil {
return nil, fmt.Errorf("查询任务失败: %w", err)
}
return &dto.GetComposeTaskRes{
TaskId: record.TaskId,
Status: record.Status,
ErrorMessage: record.ErrorMessage,
Messages: record.ResultJson,
}, nil
}

View File

@@ -190,6 +190,9 @@ func fetchFileContent(ctx context.Context, client *http.Client, url string) (str
}
func SkillMdContent(ctx context.Context, skillName string) string {
if skillName == "" {
return ""
}
skillResp, err := gateway.GetSkillUser(ctx, skillName)
if err != nil {
g.Log().Warningf(ctx, "[SkillMd] GetSkillUser 失败: %v", err)

View File

@@ -2,9 +2,7 @@ package prompt
import (
"context"
"encoding/json"
"fmt"
"prompts-core/common/util"
"prompts-core/service/gateway"
"strings"
@@ -14,8 +12,8 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
// PromptIR 统一 Prompt 中间表示
type PromptIR struct {
// IR 统一 Prompt 中间表示
type IR struct {
System []Segment `json:"system"`
History []Segment `json:"history"`
User []Segment `json:"user"`
@@ -46,8 +44,8 @@ type ContentMapping struct {
}
// NewPromptIR 创建空 PromptIR
func NewPromptIR() *PromptIR {
return &PromptIR{
func NewPromptIR() *IR {
return &IR{
System: make([]Segment, 0),
History: make([]Segment, 0),
User: make([]Segment, 0),
@@ -55,7 +53,7 @@ func NewPromptIR() *PromptIR {
}
// String 返回 PromptIR 的完整内容字符串(用于 token 计算)
func (ir *PromptIR) String() string {
func (ir *IR) String() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -81,7 +79,7 @@ func (ir *PromptIR) String() string {
}
// GetTotalContent 获取所有内容的拼接字符串(更精确的 token 计算)
func (ir *PromptIR) GetTotalContent() string {
func (ir *IR) GetTotalContent() string {
var builder strings.Builder
for _, seg := range ir.System {
@@ -103,7 +101,7 @@ func (ir *PromptIR) GetTotalContent() string {
}
// AddSystem 添加系统提示
func (ir *PromptIR) AddSystem(content string) *PromptIR {
func (ir *IR) AddSystem(content string) *IR {
if content != "" {
ir.System = append(ir.System, Segment{Type: "text", Content: content})
}
@@ -111,7 +109,7 @@ func (ir *PromptIR) AddSystem(content string) *PromptIR {
}
// AddUser 添加用户消息
func (ir *PromptIR) AddUser(content string) *PromptIR {
func (ir *IR) AddUser(content string) *IR {
if content != "" {
ir.User = append(ir.User, Segment{Type: "text", Content: content})
}
@@ -119,7 +117,7 @@ func (ir *PromptIR) AddUser(content string) *PromptIR {
}
// AddHistory 添加历史消息
func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
func (ir *IR) AddHistory(role, content string) *IR {
if content != "" {
ir.History = append(ir.History, Segment{Type: "text", Content: content, Role: role})
}
@@ -127,7 +125,7 @@ func (ir *PromptIR) AddHistory(role, content string) *PromptIR {
}
// ToMessages 转换为 OpenAI 兼容的 messages 格式MVP 默认)
func (ir *PromptIR) ToMessages() []map[string]any {
func (ir *IR) ToMessages() []map[string]any {
var messages []map[string]any
for _, seg := range ir.System {
@@ -168,22 +166,22 @@ func GetProtocolByProvider(ctx context.Context, providerName string) (*ProviderP
// parseProtocol 将 DB entity 转为编译用协议配置
func parseProtocol(e *entity.ProviderProtocol) *ProviderProtocol {
p := &ProviderProtocol{
return &ProviderProtocol{
TargetField: e.TargetField,
SystemPromptTemplate: e.SystemPromptTemplate,
MergeOrder: e.MergeOrder,
RoleMapping: gconv.MapStrStr(e.RoleMapping),
ContentMapping: ContentMapping{
Type: gconv.String(e.ContentMapping["type"]),
Field: gconv.String(e.ContentMapping["field"]),
},
RequestTemplate: e.RequestTemplate,
Capabilities: e.Capabilities,
}
// 使用通用解析方法处理各个字段
util.ParseJSONFieldFromGvar(e.MergeOrder, &p.MergeOrder)
util.ParseJSONFieldFromGvar(e.RoleMapping, &p.RoleMapping)
util.ParseJSONFieldFromGvar(e.ContentMapping, &p.ContentMapping)
util.ParseJSONFieldFromGvar(e.RequestTemplate, &p.RequestTemplate)
util.ParseJSONFieldFromGvar(e.Capabilities, &p.Capabilities)
return p
}
// Compile 将 PromptIR 按协议配置编译为 Provider Request
func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
func Compile(ir *IR, p *ProviderProtocol, chatModel *gateway.AsynchModel) (map[string]any, error) {
if ir == nil || p == nil {
return nil, fmt.Errorf("ir and protocol are required")
}
@@ -195,35 +193,25 @@ func Compile(ir *PromptIR, p *ProviderProtocol, chatModel *gateway.AsynchModel)
}
// mergeByOrder 按协议配置顺序拼接消息
func mergeByOrder(ir *PromptIR, order []string) []map[string]any {
func mergeByOrder(ir *IR, order []string) []map[string]any {
roleMap := map[string][]Segment{
"system": ir.System,
"history": ir.History,
"user": ir.User,
}
var messages []map[string]any
for _, part := range order {
switch part {
case "system":
for _, seg := range ir.System {
messages = append(messages, map[string]any{
"role": "system",
"content": seg.Content,
})
for _, seg := range roleMap[part] {
msg := map[string]any{"content": seg.Content}
if part == "history" {
msg["role"] = seg.Role
} else {
msg["role"] = part
}
case "history":
for _, seg := range ir.History {
messages = append(messages, map[string]any{
"role": seg.Role,
"content": seg.Content,
})
}
case "user":
for _, seg := range ir.User {
messages = append(messages, map[string]any{
"role": "user",
"content": seg.Content,
})
messages = append(messages, msg)
}
}
}
return messages
}
@@ -247,22 +235,22 @@ func mapRoles(messages []map[string]any, mapping map[string]string) []map[string
return messages
}
// mapContent 内容字段映射
func mapContent(messages []map[string]any, cm ContentMapping) []map[string]any {
for _, msg := range messages {
content := msg["content"]
delete(msg, "content")
if cm.Field == "" || cm.Field == "content" {
return messages
}
for i, msg := range messages {
if content, ok := msg["content"]; ok {
delete(msg, "content")
switch cm.Type {
case "parts":
msg["parts"] = []map[string]any{
{cm.Field: content},
}
messages[i]["parts"] = []map[string]any{{cm.Field: content}}
default:
msg[cm.Field] = content
messages[i][cm.Field] = content
}
}
}
return messages
}
@@ -277,20 +265,17 @@ func buildRequest(messages []map[string]any, p *ProviderProtocol, chatModel *gat
}
}
// renderTemplate 简单的 {{key}} 模板替换
// renderTemplate 模板渲染
func renderTemplate(p *ProviderProtocol, messages []map[string]any, chatModel *gateway.AsynchModel) map[string]any {
b, _ := json.Marshal(p.RequestTemplate)
str := string(b)
if chatModel != nil {
str = strings.ReplaceAll(str, `"{{model}}"`, `"`+chatModel.ModelName+`"`)
result := make(map[string]any, len(p.RequestTemplate)+1)
for k, v := range p.RequestTemplate {
result[k] = v
}
msgBytes, _ := json.Marshal(messages)
str = strings.ReplaceAll(str, `"{{messages}}"`, string(msgBytes))
var result map[string]any
_ = json.Unmarshal([]byte(str), &result)
if chatModel != nil {
result["model"] = chatModel.ModelName
}
result["messages"] = messages
if maxTokens := gconv.Int(p.Capabilities["max_tokens"]); maxTokens > 0 {
result["max_tokens"] = maxTokens

View File

@@ -21,8 +21,8 @@ import (
// Callback 会话回调
func Callback(ctx context.Context, req *dto.SessionCallbackReq) (*dto.SessionCallbackRes, error) {
fmt.Println("打印会话回调", req)
req.Messages["role"] = "assistant"
// 1) 更新 DB
_, err := dao.ComposeSession.Update(ctx, &entity.ComposeSession{
SQLBaseDO: beans.SQLBaseDO{Id: req.EpicycleId},
@@ -163,23 +163,15 @@ func DeleteSession(ctx context.Context, req *dto.DeleteSessionReq) (*dto.DeleteS
// entityToHistoryRound entity → HistoryRound
func entityToHistoryRound(s *entity.ComposeSession) *dto.HistoryRound {
reqMsgs := util.ConvertToMessages(s.RequestContent)
respMsgs := util.ConvertToMessages(s.ResponseContent)
round := &dto.HistoryRound{
return &dto.HistoryRound{
Id: s.Id,
SessionId: s.SessionId,
NodeId: s.NodeId,
CreatedAt: gconv.String(s.CreatedAt),
UpdatedAt: gconv.String(s.UpdatedAt),
User: s.RequestContent,
Assistant: s.ResponseContent,
}
if len(reqMsgs) > 0 {
round.User = reqMsgs[0]
}
if len(respMsgs) > 0 {
round.Assistant = respMsgs[0]
}
return round
}
// sessionsToHistoryRounds 批量转换