feat: 新增操作日志、任务分页查询与模型失败重试优化
- 新增操作日志表(asynch_op_log)及对应DAO,记录任务创建等操作的审计信息 - 新增任务分页查询接口(ListTask)及对应DTO、Service和DAO方法 - 优化模型调用失败重试逻辑:支持配置重试排队策略(插队到队首或队尾) - 新增临时文件存储机制,当模型调用成功但OSS上传失败时,下次仅重试OSS上传 - 模型配置新增retry_queue_max_seconds字段,控制失败重试排队策略 - 更新数据库表结构(asynch_models、asynch_task、新增asynch_op_log)及同步更新SQL - 配置文件调整:超时单位改为秒,更新服务地址和轮询间隔 - 修复模型列表查询支持按名称模糊搜索
This commit is contained in:
@@ -45,6 +45,7 @@ func (c *cleaner) runOnce(ctx context.Context) {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
@@ -71,7 +72,20 @@ func (c *cleaner) runOnce(ctx context.Context) {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
||||
} else {
|
||||
for _, t := range retryable {
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id)
|
||||
// retry_queue_max_seconds 控制失败重试的排队策略:
|
||||
// - =0:失败重试插队到队首
|
||||
// - >0:当任务从创建到现在的排队时长 >= maxSeconds,则插队到队首;否则仍放到队尾
|
||||
now := time.Now()
|
||||
enqueueAt := now
|
||||
maxSeconds := t.RetryQueueMaxSeconds
|
||||
if maxSeconds == 0 {
|
||||
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
|
||||
} else if maxSeconds > 0 && t.CreatedAt != nil {
|
||||
if now.Sub(t.CreatedAt.Time) >= time.Duration(maxSeconds)*time.Second {
|
||||
enqueueAt = now.Add(-100 * 365 * 24 * time.Hour)
|
||||
}
|
||||
}
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id, enqueueAt)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
||||
}
|
||||
@@ -82,6 +96,7 @@ func (c *cleaner) runOnce(ctx context.Context) {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
deleteTmpResult(t.TmpFile)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
|
||||
@@ -14,24 +14,52 @@ import (
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
func parseAPIKeyHeader(apiKey string) (k, v string) {
|
||||
// parseAPIKeyHeaders 支持多个 header 绑定,逗号分隔:
|
||||
// 示例:
|
||||
// - X-API-Key:qwen3-tts-key,operation:true,count:123
|
||||
// - X-API-Key:"qwen3-tts-key",operation:"true"
|
||||
//
|
||||
// 说明:
|
||||
// - HTTP Header 最终都是字符串,这里做的是“值的字符串化表达”。
|
||||
// - 若 value 用双引号包裹,会去掉外层引号再注入,便于在配置中区分字符串/布尔/数字等表达(以及避免值中包含特殊字符时歧义)。
|
||||
func parseAPIKeyHeaders(apiKey string) map[string]string {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return "", ""
|
||||
return nil
|
||||
}
|
||||
// 支持两种写法:
|
||||
// 1) HeaderName:HeaderValue(推荐)
|
||||
// 2) HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(apiKey, ":") {
|
||||
parts := strings.SplitN(apiKey, ":", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
out := map[string]string{}
|
||||
parts := strings.Split(apiKey, ",")
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
// HeaderName:HeaderValue(推荐) / HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(p, ":") {
|
||||
kv := strings.SplitN(p, ":", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
if strings.Contains(p, "=") {
|
||||
kv := strings.SplitN(p, "=", 2)
|
||||
k := strings.TrimSpace(kv[0])
|
||||
v := strings.TrimSpace(kv[1])
|
||||
v = strings.Trim(v, "\"")
|
||||
if k != "" && v != "" {
|
||||
out[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if strings.Contains(apiKey, "=") {
|
||||
parts := strings.SplitN(apiKey, "=", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
// 只给了 value:不做注入(避免注入非法 header)
|
||||
return "", ""
|
||||
return out
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
@@ -76,7 +104,7 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byt
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
}
|
||||
|
||||
timeout := time.Duration(m.TimeoutMs) * time.Millisecond
|
||||
timeout := time.Duration(m.TimeoutSeconds) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
@@ -122,7 +150,7 @@ func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byt
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" {
|
||||
for hk, hv := range parseAPIKeyHeaders(m.APIKey) {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
|
||||
@@ -15,18 +15,19 @@ type modelService struct{}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
m := &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
BaseURL: req.BaseURL,
|
||||
Route: req.Route,
|
||||
HttpMethod: req.HttpMethod,
|
||||
APIKey: req.APIKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutMs: req.TimeoutMs,
|
||||
RetryTimes: req.RetryTimes,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
ModelName: req.ModelName,
|
||||
BaseURL: req.BaseURL,
|
||||
Route: req.Route,
|
||||
HttpMethod: req.HttpMethod,
|
||||
APIKey: req.APIKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutSeconds: req.TimeoutSeconds,
|
||||
RetryTimes: req.RetryTimes,
|
||||
RetryQueueMaxSecs: req.RetryQueueMaxSeconds,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
if m.HttpMethod == "" {
|
||||
m.HttpMethod = "POST"
|
||||
@@ -40,8 +41,8 @@ func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res
|
||||
if m.QueueLimit <= 0 {
|
||||
m.QueueLimit = 1000
|
||||
}
|
||||
if m.TimeoutMs <= 0 {
|
||||
m.TimeoutMs = 60000
|
||||
if m.TimeoutSeconds <= 0 {
|
||||
m.TimeoutSeconds = 60
|
||||
}
|
||||
if m.AutoCleanSeconds <= 0 {
|
||||
m.AutoCleanSeconds = 86400
|
||||
@@ -76,12 +77,15 @@ func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) erro
|
||||
if req.QueueLimit != nil {
|
||||
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
|
||||
}
|
||||
if req.TimeoutMs != nil {
|
||||
data[entity.AsynchModelCol.TimeoutMs] = *req.TimeoutMs
|
||||
if req.TimeoutSeconds != nil {
|
||||
data[entity.AsynchModelCol.TimeoutSeconds] = *req.TimeoutSeconds
|
||||
}
|
||||
if req.RetryTimes != nil {
|
||||
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
|
||||
}
|
||||
if req.RetryQueueMaxSeconds != nil {
|
||||
data[entity.AsynchModelCol.RetryQueueMaxSecs] = *req.RetryQueueMaxSeconds
|
||||
}
|
||||
if req.AutoCleanSeconds != nil {
|
||||
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
|
||||
}
|
||||
@@ -104,6 +108,6 @@ func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel,
|
||||
return dao.Model.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
return dao.Model.List(ctx, pageNum, pageSize)
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int, modelNameLike string) (list []*entity.AsynchModel, total int64, err error) {
|
||||
return dao.Model.List(ctx, pageNum, pageSize, modelNameLike)
|
||||
}
|
||||
|
||||
@@ -5,12 +5,14 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/gogf/gf/v2/util/guid"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
@@ -25,8 +27,6 @@ type uploadFileResponse struct {
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// ossUrl := g.Cfg().MustGet(ctx, "oss.addr", "192.168.3.30:9000").String()
|
||||
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
@@ -39,7 +39,8 @@ func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, dat
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
part, err := writer.CreateFormFile("file", "")
|
||||
filename := fmt.Sprintf("asynch_%d_%s%s", time.Now().Unix(), guid.S(), ext)
|
||||
part, err := writer.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -55,16 +56,14 @@ func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, dat
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s size=%d", fullURL, len(data))
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s filename=%s size=%d", fullURL, filename, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.FileURL == "" {
|
||||
return "", fmt.Errorf("OSS服务返回错误: 上传失败")
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s filename=%s size=%d format=%s", resp.FileURL, resp.FileName, resp.FileSize, resp.FileFormat)
|
||||
fmt.Println("打印结果 resp:", resp)
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s size=%d format=%s", resp.FileURL, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -18,6 +20,7 @@ var Task = &taskService{}
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
startAt := time.Now()
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
@@ -59,6 +62,35 @@ func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3) 写操作日志(尽量不影响主流程,失败忽略)
|
||||
ip := ""
|
||||
ua := ""
|
||||
apiPath := "/task/createTask"
|
||||
httpMethod := "POST"
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
ip = r.GetClientIp()
|
||||
ua = r.UserAgent()
|
||||
apiPath = r.URL.Path
|
||||
httpMethod = r.Method
|
||||
}
|
||||
_, _ = dao.OpLog.Insert(ctx, &entity.AsynchOpLog{
|
||||
IP: ip,
|
||||
UserAgent: ua,
|
||||
APIPath: apiPath,
|
||||
HttpMethod: httpMethod,
|
||||
BizName: req.BizName,
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
OpType: "createTask",
|
||||
Success: 1,
|
||||
ErrorMsg: "",
|
||||
CostMs: time.Since(startAt).Milliseconds(),
|
||||
RequestPayload: storedPayload,
|
||||
ResponsePayload: gdb.Map{
|
||||
"taskId": taskID,
|
||||
},
|
||||
})
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
@@ -127,3 +159,28 @@ func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (r
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) List(ctx context.Context, req *dto.ListTaskReq) (res *dto.ListTaskRes, err error) {
|
||||
pageNum, pageSize := 1, 10
|
||||
if req != nil && req.Page != nil {
|
||||
if req.Page.PageNum > 0 {
|
||||
pageNum = int(req.Page.PageNum)
|
||||
}
|
||||
if req.Page.PageSize > 0 {
|
||||
pageSize = int(req.Page.PageSize)
|
||||
}
|
||||
}
|
||||
modelName := ""
|
||||
taskID := ""
|
||||
var state *int
|
||||
if req != nil {
|
||||
modelName = req.ModelName
|
||||
taskID = req.TaskID
|
||||
state = req.State
|
||||
}
|
||||
list, total, err := dao.Task.List(ctx, pageNum, pageSize, modelName, taskID, state)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.ListTaskRes{List: list, Total: total}, nil
|
||||
}
|
||||
|
||||
38
service/tmp_store.go
Normal file
38
service/tmp_store.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
// saveTmpResult 将模型输出写入临时文件,用于 OSS 上传失败后的“仅重试 OSS”。
|
||||
func saveTmpResult(taskID string, data []byte, ext string) (string, error) {
|
||||
dir := filepath.Join(os.TempDir(), "model-asynch")
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
path := filepath.Join(dir, fmt.Sprintf("%s%s", taskID, ext))
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func loadTmpResult(path string) ([]byte, error) {
|
||||
return os.ReadFile(path)
|
||||
}
|
||||
|
||||
func deleteTmpResult(path string) {
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(path)
|
||||
}
|
||||
|
||||
@@ -146,17 +146,43 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
"inputRef": t.InputRef,
|
||||
}
|
||||
}
|
||||
data, err := InvokeModel(ctx, m, payload)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
var (
|
||||
data []byte
|
||||
contentType string
|
||||
ext string
|
||||
)
|
||||
|
||||
// phase=1 表示模型已成功但 OSS 上传失败:优先从临时文件加载,避免重复跑模型
|
||||
if t.Phase == 1 && strings.TrimSpace(t.TmpFile) != "" {
|
||||
data, err = loadTmpResult(t.TmpFile)
|
||||
if err == nil && len(data) > 0 {
|
||||
contentType, ext = DetectFileType(data)
|
||||
} else {
|
||||
// 临时文件不可用:回退重新调用模型
|
||||
data = nil
|
||||
}
|
||||
}
|
||||
if data == nil {
|
||||
data, err = InvokeModel(ctx, m, payload)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
contentType, ext = DetectFileType(data)
|
||||
// 将模型输出写入临时文件,后续若 OSS 失败可只重试 OSS
|
||||
tmpPath, err := saveTmpResult(t.TaskID, data, ext)
|
||||
if err == nil && tmpPath != "" {
|
||||
t.TmpFile = tmpPath
|
||||
t.Phase = 1
|
||||
_ = dao.Task.UpdateTmpAfterModelGlobal(ctx, t.Id, tmpPath)
|
||||
}
|
||||
}
|
||||
|
||||
// 4) 存储 OSS/MinIO
|
||||
contentType, ext := DetectFileType(data)
|
||||
// 4) 存储 OSS
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
// OSS 阶段失败:保留临时文件,下一轮仅重试 OSS
|
||||
_ = dao.Task.UpdateFailedKeepTmpGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -170,6 +196,8 @@ func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
}
|
||||
// 成功后清理临时文件
|
||||
deleteTmpResult(t.TmpFile)
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
|
||||
Reference in New Issue
Block a user