import os
#os.environ["GRPC_TRACE"] = "transport_security,tsi"
#os.environ["GRPC_VERBOSITY"] = "DEBUG"
import json
from . import numerous_bigtable_pb2, numerous_bigtable_pb2_grpc
from urllib.parse import urlparse
import grpc
import logging
import logging.handlers
from numerous_api_client.headers.ValidationInterceptor import ValidationInterceptor
import threading
from time import  time
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from numerous_cert_server.cert_helper.get_cert import get_cert

log = logging.getLogger('numerous_bigtable_client')
log.setLevel(logging.DEBUG)


def attr_to_dict(obj, attrs):
    return {a: getattr(obj, a) for a in attrs}

org = os.getenv('NUMEROUS_ORGANIZATION')

if org is None:
    raise ValueError('No organtization specified!')

class NumerousBigtableGrpcClient:
    def __init__(self, url):

        parsed_url = urlparse(url)


        server = parsed_url.netloc.split(':')[0]
        port = parsed_url.netloc.split(':')[1]


        secure = parsed_url.scheme == "https"
        self._access_token = None

        self._refresh_token = os.getenv('NUMEROUS_BIGTABLE_API_REFRESH_TOKEN')
        self._prefix = os.getenv('NUMEROUS_BIGTABLE_API_PREFIX')

        self.channel = self._init_channel(server=server, port=port, route=parsed_url.path, secure=secure)
        #prefix_path(self.channel)
        self.stub = numerous_bigtable_pb2_grpc.BigtableStub(self.channel)

        self._access_token_refresher = RepeatedFunction(
            interval=9 * 60, function=self._refresh_access_token, run_initially=True,
            refresh_token=self._refresh_token)
        self._access_token_refresher.start()

    def _init_channel(self, server, port, route, secure=None, instance_id=None):
        cert = str.encode(get_cert(f'https://{server}:4443/cert'))
        creds = grpc.ssl_channel_credentials(cert)
        cert_decoded = x509.load_pem_x509_certificate(cert, default_backend())
        cert_cn = cert_decoded.subject.rfc4514_string().split('CN=')[-1].split(',')[0]
        options = (('grpc.ssl_target_name_override', cert_cn),)
        channel = grpc.secure_channel(f'{server}:{port}', creds,options)

        vi = ValidationInterceptor(token=self._access_token, token_callback=self._get_current_token, instance=instance_id)
        self._instance = vi.instance
        channel = grpc.intercept_channel(channel, vi)

        return channel

    def _get_current_token(self):
        return self._access_token

    def _refresh_access_token(self, refresh_token):
        token = self.stub.GetAccessToken(
            numerous_bigtable_pb2.RefreshRequest(
                refresh_token=numerous_bigtable_pb2.Token(val=refresh_token), prefix=self._prefix
            )
        )
        self._access_token = token.val

    def close(self):

        self._access_token_refresher.stop()

url = os.getenv('NUMEROUS_BIGTABLE_SERVER_URL')#"https://34.79.178.139:50056"


def form_row_key(*args):
    key = ""
    for i, a in enumerate(args):
        if a!="":
            if key!="":
                key+="#"
            key+=a


    return key

