first commit
This commit is contained in:
91
service/cleaner.go
Normal file
91
service/cleaner.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
var Cleaner = &cleaner{}
|
||||
|
||||
type cleaner struct{}
|
||||
|
||||
func (c *cleaner) Start(ctx context.Context) {
|
||||
if !g.Cfg().MustGet(ctx, "asynch.cleaner.enabled", true).Bool() {
|
||||
g.Log().Warningf(ctx, "[cleaner] asynch.cleaner.enabled=false,cleaner 未启动")
|
||||
return
|
||||
}
|
||||
intervalStr := g.Cfg().MustGet(ctx, "asynch.cleaner.interval", "10m").String()
|
||||
interval, _ := time.ParseDuration(intervalStr)
|
||||
if interval <= 0 {
|
||||
interval = 10 * time.Minute
|
||||
}
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
g.Log().Infof(ctx, "[cleaner] started, interval=%s", interval)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.runOnce(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (c *cleaner) runOnce(ctx context.Context) {
|
||||
// 1) 清理已下载(state=4)且过期的任务(硬删除 + OSS)
|
||||
expired, err := dao.Task.ListExpiredDownloadedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list expired(downloaded) error: %v", err)
|
||||
} else {
|
||||
for _, t := range expired {
|
||||
_ = Storage.DeleteByTask(ctx, t)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] expired(downloaded) cleaned, count=%d", len(expired))
|
||||
}
|
||||
|
||||
// 2) 超时任务标失败
|
||||
timeoutStr := g.Cfg().MustGet(ctx, "asynch.worker.taskTimeout", "30m").String()
|
||||
timeout, _ := time.ParseDuration(timeoutStr)
|
||||
if timeout > 0 {
|
||||
list, err := dao.Task.ListTimeoutTasksGlobal(ctx, timeout, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list timeout error: %v", err)
|
||||
} else {
|
||||
for _, t := range list {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "任务超时自动失败")
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] timeout cleaned, count=%d", len(list))
|
||||
}
|
||||
}
|
||||
|
||||
// 3) 失败(state=3)的任务按模型配置 retry_times 重新入队(放到队尾)
|
||||
retryable, err := dao.Task.ListFailedRetryableGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed retryable error: %v", err)
|
||||
} else {
|
||||
for _, t := range retryable {
|
||||
_ = dao.Task.RequeueForRetryGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed retryable cleaned, count=%d", len(retryable))
|
||||
}
|
||||
|
||||
// 4) 超过重试次数仍失败(state=3)的任务:硬删除 + OSS
|
||||
exhausted, err := dao.Task.ListFailedExhaustedGlobal(ctx, 200)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[cleaner] list failed exhausted error: %v", err)
|
||||
} else {
|
||||
for _, t := range exhausted {
|
||||
_ = Storage.DeleteByTask(ctx, t)
|
||||
_ = dao.Task.HardDeleteByIDGlobal(ctx, t.Id)
|
||||
}
|
||||
g.Log().Infof(ctx, "[cleaner] failed exhausted cleaned, count=%d", len(exhausted))
|
||||
}
|
||||
}
|
||||
35
service/file_detect.go
Normal file
35
service/file_detect.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DetectFileType 根据返回的二进制内容推断 contentType + 扩展名(尽量稳定)
|
||||
func DetectFileType(data []byte) (contentType string, ext string) {
|
||||
if len(data) == 0 {
|
||||
return "application/octet-stream", ""
|
||||
}
|
||||
ct := http.DetectContentType(data)
|
||||
switch ct {
|
||||
case "audio/mpeg":
|
||||
return ct, ".mp3"
|
||||
case "audio/wave", "audio/wav", "audio/x-wav":
|
||||
return ct, ".wav"
|
||||
case "video/mp4":
|
||||
return ct, ".mp4"
|
||||
case "image/png":
|
||||
return ct, ".png"
|
||||
case "image/jpeg":
|
||||
return ct, ".jpg"
|
||||
case "application/pdf":
|
||||
return ct, ".pdf"
|
||||
default:
|
||||
// 兜底:尝试从 ct 截取 subtype 作为后缀(例如 application/json)
|
||||
if parts := strings.Split(ct, "/"); len(parts) == 2 {
|
||||
return ct, "." + parts[1]
|
||||
}
|
||||
return ct, ""
|
||||
}
|
||||
}
|
||||
|
||||
54
service/headers.go
Normal file
54
service/headers.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// asyncCtx 固化异步执行所需的 token/user,避免请求结束后丢失(仅在“同请求内起 goroutine”有用)。
|
||||
// 本项目当前是“落库 + 后台 worker”模式,因此还会把必要信息持久化到任务表的 request_payload 中。
|
||||
func asyncCtx(ctx context.Context) context.Context {
|
||||
asyncCtx := context.WithoutCancel(ctx)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "token", token)
|
||||
}
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
asyncCtx = context.WithValue(asyncCtx, "xUserInfo", userInfo)
|
||||
}
|
||||
}
|
||||
if user, err := utils.GetUserInfo(ctx); err == nil && user != nil {
|
||||
asyncCtx = context.WithValue(asyncCtx, "user", user)
|
||||
}
|
||||
return asyncCtx
|
||||
}
|
||||
|
||||
// forwardHeaders 透传调用链路中必须的头信息(优先使用 ctx 里固化的 token / xUserInfo)。
|
||||
func forwardHeaders(ctx context.Context) map[string]string {
|
||||
headers := make(map[string]string)
|
||||
|
||||
if token, ok := ctx.Value("token").(string); ok && token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
if x, ok := ctx.Value("xUserInfo").(string); ok && x != "" {
|
||||
headers["X-User-Info"] = x
|
||||
}
|
||||
|
||||
// 兜底:从请求头拿
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
if headers["Authorization"] == "" {
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
headers["Authorization"] = token
|
||||
}
|
||||
}
|
||||
if headers["X-User-Info"] == "" {
|
||||
if userInfo := r.Header.Get("X-User-Info"); userInfo != "" {
|
||||
headers["X-User-Info"] = userInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
151
service/model_invoker.go
Normal file
151
service/model_invoker.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
func parseAPIKeyHeader(apiKey string) (k, v string) {
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return "", ""
|
||||
}
|
||||
// 支持两种写法:
|
||||
// 1) HeaderName:HeaderValue(推荐)
|
||||
// 2) HeaderName=HeaderValue(兼容)
|
||||
if strings.Contains(apiKey, ":") {
|
||||
parts := strings.SplitN(apiKey, ":", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
if strings.Contains(apiKey, "=") {
|
||||
parts := strings.SplitN(apiKey, "=", 2)
|
||||
return strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1])
|
||||
}
|
||||
// 只给了 value:不做注入(避免注入非法 header)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
func payloadToQuery(payload any) (url.Values, error) {
|
||||
if payload == nil {
|
||||
return url.Values{}, nil
|
||||
}
|
||||
// 统一转成 map[string]any
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := map[string]any{}
|
||||
if err := json.Unmarshal(b, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := url.Values{}
|
||||
for k, v := range m {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
// 复杂类型直接 json 字符串化
|
||||
switch vv := v.(type) {
|
||||
case string:
|
||||
q.Set(k, vv)
|
||||
case float64, bool, int, int64, uint64:
|
||||
q.Set(k, fmt.Sprintf("%v", vv))
|
||||
default:
|
||||
bs, _ := json.Marshal(v)
|
||||
q.Set(k, string(bs))
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// InvokeModel 调用模型服务,返回二进制结果
|
||||
func InvokeModel(ctx context.Context, m *entity.AsynchModel, payload any) ([]byte, error) {
|
||||
if m == nil || m.BaseURL == "" {
|
||||
return nil, fmt.Errorf("模型配置不完整")
|
||||
}
|
||||
url := strings.TrimRight(m.BaseURL, "/") + "/" + strings.TrimLeft(m.Route, "/")
|
||||
if strings.TrimSpace(m.Route) == "" {
|
||||
url = strings.TrimRight(m.BaseURL, "/")
|
||||
}
|
||||
|
||||
timeout := time.Duration(m.TimeoutMs) * time.Millisecond
|
||||
if timeout <= 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
method := strings.ToUpper(strings.TrimSpace(m.HttpMethod))
|
||||
if method == "" {
|
||||
method = http.MethodPost
|
||||
}
|
||||
|
||||
var (
|
||||
req *http.Request
|
||||
err error
|
||||
)
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
q, err := payloadToQuery(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(q) > 0 {
|
||||
if strings.Contains(url, "?") {
|
||||
url = url + "&" + q.Encode()
|
||||
} else {
|
||||
url = url + "?" + q.Encode()
|
||||
}
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
default:
|
||||
bodyBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyBytes))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 透传必要头部(如 Authorization / X-User-Info),以及注入模型配置里的 api_key
|
||||
for k, v := range forwardHeaders(ctx) {
|
||||
if v != "" {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
}
|
||||
if hk, hv := parseAPIKeyHeader(m.APIKey); hk != "" && hv != "" {
|
||||
req.Header.Set(hk, hv)
|
||||
}
|
||||
if method != http.MethodGet {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
// 尽量把错误体带回去,方便排查
|
||||
msg := string(b)
|
||||
if len(msg) > 2000 {
|
||||
msg = msg[:2000]
|
||||
}
|
||||
return nil, fmt.Errorf("模型服务返回非2xx: %d, body=%s", resp.StatusCode, msg)
|
||||
}
|
||||
return b, nil
|
||||
}
|
||||
109
service/model_service.go
Normal file
109
service/model_service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
var Model = &modelService{}
|
||||
|
||||
type modelService struct{}
|
||||
|
||||
func (s *modelService) Create(ctx context.Context, req *dto.CreateModelReq) (res *dto.CreateModelRes, err error) {
|
||||
m := &entity.AsynchModel{
|
||||
ModelName: req.ModelName,
|
||||
BaseURL: req.BaseURL,
|
||||
Route: req.Route,
|
||||
HttpMethod: req.HttpMethod,
|
||||
APIKey: req.APIKey,
|
||||
Enabled: req.Enabled,
|
||||
MaxConcurrency: req.MaxConcurrency,
|
||||
QueueLimit: req.QueueLimit,
|
||||
TimeoutMs: req.TimeoutMs,
|
||||
RetryTimes: req.RetryTimes,
|
||||
AutoCleanSeconds: req.AutoCleanSeconds,
|
||||
Remark: req.Remark,
|
||||
}
|
||||
if m.HttpMethod == "" {
|
||||
m.HttpMethod = "POST"
|
||||
}
|
||||
if m.Enabled == 0 {
|
||||
m.Enabled = 1
|
||||
}
|
||||
if m.MaxConcurrency <= 0 {
|
||||
m.MaxConcurrency = 10
|
||||
}
|
||||
if m.QueueLimit <= 0 {
|
||||
m.QueueLimit = 1000
|
||||
}
|
||||
if m.TimeoutMs <= 0 {
|
||||
m.TimeoutMs = 60000
|
||||
}
|
||||
if m.AutoCleanSeconds <= 0 {
|
||||
m.AutoCleanSeconds = 86400
|
||||
}
|
||||
id, err := dao.Model.Insert(ctx, m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateModelRes{ID: id}, nil
|
||||
}
|
||||
|
||||
func (s *modelService) Update(ctx context.Context, req *dto.UpdateModelReq) error {
|
||||
data := map[string]any{}
|
||||
if req.BaseURL != "" {
|
||||
data[entity.AsynchModelCol.BaseURL] = req.BaseURL
|
||||
}
|
||||
if req.Route != "" {
|
||||
data[entity.AsynchModelCol.Route] = req.Route
|
||||
}
|
||||
if req.HttpMethod != nil && *req.HttpMethod != "" {
|
||||
data[entity.AsynchModelCol.HttpMethod] = *req.HttpMethod
|
||||
}
|
||||
if req.APIKey != nil {
|
||||
data[entity.AsynchModelCol.APIKey] = *req.APIKey
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
data[entity.AsynchModelCol.Enabled] = *req.Enabled
|
||||
}
|
||||
if req.MaxConcurrency != nil {
|
||||
data[entity.AsynchModelCol.MaxConcurrency] = *req.MaxConcurrency
|
||||
}
|
||||
if req.QueueLimit != nil {
|
||||
data[entity.AsynchModelCol.QueueLimit] = *req.QueueLimit
|
||||
}
|
||||
if req.TimeoutMs != nil {
|
||||
data[entity.AsynchModelCol.TimeoutMs] = *req.TimeoutMs
|
||||
}
|
||||
if req.RetryTimes != nil {
|
||||
data[entity.AsynchModelCol.RetryTimes] = *req.RetryTimes
|
||||
}
|
||||
if req.AutoCleanSeconds != nil {
|
||||
data[entity.AsynchModelCol.AutoCleanSeconds] = *req.AutoCleanSeconds
|
||||
}
|
||||
if req.Remark != nil {
|
||||
data[entity.AsynchModelCol.Remark] = *req.Remark
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return errors.New("无可更新字段")
|
||||
}
|
||||
_, err := dao.Model.UpdateByID(ctx, req.ID, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Delete(ctx context.Context, id int64) error {
|
||||
_, err := dao.Model.DeleteByID(ctx, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *modelService) Get(ctx context.Context, id int64) (*entity.AsynchModel, error) {
|
||||
return dao.Model.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *modelService) List(ctx context.Context, pageNum, pageSize int) (list []*entity.AsynchModel, total int64, err error) {
|
||||
return dao.Model.List(ctx, pageNum, pageSize)
|
||||
}
|
||||
25
service/payload.go
Normal file
25
service/payload.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "github.com/gogf/gf/v2/util/gconv"
|
||||
|
||||
// parseStoredPayload 解析入库的 request_payload,拆出模型调用 payload 与透传 headers
|
||||
// 入库格式:{"payload": <any>, "headers": {"Authorization": "...", "X-User-Info":"..."}}
|
||||
func parseStoredPayload(v any) (payload any, headers map[string]string) {
|
||||
if v == nil {
|
||||
return nil, nil
|
||||
}
|
||||
m := gconv.Map(v)
|
||||
if len(m) == 0 {
|
||||
return v, nil
|
||||
}
|
||||
if h, ok := m["headers"]; ok {
|
||||
headers = gconv.MapStrStr(h)
|
||||
}
|
||||
if p, ok := m["payload"]; ok {
|
||||
payload = p
|
||||
} else {
|
||||
payload = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
56
service/semaphore.go
Normal file
56
service/semaphore.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var acquireLua = `
|
||||
local current = tonumber(redis.call("GET", KEYS[1]) or "0")
|
||||
local max = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
if current >= max then
|
||||
return 0
|
||||
end
|
||||
current = redis.call("INCR", KEYS[1])
|
||||
if current == 1 then
|
||||
redis.call("EXPIRE", KEYS[1], ttl)
|
||||
end
|
||||
if current > max then
|
||||
redis.call("DECR", KEYS[1])
|
||||
return 0
|
||||
end
|
||||
return 1
|
||||
`
|
||||
|
||||
var releaseLua = `
|
||||
local current = tonumber(redis.call("DECR", KEYS[1]) or "0")
|
||||
if current <= 0 then
|
||||
redis.call("DEL", KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`
|
||||
|
||||
func acquireSemaphore(ctx context.Context, key string, max int, ttlSeconds int64) (bool, error) {
|
||||
if max <= 0 {
|
||||
// 不限制
|
||||
return true, nil
|
||||
}
|
||||
if ttlSeconds <= 0 {
|
||||
ttlSeconds = 3600
|
||||
}
|
||||
r, err := g.Redis().Do(ctx, "EVAL", acquireLua, 1, key, max, ttlSeconds)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取并发令牌失败: %w", err)
|
||||
}
|
||||
return gconv.Int(r) == 1, nil
|
||||
}
|
||||
|
||||
func releaseSemaphore(ctx context.Context, key string) error {
|
||||
_, err := g.Redis().Do(ctx, "EVAL", releaseLua, 1, key)
|
||||
return err
|
||||
}
|
||||
|
||||
19
service/storage.go
Normal file
19
service/storage.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
)
|
||||
|
||||
// StorageService 结果存储(OSS/MinIO)抽象
|
||||
type StorageService interface {
|
||||
UploadByTask(ctx context.Context, t *entity.AsynchTask, data []byte, fileExt string, contentType string) (ossURL string, err error)
|
||||
DeleteByTask(ctx context.Context, t *entity.AsynchTask) error
|
||||
}
|
||||
|
||||
// Storage 默认存储实现(优先对接你们的 oss 文件服务;必要时也可以切到 MinIO)
|
||||
var Storage StorageService = &ossStorage{}
|
||||
|
||||
var ErrStorageNotConfigured = errors.New("存储未配置")
|
||||
90
service/storage_oss.go
Normal file
90
service/storage_oss.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
|
||||
"model-asynch/model/entity"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// 对接你们的 oss 文件服务:POST oss/file/uploadFile (multipart/form-data)
|
||||
type ossStorage struct{}
|
||||
|
||||
type uploadFileResponse struct {
|
||||
FileURL string `json:"fileURL"` // 文件 URL
|
||||
FileSize int `json:"fileSize"` // 文件大小(字节)
|
||||
FileName string `json:"fileName"` // 文件名
|
||||
FileFormat string `json:"fileFormat"` // 文件格式
|
||||
FileAddressPrefix string `json:"fileAddressPrefix"` // 文件地址前缀
|
||||
}
|
||||
|
||||
func (s *ossStorage) UploadByTask(ctx context.Context, _ *entity.AsynchTask, data []byte, fileExt string, _ string) (ossURL string, err error) {
|
||||
// ossUrl := g.Cfg().MustGet(ctx, "oss.addr", "192.168.3.30:9000").String()
|
||||
|
||||
// multipart
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
ext := fileExt
|
||||
if ext == "" {
|
||||
ext = ".bin"
|
||||
}
|
||||
if ext[0] != '.' {
|
||||
ext = "." + ext
|
||||
}
|
||||
|
||||
part, err := writer.CreateFormFile("file", "")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
contentType := writer.FormDataContentType()
|
||||
if err := writer.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := forwardHeaders(ctx)
|
||||
headers["Content-Type"] = contentType
|
||||
|
||||
fullURL := "oss/file/uploadFile"
|
||||
g.Log().Infof(ctx, "[OSS] upload start url=%s size=%d", fullURL, len(data))
|
||||
|
||||
var resp uploadFileResponse
|
||||
if err := commonHttp.Post(ctx, fullURL, headers, &resp, body.Bytes()); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if resp.FileURL == "" {
|
||||
return "", fmt.Errorf("OSS服务返回错误: 上传失败")
|
||||
}
|
||||
g.Log().Infof(ctx, "[OSS] upload success url=%s filename=%s size=%d format=%s", resp.FileURL, resp.FileName, resp.FileSize, resp.FileFormat)
|
||||
return resp.FileURL, nil
|
||||
}
|
||||
|
||||
func (s *ossStorage) DeleteByTask(ctx context.Context, t *entity.AsynchTask) error {
|
||||
// 你说当前 oss 暂时没有删除接口:这里保留方法占位,后续补接口时直接实现
|
||||
_ = ctx
|
||||
_ = t
|
||||
return nil
|
||||
}
|
||||
|
||||
// setTaskHeadersToCtx 把任务入库时保存的 header 信息注入 ctx,给 worker 调 OSS 用
|
||||
func setTaskHeadersToCtx(ctx context.Context, headers map[string]string) context.Context {
|
||||
if headers == nil {
|
||||
return ctx
|
||||
}
|
||||
if v := gconv.String(headers["Authorization"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "token", v)
|
||||
}
|
||||
if v := gconv.String(headers["X-User-Info"]); v != "" {
|
||||
ctx = context.WithValue(ctx, "xUserInfo", v)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
129
service/task_service.go
Normal file
129
service/task_service.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/dto"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
var Task = &taskService{}
|
||||
|
||||
type taskService struct{}
|
||||
|
||||
func (s *taskService) Create(ctx context.Context, req *dto.CreateTaskReq) (res *dto.CreateTaskRes, err error) {
|
||||
// 固化 token/user 等信息
|
||||
ctx = asyncCtx(ctx)
|
||||
|
||||
// 1) 检查模型配置
|
||||
m, err := dao.Model.GetByModelName(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m == nil || m.Enabled != 1 {
|
||||
return nil, errors.New("模型不存在或未启用")
|
||||
}
|
||||
|
||||
// 2) 排队上限(近似控制)
|
||||
if m.QueueLimit > 0 {
|
||||
cnt, err := dao.Task.CountActiveByModel(ctx, req.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cnt >= int64(m.QueueLimit) {
|
||||
return nil, errors.New("任务排队已满,请稍后再试")
|
||||
}
|
||||
}
|
||||
|
||||
taskID := uuid.NewString()
|
||||
// 将调用模型的 payload 与透传头信息一起存入 request_payload,供后台 worker 使用
|
||||
storedPayload := map[string]any{
|
||||
"payload": req.RequestPayload,
|
||||
"headers": forwardHeaders(ctx),
|
||||
}
|
||||
|
||||
t := &entity.AsynchTask{
|
||||
ModelName: req.ModelName,
|
||||
TaskID: taskID,
|
||||
State: 0,
|
||||
InputRef: req.InputRef,
|
||||
RequestPayload: storedPayload,
|
||||
}
|
||||
_, err = dao.Task.Insert(ctx, t)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &dto.CreateTaskRes{TaskID: taskID}, nil
|
||||
}
|
||||
|
||||
func (s *taskService) GetResult(ctx context.Context, taskID string) (res *dto.GetTaskResultRes, err error) {
|
||||
t, err := dao.Task.GetByTaskID(ctx, taskID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if t == nil {
|
||||
return nil, errors.New("任务不存在")
|
||||
}
|
||||
return &dto.GetTaskResultRes{
|
||||
OssFile: t.OssFile,
|
||||
State: t.State,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetBatch 批量查询任务;将成功(state=2)的任务更新为已下载(state=4),并写入过期时间
|
||||
func (s *taskService) GetBatch(ctx context.Context, req *dto.GetTaskBatchReq) (res *dto.GetTaskBatchRes, err error) {
|
||||
if req == nil || len(req.TaskIDs) == 0 {
|
||||
return &dto.GetTaskBatchRes{List: []dto.GetTaskBatchItem{}}, nil
|
||||
}
|
||||
// 1) 先查当前租户下的任务列表
|
||||
list, err := dao.Task.ListByTaskIDs(ctx, req.TaskIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 对成功(state=2)的任务:标记为已下载(state=4)并写入 expire_at
|
||||
now := time.Now()
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
if t.State != 2 {
|
||||
continue
|
||||
}
|
||||
// 按模型配置决定保留时间
|
||||
m, err := dao.Model.GetByModelName(ctx, t.ModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
retainSeconds := 86400
|
||||
if m != nil && m.AutoCleanSeconds > 0 {
|
||||
retainSeconds = m.AutoCleanSeconds
|
||||
}
|
||||
expireAt := gtime.New(now.Add(time.Duration(retainSeconds) * time.Second))
|
||||
_ = dao.Task.MarkDownloadedByID(ctx, t.Id, expireAt)
|
||||
|
||||
// 为了本次返回一致性,内存里也更新
|
||||
t.State = 4
|
||||
t.ExpireAt = expireAt
|
||||
}
|
||||
|
||||
// 3) 组装返回
|
||||
items := make([]dto.GetTaskBatchItem, 0, len(list))
|
||||
for _, t := range list {
|
||||
if t == nil {
|
||||
continue
|
||||
}
|
||||
items = append(items, dto.GetTaskBatchItem{
|
||||
TaskID: t.TaskID,
|
||||
State: t.State,
|
||||
OssFile: t.OssFile,
|
||||
})
|
||||
}
|
||||
return &dto.GetTaskBatchRes{List: items}, nil
|
||||
}
|
||||
177
service/worker.go
Normal file
177
service/worker.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"model-asynch/dao"
|
||||
"model-asynch/model/entity"
|
||||
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/grpool"
|
||||
)
|
||||
|
||||
var AsyncWorker = &asyncWorker{}
|
||||
|
||||
type asyncWorker struct {
|
||||
mu sync.Mutex
|
||||
pool *grpool.Pool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (w *asyncWorker) Start(ctx context.Context) {
|
||||
if !g.Cfg().MustGet(ctx, "asynch.worker.enabled", true).Bool() {
|
||||
g.Log().Warningf(ctx, "[worker] asynch.worker.enabled=false,worker 未启动")
|
||||
return
|
||||
}
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.pool != nil && !w.pool.IsClosed() {
|
||||
return
|
||||
}
|
||||
limit := g.Cfg().MustGet(ctx, "asynch.worker.goroutines", 4).Int()
|
||||
if limit <= 0 {
|
||||
limit = 1
|
||||
}
|
||||
w.pool = grpool.New(limit)
|
||||
w.closed = false
|
||||
go w.pollLoop(ctx)
|
||||
g.Log().Infof(ctx, "[worker] started, grpool limit=%d", limit)
|
||||
}
|
||||
|
||||
// Stop 关闭协程池,确保 Ctrl+C 能完整退出。
|
||||
func (w *asyncWorker) Stop(ctx context.Context) {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
if w.pool != nil && !w.pool.IsClosed() {
|
||||
w.pool.Close()
|
||||
w.closed = true
|
||||
g.Log().Infof(ctx, "[worker] stopped")
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncWorker) pollLoop(ctx context.Context) {
|
||||
pollIntervalStr := g.Cfg().MustGet(ctx, "asynch.worker.pollInterval", "1s").String()
|
||||
pollInterval, _ := time.ParseDuration(pollIntervalStr)
|
||||
if pollInterval <= 0 {
|
||||
pollInterval = time.Second
|
||||
}
|
||||
batchSize := g.Cfg().MustGet(ctx, "asynch.worker.batchSize", 5).Int()
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
|
||||
g.Log().Infof(ctx, "[worker] poll loop started, poll=%s batch=%d", pollInterval, batchSize)
|
||||
ticker := time.NewTicker(pollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
w.Stop(ctx)
|
||||
return
|
||||
case <-ticker.C:
|
||||
tasks, err := dao.Task.ClaimPendingGlobal(ctx, batchSize)
|
||||
if err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] claim pending error: %v", err)
|
||||
continue
|
||||
}
|
||||
if len(tasks) == 0 {
|
||||
continue
|
||||
}
|
||||
for _, t := range tasks {
|
||||
task := t // 防止闭包捕获循环变量
|
||||
w.mu.Lock()
|
||||
p := w.pool
|
||||
w.mu.Unlock()
|
||||
if p == nil || p.IsClosed() {
|
||||
// 池已关闭,回滚任务
|
||||
_ = w.rollbackToPending(ctx, task.Id)
|
||||
continue
|
||||
}
|
||||
_ = p.AddWithRecover(ctx, func(ctx context.Context) {
|
||||
w.handleOne(ctx, task)
|
||||
}, func(ctx context.Context, err error) {
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, task.Id, fmt.Sprintf("worker panic: %v", err))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncWorker) handleOne(ctx context.Context, t *entity.AsynchTask) {
|
||||
// 从任务入库的 request_payload 里恢复 payload + headers,给 OSS 上传透传鉴权用
|
||||
payload, headers := parseStoredPayload(t.RequestPayload)
|
||||
if len(headers) > 0 {
|
||||
ctx = setTaskHeadersToCtx(ctx, headers)
|
||||
}
|
||||
|
||||
// 1) 拉取模型配置
|
||||
m, err := dao.Model.GetByModelNameForTenant(ctx, t.TenantId, t.ModelName)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
if m == nil || m.Enabled != 1 {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, "模型不存在或未启用")
|
||||
return
|
||||
}
|
||||
|
||||
// 2) 分布式并发限制(按 tenant+model)
|
||||
semKey := fmt.Sprintf("asynch:sem:%d:%s", t.TenantId, t.ModelName)
|
||||
leaseSeconds := int64(3600) // 兜底1小时
|
||||
acquired, err := acquireSemaphore(ctx, semKey, m.MaxConcurrency, leaseSeconds)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
// 并发满了:放回排队(重新置回 state=0),下一轮再抢占
|
||||
_ = w.rollbackToPending(ctx, t.Id)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = releaseSemaphore(ctx, semKey)
|
||||
}()
|
||||
|
||||
// 3) 调用模型服务
|
||||
if payload == nil {
|
||||
payload = map[string]any{
|
||||
"taskId": t.TaskID,
|
||||
"inputRef": t.InputRef,
|
||||
}
|
||||
}
|
||||
data, err := InvokeModel(ctx, m, payload)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 4) 存储 OSS/MinIO
|
||||
contentType, ext := DetectFileType(data)
|
||||
ossURL, err := Storage.UploadByTask(ctx, t, data, ext, contentType)
|
||||
if err != nil {
|
||||
_ = dao.Task.UpdateFailedGlobal(ctx, t.Id, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 5) 更新任务状态成功
|
||||
// 注意:expire_at 的计算改为“已下载(state=4)后开始计时”,因此成功(state=2)不写 expire_at。
|
||||
fileType := strings.TrimPrefix(ext, ".")
|
||||
if fileType == "" {
|
||||
fileType = contentType
|
||||
}
|
||||
if err := dao.Task.UpdateSuccessGlobal(ctx, t.Id, ossURL, fileType, int64(len(data)), nil); err != nil {
|
||||
g.Log().Errorf(ctx, "[worker] update success failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (w *asyncWorker) rollbackToPending(ctx context.Context, id int64) error {
|
||||
return dao.Task.RollbackToPendingGlobal(ctx, id)
|
||||
}
|
||||
Reference in New Issue
Block a user