import _queue
import threading
from datetime import datetime
from queue import Queue
from time import sleep, time

import numpy as np
from scipy.interpolate import interp1d

from numerous_api_client.client.common import Interp, log


class StaticSource:
    def __init__(self, row_data, t0=0, t_end=None, dt=3600, generate_index=True):
        self.row = row_data
        self.source_type = 'static'
        # self.dt = dt
        self.t = np.inf
        self.t0 = t0
        # self.t_end = None
        self.subscribe = False

        self.generate_index = generate_index

    def get_at_time(self, t):
        self.t = t
        self.row['_index'] = t
        self.row['_index_relative'] = t - self.t0
        return self.row


class DynamicSource:
    def __init__(self, client, scenario, execution, spec, t0=0, te=0, timeout=None, source_type=None,
                 interpolation_method=Interp.zero):
        self.timeout = timeout
        self.t0 = t0
        self.meta = client.get_timeseries_meta_data(scenario=scenario, execution=execution)
        self.closed = client.get_dataset_closed(scenario=scenario, execution=execution)
        self.name = scenario
        self.spec = spec
        self.spec['offset'] = self.spec['offset'] if 'offset' in self.spec else 0

        #The offset set in spec is the full offset from 0 - not relative from meta.offset
        offset = self.spec['offset']
        self.offset = offset

        # start time in original absolute time/input base time

        start = t0 - offset
        end = max(te - offset, 0)

        stats = client.get_timeseries_stats(None, scenario, execution)
        # If data set is open, we will try to listen for data as its being generated if t end not reached.
        self.subscribe = not self.closed
        # if data is closed we should repeat the query once we run out of data - called continuous mode
        continuous = not self.subscribe

        if stats.min + offset > t0:
            raise Exception(f'dynamic input scenario {scenario} with tags {[t.name for t in self.meta.tags]} '
                            f'is offset to the future {stats.min + offset}>{t0}')

        self.start = start
        self.end = end

        print('scenario: ', scenario)
        print('scenario data is closed: ', self.closed)
        print('Offset in spec: ', datetime.fromtimestamp(spec['offset']))
        print('Offset in meta: ', datetime.fromtimestamp(self.meta.offset))
        print('meta off: ', datetime.fromtimestamp(self.meta.offset))
        print('Data stats: ', client.get_timeseries_stats(project=None, scenario=scenario, execution=execution))
        print('Combined offset: ', offset)
        print('Combined offset: ', datetime.fromtimestamp(offset))
        print('t0: ', datetime.fromtimestamp(t0))
        print('te: ', datetime.fromtimestamp(te))
        print('start: ', datetime.fromtimestamp(start))
        print('end: ', datetime.fromtimestamp(end))
        print('start: ', start)
        print('end: ', end)

        tags = set([s['tag'] for s in spec['tags'].values()])

        self.generator = client.read_data_as_row(scenario=scenario, tags=tags,
                                                 execution=execution, start=start, end=end,
                                                 subscribe=self.subscribe)
        self.source_type = source_type
        self.row_queue = Queue(100)
        self.t = -1
        #Keep a counter of rows read from data source - if nothing is read we don't want to requery.
        self._rows_read = 0
        self.row_buffer = RowBuffer(tags=[s for s in spec['tags'].keys()], offset=offset, method=interpolation_method)
        #Keep the last/presumably max ix read on record - if in continuous mode this needs to be added to the requery as an offset.
        self.max_ix_read = 0
        self.last_get = None
        def _next():
            complete = False

            first = True
            while not complete:
                if not self.row_queue.full():

                    first = False
                    try:

                        row_ = self.generator.__next__()
                        self._rows_read += 1
                        self.row_queue.put(row_)
                        self.max_ix_read =row_['_index']

                    # TODO: check subscription logic
                    except StopIteration:
                        print(f"{scenario}: All data read: {self.t-offset} of {start} to {end}, read n rows: {self._rows_read}")

                        if self._rows_read <= 0:

                            #If nothing was returned the data source is empty and no use to requery.
                            print(f'No data in  {scenario}!')
                            self.row_queue.put('STOP')
                            return
                        elif not continuous:
                            #If not continuous no requery allowed.
                            print(f'End is reached for {scenario} and not continuous!')
                            self.row_queue.put('STOP')
                            return
                        else:
                            # mode is continuous requry data and shift its _index to extend the current data
                            print(self.name,': requerying')
                            log.debug('Requerying data for '+scenario)
                            #TODO: Access if we need to keep the query in memory instead of requery. Requery could lead to inconsistency perhaps? Data changed since last query?
                            self.generator = client.read_data_as_row(tags=tags,
                                                                     scenario=scenario, execution=execution,
                                                                     start=0, end=self.end,
                                                                     #we should never subscribe on requery
                                                                     subscribe=False,
                                                                     #Make sure to offset requeried _index so we extend the data.
                                                                     offset=self.max_ix_read)

                else:
                    sleep(.01)

                if self.end > 0:
                    complete = self.max_ix_read >= self.end

        threading.Thread(target=_next, daemon=True).start()

    def __next__(self):
        while True:
            try:
                if self.last_get is None:
                    self.last_get = time()
                row_ = self.row_queue.get(block=True, timeout=0.1)
                self.last_get = time()
                if row_ is not None:
                    break
            except _queue.Empty:
                if self.timeout is not None and time() - self.last_get > self.timeout:
                    log.debug('timeout')
                    raise TimeoutError('Timed out!')
                #return None

        if row_ == 'STOP':
            log.debug('Iter done')
            raise StopIteration(self.name,': Iterator consumed')

        row = {}
        for tag, s in self.spec['tags'].items():
            try:
                row[tag] = row_[s['tag']] * s['scale'] + s['offset']
            except:
                log.debug(str(row_))
                raise
        #substract meta.offset since read_as_row returns the data in input base time
        row['_index'] = row_['_index'] + self.offset
        row['_index_relative'] = row_['_index'] - self.t0
        row['_datetime_utc'] = datetime.utcfromtimestamp(row['_index'])
        self.t = row['_index']
        return row

    # Get at simulation time
    def get_at_time(self, t):
        # move past simulation time and fill buffer (if necessary)
        while self.t <= t or not self.row_buffer.full:
            part_row = self.__next__()
            if part_row is None:
                return None

            self.row_buffer.add(part_row)
        part_row = self.row_buffer.get(t)
        return part_row


