import os
import sys
import logging
import jwt
import traceback
from google.protobuf.json_format import MessageToDict

from config import BIGTABLE_MOCK
from data_drivers import BigTableDriver, HBaseDriver
from data_processor import DataProcessor

logging.basicConfig(level=logging.INFO)
streamhandler = logging.StreamHandler(sys.stdout)

def tb_str(e):
    s=traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
    return [ss.strip() for ss in s]

if os.getenv('NUMEROUS_LOCAL') == 'True':
    from env_from_yaml_secrets import load_env_from_yaml

    load_env_from_yaml()

from concurrent import futures
import time

import grpc
from tokens import validated_request, AccessLevel
import tokens as token_manager

import datetime
from grpc_interceptor import ServerInterceptor
from grpc_reflection.v1alpha import reflection

import numerous_bigtable_pb2_grpc, numerous_bigtable_pb2


from numerous_cert_server.server import serve_env


bigtable = DataProcessor(HBaseDriver() if BIGTABLE_MOCK else BigTableDriver())


def json_serial(obj):
    """JSON serializer for objects not serializable by default json code"""

    if isinstance(obj, (datetime.datetime, datetime.date)):
        return obj.timestamp()
    raise TypeError ("Type %s not serializable" % type(obj))

#my_own_url = str(os.getenv('NUMEROUS_BIGTABLE_API_SERVER'))
#my_own_port = str(os.getenv('NUMEROUS_BIGTABLE_API_PORT'))
#my_own_secure = str(os.getenv('SECURE_CHANNEL'))
#my_own_insecure_address = str(os.getenv('NUMEROUS_BIGTABLE_SERVICE_ADDRESS'))
#my_own_insecure_port = str(os.getenv('NUMEROUS_BIGTABLE_SERVICE_PORT'))


class InstanceInterceptor(ServerInterceptor):
    def intercept(self, method, request_or_iterator, context: grpc.ServicerContext, endpoint,  *args):

        context.metadata = {t[0]:t[1] for t in context.invocation_metadata()}
        log.debug('Requested endpoint: '+str(endpoint))
        try:
            msg=MessageToDict(request_or_iterator)
            if 'data' in msg:
                msg.pop('data')
                msg.update({'note': "data omitted"})
            log.debug(f'endpoint {endpoint} - request: {msg}')
        except Exception as e:
            log.error(f'could not translate message {tb_str(e)}')
        return method(request_or_iterator, context)


def get_env(val, env, default=None):

    if val is None:
        env_val = os.getenv(env)

        if env_val is None:
            if default is None:
                raise KeyError(f'ENV var <{env}> is not set.')
            else:
                return default

        return env_val
    else:
        return val




log = logging.getLogger('numerous_bigtable.server')
log.setLevel(get_env(None, env='NUMEROUS_LOGGING_LEVEL_SERVER', default=logging.DEBUG))

formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
streamhandler.setFormatter(formatter)
log.addHandler(streamhandler)


def standard_error_handling():
    def print_tb_(f):
        def ptb(self, request, context):
            try:
                tic = time.time()

                res = f(self, request, context)


                return res
            except:
                ##exc_type, exc_value, exc_traceback = sys.exc_info()
                #traceback.print_exc()

                if not hasattr(self,'code') or self.code is None:
                    self.code = grpc.StatusCode.INTERNAL
                    self.msg = 'Internal error occurred in endpoint. Please check server logs.'

                    context.set_code(self.code)
                    context.set_details(self.msg)

                    raise
                else:
                    context.set_code(self.code)
                    context.set_details(self.msg)

                self.code = None
                self.msg = None
                return numerous_bigtable_pb2.Empty()

            finally:
                toc = time.time()
                #log.debug(f'Called endpoint: {f.__qualname__} in {toc - tic}s')


        return ptb
    return print_tb_


class Error_Handler:
    def set_error(self, exception_cls, code: grpc.StatusCode=grpc.StatusCode.INTERNAL, msg:str='Internal Error Occured!'):
        self.code = code
        self.msg = msg

        raise exception_cls(self.msg)


