package flow import ( "ai-agent/workflow/consts/node" nodeDao "ai-agent/workflow/dao/node" "ai-agent/workflow/model/dto" flowDto "ai-agent/workflow/model/dto/flow" nodeDto "ai-agent/workflow/model/dto/node" "ai-agent/workflow/model/entity" "bytes" "context" "fmt" "io" "mime/multipart" "net/http" "net/url" "path/filepath" "regexp" "strconv" "strings" "sync" commonHttp "gitea.redpowerfuture.com/red-future/common/http" "gitea.redpowerfuture.com/red-future/common/utils" "github.com/gogf/gf/v2/errors/gerror" "github.com/gogf/gf/v2/frame/g" "github.com/gogf/gf/v2/util/gconv" "github.com/tidwall/sjson" ) // 全局等待任务回调的工具 var ( asyncMu sync.Mutex asyncTasks = make(map[string]chan any) ) // Wait 阻塞等待回调结果 // 调用后会一直卡住,直到 Notify 唤醒 或 超时/取消 func Wait(ctx context.Context, taskId string) (any, error) { asyncMu.Lock() ch := make(chan any, 1) asyncTasks[taskId] = ch asyncMu.Unlock() defer close(ch) for { select { case result := <-ch: return result, nil case <-ctx.Done(): asyncMu.Lock() delete(asyncTasks, taskId) asyncMu.Unlock() return nil, ctx.Err() } } } // Notify 回调时调用,唤醒等待的任务 func Notify(taskId string, result any) { asyncMu.Lock() defer asyncMu.Unlock() ch, exist := asyncTasks[taskId] if !exist { return } ch <- result delete(asyncTasks, taskId) } func GetIsChatModel(ctx context.Context) (res *flowDto.GetIsChatModelRes, err error) { headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } res = new(flowDto.GetIsChatModelRes) err = commonHttp.Get(ctx, "model-gateway/model/getIsChatModel", headers, res, nil) return } func GetModelInfo(ctx context.Context, req *flowDto.GetModelInfoReq) (res *flowDto.GetModelInfoRes, err error) { headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } res = new(flowDto.GetModelInfoRes) err = commonHttp.Get(ctx, "model-gateway/model/getModel", headers, res, req) return } func GetComposeResult(ctx context.Context, buildType int, modelName, promptContent, skillName string, form []map[string]any, userForm []map[string]any, fileUrl []string, sessionId, nodeId string, cause string) (res *flowDto.ComposeCallbackReq, err error) { if !g.IsEmpty(promptContent) { userForm = append(userForm, map[string]any{ "prompt": promptContent, }) } var callbackUrl = utils.GetCallbackURL(ctx, "/flow/execution/composeCallBack") var consult = make([]flowDto.Consult, 0) var collectFileUrls func(val any) (fullyConsumed bool) collectFileUrls = func(val any) (fullyConsumed bool) { switch { case g.NewVar(val).IsSlice(): slice := gconv.SliceAny(val) allConsumed := false for _, item := range slice { if collectFileUrls(item) { allConsumed = true } } return allConsumed case g.NewVar(val).IsMap(): m := gconv.Map(val) allConsumed := false for _, item := range m { if collectFileUrls(item) { allConsumed = true } } return allConsumed default: s := gconv.String(val) if s != "" { getFileTypeByPath := GetFileTypeByPath(s) if getFileTypeByPath != "" { consult = append(consult, flowDto.Consult{ Type: getFileTypeByPath, Url: s, }) return true } } return false } } var newUserForm []map[string]any for _, m := range userForm { for k, v := range m { if collectFileUrls(v) { delete(m, k) } } if len(m) > 0 { newUserForm = append(newUserForm, m) } } for _, v := range fileUrl { getFileTypeByPath := GetFileTypeByPath(gconv.String(v)) if getFileTypeByPath != "" { consult = append(consult, flowDto.Consult{ Type: getFileTypeByPath, Url: gconv.String(v), }) } } msgReq := flowDto.ComposeMessagesReq{ BuildType: buildType, ModelName: modelName, SkillName: skillName, CallbackUrl: callbackUrl, Cause: cause, Form: form, UserForm: newUserForm, Consult: consult, SessionId: sessionId, NodeId: nodeId, } headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } msgRes := new(flowDto.ComposeMessagesRes) err = commonHttp.Post(ctx, "prompts-core/prompt/composeMessages", headers, msgRes, &msgReq) if err != nil { return } if g.IsEmpty(msgRes.TaskId) { return nil, fmt.Errorf("msg is empty") } waitRes, err := Wait(ctx, msgRes.TaskId) if err != nil { return nil, err } msg := new(flowDto.ComposeCallbackReq) if err = gconv.Struct(waitRes, msg); err != nil { return nil, err } if !g.IsEmpty(msg.ErrorMsg) { return nil, fmt.Errorf(msg.ErrorMsg) } return msg, nil } func CreateGatewayTask(ctx context.Context, epicycleId int64, model string, content map[string]any) (map[string]any, error) { taskId, err := createGatewayTaskOnly(ctx, epicycleId, model, content) if err != nil { return nil, err } return waitGatewayResult(ctx, taskId) } // createGatewayTaskOnly creates a gateway task and returns the taskId only // doesn't wait for completion func createGatewayTaskOnly(ctx context.Context, epicycleId int64, model string, content map[string]any) (string, error) { callbackUrl := utils.GetCallbackURL(ctx, "/flow/execution/modelCallback") req := flowDto.ModelGatewayReq{ ModelName: model, BizName: g.Cfg().MustGet(ctx, "server.name").String(), CallbackUrl: callbackUrl, RequestPayload: content, EpicycleId: epicycleId, } headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } res := new(flowDto.ModelGatewayRes) err := commonHttp.Post(ctx, "model-gateway/task/createTask", headers, res, &req) if err != nil { return "", err } if g.IsEmpty(res.TaskId) { return "", fmt.Errorf("创建模型任务失败,taskId为空") } return res.TaskId, nil } // waitGatewayResult waits for a created gateway task to complete and returns the result func waitGatewayResult(ctx context.Context, taskId string) (map[string]any, error) { waitRes, err := Wait(ctx, taskId) if err != nil { return nil, err } task := new(flowDto.ModelCallbackReq) if err = gconv.Struct(waitRes, task); err != nil { return nil, err } if task.State == 3 || !g.IsEmpty(task.ErrorMsg) { return nil, fmt.Errorf("模型执行失败:%s", task.ErrorMsg) } if g.IsEmpty(task.Messages) { return nil, fmt.Errorf("模型返回结果为空") } // 获取远程文件内容 //file, err := GetFileBytesFromURL(ctx, task.OssFile) //if err != nil { // return nil, err //} //task.Messages = gconv.Map(file) return task.Messages, nil } // updateTokenCount updates the token count in node execution func updateTokenCount(ctx context.Context, nodeExecutionId int64, responseField string, result map[string]any) { if responseField == "" { return } _, _ = nodeDao.NodeExecutionDao.Update(ctx, &nodeDto.UpdateNodeExecutionReq{ Id: nodeExecutionId, CompletionTokens: gconv.Int(result[responseField]), TotalTokens: gconv.Int(result[responseField]), }) } func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.NodeExecutionInput, skillName string, form []map[string]any, userForm []map[string]any) (mapTaskResult []map[string]any, err error) { buildType := 1 if nodeInput.Config.NodeCode == node.NodeTypeDataConversionModel { buildType = 3 } if !nodeInput.Global.IsDialogue { sessionId = "" } composeResult, err := GetComposeResult(ctx, buildType, nodeInput.Config.ModelConfig.ModelName, nodeInput.Config.PromptContent, skillName, form, userForm, nodeInput.Global.FileUrl, sessionId, nodeInput.Config.Id, nodeInput.Config.Name) if err != nil { return nil, err } modelInfo, err := GetModelInfo(ctx, &flowDto.GetModelInfoReq{ModelName: nodeInput.Config.ModelConfig.ModelName}) if err != nil { return nil, err } mapTaskResult = make([]map[string]any, len(composeResult.Messages.Rounds)) var taskResultMap map[string]any needSequential := false if buildType == 1 { if needSequential { for idx, item := range composeResult.Messages.Rounds { if !g.IsEmpty(taskResultMap) { var set string set, err = sjson.Set(gconv.String(item), modelInfo.Model.LastFrame, gconv.String(taskResultMap[modelInfo.Model.ResponseBody])) if err != nil { return nil, err } item = gconv.Map(set) } var taskResult map[string]any taskResult, err = CreateGatewayTask(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item) if err != nil { return nil, err } if g.IsEmpty(taskResult) { return nil, fmt.Errorf("模型返回结果为空") } if nodeInput.Config.NodeCode == node.NodeTypeVideoModel { ext := GetFileTypeByPath(gconv.String(taskResult[modelInfo.Model.ResponseBody])) if ext == "image" { taskResultMap = taskResult } else { taskResultMap = make(map[string]any) } } else { taskResultMap = make(map[string]any) } mapTaskResult[idx] = taskResult updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult) } } else { taskIdList := make([]string, len(composeResult.Messages.Rounds)) for idx, item := range composeResult.Messages.Rounds { var taskId string taskId, err = createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item) if err != nil { return nil, err } taskIdList[idx] = taskId } var wg sync.WaitGroup errChan := make(chan error, len(taskIdList)) for idx, taskId := range taskIdList { wg.Add(1) go func(idx int, taskId string) { defer wg.Done() var taskResult map[string]any taskResult, err = waitGatewayResult(ctx, taskId) if err != nil { errChan <- err return } mapTaskResult[idx] = taskResult updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult) }(idx, taskId) } wg.Wait() close(errChan) if len(errChan) > 0 { return nil, <-errChan } } } else { for idx, item := range composeResult.Messages.Rounds { mapTaskResult[idx] = item updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, item) } } return mapTaskResult, nil } func BuildNestedJson(body g.Map, mockConfigMap map[string]*entity.FlowNode) g.Map { jsonStr := "{}" for originKey, originItem := range body { bodyItemMap := gconv.Map(originItem) val := bodyItemMap["value"] if v, ok := bodyItemMap["value"]; ok { jsonStr, _ = sjson.Set(jsonStr, originKey, v) } // 判断 value 是不是引用结构(map) if g.NewVar(val).IsMap() { valMap := gconv.Map(val) nodeId := gconv.String(valMap["nodeId"]) fieldName := gconv.String(valMap["field"]) if configValue, ok := mockConfigMap[nodeId]; ok { if !g.IsEmpty(configValue.OutputResult) { for _, v := range configValue.OutputResult { if strings.Contains(v.Field, fieldName) { if configValue.NodeCode == node.NodeTypeDataConversionModel { switch { case g.NewVar(v.Value).IsSlice() || g.NewVar(v.Value).IsMap(): // 核心:自动判断两种结构,精准赋值 vm := gconv.Map(v.Value) // 先判断是否是 单个key包裹的对象(如 {"subtitle_style": {...}}) if len(vm) == 1 { // 遍历取出唯一的 key 和 真实值 for innerKey, innerVal := range vm { // 直接用 innerKey(subtitle_style)赋值 jsonStr, _ = sjson.Set(jsonStr, innerKey, innerVal) } } else { // 直接是对象,用 originKey 赋值 jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value) } default: jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value) } } else { jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value) } } } } if !g.IsEmpty(configValue.FormConfig) { for _, v := range configValue.FormConfig { if v.Field == fieldName { if v.Type == "uploadMultiple" { if g.NewVar(v.FieldConstraint).IsMap() { mapFieldConstraint := gconv.Map(v.FieldConstraint) for key, value := range mapFieldConstraint { if key == "maxFileCount" { if gconv.Int(value) == 1 { // 如果是单文件上传,则替换成字符串重新赋值给v.Value if g.NewVar(v.Value).IsSlice() { sliceVal := gconv.SliceAny(v.Value) if len(sliceVal) > 0 { v.Value = sliceVal[0] } } } } } } } jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value) } } } } } } return gconv.Map(jsonStr) } func VideoConcat(ctx context.Context, videoUrls []string) (r any, err error) { var httpUrl = "media/video/concat/async" headers := make(map[string]string) if r := g.RequestFromCtx(ctx); r != nil { for k, v := range r.Request.Header { if len(v) > 0 { headers[k] = v[0] } } } var callbackUrl = utils.GetCallbackURL(ctx, "/flow/execution/videoCallback") var newBody = flowDto.VideoConcatReq{ VideoUrls: videoUrls, Method: "auto", Upload: true, CallbackUrl: callbackUrl, } res := new(flowDto.VideoConcatRes) err = commonHttp.Post(ctx, httpUrl, headers, &res, newBody) if err != nil { return nil, err } return Wait(ctx, res.TaskId) } func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) { // 使用 GoFrame 客户端(自带超时、追踪、日志等能力) resp, err := g.Client().Get(ctx, fileUrl) if err != nil { return nil, gerror.Wrapf(err, "failed to request url: %s", fileUrl) } defer resp.Close() // 校验状态码 if resp.StatusCode != http.StatusOK { return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, fileUrl) } // 读取全部内容 allBytes, err := io.ReadAll(resp.Body) if err != nil { return nil, gerror.Wrapf(err, "failed to read response body, url: %s", fileUrl) } return allBytes, nil } func Upload(ctx context.Context, req *dto.UploadFileBytesReq) (*dto.UploadFileBytesRes, error) { body := &bytes.Buffer{} writer := multipart.NewWriter(body) part, err := writer.CreateFormFile("file", req.FileName) if err != nil { return nil, err } if _, err = part.Write(req.FileBytes); err != nil { return nil, err } if err = writer.Close(); err != nil { return nil, err } headers := make(map[string]string) headers["Content-Type"] = writer.FormDataContentType() if r := g.RequestFromCtx(ctx); r != nil { if auth := r.Header.Get("Authorization"); auth != "" { headers["Authorization"] = auth } } // 发起上传请求 res := &dto.UploadFileBytesRes{} httpUrl := "oss/file/uploadFile" if err = commonHttp.Post(ctx, httpUrl, headers, res, body.Bytes()); err != nil { return nil, err } g.Log().Infof(ctx, "[Upload] success url=%s size=%d", res.FileURL, res.FileSize) return res, nil } func GetFileTypeByPath(filePath string) string { if filePath == "" { return "" } // 解析 URL,获取真实路径(兼容 http 链接) u, err := url.Parse(filePath) if err == nil { filePath = u.Path } // 获取后缀(小写) ext := filepath.Ext(filePath) ext = strings.ToLower(ext) // 判断类型 switch ext { case ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp": return "image" case ".mp4", ".mov", ".avi", ".flv", ".wmv", ".mkv": return "video" case ".mp3", ".wav", ".m4a", ".flac", ".aac", ".ogg": return "audio" case ".txt", ".md", ".log", ".json", ".xml", ".inc": return "text" case ".html": return "html" case ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx": return "document" default: return "" } } func BuildText(text string) string { // 生成单条HTML var htmlBuilder strings.Builder htmlBuilder.WriteString(`
需要配图:X 张
if text != "" { // 写入清理后的文案 htmlBuilder.WriteString(fmt.Sprintf(`]*>.*?(\d+).*?
`) match := re.FindStringSubmatch(content) if len(match) >= 2 { num, err := strconv.Atoi(match[1]) if err == nil { return num } } return 0 } func ImageTagRegex(html string) string { // 🔥 修复:支持单引号、双引号、空格、换行,100% 删除imageTagRegex := regexp.MustCompile(`
]*>[\s\S]*?
`) return imageTagRegex.ReplaceAllString(html, "") } // StripHtmlTags 去掉所有HTML标签,保留换行和文本结构,并删除配图标记行 func StripHtmlTags(html string) string { // 1. 替换块级标签为换行,保证排版 blockTags := regexp.MustCompile(`?(div|p|h1|h2|h3|h4|h5|h6|li|ul|ol|br|tr|td|th)[^>]*>`) text := blockTags.ReplaceAllString(html, "\n") // 2. 去掉所有剩余的 HTML 标签 allTags := regexp.MustCompile(`<[^>]+>`) text = allTags.ReplaceAllString(text, "") // 4. 清理多余空行(多个换行只保留一个) text = regexp.MustCompile(`\n\s*\n`).ReplaceAllString(text, "\n") // 5. 只去掉首尾空白,中间换行保留 text = strings.TrimSpace(text) return text } // SplitMultiContents 拆分模型返回的多条文案(基于HTML标签分隔) func SplitMultiContents(htmlContent string) []string { var contents []string // 正则匹配