import base64
import json
import logging
from dataclasses import dataclass
from time import sleep, time
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse

import grpc

import data_port_pb2
import data_port_pb2_grpc
from numerous.client_common.get_cert import get_cert_cn
from numerous.client_common.grpc_retry import grpc_retry
from numerous.client_common.request_response_stream import RequestResponseStream
from numerous.client_common.validation_interceptor import ValidationInterceptor
from numerous.tokens.tokens_refresh import TokenRefreshable, is_expired, refresh_access_token

from .config import DATA_STREAM_DEFAULT_TIMEOUT_SLEEP_FOR, GRPC_MAX_MESSAGE_SIZE, NUMEROUS_DATA_PORT_API_PREFIX, \
    NUMEROUS_DATA_PORT_API_REFRESH_TOKEN, NUMEROUS_DATA_PORT_SERVER_CERT, NUMEROUS_DATA_PORT_SERVER_URL, \
    NUMEROUS_ORGANIZATION

log = logging.getLogger(__name__)

MAX_REFRESH_TOKEN_ATTEMPTS = 5
START_REFRESH_TOKEN_DELAY = 2


def attr_to_dict(obj, attrs):
    return {a: getattr(obj, a) for a in attrs}


# TODO: move calls to form_row_key() to Data-port service
def form_row_key(*args):
    key = ""
    for i, a in enumerate(args):
        if a != "":
            if key != "":
                key += "#"
            key += a
    return key


class TimeoutBreakStatus(Exception):
    def __init__(self, break_status):
        super(TimeoutBreakStatus, self).__init__()
        self.break_status = break_status


