refactor(model): 优化模型网关的数据解析和任务处理逻辑
This commit is contained in:
@@ -56,6 +56,7 @@ func (d *modelGatewayModelsDao) Get(ctx context.Context, req *entity.ModelGatewa
|
||||
Where(entity.ModelGatewayModelCol.Id, req.Id).
|
||||
Where(entity.ModelGatewayModelCol.Creator, req.Creator).
|
||||
Where(entity.ModelGatewayModelCol.ModelName, req.ModelName).
|
||||
Where(entity.ModelGatewayModelCol.IsChatModel, req.IsChatModel).
|
||||
Fields(fields).One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -122,7 +123,7 @@ func (d *modelGatewayModelsDao) GetByAcrossTenant(ctx context.Context, req *enti
|
||||
func (d *modelGatewayModelsDao) GetByCreatorAndPlatform(ctx context.Context, req *dto.ListModelReq) (list []*entity.ModelGatewayModel, total int, err error) {
|
||||
sql := `
|
||||
SELECT DISTINCT ON (model_name) *
|
||||
FROM asynch_models
|
||||
FROM ` + public.TableNameModel + `
|
||||
WHERE deleted_at IS NULL
|
||||
AND (? = '' OR model_name LIKE ?)
|
||||
`
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"model-gateway/model/entity"
|
||||
|
||||
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
@@ -128,32 +127,32 @@ func (d *modelGatewayTaskDao) GetPendingAsyncTasks(ctx context.Context, limit in
|
||||
|
||||
// ClaimByID 按主键抢占,返回抢占后的任务
|
||||
func (d *modelGatewayTaskDao) ClaimByID(ctx context.Context, id int64) (*entity.ModelGatewayTask, error) {
|
||||
// 1) 先查任务
|
||||
var task entity.ModelGatewayTask
|
||||
err := gfdb.DB(ctx, public.DbNameModelGateway).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
|
||||
r, err := tx.Model(public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
||||
Limit(1).
|
||||
LockUpdate().
|
||||
One()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
||||
}
|
||||
if err := r.Struct(&task); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Model(public.TableNameTask).
|
||||
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
OmitEmpty().
|
||||
Update()
|
||||
return err
|
||||
})
|
||||
r, err := gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending).
|
||||
One()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if r.IsEmpty() {
|
||||
return nil, fmt.Errorf("任务已被抢占或不存在: id=%d", id)
|
||||
}
|
||||
if err = r.Struct(&task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 2) 改为执行中
|
||||
_, err = gfdb.DB(ctx, public.DbNameModelGateway).Model(ctx, public.TableNameTask).
|
||||
Data(&entity.ModelGatewayTask{State: public.TaskStatusRunning}).
|
||||
Where(entity.ModelGatewayTaskCol.Id, id).
|
||||
Where(entity.ModelGatewayTaskCol.State, public.TaskStatusPending). // 防并发
|
||||
OmitEmpty().
|
||||
Update()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user