class BigtableServicer(numerous_bigtable_pb2_grpc.BigtableServicer, Error_Handler):

    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def ClearData(self, request, context):

        #bigtable.clear_data(request.prefix, request.key)
        return numerous_bigtable_pb2.Empty()

    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def ClearDataTags(self, request, context):

        bigtable.delete_columns(request.prefix, request.key, request.tags)
        return numerous_bigtable_pb2.Empty()

    #TODO offload deletion to worker - remember only one delete API call is allowed with bigtable
    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def SubmitDeleteData(self, request, context):
        for k in request.key:
            bigtable.clear(request.prefix, k)
        return numerous_bigtable_pb2.Empty()

    # TODO offload deletion to worker - remember only one delete API call is allowed with bigtable
    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def SubmitDeleteLogs(self, request, context):
        for k in request.key:
            bigtable.clear_logs(request.prefix, k)
        return numerous_bigtable_pb2.Empty()

    @validated_request(access_level=AccessLevel.READ)
    @standard_error_handling()
    def ReadData(self, request: numerous_bigtable_pb2.Spec, context):
        if request.time_range:
            data = bigtable.read_time_range(request.prefix, request.key, request.tags, request.start, request.end)
        else:
            data = bigtable.read_block_range(request.prefix, request.key, request.tags, request.start, request.end)

        for d in data:

            data_blocks = []
            blocks_size = 0

            for i, _d in enumerate(d[0]):
                new_block = numerous_bigtable_pb2.DataBlock(**_d)

                new_block_size = sys.getsizeof(new_block)

                # log.debug('block size: '+str(new_block_size))
                # log.debug('tot block size: ' + str(blocks_size))

                if blocks_size + new_block_size > 4e6 / 70:
                    row_complete = d[1] if i + 1 == len(d[0]) else False
                    block_complete = d[2] if i + 1 == len(d[0]) else False
                    yield numerous_bigtable_pb2.DataList(prefix=request.prefix, key=request.key, data=data_blocks,
                                           row_complete=row_complete, block_complete=block_complete)

                    data_blocks = []
                    blocks_size = 0

                blocks_size += new_block_size
                data_blocks.append(new_block)

            if len(data_blocks) > 0:
                yield numerous_bigtable_pb2.DataList(prefix=request.prefix, key=request.key, data=data_blocks,
                                       row_complete=d[1], block_complete=d[2])
        log.debug(f'Reading data completed')

    def write_data(self, request_iterator):

        validated = False
        first_req = True

        prefix = None
        key = None
        blockcounter = None
        for r in request_iterator:
            if prefix is None:
                prefix = r.prefix
                key = r.key

            if r.key != key or r.prefix != prefix:
                raise ValueError('Cannot change key or prefix during upload blocks!')

            if r.reset_block_counter:
                blockcounter = 0
            elif blockcounter is None:
                blockcounter = bigtable.get_block_counter(prefix, key)

            first_req = False

            if not validated:
                validated_request(access_level=AccessLevel.WRITE)(r)
                validated = True

            block_counter = bigtable.push_data_version_dict(prefix, key, blockcounter, {b.tag: b.values for b in r.data})

            bigtable.set_block_counter(prefix, key, block_counter)

        return numerous_bigtable_pb2.DataBlockCounter(block_counter=block_counter)

    @standard_error_handling()
    def WriteDataList(self, request_iterator, context):
        return self.write_data(request_iterator)

    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def PushDataList(self, request, context):
        return self.write_data([request])

    @validated_request(access_level=AccessLevel.READ)
    @standard_error_handling()
    def GetBlockCounter(self, request, context):
        return numerous_bigtable_pb2.DataBlockCounter(block_counter=bigtable.get_block_counter(request.prefix, request.key))


    @validated_request(access_level=AccessLevel.READ)
    @standard_error_handling()
    def ReadDataStats(self, request, context):

        stats = bigtable.read_data_stats(request.prefix, request.key, request.tag)
        return numerous_bigtable_pb2.SpecStats(prefix=request.prefix, key=request.key,
            min = stats['min'], max = stats['max'], equi_space = stats['equi_space'],
            spacing = stats['spacing'], n_blocks =stats['n_blocks'],
            equi_block_len = stats['equi_block_len'],
            block_len0 = stats['block_len0'],
            block_len_last = stats['block_len_last'], total_val_len = stats['total_val_len'],
        )
        #return spm_pb2.Json(json=json.dumps(stats))

    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def PushLogEntries(self, request, context):

        if len(request.timestamps) ==len(request.log_entries):
            bigtable.push_log_entries(request.prefix, request.key, request.timestamps, request.log_entries)
        else:
            self.set_error(IndexError, grpc.StatusCode.INVALID_ARGUMENT, "logs and timestamps must have same length.")
        return numerous_bigtable_pb2.Empty()

    @validated_request(access_level=AccessLevel.READ)
    @standard_error_handling()
    def ReadEntries(self, request, context):
        for l, t in bigtable.read_logs_time_range(request.prefix, request.key, request.start, request.end):
            yield numerous_bigtable_pb2.LogsSpecKey(prefix=request.prefix, key=request.key, log_entry=l, timestamp=t)

    @validated_request(access_level=AccessLevel.READ)
    @standard_error_handling()
    def GetMetaData(self, request, context):

        meta = bigtable.get_meta_data(request.prefix, request.key, request.meta_key)


        return numerous_bigtable_pb2.MetaSpec(prefix=request.prefix, key=request.key, meta_key=request.meta_key, data=meta)


    @validated_request(access_level=AccessLevel.WRITE)
    @standard_error_handling()
    def SetMetaData(self, request, context):
        bigtable.set_meta_data(request.prefix, request.key, meta_key=request.meta_key, meta=request.data)
        return numerous_bigtable_pb2.Empty()

    @standard_error_handling()
    def GetAccessToken(self, request, context):
        access_token = None
        try:
            access_token = token_manager.generate_access_token(request.refresh_token.val)
        except token_manager.ValidationException as e:
            self.set_error(token_manager.ValidationException, grpc.StatusCode.UNAUTHENTICATED, e.__str__())
        except jwt.exceptions.DecodeError as e:
            self.set_error(token_manager.ValidationException, grpc.StatusCode.UNAUTHENTICATED, e.__str__())
        except KeyError as e:
            self.set_error(KeyError, grpc.StatusCode.NOT_FOUND, e.__str__())
        return numerous_bigtable_pb2.Token(val=access_token)

    @validated_request(access_level=AccessLevel.DEVELOPER)
    @standard_error_handling()
    def CreateRefreshToken(self, request, context):
        refresh_token = None
        try:
            # Generate refresh token, and then use refresh token to generate access token
            refresh_token = token_manager.generate_refresh_token(

                prefix=request.prefix, user_id=request.user_id,
                organization_id=request.organization_id,
                agent=request.agent,
                purpose=request.purpose, access_level=request.access_level
            )
        except token_manager.ValidationException as e:
            self.set_error(token_manager.ValidationException, grpc.StatusCode.UNAUTHENTICATED, e.__str__())

        except KeyError as e:
            self.set_error(KeyError, grpc.StatusCode.NOT_FOUND, e.__str__())

        logging.warning(f"REFRESH: {refresh_token}")

        return numerous_bigtable_pb2.Token(val=refresh_token)

