import logging
from datetime import timedelta
from typing import Iterable, List, Optional, Tuple

from aiohappybase.sync import Connection

from config import BIGTABLE_DATA_MAX_VERSIONS, BIGTABLE_DATA_TABLE, BIGTABLE_LOG_MAX_AGE_DAYS, BIGTABLE_LOG_TABLE, \
    BIGTABLE_MOCK_HBASE_HOST, BIGTABLE_MOCK_HBASE_PORT

log = logging.getLogger('numerous_bigtable.bigtable_mock')
log.setLevel(logging.DEBUG)


class HBaseDriver:
    DATA_TABLE = BIGTABLE_DATA_TABLE.encode()
    LOG_TABLE = BIGTABLE_LOG_TABLE.encode()

    COL_FAM_ID = 'tags'
    VALUES_COL = 'vals'

    COL_META_FAM_ID = 'meta'
    META_COL = 'meta'

    LOG_COL_FAM_ID = 'logs'
    LOG_COL = 'log'

    VALUES_COLUMN = f'{COL_FAM_ID}:{VALUES_COL}'.encode()
    META_COLUMN = f"{COL_META_FAM_ID}:{META_COL}".encode()
    LOG_COLUMN = f"{LOG_COL_FAM_ID}:{LOG_COL}".encode()

    def __init__(self):
        self._hbase = Connection(BIGTABLE_MOCK_HBASE_HOST, BIGTABLE_MOCK_HBASE_PORT)

        try:
            tables = self._hbase.tables()

            if self.DATA_TABLE not in tables:
                self._data_table = self._hbase.create_table(
                    self.DATA_TABLE,
                    {
                        self.COL_FAM_ID: dict(max_versions=BIGTABLE_DATA_MAX_VERSIONS),
                        self.COL_META_FAM_ID: dict(max_versions=BIGTABLE_DATA_MAX_VERSIONS),
                    }
                )

            else:
                self._data_table = self._hbase.table(self.DATA_TABLE)

            if self.LOG_TABLE not in tables:
                logs_ttl = timedelta(days=BIGTABLE_LOG_MAX_AGE_DAYS).total_seconds()
                self._log_table = self._hbase.create_table(
                    self.LOG_TABLE,
                    {
                        self.LOG_COL_FAM_ID: dict(time_to_live=logs_ttl),
                    }
                )

            else:
                self._log_table = self._hbase.table(self.LOG_TABLE)

        except Exception:
            self._hbase.close()
            raise

    def __del__(self):
        self._hbase.close()

    def delete_data(self, row_prefix) -> None:
        with self._data_table.batch() as batch:
            for row_id, _ in self._data_table.scan(row_prefix=row_prefix):
                batch.delete(row_id)

    def delete_logs(self, row_prefix) -> None:
        with self._log_table.batch() as batch:
            for row_id, _ in self._log_table.scan(row_prefix=row_prefix):
                batch.delete(row_id)

    def get_data(self, row_prefixes: List[bytes]) -> Iterable[Tuple[bytes, List[bytes]]]:
        row_filter = ' OR '.join(f"PrefixFilter ('{prefix.decode()}')" for prefix in row_prefixes).encode()
        for row_id, _ in self._data_table.scan(filter=row_filter, columns=[self.VALUES_COLUMN]):
            yield row_id, self._data_table.cells(row_id, self.VALUES_COLUMN)

    def begin_put_data_batch(self):
        return self._data_table.batch()

    def put_data(self, batch, row_key: bytes, data: bytes, timestamp: int) -> None:
        batch.put(row_key, {self.VALUES_COLUMN: data}, timestamp=timestamp)

    def end_put_data_batch(self, batch):
        batch.close()

    def put_meta(self, row_key: bytes, data: bytes) -> None:
        self._data_table.put(row_key, {self.META_COLUMN: data})

    def get_meta(self, row_key: bytes) -> Optional[bytes]:
        row = self._data_table.row(row_key, columns=[self.META_COLUMN])
        return row.get(self.META_COLUMN)

    def begin_put_logs_batch(self):
        return self._log_table.batch()

    def put_logs(self, batch, row_key: bytes, data: bytes, timestamp: int) -> None:
        batch.put(row_key, {self.LOG_COLUMN: data}, timestamp)

    def end_put_logs_batch(self, batch):
        batch.close()

    def get_logs(self, row_prefix: bytes, start: int, end: int) -> Iterable[Tuple[bytes, int]]:
        for row_id, _ in self._log_table.scan(row_prefix=row_prefix, columns=[self.LOG_COLUMN]):
            for value, timestamp in self._log_table.cells(row_id, self.LOG_COLUMN, include_timestamp=True):
                if start <= timestamp < end:
                    yield value, timestamp
