import logging
from datetime import datetime, timedelta
from time import sleep
from typing import Iterable, List, Optional, Tuple

from pytz import utc as UTC
from google.api_core.exceptions import ResourceExhausted
from google.auth.transport.requests import Request
from google.cloud.bigtable import column_family, row_filters, row_set
from google.cloud.bigtable.client import Client
from google.oauth2 import service_account

from config import BIGTABLE_DATA_MAX_VERSIONS, BIGTABLE_DATA_TABLE, BIGTABLE_INSTANCE, BIGTABLE_LOG_MAX_AGE_DAYS, \
    BIGTABLE_LOG_TABLE, \
    BIGTABLE_SERVICE_ACCOUNT

log = logging.getLogger('numerous_bigtable.bigtable')
log.setLevel(logging.DEBUG)


def _get_credentials():
    credentials = service_account.Credentials.from_service_account_info(BIGTABLE_SERVICE_ACCOUNT, scopes=[
        'https://www.googleapis.com/auth/cloud-platform',
        'https://www.googleapis.com/auth/userinfo.email'
    ])

    request = Request()
    credentials.refresh(request)

    return credentials


class BigTableDriver:
    DATA_TABLE = BIGTABLE_DATA_TABLE.encode()
    LOG_TABLE = BIGTABLE_LOG_TABLE.encode()

    COL_FAM_ID = 'tags'
    VALUES_COL = 'vals'

    COL_META_FAM_ID = 'meta'
    META_COL = 'meta'

    LOG_COL_FAM_ID = 'logs'
    LOG_COL = 'log'

    VALUES_COLUMN = f'{COL_FAM_ID}:{VALUES_COL}'.encode()
    META_COLUMN = f"{COL_META_FAM_ID}:{META_COL}".encode()
    LOG_COLUMN = f"{LOG_COL_FAM_ID}:{LOG_COL}".encode()

    def __init__(self):
        self._client = Client(admin=True, credentials=_get_credentials())
        self._instance = self._client.instance(BIGTABLE_INSTANCE)

        self._table = self._instance.table(BIGTABLE_DATA_TABLE)
        if not self._table.exists():
            max_versions_rule = column_family.MaxVersionsGCRule(BIGTABLE_DATA_MAX_VERSIONS)
            self._table.create(column_families={self.COL_FAM_ID: max_versions_rule, self.COL_META_FAM_ID: max_versions_rule})

        self._table_logs = self._instance.table(BIGTABLE_LOG_TABLE)
        if not self._table_logs.exists():
            max_age_rule = column_family.MaxAgeGCRule(timedelta(days=BIGTABLE_LOG_MAX_AGE_DAYS))
            self._table_logs.create(column_families={self.LOG_COL_FAM_ID: max_age_rule})

    def delete_data(self, row_prefix) -> None:
        while True:
            try:
                self._table.drop_by_prefix(row_prefix, 5000)
                break

            except ResourceExhausted:
                log.debug('Delete exhausted, trying again...')
                sleep(1)

    def delete_logs(self, row_prefix) -> None:
        while True:
            try:
                self._table_logs.drop_by_prefix(row_prefix, 5000)
                break

            except ResourceExhausted:
                log.debug('Delete exhausted, trying again...')
                sleep(1)

    def get_data(self, row_prefixes: List[bytes]) -> Iterable[Tuple[bytes, List[bytes]]]:
        data_row_set = row_set.RowSet()
        for row_prefix in row_prefixes:
            data_row_set.add_row_range_with_prefix(row_prefix.decode())
        for row in self._table.read_rows(row_set=data_row_set):
            data = [d.value for d in row.to_dict().get(self.VALUES_COLUMN)]
            yield row.row_key, data

    def begin_put_data_batch(self):
        return []

    def put_data(self, batch, row_key: bytes, data: bytes, timestamp: int) -> None:
        dr = self._table.direct_row(row_key)
        dr.set_cell(self.COL_FAM_ID, self.VALUES_COL.encode(), data,
                    timestamp=datetime.utcfromtimestamp(timestamp).replace(tzinfo=UTC))
        batch.append(dr)

    def end_put_data_batch(self, batch):
        self._table.mutate_rows(batch)

    def put_meta(self, row_key: bytes, data: bytes) -> None:
        dr = self._table.direct_row(row_key)
        dr.set_cell(self.COL_META_FAM_ID, self.META_COL.encode(), data)
        dr.commit()

    def get_meta(self, row_key: bytes) -> Optional[bytes]:
        row = self._table.read_row(row_key)
        if row is None:
            return None

        meta_cells = row.to_dict().get(self.META_COLUMN)
        if not meta_cells:
            return None

        return meta_cells[0].value

    def begin_put_logs_batch(self):
        return []

    def put_logs(self, batch, row_key: bytes, data: bytes, timestamp: int) -> None:
        dr = self._table_logs.direct_row(row_key)
        dr.set_cell(self.LOG_COL_FAM_ID, self.LOG_COL, data,
                    timestamp=datetime.utcfromtimestamp(timestamp).replace(tzinfo=UTC))
        batch.append(dr)

    def end_put_logs_batch(self, batch):
        self._table_logs.mutate_rows(batch)

    def get_logs(self, row_prefix: bytes, start: int, end: int) -> Iterable[Tuple[bytes, int]]:
        data_row_set = row_set.RowSet()
        data_row_set.add_row_range_with_prefix(row_prefix.decode())
        rows = self._table_logs.read_rows(
            row_set=data_row_set,
            filter_=row_filters.TimestampRangeFilter(
                row_filters.TimestampRange(start=datetime.utcfromtimestamp(start).replace(tzinfo=UTC),
                                           end=datetime.utcfromtimestamp(end).replace(tzinfo=UTC))
            )
        )
        for row in rows:
            for d in row.to_dict().get(self.LOG_COLUMN):
                yield d.value, d.timestamp.timestamp()
