refactor(service): 重构服务模块结构并优化模型配置
This commit is contained in:
@@ -8,6 +8,11 @@ import (
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// GetServerName 获取服务名称
|
||||
func GetServerName(ctx context.Context) string {
|
||||
return g.Cfg().MustGet(ctx, "server.name", "").String()
|
||||
}
|
||||
|
||||
// GetServerPort 从配置获取服务端口
|
||||
func GetServerPort(ctx context.Context) string {
|
||||
address := g.Cfg().MustGet(ctx, "server.address", ":8080").String()
|
||||
|
||||
@@ -2,51 +2,34 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
gfgjson "github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
tGjson "github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ParseOutput 解析模型输出为 JSON 格式
|
||||
func ParseOutput(text string) (map[string]any, error) {
|
||||
j, err := gjson.LoadJson([]byte(text))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析模型输出失败: %w", err)
|
||||
}
|
||||
|
||||
return j.Map(), nil
|
||||
}
|
||||
|
||||
// ConvertToMessages 将原始数据转换为消息列表
|
||||
func ConvertToMessages(raw any) []map[string]any {
|
||||
if raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
j, err := gjson.LoadJson(gconv.Bytes(raw))
|
||||
if err != nil {
|
||||
return nil
|
||||
j := gfgjson.New(raw)
|
||||
messages := j.Get("messages")
|
||||
if !messages.IsNil() {
|
||||
return gconv.Maps(messages.Val())
|
||||
}
|
||||
|
||||
if j.Contains("messages") {
|
||||
return gconv.Maps(j.Get("messages").Array())
|
||||
}
|
||||
|
||||
return []map[string]any{j.Map()}
|
||||
}
|
||||
|
||||
// FormToJSON 将表单数据转换为 JSON 字符串
|
||||
func FormToJSON(form map[string]any) string {
|
||||
func FormToJSON(form []map[string]any) string {
|
||||
if form == nil {
|
||||
return "{}"
|
||||
return "[]"
|
||||
}
|
||||
|
||||
b, _ := json.Marshal(form)
|
||||
return string(b)
|
||||
}
|
||||
|
||||
148
common/util/mapping.go
Normal file
148
common/util/mapping.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"prompts-core/model/entity"
|
||||
"strings"
|
||||
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// ValidatePromptResult 校验模型返回结果的 JSON 结构完整性
|
||||
// 校验逻辑:只校验 requestMapping 中默认值为空的必填字段
|
||||
func ValidatePromptResult(raw map[string]any, model *entity.AsynchModel) error {
|
||||
// 1) 获取校验配置,并取值
|
||||
requestMapping := model.RequestMapping
|
||||
contentKey := ""
|
||||
for k := range model.ResponseBody {
|
||||
contentKey = k
|
||||
break
|
||||
}
|
||||
contentStr, ok := raw[contentKey].(string)
|
||||
if !ok || contentStr == "" {
|
||||
return fmt.Errorf("%s 字段为空或不是字符串", contentKey)
|
||||
}
|
||||
|
||||
// 2) 解析 content 为 JSON 数组
|
||||
var rounds []map[string]any
|
||||
if err := gjson.DecodeTo(contentStr, &rounds); err != nil {
|
||||
return fmt.Errorf("解析 content JSON 数组失败: %w", err)
|
||||
}
|
||||
if len(rounds) == 0 {
|
||||
return fmt.Errorf("content 数组为空")
|
||||
}
|
||||
|
||||
// 3) 逐条校验:只检查默认值为空的必填字段是否存在
|
||||
for i, round := range rounds {
|
||||
for path, defaultValue := range requestMapping {
|
||||
if !g.IsEmpty(defaultValue) {
|
||||
continue
|
||||
}
|
||||
if gjson.New(round).Get(path).IsNil() {
|
||||
return fmt.Errorf("rounds[%d] 缺少必填字段: %s", i, path)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReverseMap 映射 payload 到 mapping
|
||||
func ReverseMap(mapping map[string]any, payload map[string]any) map[string]any {
|
||||
jsonObj := gjson.New("{}")
|
||||
for path, defaultValue := range mapping {
|
||||
val := gjson.New(payload).Get(path)
|
||||
if !val.IsNil() {
|
||||
_ = jsonObj.Set(path, val.Val())
|
||||
} else if defaultValue != nil {
|
||||
_ = jsonObj.Set(path, defaultValue)
|
||||
}
|
||||
}
|
||||
return jsonObj.Map()
|
||||
}
|
||||
|
||||
// MapResponsePayload 映射模型响应为标准格式
|
||||
func MapResponsePayload(mapping map[string]any, responseBytes []byte) ([]byte, error) {
|
||||
if len(mapping) == 0 {
|
||||
return responseBytes, nil
|
||||
}
|
||||
|
||||
responseJson := gjson.New(responseBytes)
|
||||
resultJson := gjson.New("{}")
|
||||
|
||||
for standardField, modelPath := range mapping {
|
||||
path := gconv.String(modelPath)
|
||||
if path == "" {
|
||||
continue
|
||||
}
|
||||
val := responseJson.Get(path)
|
||||
if val.IsNil() {
|
||||
continue
|
||||
}
|
||||
resultJson.Set(standardField, val.Val())
|
||||
}
|
||||
|
||||
return []byte(resultJson.String()), nil
|
||||
}
|
||||
|
||||
// ParseHeadMsgHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func ParseHeadMsgHeaders(headMsg string) map[string]string {
|
||||
headMsg = strings.TrimSpace(headMsg)
|
||||
if headMsg == "" {
|
||||
return nil
|
||||
}
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(headMsg, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// PayloadToQuery 将 payload 转为 url.Values
|
||||
func PayloadToQuery(payload map[string]any) (url.Values, error) {
|
||||
q := url.Values{}
|
||||
for k, v := range payload {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
q.Set(k, gconv.String(v))
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
@@ -26,3 +26,15 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// GetsByModelName 批量获取模型
|
||||
func (d *modelDao) GetsByModelName(ctx context.Context, creator string, modelNames []string, fields ...string) (list []*entity.AsynchModel, err error) {
|
||||
err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Creator, creator).
|
||||
WhereIn(entity.AsynchModelCol.ModelName, modelNames).
|
||||
Fields(fields).
|
||||
Scan(&list)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ type ComposeMessagesReq struct {
|
||||
g.Meta `path:"/composeMessages" method:"post" tags:"提示词处理" summary:"拼接提示词" dc:"按 modelTypeId 读取 prompts_model_prompt.prompt_info 与 response_json_schema;form 作为系统表单,userForm 作为用户表单,结合 userFiles 调用 model-gateway,并直接返回最终 messages"`
|
||||
ModelName string `p:"modelName" json:"modelName" v:"required#modelName不能为空" dc:"实际请求的网关模型名称"`
|
||||
BuildType int `p:"buildType" json:"buildType" v:"required#buildType不能为空" dc:"构建类型"` //判断节点
|
||||
SessionId string `p:"sessionId" json:"sessionId" v:"required#sessionId不能为空" dc:"会话ID"`
|
||||
SessionId string `p:"sessionId" json:"sessionId" dc:"会话ID"` //v:"required#sessionId不能为空"
|
||||
Cause string `p:"cause" json:"cause" v:"required-if:IsBuilder,false#原因不能为空" dc:"原因"`
|
||||
CallbackUrl string `p:"callbackUrl" json:"callbackUrl" dc:"回调地址"`
|
||||
Form map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||
Form []map[string]any `p:"form" json:"form" dc:"系统表单:form 下所有字段都作为系统提示词来源"`
|
||||
UserForm []map[string]any `p:"userForm" json:"userForm" dc:"用户表单:userForm 下所有字段都作为用户提示词来源;若与 form 含义接近则严格覆盖系统字段"`
|
||||
Consult []ConsultItem `json:"consult" dc:"附件列表(图片/视频/音频)"`
|
||||
SkillName string `p:"skillName" json:"skillName" dc:"技能名称"`
|
||||
|
||||
@@ -20,7 +20,7 @@ import (
|
||||
type UserPromptPayload struct {
|
||||
Model string `json:"model"`
|
||||
PromptInfo string `json:"promptInfo"`
|
||||
Form map[string]any `json:"form"`
|
||||
Form any `json:"form"`
|
||||
UserForm any `json:"userForm"`
|
||||
Consult []dto.ConsultItem `json:"consult"`
|
||||
UserFilesText map[string]string `json:"userFilesText"`
|
||||
@@ -30,6 +30,7 @@ type UserPromptPayload struct {
|
||||
|
||||
// buildInferenceRequest 构建推理请求
|
||||
func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, history []map[string]any) (map[string]any, error) {
|
||||
//1) 处理表单分批
|
||||
processedReq, totalBatches, err := ProcessUserFormBatches(ctx, req, aiModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("处理用户表单分批失败: %w", err)
|
||||
@@ -47,9 +48,10 @@ func buildInferenceRequest(ctx context.Context, req *dto.ComposeMessagesReq, cha
|
||||
|
||||
// buildPromptTypeRequest 构建提示词类型请求(BuildType=1)
|
||||
func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, aiModel *entity.AsynchModel, chatModel *entity.AsynchModel, history []map[string]any, ir *PromptIR, totalBatches int) (map[string]any, error) {
|
||||
systemPrompt := promptBuildWithRounds(ctx, req, aiModel, totalBatches)
|
||||
//1) 构建系统提示词
|
||||
systemPrompt := promptBuildWithRounds(ctx, req, chatModel, aiModel, totalBatches)
|
||||
ir.AddSystem(systemPrompt)
|
||||
|
||||
//2) 构建历史对话
|
||||
for _, msg := range history {
|
||||
role := gconv.String(msg["role"])
|
||||
if role != "user" && role != "assistant" {
|
||||
@@ -57,7 +59,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
}
|
||||
ir.AddHistory(role, gconv.String(msg["content"]))
|
||||
}
|
||||
|
||||
userPrompt := buildUserPrompt(ctx, req, util.GetModelPrompt(ctx, aiModel.ModelType))
|
||||
ir.AddUser(userPrompt)
|
||||
if !checkOverallContent(ir, aiModel) {
|
||||
@@ -70,7 +71,6 @@ func buildPromptTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, ai
|
||||
// buildNodeTypeRequest 构建节点类型请求(BuildType=2)
|
||||
func buildNodeTypeRequest(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, ir *PromptIR) (map[string]any, error) {
|
||||
ir.AddUser(NodeBuild(ctx, req))
|
||||
|
||||
return compileToProviderRequest(ctx, ir, chatModel)
|
||||
}
|
||||
|
||||
@@ -90,33 +90,33 @@ func compileToProviderRequest(ctx context.Context, ir *PromptIR, chatModel *enti
|
||||
|
||||
return map[string]any{
|
||||
"modelName": chatModel.ModelName,
|
||||
"bizName": "prompts-core",
|
||||
"bizName": util.GetServerName(ctx),
|
||||
"callbackUrl": util.GetCallbackURL(ctx, "/prompt/callback"),
|
||||
"requestPayload": providerReq,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// promptBuildWithRounds 构建系统提示词(包含轮次信息)
|
||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, model *entity.AsynchModel, totalRounds int) string {
|
||||
// promptBuildWithRounds 构建系统提示词
|
||||
func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, chatModel *entity.AsynchModel, aiModel *entity.AsynchModel, batches int) string {
|
||||
providerProtocol, err := dao.ProviderProtocol.Get(ctx, &entity.ProviderProtocol{
|
||||
ProviderName: model.OperatorName,
|
||||
ProviderName: chatModel.OperatorName,
|
||||
Status: 1,
|
||||
})
|
||||
if err != nil || providerProtocol == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
outputJSON := util.JSONPretty(model.RequestMapping)
|
||||
maxWindowSize := util.GetMaxWindowSize(model.TokenConfig)
|
||||
availableWindow := util.GetAvailableWindow(model.TokenConfig)
|
||||
|
||||
outputJSON := util.JSONPretty(util.ReverseMap(aiModel.RequestMapping, map[string]any{}))
|
||||
maxWindowSize := util.GetMaxWindowSize(chatModel.TokenConfig)
|
||||
availableWindow := util.GetAvailableWindow(chatModel.TokenConfig)
|
||||
formContent := buildUserFormContent(req.Form)
|
||||
userFormContent := buildUserFormContent(req.UserForm)
|
||||
formInfo := fmt.Sprintf(`
|
||||
【系统表单(系统提示词/参数)】
|
||||
%s
|
||||
【用户表单全文(必须完整阅读,全部作为用户提示词来源)】
|
||||
%s
|
||||
`, util.FormToJSON(req.Form), userFormContent)
|
||||
`, formContent, userFormContent)
|
||||
|
||||
inputInfo := fmt.Sprintf(`
|
||||
目标模型: %s
|
||||
@@ -129,11 +129,8 @@ func promptBuildWithRounds(ctx context.Context, req *dto.ComposeMessagesReq, mod
|
||||
req.ModelName, // %s 目标模型名称
|
||||
maxWindowSize, // %d 最大窗口
|
||||
availableWindow, // %d 可用窗口
|
||||
totalRounds, // %d 数组长度(多轮输出要求)
|
||||
totalRounds, // %d 数组长度(结构铁律)
|
||||
outputJSON, // %s 输出结构
|
||||
inputInfo, // %s 完整输入信息
|
||||
totalRounds, // %d 数组长度(最后一行)
|
||||
)
|
||||
}
|
||||
|
||||
@@ -157,7 +154,7 @@ func buildUserPrompt(ctx context.Context, req *dto.ComposeMessagesReq, prompt st
|
||||
payload := UserPromptPayload{
|
||||
Model: req.ModelName,
|
||||
PromptInfo: prompt,
|
||||
Form: req.Form,
|
||||
Form: prepareUserFormPayload(req.Form),
|
||||
UserForm: prepareUserFormPayload(req.UserForm),
|
||||
Consult: req.Consult,
|
||||
UserFilesText: ExtractFileTexts(ctx, req.Consult),
|
||||
|
||||
@@ -21,13 +21,16 @@ import (
|
||||
|
||||
// ComposeMessages 核心拼接提示词主流程
|
||||
func ComposeMessages(ctx context.Context, req *dto.ComposeMessagesReq) (*dto.ComposeMessagesRes, error) {
|
||||
//1) 获取模型信息
|
||||
chatModel, aiModel, err := GetModelMessage(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//2) 校验用户表单
|
||||
if err = validateUserForm(req, aiModel); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//3) 处理不同类型
|
||||
switch req.BuildType {
|
||||
case public.BuildTypePrompt:
|
||||
return handlePromptBuild(ctx, req, chatModel, aiModel) // 提示词构建
|
||||
@@ -54,7 +57,7 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
||||
if chatModel == nil {
|
||||
return nil, nil, errors.New("当前没有对话模型,请添加")
|
||||
}
|
||||
aiModel, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
aiModels, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Creator: userInfo.UserName},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
@@ -62,10 +65,10 @@ func GetModelMessage(ctx context.Context, req *dto.ComposeMessagesReq) (*entity.
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if aiModel == nil {
|
||||
if aiModels == nil {
|
||||
return nil, nil, errors.New("需要构建的模型不存在")
|
||||
}
|
||||
return chatModel, aiModel, nil
|
||||
return chatModel, aiModels, nil
|
||||
}
|
||||
|
||||
// validateUserForm 校验用户表单
|
||||
@@ -150,12 +153,15 @@ func callInferenceModel(ctx context.Context, req *dto.ComposeMessagesReq, chatMo
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("构建推理请求失败: %w", err)
|
||||
}
|
||||
id, err := dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: util.GetUserMessage(taskReq),
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
||||
id := int64(0)
|
||||
if req.SessionId != "" {
|
||||
id, err = dao.ComposeSession.Insert(ctx, &entity.ComposeSession{
|
||||
SessionId: req.SessionId,
|
||||
RequestContent: util.GetUserMessage(taskReq),
|
||||
})
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("保存历史会话失败: %w", err)
|
||||
}
|
||||
}
|
||||
taskID, err := gateway.CreateGatewayTask(ctx, taskReq)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user