import traceback

import jwt
import datetime
from enum import Enum
from firebase_admin import auth
from . import firebase as fire
from grpc import StatusCode
from typing import Tuple
import logging, os
from .token_validation import validation
log = logging.getLogger('numerous_api.tokens')
log.setLevel(logging.DEBUG)


REFRESH_TOKEN_EXP = None  # datetime.timedelta(hours=24)
ACCESS_TOKEN_EXP = datetime.timedelta(minutes=10)
TOKEN_ALGORITHM = "HS256"


class TokenRequestException(Exception):
    pass


class ValidationException(Exception):
    pass


class TokenType(Enum):
    REFRESH = 0
    ACCESS = 1

def _get_obj_with_id(request, ids: list) -> str or None:
    """Get value from request (takes a list of values and tries to get them in order)."""
    for i in ids:
        if hasattr(request, i):
            return getattr(request, i)
    return None


def _get_project_id_from_request(request, custom_project_id_path: str) -> str or None:
    """Get project ID from request - either through standard keys (project, project_id) or through custom path"""
    if custom_project_id_path is None:
        return _get_obj_with_id(request, ['project_id', 'project'])

    return _get_object_id_from_request(request, custom_project_id_path)


def _get_scenario_id_from_request(request, custom_scenario_id_path: str) -> str or None:
    """Get scenario ID from request - either through standard keys (scenario, scenario_id) or through custom path"""
    if custom_scenario_id_path is None:
        return _get_obj_with_id(request, ['scenario_id', 'scenario'])

    return _get_object_id_from_request(request, custom_scenario_id_path)


def _get_object_id_from_request(request, custom_path: str) -> str or None:
    """Get id in request in an iterative manor through custom path"""
    path = custom_path.split('.')
    obj = getattr(request, path[0], None)
    last_level = 0

    for level in path[1:]:
        if obj is None:
            return getattr(request, path[last_level], None)
        obj = getattr(obj, level, None)
        last_level += 1

    return obj


def _set_validation_error(ref, msg: str):
    """Function for raising a validation error in API"""
    ref.set_error(ValidationException, StatusCode.UNAUTHENTICATED, msg)


def validated_request(
        access_level: validation.AccessLevel,
        validation_object_type: validation.ValidationObjectType = validation.ValidationObjectType.SCENARIO,
        use_user_token: bool = False, custom_project_id_path: str or None = None, custom_scenario_id_path: str or None = None
):
    """Decorator that can be added to an endpoint to enable customizable validation"""
    def wrapper(method):
        def inner(ref, *args, **kwargs):
            request = args[0]
            meta_data = args[1].invocation_metadata()

            # Find request sent by user
            project_id = _get_project_id_from_request(request, custom_project_id_path)
            scenario_id = _get_scenario_id_from_request(request, custom_scenario_id_path)

            # Get token values from context
            token = None
            user_token = None
            for meta in meta_data:
                if meta.key == 'token':
                    token = meta.value

                if meta.key == 'authorization':
                    user_token = meta.value

            # Handle validation of user_token
            if use_user_token and user_token is not None:
                validated, claims = _validate_user_token(user_token, project_id, scenario_id, access_level, validation_object_type)
                args[1].claims=claims
                # Raise an error if not validated
                if not validated:
                    _set_validation_error(ref, "Token validation failed for user token!")

            # If we are not validating on a user token, just do standard validation
            elif token is not None:
                try:
                    # Do actual validation
                    validated, claims = _validate_access_token(token, project_id, scenario_id, access_level, validation_object_type)

                    args[1].claims = claims
                    # Raise an error if not validated
                    if not validated:
                        log.warning(
                            f"Not validated! Info:\nAccess Level: {access_level.name}\n"
                            f"Validation Object Type: {validation_object_type}\nToken: {token}\n"
                            f"Endpoint: {method.__qualname__}\nRequested Project: {project_id}\n"
                            f"Requested Scenario: {scenario_id}"
                        )
                        _set_validation_error(ref, "Token validation failed!")
                except jwt.exceptions.ExpiredSignatureError:
                    _set_validation_error(ref, "Token expired!")

            # Raise an error if there is not found a token
            else:
                _set_validation_error(ref, f"Token validation failed! No token present in context. Normal: '{token}', User: '{user_token}, Use User Token: '{use_user_token}'")

            return method(ref, *args, **kwargs)
        return inner
    return wrapper


