feat: 新增模型扩展映射与查询配置字段
This commit is contained in:
@@ -63,18 +63,27 @@ func UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileEx
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
// TriggerCallback 任务成功后的回调:
|
||||
// - JSON body 参数:task_id/state/oss_file/file_type/text(可选)
|
||||
// CallbackPayload 回调请求体
|
||||
type CallbackPayload struct {
|
||||
TaskId string `json:"task_id"`
|
||||
State int `json:"state"`
|
||||
OssFile string `json:"oss_file"`
|
||||
FileType string `json:"file_type"`
|
||||
Text string `json:"text"`
|
||||
ErrorMsg string `json:"error_msg"`
|
||||
}
|
||||
|
||||
// TriggerCallback 任务成功后的回调
|
||||
func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"task_id": t.TaskID,
|
||||
"state": t.State,
|
||||
"oss_file": t.OssFile,
|
||||
"file_type": t.FileType,
|
||||
"text": t.TextResult,
|
||||
"error_msg": t.ErrorMsg,
|
||||
var resp struct{}
|
||||
payload := CallbackPayload{
|
||||
TaskId: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
FileType: t.FileType,
|
||||
Text: t.TextResult,
|
||||
ErrorMsg: t.ErrorMsg,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -84,7 +93,7 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
g.Log().Infof(ctx, "[回调] 开始发送 taskId=%s 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.TaskID, t.CallbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, t.CallbackURL, headers, &req, jsonData)
|
||||
err = commonHttp.Post(ctx, t.CallbackURL, headers, &resp, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[回调] 发送失败 taskId=%s 回调地址=%s 错误=%v", t.TaskID, t.CallbackURL, err)
|
||||
return
|
||||
@@ -92,15 +101,20 @@ func TriggerCallback(ctx context.Context, t *entity.AsynchTask) {
|
||||
g.Log().Infof(ctx, "[回调] 发送成功 taskId=%s 回调地址=%s 消息体大小=%d字节", t.TaskID, t.CallbackURL, len(jsonData))
|
||||
}
|
||||
|
||||
// PromptsCallbackPayload 提示词回调请求体
|
||||
type PromptsCallbackPayload struct {
|
||||
EpicycleId int64 `json:"epicycleId"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// TriggerPromptsCallback 任务成功后的提示词回调
|
||||
// - JSON body 参数:epicycleId(轮次id)/textResult(模型回答消息)
|
||||
func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleId int64) {
|
||||
callbackURL := "prompts-core/session/sessionCallback"
|
||||
headers := util.ForwardHeaders(ctx)
|
||||
var req struct{}
|
||||
payload := map[string]interface{}{
|
||||
"epicycleId": epicycleId,
|
||||
"text": t.TextResult,
|
||||
var resp struct{}
|
||||
payload := PromptsCallbackPayload{
|
||||
EpicycleId: epicycleId,
|
||||
Text: t.TextResult,
|
||||
}
|
||||
jsonData, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
@@ -110,7 +124,7 @@ func TriggerPromptsCallback(ctx context.Context, t *entity.AsynchTask, epicycleI
|
||||
g.Log().Infof(ctx, "[提示词回调] 开始发送 epicycleId=%d 回调地址=%s 请求头数量=%d 消息体大小=%d字节",
|
||||
t.EpicycleId, callbackURL, len(headers), len(jsonData))
|
||||
|
||||
err = commonHttp.Post(ctx, callbackURL, headers, &req, jsonData)
|
||||
err = commonHttp.Post(ctx, callbackURL, headers, &resp, jsonData)
|
||||
if err != nil {
|
||||
g.Log().Warningf(ctx, "[提示词回调] 发送失败 epicycleId=%d 回调地址=%s 错误=%v", t.EpicycleId, callbackURL, err)
|
||||
return
|
||||
|
||||
@@ -6,13 +6,12 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"model-gateway/model/entity"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/container/gvar"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -196,6 +195,59 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelK
|
||||
return mappedResponse, nil
|
||||
}
|
||||
|
||||
//// InvokeModel 调用模型服务,返回二进制结果
|
||||
//func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any, modelKey string) ([]byte, error) {
|
||||
// if m == nil || m.BaseURL == "" {
|
||||
// return nil, fmt.Errorf("模型配置不完整")
|
||||
// }
|
||||
// // 请求参数映射
|
||||
// mappedPayload, err := mapRequestPayload(m.RequestMapping, payload)
|
||||
// if err != nil {
|
||||
// return nil, fmt.Errorf("请求参数映射失败: %w", err)
|
||||
// }
|
||||
// // 合并请求头
|
||||
// headers := util.ForwardHeaders(ctx)
|
||||
// for hk, hv := range parseHeadMsgHeaders(m.HeadMsg) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
// for hk, hv := range parseHeadMsgHeaders(modelKey) {
|
||||
// headers[hk] = hv
|
||||
// }
|
||||
//
|
||||
// // 设置超时
|
||||
// timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
// if timeout <= 0 {
|
||||
// timeout = 600 * time.Second
|
||||
// }
|
||||
// ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
// defer cancel()
|
||||
//
|
||||
// invokeUrl := strings.TrimRight(m.BaseURL, "/")
|
||||
// method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
// if method == "" {
|
||||
// method = http.MethodPost
|
||||
// }
|
||||
//
|
||||
// var respBytes []byte
|
||||
//
|
||||
// switch method {
|
||||
// case http.MethodGet:
|
||||
// err = commonHttp.Get(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// default:
|
||||
// err = commonHttp.Post(ctx, invokeUrl, headers, &respBytes, mappedPayload)
|
||||
// }
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// // 响应参数映射
|
||||
// mappedResponse, err := mapResponsePayload(m.ResponseMapping, respBytes)
|
||||
// if err != nil {
|
||||
// g.Log().Warningf(ctx, "响应参数映射失败: %v,返回原始数据", err)
|
||||
// return respBytes, nil
|
||||
// }
|
||||
// return mappedResponse, nil
|
||||
//}
|
||||
|
||||
// ============================================
|
||||
// 映射相关函数
|
||||
// ============================================
|
||||
|
||||
@@ -3,7 +3,6 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"model-gateway/common/util"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/dao"
|
||||
"model-gateway/model/dto"
|
||||
@@ -87,6 +86,8 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -150,6 +151,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -195,6 +198,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -225,6 +230,8 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
IsOwner: req.IsOwner,
|
||||
OperatorName: req.OperatorName,
|
||||
TokenConfig: req.TokenConfig,
|
||||
ExtendMapping: req.ExtendMapping,
|
||||
QueryConfig: req.QueryConfig,
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -237,17 +244,20 @@ func (s *modelService) Delete(ctx context.Context, req *dto.DeleteModelReq) erro
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, req *dto.GetModelReq) (*dto.GetModelRes, error) {
|
||||
user, err := utils.GetUserInfo(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model, err := dao.Model.Get(ctx, &entity.AsynchModel{
|
||||
SQLBaseDO: beans.SQLBaseDO{Id: req.ID},
|
||||
SQLBaseDO: beans.SQLBaseDO{
|
||||
Id: req.ID,
|
||||
Creator: user.UserName,
|
||||
},
|
||||
ModelName: req.ModelName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
model.TokenConfig = util.ParseJSONField(model.TokenConfig)
|
||||
return &dto.GetModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
@@ -277,14 +287,6 @@ func (s *modelService) List(ctx context.Context, req *dto.ListModelReq) (res *dt
|
||||
return
|
||||
}
|
||||
|
||||
// 处理列表中每条记录的 JSONB 字段
|
||||
for _, m := range models {
|
||||
m.Form = util.ParseJSONField(m.Form)
|
||||
m.RequestMapping = util.ParseJSONField(m.RequestMapping)
|
||||
m.ResponseMapping = util.ParseJSONField(m.ResponseMapping)
|
||||
m.ResponseBody = util.ParseJSONField(m.ResponseBody)
|
||||
m.TokenConfig = util.ParseJSONField(m.TokenConfig)
|
||||
}
|
||||
return &dto.ListModelRes{
|
||||
List: models,
|
||||
Total: total,
|
||||
@@ -381,11 +383,6 @@ func (s *modelService) GetIsChatModel(ctx context.Context) (*dto.GetIsChatModelR
|
||||
if model == nil {
|
||||
return nil, nil
|
||||
}
|
||||
model.Form = util.ParseJSONField(model.Form)
|
||||
model.RequestMapping = util.ParseJSONField(model.RequestMapping)
|
||||
model.ResponseMapping = util.ParseJSONField(model.ResponseMapping)
|
||||
model.ResponseBody = util.ParseJSONField(model.ResponseBody)
|
||||
model.TokenConfig = util.ParseJSONField(model.TokenConfig)
|
||||
return &dto.GetIsChatModelRes{
|
||||
Model: model,
|
||||
}, nil
|
||||
|
||||
@@ -262,8 +262,8 @@ func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// GetExpendTokens 根据映射路径从 textResult 中提取消耗 token 值
|
||||
func GetExpendTokens(tokenMapping string, textResult string) int {
|
||||
value := gjson.Get(textResult, tokenMapping)
|
||||
func GetExpendTokens(responseTokenField string, textResult string) int {
|
||||
value := gjson.Get(textResult, responseTokenField)
|
||||
if value.Exists() {
|
||||
return int(value.Int())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user