package auth

import (
	"context"
	"github.com/go-kratos/kratos/v2/errors"
	"github.com/go-kratos/kratos/v2/middleware"
	"github.com/go-kratos/kratos/v2/middleware/selector"
	"github.com/go-kratos/kratos/v2/transport"
	"github.com/golang-jwt/jwt/v4"
	"strings"
)

type authKey struct{}

// Claims token载荷
type Claims struct {
	jwt.RegisteredClaims
}

const (

	// bearerWord the bearer key word for authorization
	bearerWord string = "Bearer"

	// authorizationKey holds the key used to store the JWT Token in the request tokenHeader.
	authorizationKey string = "Authorization"

	// reason holds the error reason.
	reason string = "UNAUTHORIZED"
)

var (
	ErrMissingJwtToken        = errors.Unauthorized(reason, "JWT token is missing")
	ErrTokenInvalid           = errors.Unauthorized(reason, "Token is invalid")
	ErrTokenExpired           = errors.Unauthorized(reason, "JWT token has expired")
	ErrTokenParseFail         = errors.Unauthorized(reason, "Fail to parse JWT token ")
	ErrUnSupportSigningMethod = errors.Unauthorized(reason, "Wrong signing method")
	ErrWrongContext           = errors.Unauthorized(reason, "Wrong context for middleware")
)

type Server struct {
	secret     string
	skipRouter []string
}

func New(options ...Option) *Server {
	server := &Server{}
	for _, setOpt := range options {
		setOpt(server)
	}

	return server
}

// Validate 验证token
func (s *Server) Validate() middleware.Middleware {
	return selector.Server(func(handler middleware.Handler) middleware.Handler {
		return func(ctx context.Context, req interface{}) (interface{}, error) {
			if header, ok := transport.FromServerContext(ctx); ok {
				keyFunc := func(token *jwt.Token) (interface{}, error) { return []byte(s.secret), nil }
				auths := strings.SplitN(header.RequestHeader().Get(authorizationKey), " ", 2)

				if len(auths) != 2 || !strings.EqualFold(auths[0], bearerWord) {
					return nil, ErrMissingJwtToken
				}

				jwtToken := auths[1]
				tokenInfo, err := jwt.ParseWithClaims(jwtToken, &Claims{}, keyFunc)

				if err != nil {
					ve, ok := err.(*jwt.ValidationError)
					if !ok {
						return nil, errors.Unauthorized(reason, err.Error())
					}
					if ve.Errors&jwt.ValidationErrorMalformed != 0 {
						return nil, ErrTokenInvalid
					}
					if ve.Errors&(jwt.ValidationErrorExpired|jwt.ValidationErrorNotValidYet) != 0 {
						return nil, ErrTokenExpired
					}
					return nil, ErrTokenParseFail
				}

				if !tokenInfo.Valid {
					return nil, ErrTokenInvalid
				}
				if _, ok := tokenInfo.Method.(*jwt.SigningMethodHMAC); !ok {
					return nil, ErrUnSupportSigningMethod
				}

				ctx = NewContext(ctx, tokenInfo.Claims)
				return handler(ctx, req)
			}

			return nil, ErrWrongContext
		}
	}).
		Match(skipRouter(s.skipRouter)).
		Build()
}

// skipRouter 不需要验证token的路由
func skipRouter(except []string) selector.MatchFunc {
	whiteList := make(map[string]struct{})
	for _, v := range except {
		whiteList[v] = struct{}{}
	}

	return func(ctx context.Context, operation string) bool {
		if _, ok := whiteList[operation]; ok {
			return false
		}

		return true
	}
}

// NewContext put auth info into context
func NewContext(ctx context.Context, info jwt.Claims) context.Context {
	return context.WithValue(ctx, authKey{}, info)
}

// FromContext extract auth info from context
func FromContext(ctx context.Context) (token jwt.Claims, ok bool) {
	token, ok = ctx.Value(authKey{}).(jwt.Claims)
	return
}

type Option func(s *Server)

// Secret 设置secret
func Secret(secret string) Option {
	return func(s *Server) {
		s.secret = secret
	}
}

// SkipRouter 设置不需要验证的路由
func SkipRouter(skip []string) Option {
	return func(s *Server) {
		s.skipRouter = skip
	}
}