import datetime
import logging
from functools import wraps
from pprint import PrettyPrinter
from typing import Any, Dict, Iterable, Optional, Type

import grpc
from google.protobuf.json_format import MessageToDict

from numerous.server_common.exceptions import AlreadyExistsError, InvalidArgumentError, NotFoundError, \
    NumerousBaseError, UnknownError
from numerous.tokens import BaseValidationError

log = logging.getLogger(__name__)


def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""
    if isinstance(obj, datetime.datetime):
        return obj.timestamp()
    elif isinstance(obj, datetime.date):
        return datetime.datetime(year=obj.year, month=obj.month, day=obj.day).timestamp()
    raise TypeError(f'Type {type(obj)} not serializable')


BASE_ERROR_TO_GRPC_STATUS_CODE: dict[Type[Exception], grpc.StatusCode] = {
    AlreadyExistsError: grpc.StatusCode.ALREADY_EXISTS,
    InvalidArgumentError: grpc.StatusCode.INVALID_ARGUMENT,
    NotFoundError: grpc.StatusCode.NOT_FOUND,
    UnknownError: grpc.StatusCode.UNKNOWN,
    BaseValidationError: grpc.StatusCode.UNAUTHENTICATED
}


def getattr_short_circuit(o: Any, attrs: Iterable[str]) -> Any:
    """Get a named attribute from an object request (takes a list of values and tries to get them in order)."""
    for attr in attrs:
        try:
            return getattr(o, attr)
        except AttributeError:
            pass


def handle_errors(func):
    @wraps(func)
    def handle_errors(self, request: Any, context: grpc.ServicerContext):
        try:
            return func(self, request, context)
        except (NumerousBaseError, BaseValidationError) as error:
            details = error.msg
            code = _error_to_grpc_status_code(error)
        except Exception as e:
            request_dump = _get_request_dump(request)
            log.exception("Unhandled %s occurred handling request: %s", type(e).__name__, request_dump)
            details = 'Internal error occurred in endpoint. Please check server logs.'
            code = grpc.StatusCode.INTERNAL
        log.info("Aborting context. Code: %s, details: %s", code, details)
        context.abort(code, details)

    return handle_errors


def _error_to_grpc_status_code(e: Exception) -> grpc.StatusCode:
    for error_type, code in BASE_ERROR_TO_GRPC_STATUS_CODE.items():
        if isinstance(e, error_type):
            return code
    return grpc.StatusCode.UNKNOWN


def _get_request_dump(request):
    try:
        raw_request_dump: Optional[Dict[str, Any]] = MessageToDict(request)
    except AttributeError:
        raw_request_dump = MessageToDict(request[0]) if request else None
    return PrettyPrinter(indent=2, width=120, depth=6).pformat(raw_request_dump)
