package service import ( "bytes" "context" "encoding/json" "fmt" "io" "mime/multipart" stdhttp "net/http" "os" "path/filepath" "sync" "time" commonHttp "gitea.com/red-future/common/http" "gitea.com/red-future/common/utils" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" ) var commonHttpTransportMu sync.Mutex // asyncCtx 异步上下文处理 func asyncCtx(ctx context.Context) context.Context { asyncCtx := context.WithoutCancel(ctx) if r := g.RequestFromCtx(ctx); r != nil { if token := r.Header.Get("Authorization"); token != "" { asyncCtx = context.WithValue(asyncCtx, "token", token) } } if user, uErr := utils.GetUserInfo(ctx); uErr == nil && user != nil { asyncCtx = context.WithValue(asyncCtx, "user", user) } return asyncCtx } // setCommonHttpResponseHeaderTimeout 调整公共 HTTP 客户端响应头超时,避免长时推理被 30s 默认值打断。 func setCommonHttpResponseHeaderTimeout(d time.Duration) { if d <= 0 { return } commonHttpTransportMu.Lock() defer commonHttpTransportMu.Unlock() if tr, ok := commonHttp.Httpclient.Transport.(*stdhttp.Transport); ok && tr != nil { if tr.ResponseHeaderTimeout < d { tr.ResponseHeaderTimeout = d } } } // forwardHeaders 透传调用链路中必须的头信息,优先使用异步上下文里固化的 token。 func forwardHeaders(ctx context.Context) map[string]string { headers := make(map[string]string) if token, ok := ctx.Value("token").(string); ok && token != "" { headers["Authorization"] = token } if r := g.RequestFromCtx(ctx); r != nil { if headers["Authorization"] == "" { if token := r.Header.Get("Authorization"); token != "" { headers["Authorization"] = token } } if userInfo := r.Header.Get("X-User-Info"); userInfo != "" { headers["X-User-Info"] = userInfo } } return headers } // commonPostJSON 使用 common/http 的底层客户端直连 JSON 接口,适配非统一响应包装结构。 func commonPostJSON(ctx context.Context, url string, headers map[string]string, req any, resp any) error { client := commonHttp.Httpclient.Clone().ContentJson() if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 { client.SetTimeout(d) } } if len(headers) > 0 { client.SetHeaderMap(headers) } r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, req) if err != nil { return err } defer r.Close() body, err := io.ReadAll(r.Body) if err != nil { return gerror.Wrap(err, "读取响应失败") } if r.StatusCode != stdhttp.StatusOK { return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(body)) } if err := json.Unmarshal(body, resp); err != nil { return gerror.Wrapf(err, "解析响应失败, body: %s", string(body)) } return nil } func commonPostMultipartFile(ctx context.Context, url string, headers map[string]string, form map[string]string, fileField string, filePath string, resp any) error { body := &bytes.Buffer{} writer := multipart.NewWriter(body) for k, v := range form { if v == "" { continue } if err := writer.WriteField(k, v); err != nil { return gerror.Wrapf(err, "写入表单字段失败: %s", k) } } f, err := os.Open(filePath) if err != nil { return gerror.Wrapf(err, "打开文件失败: %s", filePath) } defer f.Close() part, err := writer.CreateFormFile(fileField, filepath.Base(filePath)) if err != nil { return gerror.Wrapf(err, "创建表单文件失败: %s", fileField) } if _, err := io.Copy(part, f); err != nil { return gerror.Wrap(err, "写入文件内容失败") } contentType := writer.FormDataContentType() if err := writer.Close(); err != nil { return gerror.Wrap(err, "关闭表单写入器失败") } client := commonHttp.Httpclient.Clone() if deadline, ok := ctx.Deadline(); ok { if d := time.Until(deadline); d > 0 { client.SetTimeout(d) } } if headers == nil { headers = make(map[string]string) } headers["Content-Type"] = contentType client.SetHeaderMap(headers) r, err := client.DoRequest(ctx, stdhttp.MethodPost, url, body.Bytes()) if err != nil { return err } defer r.Close() raw, err := io.ReadAll(r.Body) if err != nil { return gerror.Wrap(err, "读取响应失败") } if r.StatusCode != stdhttp.StatusOK { return gerror.Newf("HTTP状态码异常: %d, body: %s", r.StatusCode, string(raw)) } if err := json.Unmarshal(raw, resp); err != nil { return gerror.Wrapf(err, "解析响应失败, body: %s", string(raw)) } return nil } // -------------------------- model-asynch 调用封装 -------------------------- const modelAsynchServiceName = "model-asynch" type modelAsynchCreateTaskReq struct { ModelName string `json:"modelName"` InputRef string `json:"inputRef,omitempty"` RequestPayload any `json:"requestPayload"` } type modelAsynchCreateTaskRes struct { TaskID string `json:"taskId"` } // createModelAsynchTask 调用 model-asynch 创建任务 // 注意:路由以 GoFrame 默认输出为准(通常为 /task/create-task) func createModelAsynchTask(ctx context.Context, modelName string, payload any, inputRef string) (taskID string, err error) { taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080") headers := forwardHeaders(ctx) req := &modelAsynchCreateTaskReq{ ModelName: modelName, InputRef: inputRef, RequestPayload: payload, } var res modelAsynchCreateTaskRes if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/createTask", taskUrl), headers, &res, req); err != nil { return "", err } return res.TaskID, nil } type modelAsynchBatchReq struct { TaskIDs []string `json:"taskIds"` } type modelAsynchBatchItem struct { TaskID string `json:"taskId"` State int `json:"state"` OssFile string `json:"ossFile"` } type modelAsynchBatchRes struct { List []modelAsynchBatchItem `json:"list"` } // getModelAsynchTaskBatch 批量查询任务(成功 2->4 的逻辑由中间件内部处理) func getModelAsynchTaskBatch(ctx context.Context, taskIDs []string) (items []modelAsynchBatchItem, err error) { taskUrl := g.Cfg().MustGet(ctx, "model-asynch.addr", "127.0.0.1:8080") headers := forwardHeaders(ctx) req := &modelAsynchBatchReq{TaskIDs: taskIDs} var res modelAsynchBatchRes if err := commonHttp.Post(ctx, fmt.Sprintf("%s/task/getTaskBatch", taskUrl), headers, &res, req); err != nil { return nil, err } return res.List, nil }