diff --git a/workflow/service/flow/lambda_node_imp.go b/workflow/service/flow/lambda_node_imp.go new file mode 100644 index 0000000..53925be --- /dev/null +++ b/workflow/service/flow/lambda_node_imp.go @@ -0,0 +1,593 @@ +package flow + +import ( + "ai-agent/workflow/consts/flow" + "ai-agent/workflow/consts/node" + flowDao "ai-agent/workflow/dao/flow" + "ai-agent/workflow/model/dto" + flowDto "ai-agent/workflow/model/dto/flow" + "ai-agent/workflow/model/entity" + "context" + "fmt" + "regexp" + "strconv" + "strings" + "time" + + "gitea.com/red-future/common/utils" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/util/gconv" +) + +func getNodeInfo(flowInfo *entity.FlowExecution) (htmlUrl []string, textModelName string, textResultFrom map[string]any, textModelResponse map[string]any, imgModelName string, imgResultFrom map[string]any, imgModelResponse map[string]any) { + // 查询节点中是否包含结果合并节点 + for _, item := range flowInfo.NodeInputParams { + if item.NodeCode == node.NodeTypeMerge { + for _, outputParamsItem := range flowInfo.OutputParams { + outputParamsMap := gconv.Map(outputParamsItem) + for _, mapItem := range outputParamsMap { + if strings.HasSuffix(gconv.String(mapItem), ".html") { + htmlUrl = append(htmlUrl, gconv.String(mapItem)) + } + } + } + } + if item.NodeCode == node.NodeTypeTextModel { + textModelName = item.ModelConfig.ModelName + textModelResponse = item.ModelConfig.ModelResponse + for key, modelFormItem := range item.ModelConfig.ModelForm { + textResultFrom[key] = map[string]any{ + "value": modelFormItem, + } + } + } + if item.NodeCode == node.NodeTypeImageModel { + imgModelName = item.ModelConfig.ModelName + imgModelResponse = item.ModelConfig.ModelResponse + for key, modelFormItem := range item.ModelConfig.ModelForm { + imgResultFrom[key] = map[string]any{ + "value": modelFormItem, + } + } + } + } + + return htmlUrl, textModelName, textResultFrom, textModelResponse, imgModelName, imgResultFrom, imgModelResponse + +} + +func TextImgModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) { + _, textModelName, textResultFrom, textModelResponse, imgModelName, imgResultFrom, imgModelResponse := getNodeInfo(flowInfo) + + resultUserFrom := make(map[string]any) + resultUserFrom["desc"] = req.Desc + + var textNode []node.NodeFormField + textNode, err = TextNode(ctx, req.SessionId, textModelName, req.SkillName, textResultFrom, resultUserFrom, textModelResponse, req.FileUrl) + if err != nil { + return + } + var textContent string + var textUrl string + for _, item := range textNode { + if strings.Contains(item.Field, "text_content") { + textContent = StripHtmlTags(item.Value) + } + if strings.Contains(item.Field, "text_url") { + textUrl = item.Value + } + } + + resultUserFrom["text_content"] = textContent + var imgNode []node.NodeFormField + imgNode, err = ImgNode(ctx, req.SessionId, imgModelName, req.SkillName, imgResultFrom, resultUserFrom, imgModelResponse, req.FileUrl) + if err != nil { + return + } + var imgUrl []string + for _, item := range imgNode { + if strings.Contains(item.Field, "img_url") { + imgUrl = append(imgUrl, item.Value) + } + } + + // 生成单条HTML + htmlContent := BuildHtml(textUrl, imgUrl) + // 上传OSS(每条独立上传) + fileName := fmt.Sprintf("item_%d_%d.html", 0, time.Now().UnixMilli()) + var ossResult *dto.UploadFileBytesRes + ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{ + FileBytes: []byte(htmlContent), + FileName: fileName, + }) + if err != nil { + return + } + fmt.Printf("上传OSS成功:%s", ossResult.FileURL) + + var summaryResult []map[string]interface{} + for _, outputParamsItem := range flowInfo.OutputParams { + mapItem := gconv.Map(outputParamsItem) + for _, mapValue := range mapItem { + if strings.Contains(req.ResultUrl, gconv.String(mapValue)) { + // 生成 毫秒时间戳 作为 KEY + timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10) + item := make(map[string]interface{}) + item[timeKey] = ossResult.FileURL + summaryResult = append(summaryResult, item) + continue + } + summaryResult = append(summaryResult, outputParamsItem) + } + } + if !g.IsEmpty(summaryResult) { + executionReq := flowDto.UpdateFlowExecutionReq{ + Id: flowInfo.Id, + Status: flow.FlowExecutionStatusSuccess.Code(), + OutputParams: summaryResult, + } + _, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq) + } + return +} + +func ImgModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) { + var url string + url, err = utils.GetFileAddressPrefix(ctx) + if err != nil { + return + } + + htmlUrl, _, _, _, imgModelName, imgResultFrom, imgModelResponse := getNodeInfo(flowInfo) + + resultUserFrom := make(map[string]any) + resultUserFrom["desc"] = req.Desc + + var imgNode []node.NodeFormField + imgNode, err = ImgNode(ctx, req.SessionId, imgModelName, req.SkillName, imgResultFrom, resultUserFrom, imgModelResponse, req.FileUrl) + if err != nil { + return + } + + var imgUrl string + for _, item := range imgNode { + if strings.Contains(item.Field, "img_url") { + imgUrl = item.Value + } + } + + var htmlContentUrl string + var oldHtmlUrl string + if !g.IsEmpty(htmlUrl) { + for i, item := range htmlUrl { + var htmlBytes []byte + htmlBytes, err = GetFileBytesFromURL(url + item) + if err != nil { + return + } + htmlContent := string(htmlBytes) + + imgSrcFromHtml := GetAllImgSrcFromHtml(htmlContent) + // 3. 标记是否需要替换 + needReplace := false + for _, imgSrc := range imgSrcFromHtml { + if imgSrc == req.ResultUrl { + needReplace = true + break // 找到一个就可以替换 + } + } + + // 4. 如果匹配到,执行替换(把旧的 req.ResultUrl 替换成 新链接) + if needReplace { + oldHtmlUrl = url + item + htmlContent = ReplaceImgSrc(htmlContent, req.ResultUrl, imgUrl) + // 上传OSS(每条独立上传) + fileName := fmt.Sprintf("item_%d_%d.html", i, time.Now().UnixMilli()) + var ossResult *dto.UploadFileBytesRes + ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{ + FileBytes: []byte(htmlContent), + FileName: fileName, + }) + if err != nil { + return + } + fmt.Printf("上传OSS成功:%s", ossResult.FileURL) + htmlContentUrl = ossResult.FileURL + } + + } + } + + var summaryResult []map[string]interface{} + if !g.IsEmpty(imgUrl) { + for _, outputParamsItem := range flowInfo.OutputParams { + mapItem := gconv.Map(outputParamsItem) + for _, mapValue := range mapItem { + if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) || strings.Contains(req.ResultUrl, gconv.String(mapValue)) { + if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) { + // 生成 毫秒时间戳 作为 KEY + timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10) + item := make(map[string]interface{}) + item[timeKey] = htmlContentUrl + summaryResult = append(summaryResult, item) + } + if strings.Contains(req.ResultUrl, gconv.String(mapValue)) { + // 生成 毫秒时间戳 作为 KEY + timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10) + item := make(map[string]interface{}) + item[timeKey] = imgUrl + summaryResult = append(summaryResult, item) + } + continue + } + summaryResult = append(summaryResult, outputParamsItem) + } + + } + } + if !g.IsEmpty(summaryResult) { + executionReq := flowDto.UpdateFlowExecutionReq{ + Id: flowInfo.Id, + Status: flow.FlowExecutionStatusSuccess.Code(), + OutputParams: summaryResult, + } + _, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq) + } + return +} + +func TextModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) { + var url string + url, err = utils.GetFileAddressPrefix(ctx) + if err != nil { + return + } + htmlUrl, textModelName, textResultFrom, textModelResponse, _, _, _ := getNodeInfo(flowInfo) + + resultUserFrom := make(map[string]any) + resultUserFrom["desc"] = req.Desc + + var textNode []node.NodeFormField + textNode, err = TextNode(ctx, req.SessionId, textModelName, req.SkillName, textResultFrom, resultUserFrom, textModelResponse, req.FileUrl) + if err != nil { + return + } + var textUrl string + for _, item := range textNode { + if strings.Contains(item.Field, "text_url") { + textUrl = item.Value + } + } + + var htmlContentUrl string + var oldHtmlUrl string + if !g.IsEmpty(htmlUrl) { + for i, item := range htmlUrl { + var htmlBytes []byte + htmlBytes, err = GetFileBytesFromURL(url + item) + if err != nil { + return + } + htmlContent := string(htmlBytes) + + // 1) 匹配出 incUrl 的值 + incRegex := regexp.MustCompile(`incUrl\s*=\s*"([^"]+)"`) + match := incRegex.FindStringSubmatch(htmlContent) + // 2) 获取模板里原来的 incUrl + oldIncUrl := "" + if len(match) >= 2 { + oldIncUrl = match[1] // 这是模板里的旧链接 + } + // 3) 对比:不一样才替换 + if oldIncUrl == req.ResultUrl { + oldHtmlUrl = url + item + // 替换成新的链接 + htmlContent = incRegex.ReplaceAllString(htmlContent, fmt.Sprintf(`incUrl = "%s"`, url+textUrl)) + // 上传OSS(每条独立上传) + fileName := fmt.Sprintf("item_%d_%d.html", i, time.Now().UnixMilli()) + var ossResult *dto.UploadFileBytesRes + ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{ + FileBytes: []byte(htmlContent), + FileName: fileName, + }) + if err != nil { + return + } + fmt.Printf("上传OSS成功:%s", ossResult.FileURL) + htmlContentUrl = ossResult.FileURL + } + } + } + + var summaryResult []map[string]interface{} + if !g.IsEmpty(textUrl) { + for _, outputParamsItem := range flowInfo.OutputParams { + mapItem := gconv.Map(outputParamsItem) + for _, mapValue := range mapItem { + if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) || strings.Contains(req.ResultUrl, gconv.String(mapValue)) { + if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) { + // 生成 毫秒时间戳 作为 KEY + timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10) + item := make(map[string]interface{}) + item[timeKey] = htmlContentUrl + summaryResult = append(summaryResult, item) + } + if strings.Contains(req.ResultUrl, gconv.String(mapValue)) { + // 生成 毫秒时间戳 作为 KEY + timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10) + item := make(map[string]interface{}) + item[timeKey] = textUrl + summaryResult = append(summaryResult, item) + } + continue + } + summaryResult = append(summaryResult, outputParamsItem) + } + } + } + if !g.IsEmpty(summaryResult) { + executionReq := flowDto.UpdateFlowExecutionReq{ + Id: flowInfo.Id, + Status: flow.FlowExecutionStatusSuccess.Code(), + OutputParams: summaryResult, + } + _, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq) + } + return +} + +func TextNode(ctx context.Context, sessionId, modelName, skillName string, form, userForm, modelResponse map[string]any, fileUrl []string) ([]node.NodeFormField, error) { + contentStr := "你是专业内容生成助手,请严格按以下规则输出内容:1、输出标准 HTML 片段,不要 Markdown,不要 ``` 符号,不要多余解释,2、整体用
,6、列表使用
需要配图:N 张
N 是这条文案需要的图片数量,只能是数字,不能是其他文字,11、只输出 HTML 结构,不输出任何额外文字" + userForm["prompt"] = contentStr + + mapTaskResult, err := GetModelResult(ctx, modelName, skillName, form, userForm, fileUrl, sessionId, "文案生成") + if err != nil { + return nil, err + } + + resultContent := "" + for key, _ := range modelResponse { + resultContent = gconv.String(mapTaskResult[key]) + } + + // 拆分多条文案 + contentList := SplitMultiContents(resultContent) + + outputRes := make([]node.NodeFormField, 0) + for i, contentItem := range contentList { + outputRes = append(outputRes, node.NodeFormField{ + Field: fmt.Sprintf("text_content_%d", i), + Value: contentItem, + Label: fmt.Sprintf("文案内容_%d", i), + Type: "string", + Expand: ExtractImageCount(contentItem), + }) + + // 1. 构建html文本 + plainText := BuildText(contentItem) + // 2. 上传纯文本到 OSS + textFileName := fmt.Sprintf("ai_text_%d_%d.inc", time.Now().UnixMilli(), i) + var textUrl *dto.UploadFileBytesRes + textUrl, err = Upload(ctx, &dto.UploadFileBytesReq{ + FileBytes: []byte(plainText), + FileName: textFileName, + }) + if err != nil { + return nil, err + } + // 3. 把纯文本地址存入输出 + outputRes = append(outputRes, node.NodeFormField{ + Field: fmt.Sprintf("text_url:%d", i), + Value: textUrl.FileURL, + Label: fmt.Sprintf("文案纯文本_txt_%d", i), + Type: "string", + Expand: ExtractImageCount(contentItem), + }) + } + return outputRes, nil +} + +func ImgNode(ctx context.Context, sessionId, modelName, skillName string, form, userForm, modelResponse map[string]any, fileUrl []string) ([]node.NodeFormField, error) { + + mapTaskResult, err := GetModelResult(ctx, modelName, skillName, form, userForm, fileUrl, sessionId, "图片生成") + if err != nil { + return nil, err + } + + var resultContent []string + for key, _ := range modelResponse { + resultContent = gconv.Strings(mapTaskResult[key]) + } + + var images []string + for _, item := range resultContent { + mapItem := gconv.Map(item) + for _, value := range mapItem { + values, ok := value.(string) + if !ok { + return nil, fmt.Errorf("图片地址类型错误") + } + // 下载官方临时图片 + var imgBytes []byte + imgBytes, err = GetFileBytesFromURL(values) + if err != nil { + return nil, fmt.Errorf("下载图片失败: %w", err) + } + // 构造文件名 + fileName := fmt.Sprintf("ai_image_%d.png", time.Now().UnixMilli()) + // 上传到你的OSS(你项目已有的Upload方法) + var upResp *dto.UploadFileBytesRes + upResp, err = Upload(ctx, &dto.UploadFileBytesReq{ + FileName: fileName, + FileBytes: imgBytes, + }) + if err != nil { + return nil, fmt.Errorf("上传OSS失败: %w", err) + } + images = append(images, upResp.FileURL) + } + } + + var url string + url, err = utils.GetFileAddressPrefix(ctx) + if err != nil { + return nil, err + } + outputRes := make([]node.NodeFormField, 0) + + for i, item := range images { + // 图片:image_0, image_1, image_2... + outputRes = append(outputRes, node.NodeFormField{ + Field: fmt.Sprintf("image_%d", i), + Value: fmt.Sprintf("%s%s", url, item), + Label: fmt.Sprintf("图片_%d", i), + Type: "string", + }) + // 额外存储关联关系 + outputRes = append(outputRes, node.NodeFormField{ + Field: fmt.Sprintf("img_url:%d", i), + Value: fmt.Sprintf("%s%s", url, item), + Label: fmt.Sprintf("图片_img_%d关联文案ID", i), + Type: "string", + }) + } + + return outputRes, nil +} + +func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom, resultUserFrom map[string]any) { + // 1. 直接用你原来的方法(返回两个 map) + inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config) + var outputResult []node.NodeFormField + for _, valueAny := range inputMap { + if field, ok := valueAny.(node.NodeFormField); ok { + outputResult = append(outputResult, field) + } + } + + resultUserFrom = make(map[string]any) + for _, valueAny := range outputMap { + if field, ok := valueAny.(node.NodeFormField); ok { + if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") { + if strings.Contains(field.Field, "text_content") { + field.Value = StripHtmlTags(field.Value) + } + resultUserFrom[field.Label] = field + } + } + } + for _, valueAny := range modelMap { + if field, ok := valueAny.(node.NodeFormField); ok { + outputResult = append(outputResult, field) + } + } + if !nodeInput.Global.IsDialogue { + for _, item := range outputResult { + resultUserFrom[item.Label] = item + } + for _, item := range nodeInput.Config.FormConfig { + resultUserFrom[item.Label] = item + } + } + if !g.IsEmpty(nodeInput.Global.Desc) { + resultUserFrom["desc"] = node.NodeFormField{ + Value: nodeInput.Global.Desc, + Field: "desc", + Label: "描述", + Type: "text", + } + } + + resultFrom = make(map[string]any) + for key, item := range nodeInput.Config.ModelConfig.ModelForm { + resultFrom[key] = map[string]any{ + "value": item, + } + } + skillName = nodeInput.Config.SkillName + if g.IsEmpty(nodeInput.Config.SkillName) { + skillName = nodeInput.Global.SkillName + } + + return skillName, resultFrom, resultUserFrom +} + +func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, node *entity.FlowNode) (map[string]any, map[string]any, map[string]any) { + input := make(map[string]any) + output := make(map[string]any) + model := make(map[string]any) + // 1. 有引用 → 取引用节点的字段值 + if len(node.InputSource) > 0 { + for _, source := range node.InputSource { + refNodeID := source.NodeId + isQuoteOutput := source.QuoteOutput + fields := source.Field + + refNode, ok := execInput.ConfigMap[refNodeID] + if !ok { + continue + } + + inputMap := buildInputMap(refNode) + outputMap := mergeOutput(refNode.OutputResult) + modelMap := mergeModel(refNode.ModelConfig) + if isQuoteOutput { + for k, v := range outputMap { + output[k] = v + } + } + if len(fields) > 0 { + // 取指定字段 + for _, f := range fields { + + if v, ok := inputMap[f]; ok { + input[f] = v + } + if v, ok := modelMap[f]; ok { + model[f] = v + } + } + } else { + // 取全部 + for k, v := range inputMap { + input[k] = v + } + for k, v := range modelMap { + model[k] = v + } + } + } + } + return input, output, model +} + +// buildInputMap 从 FormConfig 构造输入map +func buildInputMap(node *entity.FlowNode) map[string]any { + m := make(map[string]any) + for _, item := range node.FormConfig { + m[item.Label] = item + } + return m +} + +// mergeOutput 合并节点输出 []map → 单map +func mergeOutput(output []node.NodeFormField) map[string]any { + m := make(map[string]any) + for _, item := range output { + m[item.Label] = item + } + return m +} + +// mergeOutput 合并节点输出 []map → 单map +func mergeModel(output node.ModelItem) map[string]any { + m := make(map[string]any) + // 遍历 output.ModelForm 里的每一个 key 和原始值 + for key, rawValue := range output.ModelForm { + // 包装成 { "value": 原始值 } + m[key] = map[string]any{ + "value": rawValue, + } + } + return m +}