import logging
import struct
from datetime import datetime
from typing import List, Optional, Union

import numpy as np

from data_drivers import BigTableDriver, HBaseDriver

log = logging.getLogger('numerous_bigtable.bigtable_mock')
log.setLevel(logging.DEBUG)


class DataProcessor:
    Z_FILL = 10

    @staticmethod
    def _form_row_key(*args: str) -> bytes:
        return '#'.join(arg for arg in args if arg).encode()

    @staticmethod
    def _get_block(stats: dict, t: int):
        intervals = np.array(stats['t_block_interval']).flatten()
        if t >= max(intervals):
            ix = [[len(intervals)]]
        else:
            ix = np.where(t < intervals)
        if len(ix[0]) == 0:
            return 0
        i = min(ix[0])
        ix_i = int((i + 1) / 2 - 1)
        return ix_i  # convert to block

    def __init__(self, driver):
        self._driver: Union[BigTableDriver, HBaseDriver] = driver

    def clear(self, prefix_key, key):
        self._driver.delete_data(row_prefix=self._form_row_key(prefix_key, key))

    def delete_columns(self, prefix_key, key, columns):
        for column in columns:
            if not column:
                raise ValueError('column cannot be empty string')
            self._driver.delete_data(row_prefix=self._form_row_key(prefix_key, key, column))

    def clear_logs(self, prefix_key, key):
        self._driver.delete_logs(row_prefix=self._form_row_key(prefix_key, key))

    def read_time_range(self, prefix_key, key, tags: List[str], start: int, end: int) -> dict:
        stats = self.read_data_stats(prefix_key, key)
        log.debug(f"stats: {stats}")

        if stats['min'] is not None:
            start_block = self._get_block(stats, start)
            end_block = max(start_block, self._get_block(stats, end)) + 1
            log.debug(f'start_block: {start_block}')
            log.debug(f'end_block: {start_block}')

            for b in self.read_block_range(prefix_key, key, tags, start=start_block, end=end_block):
                yield b

    def read_block_range(self, prefix_key, key, tags: List[str], start: Optional[int] = None,
                         end: Optional[int] = None):
        if not tags:
            return

        if start is None:
            start = 0
        else:
            start = int(start)

        if end is None or end == 0:
            end = 9999

        end = max(end, start)+1

        for i in range(start, end):
            row_id_prefixes = [self._form_row_key(prefix_key, key, tag, str(i).zfill(self.Z_FILL)) for tag in tags]
            ix = self._form_row_key(prefix_key, key).count(b'#') + 1

            data = []
            size = 0
            for row_id, values in self._driver.get_data(row_prefixes=row_id_prefixes):
                rowkey = row_id.decode().split('#')
                for bindata in values:
                    vals = [struct.unpack('>d', bindata[8 * b:8 * (b + 1)])[0] for b in range(int(len(bindata) / 8))]
                    tag = rowkey[ix]
                    data.append(dict(tag=tag, values=vals))
                    size += len(vals)
                    if size > 1e6:
                        yield data, False, False
                        size = 0
                        data = []

            if data:
                yield data, True, True
            else:
                break

    def push_data_version_dict(self, prefix_key, key, block_counter: int, data: dict, block_length: int = 1e6):
        cols = list(data.keys())
        time_stamps = data['_index']
        block_counter_ = 0
        batch = self._driver.begin_put_data_batch()

        for c in cols:
            # TODO: Seems like an error: block_counter_ is overwritten every iteration
            block_counter_ = block_counter if block_counter else 0

            # TODO write a sensible timestamp!
            block = []
            t_ = time_stamps[0]
            col_data = data[c]
            for i, t, d in zip(range(len(col_data)), time_stamps, col_data):
                block.append(struct.pack('>d', d))
                if i > 0 and i % block_length == 0:
                    row_key = self._form_row_key(prefix_key, key, c, str(block_counter_).zfill(self.Z_FILL))
                    self._driver.put_data(batch, row_key, b''.join(block), t_)
                    block_counter_ += 1
                    t_ = t
                    block = []

            if block:
                row_key = self._form_row_key(prefix_key, key, c, str(block_counter_).zfill(self.Z_FILL))
                self._driver.put_data(batch, row_key, b''.join(block), t_)
                block_counter_ += 1

        self._driver.end_put_data_batch(batch)
        return block_counter_

    def get_block_counter(self, prefix_key, key):
        row_key = self._form_row_key(prefix_key, key, 'block_counter')
        meta = self._driver.get_meta(row_key)
        if not meta:
            return None

        block_counter_str = meta.decode()
        print('block: ' + block_counter_str)
        block_counter = int(block_counter_str)
        return block_counter

    def set_block_counter(self, prefix_key, key, block_counter):
        row_key = self._form_row_key(prefix_key, key, 'block_counter')
        self._driver.put_meta(row_key, str(block_counter).encode())

    def read_data_stats(self, prefix_key, key, tag='_index'):
        min_t = np.inf
        max_t = 0
        counter = 0
        val_counter = 0
        diff_equal = True
        blocks_equal = 0
        diff_ = [0]

        blocks_len0 = None
        l_vals = 0

        no_data = True
        t_block_interval = []
        for r in self.read_block_range(prefix_key, key, tags=[tag], start=0, end=0):
            no_data = False
            vals = r[0][0]['values']

            l_vals = len(vals)
            t_block_interval.append([min(vals), max(vals)])
            if l_vals > 0:
                if blocks_len0 is None:
                    blocks_len0 = l_vals

                val_counter += l_vals

                if l_vals != blocks_len0 and blocks_equal < 2:
                    blocks_equal += 1

                counter += 1

                if (min_t_r := np.min(vals)) < min_t:
                    min_t = min_t_r
                if (max_t_r := np.max(vals)) > max_t:
                    max_t = max_t_r
                if diff_equal:
                    diff_ = np.diff(vals)
                    diff_equal = all(abs(diff_ - diff_[0]) / diff_[0] < 0.00001)

        if no_data:
            min_t = None
            max_t = None
            spacing = None
            equi_len = None
            val_counter = 0
            diff_equal = None
        else:
            spacing = diff_[0]
            equi_len = blocks_equal < 2

        data = {
            'min': min_t,
            'max': max_t,
            'equi_space': diff_equal,
            'spacing': spacing,
            'n_blocks': counter,
            'equi_block_len': equi_len,  # Allows last block of different length
            'block_len0': blocks_len0,
            'block_len_last': l_vals,
            'total_val_len': val_counter,
            'offset': 0,
            't_block_interval': t_block_interval
        }

        return data

    def push_log_entries(self, prefix_key, key, timestamps, log_entries):
        row_key = self._form_row_key(prefix_key, key)
        current_group = []
        ts_last = 0
        ts = None
        batch = self._driver.begin_put_logs_batch()

        for entry, ts in zip(log_entries, timestamps):
            current_group.append(entry)
            if ts_last + 1 < ts:
                ts_last = ts
                self._driver.put_logs(batch, row_key, '\n'.join(current_group).encode(), timestamp=ts)
                current_group = []

        if current_group:
            self._driver.put_logs(batch, row_key, '\n'.join(current_group).encode(), timestamp=ts)

        self._driver.end_put_logs_batch(batch)

    def read_logs_time_range(self, prefix_key, key, start: int, end: int):
        row_key_prefix = self._form_row_key(prefix_key, key)
        if end <= 0:
            end = datetime(year=9999, month=12, day=31).timestamp()

        for value, timestamp in self._driver.get_logs(row_key_prefix, start, end):
            yield value, timestamp

    def get_meta_data(self, prefix_key: str, key: str, meta_key: str = 'default') -> str:
        row_key = self._form_row_key(prefix_key, key, 'meta', meta_key)
        meta = self._driver.get_meta(row_key)
        return meta.decode() if meta else '{}'

    def set_meta_data(self, prefix_key, key, meta: str, meta_key='default') -> None:
        row_key = self._form_row_key(prefix_key, key, 'meta', meta_key)
        self._driver.put_meta(row_key, meta.encode())
