From a3ad38e8f655b42745afab54ab46c1bab857e524 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E6=96=8C?= <259278618@qq.com> Date: Fri, 9 Jan 2026 16:46:35 +0800 Subject: [PATCH] =?UTF-8?q?common=E5=A2=9E=E5=8A=A0nats=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E9=98=9F=E5=88=97=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nats/nats.go | 401 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 362 insertions(+), 39 deletions(-) diff --git a/nats/nats.go b/nats/nats.go index d1ec88e..f04bab1 100644 --- a/nats/nats.go +++ b/nats/nats.go @@ -2,7 +2,9 @@ package nats import ( "context" + "encoding/json" "fmt" + "reflect" "sync" "sync/atomic" "time" @@ -515,102 +517,424 @@ func CreateConsumer(ctx context.Context, streamName, consumerName string, config return consumer, nil } -// SubscribeRequest 订阅 RPC 请求 -// B服务作为服务提供者,订阅主题并响应请求时使用此方法 -// subject: 订阅的主题名,与 Request 调用时使用相同的 subject -func SubscribeRequest(subject string, handler func(subject string, data []byte) ([]byte, error)) (*nats.Subscription, error) { +// ============ RPC 服务封装 ============ +// 以下方法提供了完全抽象的 RPC 调用接口 +// 调用方和响应方完全不需要知道底层使用的是 NATS 的发布订阅模式 + +// RPC 服务注册表 +var ( + rpcServices map[string]RPCHandler + rpcSubs map[string]*nats.Subscription // 服务名 -> 订阅 + rpcServicesMu sync.RWMutex + queueRPCServices map[string]map[string]RPCHandler // queueName -> subject -> handler + queueRPCSubs map[string]map[string]*nats.Subscription // queueName -> serviceName -> 订阅 + queueRPCMu sync.RWMutex +) + +// RPCHandler RPC 处理函数类型 +// 实现方只需要关注请求参数和返回值,无需了解底层 NATS 实现 +type RPCHandler func(ctx context.Context, req []byte) ([]byte, error) + +// RegisterRPCService 注册 RPC 服务(单实例) +// serviceName: 服务名称,调用方通过此名称调用服务 +// handler: 服务处理函数,接收请求并返回响应 +func RegisterRPCService(serviceName string, handler RPCHandler) error { if !checkConnected() { - return nil, fmt.Errorf("NATS 未连接") + return fmt.Errorf("NATS 未连接") } + rpcServicesMu.Lock() + if rpcServices == nil { + rpcServices = make(map[string]RPCHandler) + } + if rpcSubs == nil { + rpcSubs = make(map[string]*nats.Subscription) + } + + // 如果已存在该服务,先取消之前的订阅 + if oldSub, exists := rpcSubs[serviceName]; exists { + oldSub.Unsubscribe() + } + + rpcServices[serviceName] = handler + rpcServicesMu.Unlock() + + // 订阅服务主题 + subject := fmt.Sprintf("rpc.%s", serviceName) sub, err := nc.Subscribe(subject, func(msg *nats.Msg) { - // 处理请求 - response, err := handler(msg.Subject, msg.Data) + ctx := context.Background() + response, err := handler(ctx, msg.Data) if err != nil { - // 处理错误,发送错误响应 errMsg := fmt.Sprintf("处理失败: %v", err) if err = msg.Respond([]byte(errMsg)); err != nil { - g.Log().Errorf(context.Background(), "RPC 错误响应失败: %v", err) + g.Log().Errorf(ctx, "RPC 错误响应失败: %v", err) } return } - // 发送成功响应 if err = msg.Respond(response); err != nil { - g.Log().Errorf(context.Background(), "RPC 响应失败: %v", err) + g.Log().Errorf(ctx, "RPC 响应失败: %v", err) } }) if err != nil { - return nil, fmt.Errorf("订阅 RPC 请求失败: %w", err) + return fmt.Errorf("注册 RPC 服务失败: %w", err) } - return sub, nil + rpcSubs[serviceName] = sub + metrics.SubscribeCount.Add(1) + g.Log().Infof(context.Background(), "✅ RPC 服务已注册: %s", serviceName) + return nil } -// SubscribeQueueRequest 订阅队列模式的 RPC 请求(负载均衡) -// 多个服务实例订阅同一主题,实现负载均衡 -// subject: 订阅的主题名,与 Request 调用时使用相同的 subject -// queueName: 队列组名,同一队列组的实例之间实现负载均衡 -func SubscribeQueueRequest(subject, queueName string, handler func(subject string, data []byte) ([]byte, error)) (*nats.Subscription, error) { +// RegisterQueueRPCService 注册 RPC 服务(集群模式) +// 多个服务实例注册同一服务时,请求会自动负载均衡 +// serviceName: 服务名称 +// queueName: 队列组名,同一队列组的实例共享请求 +// handler: 服务处理函数 +func RegisterQueueRPCService(serviceName, queueName string, handler RPCHandler) error { if !checkConnected() { - return nil, fmt.Errorf("NATS 未连接") + return fmt.Errorf("NATS 未连接") } + queueRPCMu.Lock() + if queueRPCServices == nil { + queueRPCServices = make(map[string]map[string]RPCHandler) + } + if queueRPCSubs == nil { + queueRPCSubs = make(map[string]map[string]*nats.Subscription) + } + if queueRPCServices[queueName] == nil { + queueRPCServices[queueName] = make(map[string]RPCHandler) + } + if queueRPCSubs[queueName] == nil { + queueRPCSubs[queueName] = make(map[string]*nats.Subscription) + } + + // 如果已存在该服务,先取消之前的订阅 + if oldSub, exists := queueRPCSubs[queueName][serviceName]; exists { + oldSub.Unsubscribe() + } + + queueRPCServices[queueName][serviceName] = handler + queueRPCMu.Unlock() + + // 订阅服务主题(队列模式) + subject := fmt.Sprintf("rpc.%s", serviceName) sub, err := nc.QueueSubscribe(subject, queueName, func(msg *nats.Msg) { - // 处理请求 - response, err := handler(msg.Subject, msg.Data) + ctx := context.Background() + response, err := handler(ctx, msg.Data) if err != nil { - // 处理错误,发送错误响应 errMsg := fmt.Sprintf("处理失败: %v", err) if err = msg.Respond([]byte(errMsg)); err != nil { - g.Log().Errorf(context.Background(), "RPC 错误响应失败: %v", err) + g.Log().Errorf(ctx, "RPC 错误响应失败: %v", err) } return } - // 发送成功响应 - if err := msg.Respond(response); err != nil { - g.Log().Errorf(context.Background(), "RPC 响应失败: %v", err) + if err = msg.Respond(response); err != nil { + g.Log().Errorf(ctx, "RPC 响应失败: %v", err) } }) if err != nil { - return nil, fmt.Errorf("订阅队列 RPC 请求失败: %w", err) + return fmt.Errorf("注册队列 RPC 服务失败: %w", err) } - return sub, nil + queueRPCMu.Lock() + queueRPCSubs[queueName][serviceName] = sub + queueRPCMu.Unlock() + + metrics.SubscribeCount.Add(1) + g.Log().Infof(context.Background(), "✅ 队列 RPC 服务已注册: %s (队列组: %s)", serviceName, queueName) + return nil } -// Request RPC 请求-响应模式 -// A服务调用B服务查询接口时使用此方法 -func Request(ctx context.Context, subject string, data []byte, timeout time.Duration) ([]byte, error) { +// CallRPC 调用 RPC 服务 +// serviceName: 服务名称 +// req: 请求数据 +// timeout: 超时时间 +// 返回: 响应数据和错误 +func CallRPC(ctx context.Context, serviceName string, req []byte, timeout time.Duration) ([]byte, error) { if !checkConnected() { return nil, fmt.Errorf("NATS 未连接") } metrics.RequestCount.Add(1) - // 使用 timeout 参数创建超时上下文 + // 检查本地是否有注册的单实例服务,如果有则直接调用(优化性能) + rpcServicesMu.RLock() + if localHandler, exists := rpcServices[serviceName]; exists { + rpcServicesMu.RUnlock() + // 本地直接调用,避免网络开销 + response, err := localHandler(ctx, req) + if err != nil { + metrics.RequestError.Add(1) + return nil, fmt.Errorf("本地调用 RPC 服务失败 [%s]: %w", serviceName, err) + } + return response, nil + } + rpcServicesMu.RUnlock() + + // 通过 NATS 网络调用远程服务 timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - msg, err := nc.RequestWithContext(timeoutCtx, subject, data) + subject := fmt.Sprintf("rpc.%s", serviceName) + msg, err := nc.RequestWithContext(timeoutCtx, subject, req) if err != nil { metrics.RequestError.Add(1) - return nil, fmt.Errorf("RPC 请求失败: %w", err) + return nil, fmt.Errorf("调用 RPC 服务失败 [%s]: %w", serviceName, err) } if msg == nil { metrics.RequestError.Add(1) - return nil, fmt.Errorf("RPC 响应为空") + return nil, fmt.Errorf("RPC 响应为空 [%s]", serviceName) } return msg.Data, nil } -// Close 关闭 NATS 连接 -func Close() error { +// RegisterServiceOption 注册选项类型 +type RegisterServiceOption func(*registerServiceConfig) + +type registerServiceConfig struct { + queueName string // 队列组名(用于集群模式) + excludeMethods []string +} + +// WithQueueGroup 设置队列组名(集群模式) +func WithQueueGroup(queueName string) RegisterServiceOption { + return func(cfg *registerServiceConfig) { + cfg.queueName = queueName + } +} + +// WithExcludeMethods 排除不需要注册的方法 +func WithExcludeMethods(methods ...string) RegisterServiceOption { + return func(cfg *registerServiceConfig) { + cfg.excludeMethods = append(cfg.excludeMethods, methods...) + } +} + +// registerService 注册单个服务的所有公开方法(内部函数) +func registerService(service interface{}, serviceNamePrefix string, options ...RegisterServiceOption) error { + if !checkConnected() { + return fmt.Errorf("NATS 未连接") + } + + // 应用选项 + cfg := ®isterServiceConfig{} + for _, opt := range options { + opt(cfg) + } + + // 创建排除方法集合 + excludeSet := make(map[string]struct{}) + for _, method := range cfg.excludeMethods { + excludeSet[method] = struct{}{} + } + + // 获取 service 的类型 + serviceType := reflect.TypeOf(service) + + // 遍历所有方法 + registeredCount := 0 + for i := 0; i < serviceType.NumMethod(); i++ { + method := serviceType.Method(i) + + // 只注册导出方法(首字母大写) + if !method.IsExported() { + continue + } + + // 排除指定的方法 + if _, exists := excludeSet[method.Name]; exists { + continue + } + + // 检查方法签名:必须是 func(ctx context.Context, request) (response, error) + if method.Type.NumIn() < 2 { + g.Log().Warningf(context.Background(), "方法 %s 的参数数量不足,跳过注册", method.Name) + continue + } + + // 第一个参数必须是 context.Context + if !method.Type.In(0).Implements(reflect.TypeOf((*context.Context)(nil)).Elem()) { + g.Log().Warningf(context.Background(), "方法 %s 的第一个参数必须是 context.Context,跳过注册", method.Name) + continue + } + + // 返回值必须是 (result, error) 或 error + if method.Type.NumOut() < 1 || method.Type.NumOut() > 2 { + g.Log().Warningf(context.Background(), "方法 %s 的返回值数量不正确,跳过注册", method.Name) + continue + } + + if !method.Type.Out(method.Type.NumOut() - 1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { + g.Log().Warningf(context.Background(), "方法 %s 的最后一个返回值必须是 error,跳过注册", method.Name) + continue + } + + // 生成服务名称:前缀.方法名(保持原始方法名) + serviceName := fmt.Sprintf("%s.%s", serviceNamePrefix, method.Name) + + // 创建 RPC handler + handler := func(ctx context.Context, req []byte) ([]byte, error) { + // 准备方法调用参数 + args := make([]reflect.Value, 2) + args[0] = reflect.ValueOf(ctx) + + // 解析请求参数 + if len(req) > 0 { + // 如果方法有第二个参数,尝试解析 JSON + if method.Type.NumIn() > 1 { + reqValuePtr := reflect.New(method.Type.In(1)) + if err := json.Unmarshal(req, reqValuePtr.Interface()); err != nil { + return nil, fmt.Errorf("解析请求参数失败: %w", err) + } + args[1] = reqValuePtr.Elem() + } + } else if method.Type.NumIn() > 1 { + // 如果方法需要参数但请求为空,创建零值 + args[1] = reflect.Zero(method.Type.In(1)) + } + + // 调用方法 + results := method.Func.Call(args) + + // 处理返回值 + var err error + var result interface{} + + if len(results) == 1 { + // 只有 error + if !results[0].IsNil() { + err = results[0].Interface().(error) + } + } else if len(results) == 2 { + // (result, error) + result = results[0].Interface() + if !results[1].IsNil() { + err = results[1].Interface().(error) + } + } + + if err != nil { + return nil, err + } + + // 序列化返回值 + if result == nil || (reflect.ValueOf(result).Kind() == reflect.Ptr && reflect.ValueOf(result).IsNil()) { + return []byte("{}"), nil + } + + return json.Marshal(result) + } + + // 注册 RPC 服务 + var err error + if cfg.queueName != "" { + err = RegisterQueueRPCService(serviceName, cfg.queueName, handler) + } else { + err = RegisterRPCService(serviceName, handler) + } + + if err != nil { + g.Log().Errorf(context.Background(), "注册服务 %s 失败: %v", serviceName, err) + continue + } + + registeredCount++ + g.Log().Infof(context.Background(), "✅ 已自动注册 RPC 服务: %s -> %s", serviceName, method.Name) + } + + if registeredCount == 0 { + g.Log().Warningf(context.Background(), "未注册任何方法,请检查 %v 的方法签名", serviceNamePrefix) + return fmt.Errorf("未找到可注册的方法") + } + + g.Log().Infof(context.Background(), "✅ Service %v 共注册了 %d 个 RPC 方法", serviceNamePrefix, registeredCount) + return nil +} + +// AutoRegisterServices 自动注册多个服务的所有公开方法 +// serviceInstances: map[包名]service实例,如 map[string]interface{}{"user": userService, "order": orderService} +// options: 注册选项(可选) +// 示例: +// +// AutoRegisterServices(map[string]interface{}{ +// "user": userService, +// "order": orderService, +// }) +// 或 +// AutoRegisterServices(map[string]interface{}{ +// "order": orderService, +// }, WithQueueGroup("order-group")) +func AutoRegisterServices(serviceInstances map[string]interface{}, options ...RegisterServiceOption) error { + if len(serviceInstances) == 0 { + return fmt.Errorf("service 实例列表不能为空") + } + + totalRegistered := 0 + + // 遍历每个 service 实例 + for pkgName, serviceInstance := range serviceInstances { + // 注册服务 + err := registerService(serviceInstance, pkgName, options...) + if err != nil { + g.Log().Errorf(context.Background(), "注册 %s 服务失败: %v", pkgName, err) + continue + } + + totalRegistered++ + g.Log().Infof(context.Background(), "✅ %s 服务已自动注册", pkgName) + } + + if totalRegistered == 0 { + return fmt.Errorf("未能注册任何服务") + } + + g.Log().Infof(context.Background(), "✅ 共自动注册了 %d 个服务", totalRegistered) + return nil +} + +// Shutdown 优雅关闭:自动注销所有已注册的服务并关闭 NATS 连接 +func Shutdown() error { + ctx := context.Background() + g.Log().Info(ctx, "开始优雅关闭 NATS RPC 服务...") + + // 注销所有单实例服务 + rpcServicesMu.Lock() + singleServiceCount := len(rpcServices) + for serviceName := range rpcServices { + if sub, exists := rpcSubs[serviceName]; exists { + if err := sub.Unsubscribe(); err != nil { + g.Log().Errorf(ctx, "注销服务 %s 失败: %v", serviceName, err) + } + } + delete(rpcSubs, serviceName) + delete(rpcServices, serviceName) + } + rpcServicesMu.Unlock() + + // 注销所有队列服务 + queueRPCMu.Lock() + queueServiceCount := 0 + for queueName, servicesMap := range queueRPCServices { + queueServiceCount += len(servicesMap) + for serviceName, sub := range queueRPCSubs[queueName] { + if err := sub.Unsubscribe(); err != nil { + g.Log().Errorf(ctx, "注销队列服务 %s (队列: %s) 失败: %v", serviceName, queueName, err) + } + } + delete(queueRPCSubs, queueName) + delete(queueRPCServices, queueName) + } + queueRPCMu.Unlock() + + g.Log().Infof(ctx, "已注销 %d 个单实例服务和 %d 个队列服务", singleServiceCount, queueServiceCount) + mu.Lock() defer mu.Unlock() @@ -624,8 +948,7 @@ func Close() error { nc.Close() connected = false inited = false - g.Log().Info(context.Background(), "NATS 连接已关闭") } - + g.Log().Info(ctx, "NATS RPC 服务已优雅关闭") return nil }