196 lines
5.9 KiB
Go
196 lines
5.9 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
|
||
"ai-agent/digital-human/consts"
|
||
"ai-agent/digital-human/consts/public"
|
||
"ai-agent/digital-human/dao"
|
||
"ai-agent/digital-human/model/dto"
|
||
"ai-agent/digital-human/model/entity"
|
||
|
||
"github.com/gogf/gf/v2/database/gdb"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/os/gtime"
|
||
)
|
||
|
||
type asyncTaskService struct{}
|
||
|
||
// AsyncTask 异步任务同步服务(供定时任务/业务轮询调用)
|
||
var AsyncTask = new(asyncTaskService)
|
||
|
||
// Sync
|
||
// 1) 扫描 digital_human_async_task_ref 中 state=0/1 的记录(业务“生成中”)
|
||
// 2) 组装 task_id 批量请求 model-asynch /task/get-task-batch
|
||
// 3) 中间件状态映射到业务状态(业务只维护三态:0生成中/1成功/2失败):
|
||
// - 中间件 0/1/3(能查到 task_id) -> 业务 0(生成中)
|
||
// - 中间件 2/4(成功/已下载) -> 业务 1(成功)
|
||
// - 中间件 查不到 task_id(返回列表缺失) -> 业务 2(失败)
|
||
//
|
||
// 4) 绑定表仅用于“待同步列表”,因此:
|
||
// - 对中间件 0/1/3 不额外写库(减少查询/更新开销)
|
||
// - 对成功(2/4)与缺失(task_id 查不到)才更新绑定表
|
||
func (s *asyncTaskService) Sync(ctx context.Context, req *dto.SyncAsyncTasksReq) (res *dto.SyncAsyncTasksRes, err error) {
|
||
limit := 200
|
||
if req != nil && req.Limit > 0 {
|
||
limit = req.Limit
|
||
}
|
||
refs, err := dao.AsyncTaskRef.ListPending(ctx, limit)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
taskIDs := make([]string, 0, len(refs))
|
||
refMap := make(map[string]*entity.AsyncTaskRef, len(refs))
|
||
for _, r := range refs {
|
||
if r == nil || r.TaskID == "" {
|
||
continue
|
||
}
|
||
taskIDs = append(taskIDs, r.TaskID)
|
||
refMap[r.TaskID] = r
|
||
}
|
||
|
||
out := &dto.SyncAsyncTasksRes{
|
||
Total: len(taskIDs),
|
||
List: make([]dto.SyncAsyncTasksItem, 0, len(taskIDs)),
|
||
}
|
||
if len(taskIDs) == 0 {
|
||
return out, nil
|
||
}
|
||
items, err := getModelAsynchTaskBatch(ctx, taskIDs)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
seen := make(map[string]struct{}, len(items))
|
||
handled := 0
|
||
|
||
for _, it := range items {
|
||
r := refMap[it.TaskID]
|
||
if r == nil {
|
||
continue
|
||
}
|
||
seen[it.TaskID] = struct{}{}
|
||
|
||
switch it.State {
|
||
case 0, 1, 3:
|
||
// 排队中/执行中/失败(可能重试):业务侧仍视为生成中,不更新绑定表,减少更新开销
|
||
case 2, 4:
|
||
// 成功/已下载:业务侧写入 oss_file 并标记成功
|
||
if it.OssFile == "" {
|
||
errMsg := "中间件返回空oss地址"
|
||
_ = s.updateBizFailed(ctx, r, errMsg)
|
||
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
|
||
entity.AsyncTaskRefCol.State: it.State,
|
||
entity.AsyncTaskRefCol.OssFile: "",
|
||
entity.AsyncTaskRefCol.ErrorMsg: errMsg,
|
||
})
|
||
out.List = append(out.List, dto.SyncAsyncTasksItem{
|
||
TaskID: it.TaskID,
|
||
State: it.State,
|
||
TableName: r.TableName,
|
||
BizID: fmt.Sprintf("%d", r.BizID),
|
||
OssFile: "",
|
||
ErrorMsg: errMsg,
|
||
})
|
||
continue
|
||
}
|
||
if err := s.updateBizSuccess(ctx, r, it.OssFile); err != nil {
|
||
errMsg := fmt.Sprintf("生成音频失败: %v", err)
|
||
_ = s.updateBizFailed(ctx, r, errMsg)
|
||
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
|
||
entity.AsyncTaskRefCol.State: it.State,
|
||
entity.AsyncTaskRefCol.OssFile: it.OssFile,
|
||
entity.AsyncTaskRefCol.ErrorMsg: errMsg,
|
||
})
|
||
out.List = append(out.List, dto.SyncAsyncTasksItem{
|
||
TaskID: it.TaskID,
|
||
State: it.State,
|
||
TableName: r.TableName,
|
||
BizID: fmt.Sprintf("%d", r.BizID),
|
||
OssFile: it.OssFile,
|
||
ErrorMsg: errMsg,
|
||
})
|
||
continue
|
||
}
|
||
handled++
|
||
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, it.TaskID, gdb.Map{
|
||
entity.AsyncTaskRefCol.State: it.State,
|
||
entity.AsyncTaskRefCol.OssFile: it.OssFile,
|
||
entity.AsyncTaskRefCol.ErrorMsg: "",
|
||
})
|
||
default:
|
||
// 其他状态:不处理
|
||
}
|
||
|
||
out.List = append(out.List, dto.SyncAsyncTasksItem{
|
||
TaskID: it.TaskID,
|
||
State: it.State,
|
||
TableName: r.TableName,
|
||
BizID: fmt.Sprintf("%d", r.BizID),
|
||
OssFile: it.OssFile,
|
||
ErrorMsg: "",
|
||
})
|
||
}
|
||
|
||
// 处理“查不到 task_id”的情况:
|
||
// 中间件对失败重试耗尽的任务会硬删除,批量接口不会返回该 task_id。
|
||
// 业务侧把这种情况视为失败终态,并软删除绑定记录,避免重复轮询。
|
||
for _, taskID := range taskIDs {
|
||
if _, ok := seen[taskID]; ok {
|
||
continue
|
||
}
|
||
r := refMap[taskID]
|
||
if r == nil {
|
||
continue
|
||
}
|
||
msg := "模型任务不存在已失败"
|
||
_ = s.updateBizFailed(ctx, r, msg)
|
||
_, _ = dao.AsyncTaskRef.UpdateByTaskID(ctx, taskID, gdb.Map{
|
||
entity.AsyncTaskRefCol.State: 3,
|
||
entity.AsyncTaskRefCol.ErrorMsg: msg,
|
||
"deleted_at": gtime.Now(),
|
||
})
|
||
out.List = append(out.List, dto.SyncAsyncTasksItem{
|
||
TaskID: taskID,
|
||
State: 3,
|
||
TableName: r.TableName,
|
||
BizID: fmt.Sprintf("%d", r.BizID),
|
||
OssFile: "",
|
||
ErrorMsg: msg,
|
||
})
|
||
}
|
||
|
||
out.Handled = handled
|
||
g.Log().Infof(ctx, "[AsyncTask.Sync] total=%d handled=%d", out.Total, out.Handled)
|
||
return out, nil
|
||
}
|
||
|
||
// updateBizSuccess 更新业务侧状态为成功
|
||
func (s *asyncTaskService) updateBizSuccess(ctx context.Context, ref *entity.AsyncTaskRef, ossFile string) error {
|
||
switch ref.TableName {
|
||
case public.TableNameAudio:
|
||
_, err := dao.Audio.UpdateStatus(ctx, ref.BizID, consts.AudioStatusSuccess, "", ossFile, 0, "")
|
||
return err
|
||
case public.TableNameCustomVoice:
|
||
_, err := dao.CustomVoice.UpdateStatus(ctx, ref.BizID, 1, "", ossFile)
|
||
return err
|
||
default:
|
||
return fmt.Errorf("未知 table_name=%s", ref.TableName)
|
||
}
|
||
}
|
||
|
||
// updateBizFailed 更新业务侧状态为失败
|
||
func (s *asyncTaskService) updateBizFailed(ctx context.Context, ref *entity.AsyncTaskRef, msg string) error {
|
||
switch ref.TableName {
|
||
case public.TableNameAudio:
|
||
_, err := dao.Audio.UpdateStatus(ctx, ref.BizID, consts.AudioStatusFailed, msg, "", 0, "")
|
||
return err
|
||
case public.TableNameCustomVoice:
|
||
_, err := dao.CustomVoice.UpdateStatus(ctx, ref.BizID, 2, msg, "")
|
||
return err
|
||
default:
|
||
return nil
|
||
}
|
||
}
|