first commit

This commit is contained in:
2026-04-23 13:53:09 +08:00
commit 9de47fa5b8
34 changed files with 2764 additions and 0 deletions

91
service/cleaner.go Normal file
View File

@@ -0,0 +1,91 @@
package service
import (
"context"
"time"
"model-asynch/dao"
"github.com/gogf/gf/v2/frame/g"
)
var Cleaner = &cleaner{}
type cleaner struct{}
func (c *cleaner) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.cleaner.enabled", true).Bool() {
g.Log().Warningf(ctx, "[cleaner] asynch.cleaner.enabled=falsecleaner 未启动")
return
}
intervalStr := g.Cfg().MustGet(ctx, "asynch.cleaner.interval", "10m").String()
interval, _ := time.ParseDuration(intervalStr)
if interval <= 0 {
interval = 10 * time.Minute
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
g.Log().Infof(ctx, "[cleaner] started, interval=%s", interval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
c.runOnce(ctx)
}
}
}()
}
func (c *cleaner) runOnce(ctx context.Context) {
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
} else {
for _, t := range expired {
_ = Storage.DeleteByTask(ctx, t)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
}
// 2) 超时任务标失败
timeoutStr := g.Cfg().MustGet(ctx, "asynch.worker.taskTimeout", "30m").String()
timeout, _ := time.ParseDuration(timeoutStr)
if timeout > 0 {
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, timeout, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
} else {
for _, t := range list {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
}
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
}
}
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
} else {
for _, t := range retryable {
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
}
// 4) 超过重试次数仍失败(state=3)的任务:硬删除 + OSS
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
if err != nil {
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
} else {
for _, t := range exhausted {
_ = Storage.DeleteByTask(ctx, t)
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
}
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
}
}

35
service/file_detect.go Normal file
View File

@@ -0,0 +1,35 @@
package service
import (
"net/http"
"strings"
)
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
func DetectFileType(data []byte) (contentType string, ext string) {
if len(data) == 0 {
return "application/octet-stream", ""
}
ct := http.DetectContentType(data)
switch ct {
case "audio/mpeg":
return ct, ".mp3"
case "audio/wave", "audio/wav", "audio/x-wav":
return ct, ".wav"
case "video/mp4":
return ct, ".mp4"
case "image/png":
return ct, ".png"
case "image/jpeg":
return ct, ".jpg"
case "application/pdf":
return ct, ".pdf"
default:
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json
if parts := strings.Split(ct, "/"); len(parts) == 2 {
return ct, "." + parts[1]
}
return ct, ""
}
}

54
service/headers.go Normal file
View File

@@ -0,0 +1,54 @@
package service
import (
"context"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
)
// asyncCtx 固化异步执行所需的 token/user避免请求结束后丢失仅在“同请求内起 goroutine”有用
// 本项目当前是“落库 + 后台 worker”模式因此还会把必要信息持久化到任务表的 request_payload 中。
func asyncCtx(ctx context.Context) context.Context {
asyncCtx := context.WithoutCancel(ctx)
if r := g.RequestFromCtx(ctx); r != nil {
if token := r.Header.Get("Authorization"); token != "" {
asyncCtx = context.WithValue(asyncCtx, "token", token)
}
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
}
}
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
asyncCtx = context.WithValue(asyncCtx, "user", user)
}
return asyncCtx
}
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo
func forwardHeaders(ctx context.Context) map[string]string {
headers := make(map[string]string)
if token, ok := ctx.Value("token").(string); ok && token != "" {
headers["Authorization"] = token
}
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
headers["X-User-Info"] = x
}
// 兜底:从请求头拿
if r := g.RequestFromCtx(ctx); r != nil {
if headers["Authorization"] == "" {
if token := r.Header.Get("Authorization"); token != "" {
headers["Authorization"] = token
}
}
if headers["X-User-Info"] == "" {
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
headers["X-User-Info"] = userInfo
}
}
}
return headers
}