def _get_token_key() -> str:
    secret = os.getenv('NUMEROUS_TOKEN_SECRET')
    if secret is None:
        raise ValueError('Token secret is None!')
    return secret


def validate_token_request(token_request, user_id: str) -> bool:
    user = fire.get_user(user_id)
    if token_request.access_level is not None and token_request.access_level > validation.AccessLevel[user['userRole'].upper()].value:
        return False
    return True
    

def create_refresh_token(token_request, user_id: str):
    user = fire.get_user(user_id)
    user_access_level = validation.AccessLevel[user['userRole'].upper()]
    return jwt.encode({
        "type": TokenType.REFRESH.name,
        "admin": token_request.admin,
        "project_id": token_request.project_id,
        "scenario_id": token_request.scenario_id,
        "execution_id": token_request.execution_id,
        "job_id": token_request.job_id,
        "user_id": user_id,
        "organization_id": token_request.organization_id,
        "prefix": None,
        "agent": None,
        "creation_utc_timestamp": datetime.datetime.now().timestamp(),
        "purpose": None,
        "access_level": user_access_level.value if token_request.access_level is None else min(token_request.access_level, user_access_level.value)
    }, key=os.environ["NUMEROUS_TOKEN_SECRET"], algorithm="HS256")


def generate_refresh_token(
        project_id=None, scenario_id=None, admin=False, execution_id=None, job_id=None, user_id=None,
        organization_id=None, agent=None, purpose=None, access_level=validation.AccessLevel.READ, lifetime=None
):
    """Generates a refresh token based on the claims provided on function call."""

    claims = {
        "type": TokenType.REFRESH.name,
        "admin": admin,
        "project_id": project_id,
        "scenario_id": scenario_id,
        "execution_id": execution_id,
        "job_id": job_id,
        "user_id": user_id,
        "organization_id": organization_id,
        "agent": agent,
        "creation_utc_timestamp": datetime.datetime.utcnow().timestamp(),
        "purpose": purpose,
        "access_level": access_level.value
    }

    if lifetime is not None:
        claims.update({"exp": datetime.datetime.utcnow() + datetime.timedelta(seconds=lifetime)})

    return jwt.encode(claims, key=_get_token_key(), algorithm=TOKEN_ALGORITHM)


def _validate_refresh_token(refresh_token_data, instance_id, project_id, scenario_id, job_id, execution_id) -> None:
    """Validate the refresh token. If validation was successful, no exception will be raised."""
    # Check that token type is refresh
    if refresh_token_data.get('type', None) != TokenType.REFRESH.name:
        raise ValidationException("Invalid or missing type during access token creation! Expected a refresh token.")

    # If is admin refresh token, don't check execution
    if execution_id is None or execution_id == "":
        return

    elif refresh_token_data.get('admin', False):
        if project_id is not None and project_id != "":
            try:
                fire.submit_execution(project_id, scenario_id, job_id, execution_id, None, instance_id=instance_id)
            except TypeError as te:

                raise ValidationException(f"Submit execution failed {te}, {project_id}, {scenario_id}, {execution_id}"
                                                                               )
        return

    # Get variables from token
    project_id = refresh_token_data.get('project_id', None)
    scenario_id = refresh_token_data.get('scenario_id', None)
    job_id = refresh_token_data.get('job_id', None)
    execution_id = refresh_token_data.get('execution_id', None)

    # Register execution
    job = fire.get_job(project_id, scenario_id, job_id)
    if job is not None:
        log.warning('Job found')

        active_exe = job['active_execution'] if 'active_execution' in job else None

        log.debug(f"active_exe: {active_exe}")
        if active_exe is None:
            log.warning('No current active executions')
            try:
                fire.submit_execution(project_id, scenario_id, job_id, execution_id, None, instance_id=instance_id)
            except TypeError as te:

                raise ValidationException(f"Submit execution failed {te}, {project_id}, {scenario_id}, {execution_id}")

            return

        elif active_exe == execution_id:
            log.warning('Execution matching active one')
            exe = fire.get_execution(active_exe)

            if exe['instance'] is None:
                log.warning(f'setting active instance to {instance_id}')
                fire.set_execution_instance(execution_id, instance_id)

                return

            elif exe['instance'] == instance_id:
                return

            else:
                log.warning('Instance already present!')
                log.warning(f'Active is {exe["instance"]}')
                log.warning(f'Tried with {instance_id}')
                raise ValidationException("Instance already present!")

        else:
            log.warning(f'Active is {active_exe}')
            log.warning(f'Tried with {execution_id}')

    else:
        logging.warning(f"JOB: {job} from: {project_id} : {scenario_id} : {job_id}")
        raise KeyError('Job not found!')

    raise ValidationException(f"Job has active exe with a registered instance! {project_id}, {scenario_id}, {job_id}")


