338 lines
6.9 KiB
Go
338 lines
6.9 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"prompts-core/model/dto"
|
||
"sort"
|
||
"strings"
|
||
|
||
"github.com/gogf/gf/v2/container/gvar"
|
||
"github.com/gogf/gf/v2/frame/g"
|
||
)
|
||
|
||
// ============================================
|
||
// 工具函数
|
||
// ============================================
|
||
|
||
func getField(item map[string]any, fallback string) string {
|
||
if field := asString(item["field"]); field != "" {
|
||
return field
|
||
}
|
||
return fallback
|
||
}
|
||
|
||
func getLabel(item map[string]any) string {
|
||
return asString(item["label"])
|
||
}
|
||
|
||
func getValue(item map[string]any) any {
|
||
return item["value"]
|
||
}
|
||
|
||
func cloneWithValue(item map[string]any, value any) map[string]any {
|
||
cloned := make(map[string]any)
|
||
for k, v := range item {
|
||
cloned[k] = v
|
||
}
|
||
cloned["value"] = value
|
||
return cloned
|
||
}
|
||
|
||
func isSensitiveField(field string) bool {
|
||
f := strings.ToLower(field)
|
||
return f == "apikey" || f == "authorization"
|
||
}
|
||
|
||
func isAPIKeyField(field string) bool {
|
||
f := strings.ToLower(field)
|
||
return f == "apikey" || f == "authorization"
|
||
}
|
||
|
||
func isTextType(field, label string) bool {
|
||
f := strings.ToLower(field)
|
||
l := strings.ToLower(label)
|
||
return f == "prompt" || f == "text" ||
|
||
l == "提示词" || l == "文本内容" || l == "prompt" || l == "text"
|
||
}
|
||
|
||
func isDuplicate(userText, field, label string, value any) bool {
|
||
lowerText := strings.ToLower(userText)
|
||
|
||
if label != "" && strings.Contains(lowerText, strings.ToLower(label)) {
|
||
return true
|
||
}
|
||
if field != "" && strings.Contains(lowerText, strings.ToLower(field)) {
|
||
return true
|
||
}
|
||
|
||
// 检查值
|
||
if v := asString(value); v != "" && strings.Contains(lowerText, strings.ToLower(v)) {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
func isEmptyValue(v any) bool {
|
||
if v == nil {
|
||
return true
|
||
}
|
||
if s, ok := v.(string); ok {
|
||
return strings.TrimSpace(s) == ""
|
||
}
|
||
return false
|
||
}
|
||
|
||
func isNilOrEmpty(v any) bool {
|
||
if v == nil {
|
||
return true
|
||
}
|
||
if s, ok := v.(string); ok {
|
||
return strings.TrimSpace(s) == ""
|
||
}
|
||
return false
|
||
}
|
||
|
||
func asString(v any) string {
|
||
switch t := v.(type) {
|
||
case string:
|
||
return t
|
||
default:
|
||
b, _ := json.Marshal(t)
|
||
return strings.Trim(string(b), "\"")
|
||
}
|
||
}
|
||
|
||
func formatValue(v any) string {
|
||
return strings.TrimSpace(asString(v))
|
||
}
|
||
|
||
func mapToText(m map[string]any) string {
|
||
if len(m) == 0 {
|
||
return ""
|
||
}
|
||
|
||
keys := sortedKeys(m)
|
||
var parts []string
|
||
for _, k := range keys {
|
||
if isNilOrEmpty(m[k]) {
|
||
continue
|
||
}
|
||
parts = append(parts, fmt.Sprintf("%s:%s", k, formatValue(m[k])))
|
||
}
|
||
|
||
return strings.Join(parts, ",")
|
||
}
|
||
|
||
func sortedKeys(m map[string]any) []string {
|
||
keys := make([]string, 0, len(m))
|
||
for k := range m {
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
return keys
|
||
}
|
||
|
||
func mustMarshal(v any) string {
|
||
b, err := json.Marshal(v)
|
||
if err != nil {
|
||
return "{}"
|
||
}
|
||
return string(b)
|
||
}
|
||
|
||
func formatTaskError(taskID, errMsg string) error {
|
||
if strings.TrimSpace(errMsg) == "" {
|
||
return fmt.Errorf("任务失败(taskId=%s)", taskID)
|
||
}
|
||
return fmt.Errorf("任务失败(taskId=%s): %s", taskID, errMsg)
|
||
}
|
||
|
||
func getIntConfig(ctx context.Context, key string, fallback int) int {
|
||
v := g.Cfg().MustGet(ctx, key)
|
||
if v.IsEmpty() {
|
||
return fallback
|
||
}
|
||
return v.Int()
|
||
}
|
||
|
||
// ============================================
|
||
// Schema 处理
|
||
// ============================================
|
||
|
||
func enrichSchemaWithValues(schema []any, values map[string]any) []any {
|
||
if len(schema) == 0 || len(values) == 0 {
|
||
return schema
|
||
}
|
||
|
||
result := make([]any, len(schema))
|
||
copy(result, schema)
|
||
|
||
for i, item := range result {
|
||
m, ok := item.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
field := getField(m, "")
|
||
if field == "" {
|
||
continue
|
||
}
|
||
|
||
// 已有 value 则跳过
|
||
if _, hasValue := m["value"]; hasValue {
|
||
continue
|
||
}
|
||
|
||
// 补充 value
|
||
if v, exists := values[field]; exists {
|
||
m["value"] = v
|
||
result[i] = m
|
||
}
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// extractContentFromResponse 从模型完整响应中提取 content 字段
|
||
func extractContentFromResponse(text string) string {
|
||
// 尝试解析为完整的 choices 响应
|
||
var response struct {
|
||
Choices []struct {
|
||
Message struct {
|
||
Content string `json:"content"`
|
||
} `json:"message"`
|
||
} `json:"choices"`
|
||
}
|
||
|
||
if err := json.Unmarshal([]byte(text), &response); err != nil {
|
||
return ""
|
||
}
|
||
|
||
if len(response.Choices) > 0 && response.Choices[0].Message.Content != "" {
|
||
return response.Choices[0].Message.Content
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// ============================================
|
||
// 值提取
|
||
// ============================================
|
||
|
||
func extractSystemValues(req *dto.ComposeMessagesReq) map[string]any {
|
||
if req == nil {
|
||
return nil
|
||
}
|
||
|
||
values := make(map[string]any)
|
||
|
||
for _, value := range req.Form {
|
||
item, ok := value.(map[string]any)
|
||
if !ok || len(item) == 0 {
|
||
continue
|
||
}
|
||
|
||
field := getField(item, "")
|
||
if field == "" || isSensitiveField(field) {
|
||
continue
|
||
}
|
||
|
||
if v := getValue(item); !isNilOrEmpty(v) {
|
||
values[field] = v
|
||
}
|
||
}
|
||
|
||
return values
|
||
}
|
||
|
||
func extractModelKey(form map[string]any) string {
|
||
for _, value := range form {
|
||
item, ok := value.(map[string]any)
|
||
if !ok {
|
||
continue
|
||
}
|
||
|
||
field := getField(item, "")
|
||
if isAPIKeyField(field) {
|
||
key := strings.TrimSpace(asString(getValue(item)))
|
||
if key != "" {
|
||
if strings.Contains(key, ":") {
|
||
return key
|
||
}
|
||
return "Authorization:" + key
|
||
}
|
||
}
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// ==================== 工具方法 ====================
|
||
|
||
// convertToMessages 将数据库 any 类型转换为 []Message
|
||
// 支持:JSON字符串、[]byte、[]interface{}、以及 content 为字符串数组的格式
|
||
func (s *sessionService) convertToMessages(data any) []Message {
|
||
if data == nil {
|
||
return []Message{}
|
||
}
|
||
|
||
// 处理 *gvar.Var
|
||
if v, ok := data.(*gvar.Var); ok {
|
||
if v == nil || v.IsNil() || v.IsEmpty() {
|
||
return []Message{}
|
||
}
|
||
data = v.Val()
|
||
}
|
||
|
||
var rawList []any
|
||
|
||
switch v := data.(type) {
|
||
case string:
|
||
if err := json.Unmarshal([]byte(v), &rawList); err != nil {
|
||
g.Log().Warningf(context.Background(), "[会话] 解析JSON字符串失败 err=%v data=%.200s", err, v)
|
||
return []Message{}
|
||
}
|
||
case []byte:
|
||
if err := json.Unmarshal(v, &rawList); err != nil {
|
||
g.Log().Warningf(context.Background(), "[会话] 解析字节数组失败 err=%v", err)
|
||
return []Message{}
|
||
}
|
||
case []interface{}:
|
||
rawList = v
|
||
default:
|
||
b, _ := json.Marshal(v)
|
||
if err := json.Unmarshal(b, &rawList); err != nil {
|
||
g.Log().Warningf(context.Background(), "[会话] 解析未知类型失败 err=%v type=%T", err, v)
|
||
return []Message{}
|
||
}
|
||
}
|
||
|
||
// 转换每个元素为 Message
|
||
var messages []Message
|
||
for _, item := range rawList {
|
||
var msg Message
|
||
switch val := item.(type) {
|
||
case string:
|
||
if err := json.Unmarshal([]byte(val), &msg); err != nil {
|
||
g.Log().Warningf(context.Background(), "[会话] 解析消息元素失败 err=%v data=%s", err, val)
|
||
continue
|
||
}
|
||
case map[string]interface{}:
|
||
b, _ := json.Marshal(val)
|
||
json.Unmarshal(b, &msg)
|
||
default:
|
||
b, _ := json.Marshal(val)
|
||
json.Unmarshal(b, &msg)
|
||
}
|
||
messages = append(messages, msg)
|
||
}
|
||
|
||
if messages == nil {
|
||
messages = []Message{}
|
||
}
|
||
return messages
|
||
}
|