import jwt
import datetime
from enum import Enum
from grpc import StatusCode

import logging, os

log = logging.getLogger('numerous_bigtable.tokens')
log.setLevel(logging.DEBUG)

REFRESH_TOKEN_EXP = None  # datetime.timedelta(hours=24)
ACCESS_TOKEN_EXP = datetime.timedelta(minutes=10)
TOKEN_ALGORITHM = "HS256"


class ValidationException(Exception):
    pass


class TokenType(Enum):
    REFRESH = 0
    ACCESS = 1


class AccessLevel(str, Enum):
    ANY = 0
    READ = 1
    WRITE = 2
    DEVELOPER = 3
    OWNER = 4
    ADMIN = 5


def _set_validation_error(ref, msg: str):
    """Function for raising a validation error in API"""
    ref.set_error(ValidationException, StatusCode.UNAUTHENTICATED, msg)


def validated_request(
        access_level: AccessLevel
):
    """Decorator that can be added to an endpoint to enable customizable validation"""

    def wrapper(method):
        def inner(ref, *args, **kwargs):
            # Find request sent by user
            request = args[0]

            prefix_from_request = request.prefix if hasattr(request, 'prefix') else None

            meta_data = args[1].invocation_metadata()

            # Get token values from context
            token = None

            for meta in meta_data:

                if meta.key == 'token':
                    token = meta.value

            # If we are not validating on a user token, just do standard validation
            if token is not None:
                try:
                    # Do actual validation
                    validated, claims = _validate_access_token(token, prefix_from_request, access_level)

                    args[1].claims = claims
                    # Raise an error if not validated
                    if not validated:
                        log.warning(
                            f"Not validated! Info:\nAccess Level: {access_level.name}\n"
                            f"\nToken: {token}\n"
                            f"Endpoint: {method.__qualname__}\nRequested Prefix: {prefix_from_request}\n"
                        )
                        _set_validation_error(ref, "Token validation failed!")
                except jwt.exceptions.ExpiredSignatureError:
                    _set_validation_error(ref, "Token expired!")

            # Raise an error if there is not found a token
            else:
                _set_validation_error(ref, f"Token validation failed! No token present in context. Normal: '{token}'")

            return method(ref, *args, **kwargs)

        return inner

    return wrapper


def _get_token_key():
    # TODO: Find a better way to make keys...

    secret = os.getenv('NUMEROUS_TOKEN_SECRET')

    if secret is None:
        raise ValueError('Token secret is None!')

    return secret  # os.getenv('NUMEROUS_CERT_KEY')


def generate_refresh_token(prefix, admin=False, user_id=None, organization_id=None, agent=None, purpose=None,
                           access_level=AccessLevel.READ, lifetime=None):
    """Generates a refresh token based on the claims provided on function call."""

    claims = {

        "type": TokenType.REFRESH.name,
        "admin": admin,
        "prefix": prefix,
        "user_id": user_id,
        "organization_id": organization_id,
        "agent": agent,
        "creation_utc_timestamp": datetime.datetime.utcnow().timestamp(),
        "purpose": purpose,
        "access_level": access_level
    }
    if lifetime is not None:
        claims.update({"exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=lifetime)})

    return jwt.encode(
        claims,
        key=_get_token_key(), algorithm=TOKEN_ALGORITHM
    )


def _validate_refresh_token(refresh_token_data):
    # Check that token type is refresh
    if refresh_token_data.get('type', None) != TokenType.REFRESH.name:
        raise ValidationException("Invalid or missing type during access token creation! Expected a refresh token.")

    else:
        return

    # raise ValidationException("Token not validated. Access denied.")


def generate_access_token(refresh_token):
    refresh_token_data = _wrapped_jwt_token_read(refresh_token)
    _validate_refresh_token(refresh_token_data)

    refresh_token_data.update({

        "type": TokenType.ACCESS.name,
        "exp": datetime.datetime.utcnow() + ACCESS_TOKEN_EXP
    })

    # Return new token
    return jwt.encode(
        refresh_token_data,
        key=_get_token_key(), algorithm=TOKEN_ALGORITHM
    )


def _validate_access_token(access_token, prefix_from_request, access_level) -> bool:
    access_token_data = _wrapped_jwt_token_read(access_token)
    #log.debug(access_token_data)
    token_prefix = access_token_data.get('prefix', None)
    user_id = access_token_data.get('user_id', '')
    claims = access_token_data

    # Admin rights - allow any action
    if access_token_data['admin']:
        return True, claims

    # If not admin, and admin is required, disallow
    if access_level == AccessLevel.ADMIN:
        return False, claims

    if token_prefix == prefix_from_request:

        if 'access_level' in access_token_data and access_token_data['access_level'] >= access_level:
            return True, claims
    else:
        log.debug(f'Token prefix not matching request prefix: {token_prefix} != {prefix_from_request}')

    return False, claims


def _wrapped_jwt_token_read(token, key=_get_token_key(), algorithm=TOKEN_ALGORITHM):
    return jwt.decode(jwt=token, key=key, algorithms=[algorithm])