151
service/model_invoker.go Normal file
View File

@@ -0,0 +1,151 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"model-asynch/model/entity"
)
func parseAPIKeyHeader(apiKey string) (k, v string) {
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return "", ""
}
// 支持两种写法:
// 1) HeaderName:HeaderValue推荐
// 2) HeaderName=HeaderValue兼容
if strings.Contains(apiKey, ":") {
parts := strings.SplitN(apiKey, ":", 2)
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
if strings.Contains(apiKey, "=") {
parts := strings.SplitN(apiKey, "=", 2)
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
}
// 只给了 value不做注入避免注入非法 header
return "", ""
}
func payloadToQuery(payload any) (url.Values, error) {
if payload == nil {
return url.Values{}, nil
}
// 统一转成 map[string]any
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
m := map[string]any{}
if err := json.Unmarshal(b, &m); err != nil {
return nil, err
}
q := url.Values{}
for k, v := range m {
if v == nil {
continue
}
// 复杂类型直接 json 字符串化
switch vv := v.(type) {
case string:
q.Set(k, vv)
case float64, bool, int, int64, uint64:
q.Set(k, fmt.Sprintf("%v", vv))
default:
bs, _ := json.Marshal(v)
q.Set(k, string(bs))
}
}
return q, nil
}
// InvokeModel 调用模型服务,返回二进制结果
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byte, error) {
if m == nil || m.BaseURL == "" {
return nil, fmt.Errorf("模型配置不完整")
}
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
if strings.TrimSpace(m.Route) == "" {
url = strings.TrimRight(m.BaseURL, "/")
}
timeout := time.Duration(m.TimeoutMs) * time.Millisecond
if timeout <= 0 {
timeout = 60 * time.Second
}
client := &http.Client{Timeout: timeout}
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
if method == "" {
method = http.MethodPost
}
var (
req *http.Request
err error
)
switch method {
case http.MethodGet:
q, err := payloadToQuery(payload)
if err != nil {
return nil, err
}
if len(q) > 0 {
if strings.Contains(url, "?") {
url = url + "&" + q.Encode()
} else {
url = url + "?" + q.Encode()
}
}
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
default:
bodyBytes, err := json.Marshal(payload)
if err != nil {
return nil, err
}
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
}
if err != nil {
return nil, err
}
// 透传必要头部(如 Authorization / X-User-Info以及注入模型配置里的 api_key
for k, v := range forwardHeaders(ctx) {
if v != "" {
req.Header.Set(k, v)
}
}
if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" {
req.Header.Set(hk, hv)
}
if method != http.MethodGet {
req.Header.Set("Content-Type", "application/json")
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// 尽量把错误体带回去,方便排查
msg := string(b)
if len(msg) > 2000 {
msg = msg[:2000]
}
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
}
return b, nil
}

109
service/model_service.go Normal file
View File

