import logging
import time
from typing import Any, Callable

import grpc
from grpc_interceptor import ServerInterceptor

from .metrics import GRPC_ERRORS_COUNTER, GRPC_REQUESTS_COUNTER, GRPC_REQUESTS_LATENCY

log = logging.getLogger(__name__)


class LoggingInterceptor(ServerInterceptor):
    def intercept(self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str):
        log.info('Request %s', method_name)
        log.debug('Invocation-Metadata for %s: %s', method_name,
                  ', '.join(f'{m[0]}={m[1]}' for m in context.invocation_metadata()))
        start = time.time()
        try:
            return method(request_or_iterator, context)
        finally:
            log.debug('Executed %s, duration %.3fms', method_name, 1000 * (time.time() - start))


class PrometheusInterceptor(ServerInterceptor):
    def intercept(self, method: Callable, request_or_iterator: Any, context: grpc.ServicerContext, method_name: str):
        with GRPC_REQUESTS_LATENCY.labels(method=method_name).time():
            try:
                return method(request_or_iterator, context)
            finally:
                code = context.code()
                if isinstance(code, grpc.StatusCode) and code != grpc.StatusCode.OK:
                    GRPC_ERRORS_COUNTER.labels(error=str(code.value)).inc()
                GRPC_REQUESTS_COUNTER.labels(method=method_name).inc()