def _initialize_channel(server, port):
    #if str(secure_channel) == 'True':
    #log.debug('cert key: '+os.getenv('NUMEROUS_CERT_KEY'))
    #log.debug('cert cer: '+os.getenv('NUMEROUS_CERT_CRT'))

    private_key = str.encode(os.getenv('NUMEROUS_CERT_KEY'))
    certificate_chain = str.encode(os.getenv('NUMEROUS_CERT_CRT'))

    serve_env('NUMEROUS_CERT_KEY', 'NUMEROUS_CERT_CRT', '0.0.0.0', 4443)

    server_credentials = grpc.ssl_server_credentials(((private_key, certificate_chain),))
    server.add_secure_port(f'[::]:{port}', server_credentials)

    #else:
    #server.add_insecure_port(f'[::]:{port}')

    #server.add_insecure_port(f'[::]:{insecure_port}')

    reflection.enable_server_reflection(['BigtableServicer'], server)

    return server

if __name__ == '__main__':
    # create a gRPC server
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=100),interceptors=[InstanceInterceptor()])

    # kubernetes_api = kubernetes_api.KubernetesClient()

    # to add the defined class to the server
    numerous_bigtable_pb2_grpc.add_BigtableServicer_to_server(BigtableServicer(), server)


    #log.info(f'Starting server. Listening on port {50051}. Using secure channel: {my_own_secure}')

    server = _initialize_channel(server, 50051)
    # server.add_insecure_port(f'[::]:{port}')
    server.start()


    log.info(f'Starting server. Listening on port {50051}')#. Using secure channel: {False}')


    # since server.start() will not block,
    # a sleep-loop is added to keep alive
    try:
        while True:
            time.sleep(86400)
    except KeyboardInterrupt:
        server.stop(0)

