diff --git a/http/http.go b/http/http.go index 8d84822..c675ec4 100644 --- a/http/http.go +++ b/http/http.go @@ -6,6 +6,7 @@ import ( "fmt" _ "gitee.com/red-future---jilin-g/common/consul" "gitee.com/red-future---jilin-g/common/jaeger" + "gitee.com/red-future---jilin-g/common/middleware" "gitee.com/red-future---jilin-g/common/utils" "github.com/gogf/gf/contrib/registry/consul/v2" "github.com/gogf/gf/v2/frame/g" @@ -38,9 +39,33 @@ func init() { //s.Use(common.Cors) //中间件验证 //s.EnablePProf() //启用性能分析 Httpserver.SetOpenApiPath("/api.json") - Httpserver.SetSwaggerPath("/swagger") //api文档访问路径 - Httpserver.SetDumpRouterMap(true) //关闭打印路由注册信息 - Httpserver.BindMiddlewareDefault(ghttp.MiddlewareHandlerResponse, jaeger.NewTracer) //使用默认http返回结构 + Httpserver.SetSwaggerPath("/docs") //api文档访问路径 + Httpserver.SetDumpRouterMap(true) //关闭打印路由注册信息 + Httpserver.SetSwaggerUITemplate(` + + + + + + + SwaggerUI + + + +
+ + + + + `) + Httpserver.BindMiddlewareDefault(ghttp.MiddlewareCORS, ghttp.MiddlewareHandlerResponse, middleware.Limiter, jaeger.NewTracer) //使用默认http返回结构 go Httpserver.Run() consulCfg, _ := g.Cfg().Get(context.Background(), "consul.address") @@ -51,7 +76,6 @@ func init() { } gsvc.SetRegistry(registry) gsel.SetBuilder(gsel.NewBuilderRoundRobin()) - Httpclient.SetHeader("Authorization", g.RequestFromCtx(context.TODO()).GetHeader("Authorization")) Httpclient.SetDiscovery(gsvc.GetRegistry()) } func RouteRegister(controllers []interface{}) { @@ -71,6 +95,7 @@ func doRequest(ctx context.Context, method string, url string, target any, data if err != nil { return } + Httpclient.SetHeader("Authorization", g.RequestFromCtx(ctx).GetHeader("Authorization")) response, err := Httpclient.DoRequest(ctx, method, url, data) if err != nil { return diff --git a/middleware/middleware.go b/middleware/middleware.go new file mode 100644 index 0000000..593c25c --- /dev/null +++ b/middleware/middleware.go @@ -0,0 +1,55 @@ +package middleware + +import ( + "context" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/ghttp" + "github.com/gogf/gf/v2/os/gtime" + "github.com/gogf/gf/v2/text/gstr" + "golang.org/x/time/rate" +) + +// Logger 中间件 +func Logger(r *ghttp.Request) { + startTime := gtime.TimestampMilli() + r.Middleware.Next() + endTime := gtime.TimestampMilli() + g.Log().Infof(r.GetCtx(), + "request: %s %s | status: %d | time: %dms", + r.Method, + r.URL.Path, + r.Response.Status, + endTime-startTime, + ) +} + +var rateLimit, _ = g.Cfg().Get(context.TODO(), "rate.limit") +var rateBurst, _ = g.Cfg().Get(context.TODO(), "rate.burst") +var limiter = rate.NewLimiter(rate.Limit(rateLimit.Int()), rateBurst.Int()) + +func Limiter(r *ghttp.Request) { + if !limiter.Allow() { + r.Response.WriteStatusExit(429) // Return 429 Too Many Requests + r.ExitAll() + } + r.Middleware.Next() +} +func Auth(r *ghttp.Request) { + token := r.Header.Get("Authorization") + if token == "" || !gstr.HasPrefix(token, "Bearer ") { + r.Response.WriteStatusExit(401, "Unauthorized") + return + } + + // 验证 token + if !validateToken(gstr.SubStrFrom(token, "7")) { + r.Response.WriteStatusExit(401, "Unauthorized") + return + } + + r.Middleware.Next() +} +func validateToken(token string) bool { + // 实现 token 验证逻辑 + return token == "valid-token" +}