重构了一下 rag的方法, 使用 goframe的框架, 还有redis连接部分
This commit is contained in:
200
redis/redis.go
200
redis/redis.go
@@ -3,33 +3,50 @@ package redis
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/gogf/gf/v2/database/gredis"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/os/gtime"
|
||||
"github.com/gogf/gf/v2/util/gconv"
|
||||
)
|
||||
|
||||
// GRedisClient GoFrame gredis 客户端,统一使用(懒加载)
|
||||
var GRedisClient *gredis.Redis
|
||||
var (
|
||||
// redisClient 单例 Redis 客户端
|
||||
redisClient *gredis.Redis
|
||||
// redisOnce 确保只初始化一次
|
||||
redisOnce sync.Once
|
||||
// RedisClient 兼容导出(供 mongo.go 使用)
|
||||
// 注意:这是一个指向单例的指针,首次调用 GetRedisClient() 后生效
|
||||
RedisClient *gredis.Redis
|
||||
)
|
||||
|
||||
// RedisClient GRedisClient 的别名,保持向后兼容
|
||||
var RedisClient *gredis.Redis
|
||||
|
||||
// GetRedisClient 获取 Redis 客户端(懒加载)
|
||||
// GetRedisClient 获取 Redis 客户端(单例模式)
|
||||
func GetRedisClient() *gredis.Redis {
|
||||
if GRedisClient == nil {
|
||||
GRedisClient = g.Redis()
|
||||
RedisClient = GRedisClient
|
||||
}
|
||||
return GRedisClient
|
||||
redisOnce.Do(func() {
|
||||
redisClient = g.Redis()
|
||||
RedisClient = redisClient // 同步更新兼容导出
|
||||
})
|
||||
return redisClient
|
||||
}
|
||||
|
||||
// init 包初始化时自动初始化 Redis 客户端
|
||||
func init() {
|
||||
GetRedisClient()
|
||||
}
|
||||
|
||||
// Stream 和消费者组常量
|
||||
const (
|
||||
// RAGFlow 请求 Stream Key
|
||||
RAGFlowRequestStreamKey = "ragflow:request:stream"
|
||||
// RAGFlow 消费者组名称
|
||||
// RAGFlow 响应 Stream Key
|
||||
RAGFlowResponseStreamKey = "ragflow:response:stream"
|
||||
// RAGFlow 请求消费者组名称
|
||||
RAGFlowRequestConsumerGroup = "ragflow:request:consumer:group"
|
||||
// RAGFlow 响应消费者组名称
|
||||
RAGFlowResponseConsumerGroup = "ragflow:response:consumer:group"
|
||||
// RAGFlow 消费者组名称(兼容旧代码)
|
||||
RAGFlowConsumerGroup = "ragflow:consumer:group"
|
||||
// 会话最后活跃时间 Key 前缀
|
||||
SessionLastActiveKeyPrefix = "ragflow:session:"
|
||||
@@ -79,6 +96,9 @@ func AddToStream(ctx context.Context, streamKey string, values map[string]interf
|
||||
// ReadFromStream 从 Stream 读取消息(消费者组模式)
|
||||
// 使用 gredis Do() 方法执行 XREADGROUP 命令
|
||||
func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName string, count int64, blockMs int64) ([]StreamMessage, error) {
|
||||
glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP GROUP %s %s COUNT %d BLOCK %d STREAMS %s >",
|
||||
groupName, consumerName, count, blockMs, streamKey)
|
||||
|
||||
// XREADGROUP GROUP groupName consumerName COUNT count BLOCK blockMs STREAMS streamKey >
|
||||
result, err := GetRedisClient().Do(ctx,
|
||||
"XREADGROUP", "GROUP", groupName, consumerName,
|
||||
@@ -88,66 +108,89 @@ func ReadFromStream(ctx context.Context, streamKey, groupName, consumerName stri
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
glog.Errorf(ctx, "[DEBUG Redis] XREADGROUP 错误: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析返回值
|
||||
// 格式: [[streamKey, [[msgID, [field1, value1, field2, value2, ...]], ...]]]
|
||||
glog.Debugf(ctx, "[DEBUG Redis] XREADGROUP 返回: %+v", result)
|
||||
|
||||
// 预分配容量,避免动态扩容
|
||||
messages := make([]StreamMessage, 0, int(count))
|
||||
|
||||
if result == nil {
|
||||
if result == nil || result.IsEmpty() {
|
||||
// 超时或没有数据
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// 类型断言:result.Val() 返回 interface{}
|
||||
streamsArray, ok := result.Val().([]interface{})
|
||||
if !ok || len(streamsArray) == 0 {
|
||||
return messages, nil
|
||||
}
|
||||
// GoFrame gredis 返回格式: map[streamKey:[[msgID [field1 value1 field2 value2 ...]] ...]]
|
||||
resultVal := result.Val()
|
||||
|
||||
// 遍历每个 stream
|
||||
for _, streamData := range streamsArray {
|
||||
streamArray, ok := streamData.([]interface{})
|
||||
if !ok || len(streamArray) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// streamArray[0] 是 streamKey, streamArray[1] 是消息数组
|
||||
messagesArray, ok := streamArray[1].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析每条消息
|
||||
for _, msgData := range messagesArray {
|
||||
msgArray, ok := msgData.([]interface{})
|
||||
if !ok || len(msgArray) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// msgArray[0] 是 ID, msgArray[1] 是字段数组
|
||||
msgID := gconv.String(msgArray[0])
|
||||
fieldsArray, ok := msgArray[1].([]interface{})
|
||||
// 尝试 map 格式(GoFrame gredis 返回)
|
||||
if streamsMap, ok := resultVal.(map[interface{}]interface{}); ok {
|
||||
for _, streamMsgs := range streamsMap {
|
||||
msgsArray, ok := streamMsgs.([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析字段为 map,预分配容量,避免动态扩容
|
||||
values := make(map[string]interface{}, len(fieldsArray)/2)
|
||||
for i := 0; i < len(fieldsArray); i += 2 {
|
||||
if i+1 < len(fieldsArray) {
|
||||
key := gconv.String(fieldsArray[i])
|
||||
val := fieldsArray[i+1]
|
||||
values[key] = val
|
||||
for _, msgData := range msgsArray {
|
||||
msgArray, ok := msgData.([]interface{})
|
||||
if !ok || len(msgArray) < 2 {
|
||||
continue
|
||||
}
|
||||
msgID := gconv.String(msgArray[0])
|
||||
fieldsArray, ok := msgArray[1].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
values := make(map[string]interface{}, len(fieldsArray)/2)
|
||||
for i := 0; i < len(fieldsArray); i += 2 {
|
||||
if i+1 < len(fieldsArray) {
|
||||
key := gconv.String(fieldsArray[i])
|
||||
values[key] = fieldsArray[i+1]
|
||||
}
|
||||
}
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msgID,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
}
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msgID,
|
||||
Values: values,
|
||||
})
|
||||
// 尝试数组格式(标准 Redis 返回)
|
||||
if streamsArray, ok := resultVal.([]interface{}); ok && len(streamsArray) > 0 {
|
||||
for _, streamData := range streamsArray {
|
||||
streamArray, ok := streamData.([]interface{})
|
||||
if !ok || len(streamArray) < 2 {
|
||||
continue
|
||||
}
|
||||
messagesArray, ok := streamArray[1].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, msgData := range messagesArray {
|
||||
msgArray, ok := msgData.([]interface{})
|
||||
if !ok || len(msgArray) < 2 {
|
||||
continue
|
||||
}
|
||||
msgID := gconv.String(msgArray[0])
|
||||
fieldsArray, ok := msgArray[1].([]interface{})
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
values := make(map[string]interface{}, len(fieldsArray)/2)
|
||||
for i := 0; i < len(fieldsArray); i += 2 {
|
||||
if i+1 < len(fieldsArray) {
|
||||
key := gconv.String(fieldsArray[i])
|
||||
values[key] = fieldsArray[i+1]
|
||||
}
|
||||
}
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msgID,
|
||||
Values: values,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,16 +243,16 @@ func GetPendingMessages(ctx context.Context, streamKey, groupName string, start,
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return []PendingMessage{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 解析返回值:[[ID, consumer, idle, retryCount], ...]
|
||||
pendingArray, ok := result.Val().([]interface{})
|
||||
if !ok {
|
||||
return []PendingMessage{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var messages []PendingMessage
|
||||
messages := make([]PendingMessage, 0, len(pendingArray))
|
||||
for _, item := range pendingArray {
|
||||
itemArray, ok := item.([]interface{})
|
||||
if !ok || len(itemArray) < 4 {
|
||||
@@ -242,13 +285,13 @@ func ClaimPendingMessage(ctx context.Context, streamKey, groupName, consumerName
|
||||
}
|
||||
|
||||
if result == nil {
|
||||
return []StreamMessage{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 解析返回值:类似 XREADGROUP
|
||||
messagesArray, ok := result.Val().([]interface{})
|
||||
if !ok {
|
||||
return []StreamMessage{}, nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 预分配容量,避免动态扩容
|
||||
@@ -344,6 +387,43 @@ func SetSessionCache(ctx context.Context, userId, sessionId string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 限流相关常量
|
||||
const (
|
||||
// RateLimitKeyPrefix 限流计数器 Key 前缀
|
||||
RateLimitKeyPrefix = "ragflow:ratelimit:"
|
||||
)
|
||||
|
||||
// IncrRateLimit 增加限流计数器,返回当前计数
|
||||
// windowSeconds: 时间窗口(秒)
|
||||
func IncrRateLimit(ctx context.Context, key string, windowSeconds int64) (count int64, err error) {
|
||||
fullKey := RateLimitKeyPrefix + key
|
||||
result, err := GetRedisClient().Do(ctx, "INCR", fullKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
count = result.Int64()
|
||||
|
||||
// 首次设置过期时间
|
||||
if count == 1 {
|
||||
GetRedisClient().Do(ctx, "EXPIRE", fullKey, windowSeconds)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// GetRateLimit 获取当前限流计数
|
||||
func GetRateLimit(ctx context.Context, key string) (count int64, err error) {
|
||||
fullKey := RateLimitKeyPrefix + key
|
||||
result, err := GetRedisClient().Get(ctx, fullKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if result.IsEmpty() {
|
||||
return 0, nil
|
||||
}
|
||||
count = result.Int64()
|
||||
return
|
||||
}
|
||||
|
||||
// GetSessionCache 获取缓存的 RAGFlow Session ID
|
||||
// 使用 gredis Get 方法
|
||||
func GetSessionCache(ctx context.Context, userId string) (string, error) {
|
||||
|
||||
@@ -37,3 +37,68 @@ func (m *BatchStreamMessage) ToMap() map[string]interface{} {
|
||||
"index": m.Index,
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseStreamMessage RAGFlow 响应消息结构(写入结果 Stream)
|
||||
type ResponseStreamMessage struct {
|
||||
UserId string `json:"user_id"` // 用户ID
|
||||
Platform string `json:"platform"` // 平台标识
|
||||
Question string `json:"question"` // 用户问题
|
||||
Content string `json:"content"` // RAGFlow 回复内容
|
||||
SessionId string `json:"session_id"` // RAGFlow Session ID
|
||||
Timestamp int64 `json:"timestamp"` // 时间戳(秒)
|
||||
MessageId string `json:"message_id"` // 原始消息ID
|
||||
}
|
||||
|
||||
// ToMap 转换为 map[string]interface{} 用于 Stream 存储
|
||||
func (m *ResponseStreamMessage) ToMap() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"user_id": m.UserId,
|
||||
"platform": m.Platform,
|
||||
"question": m.Question,
|
||||
"content": m.Content,
|
||||
"session_id": m.SessionId,
|
||||
"timestamp": m.Timestamp,
|
||||
"message_id": m.MessageId,
|
||||
}
|
||||
}
|
||||
|
||||
// FollowUpMessage 追问消息结构(RabbitMQ 延时队列)
|
||||
type FollowUpMessage struct {
|
||||
UserId string `json:"user_id"` // 用户ID
|
||||
Platform string `json:"platform"` // 平台标识
|
||||
Content string `json:"content"` // 追问内容
|
||||
FollowUpType int `json:"follow_up_type"` // 追问类型:1=30s, 2=60s, 3=180s
|
||||
Timestamp int64 `json:"timestamp"` // 发送时间戳
|
||||
}
|
||||
|
||||
// 追问话术常量
|
||||
const (
|
||||
FollowUpType1 = 1 // 30秒追问
|
||||
FollowUpType2 = 2 // 60秒追问
|
||||
FollowUpType3 = 3 // 180秒追问
|
||||
)
|
||||
|
||||
// 追问话术内容
|
||||
var FollowUpContents = map[int]string{
|
||||
FollowUpType1: "还有其他问题吗?",
|
||||
FollowUpType2: "如果需要帮助,随时告诉我~",
|
||||
FollowUpType3: "我一直在线,有问题随时找我~",
|
||||
}
|
||||
|
||||
// 追问延时时间(秒)
|
||||
var FollowUpDelays = map[int]int{
|
||||
FollowUpType1: 30,
|
||||
FollowUpType2: 60,
|
||||
FollowUpType3: 180,
|
||||
}
|
||||
|
||||
// ArchiveMessage 会话归档消息结构(RabbitMQ 延时队列)
|
||||
type ArchiveMessage struct {
|
||||
UserId string `json:"user_id"` // 用户ID
|
||||
Platform string `json:"platform"` // 平台标识
|
||||
SessionId string `json:"session_id"` // RAGFlow Session ID
|
||||
Timestamp int64 `json:"timestamp"` // 发送时间戳
|
||||
}
|
||||
|
||||
// 归档延时时间(秒)
|
||||
const ArchiveDelaySeconds = 3600 // 60分钟
|
||||
|
||||
Reference in New Issue
Block a user