feat: 新增主动拉取与多类型回调功能
- 新增 ActivePull 实体、DAO、DTO 及 Service,支持主动拉取任务管理 - 新增 ComposeCallback、VideoCallback、HttpNodeCallback 多类型回调接口 - FlowExecution 增加 NodeGroupId 和 TotalTokens 字段,支持节点组追踪与 Token 统计 - ExecutedNodes 结构由字符串列表改为包含执行状态的节点对象列表 - 重构回调通知机制,统一 Notify 函数调用 - 优化输出项类型判断逻辑,新增文件类型标识
This commit is contained in:
@@ -5,8 +5,10 @@ import (
|
||||
"ai-agent/workflow/consts/node"
|
||||
fileDao "ai-agent/workflow/dao/file"
|
||||
flowDao "ai-agent/workflow/dao/flow"
|
||||
nodeDao "ai-agent/workflow/dao/node"
|
||||
fileDto "ai-agent/workflow/model/dto/file"
|
||||
flowDto "ai-agent/workflow/model/dto/flow"
|
||||
nodeDto "ai-agent/workflow/model/dto/node"
|
||||
"ai-agent/workflow/model/entity"
|
||||
"context"
|
||||
"errors"
|
||||
@@ -15,12 +17,14 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/cloudwego/eino/compose"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
@@ -121,22 +125,30 @@ func (s *flowExecutionService) List(ctx context.Context, req *flowDto.ListFlowEx
|
||||
item := &tempItems[idx]
|
||||
val := item.Content
|
||||
suffix := "内容"
|
||||
|
||||
switch {
|
||||
case strings.Contains(val, "img") || strings.Contains(val, "png") || strings.Contains(val, "jpg"):
|
||||
ext := ""
|
||||
ext = GetFileTypeByPath(val)
|
||||
if ext == "image" {
|
||||
suffix = "图片"
|
||||
case strings.Contains(val, "html") || strings.Contains(val, "HTML"):
|
||||
suffix = "HTML"
|
||||
case strings.Contains(val, "inc") || len(val) > 50:
|
||||
}
|
||||
if ext == "video" {
|
||||
suffix = "视频"
|
||||
}
|
||||
if ext == "audio" {
|
||||
suffix = "音频"
|
||||
}
|
||||
if ext == "text" {
|
||||
suffix = "文案"
|
||||
}
|
||||
|
||||
if ext == "html" {
|
||||
suffix = "HTML"
|
||||
}
|
||||
suffixCount[suffix]++
|
||||
item.Type = ext
|
||||
item.Label = fmt.Sprintf("%s_%d", suffix, suffixCount[suffix])
|
||||
}
|
||||
|
||||
// 组装节点
|
||||
node := flowDto.FlowNode{
|
||||
flowNode := flowDto.FlowNode{
|
||||
FlowName: displayFlowName,
|
||||
Id: execution.Id,
|
||||
SessionId: gconv.String(execution.SessionId),
|
||||
@@ -147,7 +159,7 @@ func (s *flowExecutionService) List(ctx context.Context, req *flowDto.ListFlowEx
|
||||
dateMap[createDate] = &[]flowWrap{}
|
||||
}
|
||||
*dateMap[createDate] = append(*dateMap[createDate], flowWrap{
|
||||
flowNode: node,
|
||||
flowNode: flowNode,
|
||||
createdAt: execution.CreatedAt,
|
||||
})
|
||||
}
|
||||
@@ -188,6 +200,12 @@ func (s *flowExecutionService) List(ctx context.Context, req *flowDto.ListFlowEx
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ComposeCallback 提示词回调接口
|
||||
func (s *flowExecutionService) ComposeCallback(ctx context.Context, req *flowDto.ComposeCallbackReq) (err error) {
|
||||
Notify(req.TaskId, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModelCallback 模型回调接口
|
||||
func (s *flowExecutionService) ModelCallback(ctx context.Context, req *flowDto.ModelCallbackReq) (err error) {
|
||||
// 唤醒等待的任务
|
||||
@@ -195,43 +213,19 @@ func (s *flowExecutionService) ModelCallback(ctx context.Context, req *flowDto.M
|
||||
return nil
|
||||
}
|
||||
|
||||
// 全局等待任务回调的工具
|
||||
var (
|
||||
asyncMu sync.Mutex
|
||||
asyncTasks = make(map[string]chan any)
|
||||
)
|
||||
|
||||
// Wait 阻塞等待回调结果
|
||||
// 调用后会一直卡住,直到 Notify 唤醒 或 超时/取消
|
||||
func Wait(ctx context.Context, taskId string) (any, error) {
|
||||
asyncMu.Lock()
|
||||
ch := make(chan any, 1)
|
||||
asyncTasks[taskId] = ch
|
||||
asyncMu.Unlock()
|
||||
|
||||
select {
|
||||
case result := <-ch:
|
||||
return result, nil
|
||||
case <-ctx.Done():
|
||||
asyncMu.Lock()
|
||||
delete(asyncTasks, taskId)
|
||||
asyncMu.Unlock()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
// VideoCallback 视频拼接回调接口
|
||||
func (s *flowExecutionService) VideoCallback(ctx context.Context, req *flowDto.VideoCallbackReq) (err error) {
|
||||
// 唤醒等待的任务
|
||||
Notify(req.TaskId, req)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Notify 回调时调用,唤醒等待的任务
|
||||
func Notify(taskId string, result any) {
|
||||
asyncMu.Lock()
|
||||
defer asyncMu.Unlock()
|
||||
|
||||
ch, exist := asyncTasks[taskId]
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
|
||||
ch <- result
|
||||
delete(asyncTasks, taskId)
|
||||
// HttpNodeCallback http节点回调接口
|
||||
func (s *flowExecutionService) HttpNodeCallback(ctx context.Context) (err error) {
|
||||
r := g.RequestFromCtx(ctx)
|
||||
taskId := r.Get("task_id").String()
|
||||
Notify(taskId, r)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ===================== 核心改造:替换为 sync.Map 存储取消上下文 =====================
|
||||
@@ -298,11 +292,13 @@ func (s *flowExecutionService) Execute(ctx context.Context, req *flowDto.Execute
|
||||
}
|
||||
var executionId int64
|
||||
var isDialogue bool
|
||||
var nodeGroupId = uuid.NewString()
|
||||
if flowInfo == nil {
|
||||
isDialogue = false
|
||||
var r = new(flowDto.CreateFlowExecutionReq)
|
||||
r.FlowUserId = req.FlowId
|
||||
r.FlowName = req.FlowName
|
||||
r.NodeGroupId = nodeGroupId
|
||||
r.TriggerType = flow.FlowExecutionTriggerTypeManual.Code()
|
||||
r.FlowContent = req.FlowContent
|
||||
r.NodeInputParams = req.NodeInputParams
|
||||
@@ -327,9 +323,10 @@ func (s *flowExecutionService) Execute(ctx context.Context, req *flowDto.Execute
|
||||
cancelMap.Store(traceId, cancel)
|
||||
}
|
||||
executionReq := flowDto.UpdateFlowExecutionReq{
|
||||
Id: executionId,
|
||||
Status: flow.FlowExecutionStatusRunning.Code(),
|
||||
TraceId: traceId,
|
||||
Id: executionId,
|
||||
NodeGroupId: nodeGroupId,
|
||||
Status: flow.FlowExecutionStatusRunning.Code(),
|
||||
TraceId: traceId,
|
||||
}
|
||||
_, err = flowDao.FlowExecutionDao.Update(ctx, &executionReq)
|
||||
if err != nil {
|
||||
@@ -352,6 +349,7 @@ func (s *flowExecutionService) Execute(ctx context.Context, req *flowDto.Execute
|
||||
}
|
||||
|
||||
if isDialogue && !g.IsEmpty(flowInfo) && !g.IsEmpty(req.ResultUrl) {
|
||||
req.NodeGroupId = nodeGroupId
|
||||
if strings.HasSuffix(gconv.String(req.ResultUrl), ".inc") {
|
||||
err = TextModelSingleLambda(ctx, req, flowInfo)
|
||||
return
|
||||
@@ -440,6 +438,7 @@ func (s *flowExecutionService) Execute(ctx context.Context, req *flowDto.Execute
|
||||
// ✅【第4步】构建全局执行入参(现在 schemaMap 是有值的!)
|
||||
// =========================================================================
|
||||
execInput := &flowDto.FlowExecutionInput{
|
||||
NodeGroupId: nodeGroupId,
|
||||
IsDialogue: isDialogue,
|
||||
ExecutionId: executionId,
|
||||
ConfigMap: configMap,
|
||||
@@ -476,6 +475,17 @@ func (s *flowExecutionService) Execute(ctx context.Context, req *flowDto.Execute
|
||||
|
||||
// BuildGraphFromFlowContent 根据前端保存的工作流JSON,自动构建执行图
|
||||
func BuildGraphFromFlowContent(ctx context.Context, flowContent *entity.FlowInfo, judge2IntentNodeMap map[string]string, summaryNodeID string) (compose.Runnable[any, any], error) {
|
||||
// 注册自定义合并函数:处理 *flowDto.FlowExecutionInput 类型合并
|
||||
// 由于 ConfigMap 是 map 引用类型,所有并行分支修改已经写入共享内存
|
||||
// 直接返回第一个实例即可,所有修改都已经可见
|
||||
compose.RegisterValuesMergeFunc(func(values []*flowDto.FlowExecutionInput) (*flowDto.FlowExecutionInput, error) {
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
// 返回第一个实例,ConfigMap 是指针,所有修改都已经写入共享数据结构
|
||||
return values[0], nil
|
||||
})
|
||||
|
||||
graph := compose.NewGraph[any, any]()
|
||||
nodeMap := make(map[string]entity.FlowNode)
|
||||
|
||||
@@ -582,7 +592,7 @@ func BuildGraphFromFlowContent(ctx context.Context, flowContent *entity.FlowInfo
|
||||
}
|
||||
_ = graph.AddEdge(summaryNodeID, compose.END)
|
||||
|
||||
return graph.Compile(ctx, compose.WithGraphName("auto_build_workflow"))
|
||||
return graph.Compile(ctx, compose.WithGraphName("auto_build_workflow"), compose.WithNodeTriggerMode(compose.AllPredecessor))
|
||||
}
|
||||
|
||||
// -------------------------- 节点自动注册器(核心分发) --------------------------
|
||||
@@ -606,7 +616,7 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
|
||||
}
|
||||
|
||||
// 获取入参 - 适配切片类型:遍历所有来源节点
|
||||
var realInput any
|
||||
realInput := new(flowDto.NodeExecutionInput)
|
||||
if len(flowNode.InputSource) > 0 { // 改为判断切片长度
|
||||
// 遍历所有指定的来源节点,聚合输出结果
|
||||
for _, inputSource := range flowNode.InputSource { // 遍历切片
|
||||
@@ -621,19 +631,54 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
|
||||
Config: currentConfig,
|
||||
Global: execInput, // ✅ 把【全部节点】的对象直接塞进来
|
||||
}
|
||||
|
||||
// 执行节点
|
||||
output, err := lambda(ctx, realInput)
|
||||
// ✅ 插入节点执行记录,初始状态为运行中
|
||||
startTime := time.Now()
|
||||
nodeExecutionId, err := nodeDao.NodeExecutionDao.Insert(ctx, &nodeDto.CreateNodeExecutionReq{
|
||||
FlowExecutionId: execInput.ExecutionId,
|
||||
NodeId: nodeID,
|
||||
NodeName: flowNode.Name,
|
||||
NodeGroupId: execInput.NodeGroupId,
|
||||
InputParams: realInput,
|
||||
Status: node.NodeExecutionStatusRunning.Code(),
|
||||
})
|
||||
if err != nil {
|
||||
// 记录失败到已执行列表
|
||||
execInput.ExecutedNodes = append(execInput.ExecutedNodes, flowDto.ExecutedNode{
|
||||
NodeId: nodeID,
|
||||
Status: node.NodeExecutionStatusFailed.Code(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
// ✅ 自动把当前节点ID 加入已执行列表
|
||||
execInput.ExecutedNodes = append(execInput.ExecutedNodes, nodeID)
|
||||
|
||||
// 输出存入 FlowNodeConfig
|
||||
if outConfig, ok := output.(*entity.FlowNode); ok {
|
||||
currentConfig.OutputResult = outConfig.OutputResult
|
||||
realInput.NodeExecutionId = nodeExecutionId
|
||||
// 执行节点
|
||||
_, err = lambda(ctx, realInput)
|
||||
durationMs := time.Since(startTime).Milliseconds()
|
||||
updateReq := &nodeDto.UpdateNodeExecutionReq{
|
||||
Id: nodeExecutionId,
|
||||
InputParams: realInput,
|
||||
DurationMs: durationMs,
|
||||
}
|
||||
if err != nil {
|
||||
// 执行失败,更新状态
|
||||
updateReq.Status = node.NodeExecutionStatusFailed.Code()
|
||||
updateReq.ErrorMessage = err.Error()
|
||||
_, _ = nodeDao.NodeExecutionDao.Update(ctx, updateReq)
|
||||
// 记录失败到已执行列表
|
||||
execInput.ExecutedNodes = append(execInput.ExecutedNodes, flowDto.ExecutedNode{
|
||||
NodeId: nodeID,
|
||||
Status: node.NodeExecutionStatusFailed.Code(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行成功,更新状态
|
||||
updateReq.Status = node.NodeExecutionStatusSuccess.Code()
|
||||
_, _ = nodeDao.NodeExecutionDao.Update(ctx, updateReq)
|
||||
// 记录成功到已执行列表
|
||||
execInput.ExecutedNodes = append(execInput.ExecutedNodes, flowDto.ExecutedNode{
|
||||
NodeId: nodeID,
|
||||
Status: node.NodeExecutionStatusSuccess.Code(),
|
||||
})
|
||||
|
||||
// ✅ 关键:返回整个 execInput,让下一个节点继续用!
|
||||
return execInput, nil
|
||||
@@ -654,6 +699,16 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(VideoModelLambda)))
|
||||
case node.NodeTypeAudioModel:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(AudioModelLambda)))
|
||||
case node.NodeTypeBatchModel:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(BatchModelLambda)))
|
||||
case node.NodeTypeDataConversionModel:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(DataConversionLambda)))
|
||||
//case node.NodeTypeSenseOptimizeModel:
|
||||
// _ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(SenseOptimizeModelLambda)))
|
||||
//case node.NodeTypeStoryOptimizeModel:
|
||||
// _ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(StoryOptimizeModelLambda)))
|
||||
//case node.NodeTypeScriptOptimizeModel:
|
||||
// _ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(ScriptOptimizeModelLambda)))
|
||||
case node.NodeTypeCustomNode:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(CustomLambda)))
|
||||
case node.NodeTypeForm:
|
||||
@@ -662,6 +717,10 @@ func registerNodeToGraph(graph *compose.Graph[any, any], flowNode entity.FlowNod
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(IntentLambda)))
|
||||
case node.NodeTypeMerge:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(MergeLambda)))
|
||||
case node.NodeTypeDataMerge:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(DataMergeLambda)))
|
||||
case node.NodeTypeHttp:
|
||||
_ = graph.AddLambdaNode(nodeID, compose.InvokableLambda(wrapLambda(HttpLambda)))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@ package flow
|
||||
|
||||
import (
|
||||
"ai-agent/workflow/consts/flow"
|
||||
"ai-agent/workflow/consts/node"
|
||||
flowDao "ai-agent/workflow/dao/flow"
|
||||
flowDto "ai-agent/workflow/model/dto/flow"
|
||||
"ai-agent/workflow/model/entity"
|
||||
"ai-agent/workflow/service"
|
||||
"context"
|
||||
|
||||
"gitea.com/red-future/common/beans"
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
@@ -18,25 +19,8 @@ var FlowUserService = &flowUserService{}
|
||||
|
||||
type flowUserService struct{}
|
||||
|
||||
// IsAdmin 调用admin-go服务检查是否是管理员
|
||||
func IsAdmin(ctx context.Context) (res bool, err error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
var r = make(map[string]bool)
|
||||
if err = commonHttp.Get(ctx, "admin-go/api/v1/system/user/checkIsSuperAdmin", headers, &r); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return r["isSuperAdmin"], err
|
||||
}
|
||||
|
||||
func (s *flowUserService) Create(ctx context.Context, req *flowDto.CreateFlowUserReq) (res *flowDto.CreateFlowUserRes, err error) {
|
||||
admin, err := IsAdmin(ctx)
|
||||
admin, err := service.UtilService.IsAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -57,7 +41,7 @@ func (s *flowUserService) Create(ctx context.Context, req *flowDto.CreateFlowUse
|
||||
}
|
||||
|
||||
func (s *flowUserService) Update(ctx context.Context, req *flowDto.UpdateFlowUserReq) (err error) {
|
||||
admin, err := IsAdmin(ctx)
|
||||
admin, err := service.UtilService.IsAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -97,6 +81,26 @@ func (s *flowUserService) Update(ctx context.Context, req *flowDto.UpdateFlowUse
|
||||
}
|
||||
|
||||
func ExtractFlowNodeFrom(flowContent *entity.FlowInfo) []*entity.FlowNode {
|
||||
// 构建每个节点的上游节点映射
|
||||
upstreamMap := make(map[string][]string)
|
||||
for _, edge := range flowContent.Edges {
|
||||
upstreamMap[edge.To] = append(upstreamMap[edge.To], edge.From)
|
||||
}
|
||||
|
||||
// 同时更新 flowContent.Nodes 中的 DataMerge 节点
|
||||
for i := range flowContent.Nodes {
|
||||
n := &flowContent.Nodes[i]
|
||||
// 对于 DataMerge 节点,自动根据边关系填充 InputSource
|
||||
if n.NodeCode == node.NodeTypeDataMerge {
|
||||
n.InputSource = nil
|
||||
for _, fromId := range upstreamMap[n.Id] {
|
||||
n.InputSource = append(n.InputSource, entity.FlowNodeInputSource{
|
||||
NodeId: fromId,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var flowNodes []*entity.FlowNode
|
||||
for _, item := range flowContent.Nodes {
|
||||
flowNodes = append(flowNodes, &item)
|
||||
@@ -105,7 +109,7 @@ func ExtractFlowNodeFrom(flowContent *entity.FlowInfo) []*entity.FlowNode {
|
||||
}
|
||||
|
||||
func (s *flowUserService) Delete(ctx context.Context, req *flowDto.DeleteFlowUserReq) (err error) {
|
||||
admin, err := IsAdmin(ctx)
|
||||
admin, err := service.UtilService.IsAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -146,7 +150,7 @@ func (s *flowUserService) Get(ctx context.Context, req *flowDto.GetFlowUserReq)
|
||||
}
|
||||
|
||||
func (s *flowUserService) List(ctx context.Context, req *flowDto.ListFlowUserReq) (res *flowDto.ListFlowRes, err error) {
|
||||
admin, err := IsAdmin(ctx)
|
||||
admin, err := service.UtilService.IsAdmin(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitea.com/red-future/common/db/gfdb"
|
||||
@@ -58,7 +59,6 @@ func JudgeLambda(ctx context.Context, input any) (string, error) {
|
||||
outputResult = append(outputResult, field)
|
||||
}
|
||||
}
|
||||
|
||||
contextParts := ""
|
||||
for _, v := range nodeInput.Config.FormConfig {
|
||||
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
|
||||
@@ -68,51 +68,128 @@ func JudgeLambda(ctx context.Context, input any) (string, error) {
|
||||
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, v.Label, v.Value)
|
||||
}
|
||||
}
|
||||
|
||||
if !g.IsEmpty(nodeInput.Global.Desc) {
|
||||
contextParts = fmt.Sprintf("%s,%s:%s", contextParts, "描述", nodeInput.Global.Desc)
|
||||
}
|
||||
configMap := gconv.Map(nodeInput.Config.Config)
|
||||
ids := gconv.Strings(configMap["branch_ids"])
|
||||
branchIdNameMap := gconv.Map(configMap["branch_id_name_map"])
|
||||
|
||||
// 【重构】构建提示词:展示ID和对应的名称
|
||||
var branchIdNameLines []string
|
||||
for _, id := range ids {
|
||||
name := gconv.String(branchIdNameMap[id])
|
||||
branchIdNameLines = append(branchIdNameLines, fmt.Sprintf("%s: %s", id, name))
|
||||
}
|
||||
|
||||
getIsChatModel, err := GetIsChatModel(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req := flowDto.ComposeMessagesReq{
|
||||
BuildType: 2,
|
||||
ModelName: getIsChatModel.ModelName,
|
||||
SkillName: "",
|
||||
Cause: "判断节点",
|
||||
Form: map[string]any{"prompt": strings.Join(branchIdNameLines, "\n")},
|
||||
UserForm: map[string]any{"prompt": contextParts},
|
||||
UserFiles: nodeInput.Global.FileUrl,
|
||||
SessionId: nodeInput.Global.SessionId,
|
||||
}
|
||||
msg, err := ComposeMessages(ctx, &req)
|
||||
composeResult, err := GetComposeResult(ctx, 2, getIsChatModel.Model.ModelName, "", "", []map[string]any{{"prompt": strings.Join(branchIdNameLines, "\n")}}, []map[string]any{{"prompt": contextParts}}, nodeInput.Global.FileUrl, nodeInput.Global.SessionId, nodeInput.Config.Id, "判断节点")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if g.IsEmpty(msg.Messages) {
|
||||
if g.IsEmpty(composeResult.TaskId) {
|
||||
return "", fmt.Errorf("msg is empty")
|
||||
}
|
||||
|
||||
content := ""
|
||||
for key, _ := range getIsChatModel.ResponseBody {
|
||||
content = gconv.String(msg.Messages[key])
|
||||
for key, _ := range getIsChatModel.Model.ResponseBody {
|
||||
content = gconv.String(composeResult.Messages.Rounds[0][key])
|
||||
}
|
||||
fmt.Printf("JudgeLambda路由:目标节点ID=%s\n", gconv.String(content))
|
||||
return content, nil
|
||||
}
|
||||
|
||||
func BatchModelLambda(ctx context.Context, input any) (any, error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
reqMap := make([]map[string]any, 0)
|
||||
for _, userItem := range userFrom {
|
||||
m := gconv.Map(userItem)
|
||||
for _, i := range nodeInput.Config.InputSource {
|
||||
for _, f := range i.Field {
|
||||
val := m[f]
|
||||
if !g.IsEmpty(val) {
|
||||
if g.NewVar(val).IsSlice() {
|
||||
slice := gconv.SliceAny(val)
|
||||
for _, item := range slice {
|
||||
reqMap = append(reqMap, map[string]any{f: item})
|
||||
}
|
||||
} else {
|
||||
reqMap = append(reqMap, map[string]any{f: val})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 结果按索引存放,保证顺序
|
||||
res := make([][]node.NodeFormField, len(reqMap))
|
||||
var wg sync.WaitGroup
|
||||
// 用一个通道标记是否完成
|
||||
done := make(chan struct{})
|
||||
// 错误只存一个
|
||||
var execErr error
|
||||
|
||||
// 并发执行
|
||||
for idx, item := range reqMap {
|
||||
wg.Add(1)
|
||||
go func(idx int, userItem map[string]any) {
|
||||
defer wg.Done()
|
||||
|
||||
singleUserFrom := []map[string]any{userItem}
|
||||
output, err := TextNode(ctx, nodeInput, skillName, from, singleUserFrom)
|
||||
if err != nil {
|
||||
// 并发安全赋值错误
|
||||
if execErr == nil {
|
||||
execErr = err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 直接按原索引写,顺序绝对正确
|
||||
res[idx] = output
|
||||
}(idx, item)
|
||||
}
|
||||
|
||||
fmt.Printf("JudgeLambda路由:目标节点ID=%s\n", gconv.String(content))
|
||||
// 后台等待所有协程完成,然后关闭 done 通道
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
return content, nil
|
||||
// 等待全部完成
|
||||
<-done
|
||||
|
||||
// 如果有错误,直接返回
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
}
|
||||
|
||||
// 全局自增 i
|
||||
var globalIndex int
|
||||
var outputRes []node.NodeFormField
|
||||
for _, items := range res {
|
||||
for _, item := range items {
|
||||
// 1. 拿到原来的 Field:例如 "text_content:2:0"
|
||||
oldField := item.Field
|
||||
// 2. 找到最后一个 : 的位置
|
||||
if idx := strings.LastIndex(oldField, ":"); idx != -1 {
|
||||
// 3. 截断前面部分,拼接上新的 globalIndex
|
||||
item.Field = oldField[:idx+1] + fmt.Sprint(globalIndex)
|
||||
}
|
||||
// Label 同理
|
||||
oldLabel := item.Label
|
||||
if idx := strings.LastIndex(oldLabel, ":"); idx != -1 {
|
||||
item.Label = oldLabel[:idx+1] + fmt.Sprint(globalIndex)
|
||||
}
|
||||
outputRes = append(outputRes, item)
|
||||
}
|
||||
globalIndex++
|
||||
}
|
||||
|
||||
nodeInput.Config.OutputResult = outputRes
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
// TextModelLambda 构建文案
|
||||
@@ -122,7 +199,7 @@ func TextModelLambda(ctx context.Context, input any) (any, error) {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
outputRes, err := TextNode(ctx, nodeInput.Global.SessionId, nodeInput.Config.ModelConfig.ModelName, skillName, from, userFrom, nodeInput.Config.ModelConfig.ModelResponse, nodeInput.Global.FileUrl)
|
||||
outputRes, err := TextNode(ctx, nodeInput, skillName, from, userFrom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -137,7 +214,7 @@ func ImageModelLambda(ctx context.Context, input any) (any, error) {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
outputRes, err := ImgNode(ctx, nodeInput.Global.SessionId, nodeInput.Config.ModelConfig.ModelName, skillName, from, userFrom, nodeInput.Config.ModelConfig.ModelResponse, nodeInput.Global.FileUrl)
|
||||
outputRes, err := ImgNode(ctx, nodeInput, skillName, from, userFrom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -145,7 +222,213 @@ func ImageModelLambda(ctx context.Context, input any) (any, error) {
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
// AudioModelLambda 构建音频
|
||||
func AudioModelLambda(ctx context.Context, input any) (any, error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
outputRes, err := AudioOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodeInput.Config.OutputResult = outputRes
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
// VideoModelLambda 构建视频
|
||||
func VideoModelLambda(ctx context.Context, input any) (any, error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
res, err := VideoOptimizeNode(ctx, nodeInput, skillName, from, userFrom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
videoURL := make([]string, 0)
|
||||
for _, v := range res {
|
||||
if strings.Contains(v.Field, "content") {
|
||||
videoURL = append(videoURL, gconv.String(v.Value))
|
||||
}
|
||||
}
|
||||
if g.IsEmpty(videoURL) {
|
||||
return nil, fmt.Errorf("视频合成失败:模型生成视频失败")
|
||||
}
|
||||
waitRes, err := VideoConcat(ctx, videoURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := new(flowDto.VideoCallbackReq)
|
||||
if err = gconv.Struct(waitRes, msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
urlPrefix, err := utils.GetFileAddressPrefix(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputRes := make([]node.NodeFormField, 0)
|
||||
if nodeInput.Config.IsSaveFile {
|
||||
outputRes = append(outputRes, node.NodeFormField{
|
||||
Field: fmt.Sprintf("video_oss_url:content:%d", 0),
|
||||
Value: msg.FileURL,
|
||||
Label: fmt.Sprintf("video_oss_url:content:%d", 0),
|
||||
Type: "string",
|
||||
})
|
||||
} else {
|
||||
outputRes = append(outputRes, node.NodeFormField{
|
||||
Field: fmt.Sprintf("concat_video_url:content:%d", 0),
|
||||
Value: urlPrefix + msg.FileURL,
|
||||
Label: fmt.Sprintf("concat_video_url:content:%d", 0),
|
||||
Type: "string",
|
||||
})
|
||||
}
|
||||
nodeInput.Config.OutputResult = outputRes
|
||||
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
// HttpLambda 构建HTTP(S)接口
|
||||
func HttpLambda(ctx context.Context, input any) (any, error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
outputRes := make([]node.NodeFormField, 0)
|
||||
var err error
|
||||
outputRes, err = HttpNode(ctx, nodeInput)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodeInput.Config.OutputResult = outputRes
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
// DataConversionLambda 构建数据转换
|
||||
func DataConversionLambda(ctx context.Context, input any) (any, error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("入参类型错误")
|
||||
}
|
||||
skillName, from, userFrom := BuildParam(nodeInput)
|
||||
outputRes, err := DataConversionNode(ctx, nodeInput, skillName, from, userFrom)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nodeInput.Config.OutputResult = outputRes
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
func DataMergeLambda(ctx context.Context, input any) (res any, err error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("参数合并入参类型错误")
|
||||
}
|
||||
|
||||
// var nodeIds []string
|
||||
// for _, item := range nodeInput.Config.InputSource {
|
||||
// nodeIds = append(nodeIds, item.NodeId)
|
||||
// }
|
||||
//
|
||||
// // 检查是否所有输入节点都执行完成,并且检查是否有节点失败
|
||||
// checkAllExecuted := func() (allExecuted bool, hasFailed bool, failedNode string) {
|
||||
// executedCount := 0
|
||||
// for _, executedNode := range nodeInput.Global.ExecutedNodes {
|
||||
// // 检查是否是我们需要的输入节点,并且它失败了
|
||||
// for _, targetId := range nodeIds {
|
||||
// if executedNode.NodeId == targetId {
|
||||
// if executedNode.Status == node.NodeExecutionStatusFailed.Code() {
|
||||
// return false, true, targetId
|
||||
// }
|
||||
// executedCount++
|
||||
// break
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// return executedCount == len(nodeIds), false, ""
|
||||
// }
|
||||
//
|
||||
// // 初次检查
|
||||
// allExecuted, hasFailed, failedNode := checkAllExecuted()
|
||||
// if hasFailed {
|
||||
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
|
||||
// }
|
||||
//
|
||||
// // 如果不是全部都已执行,阻塞等待直到全部完成、上下文取消或有节点失败
|
||||
// if !allExecuted {
|
||||
// // 轮询检查,每500ms检查一次,依赖ctx超时控制
|
||||
// ticker := time.NewTicker(500 * time.Millisecond)
|
||||
// defer ticker.Stop()
|
||||
//
|
||||
// for {
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// // 如果上下文已经取消,说明已有节点报错,直接退出
|
||||
// return nil, ctx.Err()
|
||||
// case <-ticker.C:
|
||||
// // 重新检查所有节点
|
||||
// allExecuted, hasFailed, failedNode := checkAllExecuted()
|
||||
// if hasFailed {
|
||||
// // 有一个输入节点失败,直接退出
|
||||
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
|
||||
// }
|
||||
// if allExecuted {
|
||||
// // 全部执行完成,退出循环继续执行
|
||||
// goto allDone
|
||||
// }
|
||||
//
|
||||
// // 再次检查上下文是否已经取消,如果已经取消则立即退出
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//allDone:
|
||||
//
|
||||
// // 最终检查:所有输入节点都成功了吗
|
||||
// _, hasFailed, failedNode = checkAllExecuted()
|
||||
// if hasFailed {
|
||||
// // 有一个输入节点失败,直接退出
|
||||
// return nil, fmt.Errorf("输入节点[%s]执行失败", failedNode)
|
||||
// }
|
||||
//
|
||||
// // 构建已执行节点ID的map,方便合并时查找
|
||||
// executedMap := make(map[string]*flowDto.ExecutedNode, len(nodeInput.Global.ExecutedNodes))
|
||||
// for _, en := range nodeInput.Global.ExecutedNodes {
|
||||
// executedMap[en.NodeId] = &en
|
||||
// }
|
||||
//
|
||||
// // 合并所有输入源节点的输出结果
|
||||
// for _, inputSource := range nodeInput.Config.InputSource {
|
||||
// // 每次循环都检查上下文是否已取消,提前退出
|
||||
// select {
|
||||
// case <-ctx.Done():
|
||||
// return nil, ctx.Err()
|
||||
// default:
|
||||
// }
|
||||
// // 再次检查该节点是否失败
|
||||
// if en, ok := executedMap[inputSource.NodeId]; ok && en.Status == node.NodeExecutionStatusFailed.Code() {
|
||||
// return nil, fmt.Errorf("输入节点[%s]执行失败", inputSource.NodeId)
|
||||
// }
|
||||
// sourceNodeConfig := nodeInput.Global.ConfigMap[inputSource.NodeId]
|
||||
// if sourceNodeConfig != nil && len(sourceNodeConfig.OutputResult) > 0 {
|
||||
// nodeInput.Config.OutputResult = append(nodeInput.Config.OutputResult, sourceNodeConfig.OutputResult...)
|
||||
// }
|
||||
// }
|
||||
|
||||
return nodeInput, nil
|
||||
}
|
||||
|
||||
func MergeLambda(ctx context.Context, input any) (res any, err error) {
|
||||
nodeInput, ok := input.(*flowDto.NodeExecutionInput)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("汇总节点入参类型错误")
|
||||
@@ -155,7 +438,8 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
dataMap := make(map[string]node.NodeFormField)
|
||||
_, outputMap, _ := GetNodeContextContent(nodeInput.Global, nodeInput.Config)
|
||||
for _, valueAny := range outputMap {
|
||||
if field, ok := valueAny.(node.NodeFormField); ok {
|
||||
field := node.NodeFormField{}
|
||||
if field, ok = valueAny.(node.NodeFormField); ok {
|
||||
dataMap[field.Field] = field
|
||||
}
|
||||
}
|
||||
@@ -163,7 +447,7 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
// 2. 提取所有文案:text_content_0,1,2...
|
||||
var contents []node.NodeFormField
|
||||
for i := 0; ; i++ {
|
||||
key := fmt.Sprintf("text_url:%d", i)
|
||||
key := fmt.Sprintf("text_content:%d", i)
|
||||
val, has := dataMap[key]
|
||||
if !has || val.Value == "" {
|
||||
break
|
||||
@@ -179,7 +463,7 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
if !has || val.Value == "" {
|
||||
break
|
||||
}
|
||||
images = append(images, val.Value)
|
||||
images = append(images, gconv.String(val.Value))
|
||||
}
|
||||
|
||||
// 4. 🔥 核心算法:图片按顺序连续归属给每条文案
|
||||
@@ -232,8 +516,8 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
if len(contents) > 0 {
|
||||
for i, val := range contents {
|
||||
item := Item{
|
||||
Content: url + val.Value, // 文案
|
||||
Images: textImgMap[i], // 自动绑定该条目的图片(没有则为空切片)
|
||||
Content: url + gconv.String(val.Value), // 文案
|
||||
Images: textImgMap[i], // 自动绑定该条目的图片(没有则为空切片)
|
||||
}
|
||||
allItems = append(allItems, item)
|
||||
}
|
||||
@@ -254,24 +538,8 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
|
||||
// 遍历所有【独立图文条目】 → 每条生成独立HTML、独立上传OSS、独立输出记录
|
||||
for idx, item := range allItems {
|
||||
// item 结构包含:Content(string) + Images([]string)
|
||||
// 支持任意来源:文生图、图生文、单独文、单独图、文图合并
|
||||
|
||||
// 生成单条HTML
|
||||
htmlContent := BuildHtml(item.Content, item.Images)
|
||||
|
||||
// 上传OSS(每条独立上传)
|
||||
fileName := fmt.Sprintf("item_%d_%d.html", idx, time.Now().UnixMilli())
|
||||
ossResult, err := Upload(ctx, &dto.UploadFileBytesReq{
|
||||
FileBytes: []byte(htmlContent),
|
||||
FileName: fileName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 拼接成一条输出记录
|
||||
// 每条记录包含:HTML内容 + 访问URL + 文案 + 图片列表
|
||||
outputRecords = append(outputRecords,
|
||||
node.NodeFormField{
|
||||
Field: fmt.Sprintf("item_html_%d", idx),
|
||||
@@ -279,25 +547,26 @@ func MergeLambda(ctx context.Context, input any) (any, error) {
|
||||
Label: fmt.Sprintf("条目%d HTML", idx+1),
|
||||
Type: "textarea",
|
||||
},
|
||||
node.NodeFormField{
|
||||
Field: fmt.Sprintf("item_html_url_%d", idx),
|
||||
Value: ossResult.FileURL,
|
||||
Label: fmt.Sprintf("条目%d 地址", idx+1),
|
||||
Type: "text",
|
||||
},
|
||||
node.NodeFormField{
|
||||
Field: fmt.Sprintf("item_txt_url_%d", idx),
|
||||
Value: item.Content,
|
||||
Label: fmt.Sprintf("条目%d 文案", idx+1),
|
||||
Type: "text",
|
||||
},
|
||||
node.NodeFormField{
|
||||
Field: fmt.Sprintf("item_image_url_%d", idx),
|
||||
Value: strings.Join(item.Images, ","),
|
||||
Label: fmt.Sprintf("条目%d 图片", idx+1),
|
||||
Type: "text",
|
||||
},
|
||||
)
|
||||
if nodeInput.Config.IsSaveFile {
|
||||
// 上传OSS(每条独立上传)
|
||||
fileName := fmt.Sprintf("item_%d_%d.html", idx, time.Now().UnixMilli())
|
||||
ossResult, err := Upload(ctx, &dto.UploadFileBytesReq{
|
||||
FileBytes: []byte(htmlContent),
|
||||
FileName: fileName,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outputRecords = append(outputRecords,
|
||||
node.NodeFormField{
|
||||
Field: fmt.Sprintf("item_html_url_%d", idx),
|
||||
Value: ossResult.FileURL,
|
||||
Label: fmt.Sprintf("条目%d 地址", idx+1),
|
||||
Type: "text",
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 最终输出多条记录
|
||||
@@ -313,11 +582,12 @@ func SummaryLambda(ctx context.Context, input any) (any, error) {
|
||||
|
||||
// 聚合所有已执行节点的输出结果
|
||||
var summaryResult []map[string]interface{}
|
||||
for _, nodeID := range execInput.Global.ExecutedNodes {
|
||||
for _, executedNode := range execInput.Global.ExecutedNodes {
|
||||
nodeID := executedNode.NodeId
|
||||
nodeConfig := execInput.Global.ConfigMap[nodeID]
|
||||
if nodeConfig != nil && len(nodeConfig.OutputResult) > 0 {
|
||||
for _, field := range nodeConfig.OutputResult {
|
||||
if strings.Contains(field.Field, "item_html_url") || strings.Contains(field.Field, "img_url") || strings.Contains(field.Field, "text_url") {
|
||||
if strings.Contains(field.Field, "http_file_url") || strings.Contains(field.Field, "audio_oss_url") || strings.Contains(field.Field, "video_oss_url") || strings.Contains(field.Field, "item_html_url") || strings.Contains(field.Field, "img_oss_url") || strings.Contains(field.Field, "text_url") {
|
||||
// 生成 毫秒时间戳 作为 KEY
|
||||
timeKey := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
item := make(map[string]interface{})
|
||||
@@ -376,18 +646,6 @@ func SummaryLambda(ctx context.Context, input any) (any, error) {
|
||||
return execInput, err
|
||||
}
|
||||
|
||||
// VideoModelLambda 构建视频
|
||||
func VideoModelLambda(ctx context.Context, input any) (any, error) {
|
||||
fmt.Println("VideoModelLambda:", input)
|
||||
return input, nil
|
||||
}
|
||||
|
||||
// AudioModelLambda 构建音频
|
||||
func AudioModelLambda(ctx context.Context, input any) (any, error) {
|
||||
fmt.Println("AudioModelLambda:", input)
|
||||
return input, nil
|
||||
}
|
||||
|
||||
// CustomLambda 构建自定义
|
||||
func CustomLambda(ctx context.Context, input any) (any, error) {
|
||||
fmt.Println("CustomLambda:", input)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,74 @@
|
||||
package flow
|
||||
|
||||
import (
|
||||
"ai-agent/workflow/consts/node"
|
||||
nodeDao "ai-agent/workflow/dao/node"
|
||||
"ai-agent/workflow/model/dto"
|
||||
flowDto "ai-agent/workflow/model/dto/flow"
|
||||
nodeDto "ai-agent/workflow/model/dto/node"
|
||||
"ai-agent/workflow/model/entity"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
commonHttp "gitea.com/red-future/common/http"
|
||||
"gitea.com/red-future/common/utils"
|
||||
"github.com/gogf/gf/v2/errors/gerror"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// 全局等待任务回调的工具
|
||||
var (
|
||||
asyncMu sync.Mutex
|
||||
asyncTasks = make(map[string]chan any)
|
||||
)
|
||||
|
||||
// Wait 阻塞等待回调结果
|
||||
// 调用后会一直卡住,直到 Notify 唤醒 或 超时/取消
|
||||
func Wait(ctx context.Context, taskId string) (any, error) {
|
||||
asyncMu.Lock()
|
||||
ch := make(chan any, 1)
|
||||
asyncTasks[taskId] = ch
|
||||
asyncMu.Unlock()
|
||||
|
||||
defer close(ch)
|
||||
for {
|
||||
select {
|
||||
case result := <-ch:
|
||||
return result, nil
|
||||
case <-ctx.Done():
|
||||
asyncMu.Lock()
|
||||
delete(asyncTasks, taskId)
|
||||
asyncMu.Unlock()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify 回调时调用,唤醒等待的任务
|
||||
func Notify(taskId string, result any) {
|
||||
asyncMu.Lock()
|
||||
defer asyncMu.Unlock()
|
||||
|
||||
ch, exist := asyncTasks[taskId]
|
||||
if !exist {
|
||||
return
|
||||
}
|
||||
ch <- result
|
||||
delete(asyncTasks, taskId)
|
||||
}
|
||||
|
||||
func GetIsChatModel(ctx context.Context) (res *flowDto.GetIsChatModelRes, err error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
@@ -33,7 +83,7 @@ func GetIsChatModel(ctx context.Context) (res *flowDto.GetIsChatModelRes, err er
|
||||
return
|
||||
}
|
||||
|
||||
func ComposeMessages(ctx context.Context, req *flowDto.ComposeMessagesReq) (res *flowDto.ComposeMessagesRes, err error) {
|
||||
func GetModelInfo(ctx context.Context, req *flowDto.GetModelInfoReq) (res *flowDto.GetModelInfoRes, err error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
@@ -42,58 +92,71 @@ func ComposeMessages(ctx context.Context, req *flowDto.ComposeMessagesReq) (res
|
||||
}
|
||||
}
|
||||
}
|
||||
res = new(flowDto.ComposeMessagesRes)
|
||||
err = commonHttp.Post(ctx, "prompts-core/prompt/composeMessages", headers, res, &req)
|
||||
res = new(flowDto.GetModelInfoRes)
|
||||
err = commonHttp.Get(ctx, "model-gateway/model/getModel", headers, res, req)
|
||||
return
|
||||
}
|
||||
|
||||
func GetModelResult(ctx context.Context, modelName, skillName string, form, userFrom map[string]any, fileUrl []string, sessionId string, cause string) (mapTaskResult map[string]any, err error) {
|
||||
func GetComposeResult(ctx context.Context, buildType int, modelName, promptContent, skillName string, form []map[string]any, userForm []map[string]any, fileUrl []string, sessionId, nodeId string, cause string) (res *flowDto.ComposeCallbackReq, err error) {
|
||||
if !g.IsEmpty(promptContent) {
|
||||
userForm = append(userForm, map[string]any{
|
||||
"prompt": promptContent,
|
||||
})
|
||||
}
|
||||
var callbackUrl = utils.GetCallbackURL(ctx, "/flow/execution/composeCallBack")
|
||||
var consult = make([]flowDto.Consult, 0)
|
||||
var collectFileUrls func(val any)
|
||||
collectFileUrls = func(val any) {
|
||||
switch {
|
||||
case g.NewVar(val).IsSlice():
|
||||
slice := gconv.SliceAny(val)
|
||||
for _, item := range slice {
|
||||
collectFileUrls(item)
|
||||
}
|
||||
case g.NewVar(val).IsMap():
|
||||
m := gconv.Map(val)
|
||||
for _, item := range m {
|
||||
collectFileUrls(item)
|
||||
}
|
||||
default:
|
||||
s := gconv.String(val)
|
||||
if s != "" {
|
||||
getFileTypeByPath := GetFileTypeByPath(s)
|
||||
if getFileTypeByPath != "" {
|
||||
consult = append(consult, flowDto.Consult{
|
||||
Type: getFileTypeByPath,
|
||||
Url: s,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, m := range userForm {
|
||||
for _, v := range gconv.Map(m) {
|
||||
collectFileUrls(v)
|
||||
}
|
||||
}
|
||||
for _, v := range fileUrl {
|
||||
getFileTypeByPath := GetFileTypeByPath(gconv.String(v))
|
||||
if getFileTypeByPath != "" {
|
||||
consult = append(consult, flowDto.Consult{
|
||||
Type: getFileTypeByPath,
|
||||
Url: gconv.String(v),
|
||||
})
|
||||
}
|
||||
}
|
||||
msgReq := flowDto.ComposeMessagesReq{
|
||||
BuildType: 1,
|
||||
ModelName: modelName,
|
||||
SkillName: skillName,
|
||||
Cause: cause,
|
||||
Form: form,
|
||||
UserForm: userFrom,
|
||||
UserFiles: fileUrl,
|
||||
SessionId: sessionId,
|
||||
BuildType: buildType,
|
||||
ModelName: modelName,
|
||||
SkillName: skillName,
|
||||
CallbackUrl: callbackUrl,
|
||||
Cause: cause,
|
||||
Form: form,
|
||||
UserForm: userForm,
|
||||
Consult: consult,
|
||||
SessionId: sessionId,
|
||||
NodeId: nodeId,
|
||||
}
|
||||
msg, err := ComposeMessages(ctx, &msgReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if g.IsEmpty(msg.Messages) {
|
||||
return nil, fmt.Errorf("msg is empty")
|
||||
}
|
||||
var taskResult any
|
||||
taskResult, err = GatewayTask(ctx, msg.EpicycleId, modelName, msg.Messages)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var getTaskResult *flowDto.TaskCallback
|
||||
getTaskResult, err = GetTaskResult(ctx, taskResult)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mapTaskResult = gconv.Map(getTaskResult.Text)
|
||||
return mapTaskResult, nil
|
||||
}
|
||||
|
||||
func GatewayTask(ctx context.Context, epicycleId int64, model string, content map[string]any) (any, error) {
|
||||
modelTaskId, err := CreateGatewayTask(ctx, &flowDto.CreateTaskReq{
|
||||
ModelName: model,
|
||||
BizName: g.Cfg().MustGet(ctx, "server.name").String(),
|
||||
CallbackUrl: "/flow/execution/modelCallback",
|
||||
RequestPayload: content,
|
||||
EpicycleId: epicycleId,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Wait(ctx, modelTaskId)
|
||||
}
|
||||
|
||||
func CreateGatewayTask(ctx context.Context, req *flowDto.CreateTaskReq) (string, error) {
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
@@ -102,69 +165,330 @@ func CreateGatewayTask(ctx context.Context, req *flowDto.CreateTaskReq) (string,
|
||||
}
|
||||
}
|
||||
}
|
||||
res := new(flowDto.CreateTaskRes)
|
||||
msgRes := new(flowDto.ComposeMessagesRes)
|
||||
err = commonHttp.Post(ctx, "prompts-core/prompt/composeMessages", headers, msgRes, &msgReq)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if g.IsEmpty(msgRes.TaskId) {
|
||||
return nil, fmt.Errorf("msg is empty")
|
||||
}
|
||||
waitRes, err := Wait(ctx, msgRes.TaskId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := new(flowDto.ComposeCallbackReq)
|
||||
if err = gconv.Struct(waitRes, msg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !g.IsEmpty(msg.ErrorMsg) {
|
||||
return nil, fmt.Errorf(msg.ErrorMsg)
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func CreateGatewayTask(ctx context.Context, epicycleId int64, model string, content map[string]any) (map[string]any, error) {
|
||||
taskId, err := createGatewayTaskOnly(ctx, epicycleId, model, content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return waitGatewayResult(ctx, taskId)
|
||||
}
|
||||
|
||||
// createGatewayTaskOnly creates a gateway task and returns the taskId only
|
||||
// doesn't wait for completion
|
||||
func createGatewayTaskOnly(ctx context.Context, epicycleId int64, model string, content map[string]any) (string, error) {
|
||||
callbackUrl := utils.GetCallbackURL(ctx, "/flow/execution/modelCallback")
|
||||
req := flowDto.ModelGatewayReq{
|
||||
ModelName: model,
|
||||
BizName: g.Cfg().MustGet(ctx, "server.name").String(),
|
||||
CallbackUrl: callbackUrl,
|
||||
RequestPayload: content,
|
||||
EpicycleId: epicycleId,
|
||||
}
|
||||
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
res := new(flowDto.ModelGatewayRes)
|
||||
err := commonHttp.Post(ctx, "model-gateway/task/createTask", headers, res, &req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if g.IsEmpty(res.TaskId) {
|
||||
return "", fmt.Errorf("创建模型任务失败,taskId为空")
|
||||
}
|
||||
|
||||
return res.TaskId, nil
|
||||
}
|
||||
|
||||
func GetTaskResult(ctx context.Context, result any) (*flowDto.TaskCallback, error) {
|
||||
task := new(flowDto.TaskCallback)
|
||||
if err := gconv.Struct(result, task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
url, err := utils.GetFileAddressPrefix(ctx)
|
||||
// waitGatewayResult waits for a created gateway task to complete and returns the result
|
||||
func waitGatewayResult(ctx context.Context, taskId string) (map[string]any, error) {
|
||||
waitRes, err := Wait(ctx, taskId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取远程文件内容
|
||||
file, err := FetchRemoteJsonFile(ctx, url+task.OssFile)
|
||||
if err != nil {
|
||||
task := new(flowDto.ModelCallbackReq)
|
||||
if err := gconv.Struct(waitRes, task); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Text = gconv.String(file)
|
||||
if task.State == 3 || !g.IsEmpty(task.ErrorMsg) {
|
||||
return nil, fmt.Errorf("模型执行失败:%s", task.ErrorMsg)
|
||||
}
|
||||
if g.IsEmpty(task.Messages) {
|
||||
return nil, fmt.Errorf("模型返回结果为空")
|
||||
}
|
||||
|
||||
return task, nil
|
||||
return task.Messages, nil
|
||||
}
|
||||
|
||||
func FetchRemoteJsonFile(ctx context.Context, fileUrl string) ([]byte, error) {
|
||||
// 1. 下载文件
|
||||
// updateTokenCount updates the token count in node execution
|
||||
func updateTokenCount(ctx context.Context, nodeExecutionId int64, responseField string, result map[string]any) {
|
||||
if responseField == "" {
|
||||
return
|
||||
}
|
||||
_, _ = nodeDao.NodeExecutionDao.Update(ctx, &nodeDto.UpdateNodeExecutionReq{
|
||||
Id: nodeExecutionId,
|
||||
CompletionTokens: gconv.Int(result[responseField]),
|
||||
TotalTokens: gconv.Int(result[responseField]),
|
||||
})
|
||||
}
|
||||
|
||||
func GetModelResult(ctx context.Context, sessionId string, nodeInput *flowDto.NodeExecutionInput, skillName string, form []map[string]any, userForm []map[string]any) (mapTaskResult []map[string]any, err error) {
|
||||
buildType := 1
|
||||
if nodeInput.Config.NodeCode == node.NodeTypeDataConversionModel {
|
||||
buildType = 3
|
||||
}
|
||||
|
||||
composeResult, err := GetComposeResult(ctx, buildType, nodeInput.Config.ModelConfig.ModelName, nodeInput.Config.PromptContent, skillName, form, userForm, nodeInput.Global.FileUrl, sessionId, nodeInput.Config.Id, nodeInput.Config.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
modelInfo, err := GetModelInfo(ctx, &flowDto.GetModelInfoReq{ModelName: nodeInput.Config.ModelConfig.ModelName})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mapTaskResult = make([]map[string]any, len(composeResult.Messages.Rounds))
|
||||
var taskResultMap map[string]any
|
||||
|
||||
needSequential := false
|
||||
if buildType == 1 {
|
||||
if needSequential {
|
||||
for idx, item := range composeResult.Messages.Rounds {
|
||||
if !g.IsEmpty(taskResultMap) {
|
||||
var set string
|
||||
set, err = sjson.Set(gconv.String(item), modelInfo.Model.LastFrame, gconv.String(taskResultMap[modelInfo.Model.ResponseBody]))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item = gconv.Map(set)
|
||||
}
|
||||
|
||||
var taskResult map[string]any
|
||||
taskResult, err = CreateGatewayTask(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if g.IsEmpty(taskResult) {
|
||||
return nil, fmt.Errorf("模型返回结果为空")
|
||||
}
|
||||
|
||||
// Update taskResultMap for next round (used by VideoModel)
|
||||
if nodeInput.Config.NodeCode == node.NodeTypeVideoModel {
|
||||
ext := GetFileTypeByPath(gconv.String(taskResult[modelInfo.Model.ResponseBody]))
|
||||
if ext == "image" {
|
||||
taskResultMap = taskResult
|
||||
} else {
|
||||
taskResultMap = make(map[string]any)
|
||||
}
|
||||
} else {
|
||||
taskResultMap = make(map[string]any)
|
||||
}
|
||||
|
||||
mapTaskResult[idx] = taskResult
|
||||
updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult)
|
||||
}
|
||||
} else {
|
||||
taskIdList := make([]string, len(composeResult.Messages.Rounds))
|
||||
|
||||
for idx, item := range composeResult.Messages.Rounds {
|
||||
var taskId string
|
||||
taskId, err = createGatewayTaskOnly(ctx, composeResult.EpicycleId, nodeInput.Config.ModelConfig.ModelName, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
taskIdList[idx] = taskId
|
||||
}
|
||||
|
||||
// Step 2: Wait for all tasks in parallel
|
||||
var wg sync.WaitGroup
|
||||
errChan := make(chan error, len(taskIdList))
|
||||
|
||||
for idx, taskId := range taskIdList {
|
||||
wg.Add(1)
|
||||
|
||||
// Pass idx and taskId as parameters to avoid loop variable capture bug
|
||||
// This guarantees results are stored in the correct order matching original requests
|
||||
go func(idx int, taskId string) {
|
||||
defer wg.Done()
|
||||
|
||||
var taskResult map[string]any
|
||||
taskResult, err = waitGatewayResult(ctx, taskId)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
mapTaskResult[idx] = taskResult
|
||||
updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, taskResult)
|
||||
}(idx, taskId)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errChan)
|
||||
|
||||
if len(errChan) > 0 {
|
||||
return nil, <-errChan
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for idx, item := range composeResult.Messages.Rounds {
|
||||
mapTaskResult[idx] = item
|
||||
updateTokenCount(ctx, nodeInput.NodeExecutionId, modelInfo.Model.ResponseTokenField, item)
|
||||
}
|
||||
}
|
||||
|
||||
return mapTaskResult, nil
|
||||
}
|
||||
|
||||
func BuildNestedJson(body g.Map, mockConfigMap map[string]*entity.FlowNode) g.Map {
|
||||
jsonStr := "{}"
|
||||
for originKey, originItem := range body {
|
||||
bodyItemMap := gconv.Map(originItem)
|
||||
val := bodyItemMap["value"]
|
||||
if v, ok := bodyItemMap["value"]; ok {
|
||||
jsonStr, _ = sjson.Set(jsonStr, originKey, v)
|
||||
}
|
||||
// 判断 value 是不是引用结构(map)
|
||||
if g.NewVar(val).IsMap() {
|
||||
valMap := gconv.Map(val)
|
||||
nodeId := gconv.String(valMap["nodeId"])
|
||||
fieldName := gconv.String(valMap["field"])
|
||||
if configValue, ok := mockConfigMap[nodeId]; ok {
|
||||
if !g.IsEmpty(configValue.OutputResult) {
|
||||
for _, v := range configValue.OutputResult {
|
||||
if strings.Contains(v.Field, fieldName) {
|
||||
if configValue.NodeCode == node.NodeTypeDataConversionModel {
|
||||
switch {
|
||||
case g.NewVar(v.Value).IsSlice() || g.NewVar(v.Value).IsMap():
|
||||
// 核心:自动判断两种结构,精准赋值
|
||||
vm := gconv.Map(v.Value)
|
||||
// 先判断是否是 单个key包裹的对象(如 {"subtitle_style": {...}})
|
||||
if len(vm) == 1 {
|
||||
// 遍历取出唯一的 key 和 真实值
|
||||
for innerKey, innerVal := range vm {
|
||||
// 直接用 innerKey(subtitle_style)赋值
|
||||
jsonStr, _ = sjson.Set(jsonStr, innerKey, innerVal)
|
||||
}
|
||||
} else {
|
||||
// 直接是对象,用 originKey 赋值
|
||||
jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value)
|
||||
}
|
||||
default:
|
||||
jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value)
|
||||
}
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !g.IsEmpty(configValue.FormConfig) {
|
||||
for _, v := range configValue.FormConfig {
|
||||
if v.Field == fieldName {
|
||||
if v.Type == "uploadMultiple" {
|
||||
if g.NewVar(v.FieldConstraint).IsMap() {
|
||||
mapFieldConstraint := gconv.Map(v.FieldConstraint)
|
||||
for key, value := range mapFieldConstraint {
|
||||
if key == "maxFileCount" {
|
||||
if gconv.Int(value) == 1 {
|
||||
// 如果是单文件上传,则替换成字符串重新赋值给v.Value
|
||||
if g.NewVar(v.Value).IsSlice() {
|
||||
sliceVal := gconv.SliceAny(v.Value)
|
||||
if len(sliceVal) > 0 {
|
||||
v.Value = sliceVal[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, originKey, v.Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return gconv.Map(jsonStr)
|
||||
}
|
||||
|
||||
func VideoConcat(ctx context.Context, videoUrls []string) (r any, err error) {
|
||||
var httpUrl = "media/video/concat/async"
|
||||
headers := make(map[string]string)
|
||||
if r := g.RequestFromCtx(ctx); r != nil {
|
||||
for k, v := range r.Request.Header {
|
||||
if len(v) > 0 {
|
||||
headers[k] = v[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
var callbackUrl = utils.GetCallbackURL(ctx, "/flow/execution/videoCallback")
|
||||
var newBody = flowDto.VideoConcatReq{
|
||||
VideoUrls: videoUrls,
|
||||
Method: "auto",
|
||||
Upload: true,
|
||||
CallbackUrl: callbackUrl,
|
||||
}
|
||||
res := new(flowDto.VideoConcatRes)
|
||||
err = commonHttp.Post(ctx, httpUrl, headers, &res, newBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return Wait(ctx, res.TaskId)
|
||||
}
|
||||
|
||||
func GetFileBytesFromURL(ctx context.Context, fileUrl string) ([]byte, error) {
|
||||
// 使用 GoFrame 客户端(自带超时、追踪、日志等能力)
|
||||
resp, err := g.Client().Get(ctx, fileUrl)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get file failed: %w", err)
|
||||
return nil, gerror.Wrapf(err, "failed to request url: %s", fileUrl)
|
||||
}
|
||||
defer resp.Close()
|
||||
|
||||
// 校验状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("http status error: %d", resp.StatusCode)
|
||||
return nil, gerror.Newf("request failed with status code: %d, url: %s", resp.StatusCode, fileUrl)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
// 读取全部内容
|
||||
allBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, gerror.Wrapf(err, "failed to read response body, url: %s", fileUrl)
|
||||
}
|
||||
|
||||
func GetFileBytesFromURL(url string) (all []byte, err error) {
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
fmt.Printf("请求失败 %s: %v", url, err)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Printf("请求失败,状态码: %d\n", resp.StatusCode)
|
||||
return
|
||||
}
|
||||
all, err = io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("读取内容失败 %s: %v", url, err)
|
||||
return
|
||||
}
|
||||
return
|
||||
return allBytes, nil
|
||||
}
|
||||
|
||||
func Upload(ctx context.Context, req *dto.UploadFileBytesReq) (*dto.UploadFileBytesRes, error) {
|
||||
@@ -192,8 +516,8 @@ func Upload(ctx context.Context, req *dto.UploadFileBytesReq) (*dto.UploadFileBy
|
||||
|
||||
// 发起上传请求
|
||||
res := &dto.UploadFileBytesRes{}
|
||||
url := "oss/file/uploadFile"
|
||||
if err = commonHttp.Post(ctx, url, headers, res, body.Bytes()); err != nil {
|
||||
httpUrl := "oss/file/uploadFile"
|
||||
if err = commonHttp.Post(ctx, httpUrl, headers, res, body.Bytes()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -201,6 +525,40 @@ func Upload(ctx context.Context, req *dto.UploadFileBytesReq) (*dto.UploadFileBy
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func GetFileTypeByPath(filePath string) string {
|
||||
if filePath == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 解析 URL,获取真实路径(兼容 http 链接)
|
||||
u, err := url.Parse(filePath)
|
||||
if err == nil {
|
||||
filePath = u.Path
|
||||
}
|
||||
|
||||
// 获取后缀(小写)
|
||||
ext := filepath.Ext(filePath)
|
||||
ext = strings.ToLower(ext)
|
||||
|
||||
// 判断类型
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg", ".png", ".gif", ".bmp", ".webp":
|
||||
return "image"
|
||||
case ".mp4", ".mov", ".avi", ".flv", ".wmv", ".mkv":
|
||||
return "video"
|
||||
case ".mp3", ".wav", ".m4a", ".flac", ".aac", ".ogg":
|
||||
return "audio"
|
||||
case ".txt", ".md", ".log", ".json", ".xml", ".inc":
|
||||
return "text"
|
||||
case ".html":
|
||||
return "html"
|
||||
case ".pdf", ".doc", ".docx", ".xls", ".xlsx", ".ppt", ".pptx":
|
||||
return "document"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func BuildText(text string) string {
|
||||
// 生成单条HTML
|
||||
var htmlBuilder strings.Builder
|
||||
@@ -354,7 +712,7 @@ func BuildHtml(text string, images []string) string {
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.06);
|
||||
}
|
||||
|
||||
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
@@ -457,8 +815,8 @@ func SplitMultiContents(htmlContent string) []string {
|
||||
func GetAllImgSrcFromHtml(html string) []string {
|
||||
var imgSrcList []string
|
||||
re := regexp.MustCompile(`<img[^>]*src\s*=\s*["']([^"']+)["']`)
|
||||
matchs := re.FindAllStringSubmatch(html, -1)
|
||||
for _, match := range matchs {
|
||||
submatch := re.FindAllStringSubmatch(html, -1)
|
||||
for _, match := range submatch {
|
||||
if len(match) >= 2 {
|
||||
imgSrcList = append(imgSrcList, match[1])
|
||||
}
|
||||
@@ -468,7 +826,7 @@ func GetAllImgSrcFromHtml(html string) []string {
|
||||
|
||||
// ReplaceImgSrc 替换img src的方法
|
||||
func ReplaceImgSrc(html string, oldSrc string, newSrc string) string {
|
||||
// 精准替换:找到 <img xxx src="oldSrc" xxx> 并替换
|
||||
// 精准替换:找到 <img xxx src="oldSrc" xxx>
|
||||
re := regexp.MustCompile(`(<img[^>]*src\s*=\s*["'])` + regexp.QuoteMeta(oldSrc) + `(["'])`)
|
||||
return re.ReplaceAllString(html, `${1}`+newSrc+`${2}`)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user