refactor(service): 重构模型网关服务结构
This commit is contained in:
23
dao/model_gateway_logs_op.go
Normal file
23
dao/model_gateway_logs_op.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
)
|
||||
|
||||
var ModelGatewayLogsOp = &modelGatewayLogsOpDao{}
|
||||
|
||||
type modelGatewayLogsOpDao struct{}
|
||||
|
||||
// Insert 插入操作日志
|
||||
func (d *modelGatewayLogsOpDao) Insert(ctx context.Context, req *entity.ModelGatewayLogsOp) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).Insert(req)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
52
dao/model_gateway_logs_stat_dao.go
Normal file
52
dao/model_gateway_logs_stat_dao.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ModelGatewayLogsStat = &modelGatewayLogsStatDao{}
|
||||
|
||||
type modelGatewayLogsStatDao struct{}
|
||||
|
||||
// IncRequestCount 原子累加:按天+租户+创建人+模型 +1
|
||||
func (d *modelGatewayLogsStatDao) IncRequestCount(ctx context.Context, day time.Time, tenantId uint64, creator, modelName string) error {
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
|
||||
Data(&entity.ModelGatewayLogsStat{
|
||||
Day: gtime.New(day),
|
||||
TenantId: tenantId,
|
||||
Creator: creator,
|
||||
ModelName: modelName,
|
||||
RequestCount: 1,
|
||||
}).
|
||||
OnDuplicate("request_count", "request_count+1").
|
||||
Insert()
|
||||
return err
|
||||
}
|
||||
|
||||
// List 分页查询统计
|
||||
func (d *modelGatewayLogsStatDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayLogsStat) (list []*entity.ModelGatewayLogsStat, total int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameStat).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayLogsStatCols.Creator, req.Creator).
|
||||
WhereLike(entity.ModelGatewayLogsStatCols.ModelName, "%"+req.ModelName+"%").
|
||||
OrderDesc(entity.ModelGatewayLogsStatCols.Day).
|
||||
OrderDesc(entity.ModelGatewayLogsStatCols.RequestCount)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
@@ -9,71 +9,64 @@ import (
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Model = &modelDao{}
|
||||
var ModelGatewayModels = &modelGatewayModelsDao{}
|
||||
|
||||
type modelDao struct{}
|
||||
type modelGatewayModelsDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *modelDao) Insert(ctx context.Context, req *entity.AsynchModel) (id int64, err error) {
|
||||
m := new(entity.AsynchModel)
|
||||
err = gconv.Struct(req, &m)
|
||||
func (d *modelGatewayModelsDao) Insert(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).Insert(req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (d *modelDao) Update(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
func (d *modelGatewayModelsDao) Update(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Data(&req).
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Data(req).
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Delete 删除
|
||||
func (d *modelDao) Delete(ctx context.Context, req *entity.AsynchModel) (rows int64, err error) {
|
||||
func (d *modelGatewayModelsDao) Delete(ctx context.Context, req *entity.ModelGatewayModel) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Get 获取模型
|
||||
func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
var m entity.ModelGatewayModel
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
return &m, err
|
||||
}
|
||||
|
||||
//// Get 按ID获取(带租户隔离,只查当前租户)
|
||||
//func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
//func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
// var whereCondition strings.Builder
|
||||
// var queryParams []interface{}
|
||||
// if !g.IsEmpty(req.Id) {
|
||||
@@ -108,25 +101,25 @@ func (d *modelDao) Get(ctx context.Context, req *entity.AsynchModel, fields ...s
|
||||
// return
|
||||
//}
|
||||
|
||||
// GetByAcrossTenant 按ID获取(跨租户,查所有租户)
|
||||
func (d *modelDao) GetByAcrossTenant(ctx context.Context, req *entity.AsynchModel, fields ...string) (m *entity.AsynchModel, err error) {
|
||||
// GetByAcrossTenant 跨租户查询
|
||||
func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *entity.ModelGatewayModel, fields ...string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameModel).
|
||||
NoTenantId(ctx).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchModelCol.Id, req.Id).
|
||||
Where(entity.AsynchModelCol.Creator, req.Creator).
|
||||
Where(entity.AsynchModelCol.IsChatModel, req.IsChatModel).
|
||||
Where(entity.AsynchModelCol.ModelName, req.ModelName).
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
var m entity.ModelGatewayModel
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
return &m, err
|
||||
}
|
||||
|
||||
// GetByCreatorAndPlatform 按创建者、平台获取
|
||||
func (d *modelDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.AsynchModel, total int, err error) {
|
||||
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
||||
sql := `
|
||||
SELECT DISTINCT ON (model_name) *
|
||||
FROM asynch_models
|
||||
@@ -186,7 +179,7 @@ WHERE deleted_at IS NULL
|
||||
}
|
||||
|
||||
// GetByModelNameForTenant 后台任务使用:按 tenant_id + model_name 查询,不依赖 gfdb Hook/Trace/用户上下文
|
||||
func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (m *entity.AsynchModel, err error) {
|
||||
func (d *modelGatewayModelsDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64, modelName string) (*entity.ModelGatewayModel, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx,
|
||||
"SELECT * FROM "+public.TableNameModel+" WHERE tenant_id=? AND model_name=? AND deleted_at IS NULL LIMIT 1",
|
||||
tenantId, modelName,
|
||||
@@ -197,7 +190,7 @@ func (d *modelDao) GetByModelNameForTenant(ctx context.Context, tenantId uint64,
|
||||
if r.IsEmpty() {
|
||||
return nil, nil
|
||||
}
|
||||
var list []*entity.AsynchModel
|
||||
var list []*entity.ModelGatewayModel
|
||||
if err := r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
159
dao/model_gateway_task_dao.go
Normal file
159
dao/model_gateway_task_dao.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var ModelGatewayTask = &modelGatewayTaskDao{}
|
||||
|
||||
type modelGatewayTaskDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *modelGatewayTaskDao) Insert(ctx context.Context, req *entity.ModelGatewayTask) (id int64, err error) {
|
||||
m := new(entity.ModelGatewayTask)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新(按ID)
|
||||
func (d *modelGatewayTaskDao) Update(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Data(req).
|
||||
Where(entity.ModelGatewayTaskCol.Id, req.Id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Get 获取(按TaskID 或 ID)
|
||||
func (d *modelGatewayTaskDao) Get(ctx context.Context, req *entity.ModelGatewayTask) (m *entity.ModelGatewayTask, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
|
||||
Where(entity.ModelGatewayTaskCol.Id, req.Id).
|
||||
One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// List 分页查询
|
||||
func (d *modelGatewayTaskDao) List(ctx context.Context, pageNum, pageSize int, req *entity.ModelGatewayTask) (list []*entity.ModelGatewayTask, total int64, err error) {
|
||||
model := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Where(entity.ModelGatewayTaskCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayTaskCol.ModelName, "%"+req.ModelName+"%").
|
||||
Where(entity.ModelGatewayTaskCol.BizName, req.BizName).
|
||||
Where(entity.ModelGatewayTaskCol.State, req.State).
|
||||
Where(entity.ModelGatewayTaskCol.TaskID, req.TaskID).
|
||||
OrderDesc(entity.ModelGatewayTaskCol.CreatedAt)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
model = model.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := model.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete 删除(软删,按ID)
|
||||
func (d *modelGatewayTaskDao) Delete(ctx context.Context, req *entity.ModelGatewayTask) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.Id, req.Id).
|
||||
Delete()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// ListByTaskIDs 批量查询
|
||||
func (d *modelGatewayTaskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (list []*entity.ModelGatewayTask, err error) {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
WhereIn(entity.ModelGatewayTaskCol.TaskID, taskIDs).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// MarkDownloadedByID 标记已下载
|
||||
func (d *modelGatewayTaskDao) MarkDownloadedByID(ctx context.Context, id int64) error {
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
Where(entity.ModelGatewayTaskCol.State, 2).
|
||||
Data(map[string]any{entity.ModelGatewayTaskCol.State: 4}).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetPendingAsyncTasks 获取进行中的异步任务
|
||||
func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.ModelGatewayTask, error) {
|
||||
var tasks []*entity.ModelGatewayTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.State, 1).
|
||||
Limit(limit).
|
||||
Scan(&tasks)
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ======================== 事务抢占 ========================
|
||||
|
||||
// ClaimByID 按主键抢占,返回抢占后的任务
|
||||
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
||||
var task entity.ModelGatewayTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
r, err := tx.Model(public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
||||
Limit(1).
|
||||
LockUpdate().
|
||||
One()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
||||
}
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Model(public.TableNameTask).
|
||||
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
OmitEmpty().
|
||||
Update()
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
type opLogDao struct{}
|
||||
|
||||
var OpLog = &opLogDao{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *opLogDao) Insert(ctx context.Context, req *entity.LogsModelOp) (id int64, err error) {
|
||||
m := new(entity.LogsModelOp)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameOpLog).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
type statDao struct{}
|
||||
|
||||
var Stat = &statDao{}
|
||||
|
||||
// IncRequestCount 原子累加(支持分布式/多协程):按天+租户+创建人+模型 +1
|
||||
func (d *statDao) IncRequestCount(ctx context.Context, day time.Time, tenantId int64, creator, modelName string) error {
|
||||
sql := fmt.Sprintf(`
|
||||
INSERT INTO %s(day, tenant_id, creator, model_name, request_count, created_at, updated_at)
|
||||
VALUES(?, ?, ?, ?, 1, NOW(), NOW())
|
||||
ON CONFLICT (day, tenant_id, creator, model_name)
|
||||
DO UPDATE SET request_count = %s.request_count + 1, updated_at = NOW()`,
|
||||
public.TableNameStat, public.TableNameStat,
|
||||
)
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, gtime.New(day).Format("Y-m-d"), tenantId, creator, modelName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *statDao) List(ctx context.Context, pageNum, pageSize int, startDay, endDay string, tenantId *int64, creator, modelName string) (list []*entity.LogsModelStat, total int64, err error) {
|
||||
m := gfdb.DB(ctx).Model(ctx, public.TableNameStat).Where("1=1")
|
||||
if startDay != "" {
|
||||
m = m.Where("day >= ?", startDay)
|
||||
}
|
||||
if endDay != "" {
|
||||
m = m.Where("day <= ?", endDay)
|
||||
}
|
||||
if tenantId != nil {
|
||||
m = m.Where("tenant_id = ?", *tenantId)
|
||||
}
|
||||
if creator != "" {
|
||||
m = m.WhereLike("creator", "%"+creator+"%")
|
||||
}
|
||||
if modelName != "" {
|
||||
m = m.WhereLike("model_name", "%"+modelName+"%")
|
||||
}
|
||||
m = m.OrderDesc("day").OrderDesc("request_count")
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
m = m.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := m.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
125
dao/task_dao.go
125
dao/task_dao.go
@@ -1,125 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
var Task = &taskDao{}
|
||||
|
||||
type taskDao struct{}
|
||||
|
||||
// Insert 插入
|
||||
func (d *taskDao) Insert(ctx context.Context, req *entity.AsynchTask) (id int64, err error) {
|
||||
m := new(entity.AsynchTask)
|
||||
err = gconv.Struct(req, &m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
r, err := gfdb.DB(ctx).Model(ctx, public.TableNameTask).
|
||||
Insert(m)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.LastInsertId()
|
||||
}
|
||||
|
||||
// Update 更新
|
||||
func (d *taskDao) Update(ctx context.Context, req *entity.AsynchTask) (rows int64, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Data(&req).
|
||||
Where(entity.AsynchTaskCol.Id, req.Id).
|
||||
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||
Update()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// Get 获取
|
||||
func (d *taskDao) Get(ctx context.Context, req *entity.AsynchTask, fields ...string) (m *entity.AsynchTask, err error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
Where(entity.AsynchTaskCol.TaskID, req.TaskID).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = r.Struct(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// ListByTaskIDs 批量查询任务(会受 gfdb 的租户 Hook 影响,只返回当前租户数据)
|
||||
func (d *taskDao) ListByTaskIDs(ctx context.Context, taskIDs []string) (m []*entity.AsynchTask, err error) {
|
||||
if len(taskIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
OmitEmpty().
|
||||
WhereIn(entity.AsynchTaskCol.TaskID, taskIDs).
|
||||
All()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = r.Structs(&m)
|
||||
return
|
||||
}
|
||||
|
||||
// MarkDownloadedByID 将成功任务标记为已下载(state=4),并写入过期时间
|
||||
func (d *taskDao) MarkDownloadedByID(ctx context.Context, id int64, expireAt *gtime.Time) error {
|
||||
data := gdb.Map{
|
||||
entity.AsynchTaskCol.State: 4,
|
||||
entity.AsynchTaskCol.ExpireAt: expireAt,
|
||||
entity.AsynchTaskCol.Updater: "",
|
||||
}
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Where(entity.AsynchTaskCol.State, 2).
|
||||
Data(data).
|
||||
Update()
|
||||
return err
|
||||
}
|
||||
|
||||
// List 任务分页查询(受 gfdb 租户 Hook 影响)
|
||||
func (d *taskDao) List(ctx context.Context, pageNum, pageSize int, modelNameLike, taskIDLike string, state *int) (list []*entity.AsynchTask, total int64, err error) {
|
||||
m := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).Where("deleted_at IS NULL")
|
||||
if modelNameLike != "" {
|
||||
m = m.WhereLike(entity.AsynchTaskCol.ModelName, "%"+modelNameLike+"%")
|
||||
}
|
||||
if taskIDLike != "" {
|
||||
m = m.WhereLike(entity.AsynchTaskCol.TaskID, "%"+taskIDLike+"%")
|
||||
}
|
||||
if state != nil {
|
||||
m = m.Where(entity.AsynchTaskCol.State, *state)
|
||||
}
|
||||
m = m.OrderDesc(entity.AsynchTaskCol.CreatedAt)
|
||||
if pageNum > 0 && pageSize > 0 {
|
||||
m = m.Page(pageNum, pageSize)
|
||||
}
|
||||
r, totalInt, err := m.AllAndCount(false)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = gconv.Int64(totalInt)
|
||||
err = r.Structs(&list)
|
||||
return
|
||||
}
|
||||
|
||||
// GetPendingAsyncTasks 获取进行中的异步任务
|
||||
func (d *taskDao) GetPendingAsyncTasks(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where("state", 1).
|
||||
Where("deleted_at IS NULL").
|
||||
Limit(limit).
|
||||
Scan(&tasks)
|
||||
return tasks, err
|
||||
}
|
||||
@@ -1,220 +0,0 @@
|
||||
package dao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"model-gateway/consts/public"
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
)
|
||||
|
||||
// ======================== 查询辅助 ========================
|
||||
|
||||
const taskColumns = `id, tenant_id, creator, model_name, task_id, biz_name, callback_url, model_key, retry_count, input_ref, request_payload, phase, tmp_file`
|
||||
|
||||
// ======================== 通用 CRUD ========================
|
||||
|
||||
// UpdateFields 更新指定字段(map 版,用于必须更新零值的场景)
|
||||
func (d *taskDao) UpdateFields(ctx context.Context, id int64, data map[string]any) (int64, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).
|
||||
Model(ctx, public.TableNameTask).
|
||||
Data(data).
|
||||
Where(entity.AsynchTaskCol.Id, id).
|
||||
Update()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return r.RowsAffected()
|
||||
}
|
||||
|
||||
// execUpdate 内部辅助:执行原生 UPDATE,自动补 updated_at
|
||||
func execUpdate(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := gfdb.DB(ctx, public.DbNameModelGateway).Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// ======================== 事务抢占 ========================
|
||||
|
||||
func claimTasks(ctx context.Context, where string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 %s LIMIT 1 FOR UPDATE SKIP LOCKED`,
|
||||
taskColumns, public.TableNameTask, where,
|
||||
)
|
||||
r, err := tx.GetOne(sql, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
var task entity.AsynchTask
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, task.Id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tasks = []*entity.AsynchTask{&task}
|
||||
return nil
|
||||
})
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingGlobal 批量抢占 pending 任务
|
||||
func (d *taskDao) ClaimPendingGlobal(ctx context.Context, batchSize int) ([]*entity.AsynchTask, error) {
|
||||
if batchSize <= 0 {
|
||||
batchSize = 1
|
||||
}
|
||||
var tasks []*entity.AsynchTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
sql := fmt.Sprintf(
|
||||
`SELECT %s FROM %s WHERE deleted_at IS NULL AND state = 0 ORDER BY enqueue_at ASC LIMIT %d FOR UPDATE SKIP LOCKED`,
|
||||
taskColumns, public.TableNameTask, batchSize,
|
||||
)
|
||||
r, err := tx.GetAll(sql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil
|
||||
}
|
||||
if err := r.Structs(&tasks); err != nil {
|
||||
return err
|
||||
}
|
||||
now := time.Now()
|
||||
for _, t := range tasks {
|
||||
_, err = tx.Exec(
|
||||
fmt.Sprintf(`UPDATE %s SET state=1, started_at=?, updated_at=? WHERE id=?`, public.TableNameTask),
|
||||
now, now, t.Id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ClaimPendingByTaskIDGlobal 按 task_id 抢占
|
||||
func (d *taskDao) ClaimPendingByTaskIDGlobal(ctx context.Context, taskID string) (*entity.AsynchTask, error) {
|
||||
if taskID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
tasks, err := claimTasks(ctx, "AND task_id = ?", taskID)
|
||||
if err != nil || len(tasks) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return tasks[0], nil
|
||||
}
|
||||
|
||||
// ======================== 业务更新方法 ========================
|
||||
|
||||
// RollbackToPendingGlobal 回滚到 pending 状态
|
||||
func (d *taskDao) RollbackToPendingGlobal(ctx context.Context, id int64) error {
|
||||
// state=0 可能被 OmitEmpty 跳过,所以用原生 SQL + 条件 state=1 防并发
|
||||
return execUpdate(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, enqueue_at=NOW(), updated_at=NOW() WHERE id=? AND state=1`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
// IncRetryCountGlobal 重试计数 +1
|
||||
func (d *taskDao) IncRetryCountGlobal(ctx context.Context, id int64) error {
|
||||
return execUpdate(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET retry_count=retry_count+1, updated_at=NOW() WHERE id=?`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
// RequeueForRetryGlobal 重新入队
|
||||
func (d *taskDao) RequeueForRetryGlobal(ctx context.Context, id int64, enqueueAt time.Time) error {
|
||||
return execUpdate(ctx,
|
||||
fmt.Sprintf(`UPDATE %s SET state=0, retry_count=retry_count+1, enqueue_at=?, updated_at=NOW() WHERE id=? AND state=3 AND deleted_at IS NULL`, public.TableNameTask),
|
||||
enqueueAt, id,
|
||||
)
|
||||
}
|
||||
|
||||
// ======================== 列表查询 ========================
|
||||
|
||||
// ListExpiredDownloadedGlobal 查询已过期下载的任务
|
||||
func (d *taskDao) ListExpiredDownloadedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx,
|
||||
fmt.Sprintf(`SELECT * FROM %s WHERE deleted_at IS NULL AND state=4 AND expire_at IS NOT NULL AND expire_at < ? LIMIT ?`, public.TableNameTask),
|
||||
gtime.Now(), clampLimit(limit, 200),
|
||||
)
|
||||
}
|
||||
|
||||
// ListFailedRetryableGlobal 查询可重试的失败任务
|
||||
func (d *taskDao) ListFailedRetryableGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx,
|
||||
fmt.Sprintf(
|
||||
`SELECT t.*, m.retry_queue_max_seconds FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count < m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
|
||||
public.TableNameTask, public.TableNameModel,
|
||||
),
|
||||
clampLimit(limit, 200),
|
||||
)
|
||||
}
|
||||
|
||||
// ListFailedExhaustedGlobal 查询重试次数耗尽的失败任务
|
||||
func (d *taskDao) ListFailedExhaustedGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx,
|
||||
fmt.Sprintf(
|
||||
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state=3 AND t.retry_count >= m.retry_times ORDER BY t.updated_at ASC LIMIT ?`,
|
||||
public.TableNameTask, public.TableNameModel,
|
||||
),
|
||||
clampLimit(limit, 200),
|
||||
)
|
||||
}
|
||||
|
||||
// ListTimeoutTasksGlobal 查询超时任务
|
||||
func (d *taskDao) ListTimeoutTasksGlobal(ctx context.Context, limit int) ([]*entity.AsynchTask, error) {
|
||||
return queryTasks(ctx,
|
||||
fmt.Sprintf(
|
||||
`SELECT t.* FROM %s t JOIN %s m ON t.tenant_id=m.tenant_id AND t.model_name=m.model_name WHERE t.deleted_at IS NULL AND t.state IN (0,1) AND m.expected_seconds > 0 AND t.created_at < (NOW() - (m.expected_seconds || ' seconds')::interval) LIMIT ?`,
|
||||
public.TableNameTask, public.TableNameModel,
|
||||
),
|
||||
clampLimit(limit, 200),
|
||||
)
|
||||
}
|
||||
|
||||
// HardDeleteByIDGlobal 物理删除任务
|
||||
func (d *taskDao) HardDeleteByIDGlobal(ctx context.Context, id int64) error {
|
||||
return execUpdate(ctx,
|
||||
fmt.Sprintf(`DELETE FROM %s WHERE id=?`, public.TableNameTask),
|
||||
id,
|
||||
)
|
||||
}
|
||||
|
||||
// ======================== 内部辅助 ========================
|
||||
|
||||
func queryTasks(ctx context.Context, sql string, args ...any) ([]*entity.AsynchTask, error) {
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).GetAll(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var list []*entity.AsynchTask
|
||||
if err = r.Structs(&list); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func clampLimit(limit, defaultVal int) int {
|
||||
if limit <= 0 {
|
||||
return defaultVal
|
||||
}
|
||||
return limit
|
||||
}
|
||||
Reference in New Issue
Block a user