refactor(service): 重构服务模块结构并优化模型配置

This commit is contained in:
2026-05-29 17:54:19 +08:00
parent d74559ae74
commit 55eb436639
7 changed files with 204 additions and 53 deletions

View File

@@ -8,6 +8,11 @@ import (
"github.com/gogf/gf/v2/util/gconv"
)
// GetServerName 获取服务名称
func GetServerName(ctx context.Context) string {
return g.Cfg().MustGet(ctx, "server.name", "").String()
}
// GetServerPort 从配置获取服务端口
func GetServerPort(ctx context.Context) string {
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()

View File

@@ -2,51 +2,34 @@ package util
import (
"encoding/json"
"fmt"
"strconv"
"github.com/gogf/gf/v2/container/gvar"
"github.com/gogf/gf/v2/encoding/gjson"
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/util/gconv"
tGjson "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ParseOutput 解析模型输出为 JSON 格式
func ParseOutput(text string) (map[string]any, error) {
j, err := gjson.LoadJson([]byte(text))
if err != nil {
return nil, fmt.Errorf("解析模型输出失败: %w", err)
}
return j.Map(), nil
}
// ConvertToMessages 将原始数据转换为消息列表
func ConvertToMessages(raw any) []map[string]any {
if raw == nil {
return nil
}
j, err := gjson.LoadJson(gconv.Bytes(raw))
if err != nil {
return nil
j := gfgjson.New(raw)
messages := j.Get("messages")
if !messages.IsNil() {
return gconv.Maps(messages.Val())
}
if j.Contains("messages") {
return gconv.Maps(j.Get("messages").Array())
}
return []map[string]any{j.Map()}
}
// FormToJSON 将表单数据转换为 JSON 字符串
func FormToJSON(form map[string]any) string {
func FormToJSON(form []map[string]any) string {
if form == nil {
return "{}"
return "[]"
}
b, _ := json.Marshal(form)
return string(b)
}

148
common/util/mapping.go Normal file
View File

@@ -0,0 +1,148 @@
package util
import (
"fmt"
"net/url"
"prompts-core/model/entity"
"strings"
"github.com/gogf/gf/v2/encoding/gjson"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
// 1) 获取校验配置,并取值
requestMapping := model.RequestMapping
contentKey := ""
for k := range model.ResponseBody {
contentKey = k
break
}
contentStr, ok := raw[contentKey].(string)
if !ok || contentStr == "" {
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
}
// 2) 解析 content 为 JSON 数组
var rounds []map[string]any
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
}
if len(rounds) == 0 {
return fmt.Errorf("content 数组为空")
}
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
for i, round := range rounds {
for path, defaultValue := range requestMapping {
if !g.IsEmpty(defaultValue) {
continue
}
if gjson.New(round).Get(path).IsNil() {
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
}
}
}
return nil
}
// ReverseMap 映射 payload 到 mapping
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
jsonObj := gjson.New("{}")
for path, defaultValue := range mapping {
val := gjson.New(payload).Get(path)
if !val.IsNil() {
_ = jsonObj.Set(path, val.Val())
} else if defaultValue != nil {
_ = jsonObj.Set(path, defaultValue)
}
}
return jsonObj.Map()
}
// MapResponsePayload 映射模型响应为标准格式
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
if len(mapping) == 0 {
return responseBytes, nil
}
responseJson := gjson.New(responseBytes)
resultJson := gjson.New("{}")
for standardField, modelPath := range mapping {
path := gconv.String(modelPath)
if path == "" {
continue
}
val := responseJson.Get(path)
if val.IsNil() {
continue
}
resultJson.Set(standardField, val.Val())
}
return []byte(resultJson.String()), nil
}
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
// 示例:
// - X-API-Key:qwen3-tts-key,operation:true,count:123
// - X-API-Key:"qwen3-tts-key",operation:"true"
//
// 说明:
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
func ParseHeadMsgHeaders(headMsg string) map[string]string {
headMsg = strings.TrimSpace(headMsg)
if headMsg == "" {
return nil
}
out := map[string]string{}
parts := strings.Split(headMsg, ",")
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
// HeaderName:HeaderValue推荐 / HeaderName=HeaderValue兼容
if strings.Contains(p, ":") {
kv := strings.SplitN(p, ":", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
if strings.Contains(p, "=") {
kv := strings.SplitN(p, "=", 2)
k := strings.TrimSpace(kv[0])
v := strings.TrimSpace(kv[1])
v = strings.Trim(v, "\"")
if k != "" && v != "" {
out[k] = v
}
continue
}
}
if len(out) == 0 {
return nil
}
return out
}
// PayloadToQuery 将 payload 转为 url.Values
func PayloadToQuery(payload map[string]any) (url.Values, error) {
q := url.Values{}
for k, v := range payload {
if v == nil {
continue
}
q.Set(k, gconv.String(v))
}
return q, nil
}

View File

@@ -26,3 +26,15 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
err = r.Struct(&m)
return
}
// GetsByModelName 批量获取模型
func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) {
err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
OmitEmpty().
Where(entity.AsynchModelCol.Creator, creator).
WhereIn(entity.AsynchModelCol.ModelName, modelNames).
Fields(fields).
Scan(&list)
return
}

