475 lines
13 KiB
Go
475 lines
13 KiB
Go
package flow
|
||
|
||
import (
|
||
"ai-agent/workflow/model/dto"
|
||
flowDto "ai-agent/workflow/model/dto/flow"
|
||
"bytes"
|
||
"context"
|
||
"fmt"
|
||
"io"
|
||
"mime/multipart"
|
||
"net/http"
|
||
"regexp"
|
||
"strconv"
|
||
"strings"
|
||
|
||
commonHttp "gitea.redpowerfuture.com/red-future/common/http"
|
||
"gitea.redpowerfuture.com/red-future/common/utils"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
"github.com/gogf/gf/v2/util/gconv"
|
||
)
|
||
|
||
func GetIsChatModel(ctx context.Context) (res *flowDto.GetIsChatModelRes, err error) {
|
||
headers := make(map[string]string)
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
for k, v := range r.Request.Header {
|
||
if len(v) > 0 {
|
||
headers[k] = v[0]
|
||
}
|
||
}
|
||
}
|
||
res = new(flowDto.GetIsChatModelRes)
|
||
err = commonHttp.Get(ctx, "model-gateway/model/getIsChatModel", headers, res, nil)
|
||
return
|
||
}
|
||
|
||
func ComposeMessages(ctx context.Context, req *flowDto.ComposeMessagesReq) (res *flowDto.ComposeMessagesRes, err error) {
|
||
headers := make(map[string]string)
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
for k, v := range r.Request.Header {
|
||
if len(v) > 0 {
|
||
headers[k] = v[0]
|
||
}
|
||
}
|
||
}
|
||
res = new(flowDto.ComposeMessagesRes)
|
||
err = commonHttp.Post(ctx, "prompts-core/prompt/composeMessages", headers, res, &req)
|
||
return
|
||
}
|
||
|
||
func GetModelResult(ctx context.Context, modelName, skillName string, form, userFrom map[string]any, fileUrl []string, sessionId string, cause string) (mapTaskResult map[string]any, err error) {
|
||
msgReq := flowDto.ComposeMessagesReq{
|
||
BuildType: 1,
|
||
ModelName: modelName,
|
||
SkillName: skillName,
|
||
Cause: cause,
|
||
Form: form,
|
||
UserForm: userFrom,
|
||
UserFiles: fileUrl,
|
||
SessionId: sessionId,
|
||
}
|
||
msg, err := ComposeMessages(ctx, &msgReq)
|
||
if err != nil {
|
||
return
|
||
}
|
||
if g.IsEmpty(msg.Messages) {
|
||
return nil, fmt.Errorf("msg is empty")
|
||
}
|
||
var taskResult any
|
||
taskResult, err = GatewayTask(ctx, msg.EpicycleId, modelName, msg.Messages)
|
||
if err != nil {
|
||
return
|
||
}
|
||
var getTaskResult *flowDto.TaskCallback
|
||
getTaskResult, err = GetTaskResult(ctx, taskResult)
|
||
if err != nil {
|
||
return
|
||
}
|
||
mapTaskResult = gconv.Map(getTaskResult.Text)
|
||
return mapTaskResult, nil
|
||
}
|
||
|
||
func GatewayTask(ctx context.Context, epicycleId int64, model string, content map[string]any) (any, error) {
|
||
modelTaskId, err := CreateGatewayTask(ctx, &flowDto.CreateTaskReq{
|
||
ModelName: model,
|
||
BizName: g.Cfg().MustGet(ctx, "server.name").String(),
|
||
CallbackUrl: "/flow/execution/modelCallback",
|
||
RequestPayload: content,
|
||
EpicycleId: epicycleId,
|
||
})
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return Wait(ctx, modelTaskId)
|
||
}
|
||
|
||
func CreateGatewayTask(ctx context.Context, req *flowDto.CreateTaskReq) (string, error) {
|
||
headers := make(map[string]string)
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
for k, v := range r.Request.Header {
|
||
if len(v) > 0 {
|
||
headers[k] = v[0]
|
||
}
|
||
}
|
||
}
|
||
res := new(flowDto.CreateTaskRes)
|
||
err := commonHttp.Post(ctx, "model-gateway/task/createTask", headers, res, &req)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return res.TaskId, nil
|
||
}
|
||
|
||
func GetTaskResult(ctx context.Context, result any) (*flowDto.TaskCallback, error) {
|
||
task := new(flowDto.TaskCallback)
|
||
if err := gconv.Struct(result, task); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
url, err := utils.GetFileAddressPrefix(ctx)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 获取远程文件内容
|
||
file, err := FetchRemoteJsonFile(ctx, url+task.OssFile)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
task.Text = gconv.String(file)
|
||
|
||
return task, nil
|
||
}
|
||
|
||
func FetchRemoteJsonFile(ctx context.Context, fileUrl string) ([]byte, error) {
|
||
// 1. 下载文件
|
||
resp, err := g.Client().Get(ctx, fileUrl)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("get file failed: %w", err)
|
||
}
|
||
defer resp.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("http status error: %d", resp.StatusCode)
|
||
}
|
||
|
||
return io.ReadAll(resp.Body)
|
||
}
|
||
|
||
func GetFileBytesFromURL(url string) (all []byte, err error) {
|
||
resp, err := http.Get(url)
|
||
if err != nil {
|
||
fmt.Printf("请求失败 %s: %v", url, err)
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
fmt.Printf("请求失败,状态码: %d\n", resp.StatusCode)
|
||
return
|
||
}
|
||
all, err = io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
fmt.Printf("读取内容失败 %s: %v", url, err)
|
||
return
|
||
}
|
||
return
|
||
}
|
||
|
||
func Upload(ctx context.Context, req *dto.UploadFileBytesReq) (*dto.UploadFileBytesRes, error) {
|
||
body := &bytes.Buffer{}
|
||
writer := multipart.NewWriter(body)
|
||
|
||
part, err := writer.CreateFormFile("file", req.FileName)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if _, err = part.Write(req.FileBytes); err != nil {
|
||
return nil, err
|
||
}
|
||
if err = writer.Close(); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
headers := make(map[string]string)
|
||
headers["Content-Type"] = writer.FormDataContentType()
|
||
if r := g.RequestFromCtx(ctx); r != nil {
|
||
if auth := r.Header.Get("Authorization"); auth != "" {
|
||
headers["Authorization"] = auth
|
||
}
|
||
}
|
||
|
||
// 发起上传请求
|
||
res := &dto.UploadFileBytesRes{}
|
||
url := "oss/file/uploadFile"
|
||
if err = commonHttp.Post(ctx, url, headers, res, body.Bytes()); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
g.Log().Infof(ctx, "[Upload] success url=%s size=%d", res.FileURL, res.FileSize)
|
||
return res, nil
|
||
}
|
||
|
||
func BuildText(text string) string {
|
||
// 生成单条HTML
|
||
var htmlBuilder strings.Builder
|
||
htmlBuilder.WriteString(`
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||
<style>
|
||
* {
|
||
margin: 0;
|
||
padding: 0;
|
||
box-sizing: border-box;
|
||
}
|
||
body {
|
||
font-family: "Microsoft YaHei", "PingFang SC", Arial, sans-serif;
|
||
background: #f5f5f5;
|
||
color: #333;
|
||
line-height: 1.8;
|
||
padding: 20px;
|
||
}
|
||
.container {
|
||
max-width: 900px;
|
||
margin: 0 auto;
|
||
background: #fff;
|
||
border-radius: 12px;
|
||
box-shadow: 0 2px 12px rgba(0, 0, 0, 0.08);
|
||
overflow: hidden;
|
||
}
|
||
.item {
|
||
padding: 30px;
|
||
}
|
||
.image-group img {
|
||
width: 100%;
|
||
height: auto;
|
||
display: block;
|
||
margin-bottom: 6px;
|
||
border-radius: 8px;
|
||
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
|
||
}
|
||
.image-group img:last-child {
|
||
margin-bottom: 0;
|
||
}
|
||
.image-group {
|
||
margin-bottom: 25px;
|
||
}
|
||
.text {
|
||
padding: 0;
|
||
font-size: 15px;
|
||
line-height: 1.4;
|
||
color: #555;
|
||
}
|
||
.text h2 {
|
||
font-size: 28px;
|
||
font-weight: bold;
|
||
color: #1a1a1a;
|
||
margin-bottom: 15px;
|
||
line-height: 1.2;
|
||
}
|
||
.text h3 {
|
||
font-size: 20px;
|
||
font-weight: 600;
|
||
color: #2c3e50;
|
||
margin: 20px 0 12px;
|
||
padding-left: 12px;
|
||
border-left: 4px solid #409eff;
|
||
}
|
||
.text p {
|
||
margin-bottom: 12px;
|
||
text-align: justify;
|
||
}
|
||
.text strong {
|
||
color: #e74c3c;
|
||
font-weight: 600;
|
||
}
|
||
.text ul {
|
||
list-style: none;
|
||
padding: 0;
|
||
margin: 8px 0;
|
||
}
|
||
.text ul li {
|
||
padding: 10px 0 10px 30px;
|
||
position: relative;
|
||
line-height: 1.2;
|
||
}
|
||
.text ul li:before {
|
||
content: "●";
|
||
color: #409eff;
|
||
font-size: 12px;
|
||
position: absolute;
|
||
left: 12px;
|
||
top: 12px;
|
||
}
|
||
@media (max-width: 768px) {
|
||
body {
|
||
padding: 10px;
|
||
}
|
||
.text h2 {
|
||
font-size: 24px;
|
||
}
|
||
.text h3 {
|
||
font-size: 18px;
|
||
}
|
||
}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="container">
|
||
<div class="item">
|
||
`)
|
||
// 🔥 写入文案前:删除 <p class="image-count">需要配图:X 张</p>
|
||
if text != "" {
|
||
// 写入清理后的文案
|
||
htmlBuilder.WriteString(fmt.Sprintf(`<div class="text">%s</div>`, ImageTagRegex(text)))
|
||
}
|
||
htmlBuilder.WriteString(`</div>
|
||
</div>
|
||
</body>
|
||
</html>`)
|
||
|
||
return htmlBuilder.String()
|
||
}
|
||
|
||
func BuildHtml(text string, images []string) string {
|
||
var htmlBuilder strings.Builder
|
||
htmlBuilder.WriteString(`<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||
<style>
|
||
* {
|
||
margin: 0;
|
||
padding: 0;
|
||
box-sizing: border-box;
|
||
}
|
||
body {
|
||
font-family: "Microsoft YaHei", sans-serif;
|
||
padding: 20px;
|
||
background-color: #f6f6f6;
|
||
line-height: 1.7;
|
||
font-size: 16px;
|
||
color: #333;
|
||
}
|
||
.container {
|
||
max-width: 750px;
|
||
margin: 0 auto;
|
||
background: #fff;
|
||
padding: 30px;
|
||
border-radius: 12px;
|
||
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.06);
|
||
}
|
||
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<div class="container">
|
||
`)
|
||
// 写入图片(支持0张、1张、多张)
|
||
if len(images) > 0 {
|
||
htmlBuilder.WriteString(`<div class="image-group">`)
|
||
for _, imgUrl := range images {
|
||
htmlBuilder.WriteString(fmt.Sprintf(`<img src="%s" alt="图片"/>`, imgUrl))
|
||
}
|
||
htmlBuilder.WriteString(`</div>`)
|
||
}
|
||
htmlBuilder.WriteString(`
|
||
<div id="content">加载中...</div>
|
||
</div>
|
||
|
||
<script>
|
||
const incUrl = "` + text + `";
|
||
fetch(incUrl)
|
||
.then(res => {
|
||
if (!res.ok) throw new Error("加载失败");
|
||
return res.text();
|
||
})
|
||
.then(text => {
|
||
document.getElementById("content").innerHTML = text;
|
||
})
|
||
.catch(err => {
|
||
document.getElementById("content").innerHTML = "加载失败:" + err.message;
|
||
});
|
||
</script>
|
||
</body>
|
||
</html>`)
|
||
|
||
return htmlBuilder.String()
|
||
}
|
||
|
||
// ExtractImageCount 修复:支持单引号/双引号 + 换行 + 空格
|
||
func ExtractImageCount(content string) int {
|
||
// 🔥 关键:支持 class='image-count' (单引号)
|
||
re := regexp.MustCompile(`<p class=['"]image-count['"][^>]*>.*?(\d+).*?</p>`)
|
||
match := re.FindStringSubmatch(content)
|
||
if len(match) >= 2 {
|
||
num, err := strconv.Atoi(match[1])
|
||
if err == nil {
|
||
return num
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
func ImageTagRegex(html string) string {
|
||
// 🔥 修复:支持单引号、双引号、空格、换行,100% 删除 <p class='image-count'>
|
||
imageTagRegex := regexp.MustCompile(`<p class=['"]image-count['"][^>]*>[\s\S]*?</p>`)
|
||
return imageTagRegex.ReplaceAllString(html, "")
|
||
}
|
||
|
||
// StripHtmlTags 去掉所有HTML标签,保留换行和文本结构,并删除配图标记行
|
||
func StripHtmlTags(html string) string {
|
||
// 1. 替换块级标签为换行,保证排版
|
||
blockTags := regexp.MustCompile(`</?(div|p|h1|h2|h3|h4|h5|h6|li|ul|ol|br|tr|td|th)[^>]*>`)
|
||
text := blockTags.ReplaceAllString(html, "\n")
|
||
|
||
// 2. 去掉所有剩余的 HTML 标签
|
||
allTags := regexp.MustCompile(`<[^>]+>`)
|
||
text = allTags.ReplaceAllString(text, "")
|
||
|
||
// 4. 清理多余空行(多个换行只保留一个)
|
||
text = regexp.MustCompile(`\n\s*\n`).ReplaceAllString(text, "\n")
|
||
|
||
// 5. 只去掉首尾空白,中间换行保留
|
||
text = strings.TrimSpace(text)
|
||
|
||
return text
|
||
}
|
||
|
||
// SplitMultiContents 拆分模型返回的多条文案(基于HTML标签分隔)
|
||
func SplitMultiContents(htmlContent string) []string {
|
||
var contents []string
|
||
// 正则匹配<div class="content-item" id="content-{序号}">包裹的内容
|
||
re := regexp.MustCompile(`<div class="content-item" id="content-\d+">([\s\S]*?)</div>`)
|
||
matches := re.FindAllStringSubmatch(htmlContent, -1)
|
||
for _, match := range matches {
|
||
if len(match) > 1 {
|
||
// 清理空内容
|
||
trimmed := strings.TrimSpace(match[1])
|
||
if trimmed != "" {
|
||
contents = append(contents, trimmed)
|
||
}
|
||
}
|
||
}
|
||
// 兜底:如果没有匹配到结构化内容,按换行/分隔符拆分
|
||
if len(contents) == 0 {
|
||
contents = strings.Split(htmlContent, "===分隔符===") // 提示词中可新增此兜底规则
|
||
}
|
||
return contents
|
||
}
|
||
|
||
// GetAllImgSrcFromHtml 先把提取img src的工具方法放在外面
|
||
func GetAllImgSrcFromHtml(html string) []string {
|
||
var imgSrcList []string
|
||
re := regexp.MustCompile(`<img[^>]*src\s*=\s*["']([^"']+)["']`)
|
||
matchs := re.FindAllStringSubmatch(html, -1)
|
||
for _, match := range matchs {
|
||
if len(match) >= 2 {
|
||
imgSrcList = append(imgSrcList, match[1])
|
||
}
|
||
}
|
||
return imgSrcList
|
||
}
|
||
|
||
// ReplaceImgSrc 替换img src的方法
|
||
func ReplaceImgSrc(html string, oldSrc string, newSrc string) string {
|
||
// 精准替换:找到 <img xxx src="oldSrc" xxx> 并替换
|
||
re := regexp.MustCompile(`(<img[^>]*src\s*=\s*["'])` + regexp.QuoteMeta(oldSrc) + `(["'])`)
|
||
return re.ReplaceAllString(html, `${1}`+newSrc+`${2}`)
|
||
}
|