220 lines
6.3 KiB
Go
220 lines
6.3 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
stdhttp "net/http"
|
||
"os"
|
||
"path/filepath"
|
||
"sync"
|
||
"time"
|
||
|
||
commonHttp "gitea.com/red-future/common/http"
|
||
"gitea.com/red-future/common/utils"
|
||
"github.com/gogf/gf/v2/errors/gerror"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
)
|
||
|
||
var commonHttpTransportMu sync.Mutex
|
||
|
||
// asyncCtx 异步上下文处理
|
||
func asyncCtx(ctx context.Context) context.Context {
|
||
asyncCtx := context.WithoutCancel(ctx)
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
if token := r.Header.Get("Authorization"); token != "" {
|
||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||
}
|
||
}
|
||
if user, uErr := utils.GetUserInfo(ctx); uErr == nil && user != nil {
|
||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||
}
|
||
return asyncCtx
|
||
}
|
||
|
||
// setCommonHttpResponseHeaderTimeout 调整公共 HTTP 客户端响应头超时,避免长时推理被 30s 默认值打断。
|
||
func setCommonHttpResponseHeaderTimeout(d time.Duration) {
|
||
if d <= 0 {
|
||
return
|
||
}
|
||
commonHttpTransportMu.Lock()
|
||
defer commonHttpTransportMu.Unlock()
|
||
if tr, ok := commonHttp.Httpclient.Transport.(*stdhttp.Transport); ok && tr != nil {
|
||
if tr.ResponseHeaderTimeout < d {
|
||
tr.ResponseHeaderTimeout = d
|
||
}
|
||
}
|
||
}
|
||
|
||
// forwardHeaders 透传调用链路中必须的头信息,优先使用异步上下文里固化的 token。
|
||
func forwardHeaders(ctx context.Context) map[string]string {
|
||
headers := make(map[string]string)
|
||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||
headers["Authorization"] = token
|
||
}
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
if headers["Authorization"] == "" {
|
||
if token := r.Header.Get("Authorization"); token != "" {
|
||
headers["Authorization"] = token
|
||
}
|
||
}
|
||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||
headers["X-User-Info"] = userInfo
|
||
}
|
||
}
|
||
return headers
|
||
}
|
||
|
||
// commonPostJSON 使用 common/http 的底层客户端直连 JSON 接口,适配非统一响应包装结构。
|
||
func commonPostJSON(ctx context.Context, url string, headers map[string]string, req any, resp any) error {
|
||
client := commonHttp.Httpclient.Clone().ContentJson()
|
||
if deadline, ok := ctx.Deadline(); ok {
|
||
if d := time.Until(deadline); d > 0 {
|
||
client.SetTimeout(d)
|
||
}
|
||
}
|
||
if len(headers) > 0 {
|
||
client.SetHeaderMap(headers)
|
||
}
|
||
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, req)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer r.Close()
|
||
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
return gerror.Wrap(err, "读取响应失败")
|
||
}
|
||
if r.StatusCode != stdhttp.StatusOK {
|
||
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(body))
|
||
}
|
||
if err := json.Unmarshal(body, resp); err != nil {
|
||
return gerror.Wrapf(err, "解析响应失败, body: %s", string(body))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func commonPostMultipartFile(ctx context.Context, url string, headers map[string]string, form map[string]string, fileField string, filePath string, resp any) error {
|
||
body := &bytes.Buffer{}
|
||
writer := multipart.NewWriter(body)
|
||
|
||
for k, v := range form {
|
||
if v == "" {
|
||
continue
|
||
}
|
||
if err := writer.WriteField(k, v); err != nil {
|
||
return gerror.Wrapf(err, "写入表单字段失败: %s", k)
|
||
}
|
||
}
|
||
|
||
f, err := os.Open(filePath)
|
||
if err != nil {
|
||
return gerror.Wrapf(err, "打开文件失败: %s", filePath)
|
||
}
|
||
defer f.Close()
|
||
|
||
part, err := writer.CreateFormFile(fileField, filepath.Base(filePath))
|
||
if err != nil {
|
||
return gerror.Wrapf(err, "创建表单文件失败: %s", fileField)
|
||
}
|
||
if _, err := io.Copy(part, f); err != nil {
|
||
return gerror.Wrap(err, "写入文件内容失败")
|
||
}
|
||
|
||
contentType := writer.FormDataContentType()
|
||
if err := writer.Close(); err != nil {
|
||
return gerror.Wrap(err, "关闭表单写入器失败")
|
||
}
|
||
|
||
client := commonHttp.Httpclient.Clone()
|
||
if deadline, ok := ctx.Deadline(); ok {
|
||
if d := time.Until(deadline); d > 0 {
|
||
client.SetTimeout(d)
|
||
}
|
||
}
|
||
if headers == nil {
|
||
headers = make(map[string]string)
|
||
}
|
||
headers["Content-Type"] = contentType
|
||
client.SetHeaderMap(headers)
|
||
|
||
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, body.Bytes())
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer r.Close()
|
||
|
||
raw, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
return gerror.Wrap(err, "读取响应失败")
|
||
}
|
||
if r.StatusCode != stdhttp.StatusOK {
|
||
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(raw))
|
||
}
|
||
if err := json.Unmarshal(raw, resp); err != nil {
|
||
return gerror.Wrapf(err, "解析响应失败, body: %s", string(raw))
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// -------------------------- model-asynch 调用封装 --------------------------
|
||
|
||
const modelAsynchServiceName = "model-asynch"
|
||
|
||
type modelAsynchCreateTaskReq struct {
|
||
ModelName string `json:"modelName"`
|
||
InputRef string `json:"inputRef,omitempty"`
|
||
RequestPayload any `json:"requestPayload"`
|
||
}
|
||
|
||
type modelAsynchCreateTaskRes struct {
|
||
TaskID string `json:"taskId"`
|
||
}
|
||
|
||
// createModelAsynchTask 调用 model-asynch 创建任务
|
||
// 注意:路由以 GoFrame 默认输出为准(通常为 /task/create-task)
|
||
func createModelAsynchTask(ctx context.Context, modelName string, payload any, inputRef string) (taskID string, err error) {
|
||
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
|
||
headers := forwardHeaders(ctx)
|
||
req := &modelAsynchCreateTaskReq{
|
||
ModelName: modelName,
|
||
InputRef: inputRef,
|
||
RequestPayload: payload,
|
||
}
|
||
var res modelAsynchCreateTaskRes
|
||
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/createTask", taskUrl), headers, &res, req); err != nil {
|
||
return "", err
|
||
}
|
||
return res.TaskID, nil
|
||
}
|
||
|
||
type modelAsynchBatchReq struct {
|
||
TaskIDs []string `json:"taskIds"`
|
||
}
|
||
|
||
type modelAsynchBatchItem struct {
|
||
TaskID string `json:"taskId"`
|
||
State int `json:"state"`
|
||
OssFile string `json:"ossFile"`
|
||
}
|
||
|
||
type modelAsynchBatchRes struct {
|
||
List []modelAsynchBatchItem `json:"list"`
|
||
}
|
||
|
||
// getModelAsynchTaskBatch 批量查询任务(成功 2->4 的逻辑由中间件内部处理)
|
||
func getModelAsynchTaskBatch(ctx context.Context, taskIDs []string) (items []modelAsynchBatchItem, err error) {
|
||
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
|
||
headers := forwardHeaders(ctx)
|
||
req := &modelAsynchBatchReq{TaskIDs: taskIDs}
|
||
var res modelAsynchBatchRes
|
||
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/getTaskBatch", taskUrl), headers, &res, req); err != nil {
|
||
return nil, err
|
||
}
|
||
return res.List, nil
|
||
}
|