class DataPortClient(TokenRefreshable):
    _org = NUMEROUS_ORGANIZATION

    def __init__(self):
        log.debug("Data-Port client connecting to %s", NUMEROUS_DATA_PORT_SERVER_URL)
        parsed_url = urlparse(NUMEROUS_DATA_PORT_SERVER_URL)
        server = parsed_url.netloc.split(':')[0]
        port = parsed_url.netloc.split(':')[1]

        secure = parsed_url.scheme == "https"
        self._access_token = None
        self.channel = self._init_channel(server=server, port=port, route=parsed_url.path, secure=secure)
        self.stub = data_port_pb2_grpc.DataPortStub(self.channel)

    def _init_channel(self, server, port, route, secure: bool = True, instance_id=None):
        options = [
            ('grpc.max_message_length', GRPC_MAX_MESSAGE_SIZE),
            ('grpc.max_send_message_length', GRPC_MAX_MESSAGE_SIZE),
            ('grpc.max_receive_message_length', GRPC_MAX_MESSAGE_SIZE),
        ]

        if not secure:
            channel = grpc.insecure_channel(f'{server}:{port}', options)
        else:
            cert = base64.b64decode(NUMEROUS_DATA_PORT_SERVER_CERT)
            creds = grpc.ssl_channel_credentials(cert)
            options += [
                ('grpc.ssl_target_name_override', get_cert_cn(cert)),
            ]
            channel = grpc.secure_channel(f'{server}:{port}', creds, options)

        if NUMEROUS_DATA_PORT_API_REFRESH_TOKEN:
            vi = ValidationInterceptor(token=self._access_token, token_callback=self._get_current_token, instance=instance_id)
            self._instance = vi.instance
            channel = grpc.intercept_channel(channel, vi)

        return channel

    def _get_current_token(self):
        return self._access_token

    @grpc_retry(max_attempts=MAX_REFRESH_TOKEN_ATTEMPTS, start_delay=START_REFRESH_TOKEN_DELAY)
    def _refresh_access_token(self):
        if not NUMEROUS_DATA_PORT_API_REFRESH_TOKEN or (self._access_token and not is_expired(self._access_token)):
            return
        token = self.stub.GetAccessToken(
            data_port_pb2.RefreshRequest(
                refresh_token=data_port_pb2.Token(val=NUMEROUS_DATA_PORT_API_REFRESH_TOKEN),
                prefix=NUMEROUS_DATA_PORT_API_PREFIX,
            )
        )
        self._access_token = token.val

    def close(self):
        pass

    @refresh_access_token
    def push_log_entries(self, execution, log_entries, timestamps):
        self.stub.PushLogEntries(data_port_pb2.LogEntries(prefix=self._org, key=execution, log_entries=log_entries, timestamps=timestamps))

    @refresh_access_token
    def read_logs_time_range(self, execution, start, end):
        for e in self.stub.ReadEntries(data_port_pb2.ReadLogsSpec(prefix=self._org, key=execution, start=start, end=end)):
            yield e.log_entry, e.timestamp

    @refresh_access_token
    def set_meta_data(self, scenario, execution, offset, tags, aliases, epoch_type, timezone):
        data = dict(
            offset=offset, tags=tags, aliases=aliases, epoch_type=epoch_type, timezone=timezone
        )
        self.set_custom_meta_data(scenario, execution, key="meta", meta=json.dumps(data))

    @refresh_access_token
    def get_meta_data(self, scenario, execution):
        return self.get_custom_meta_data(scenario, execution, 'meta')

    @refresh_access_token
    def set_custom_meta_data(self, scenario, execution, key, meta):
        self.stub.SetMetaData(
            data_port_pb2.MetaSpec(prefix=self._org, key=form_row_key(scenario, execution), meta_key=key,
                                           data=meta))

    @refresh_access_token
    def get_custom_meta_data(self, scenario, execution, key):
        reply = self.stub.GetMetaData(
            data_port_pb2.ReadMetaSpec(prefix=self._org, key=form_row_key(scenario, execution), meta_key=key)
        )
        return json.loads(reply.data)

    @refresh_access_token
    def read(self, scenario, execution, tags, start, end, time_range=False):
        for dl in self.stub.ReadData(data_port_pb2.Spec(prefix=self._org, key=form_row_key(scenario,execution), tags=tags, start=start, end=end, time_range=time_range)):
            yield dl

    @refresh_access_token
    def read_time_range(self, scenario, execution, tags, start, end):
        for r in self.read(scenario, execution, tags, start, end, time_range=True):
            yield r

    @refresh_access_token
    def read_block_range(self, scenario, execution, tags, start, end):
        for r in self.read(scenario, execution, tags, start, end, time_range=False):
            yield r

    @refresh_access_token
    def read_data_stats(self, scenario: str, execution: str, tags: List[str]):
        request = data_port_pb2.ReadSpecStats(
            prefix=self._org,
            key=form_row_key(scenario, execution),
            tags=tags
        )
        reply = self.stub.ReadDataStats(request)
        return attr_to_dict(reply, ['min', 'max', 'equi_space', 'spacing', 'n_blocks', 'equi_block_len', 'block_len0',
                                    'block_len_last', 'total_val_len'])

    @refresh_access_token
    def push_data_version_dict(self, scenario, execution, data, reset_block_counter=False):
        block_counter_ = self.stub.PushDataList(
            data_port_pb2.DataList(
                prefix=self._org,
                key=form_row_key(scenario, execution),
                data=[data_port_pb2.DataBlock(tag=t, values=v) for t, v in data.items()],
                reset_block_counter=reset_block_counter,
            )
        )
        return block_counter_.block_counter

    @refresh_access_token
    def get_block_counter(self, scenario, execution):
        block_counter_ = self.stub.GetBlockCounter(
            data_port_pb2.KeySpec(
                prefix=self._org, key=form_row_key(scenario, execution),
            )
        )
        return block_counter_.block_counter

    @refresh_access_token
    def clear(self, scenario, execution):
        self.stub.ClearData(data_port_pb2.KeySpec(prefix=self._org, key=form_row_key(scenario, execution)))

    @refresh_access_token
    def submit_delete_data(self, scenarios, executions):
        self.stub.SubmitDeleteData(data_port_pb2.DeleteDataSpec(prefix=self._org, key=[form_row_key(scenario, execution) for scenario, execution in zip(scenarios, executions)]))

    @refresh_access_token
    def submit_delete_logs(self, scenarios, executions):
        self.stub.SubmitDeleteLogs(data_port_pb2.DeleteDataSpec(prefix=self._org, key=[form_row_key(scenario, execution) for scenario, execution in zip(scenarios, executions)]))

    @refresh_access_token
    def submit_delete_data_and_logs(self, scenarios, executions):
        self.submit_delete_logs(scenarios, executions)
        self.submit_delete_data(scenarios, executions)

    @refresh_access_token
    def delete_columns(self, scenario, execution, columns):
        self.stub.ClearDataTags(data_port_pb2.KeySpec(prefix=self._org, key=form_row_key(scenario, execution), tags=columns))

    @refresh_access_token
    def set_data_stream_status(self, scenario, execution, status):
        self.set_custom_meta_data(scenario, execution, 'status', json.dumps(status))

    @refresh_access_token
    def get_data_stream_status(self, scenario, execution):
        return self.get_custom_meta_data(scenario, execution, 'status')

    @refresh_access_token
    def open_write_data_stream(self, scenario, execution):
        class _WriteDataStream(RequestResponseStream):
            def __init__(stream, client: DataPortClient):
                super(_WriteDataStream, stream).__init__(
                    client.stub,
                    client.stub.WriteDataStream,
                )

            def write(
                    stream, index: List[float], data: Dict[str, List[float]], overwrite: bool = False,
                    update_stats: bool = False
            ):
                response = stream.send(
                    data_port_pb2.WriteDataStreamRequest(
                        prefix=self._org,
                        key=form_row_key(scenario, execution),
                        overwrite=overwrite,
                        index=index,
                        data={tag: data_port_pb2.StreamData(values=values) for tag, values in data.items()},
                        update_stats=update_stats
                    )
                )
                return response.index

        return _WriteDataStream(self)

    @refresh_access_token
    def open_read_data_stream(
            self,
            scenario: Optional[str] = None,
            execution: Optional[str] = None,
    ):
        class _ReadDataStream(RequestResponseStream):
            def __init__(stream, client: DataPortClient):
                super(_ReadDataStream, stream).__init__(
                    client.stub,
                    client.stub.ReadDataStream,
                )

            def read(
                    stream,
                    tags: List[str],
                    start: Optional[float] = None,
                    end: Optional[float] = None,
                    length: Optional[int] = None,
                    timeout_wait: float = 0.0,
                    timeout_sleep_for: float = DATA_STREAM_DEFAULT_TIMEOUT_SLEEP_FOR,
                    timeout_break_statuses: Optional[List[Any]] = None,
            ):
                deadline_time = time() + timeout_wait
                while True:
                    response = stream.send(
                        data_port_pb2.ReadDataStreamRequest(
                            prefix=self._org,
                            key=form_row_key(scenario, execution),
                            tags=tags,
                            start=start,
                            end=end,
                            length=length,
                        )
                    )

                    if response.index or time() >= deadline_time:
                        return response.index, {tag: stream_data.values for tag, stream_data in response.data.items()}

                    if timeout_break_statuses:
                        status = self.get_data_stream_status(scenario, execution)
                        if status in timeout_break_statuses:
                            raise TimeoutBreakStatus(status)

                    sleep(timeout_sleep_for)

        return _ReadDataStream(self)

    @dataclass
    class DataStreamStats:
        min: Optional[float] = None
        max: Optional[float] = None
        len: Optional[int] = None

    def get_data_stream_stats(
            self, scenario: str, execution: str, tags: Optional[List[str]] = None
    ) -> Dict[str, DataStreamStats]:
        response = self.stub.GetDataStreamStats(
            data_port_pb2.GetDataStreamStatsRequest(
                prefix=self._org,
                key=form_row_key(scenario, execution),
                tags=tags,
            )
        )
        stats = {
            tag: self.DataStreamStats(
                min=tag_stats.min,
                max=tag_stats.max,
                len=tag_stats.len,
            ) for tag, tag_stats in response.stats.items()
        }
        return stats


def prefix_path(channel):
    prefix_path = "/test"

    def decorate_channel_method(f):
        def wrap_prefix(method_path, request_serializer, response_deserializer):
            prefixed_path = prefix_path+method_path
            log.info('%s', prefixed_path)
            return f(prefixed_path, request_serializer=request_serializer, response_deserializer=response_deserializer)
        return wrap_prefix

    channel.unary_unary = decorate_channel_method(channel.unary_unary)
    channel.unary_stream = decorate_channel_method(channel.unary_stream)
    channel.stream_stream = decorate_channel_method(channel.stream_stream)
    channel.stream_unary = decorate_channel_method(channel.stream_unary)