@@ -0,0 +1,109 @@
package service
import (
"context"
"errors"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
)
var Model = &modelService{}
type modelService struct{}
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
m := &entity.AsynchModel{
ModelName: req.ModelName,
BaseURL: req.BaseURL,
Route: req.Route,
HttpMethod: req.HttpMethod,
APIKey: req.APIKey,
Enabled: req.Enabled,
MaxConcurrency: req.MaxConcurrency,
QueueLimit: req.QueueLimit,
TimeoutMs: req.TimeoutMs,
RetryTimes: req.RetryTimes,
AutoCleanSeconds: req.AutoCleanSeconds,
Remark: req.Remark,
}
if m.HttpMethod == "" {
m.HttpMethod = "POST"
}
if m.Enabled == 0 {
m.Enabled = 1
}
if m.MaxConcurrency <= 0 {
m.MaxConcurrency = 10
}
if m.QueueLimit <= 0 {
m.QueueLimit = 1000
}
if m.TimeoutMs <= 0 {
m.TimeoutMs = 60000
}
if m.AutoCleanSeconds <= 0 {
m.AutoCleanSeconds = 86400
}
id, err := dao.Model.Insert(ctx, m)
if err != nil {
return nil, err
}
return &dto.CreateModelRes{ID: id}, nil
}
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
data := map[string]any{}
if req.BaseURL != "" {
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
}
if req.Route != "" {
data[entity.AsynchModelCol.Route] = req.Route
}
if req.HttpMethod != nil && *req.HttpMethod != "" {
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
}
if req.APIKey != nil {
data[entity.AsynchModelCol.APIKey] = *req.APIKey
}
if req.Enabled != nil {
data[entity.AsynchModelCol.Enabled] = *req.Enabled
}
if req.MaxConcurrency != nil {
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
}
if req.QueueLimit != nil {
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
}
if req.TimeoutMs != nil {
data[entity.AsynchModelCol.TimeoutMs] = *req.TimeoutMs
}
if req.RetryTimes != nil {
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
}
if req.AutoCleanSeconds != nil {
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
}
if req.Remark != nil {
data[entity.AsynchModelCol.Remark] = *req.Remark
}
if len(data) == 0 {
return errors.New("无可更新字段")
}
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
return err
}
func (s *modelService) Delete(ctx context.Context, id int64) error {
_, err := dao.Model.DeleteByID(ctx, id)
return err
}
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
return dao.Model.GetByID(ctx, id)
}
func (s *modelService) List(ctx context.Context, pageNum, pageSize int) (list []*entity.AsynchModel, total int64, err error) {
return dao.Model.List(ctx, pageNum, pageSize)
}

25
service/payload.go Normal file
View File

@@ -0,0 +1,25 @@
package service
import "github.com/gogf/gf/v2/util/gconv"
// parseStoredPayload 解析入库的 request_payload拆出模型调用 payload 与透传 headers
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
func parseStoredPayload(v any) (payload any, headers map[string]string) {
if v == nil {
return nil, nil
}
m := gconv.Map(v)
if len(m) == 0 {
return v, nil
}
if h, ok := m["headers"]; ok {
headers = gconv.MapStrStr(h)
}
if p, ok := m["payload"]; ok {
payload = p
} else {
payload = v
}
return
}

56
service/semaphore.go Normal file
View File

@@ -0,0 +1,56 @@
package service
import (
"context"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
var acquireLua = `
local current = tonumber(redis.call("GET", KEYS[1]) or "0")
local max = tonumber(ARGV[1])
local ttl = tonumber(ARGV[2])
if current >= max then
return 0
end
current = redis.call("INCR", KEYS[1])
if current == 1 then
redis.call("EXPIRE", KEYS[1], ttl)
end
if current > max then
redis.call("DECR", KEYS[1])
return 0
end
return 1
`
var releaseLua = `
local current = tonumber(redis.call("DECR", KEYS[1]) or "0")
if current <= 0 then
redis.call("DEL", KEYS[1])
end
return 1
`
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
if max <= 0 {
// 不限制
return true, nil
}
if ttlSeconds <= 0 {
ttlSeconds = 3600
}
r, err := g.Redis().Do(ctx, "EVAL", acquireLua, 1, key, max, ttlSeconds)
if err != nil {
return false, fmt.Errorf("获取并发令牌失败: %w", err)
}
return gconv.Int(r) == 1, nil
}
func releaseSemaphore(ctx context.Context, key string) error {
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
return err
}

19
service/storage.go Normal file
View File

@@ -0,0 +1,19 @@
package service
import (
"context"
"errors"
"model-asynch/model/entity"
)
// StorageService 结果存储OSS/MinIO抽象
type StorageService interface {
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
DeleteByTask(ctx context.Context, t *entity.AsynchTask) error
}
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO
var Storage StorageService = &ossStorage{}
var ErrStorageNotConfigured = errors.New("存储未配置")

90
service/storage_oss.go Normal file
View File