View File

@@ -6,10 +6,10 @@ type ComposeMessagesReq struct {
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schemaform 作为系统表单userForm 作为用户表单,结合 userFiles 调用 model-gateway并直接返回最终 messages"`
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
Form map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
Form []map[string]any `p:"form" json:"form" dc:"系统表单form 下所有字段都作为系统提示词来源"`
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`

View File

@@ -20,7 +20,7 @@ import (
type UserPromptPayload struct {
Model string `json:"model"`
PromptInfo string `json:"promptInfo"`
Form map[string]any `json:"form"`
Form any `json:"form"`
UserForm any `json:"userForm"`
Consult []dto.ConsultItem `json:"consult"`
UserFilesText map[string]string `json:"userFilesText"`
@@ -30,6 +30,7 @@ type UserPromptPayload struct {
// buildInferenceRequest 构建推理请求
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
//1) 处理表单分批
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
if err != nil {
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
@@ -47,9 +48,10 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
// buildPromptTypeRequest 构建提示词类型请求BuildType=1
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
//1) 构建系统提示词
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
ir.AddSystem(systemPrompt)
//2) 构建历史对话
for _, msg := range history {
role := gconv.String(msg["role"])
if role != "user" && role != "assistant" {
@@ -57,7 +59,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
}
ir.AddHistory(role, gconv.String(msg["content"]))
}
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
ir.AddUser(userPrompt)
if !checkOverallContent(ir, aiModel) {
@@ -70,7 +71,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
// buildNodeTypeRequest 构建节点类型请求BuildType=2
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
ir.AddUser(NodeBuild(ctx, req))
return compileToProviderRequest(ctx, ir, chatModel)
}
@@ -90,33 +90,33 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
return map[string]any{
"modelName": chatModel.ModelName,
"bizName": "prompts-core",
"bizName": util.GetServerName(ctx),
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
"requestPayload": providerReq,
}, nil
}
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
// promptBuildWithRounds 构建系统提示词
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
ProviderName: model.OperatorName,
ProviderName: chatModel.OperatorName,
Status: 1,
})
if err != nil || providerProtocol == nil {
return ""
}
outputJSON := util.JSONPretty(model.RequestMapping)
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
availableWindow := util.GetAvailableWindow(model.TokenConfig)
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
formContent := buildUserFormContent(req.Form)
userFormContent := buildUserFormContent(req.UserForm)
formInfo := fmt.Sprintf(`
【系统表单(系统提示词/参数)】
%s
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
%s
`, util.FormToJSON(req.Form), userFormContent)
`, formContent, userFormContent)
inputInfo := fmt.Sprintf(`
目标模型: %s
@@ -129,11 +129,8 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, mod
req.ModelName, // %s 目标模型名称
maxWindowSize, // %d 最大窗口
availableWindow, // %d 可用窗口
totalRounds, // %d 数组长度(多轮输出要求)
totalRounds, // %d 数组长度(结构铁律)
outputJSON, // %s 输出结构
inputInfo, // %s 完整输入信息
totalRounds, // %d 数组长度(最后一行)
)
}
@@ -157,7 +154,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
payload := UserPromptPayload{
Model: req.ModelName,
PromptInfo: prompt,
Form: req.Form,
Form: prepareUserFormPayload(req.Form),
UserForm: prepareUserFormPayload(req.UserForm),
Consult: req.Consult,
UserFilesText: ExtractFileTexts(ctx, req.Consult),

View File

@@ -21,13 +21,16 @@ import (
// ComposeMessages 核心拼接提示词主流程
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
//1) 获取模型信息
chatModel, aiModel, err := GetModelMessage(ctx, req)
if err != nil {
return nil, err
}
//2) 校验用户表单
if err = validateUserForm(req, aiModel); err != nil {
return nil, err
}
//3) 处理不同类型
switch req.BuildType {
case public.BuildTypePrompt:
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
@@ -54,7 +57,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
if chatModel == nil {
return nil, nil, errors.New("当前没有对话模型,请添加")
}
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
ModelName: req.ModelName,
})
@@ -62,10 +65,10 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
return nil, nil, err
}
if aiModel == nil {
if aiModels == nil {
return nil, nil, errors.New("需要构建的模型不存在")
}
return chatModel, aiModel, nil
return chatModel, aiModels, nil
}
// validateUserForm 校验用户表单
@@ -150,12 +153,15 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
if err != nil {
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
}
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
id := int64(0)
if req.SessionId != "" {
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
SessionId: req.SessionId,
RequestContent: util.GetUserMessage(taskReq),
})
if err != nil {
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
}
}
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
if err != nil {