prompts-core
This commit is contained in:
@@ -1,21 +1,59 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
)
|
||||
|
||||
// ============================================
|
||||
// 文件处理
|
||||
// 文件处理(配置直接内联 + zip 支持)
|
||||
// ============================================
|
||||
|
||||
func fetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
// 允许的文本类 MIME 类型前缀
|
||||
var allowedMIMEPrefixes = []string{
|
||||
"text/",
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"application/toml",
|
||||
"application/x-httpd-php",
|
||||
"application/x-sh",
|
||||
"application/x-python",
|
||||
"application/x-perl",
|
||||
"application/x-ruby",
|
||||
}
|
||||
|
||||
// 禁止的文件扩展名
|
||||
var bannedExtensions = map[string]bool{
|
||||
".png": true, ".jpg": true, ".jpeg": true, ".gif": true, ".bmp": true,
|
||||
".webp": true, ".svg": true, ".ico": true, ".tiff": true, ".tif": true,
|
||||
".mp3": true, ".wav": true, ".ogg": true, ".flac": true, ".aac": true,
|
||||
".wma": true, ".m4a": true,
|
||||
".mp4": true, ".avi": true, ".mkv": true, ".mov": true, ".wmv": true,
|
||||
".flv": true, ".webm": true,
|
||||
".tar": true, ".gz": true, ".rar": true, ".7z": true,
|
||||
".exe": true, ".dll": true, ".so": true, ".bin": true, ".dat": true,
|
||||
".class": true, ".pyc": true,
|
||||
".pdf": true, ".doc": true, ".docx": true, ".xls": true, ".xlsx": true,
|
||||
".ppt": true, ".pptx": true,
|
||||
}
|
||||
|
||||
var symbolCleaner = regexp.MustCompile(`[\x00-\x08\x0B\x0C\x0E-\x1F]`)
|
||||
|
||||
// FetchFileTexts 从 URL 列表获取文件内容(支持 zip 内文件)
|
||||
func FetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
if len(urls) == 0 {
|
||||
@@ -23,7 +61,7 @@ func fetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: 8 * time.Second,
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "userFiles.httpTimeoutSec", 8).Int()) * time.Second,
|
||||
}
|
||||
|
||||
for _, rawURL := range urls {
|
||||
@@ -32,13 +70,20 @@ func fetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
continue
|
||||
}
|
||||
|
||||
if isBannedExtension(url) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isZipExtension(url) {
|
||||
zipTexts := fetchZipFileTexts(ctx, client, url)
|
||||
for k, v := range zipTexts {
|
||||
result[k] = v
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
text, err := fetchFileContent(ctx, client, url)
|
||||
if err != nil {
|
||||
glog.Warningf(ctx,
|
||||
"[FetchFile] failed url=%s err=%v",
|
||||
url,
|
||||
err,
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -46,24 +91,131 @@ func fetchFileTexts(ctx context.Context, urls []string) map[string]string {
|
||||
continue
|
||||
}
|
||||
|
||||
text = cleanSymbols(text)
|
||||
result[url] = text
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func fetchFileContent(
|
||||
ctx context.Context,
|
||||
client *http.Client,
|
||||
url string,
|
||||
) (string, error) {
|
||||
func isZipExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return ext == ".zip"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(
|
||||
ctx,
|
||||
http.MethodGet,
|
||||
url,
|
||||
nil,
|
||||
func fetchZipFileTexts(ctx context.Context, client *http.Client, url string) map[string]string {
|
||||
result := make(map[string]string)
|
||||
|
||||
zipBytes, err := downloadFile(client, url,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return result
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "userFiles.zipEntryMaxSizeKB", 500).Int()) * 1024
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
fileName := file.Name
|
||||
|
||||
if isBannedExtension(fileName) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isZipExtension(fileName) {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
contentType := http.DetectContentType(content)
|
||||
if !isReadableContentType(contentType) {
|
||||
continue
|
||||
}
|
||||
|
||||
text := cleanSymbols(string(content))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
key := url + "::" + fileName
|
||||
result[key] = text
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func downloadFile(client *http.Client, url string, maxSize int64) ([]byte, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(io.LimitReader(resp.Body, maxSize))
|
||||
}
|
||||
|
||||
func isBannedExtension(url string) bool {
|
||||
ext := strings.ToLower(filepath.Ext(url))
|
||||
if idx := strings.Index(ext, "?"); idx != -1 {
|
||||
ext = ext[:idx]
|
||||
}
|
||||
return bannedExtensions[ext]
|
||||
}
|
||||
|
||||
func isReadableContentType(contentType string) bool {
|
||||
if contentType == "" {
|
||||
return false
|
||||
}
|
||||
ct := strings.ToLower(contentType)
|
||||
for _, prefix := range allowedMIMEPrefixes {
|
||||
if strings.HasPrefix(ct, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cleanSymbols(text string) string {
|
||||
text = symbolCleaner.ReplaceAllString(text, "")
|
||||
text = strings.ReplaceAll(text, "\r\n", "\n")
|
||||
text = strings.ReplaceAll(text, "\r", "\n")
|
||||
text = regexp.MustCompile(`\n{3,}`).ReplaceAllString(text, "\n\n")
|
||||
return strings.TrimSpace(text)
|
||||
}
|
||||
|
||||
func fetchFileContent(ctx context.Context, client *http.Client, url string) (string, error) {
|
||||
req, err := http.NewRequest(http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -74,21 +226,19 @@ func fetchFileContent(
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// HTTP状态检查
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Content-Type检查
|
||||
contentType := strings.ToLower(resp.Header.Get("Content-Type"))
|
||||
|
||||
if !isTextContentType(contentType) {
|
||||
return "", fmt.Errorf("unsupported content-type: %s", contentType)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if !isReadableContentType(contentType) {
|
||||
return "", fmt.Errorf("unreadable content-type: %s", contentType)
|
||||
}
|
||||
|
||||
// 最大读取20KB
|
||||
body, err := io.ReadAll(
|
||||
io.LimitReader(resp.Body, 20*1024),
|
||||
io.LimitReader(resp.Body,
|
||||
int64(g.Cfg().MustGet(ctx, "userFiles.textFileMaxSizeKB", 500).Int())*1024,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -97,35 +247,94 @@ func fetchFileContent(
|
||||
return strings.TrimSpace(string(body)), nil
|
||||
}
|
||||
|
||||
// 判断是否为文本类型
|
||||
func isTextContentType(contentType string) bool {
|
||||
|
||||
// text/*
|
||||
if strings.HasPrefix(contentType, "text/") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 常见文本类型
|
||||
allowTypes := []string{
|
||||
"application/json",
|
||||
"application/xml",
|
||||
"application/javascript",
|
||||
"application/x-yaml",
|
||||
"application/yaml",
|
||||
"application/toml",
|
||||
}
|
||||
|
||||
for _, t := range allowTypes {
|
||||
if strings.Contains(contentType, t) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func sanitizeURL(raw string) string {
|
||||
s := strings.TrimSpace(raw)
|
||||
s = strings.Trim(s, "`\"")
|
||||
return s
|
||||
}
|
||||
|
||||
// SkillMdContent 根据 skillName 获取 zip 内所有 md 文件拼接内容
|
||||
func SkillMdContent(ctx context.Context, skillName string) string {
|
||||
// 1. 请求接口获取 SkillUserVO
|
||||
skillResp, err := GetSkillUser(ctx, skillName)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
fullUrl := skillResp.ImgAddressPrefix + skillResp.FileUrl
|
||||
// 2. 下载 zip 文件
|
||||
client := &http.Client{
|
||||
Timeout: time.Duration(g.Cfg().MustGet(ctx, "skillFiles.httpTimeoutSec", 30).Int()) * time.Second,
|
||||
}
|
||||
|
||||
zipBytes, err := downloadFile(client, fullUrl,
|
||||
int64(g.Cfg().MustGet(ctx, "skillFiles.zipMaxSizeMB", 10).Int())*1024*1024,
|
||||
)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 3. 解压 zip 并提取所有 md 文件内容
|
||||
mdContents, err := extractMdFiles(ctx, zipBytes)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(mdContents) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 4. 拼接所有 md 内容
|
||||
var builder strings.Builder
|
||||
builder.WriteString(fmt.Sprintf("# Skill: %s\n\n", skillResp.Name))
|
||||
if skillResp.Description != "" {
|
||||
builder.WriteString(fmt.Sprintf("> %s\n\n", skillResp.Description))
|
||||
}
|
||||
|
||||
for fileName, content := range mdContents {
|
||||
builder.WriteString(fmt.Sprintf("## %s\n\n", fileName))
|
||||
builder.WriteString(content)
|
||||
builder.WriteString("\n\n---\n\n")
|
||||
}
|
||||
|
||||
return strings.TrimSpace(builder.String())
|
||||
}
|
||||
|
||||
// extractMdFiles 解压 zip 并提取所有 .md 文件内容
|
||||
func extractMdFiles(ctx context.Context, zipBytes []byte) (map[string]string, error) {
|
||||
result := make(map[string]string)
|
||||
|
||||
reader, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
entryMaxSize := int64(g.Cfg().MustGet(ctx, "skillFiles.mdMaxSizeKB", 500).Int()) * 1024
|
||||
|
||||
for _, file := range reader.File {
|
||||
if file.FileInfo().IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !strings.HasSuffix(strings.ToLower(file.Name), ".md") {
|
||||
continue
|
||||
}
|
||||
|
||||
rc, err := file.Open()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
content, err := io.ReadAll(io.LimitReader(rc, entryMaxSize))
|
||||
rc.Close()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(content) > 0 {
|
||||
result[file.Name] = strings.TrimSpace(string(content))
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user