class FIFO:
    def __init__(self, buffer_len=2):
        self.buffer = None
        self.buffer_len = buffer_len

    def _init_buffer(self, row):
        self.buffer = {tag: np.zeros(self.buffer_len) for tag in row.keys()}

    def __len__(self):
        if self.buffer is None:
            return 0
        else:
            return len(self.buffer)

    def add(self, row):
        if self.buffer is None:
            self._init_buffer(row)

        for tag, val in row.items():
            self.buffer[tag] = np.roll(self.buffer[tag], -1)
            self.buffer[tag][-1] = val


class RowBuffer:
    def __init__(self, tags=None, offset=None, method=Interp.zero):
        self.tags = tags
        self.tags.append('_index')
        self.offset = offset
        self.method = method
        self.full = False
        self._rows = 0
        if method == Interp.zero:
            self.buffer_len = 2
        elif method == Interp.linear:
            self.buffer_len = 2
        else:
            raise NotImplementedError
        self._buffer = FIFO(self.buffer_len)

    def add(self, row):
        row_tags = {tag: row.get(tag) for tag in self.tags}
        self._buffer.add(row_tags)

        if not self.full:
            self._rows += 1
            if self._rows >= self.buffer_len:
                self.full = True

    def get(self, t):
        tags = self.tags
        irow = {}
        if self.method == Interp.zero:
            for tag in tags[:-1]:
                ifun = interp1d(self._buffer.buffer['_index'], self._buffer.buffer[tag], kind='zero')
                irow.update({tag: ifun(t).tolist()})
        elif self.method == Interp.linear:
            for tag in tags[:-1]:
                ifun = interp1d(self._buffer.buffer['_index'], self._buffer.buffer[tag], kind='linear')
                irow.update({tag: ifun(t).tolist()})
        else:
            raise NotImplementedError('Unknown interpolation type')
        return irow


class InputManager:
    def __init__(self, sources, t, t0):
        self.sources = sources
        self.t = t
        self.t0 = t0

    def get_at_time(self, t):
        row = {}
        for source in self.sources:
            try:
                part_row = source.get_at_time(t)
                if part_row is None:
                    return None

            except TimeoutError as e:
                if source.subscribe:
                    return None
                else:
                    raise e
            row.update(part_row)
        # return simulation time
        row.update({'_index': t})
        return row