@@ -0,0 +1,90 @@
package service
import (
"bytes"
"context"
"fmt"
"mime/multipart"
"model-asynch/model/entity"
commonHttp "gitea.com/red-future/common/http"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
// 对接你们的 oss 文件服务POST oss/file/uploadFile (multipart/form-data)
type ossStorage struct{}
type uploadFileResponse struct {
FileURL string `json:"fileURL"` // 文件 URL
FileSize int `json:"fileSize"` // 文件大小(字节)
FileName string `json:"fileName"` // 文件名
FileFormat string `json:"fileFormat"` // 文件格式
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
}
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
// ossUrl := g.Cfg().MustGet(ctx, "oss.addr", "192.168.3.30:9000").String()
// multipart
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
ext := fileExt
if ext == "" {
ext = ".bin"
}
if ext[0] != '.' {
ext = "." + ext
}
part, err := writer.CreateFormFile("file", "")
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
contentType := writer.FormDataContentType()
if err := writer.Close(); err != nil {
return "", err
}
headers := forwardHeaders(ctx)
headers["Content-Type"] = contentType
fullURL := "oss/file/uploadFile"
g.Log().Infof(ctx, "[OSS] upload start url=%s size=%d", fullURL, len(data))
var resp uploadFileResponse
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
return "", err
}
if resp.FileURL == "" {
return "", fmt.Errorf("OSS服务返回错误: 上传失败")
}
g.Log().Infof(ctx, "[OSS] upload success url=%s filename=%s size=%d format=%s", resp.FileURL, resp.FileName, resp.FileSize, resp.FileFormat)
return resp.FileURL, nil
}
func (s *ossStorage) DeleteByTask(ctx context.Context, t *entity.AsynchTask) error {
// 你说当前 oss 暂时没有删除接口:这里保留方法占位,后续补接口时直接实现
_ = ctx
_ = t
return nil
}
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx给 worker 调 OSS 用
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
if headers == nil {
return ctx
}
if v := gconv.String(headers["Authorization"]); v != "" {
ctx = context.WithValue(ctx, "token", v)
}
if v := gconv.String(headers["X-User-Info"]); v != "" {
ctx = context.WithValue(ctx, "xUserInfo", v)
}
return ctx
}

129
service/task_service.go Normal file
View File

@@ -0,0 +1,129 @@
package service
import (
"context"
"errors"
"time"
"model-asynch/dao"
"model-asynch/model/dto"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/os/gtime"
"github.com/google/uuid"
)
var Task = &taskService{}
type taskService struct{}
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
// 固化 token/user 等信息
ctx = asyncCtx(ctx)
// 1) 检查模型配置
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
if err != nil {
return nil, err
}
if m == nil || m.Enabled != 1 {
return nil, errors.New("模型不存在或未启用")
}
// 2) 排队上限(近似控制)
if m.QueueLimit > 0 {
cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName)
if err != nil {
return nil, err
}
if cnt >= int64(m.QueueLimit) {
return nil, errors.New("任务排队已满,请稍后再试")
}
}
taskID := uuid.NewString()
// 将调用模型的 payload 与透传头信息一起存入 request_payload供后台 worker 使用
storedPayload := map[string]any{
"payload": req.RequestPayload,
"headers": forwardHeaders(ctx),
}
t := &entity.AsynchTask{
ModelName: req.ModelName,
TaskID: taskID,
State: 0,
InputRef: req.InputRef,
RequestPayload: storedPayload,
}
_, err = dao.Task.Insert(ctx, t)
if err != nil {
return nil, err
}
return &dto.CreateTaskRes{TaskID: taskID}, nil
}
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
t, err := dao.Task.GetByTaskID(ctx, taskID)
if err != nil {
return nil, err
}
if t == nil {
return nil, errors.New("任务不存在")
}
return &dto.GetTaskResultRes{
OssFile: t.OssFile,
State: t.State,
}, nil
}
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
if req == nil || len(req.TaskIDs) == 0 {
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
}
// 1) 先查当前租户下的任务列表
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
if err != nil {
return nil, err
}
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
now := time.Now()
for _, t := range list {
if t == nil {
continue
}
if t.State != 2 {
continue
}
// 按模型配置决定保留时间
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
if err != nil {
return nil, err
}
retainSeconds := 86400
if m != nil && m.AutoCleanSeconds > 0 {
retainSeconds = m.AutoCleanSeconds
}
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
// 为了本次返回一致性,内存里也更新
t.State = 4
t.ExpireAt = expireAt
}
// 3) 组装返回
items := make([]dto.GetTaskBatchItem, 0, len(list))
for _, t := range list {
if t == nil {
continue
}
items = append(items, dto.GetTaskBatchItem{
TaskID: t.TaskID,
State: t.State,
OssFile: t.OssFile,
})
}
return &dto.GetTaskBatchRes{List: items}, nil
}

