package auth import ( "context" "github.com/go-kratos/kratos/v2/errors" "github.com/go-kratos/kratos/v2/middleware" "github.com/go-kratos/kratos/v2/middleware/selector" "github.com/go-kratos/kratos/v2/transport" "github.com/golang-jwt/jwt/v4" "strings" ) type authKey struct{} // Claims token载荷 type Claims struct { jwt.RegisteredClaims } const ( // bearerWord the bearer key word for authorization bearerWord string = "Bearer" // authorizationKey holds the key used to store the JWT Token in the request tokenHeader. authorizationKey string = "Authorization" // reason holds the error reason. reason string = "UNAUTHORIZED" ) var ( ErrMissingJwtToken = errors.Unauthorized(reason, "JWT token is missing") ErrTokenInvalid = errors.Unauthorized(reason, "Token is invalid") ErrTokenExpired = errors.Unauthorized(reason, "JWT token has expired") ErrTokenParseFail = errors.Unauthorized(reason, "Fail to parse JWT token ") ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method") ErrWrongContext = errors.Unauthorized(reason, "Wrong context for middleware") ) type Server struct { secret string skipRouter []string } func New(options ...Option) *Server { server := &Server{} for _, setOpt := range options { setOpt(server) } return server } // Validate 验证token func (s *Server) Validate() middleware.Middleware { return selector.Server(func(handler middleware.Handler) middleware.Handler { return func(ctx context.Context, req interface{}) (interface{}, error) { if header, ok := transport.FromServerContext(ctx); ok { keyFunc := func(token *jwt.Token) (interface{}, error) { return []byte(s.secret), nil } auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2) if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) { return nil, ErrMissingJwtToken } jwtToken := auths[1] tokenInfo, err := jwt.ParseWithClaims(jwtToken, &Claims{}, keyFunc) if err != nil { ve, ok := err.(*jwt.ValidationError) if !ok { return nil, errors.Unauthorized(reason, err.Error()) } if ve.Errors&jwt.ValidationErrorMalformed != 0 { return nil, ErrTokenInvalid } if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 { return nil, ErrTokenExpired } return nil, ErrTokenParseFail } if !tokenInfo.Valid { return nil, ErrTokenInvalid } if _, ok := tokenInfo.Method.(*jwt.SigningMethodHMAC); !ok { return nil, ErrUnSupportSigningMethod } ctx = NewContext(ctx, tokenInfo.Claims) return handler(ctx, req) } return nil, ErrWrongContext } }). Match(skipRouter(s.skipRouter)). Build() } // skipRouter 不需要验证token的路由 func skipRouter(except []string) selector.MatchFunc { whiteList := make(map[string]struct{}) for _, v := range except { whiteList[v] = struct{}{} } return func(ctx context.Context, operation string) bool { if _, ok := whiteList[operation]; ok { return false } return true } } // NewContext put auth info into context func NewContext(ctx context.Context, info jwt.Claims) context.Context { return context.WithValue(ctx, authKey{}, info) } // FromContext extract auth info from context func FromContext(ctx context.Context) (token jwt.Claims, ok bool) { token, ok = ctx.Value(authKey{}).(jwt.Claims) return } type Option func(s *Server) // Secret 设置secret func Secret(secret string) Option { return func(s *Server) { s.secret = secret } } // SkipRouter 设置不需要验证的路由 func SkipRouter(skip []string) Option { return func(s *Server) { s.skipRouter = skip } }