Files
ai-agent/digital-human/service/http_wrapper.go
2026-04-27 14:02:43 +08:00

220 lines
6.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
stdhttp "net/http"
"os"
"path/filepath"
"sync"
"time"
commonHttp "gitea.com/red-future/common/http"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/errors/gerror"
"github.com/gogf/gf/v2/frame/g"
)
var commonHttpTransportMu sync.Mutex
// asyncCtx 异步上下文处理
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
}
if user, uErr := utils.GetUserInfo(ctx); uErr == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// setCommonHttpResponseHeaderTimeout 调整公共 HTTP 客户端响应头超时,避免长时推理被 30s 默认值打断。
func setCommonHttpResponseHeaderTimeout(d time.Duration) {
if d <= 0 {
return
}
commonHttpTransportMu.Lock()
defer commonHttpTransportMu.Unlock()
if tr, ok := commonHttp.Httpclient.Transport.(*stdhttp.Transport); ok && tr != nil {
if tr.ResponseHeaderTimeout < d {
tr.ResponseHeaderTimeout = d
}
}
}
// forwardHeaders 透传调用链路中必须的头信息,优先使用异步上下文里固化的 token。
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
return headers
}
// commonPostJSON 使用 common/http 的底层客户端直连 JSON 接口,适配非统一响应包装结构。
func commonPostJSON(ctx context.Context, url string, headers map[string]string, req any, resp any) error {
client := commonHttp.Httpclient.Clone().ContentJson()
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 {
client.SetTimeout(d)
}
}
if len(headers) > 0 {
client.SetHeaderMap(headers)
}
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, req)
if err != nil {
return err
}
defer r.Close()
body, err := io.ReadAll(r.Body)
if err != nil {
return gerror.Wrap(err, "读取响应失败")
}
if r.StatusCode != stdhttp.StatusOK {
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(body))
}
if err := json.Unmarshal(body, resp); err != nil {
return gerror.Wrapf(err, "解析响应失败, body: %s", string(body))
}
return nil
}
func commonPostMultipartFile(ctx context.Context, url string, headers map[string]string, form map[string]string, fileField string, filePath string, resp any) error {
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
for k, v := range form {
if v == "" {
continue
}
if err := writer.WriteField(k, v); err != nil {
return gerror.Wrapf(err, "写入表单字段失败: %s", k)
}
}
f, err := os.Open(filePath)
if err != nil {
return gerror.Wrapf(err, "打开文件失败: %s", filePath)
}
defer f.Close()
part, err := writer.CreateFormFile(fileField, filepath.Base(filePath))
if err != nil {
return gerror.Wrapf(err, "创建表单文件失败: %s", fileField)
}
if _, err := io.Copy(part, f); err != nil {
return gerror.Wrap(err, "写入文件内容失败")
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return gerror.Wrap(err, "关闭表单写入器失败")
}
client := commonHttp.Httpclient.Clone()
if deadline, ok := ctx.Deadline(); ok {
if d := time.Until(deadline); d > 0 {
client.SetTimeout(d)
}
}
if headers == nil {
headers = make(map[string]string)
}
headers["Content-Type"] = contentType
client.SetHeaderMap(headers)
r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, body.Bytes())
if err != nil {
return err
}
defer r.Close()
raw, err := io.ReadAll(r.Body)
if err != nil {
return gerror.Wrap(err, "读取响应失败")
}
if r.StatusCode != stdhttp.StatusOK {
return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(raw))
}
if err := json.Unmarshal(raw, resp); err != nil {
return gerror.Wrapf(err, "解析响应失败, body: %s", string(raw))
}
return nil
}
// -------------------------- model-asynch 调用封装 --------------------------
const modelAsynchServiceName = "model-asynch"
type modelAsynchCreateTaskReq struct {
ModelName string `json:"modelName"`
InputRef string `json:"inputRef,omitempty"`
RequestPayload any `json:"requestPayload"`
}
type modelAsynchCreateTaskRes struct {
TaskID string `json:"taskId"`
}
// createModelAsynchTask 调用 model-asynch 创建任务
// 注意:路由以 GoFrame 默认输出为准(通常为 /task/create-task
func createModelAsynchTask(ctx context.Context, modelName string, payload any, inputRef string) (taskID string, err error) {
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
headers := forwardHeaders(ctx)
req := &modelAsynchCreateTaskReq{
ModelName: modelName,
InputRef: inputRef,
RequestPayload: payload,
}
var res modelAsynchCreateTaskRes
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/createTask", taskUrl), headers, &res, req); err != nil {
return "", err
}
return res.TaskID, nil
}
type modelAsynchBatchReq struct {
TaskIDs []string `json:"taskIds"`
}
type modelAsynchBatchItem struct {
TaskID string `json:"taskId"`
State int `json:"state"`
OssFile string `json:"ossFile"`
}
type modelAsynchBatchRes struct {
List []modelAsynchBatchItem `json:"list"`
}
// getModelAsynchTaskBatch 批量查询任务(成功 2->4 的逻辑由中间件内部处理)
func getModelAsynchTaskBatch(ctx context.Context, taskIDs []string) (items []modelAsynchBatchItem, err error) {
taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080")
headers := forwardHeaders(ctx)
req := &modelAsynchBatchReq{TaskIDs: taskIDs}
var res modelAsynchBatchRes
if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/getTaskBatch", taskUrl), headers, &res, req); err != nil {
return nil, err
}
return res.List, nil
}