177
service/worker.go Normal file
View File

@@ -0,0 +1,177 @@
package service
import (
"context"
"fmt"
"strings"
"sync"
"time"
"model-asynch/dao"
"model-asynch/model/entity"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/os/grpool"
)
var AsyncWorker = &asyncWorker{}
type asyncWorker struct {
mu sync.Mutex
pool *grpool.Pool
closed bool
}
func (w *asyncWorker) Start(ctx context.Context) {
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=falseworker 未启动")
return
}
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
return
}
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
if limit <= 0 {
limit = 1
}
w.pool = grpool.New(limit)
w.closed = false
go w.pollLoop(ctx)
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
}
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
func (w *asyncWorker) Stop(ctx context.Context) {
w.mu.Lock()
defer w.mu.Unlock()
if w.pool != nil && !w.pool.IsClosed() {
w.pool.Close()
w.closed = true
g.Log().Infof(ctx, "[worker] stopped")
}
}
func (w *asyncWorker) pollLoop(ctx context.Context) {
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
pollInterval, _ := time.ParseDuration(pollIntervalStr)
if pollInterval <= 0 {
pollInterval = time.Second
}
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
if batchSize <= 0 {
batchSize = 1
}
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
w.Stop(ctx)
return
case <-ticker.C:
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
if err != nil {
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
continue
}
if len(tasks) == 0 {
continue
}
for _, t := range tasks {
task := t // 防止闭包捕获循环变量
w.mu.Lock()
p := w.pool
w.mu.Unlock()
if p == nil || p.IsClosed() {
// 池已关闭,回滚任务
_ = w.rollbackToPending(ctx, task.Id)
continue
}
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
w.handleOne(ctx, task)
}, func(ctx context.Context, err error) {
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
}
})
}
}
}
}
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
// 从任务入库的 request_payload 里恢复 payload + headers给 OSS 上传透传鉴权用
payload, headers := parseStoredPayload(t.RequestPayload)
if len(headers) > 0 {
ctx = setTaskHeadersToCtx(ctx, headers)
}
// 1) 拉取模型配置
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if m == nil || m.Enabled != 1 {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
return
}
// 2) 分布式并发限制(按 tenant+model
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
leaseSeconds := int64(3600) // 兜底1小时
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
if !acquired {
// 并发满了:放回排队(重新置回 state=0下一轮再抢占
_ = w.rollbackToPending(ctx, t.Id)
return
}
defer func() {
_ = releaseSemaphore(ctx, semKey)
}()
// 3) 调用模型服务
if payload == nil {
payload = map[string]any{
"taskId": t.TaskID,
"inputRef": t.InputRef,
}
}
data, err := InvokeModel(ctx, m, payload)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 4) 存储 OSS/MinIO
contentType, ext := DetectFileType(data)
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
if err != nil {
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
return
}
// 5) 更新任务状态成功
// 注意expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
fileType := strings.TrimPrefix(ext, ".")
if fileType == "" {
fileType = contentType
}
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
return
}
}
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
return dao.Task.RollbackToPendingGlobal(ctx, id)
}