def generate_access_token(refresh_token, instance_id, project_id, scenario_id, job_id, execution_id):
    try:
        refresh_token_data = _wrapped_jwt_token_read(refresh_token)
    except jwt.exceptions.InvalidSignatureError:
        raise ValidationException("Error reading refresh token!")

    _validate_refresh_token(refresh_token_data, instance_id, project_id, scenario_id, job_id, execution_id)

    # Return new token
    return jwt.encode(
        {
            "project_id": refresh_token_data.get('project_id', project_id),
            "scenario_id": refresh_token_data.get('scenario_id', scenario_id),
            "access_level": refresh_token_data.get('access_level'),
            "user_id": refresh_token_data.get("user_id"),
            "admin": refresh_token_data.get('admin', False),
            "type": TokenType.ACCESS.name,
            "exp": datetime.datetime.utcnow() + ACCESS_TOKEN_EXP
        },
        key=_get_token_key(), algorithm=TOKEN_ALGORITHM
    )


def _validate_access_token(access_token, requested_project_id, requested_scenario_id, required_access_level: validation.AccessLevel, validation_object_type) -> Tuple[bool, dict]:
    """Validate access token. Returns True if validated and false if not. Also returns claims in token."""
    access_token_data = _wrapped_jwt_token_read(access_token)

    token_project_id = access_token_data.get('project_id', '')
    token_scenario_id = access_token_data.get('scenario_id', '')
    token_admin = access_token_data.get('admin', False)
    token_access_level = access_token_data.get('access_level', validation.AccessLevel.ANY)

    user_id = access_token_data.get('user_id', '')

    claims = {'user_id': user_id}

    return validation.validate_access_token(
        access_token=access_token,
        token_project_id=token_project_id,
        token_scenario_id=token_scenario_id,
        token_admin=token_admin,
        token_access_level=token_access_level,
        required_access_level=required_access_level,
        requested_project_id=requested_project_id,
        requested_scenario_id=requested_scenario_id,
        validation_object_type=validation_object_type

    ), claims




def _validate_user_token(user_token, requested_project_id, requested_scenario_id, access_level, validation_object_type) -> Tuple[bool, dict]:
    """Will raise an error if token cannot be decoded (and therefore is invalid) - otherwise True"""
    # TODO: Maybe make more sophisticated validation?
    decoded_token = _wrapped_user_token_read(user_token)
    uid = decoded_token['uid']
    log.warning(f"User authenticated: {uid}")
    log.debug(decoded_token)
    user = fire.get_user(uid)

    if (access_level == validation.AccessLevel.ADMIN or access_level == validation.AccessLevel.OWNER) and \
            user['userRole'] in ['owner']:
        return True, {'user_id': uid}

    elif access_level == validation.AccessLevel.DEVELOPER and user['userRole'] in ['owner', 'developer']:
        return True, {'user_id': uid}

    elif access_level >= validation.AccessLevel.READ and user['userRole'] in ['owner', 'developer', 'simulator']:
        return True, {'user_id': uid}

    return False, {'user_id': uid}

def _wrapped_jwt_token_read(token, key=_get_token_key(), algorithm=TOKEN_ALGORITHM):
    return jwt.decode(jwt=token, key=key, algorithms=[algorithm])


def _wrapped_user_token_read(token):
    return auth.verify_id_token(token)