class Bigtable(NumerousBigtableGrpcClient):

    def __init__(self):
        super().__init__(url)

    def push_log_entries(self, execution, log_entries, timestamps):
        self.stub.PushLogEntries(numerous_bigtable_pb2.LogEntries(prefix=org, key=execution, log_entries=log_entries, timestamps=timestamps))

    def read_logs_time_range(self, execution, start, end):
        for e in self.stub.ReadEntries(numerous_bigtable_pb2.ReadLogsSpec(prefix=org, key=execution, start=start, end=end)):
            yield e.log_entry, e.timestamp

    def set_meta_data(self, scenario, execution, offset, tags, aliases, epoch_type, timezone):
        data = dict(
            offset=offset, tags=tags, aliases=aliases, epoch_type=epoch_type, timezone=timezone
        )
        self.set_custom_meta_data(scenario, execution, key="meta", meta=json.dumps(data))

    def get_meta_data(self, scenario, execution):
        return self.get_custom_meta_data(scenario, execution, 'meta')

    def set_custom_meta_data(self, scenario, execution, key, meta):
        self.stub.SetMetaData(
            numerous_bigtable_pb2.MetaSpec(prefix=org, key=form_row_key(scenario, execution), meta_key=key,
                                           data=meta))

    def get_custom_meta_data(self, scenario, execution, key):
        reply = self.stub.GetMetaData(
            numerous_bigtable_pb2.ReadMetaSpec(prefix=org, key=form_row_key(scenario, execution), meta_key=key))
        return json.loads(reply.data)

    def read(self, scenario, execution, tags, start, end, time_range=False):
        for dl in self.stub.ReadData(numerous_bigtable_pb2.Spec(prefix=org, key=form_row_key(scenario,execution), tags=tags, start=start, end=end, time_range=time_range)):
            yield dl

    def read_time_range(self, scenario, execution, tags, start, end):
        for r in  self.read(scenario, execution, tags, start, end, time_range=True):
            yield r

    def read_block_range(self, scenario, execution, tags, start, end):
        for r in self.read(scenario, execution, tags, start, end, time_range=False):
            yield r

    def read_data_stats(self, scenario, execution, tag='_index'):
        reply = self.stub.ReadDataStats(
            numerous_bigtable_pb2.ReadSpecStats(prefix=org, key=form_row_key(scenario, execution), tag=tag))

        return attr_to_dict(reply, ['min', 'max', 'equi_space', 'spacing', 'n_blocks', 'equi_block_len', 'block_len0', 'block_len_last', 'total_val_len'])

    def push_data_version_dict(self, scenario, execution, data, block_length=1000):
        block_counter_ = self.stub.PushDataList(
            numerous_bigtable_pb2.DataList(
                prefix=org, key=form_row_key(scenario,execution), data=[numerous_bigtable_pb2.DataBlock(tag=t, values=v) for t, v in data.items()], reset_block_counter=False
            )
        )
        return block_counter_.block_counter

    def get_block_counter(self, scenario, execution):
        block_counter_ = self.stub.GetBlockCounter(
            numerous_bigtable_pb2.KeySpec(
                prefix=org, key=form_row_key(scenario, execution),
            )
        )
        return block_counter_.block_counter

    def clear(self, scenario, execution):
        self.stub.ClearData(numerous_bigtable_pb2.KeySpec(prefix=org, key=form_row_key(scenario, execution)))

    def submit_delete_data(self, scenarios, executions):
        self.stub.SubmitDeleteData(numerous_bigtable_pb2.DeleteDataSpec(prefix=org, key=[form_row_key(scenario, execution) for scenario, execution in zip(scenarios, executions)]))

    def submit_delete_logs(self, scenarios, executions):
        self.stub.SubmitDeleteLogs(numerous_bigtable_pb2.DeleteDataSpec(prefix=org, key=[form_row_key(scenario, execution) for scenario, execution in zip(scenarios, executions)]))

    def submit_delete_data_and_logs(self, scenarios, executions):
        self.submit_delete_logs(scenarios, executions)
        self.submit_delete_data(scenarios, executions)


    def delete_columns(self, scenario, execution, columns):
        self.stub.ClearDataTags(numerous_bigtable_pb2.KeySpec(prefix=org, key=form_row_key(scenario, execution), tags=columns))

    def close(self):

        super().close()

class RepeatedFunction:
    def __init__(self, interval, function, run_initially=False, *args, **kwargs):
        self._timer = None
        self.interval = interval
        self.function = function
        self.args = args
        self.kwargs = kwargs
        self.is_running = False
        self.next_call = time()

        if run_initially:
            self.function(*self.args, **self.kwargs)

    def _run(self):
        self.is_running = False
        self.start()
        self.function(*self.args, **self.kwargs)

    def start(self):
        if not self.is_running:
            self.next_call += self.interval
            self._timer = threading.Timer(self.next_call - time(), self._run)
            self._timer.start()
            self.is_running = True

    def stop(self):
        self._timer.cancel()
        self.is_running = False



def prefix_path(channel):

    prefix_path = "/test"


    def decorate_channel_method(f):
        def wrap_prefix(method_path, request_serializer, response_deserializer):
            prefixed_path = prefix_path+method_path
            print(prefixed_path)
            return f(prefixed_path, request_serializer=request_serializer, response_deserializer=response_deserializer)

        return wrap_prefix

    channel.unary_unary = decorate_channel_method(channel.unary_unary)
    channel.unary_stream = decorate_channel_method(channel.unary_stream)
    channel.stream_stream = decorate_channel_method(channel.stream_stream)
    channel.stream_unary = decorate_channel_method(channel.stream_unary)

