Files
ai-agent/workflow/service/flow/lambda_node.go

654 lines
19 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package flow
import (
"ai-agent/workflow/consts/flow"
"ai-agent/workflow/consts/node"
"ai-agent/workflow/consts/public"
fileDao "ai-agent/workflow/dao/file"
flowDao "ai-agent/workflow/dao/flow"
"ai-agent/workflow/model/dto"
fileDto "ai-agent/workflow/model/dto/file"
flowDto "ai-agent/workflow/model/dto/flow"
"context"
"fmt"
"strconv"
"strings"
"sync"
"time"
"gitea.redpowerfuture.com/red-future/common/db/gfdb"
"gitea.redpowerfuture.com/red-future/common/utils"
"github.com/gogf/gf/v2/database/gdb"
"github.com/gogf/gf/v2/frame/g"
"github.com/gogf/gf/v2/util/gconv"
)
func StartLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
func FormLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
func IntentLambda(ctx context.Context, input any) (any, error) {
return input, nil
}
// JudgeLambda 分支判断核心读取IntentLambda的输出 → 返回目标节点ID做路由
func JudgeLambda(ctx context.Context, input any) (string, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return "", fmt.Errorf("入参类型错误,期望 *flowDto.NodeExecutionInput实际 %T", input)
}
// 1. 直接用你原来的方法(返回两个 map
inputMap, outputMap, modelMap := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
var outputResult []node.NodeFormField
for _, valueAny := range inputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
for _, valueAny := range outputMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
for _, valueAny := range modelMap {
if field, ok := valueAny.(node.NodeFormField); ok {
outputResult = append(outputResult, field)
}
}
contextParts := ""
for _, v := range nodeInput.Config.FormConfig {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
}
if !nodeInput.Global.IsDialogue {
for _, v := range outputResult {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
}
}
if !g.IsEmpty(nodeInput.Global.Desc) {
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, "描述", nodeInput.Global.Desc)
}
configMap := gconv.Map(nodeInput.Config.Config)
ids := gconv.Strings(configMap["branch_ids"])
branchIdNameMap := gconv.Map(configMap["branch_id_name_map"])
var branchIdNameLines []string
for _, id := range ids {
name := gconv.String(branchIdNameMap[id])
branchIdNameLines = append(branchIdNameLines, fmt.Sprintf("%s: %s", id, name))
}
getIsChatModel, err := GetIsChatModel(ctx)
if err != nil {
return "", err
}
composeResult, err := GetComposeResult(ctx, 2, getIsChatModel.Model.ModelName, "", "", []map[string]any{{"prompt": strings.Join(branchIdNameLines, "\n")}}, []map[string]any{{"prompt": contextParts}}, nodeInput.Global.FileUrl, nodeInput.Global.SessionId, nodeInput.Config.Id, "判断节点")
if err != nil {
return "", err
}
if g.IsEmpty(composeResult.TaskId) {
return "", fmt.Errorf("msg is empty")
}
content := ""
for key, _ := range getIsChatModel.Model.ResponseBody {
content = gconv.String(composeResult.Messages.Rounds[0][key])
}
fmt.Printf("JudgeLambda路由目标节点ID=%s\n", gconv.String(content))
return content, nil
}
func BatchModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
reqMap := make([]map[string]any, 0)
for _, userItem := range userFrom {
m := gconv.Map(userItem)
for _, i := range nodeInput.Config.InputSource {
for _, f := range i.Field {
val := m[f]
if !g.IsEmpty(val) {
if g.NewVar(val).IsSlice() {
slice := gconv.SliceAny(val)
for _, item := range slice {
reqMap = append(reqMap, map[string]any{f: item})
}
} else {
reqMap = append(reqMap, map[string]any{f: val})
}
}
}
}
}
// 结果按索引存放,保证顺序
res := make([][]node.NodeFormField, len(reqMap))
var wg sync.WaitGroup
// 用一个通道标记是否完成
done := make(chan struct{})
// 错误只存一个
var execErr error
// 并发执行
for idx, item := range reqMap {
wg.Add(1)
go func(idx int, userItem map[string]any) {
defer wg.Done()
singleUserFrom := []map[string]any{userItem}
output, err := TextNode(ctx, nodeInput, skillName, from, singleUserFrom)
if err != nil {
// 并发安全赋值错误
if execErr == nil {
execErr = err
}
return
}
// 直接按原索引写,顺序绝对正确
res[idx] = output
}(idx, item)
}
// 后台等待所有协程完成,然后关闭 done 通道
go func() {
wg.Wait()
close(done)
}()
// 等待全部完成
<-done
// 如果有错误,直接返回
if execErr != nil {
return nil, execErr
}
// 全局自增 i
var globalIndex int
var outputRes []node.NodeFormField
for _, items := range res {
for _, item := range items {
// 1. 拿到原来的 Field例如 "text_content:2:0"
oldField := item.Field
// 2. 找到最后一个 : 的位置
if idx := strings.LastIndex(oldField, ":"); idx != -1 {
// 3. 截断前面部分,拼接上新的 globalIndex
item.Field = oldField[:idx+1] + fmt.Sprint(globalIndex)
}
// Label 同理
oldLabel := item.Label
if idx := strings.LastIndex(oldLabel, ":"); idx != -1 {
item.Label = oldLabel[:idx+1] + fmt.Sprint(globalIndex)
}
outputRes = append(outputRes, item)
}
globalIndex++
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// TextModelLambda 构建文案
func TextModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := TextNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// ImageModelLambda 构建图片
func ImageModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := ImgNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// AudioModelLambda 构建音频
func AudioModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := AudioOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// VideoModelLambda 构建视频
func VideoModelLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
res, err := VideoOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
videoURL := make([]string, 0)
for _, v := range res {
if strings.Contains(v.Field, "content") {
videoURL = append(videoURL, gconv.String(v.Value))
}
}
if g.IsEmpty(videoURL) {
return nil, fmt.Errorf("视频合成失败:模型生成视频失败")
}
waitRes, err := VideoConcat(ctx, videoURL)
if err != nil {
return nil, err
}
msg := new(flowDto.VideoCallbackReq)
if err = gconv.Struct(waitRes, msg); err != nil {
return nil, err
}
urlPrefix, err := utils.GetFileAddressPrefix(ctx)
if err != nil {
return nil, err
}
outputRes := make([]node.NodeFormField, 0)
if nodeInput.Config.IsSaveFile {
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("video_oss_url:content:%d", 0),
Value: msg.FileURL,
Label: fmt.Sprintf("video_oss_url:content:%d", 0),
Type: "string",
})
} else {
outputRes = append(outputRes, node.NodeFormField{
Field: fmt.Sprintf("concat_video_url:content:%d", 0),
Value: urlPrefix + msg.FileURL,
Label: fmt.Sprintf("concat_video_url:content:%d", 0),
Type: "string",
})
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// HttpLambda 构建HTTP(S)接口
func HttpLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
outputRes := make([]node.NodeFormField, 0)
var err error
outputRes, err = HttpNode(ctx, nodeInput)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
// DataConversionLambda 构建数据转换
func DataConversionLambda(ctx context.Context, input any) (any, error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("入参类型错误")
}
skillName, from, userFrom := BuildParam(nodeInput)
outputRes, err := DataConversionNode(ctx, nodeInput, skillName, from, userFrom)
if err != nil {
return nil, err
}
nodeInput.Config.OutputResult = outputRes
return nodeInput, nil
}
func DataMergeLambda(ctx context.Context, input any) (res any, err error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("参数合并入参类型错误")
}
// var nodeIds []string
// for _, item := range nodeInput.Config.InputSource {
// nodeIds = append(nodeIds, item.NodeId)
// }
//
// // 检查是否所有输入节点都执行完成,并且检查是否有节点失败
// checkAllExecuted := func() (allExecuted bool, hasFailed bool, failedNode string) {
// executedCount := 0
// for _, executedNode := range nodeInput.Global.ExecutedNodes {
// // 检查是否是我们需要的输入节点,并且它失败了
// for _, targetId := range nodeIds {
// if executedNode.NodeId == targetId {
// if executedNode.Status == node.NodeExecutionStatusFailed.Code() {
// return false, true, targetId
// }
// executedCount++
// break
// }
// }
// }
// return executedCount == len(nodeIds), false, ""
// }
//
// // 初次检查
// allExecuted, hasFailed, failedNode := checkAllExecuted()
// if hasFailed {
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
//
// // 如果不是全部都已执行,阻塞等待直到全部完成、上下文取消或有节点失败
// if !allExecuted {
// // 轮询检查每500ms检查一次依赖ctx超时控制
// ticker := time.NewTicker(500 * time.Millisecond)
// defer ticker.Stop()
//
// for {
// select {
// case <-ctx.Done():
// // 如果上下文已经取消,说明已有节点报错,直接退出
// return nil, ctx.Err()
// case <-ticker.C:
// // 重新检查所有节点
// allExecuted, hasFailed, failedNode := checkAllExecuted()
// if hasFailed {
// // 有一个输入节点失败,直接退出
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
// if allExecuted {
// // 全部执行完成,退出循环继续执行
// goto allDone
// }
//
// // 再次检查上下文是否已经取消,如果已经取消则立即退出
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// }
// }
// }
//allDone:
//
// // 最终检查:所有输入节点都成功了吗
// _, hasFailed, failedNode = checkAllExecuted()
// if hasFailed {
// // 有一个输入节点失败,直接退出
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
// }
//
// // 构建已执行节点ID的map方便合并时查找
// executedMap := make(map[string]*flowDto.ExecutedNode, len(nodeInput.Global.ExecutedNodes))
// for _, en := range nodeInput.Global.ExecutedNodes {
// executedMap[en.NodeId] = &en
// }
//
// // 合并所有输入源节点的输出结果
// for _, inputSource := range nodeInput.Config.InputSource {
// // 每次循环都检查上下文是否已取消,提前退出
// select {
// case <-ctx.Done():
// return nil, ctx.Err()
// default:
// }
// // 再次检查该节点是否失败
// if en, ok := executedMap[inputSource.NodeId]; ok && en.Status == node.NodeExecutionStatusFailed.Code() {
// return nil, fmt.Errorf("输入节点[%s]执行失败", inputSource.NodeId)
// }
// sourceNodeConfig := nodeInput.Global.ConfigMap[inputSource.NodeId]
// if sourceNodeConfig != nil && len(sourceNodeConfig.OutputResult) > 0 {
// nodeInput.Config.OutputResult = append(nodeInput.Config.OutputResult, sourceNodeConfig.OutputResult...)
// }
// }
return nodeInput, nil
}
func MergeLambda(ctx context.Context, input any) (res any, err error) {
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("汇总节点入参类型错误")
}
// 1. 把所有节点输出拍平成 字段名->内容 的map
dataMap := make(map[string]node.NodeFormField)
_, outputMap, _ := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
for _, valueAny := range outputMap {
field := node.NodeFormField{}
if field, ok = valueAny.(node.NodeFormField); ok {
dataMap[field.Field] = field
}
}
// 2. 提取所有文案text_content_0,1,2...
var contents []node.NodeFormField
for i := 0; ; i++ {
key := fmt.Sprintf("text_content:%d", i)
val, has := dataMap[key]
if !has || val.Value == "" {
break
}
contents = append(contents, val)
}
// 3. 提取所有图片image_0,1,2...
var images []string
for i := 0; ; i++ {
key := fmt.Sprintf("img_url:%d", i)
val, has := dataMap[key]
if !has || val.Value == "" {
break
}
images = append(images, gconv.String(val.Value))
}
// 4. 🔥 核心算法:图片按顺序连续归属给每条文案
textImgMap := make(map[int][]string) // key:文案下标value:图片列表
if len(contents) > 0 && len(images) > 0 {
imgIndex := 0 // 当前用到第几张图片
totalImg := len(images)
for i, item := range contents {
// 图片已分配完,直接退出
if imgIndex >= totalImg {
break
}
// 当前文案需要挂载的图片数量
needCount := gconv.Int(item.Expand)
if needCount <= 0 {
continue
}
var imgList []string
for imgc := 0; imgc < needCount; imgc++ {
// 关键:必须判断是否越界
if imgIndex >= totalImg {
break
}
imgList = append(imgList, images[imgIndex])
imgIndex++
}
// 有图片才存入 map
if len(imgList) > 0 {
textImgMap[i] = imgList
}
}
}
type Item struct {
Content string // 文案(可为空)
Images []string // 图片(可空、可多张)
}
// 🔥 把现有数据转换成通用 Item 列表(支持:纯文案、纯图片、图文任意组合)
var allItems []Item
url, err := utils.GetFileAddressPrefix(ctx)
if err != nil {
return nil, err
}
// 情况1有文案 → 按文案条目生成 Item每条文案+对应图片)
if len(contents) > 0 {
for i, val := range contents {
item := Item{
Content: url + gconv.String(val.Value), // 文案
Images: textImgMap[i], // 自动绑定该条目的图片(没有则为空切片)
}
allItems = append(allItems, item)
}
} else {
// 情况2没有文案只有图片 → 每张/每组图片生成独立 Item纯图片条目
if len(images) > 0 {
for _, img := range images {
allItems = append(allItems, Item{
Content: "",
Images: []string{img},
})
}
}
}
// 5. 生成多条独立HTML记录通用方案任意图文组合每条独立生成+独立上传)
var outputRecords []node.NodeFormField
// 遍历所有【独立图文条目】 → 每条生成独立HTML、独立上传OSS、独立输出记录
for idx, item := range allItems {
// 生成单条HTML
htmlContent := BuildHtml(item.Content, item.Images)
outputRecords = append(outputRecords,
node.NodeFormField{
Field: fmt.Sprintf("item_html_%d", idx),
Value: htmlContent,
Label: fmt.Sprintf("条目%d HTML", idx+1),
Type: "textarea",
},
)
if nodeInput.Config.IsSaveFile {
// 上传OSS每条独立上传
fileName := fmt.Sprintf("item_%d_%d.html", idx, time.Now().UnixMilli())
ossResult, err := Upload(ctx, &dto.UploadFileBytesReq{
FileBytes: []byte(htmlContent),
FileName: fileName,
})
if err != nil {
return nil, err
}
outputRecords = append(outputRecords,
node.NodeFormField{
Field: fmt.Sprintf("item_html_url_%d", idx),
Value: ossResult.FileURL,
Label: fmt.Sprintf("条目%d 地址", idx+1),
Type: "text",
},
)
}
}
// 最终输出多条记录
nodeInput.Config.OutputResult = outputRecords
return nodeInput, nil
}
func SummaryLambda(ctx context.Context, input any) (any, error) {
execInput, ok := input.(*flowDto.NodeExecutionInput)
if !ok {
return nil, fmt.Errorf("汇总节点入参类型错误,实际是 %T", input)
}
// 聚合所有已执行节点的输出结果
var summaryResult []map[string]interface{}
for _, executedNode := range execInput.Global.ExecutedNodes {
nodeID := executedNode.NodeId
nodeConfig := execInput.Global.ConfigMap[nodeID]
if nodeConfig != nil && len(nodeConfig.OutputResult) > 0 {
for _, field := range nodeConfig.OutputResult {
if strings.Contains(field.Field, "http_file_url") || strings.Contains(field.Field, "audio_oss_url") || strings.Contains(field.Field, "video_oss_url") || strings.Contains(field.Field, "item_html_url") || strings.Contains(field.Field, "img_oss_url") || strings.Contains(field.Field, "text_url") {
// 生成 毫秒时间戳 作为 KEY
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
item := make(map[string]interface{})
item[timeKey] = field.Value
summaryResult = append(summaryResult, item)
}
}
}
}
// 把汇总结果存入当前节点的输出
g.Log().Info(ctx, fmt.Sprintf("结果汇总完成,汇总数据:%+v", summaryResult))
err := gfdb.DB(ctx, public.DbNameBlackDeacon).Transaction(ctx, func(ctx context.Context, tx gdb.TX) error {
flowInfo, err := flowDao.FlowExecutionDao.Get(ctx, &flowDto.GetFlowExecutionReq{
SessionId: execInput.Global.SessionId,
})
if err != nil {
return err
}
executionReq := flowDto.UpdateFlowExecutionReq{
Id: execInput.Global.ExecutionId,
Status: flow.FlowExecutionStatusSuccess.Code(),
OutputParams: summaryResult,
}
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
if flowInfo != nil {
var url string
url, err = utils.GetFileAddressPrefix(ctx)
if err != nil {
return err
}
createFileTempReq := make([]*fileDto.CreateFileTempReq, 0, len(flowInfo.OutputParams))
for _, fileUrl := range flowInfo.OutputParams {
m := gconv.Map(fileUrl)
for _, v := range m {
var createReq = new(fileDto.CreateFileTempReq)
createReq.BusinessId = flowInfo.SessionId
createReq.FileUrl = url + gconv.String(v)
createFileTempReq = append(createFileTempReq, createReq)
}
}
if len(createFileTempReq) > 0 {
_, err = fileDao.FileTempDao.BatchInsert(ctx, createFileTempReq)
if err != nil {
return err
}
}
}
return nil
})
return execInput, err
}
// CustomLambda 构建自定义
func CustomLambda(ctx context.Context, input any) (any, error) {
fmt.Println("CustomLambda:", input)
return input, nil
}