import _queue
import threading
from datetime import datetime
from queue import Queue
from time import sleep

import numpy as np
from scipy.interpolate import interp1d

from numerous_api_client.client.common import Interp, log


class StaticSource:
    def __init__(self, row_data, t0=0, t_end=None, dt=3600, generate_index=True):
        self.row = row_data
        self.source_type = 'static'
        # self.dt = dt
        self.t = np.inf
        self.t0 = t0
        # self.t_end = None
        self.subscribe = False

        self.generate_index = generate_index

    def get_at_time(self, t):
        self.t = t
        self.row['_index'] = t
        self.row['_index_relative'] = t - self.t0
        return self.row


class DynamicSource:
    def __init__(self, client, scenario, execution, spec, t0=0, te=0, timeout=None, source_type=None,
                 interpolation_method=Interp.zero):
        self.timeout = timeout
        self.t0 = t0
        self.meta = client.get_timeseries_meta_data(scenario=scenario, execution=execution)
        self.closed = client.get_dataset_closed(scenario=scenario, execution=execution)
        self.name = scenario
        self.spec = spec

        self.spec['offset'] = self.spec['offset'] if 'offset' in self.spec else 0

        offset = self.meta.offset + self.spec['offset']

        # start time in original absolute time
        start = t0 - offset
        end = te - offset
        # else:
        # TODO Get info from scenario
        #    offset = self.spec['offset'] - self.meta.offset
        stats = client.get_timeseries_stats(None, scenario, execution)

        # If data set is open, we will try to listen for data as its being generated if t end not reached.
        self.subscribe = not self.closed
        continuous = True  # not self.subscribe

        self.offset = offset
        # self.meta.offset + self.spec['offset']

        if stats.min + offset > t0:
            raise Exception(f'dynamic input scenario {scenario} with tags {[t.name for t in self.meta.tags]} '
                            f'is offset to the future {stats.min + offset}>{t0}')

        self.start = start
        self.end = end
        print('meta off: ', datetime.fromtimestamp(self.meta.offset))

        print('Combined offset: ', offset)
        print('Combined offset: ', datetime.fromtimestamp(offset))
        print('t0: ', datetime.fromtimestamp(t0))
        print('te: ', datetime.fromtimestamp(te))
        print('start: ', datetime.fromtimestamp(start))
        print('end: ', datetime.fromtimestamp(end))
        print('start: ', start)
        print('end: ', end)

        tags = set([s['tag'] for s in spec['tags'].values()])
        log.debug('Datasource Tags: ' + str(tags))
        self.generator = client.read_data_as_row(scenario=scenario, tags=tags,
                                                 execution=execution, start=start, end=end,
                                                 subscribe=self.subscribe)
        self.source_type = source_type
        self.row_queue = Queue(100)
        self.t = -1
        self.row_buffer = RowBuffer(tags=[s for s in spec['tags'].keys()], offset=offset, method=interpolation_method)

        def _next():
            while True:
                if not self.row_queue.full():
                    first = True
                    while continuous or first:
                        first = False
                        try:

                            row_ = self.generator.__next__()

                            self.row_queue.put(row_)

                        # TODO: check subscription logic
                        except StopIteration:
                            print(f"all data read: {self.t}")
                            if self.subscribe:
                                stats = client.get_timeseries_stats(None, scenario, execution)
                                if stats.max + self.offset > self.t:
                                    self.start = self.t - self.offset
                                    self.end = stats.max
                                    self.generator = client.read_data_as_row(tags=tags,
                                                                             scenario=scenario, execution=execution,
                                                                             start=self.start, end=self.end,
                                                                             subscribe=self.subscribe)
                                else:
                                    sleep(10)

                            if not continuous:
                                self.row_queue.put('STOP')

                                return

                            elif continuous:
                                # self.offset += t0
                                log.debug('Requerying data')
                                self.generator = client.read_data_as_row(tags=tags,
                                                                         scenario=scenario, execution=execution,
                                                                         start=self.start, end=self.end,
                                                                         subscribe=self.subscribe,
                                                                         offset=(self.end - self.start))


                else:
                    sleep(.01)

        threading.Thread(target=_next, daemon=True).start()

    def __next__(self):
        try:
            row_ = self.row_queue.get(timeout=self.timeout)
        except _queue.Empty:
            log.debug('timeout')
            raise TimeoutError('Timed out!')

        if row_ == 'STOP':
            log.debug('Iter done')
            raise StopIteration('Iterator consumed')

        row = {}
        for tag, s in self.spec['tags'].items():
            try:
                row[tag] = row_[s['tag']] * s['scale'] + s['offset']
            except:
                log.debug(str(row_))
                raise

        row['_index'] = row_['_index'] + self.offset
        row['_index_relative'] = row_['_index'] - self.t0
        row['_datetime_utc'] = datetime.utcfromtimestamp(row['_index'])
        self.t = row['_index']
        return row

    # Get at simulation time
    def get_at_time(self, t):
        while self.t <= t:
            part_row = self.__next__()

            self.row_buffer.add(part_row)
        part_row = self.row_buffer.get(t)
        return part_row


class FIFO:
    def __init__(self, buffer_len=2):
        self.buffer = None
        self.buffer_len = buffer_len

    def _init_buffer(self, row):
        self.buffer = {tag: np.zeros(self.buffer_len) for tag in row.keys()}

    def __len__(self):
        if self.buffer is None:
            return 0
        else:
            return len(self.buffer)

    def add(self, row):
        if self.buffer is None:
            self._init_buffer(row)

        for tag, val in row.items():
            self.buffer[tag] = np.roll(self.buffer[tag], -1)
            self.buffer[tag][-1] = val


class RowBuffer:
    def __init__(self, tags=None, offset=None, method=Interp.zero):
        self.tags = tags
        self.tags.append('_index')
        self.offset = offset
        self.method = method
        if method == Interp.zero:
            self.buffer_len = 2
        elif method == Interp.linear:
            self.buffer_len = 2
        else:
            raise NotImplementedError
        self._buffer = FIFO(self.buffer_len)

    def add(self, row):
        row_tags = {tag: row.get(tag) for tag in self.tags}
        self._buffer.add(row_tags)

    def get(self, t):
        tags = self.tags
        irow = {}
        if self.method == Interp.zero:
            for tag in tags[:-1]:
                ifun = interp1d(self._buffer.buffer['_index'], self._buffer.buffer[tag], kind='zero')
                irow.update({tag: ifun(t).tolist()})
        elif self.method == Interp.linear:
            for tag in tags[:-1]:
                ifun = interp1d(self._buffer.buffer['_index'], self._buffer.buffer[tag], kind='linear')
                irow.update({tag: ifun(t).tolist()})
        else:
            raise NotImplementedError('Unknown interpolation type')
        return irow


class InputManager:
    def __init__(self, sources, t, t0):
        self.sources = sources
        self.t = t
        self.t0 = t0

    def get_at_time(self, t):
        row = {}
        for source in self.sources:
            try:
                part_row = source.get_at_time(t)

            except TimeoutError as e:
                if source.subscribe:
                    return None
                else:
                    raise e
            row.update(part_row)
        # return simulation time
        row.update({'_index': t})
        return row
