Files
ai-agent/workflow/service/flow/lambda_node_imp.go
2026-06-08 13:39:20 +08:00

600 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package flow
import (
"ai-agent/workflow/consts/flow"
"ai-agent/workflow/consts/node"
flowDao "ai-agent/workflow/dao/flow"
"ai-agent/workflow/model/dto"
flowDto "ai-agent/workflow/model/dto/flow"
"ai-agent/workflow/model/entity"
"context"
"fmt"
"regexp"
"strconv"
"strings"
"time"
"gitea.com/red-future/common/utils"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
func getNodeInfo(flowInfo *entity.FlowExecution) (htmlUrl []string, textModelName string, textResultFrom map[string]any, textModelResponse map[string]any, imgModelName string, imgResultFrom map[string]any, imgModelResponse map[string]any) {
textModelName = ""
textResultFrom = make(map[string]any)
textModelResponse = make(map[string]any)
imgModelName = ""
imgResultFrom = make(map[string]any)
imgModelResponse = make(map[string]any)
// 查询节点中是否包含结果合并节点
for _, item := range flowInfo.NodeInputParams {
if item.NodeCode == node.NodeTypeMerge {
for _, outputParamsItem := range flowInfo.OutputParams {
outputParamsMap := gconv.Map(outputParamsItem)
for _, mapItem := range outputParamsMap {
if strings.HasSuffix(gconv.String(mapItem), ".html") {
htmlUrl = append(htmlUrl, gconv.String(mapItem))
}
}
}
}
if item.NodeCode == node.NodeTypeTextModel {
textModelName = item.ModelConfig.ModelName
textModelResponse = item.ModelConfig.ModelResponse
for key, modelFormItem := range item.ModelConfig.ModelForm {
textResultFrom[key] = map[string]any{
"value": modelFormItem,
}
}
}
if item.NodeCode == node.NodeTypeImageModel {
imgModelName = item.ModelConfig.ModelName
imgModelResponse = item.ModelConfig.ModelResponse
for key, modelFormItem := range item.ModelConfig.ModelForm {
imgResultFrom[key] = map[string]any{
"value": modelFormItem,
}
}
}
}
return htmlUrl, textModelName, textResultFrom, textModelResponse, imgModelName, imgResultFrom, imgModelResponse
}
func TextImgModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) {
_, textModelName, textResultFrom, textModelResponse, imgModelName, imgResultFrom, imgModelResponse := getNodeInfo(flowInfo)
resultUserFrom := make(map[string]any)
resultUserFrom["desc"] = req.Desc
var textNode []node.NodeFormField
textNode, err = TextNode(ctx, req.SessionId, textModelName, req.SkillName, textResultFrom, resultUserFrom, textModelResponse, req.FileUrl)
if err != nil {
return
}
var textContent string
var textUrl string
for _, item := range textNode {
if strings.Contains(item.Field, "text_content") {
textContent = StripHtmlTags(item.Value)
}
if strings.Contains(item.Field, "text_url") {
textUrl = item.Value
}
}
resultUserFrom["text_content"] = textContent
var imgNode []node.NodeFormField
imgNode, err = ImgNode(ctx, req.SessionId, imgModelName, req.SkillName, imgResultFrom, resultUserFrom, imgModelResponse, req.FileUrl)
if err != nil {
return
}
var imgUrl []string
for _, item := range imgNode {
if strings.Contains(item.Field, "img_url") {
imgUrl = append(imgUrl, item.Value)
}
}
// 生成单条HTML
htmlContent := BuildHtml(textUrl, imgUrl)
// 上传OSS每条独立上传
fileName := fmt.Sprintf("item_%d_%d.html", 0, time.Now().UnixMilli())
var ossResult *dto.UploadFileBytesRes
ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(htmlContent),
FileName: fileName,
})
if err != nil {
return
}
fmt.Printf("上传OSS成功%s", ossResult.FileURL)
var summaryResult []map[string]interface{}
for _, outputParamsItem := range flowInfo.OutputParams {
mapItem := gconv.Map(outputParamsItem)
for _, mapValue := range mapItem {
if strings.Contains(req.ResultUrl, gconv.String(mapValue)) {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = ossResult.FileURL
summaryResult = append(summaryResult, item)
continue
}
summaryResult = append(summaryResult, outputParamsItem)
}
}
if !g.IsEmpty(summaryResult) {
executionReq := flowDto.UpdateFlowExecutionReq{
Id: flowInfo.Id,
Status: flow.FlowExecutionStatusSuccess.Code(),
OutputParams: summaryResult,
}
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
}
return
}
func ImgModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) {
var url string
url, err = utils.GetFileAddressPrefix(ctx)
if err != nil {
return
}
htmlUrl, _, _, _, imgModelName, imgResultFrom, imgModelResponse := getNodeInfo(flowInfo)
resultUserFrom := make(map[string]any)
resultUserFrom["desc"] = req.Desc
var imgNode []node.NodeFormField
imgNode, err = ImgNode(ctx, req.SessionId, imgModelName, req.SkillName, imgResultFrom, resultUserFrom, imgModelResponse, req.FileUrl)
if err != nil {
return
}
var imgUrl string
for _, item := range imgNode {
if strings.Contains(item.Field, "img_url") {
imgUrl = item.Value
}
}
var htmlContentUrl string
var oldHtmlUrl string
if !g.IsEmpty(htmlUrl) {
for i, item := range htmlUrl {
var htmlBytes []byte
htmlBytes, err = GetFileBytesFromURL(url + item)
if err != nil {
return
}
htmlContent := string(htmlBytes)
imgSrcFromHtml := GetAllImgSrcFromHtml(htmlContent)
// 3. 标记是否需要替换
needReplace := false
for _, imgSrc := range imgSrcFromHtml {
if imgSrc == req.ResultUrl {
needReplace = true
break // 找到一个就可以替换
}
}
// 4. 如果匹配到,执行替换(把旧的 req.ResultUrl 替换成 新链接)
if needReplace {
oldHtmlUrl = url + item
htmlContent = ReplaceImgSrc(htmlContent, req.ResultUrl, imgUrl)
// 上传OSS每条独立上传
fileName := fmt.Sprintf("item_%d_%d.html", i, time.Now().UnixMilli())
var ossResult *dto.UploadFileBytesRes
ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(htmlContent),
FileName: fileName,
})
if err != nil {
return
}
fmt.Printf("上传OSS成功%s", ossResult.FileURL)
htmlContentUrl = ossResult.FileURL
}
}
}
var summaryResult []map[string]interface{}
if !g.IsEmpty(imgUrl) {
for _, outputParamsItem := range flowInfo.OutputParams {
mapItem := gconv.Map(outputParamsItem)
for _, mapValue := range mapItem {
if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) || strings.Contains(req.ResultUrl, gconv.String(mapValue)) {
if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = htmlContentUrl
summaryResult = append(summaryResult, item)
}
if strings.Contains(req.ResultUrl, gconv.String(mapValue)) {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = imgUrl
summaryResult = append(summaryResult, item)
}
continue
}
summaryResult = append(summaryResult, outputParamsItem)
}
}
}
if !g.IsEmpty(summaryResult) {
executionReq := flowDto.UpdateFlowExecutionReq{
Id: flowInfo.Id,
Status: flow.FlowExecutionStatusSuccess.Code(),
OutputParams: summaryResult,
}
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
}
return
}
func TextModelSingleLambda(ctx context.Context, req *flowDto.ExecuteReq, flowInfo *entity.FlowExecution) (err error) {
var url string
url, err = utils.GetFileAddressPrefix(ctx)
if err != nil {
return
}
htmlUrl, textModelName, textResultFrom, textModelResponse, _, _, _ := getNodeInfo(flowInfo)
resultUserFrom := make(map[string]any)
resultUserFrom["desc"] = req.Desc
var textNode []node.NodeFormField
textNode, err = TextNode(ctx, req.SessionId, textModelName, req.SkillName, textResultFrom, resultUserFrom, textModelResponse, req.FileUrl)
if err != nil {
return
}
var textUrl string
for _, item := range textNode {
if strings.Contains(item.Field, "text_url") {
textUrl = item.Value
}
}
var htmlContentUrl string
var oldHtmlUrl string
if !g.IsEmpty(htmlUrl) {
for i, item := range htmlUrl {
var htmlBytes []byte
htmlBytes, err = GetFileBytesFromURL(url + item)
if err != nil {
return
}
htmlContent := string(htmlBytes)
// 1) 匹配出 incUrl 的值
incRegex := regexp.MustCompile(`incUrl\s*=\s*"([^"]+)"`)
match := incRegex.FindStringSubmatch(htmlContent)
// 2) 获取模板里原来的 incUrl
oldIncUrl := ""
if len(match) >= 2 {
oldIncUrl = match[1] // 这是模板里的旧链接
}
// 3) 对比:不一样才替换
if oldIncUrl == req.ResultUrl {
oldHtmlUrl = url + item
// 替换成新的链接
htmlContent = incRegex.ReplaceAllString(htmlContent, fmt.Sprintf(`incUrl = "%s"`, url+textUrl))
// 上传OSS每条独立上传
fileName := fmt.Sprintf("item_%d_%d.html", i, time.Now().UnixMilli())
var ossResult *dto.UploadFileBytesRes
ossResult, err = Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(htmlContent),
FileName: fileName,
})
if err != nil {
return
}
fmt.Printf("上传OSS成功%s", ossResult.FileURL)
htmlContentUrl = ossResult.FileURL
}
}
}
var summaryResult []map[string]interface{}
if !g.IsEmpty(textUrl) {
for _, outputParamsItem := range flowInfo.OutputParams {
mapItem := gconv.Map(outputParamsItem)
for _, mapValue := range mapItem {
if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) || strings.Contains(req.ResultUrl, gconv.String(mapValue)) {
if strings.Contains(oldHtmlUrl, gconv.String(mapValue)) {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = htmlContentUrl
summaryResult = append(summaryResult, item)
}
if strings.Contains(req.ResultUrl, gconv.String(mapValue)) {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = textUrl
summaryResult = append(summaryResult, item)
}
continue
}
summaryResult = append(summaryResult, outputParamsItem)
}
}
}
if !g.IsEmpty(summaryResult) {
executionReq := flowDto.UpdateFlowExecutionReq{
Id: flowInfo.Id,
Status: flow.FlowExecutionStatusSuccess.Code(),
OutputParams: summaryResult,
}
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
}
return
}
func TextNode(ctx context.Context, sessionId, modelName, skillName string, form, userForm, modelResponse map[string]any, fileUrl []string) ([]node.NodeFormField, error) {
contentStr := "你是专业内容生成助手请严格按以下规则输出内容1、输出标准 HTML 片段,不要 Markdown不要 ``` 符号不要多余解释2、整体用 <div class='report-container'> 包裹3、主标题使用 <h2 class='title'>4、章节标题使用 <h3 class='section-title'>5、正文段落使用 <p class='paragraph'>6、列表使用 <ul class='list'><li>...</li></ul>7、重点内容使用 <strong> 加粗8、段落之间清晰分隔结构规整9、如果生成多条文案每条文案独立用 <div class='content-item' id='content-{序号}'> 包裹序号从1开始10、每条文案内部必须在最上方添加一行固定格式<p class='image-count'>需要配图N 张</p> N 是这条文案需要的图片数量只能是数字不能是其他文字11、只输出 HTML 结构,不输出任何额外文字"
userForm["prompt"] = contentStr
mapTaskResult, err := GetModelResult(ctx, modelName, skillName, form, userForm, fileUrl, sessionId, "文案生成")
if err != nil {
return nil, err
}
resultContent := ""
for key, _ := range modelResponse {
resultContent = gconv.String(mapTaskResult[key])
}
// 拆分多条文案
contentList := SplitMultiContents(resultContent)
outputRes := make([]node.NodeFormField, 0)
for i, contentItem := range contentList {
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("text_content_%d", i),
Value: contentItem,
Label: fmt.Sprintf("文案内容_%d", i),
Type: "string",
Expand: ExtractImageCount(contentItem),
})
// 1. 构建html文本
plainText := BuildText(contentItem)
// 2. 上传纯文本到 OSS
textFileName := fmt.Sprintf("ai_text_%d_%d.inc", time.Now().UnixMilli(), i)
var textUrl *dto.UploadFileBytesRes
textUrl, err = Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(plainText),
FileName: textFileName,
})
if err != nil {
return nil, err
}
// 3. 把纯文本地址存入输出
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("text_url:%d", i),
Value: textUrl.FileURL,
Label: fmt.Sprintf("文案纯文本_txt_%d", i),
Type: "string",
Expand: ExtractImageCount(contentItem),
})
}
return outputRes, nil
}
func ImgNode(ctx context.Context, sessionId, modelName, skillName string, form, userForm, modelResponse map[string]any, fileUrl []string) ([]node.NodeFormField, error) {
mapTaskResult, err := GetModelResult(ctx, modelName, skillName, form, userForm, fileUrl, sessionId, "图片生成")
if err != nil {
return nil, err
}
var resultContent []string
for key, _ := range modelResponse {
resultContent = gconv.Strings(mapTaskResult[key])
}
var images []string
for _, item := range resultContent {
mapItem := gconv.Map(item)
for _, value := range mapItem {
values, ok := value.(string)
if !ok {
return nil, fmt.Errorf("图片地址类型错误")
}
// 下载官方临时图片
var imgBytes []byte
imgBytes, err = GetFileBytesFromURL(values)
if err != nil {
return nil, fmt.Errorf("下载图片失败: %w", err)
}
// 构造文件名
fileName := fmt.Sprintf("ai_image_%d.png", time.Now().UnixMilli())
// 上传到你的OSS你项目已有的Upload方法
var upResp *dto.UploadFileBytesRes
upResp, err = Upload(ctx, &dto.UploadFileBytesReq{
FileName: fileName,
FileBytes: imgBytes,
})
if err != nil {
return nil, fmt.Errorf("上传OSS失败: %w", err)
}
images = append(images, upResp.FileURL)
}
}
var url string
url, err = utils.GetFileAddressPrefix(ctx)
if err != nil {
return nil, err
}
outputRes := make([]node.NodeFormField, 0)
for i, item := range images {
// 图片image_0, image_1, image_2...
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("image_%d", i),
Value: fmt.Sprintf("%s%s", url, item),
Label: fmt.Sprintf("图片_%d", i),
Type: "string",
})
// 额外存储关联关系
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("img_url:%d", i),
Value: fmt.Sprintf("%s%s", url, item),
Label: fmt.Sprintf("图片_img_%d关联文案ID", i),
Type: "string",
})
}
return outputRes, nil
}
func BuildParam(nodeInput *flowDto.NodeExecutionInput) (skillName string, resultFrom, resultUserFrom map[string]any) {
// 1. 直接用你原来的方法(返回两个 map
inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
var outputResult []node.NodeFormField
for _, valueAny := range inputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
resultUserFrom = make(map[string]any)
for _, valueAny := range outputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
if !strings.Contains(field.Field, "text_url") && !strings.Contains(field.Field, "img_url") {
if strings.Contains(field.Field, "text_content") {
field.Value = StripHtmlTags(field.Value)
}
resultUserFrom[field.Label] = field
}
}
}
for _, valueAny := range modelMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
if !nodeInput.Global.IsDialogue {
for _, item := range outputResult {
resultUserFrom[item.Label] = item
}
for _, item := range nodeInput.Config.FormConfig {
resultUserFrom[item.Label] = item
}
}
if !g.IsEmpty(nodeInput.Global.Desc) {
resultUserFrom["desc"] = node.NodeFormField{
Value: nodeInput.Global.Desc,
Field: "desc",
Label: "描述",
Type: "text",
}
}
resultFrom = make(map[string]any)
for key, item := range nodeInput.Config.ModelConfig.ModelForm {
resultFrom[key] = map[string]any{
"value": item,
}
}
skillName = nodeInput.Config.SkillName
if g.IsEmpty(nodeInput.Config.SkillName) {
skillName = nodeInput.Global.SkillName
}
return skillName, resultFrom, resultUserFrom
}
func GetNodeContextContent(execInput *flowDto.FlowExecutionInput, node *entity.FlowNode) (map[string]any, map[string]any, map[string]any) {
input := make(map[string]any)
output := make(map[string]any)
model := make(map[string]any)
// 1. 有引用 → 取引用节点的字段值
if len(node.InputSource) > 0 {
for _, source := range node.InputSource {
refNodeID := source.NodeId
isQuoteOutput := source.QuoteOutput
fields := source.Field
refNode, ok := execInput.ConfigMap[refNodeID]
if !ok {
continue
}
inputMap := buildInputMap(refNode)
outputMap := mergeOutput(refNode.OutputResult)
modelMap := mergeModel(refNode.ModelConfig)
if isQuoteOutput {
for k, v := range outputMap {
output[k] = v
}
}
if len(fields) > 0 {
// 取指定字段
for _, f := range fields {
if v, ok := inputMap[f]; ok {
input[f] = v
}
if v, ok := modelMap[f]; ok {
model[f] = v
}
}
} else {
// 取全部
for k, v := range inputMap {
input[k] = v
}
for k, v := range modelMap {
model[k] = v
}
}
}
}
return input, output, model
}
// buildInputMap 从 FormConfig 构造输入map
func buildInputMap(node *entity.FlowNode) map[string]any {
m := make(map[string]any)
for _, item := range node.FormConfig {
m[item.Label] = item
}
return m
}
// mergeOutput 合并节点输出 []map → 单map
func mergeOutput(output []node.NodeFormField) map[string]any {
m := make(map[string]any)
for _, item := range output {
m[item.Label] = item
}
return m
}
// mergeOutput 合并节点输出 []map → 单map
func mergeModel(output node.ModelItem) map[string]any {
m := make(map[string]any)
// 遍历 output.ModelForm 里的每一个 key 和原始值
for key, rawValue := range output.ModelForm {
// 包装成 { "value": 原始值 }
m[key] = map[string]any{
"value": rawValue,
}
}
return m
}