Files
common/message/nats_rpc.go
qhd 55a6ec0374 重构消息队列连接管理,支持多数据源配置
主要变更:
1. 重构NATS、RabbitMQ和Redis连接管理模块,支持多数据源配置
2. 统一连接管理接口,增加数据源名称参数
3. 优化连接状态检查和错误处理
4. 增加连接池管理和资源清理机制
5. 改进日志输出格式和内容
2026-03-12 08:51:45 +08:00

771 lines
23 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package message
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/gogf/gf/v2/frame/g"
"github.com/nats-io/nats.go"
"go.opentelemetry.io/otel/trace"
"reflect"
"sync"
)
// ============ RPC 服务封装 ============
// 以下方法提供了完全抽象的 RPC 调用接口
// 调用方和响应方完全不需要知道底层使用的是 NATS 的发布订阅模式
// RPC 服务注册表
var (
rpcServices map[string]rpcHandler
rpcSubs map[string]*nats.Subscription // 服务名 -> 订阅
rpcServicesMu sync.RWMutex
queueRPCServices map[string]map[string]rpcHandler // queueName -> subject -> handler
queueRPCSubs map[string]map[string]*nats.Subscription // queueName -> serviceName -> 订阅
queueRPCMu sync.RWMutex
// ============ TraceID 主动取消支持 ============
// 全局映射表TraceID -> CancelFunc并发安全
traceCancelMap map[string]context.CancelFunc
traceCancelMu sync.RWMutex
// 取消主题前缀
cancelSubjectPrefix = "ctx.cancel.otel."
// RPC 使用的默认数据源名称
rpcDefaultDatasource = "default"
)
// rpcHandler RPC 处理函数类型
// 实现方只需要关注请求参数和返回值,无需了解底层 NATS 实现
// 返回值可以是任意类型,会被自动序列化为 JSON
type rpcHandler func(ctx context.Context, req []byte) (any, error)
// registerRPCService 注册 RPC 服务(单实例)
// serviceName: 服务名称,调用方通过此名称调用服务
// handler: 服务处理函数,接收请求并返回响应
func registerRPCService(serviceName string, handler rpcHandler) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
rpcServicesMu.Lock()
if rpcServices == nil {
rpcServices = make(map[string]rpcHandler)
}
if rpcSubs == nil {
rpcSubs = make(map[string]*nats.Subscription)
}
// 如果已存在该服务,先取消之前的订阅
if oldSub, exists := rpcSubs[serviceName]; exists {
oldSub.Unsubscribe()
}
rpcServices[serviceName] = handler
rpcServicesMu.Unlock()
// 订阅服务主题
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
subject := fmt.Sprintf("rpc.%s", serviceName)
sub, err := nc.Subscribe(subject, func(msg *nats.Msg) {
// 执行处理函数
executeHandler(handler, msg)
})
if err != nil {
return fmt.Errorf("注册 RPC 服务失败: %w", err)
}
rpcSubs[serviceName] = sub
g.Log().Infof(context.Background(), "✅ RPC 服务已注册: %s", serviceName)
return nil
}
// registerQueueRPCService 注册 RPC 服务(集群模式)
// 多个服务实例注册同一服务时,请求会自动负载均衡
// serviceName: 服务名称
// queueName: 队列组名,同一队列组的实例共享请求
// handler: 服务处理函数
func registerQueueRPCService(serviceName, queueName string, handler rpcHandler) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
queueRPCMu.Lock()
if queueRPCServices == nil {
queueRPCServices = make(map[string]map[string]rpcHandler)
}
if queueRPCSubs == nil {
queueRPCSubs = make(map[string]map[string]*nats.Subscription)
}
if queueRPCServices[queueName] == nil {
queueRPCServices[queueName] = make(map[string]rpcHandler)
}
if queueRPCSubs[queueName] == nil {
queueRPCSubs[queueName] = make(map[string]*nats.Subscription)
}
// 如果已存在该服务,先取消之前的订阅
if oldSub, exists := queueRPCSubs[queueName][serviceName]; exists {
oldSub.Unsubscribe()
}
queueRPCServices[queueName][serviceName] = handler
queueRPCMu.Unlock()
// 订阅服务主题(队列模式)
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
subject := fmt.Sprintf("rpc.%s", serviceName)
sub, err := nc.QueueSubscribe(subject, queueName, func(msg *nats.Msg) {
// 执行处理函数
executeHandler(handler, msg)
})
if err != nil {
return fmt.Errorf("注册队列 RPC 服务失败: %w", err)
}
queueRPCMu.Lock()
queueRPCSubs[queueName][serviceName] = sub
queueRPCMu.Unlock()
g.Log().Infof(context.Background(), "✅ 队列 RPC 服务已注册: %s (队列组: %s)", serviceName, queueName)
return nil
}
// executeHandler 执行 RPC 处理函数
func executeHandler(handler rpcHandler, msg *nats.Msg) {
// 响应
var respData []byte
// 从消息头重建上下文
ctx := headersToContext(context.Background(), msg.Header)
// 提取 TraceID创建可取消的 context
ctx = createCancelContext(ctx, msg.Header.Get(traceIDKey))
// 检查 context 是否已取消(在调用 handler 之前)
select {
case <-ctx.Done():
// context 已取消,返回取消错误
g.Log().Infof(ctx, "RPC 请求已取消traceID: %s", msg.Header.Get(traceIDKey))
// 仍然需要发送响应以避免客户端超时
respData = []byte(`{"_err":"请求已取消"}`)
// 清理取消映射表
cleanupTraceCancel(msg.Header.Get(traceIDKey))
return
default:
}
// 执行业务处理
response, err := handler(ctx, msg.Data)
if err != nil {
// 错误时返回 {"_err": "错误信息"}
if respData, err = json.Marshal(map[string]any{"_err": err.Error()}); err != nil {
g.Log().Errorf(ctx, "RPC 错误响应序列化失败: %v", err)
respData = []byte(`{"_err":"错误响应序列化失败"}`)
}
} else if response == nil {
// 空响应时返回空对象(或 {"_err": ""}
respData = []byte(`{}`)
} else {
// 成功时返回业务数据
if respData, err = json.Marshal(response); err != nil {
g.Log().Errorf(ctx, "RPC 响应序列化失败: %v", err)
respData = []byte(`{"_err":"响应序列化失败"}`)
}
}
// 发送响应(必须执行) 如果客户端用 nc.Request(...) 发送消息 → 双向模式,服务端必须 msg.Respond
if err = msg.Respond(respData); err != nil {
g.Log().Errorf(ctx, "RPC 响应失败: %v", err)
}
// 请求结束,清理取消映射表
cleanupTraceCancel(msg.Header.Get(traceIDKey))
}
// createCancelContext 创建可取消的 context 并注册到取消映射表
// 返回可取消的 context如果 traceID 为空则返回原 context
func createCancelContext(ctx context.Context, traceID string) context.Context {
if g.IsEmpty(traceID) {
return ctx
}
// 创建带取消功能的 context
taskCtx, cancel := context.WithCancel(ctx)
// 注册到取消映射表
traceCancelMu.Lock()
if traceCancelMap == nil {
traceCancelMap = make(map[string]context.CancelFunc)
}
// 如果同一 TraceID 已有 CancelFunc先调用它
if oldCancel, exists := traceCancelMap[traceID]; exists {
oldCancel()
}
traceCancelMap[traceID] = cancel
traceCancelMu.Unlock()
return taskCtx
}
// ============ TraceID 主动取消功能 ============
// 以下函数实现了基于 OpenTelemetry TraceID 的跨进程任务取消机制
// SetupCancelListener 设置取消监听器
// 订阅取消主题,监听取消指令
// 使用示例:
//
// sub, err := nats.SetupCancelListener(ctx)
func setupCancelListener(ctx context.Context) (*nats.Subscription, error) {
if !natsPing(ctx, rpcDefaultDatasource) {
return nil, fmt.Errorf("NATS 未连接")
}
if traceCancelMap == nil {
traceCancelMap = make(map[string]context.CancelFunc)
}
// 修复问题3订阅取消主题格式: ctx.cancel.otel.*
// 使用 * 通配符而不是 >,因为 TraceID 是最后一部分
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return nil, fmt.Errorf("NATS 连接不存在")
}
cancelSubject := cancelSubjectPrefix + "*"
sub, err := nc.Subscribe(cancelSubject, func(msg *nats.Msg) {
// 从主题中解析 TraceID (去除前缀)
prefixLen := len(cancelSubjectPrefix)
if len(msg.Subject) <= prefixLen {
g.Log().Warningf(ctx, "取消消息主题格式错误: %s", msg.Subject)
return
}
traceID := msg.Subject[prefixLen:]
if traceID == "" {
g.Log().Warning(ctx, "取消消息主题缺少 TraceID")
return
}
// 从映射表获取 CancelFunc 并执行取消
traceCancelMu.RLock()
cancel, ok := traceCancelMap[traceID]
traceCancelMu.RUnlock()
if ok {
cancel()
g.Log().Infof(ctx, "📢 取消信号已发送traceID: %s", traceID)
} else {
g.Log().Infof(ctx, "⚠️ 未找到对应的可取消任务traceID: %s", traceID)
}
})
if err != nil {
return nil, fmt.Errorf("设置取消监听器失败: %w", err)
}
g.Log().Infof(ctx, "✅ 取消监听器已设置: %s", cancelSubject)
return sub, nil
}
// publishCancel 发布取消指令
// 向指定 TraceID 发送取消信号
// 使用示例:
//
// err := nats.publishCancel(ctx, traceID)
func publishCancel(ctx context.Context, traceID string) error {
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
if traceID == "" {
return fmt.Errorf("TraceID 不能为空")
}
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
cancelSubject := cancelSubjectPrefix + traceID
err := nc.Publish(cancelSubject, nil)
if err != nil {
return fmt.Errorf("发布取消信号失败: %w", err)
}
g.Log().Infof(ctx, "📤 已发送取消信号traceID: %s主题: %s", traceID, cancelSubject)
return nil
}
// cleanupTraceCancel 清理取消映射表中的条目
// 任务取消/正常结束后必须调用此函数,避免内存泄漏
// 使用示例:
//
// defer nats.cleanupTraceCancel(traceID)
func cleanupTraceCancel(traceID string) {
if traceID == "" {
return
}
traceCancelMu.Lock()
defer traceCancelMu.Unlock()
if _, ok := traceCancelMap[traceID]; ok {
delete(traceCancelMap, traceID)
g.Log().Infof(context.Background(), "✅ 已清理取消映射表traceID: %s", traceID)
}
}
// CallRPC 调用 RPC 服务
// serviceName: 服务名称
// req: 请求数据
// 返回: 响应数据(任意类型)和错误
func CallRPC(ctx context.Context, serviceName string, req any, resp any) (err error) {
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
// 验证 resp 必须是指针类型
respValue := reflect.ValueOf(resp)
if respValue.Kind() != reflect.Ptr {
return fmt.Errorf("resp 参数必须是指针类型(当前类型: %T", resp)
}
// 构建请求体
var reqBody []byte
if !g.IsEmpty(req) {
reqValue := reflect.ValueOf(req)
if !(reqValue.Kind() == reflect.Ptr && reqValue.IsNil()) && !reqValue.IsZero() {
reqData, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("序列化请求参数失败: %w", err)
}
reqBody = reqData
}
}
// 检查本地是否有注册的单实例服务,如果有则直接调用(优化性能)
rpcServicesMu.RLock()
if localHandler, exists := rpcServices[serviceName]; exists {
rpcServicesMu.RUnlock()
// 修复问题1本地调用也需要处理取消机制
var traceID string
if traceID, err = getTraceID(ctx); err != nil {
return err
}
// 提取 TraceID创建可取消的 context
cancelCtx := createCancelContext(ctx, traceID)
// 执行本地调用
var response interface{}
if response, err = localHandler(cancelCtx, reqBody); err != nil {
return fmt.Errorf("本地调用 RPC 服务失败 [%s]: %w", serviceName, err)
}
// 请求结束,清理取消映射表
cleanupTraceCancel(traceID)
// 检查是否为错误消息:尝试解析为 map看是否包含 "_err" 字段
var respMap map[string]any
if json.Unmarshal(response.([]byte), &respMap) == nil {
if errMsg, ok := respMap["_err"]; ok {
return fmt.Errorf("%v", errMsg)
}
}
// 正常数据直接返回
// responseMsg.Data 已经是 []byte 类型(来自 msg.Data直接反序列化
if err = json.Unmarshal(response.([]byte), resp); err != nil {
return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, response)
}
return
}
rpcServicesMu.RUnlock()
subject := fmt.Sprintf("rpc.%s", serviceName)
// 创建消息并将上下文元数据写入消息头
msg := nats.NewMsg(subject)
msg.Data = reqBody
headers, err := contextToHeaders(ctx)
if err != nil {
return fmt.Errorf("上下文转换失败: %w", err)
}
msg.Header = headers
// 修复问题5优化 go 协程避免资源泄漏
// 使用 done channel 来确保 goroutine 能正确退出
done := make(chan struct{})
var closeDoneOnce sync.Once
closeDone := func() {
closeDoneOnce.Do(func() {
close(done)
})
}
if msg.Header.Get(traceIDKey) != "" {
go func() {
defer closeDone()
select {
case <-ctx.Done():
// context 被取消时,发送取消信号给服务端
if errors.Is(ctx.Err(), context.Canceled) {
if err := publishCancel(context.Background(), msg.Header.Get(traceIDKey)); err != nil {
g.Log().Errorf(ctx, "发送 RPC 取消信号失败: %v", err)
} else {
g.Log().Infof(ctx, "RPC 调用已取消traceID: %s", msg.Header.Get(traceIDKey))
}
}
case <-done:
// 请求已完成,无需发送取消信号
return
}
}()
}
// 发送请求
nc := getNatsConn(rpcDefaultDatasource)
if nc == nil {
return fmt.Errorf("NATS 连接不存在")
}
responseMsg, err := nc.RequestMsgWithContext(ctx, msg)
// 关闭 done channel通知 goroutine 退出
closeDone()
if err != nil {
return fmt.Errorf("调用 RPC 服务失败 [%s]: %w", serviceName, err)
}
if responseMsg == nil {
return fmt.Errorf("RPC 响应为空 [%s]", serviceName)
}
// 解析响应
if len(responseMsg.Data) > 0 {
// 检查是否为错误消息:尝试解析为 map看是否包含 "_err" 字段
var respMap map[string]any
if json.Unmarshal(responseMsg.Data, &respMap) == nil {
if errMsg, ok := respMap["_err"]; ok {
return fmt.Errorf("%v", errMsg)
}
}
// 正常数据直接返回
// responseMsg.Data 已经是 []byte 类型(来自 msg.Data直接反序列化
if err = json.Unmarshal(responseMsg.Data, resp); err != nil {
return fmt.Errorf("解析响应失败: %w (响应内容: %s)", err, responseMsg.Data)
}
}
return
}
// RegisterServiceOption 注册选项类型
type registerServiceOption func(*registerServiceConfig)
type registerServiceConfig struct {
queueName string // 队列组名(用于集群模式)
excludeMethods []string
}
// WithQueueGroup 设置队列组名(集群模式)
func WithQueueGroup(queueName string) registerServiceOption {
return func(cfg *registerServiceConfig) {
cfg.queueName = queueName
}
}
// WithExcludeMethods 排除不需要注册的方法
func WithExcludeMethods(methods ...string) registerServiceOption {
return func(cfg *registerServiceConfig) {
cfg.excludeMethods = append(cfg.excludeMethods, methods...)
}
}
// AutoRegisterServices 自动注册多个服务的所有公开方法
// serviceInstances: map[包名]service实例如 map[string]interface{}{"user": userService, "order": orderService}
// options: 注册选项(可选)
// 示例:
//
// AutoRegisterServices(map[string]interface{}{
// "user": userService,
// "order": orderService,
// })
// 或
// AutoRegisterServices(map[string]interface{}{
// "order": orderService,
// }, WithQueueGroup("order-group"))
func AutoRegisterServices(ctx context.Context, serviceInstances map[string]interface{}, options ...registerServiceOption) error {
// 先注册 RPC 服务(如果 NATS 不可用则记录警告但不阻塞启动)
if !natsPing(ctx, rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接RPC 服务未注册")
}
if len(serviceInstances) == 0 {
return fmt.Errorf("service 实例列表不能为空")
}
totalRegistered := 0
// 遍历每个 service 实例
for pkgName, serviceInstance := range serviceInstances {
// 注册服务
err := registerService(serviceInstance, pkgName, options...)
if err != nil {
g.Log().Errorf(ctx, "注册 %s 服务失败: %v", pkgName, err)
continue
}
totalRegistered++
g.Log().Infof(ctx, "✅ %s 服务已自动注册", pkgName)
}
if totalRegistered == 0 {
return fmt.Errorf("未能注册任何服务")
}
// 设置取消监听器(监听基于 TraceID 的取消请求)
if _, err := setupCancelListener(ctx); err != nil {
g.Log().Errorf(ctx, "设置取消监听器失败: %v", err)
} else {
g.Log().Infof(ctx, "✅ 取消监听器已自动设置")
}
g.Log().Infof(ctx, "✅ 共自动注册了 %d 个服务", totalRegistered)
return nil
}
// registerService 注册单个服务的所有公开方法(内部函数)
func registerService(service interface{}, serviceNamePrefix string, options ...registerServiceOption) (err error) {
if !natsPing(context.Background(), rpcDefaultDatasource) {
return fmt.Errorf("NATS 未连接")
}
// 应用选项
cfg := &registerServiceConfig{}
for _, opt := range options {
opt(cfg)
}
// 创建排除方法集合
excludeSet := make(map[string]struct{})
for _, method := range cfg.excludeMethods {
excludeSet[method] = struct{}{}
}
// 获取 service 的类型
serviceType := reflect.TypeOf(service)
// 遍历所有方法
registeredCount := 0
for i := 0; i < serviceType.NumMethod(); i++ {
method := serviceType.Method(i)
// 只注册导出方法(首字母大写)
if !method.IsExported() {
continue
}
// 排除指定的方法
if _, exists := excludeSet[method.Name]; exists {
continue
}
// 检查方法签名:必须是 func(ctx context.Context, request) (response, error)
// 注意method.Type.NumIn() 包含接收者,所以实际参数数量需要减去 1
// 要求:接收者 + context.Context + request总共3个参数
if method.Type.NumIn() != 3 {
g.Log().Warningf(context.Background(), "方法 %s 必须有2个参数context.Context 和请求参数),跳过注册", method.Name)
continue
}
// 第一个参数(接收者之后的第一个参数)必须是 context.Context
// method.Type.In(0) 是接收者method.Type.In(1) 才是第一个参数
if !method.Type.In(1).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) {
g.Log().Warningf(context.Background(), "方法 %s 的第一个参数必须是 context.Context跳过注册", method.Name)
continue
}
// 第二个参数必须是结构体指针或数组
reqType := method.Type.In(2)
if reqType.Kind() != reflect.Ptr && reqType.Kind() != reflect.Slice && reqType.Kind() != reflect.Array {
g.Log().Warningf(context.Background(), "方法 %s 的第二个参数必须是结构体指针或数组,跳过注册", method.Name)
continue
}
// 返回值必须是 (result, error)即2个返回值
if method.Type.NumOut() != 2 {
g.Log().Warningf(context.Background(), "方法 %s 必须有2个返回值result 和 error跳过注册", method.Name)
continue
}
// 最后一个返回值必须是 error
if !method.Type.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) {
g.Log().Warningf(context.Background(), "方法 %s 的最后一个返回值必须是 error跳过注册", method.Name)
continue
}
// 生成服务名称:前缀.方法名(保持原始方法名)
serviceName := fmt.Sprintf("%s.%s", serviceNamePrefix, method.Name)
// 创建 RPC handler
handler := func(ctx context.Context, req []byte) (any, error) {
// 准备方法调用参数
// args[0] 是接收者, args[1] 是 ctx, args[2] 是请求参数
args := make([]reflect.Value, 3)
args[0] = reflect.ValueOf(service) // 接收者
args[1] = reflect.ValueOf(ctx) // context.Context
// 解析请求参数
if len(req) > 0 {
reqValuePtr := reflect.New(reqType)
// 解析 JSON
if err := json.Unmarshal(req, reqValuePtr.Interface()); err != nil {
// 根据参数类型提供更友好的错误提示
var typeHint string
if reqType.Kind() == reflect.Ptr {
typeHint = fmt.Sprintf("(期望类型: %s", reqType.Elem().Name())
} else { // reflect.Slice 或 reflect.Array
typeHint = fmt.Sprintf("(期望类型: %s请确保客户端传递的是JSON数组格式", reqType.String())
}
return nil, fmt.Errorf("解析请求参数失败%s: %w", typeHint, err)
}
args[2] = reqValuePtr.Elem()
} else {
// 请求为空,创建零值
args[2] = reflect.Zero(method.Type.In(2))
}
// 调用方法
results := method.Func.Call(args)
// 处理返回值
var result any
if len(results) == 1 {
// 只有 error
if !results[0].IsNil() {
err = results[0].Interface().(error)
}
} else if len(results) == 2 {
// (result, error)
result = results[0].Interface()
if !results[1].IsNil() {
err = results[1].Interface().(error)
}
}
if err != nil {
return nil, err
}
return result, nil
}
// 注册 RPC 服务
var err error
if cfg.queueName != "" {
err = registerQueueRPCService(serviceName, cfg.queueName, handler)
} else {
err = registerRPCService(serviceName, handler)
}
if err != nil {
g.Log().Errorf(context.Background(), "注册服务 %s 失败: %v", serviceName, err)
continue
}
registeredCount++
g.Log().Infof(context.Background(), "✅ 已自动注册 RPC 服务: %s -> %s", serviceName, method.Name)
}
if registeredCount == 0 {
g.Log().Warningf(context.Background(), "未注册任何方法,请检查 %v 的方法签名", serviceNamePrefix)
return fmt.Errorf("未找到可注册的方法")
}
g.Log().Infof(context.Background(), "✅ Service %v 共注册了 %d 个 RPC 方法", serviceNamePrefix, registeredCount)
return nil
}
// ============ 上下文元数据工具函数 ============
// 以下函数用于在 context 和 NATS 消息头之间互转元数据
// 定义常见的上下文元数据 key私有
const (
traceIDKey = "trace_id"
tokenKey = "token"
)
func getTraceID(ctx context.Context) (traceID string, err error) {
// 提取 traceId首先尝试从 OpenTelemetry Span 中提取,从 context 中提取 TraceID
span := trace.SpanFromContext(ctx)
if span != nil && span.SpanContext().HasTraceID() {
traceID = span.SpanContext().TraceID().String()
} else if tid := ctx.Value(traceIDKey); tid != nil {
traceID = fmt.Sprintf("%v", tid)
}
if traceID == "" {
return traceID, fmt.Errorf("context 中没有 TraceID")
}
return
}
// contextToHeaders 将 context 中的元数据转换为 NATS 消息头
// 支持提取 user_id、tenant_id、trace_id、token 等常见字段
func contextToHeaders(ctx context.Context) (nats.Header, error) {
headers := make(nats.Header)
// 提取 traceId首先尝试从 OpenTelemetry Span 中提取
if traceID, err := getTraceID(ctx); err != nil {
return headers, err
} else {
headers.Set(traceIDKey, traceID)
}
// 提取 token优先级context value > HTTP Authorization header
token := ""
if t := ctx.Value(tokenKey); t != nil {
token = fmt.Sprintf("%v", t)
} else if r := g.RequestFromCtx(ctx); r != nil {
// 从 HTTP 请求的 Authorization header 中提取 token
auth := r.GetHeader("Authorization")
if auth != "" {
// 移除 "Bearer " 前缀
if len(auth) > 7 && auth[:7] == "Bearer " {
token = auth[7:]
} else {
token = auth
}
}
}
if token != "" {
headers.Set(tokenKey, token)
}
return headers, nil
}
// headersToContext 从 NATS 消息头重建 context
// 支持还原 user_id、tenant_id、trace_id、token 等字段
func headersToContext(ctx context.Context, headers nats.Header) context.Context {
if headers == nil {
return ctx
}
// 恢复 trace_id
if traceID := headers.Get(traceIDKey); traceID != "" {
ctx = context.WithValue(ctx, traceIDKey, traceID)
}
// 恢复 token
if token := headers.Get(tokenKey); token != "" {
ctx = context.WithValue(ctx, tokenKey, token)
}
return ctx
}