from contextlib import contextmanager
import json
import os
import logging
import logging.handlers
import pandas as pd
from datetime import datetime
import pytz
import grpc
from numerous_api_client.python_protos import spm_pb2_grpc, spm_pb2
from queue import Queue
import _queue
import threading
from time import time, sleep
import requests
import signal
from enum import Enum
from numerous_api_client.headers.ValidationInterceptor import ValidationInterceptor
import traceback
from uuid import uuid4

class JobStatus(Enum):
    ready=0
    running=1
    finished=2
    request_termination=3
    terminated=4
    failed=5
    requested=6
    initializing=7


def get_env(val, env, default=None):

    if val is None:
        env_val = os.getenv(env)

        if env_val is None:
            if default is None:
                raise KeyError(f'ENV var <{env}> is not set.')
            else:
                return default

        return env_val
    else:
        return val

tzl = pytz.utc

log = logging.getLogger('numerous_client')
log.setLevel(logging.DEBUG)


def handle_term(signum, frame):
    log.warning('Terminated')
    raise KeyboardInterrupt()


signal.signal(signal.SIGINT, handle_term)
signal.signal(signal.SIGTERM, handle_term)


class NumerousBufferedWriter:
    def __init__(self, write_method, writer_closed_method, scenario:str=None, buffer_size:int=24):
        self.scenario = scenario
        self._tags = None
        self.ix = 0
        self._write_method = write_method
        self._writer_closed = writer_closed_method
        self._buffer = None
        self.buffer_size = buffer_size
        self.write_queue = Queue()
        self.closed = False

        self.last_flush = time()
        self.max_elapse_flush = 600
        self.min_elapse_flush = 1
        self.force_flush = False

        self.buffer_number_size = 1e6

        def must_flush(n_rows, rows_size):
            tic = time()
            buffer_counter = rows_size * n_rows
            if self.force_flush:
                self.force_flush = False

                self.last_flush = tic
                return True

            elif (n_rows>self.buffer_size and tic> self.min_elapse_flush+self.last_flush) or buffer_counter>self.buffer_number_size:
                self.last_flush = tic
                return True
            elif tic > (self.last_flush + self.max_elapse_flush):

                self.last_flush = tic

                return True

            return False

        def writer_thread_func():

            def write_generator():

                while True:
                    row = self.write_queue.get()

                    if row == 'STOP':

                        return

                    yield row

            not_done = True
            while not_done:
                try:
                    self._write_method(write_generator(), must_flush=must_flush)
                    not_done = False
                except Exception as e:
                    print(e)

        self.writer_thread = threading.Thread(target=writer_thread_func)
        self.writer_thread.start()

    def _init_buffer(self):
        self._buffer_count = 0
        #self._buffer = {t: [] for t in self._tags}
        self._buffer_timestamp = []

    def write_row(self, data):
        if not '_index' in data:
            data['_index'] = self.ix
            self.ix += 1

        if isinstance(data['_index'], datetime):
            data['_index'] = data['_index'].timestamp()
        if self.closed:
            raise ValueError('Queue is closed!')
        self.write_queue.put(data)
        return data['_index']

    def flush(self):
        self.force_flush = True
        #self._write_method(self._buffer, self.scenario)
        self._init_buffer()

    def close(self):
        if not self.closed:
            self.flush()

            self.write_queue.put('STOP')

            self.writer_thread.join()

            self._writer_closed(self.scenario)

        self.closed = True

class NumerousLogHandler:
    def __init__(self, client):
        self.client = client

    def handle(self, record):
        self.client.push_scenario_log(record.msg)

    #TODO make a handler thread to schedule update to server!

class NumerousClient:
    def __init__(self, job_id=None, project=None, scenario=None, server=None, port=None, refresh_token=None, clear_data=None, no_log=False, instance_id=None, secure=None):
        self.job_states = JobStatus
        self._project = get_env(project, 'NUMEROUS_PROJECT')
        self._scenario = get_env(scenario, 'NUMEROUS_SCENARIO')
        self._spm_job_id = get_env(job_id, 'SPM_JOB_ID')
        self._refresh_token = get_env(refresh_token, 'NUMEROUS_API_REFRESH_TOKEN')
        self._access_token = None

        self._instance_id = str(uuid4()) if instance_id is None else instance_id
        self._execution_id = get_env(None, 'NUMEROUS_EXECUTION_ID', 'not_found')
        if self._execution_id == 'not_found' or self._execution_id is None:
            self._execution_id = str(uuid4())
            print("EX", self._execution_id)
            log.debug(f'No execution id found - generating new: {self._execution_id}')

        self._complete = True
        self.writers = []


        channel = self._init_channel(server=server, port=port, secure=secure, instance_id=self._instance_id)
        stub = spm_pb2_grpc.SPMStub(channel)

        # TODO: Maybe reconsider both of these!
        self._job_manager = spm_pb2_grpc.JobManagerStub(channel)
        self._token_manager = spm_pb2_grpc.TokenManagerStub(channel)
        self._base_client = stub

        # Refresh token every 9 minutes (?)
        self._access_token_refresher = RepeatedFunction(
            interval=9*60, function=self._refresh_access_token, run_initially=True,
            refresh_token=self._refresh_token, instance_id=self._instance_id, execution_id=self._execution_id
        )
        self._access_token_refresher.start()

        self._terminated = False
        if not no_log:
            self._init_logger()
        log.info('init client')

        self.last_progress = -999999
        self.progress_debounce = 10

        self.last_state_set = -999999
        self.state_debounce = 60

        clear_data = get_env(clear_data, 'NUMEROUS_CLEAR_DATA', default='False') == 'True'

        if clear_data:
            self.clear_timeseries_data()

        self._listen = threading.Thread(target=self._listen_terminate)
        self._listen.setDaemon(True)
        self._listen.start()

    def complete_execution(self):
        self._base_client.CompleteExecution(spm_pb2.ExecutionMessage(project_id=self._project, scenario_id=self._scenario, job_id=self._spm_job_id, execution_id=self._execution_id))

    def _init_channel(self, server, port, secure=None, instance_id=None):
        secure_channel = get_env(secure, 'SECURE_CHANNEL')
        server = get_env(server, 'NUMEROUS_API_SERVER')
        port = get_env(port, 'NUMEROUS_API_PORT')

        log.info(f"Client connecting to: {server}:{port}, using secure channel: {secure_channel}")
        print(server,':',port,' secure: ', str(secure_channel))
        if str(secure_channel) == 'True':
            print('secure connect!')
            with open(f'{os.path.dirname(os.path.abspath(__file__))}/certs/server.crt', 'rb') as f:
                creds = grpc.ssl_channel_credentials(f.read())
            channel = grpc.secure_channel(f'{server}:{port}', creds)
        else:
            channel = grpc.insecure_channel(f'{server}:{port}')

        vi = ValidationInterceptor(token=self._access_token, token_callback=self._get_current_token, instance=instance_id)
        self._instance = vi.instance
        channel = grpc.intercept_channel(channel, vi)
        return channel

    def _get_current_token(self):
        return self._access_token

    def _init_logger(self):
        log_queue = Queue(-1)
        numerous_logger = logging.handlers.QueueHandler(queue=log_queue)
        logger = logging.getLogger()
        logger.setLevel(logging.ERROR)
        logger.addHandler(numerous_logger)

        self.log_listener = logging.handlers.QueueListener(log_queue, NumerousLogHandler(self))
        self.log_listener.start()

    def _refresh_access_token(self, refresh_token, instance_id, execution_id):
        print("REFRESH TOKEN", refresh_token)
        token = self._token_manager.GetAccessToken(
            spm_pb2.RefreshRequest(
                refresh_token=spm_pb2.Token(val=refresh_token), instance_id=instance_id, execution_id=execution_id
            )
        )
        print(f"TOKEN REFRESHED: {token.val}")
        self._access_token = token.val

    def _listen_terminate(self):
       for channel, message in self.subscribe_messages([".".join(['COMMAND', self._project, self._scenario, self._spm_job_id])]):
            log.debug('Received message: '+str(message)+'\n on channel: '+str(channel))
            if 'command' in message:
                command = message['command']
                if command == 'terminate':
                    log.warning('Received Termination Command')
                    signal.raise_signal(signal.SIGTERM)
                    break

    def get_scenario_document(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        scenario_data = self._base_client.GetScenario(spm_pb2.Scenario(project=self._project, scenario=scenario))

        return json.loads(scenario_data.scenario_document), scenario_data.files

    def get_job(self, scenario_doc=None):
        if scenario_doc is None:
            scenario_doc, files = self.get_scenario_document()

        return scenario_doc['jobs'][self._spm_job_id]


    def get_group_document(self, group):
        group_data = self._base_client.GetGroup(spm_pb2.Group(project=self._project, group=group))

        return json.loads(group_data.group_document)

    def get_project_document(self):
        proj_data = self._base_client.GetProject(spm_pb2.Project(project=self._project))

        return json.loads(proj_data.project_document)


    def listen_scenario_document(self, scenario=None):
        #if project is None:
        #    project = self._project
        if scenario is None:
            scenario = self._scenario


        for doc in self._base_client.ListenScenario(spm_pb2.Scenario(project=project, scenario=scenario)):
            json_doc = doc.scenario_document

            yield json.loads(json_doc), doc.files



    def get_scenario(self, scenario=None, path='.'):
        log.debug('Get scenario')
        if scenario is None:
            scenario = self._scenario

        scenario_data, scenario_files = self.get_scenario_document(scenario)

        scenario_files = [{'name': f.name, 'url': f.url, 'path': f.path}
                          for f in scenario_files]

        model_data, model_files = self.get_model(scenario_data['systemID'], project_id=self._project, scenario_id=scenario)

        model_files = [{'name': f.name, 'url': f.url, 'path': f.path}
        for f in model_files]

        file_paths_local = {}
        for f in scenario_files + model_files:

            f_path = path + '/' + f['name']
            file_paths_local[f['name']] = f_path
            r = requests.get(f['url'], allow_redirects=True)
            open(f_path, 'wb').write(r.content)

        return scenario_data, model_data, file_paths_local

    def get_model(self, model_id, project_id=None, scenario_id=None):

        model = self._base_client.GetModel(spm_pb2.Model(model_id=model_id, project_id=project_id, scenario_id=scenario_id))

        return json.loads(model.model), model.files

    def set_scenario_data_tags(self, tags, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.SetScenarioDataTags(spm_pb2.ScenarioDataTags(project=self._project, scenario=scenario, tags=tags))

    def set_scenario_progress(self, message, status, progress=None, clean=False, scenario=None, force=False):
        if scenario is None:
            scenario = self._scenario
        tic = time()

        if tic > (self.last_progress + self.progress_debounce) or force:

            self.last_progress=tic
            self._base_client.SetScenarioProgress(spm_pb2.ScenarioProgress(
                project=self._project, scenario=scenario, spm_job_id=self._spm_job_id,
                message=message, status=status, clean=clean, progress=progress
            ))

    def clear_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.ClearScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

    def set_scenario_results(self, names:list, values:list, units:list, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.SetScenarioResults(spm_pb2.ScenarioResults(project=self._project, scenario=scenario, names=names, values=values, units=units))

    def get_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        results = self._base_client.GetScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

        return results.names, results.values, results.units

    def push_scenario_error(self, error, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.PushScenarioError(spm_pb2.ScenarioError(project=self._project, scenario=scenario, error=error, spm_job_id=self._spm_job_id))

    def push_scenario_log(self, message, initialize=False, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.PushScenarioLogEntry(spm_pb2.ScenarioLogEntry(
            project=self._project, scenario=scenario, message=message,
            spm_job_id=self._spm_job_id, initialize=initialize
        ))

    def push_scenario_logs(self, logs):

        self._base_client.PushExecutionLogEntries(spm_pb2.LogEntries(
            execution_id=self._execution_id, log_entries=[l[1] for l in logs],
            timestamps=[l[0].timestamp() for l in logs]
        ))

    def delete_scenario(self, scenario=None):
        self._base_client.PushScenarioFormattedError(
            spm_pb2.ScenarioFormattedError(project=self._project, scenario=scenario, message=message, hint=hint,
                                           category=category, exception_object_type=exception_object_type,
                                           exception_object_message=exception_object_message,
                                           full_traceback=full_traceback))

    def get_download_files(self, files, local_path='.', scenario=None):
        if scenario is None:
            scenario = self._scenario
    #    project = "Xcqmc04Ckc9y7IbG6BWa"
    #    scenario = "pYyQFfXoCMZ2Tp6whbak"

        files = self._base_client.GetSignedURLs(spm_pb2.FileSignedUrls(files=[spm_pb2.FileSignedUrl(path=scenario+'/'+f) for f in files]))

        files_out = {}
        for f in files.files:

            r = requests.get(f.url, allow_redirects=True)
            local_file_path = local_path+'/'+f.name
            open(local_file_path, 'wb').write(r.content)
            files_out[f.name] = local_file_path
        return files_out

    def upload_file(self, local_file, file_id, file=None, scenario=None):
        if scenario is None:
            scenario = self._scenario

        if file is None:
            file = local_file.split('/')[-1]
        print('File: ', file)
        upload_url = self._base_client.GenerateScenarioUploadSignedURL(spm_pb2.ScenarioFilePath(project=self._project, scenario=scenario, path=file, file_id=file_id))

        def upload():
            filename = local_file

            filesize = os.path.getsize(filename)

            headers = {"Content-Type": "multipart/related"}
            params = {
                "name": file,
                "mimeType": "text/html"
            }
            r = requests.post(
                upload_url.url,
                headers=headers,
                params=params,
                data=open(filename, 'rb')
            )

        upload()


    def get_timeseries_meta_data(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        return self._base_client.GetScenarioMetaData(spm_pb2.Scenario(scenario=scenario))

    def get_timeseries_custom_meta_data(self, scenario=None, key=None):
        if scenario is None:
            scenario = self._scenario

        return self._base_client.GetScenarioCustomMetaData(spm_pb2.ScenarioCustomMetaData(scenario=scenario, key=key)).meta

    def set_timeseries_custom_meta_data(self, meta:str, scenario=None, key=None):
        if scenario is None:
            scenario = self._scenario

        return self._base_client.SetScenarioCustomMetaData(spm_pb2.ScenarioCustomMetaData(scenario=scenario, key=key, meta=meta))

    def get_state(self, scenario=None):
        try:
            state_json = self.get_timeseries_custom_meta_data(scenario=None, key='state')
            return json.loads(state_json)
        except json.decoder.JSONDecodeError:
            #print('bad json: ', state_json)
            return None

    def set_state(self, state, scenario=None, force=False):
        tic = time()

        if tic > (self.last_state_set + self.state_debounce) or force:
            log.debug('State set.')
            self.last_state_set = tic
            return self.set_timeseries_custom_meta_data(meta=json.dumps(state),scenario=scenario, key='state')

    def data_read_df(self, tags=[], start:datetime=None, end:datetime=None, scenario=None):
        if not scenario:
            scenario = self._scenario

        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(scenario=scenario, tags=tags, start=start, end=start))

        #data = {r.tag: r.values for l in read_data for r in l.data}
        data = {}
        min_len = 1e12
        for l in read_data:

            for r in l.data:
                if not r.tag in data:
                    data[r.tag] = []

                data[r.tag]+=r.values

        for v in data.values():
            if len(v) < min_len:
                min_len = len(v)

        data_ = {}

        for k,v in data.items():
            data_[k] = v[:min_len]


        df = pd.DataFrame(data_)

        return df

    def read_data_stream(self, tags, start:datetime=None, end:datetime=None, scenario=None):

        time_range= False
        if isinstance(start, datetime):
            time_range = True
            start = start.timestamp()
        elif isinstance(start, float):
            time_range = True

        if isinstance(end, datetime):
            time_range = True
            end = end.timestamp()
        elif isinstance(end, float):
            time_range = True

        if start is None:
            start = 0

        if end is None:
            end = 0

        if not scenario:
            scenario = self._scenario

        read_data = self._base_client.ReadData(spm_pb2.ReadScenario(scenario=scenario, tags=tags, start=start, end=end, time_range = time_range))

        for l in read_data:
            for r in l.data:
                yield r.tag, r.values

    def read_data_as_row(self, tags=[], start:int=0, end:int=0, scenario=None, subscribe=False, offset=0):

        log.debug('Reading data as row')
        log.debug('Subscribed: ' + str(subscribe) + ', starting from: '+str(start))
        if not scenario:
            scenario = self._scenario
        time_range = (start > 0 or end > 0)
        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(scenario=scenario, tags=tags, start=start, end=end, time_range=time_range, listen=subscribe))
        current_row = {}
        end = end if end>0 else 9999999
        for b in read_data:

            for d in b.data:
                if not d.tag in current_row:
                    current_row[d.tag]=list(d.values)
                else:
                    current_row[d.tag] += list(d.values)

            if b.row_complete:
                len_index = len(current_row['_index'])

                wrong_len_tags = []
                for k in list(current_row.keys()):
                    if (len_cr:=len(current_row[k])) != len_index:
                        wrong_len_tags.append((k, len_cr))
                if len(wrong_len_tags)>0:
                    raise ValueError(f'The following tags have a wrong number of data points returned (_index has {len_index}: '+"\n".join([tag[0] +': '+str(tag[1]) for tag in wrong_len_tags]))

                for i in range(len_index):
                    current_ix = current_row['_index'][i] if time_range else i

                    if current_ix >= start and current_ix < end:

                        r = {k: v[i] for k,v in current_row.items()}
                        r['_index']+=offset
                        yield r

                current_row = {}

        if len(list(current_row.keys()))>0:
            for i, ix in enumerate(current_row['_index']):
                r = {k: v[i] for k, v in current_row.items()}
                r['_index'] += offset
                yield r


        log.debug('Reading as row complete')
        return

    def clear_timeseries_data(self, scenario=None):
        log.warning('Clearing data!')

        if scenario is None:
            scenario = self._scenario

        self._base_client.ClearData(
            spm_pb2.Scenario(scenario=scenario))

    def data_write_df(self, df:pd.DataFrame, scenario=None, clear=True):
        if scenario is None:
            scenario = self._scenario

        if clear:
            self.clear_timeseries_data(scenario=scenario)

        #df['_index']=[t.timestamp() for t in df.index]
        len_ix = len(df['_index'])
        i=0
        data=[]
        while i<len_ix:
            i_ = i + 10000
            data.append([{'tag': c, 'values': df[c].values[i:i_]} for c in list(df)])
            i=i_

        def data_writer():
            size = 0
            b_size = len(df['_index'])

            blocks = []
            for r in data:
                for d in r:
                    size += b_size
                    blocks.append(spm_pb2.DataBlock(**d))
                    if size>2e5:
                        yield spm_pb2.DataList(scenario=scenario, data=blocks, clear=False, block_complete=False, row_complete=False)
                        size=0
                        blocks = []
                if size>0:
                    yield spm_pb2.DataList(scenario=scenario, data=blocks, clear=False, block_complete=True,
                                           row_complete=False)

        tic = time()
        self._base_client.WriteDataList(data_writer())

        toc = time()
        #self._base_client.CloseData(spm_pb2.DataCompleted(scenario=scenario))

    def _rows_to_blocks(self, rows):
        blocks = []
        data = {}
        for r in rows:
            for k, v in r.items():
                if not k in data:
                    data[k] = []
                data[k].append(v)

        for k, v in data.items():

            blocks.append(spm_pb2.DataBlock(tag=k, values=v))


        return blocks

    def write_with_row_generator(self, data_generator,  scenario:str=None, clear = False, must_flush=None):
        if must_flush is None:
            def must_flush_(n):
                return True
            must_flush = must_flush_

        #clear_ = clear
        if scenario is None:
            scenario = self._scenario


        def gen(clear):

            rows = []
            for row in data_generator:
                if row is not None:
                    rows.append(row)
                if len(rows)>0 and must_flush(len(rows), len(row)):

                    blocks = self._rows_to_blocks(rows)
                    rows = []

                    yield spm_pb2.DataList(scenario=scenario, data=blocks, clear=clear, block_complete=False, row_complete=True)
                    clear = False

            #if len(rows)>0:
            blocks = self._rows_to_blocks(rows)
            yield spm_pb2.DataList(scenario=scenario, data=blocks, clear=clear, block_complete=True, row_complete=True)

            return

        self._base_client.WriteDataList(gen(clear))

    def new_writer(self, buffer_size=100, clear=False):
        if clear:
            self.clear_timeseries_data(self._scenario)

        return NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

    def writer_closed(self, scenario):
        log.debug('Writer Closed')
        self._base_client.CloseData(spm_pb2.DataCompleted(scenario=scenario, eta=-1, finalized=True))

    def read_configuration(self, name):
        return self._base_client.ReadConfiguration(spm_pb2.ConfigurationRequest(name=name))

    def store_configuration(self, name, description, datetime_, user, configuration, comment, tags):
        return self._base_client.StoreConfiguration(spm_pb2.Configuration(name=name, description=description, datetime=datetime_, user=user, configuration=configuration, comment=comment, tags=tags))

    @contextmanager
    def open_writer(self, aliases:dict=None, clear=False, buffer_size=24*7) -> NumerousBufferedWriter:

        if clear:
            self.clear_timeseries_data(self._scenario)
        writer_ = NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

        self.writers.append(writer_)

        try:
            yield writer_
        finally:
            writer_.close()

    def get_inputs(self, scenario_data, t0=0, dt=3600, tag_prefix='Stub', tag_seperator='.', timeout=10, subscribe=False):
        input_sources = {}
        static_data = {}
        input_source_types = {}
        only_static = True

        for c in scenario_data['simComponents']:
            if "subcomponents" in c:
                for sc in c["subcomponents"]:
                    for sc_ in scenario_data['simComponents']:
                        if sc['uuid'] == sc_['uuid']:
                            sc_['isSubcomponent'] = True

        # Get tags to update and sources
        for sc in scenario_data['simComponents']:
            if not sc['disabled'] and (sc['isMainComponent'] or sc['isSubcomponent']):
                for iv in sc['inputVariables']:

                    tag_ = tag_seperator.join(filter(None, [tag_prefix, sc['name'], iv['id']]))

                    if iv['dataSourceType'] == 'static':
                        static_data[tag_] = iv['value']


                    elif iv['dataSourceType'] == 'scenario':

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = {}

                        input_sources[iv['dataSourceID']][tag_] = {'tag': iv['tagSource']['tag'], 'scale': iv['scaling'],
                                                                 'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']]='scenario'

                    elif iv['dataSourceType'] in ['dataset','csv']:

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = {}

                        input_sources[iv['dataSourceID']][tag_] = {'tag': iv['tagSource']['tag'], 'scale': iv['scaling'],
                                                                   'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']]='csv'

                    else:
                        raise ValueError('Unknown data source type: ' + iv['dataSourceType'])

        # First add the static source
        log.info('static only: '+ str(only_static))
        inputs = [StaticSource(static_data, t0=t0, dt=dt, generate_index=only_static)]
        log.debug('input source types: ' + str(input_source_types))
        # Then add dynamic sources
        for input_source, spec in input_sources.items():
            inputs.append(DynamicSource(client=self, scenario=input_source, spec=spec, t0=t0, timeout=timeout, source_type=input_source_types[input_source], subscribe=subscribe))

        input_manager = InputManager(inputs)
        return input_manager

    def iter_inputs(self, scenario_data, t0=0, dt=3600, tag_prefix='Stub', tag_seperator='.', timeout=10, subscribe=False):

        while True:
            for r in self.get_inputs(scenario_data, t0=t0, dt=dt, tag_prefix=tag_prefix, tag_seperator=tag_seperator, timeout=timeout, subscribe=subscribe):
                yield r
            return

    def publish_message(self, message):
        channel = ".".join(['EVENT', self._project, self._scenario])

        self._base_client.PublishSubcsriptionMessage(spm_pb2.SubscriptionMessage(channel=channel, message=json.dumps(message)))
    #def subscribe_channels(self, channels):
    #    self._base_client.SubscribeForUpdates(spm_pb2.)
    def subscribe_messages(self, channel_patterns):
        for message in self._base_client.SubscribeForUpdates(spm_pb2.Subscription(channel_patterns=channel_patterns)):
            yield message.channel, json.loads(message.message)

    def refresh_token(self, refresh_token):
        refreshed_token = self._token_manager.GetAccessToken(spm_pb2.Token(val=refresh_token))
        return refreshed_token

    def close(self):
        #close all writers
        for w in self.writers:
            w.close()
        self._access_token_refresher.stop()

        if hasattr(self,'log_listener'):
            self.log_listener.stop()
        if self._complete:
            self.complete_execution()


@contextmanager
def open_client(job_id=None, project=None, scenario=None, clear_data=None, no_log=False, instance_id=None, refresh_token=None) -> NumerousClient:
    numerous_client_ = NumerousClient(job_id, project, scenario, clear_data=clear_data, no_log=no_log, instance_id=instance_id, refresh_token=refresh_token)

    try:
        yield numerous_client_

    except Exception as e:
        #print("CLIENT FAILED WITH ERROR CODE: ", e, traceback.format_exc())
        numerous_client_.push_scenario_error(error=traceback.format_exc())

    finally:
        numerous_client_.close()

class StaticSource:
    def __init__(self, row_data, t0=0, t_end=None, dt=3600, generate_index=True):
        self.row = row_data
        self.source_type = 'static'
        self.dt = dt
        self.t = t0
        self.t_end = None

        self.generate_index = generate_index

    def __next__(self):
        if self.generate_index:
            self.row['_index'] = self.t
            self.t += self.dt

            if self.t_end is not None and self.t>self.t_end:
                raise StopIteration('Reached t end of '+str(self.t_end))


        return self.row

class DynamicSource:
    def __init__(self, client, scenario, spec, t0=0, timeout=None, source_type=None, subscribe=False):
        self.timeout = timeout

        if source_type == 'csv':
            t0/=3600
            continuous = True
            subscribe = False
            log.debug('Continous mode set on input source')
        else:
            continuous = False

        self.generator = client.read_data_as_row(tags=[s['tag'] for s in spec.values()], scenario=scenario, start=t0, subscribe=subscribe)
        self.spec = spec
        self.source_type = source_type
        self.row_queue = Queue(100)

        def _next():
            offset = 0
            while True:
                if not self.row_queue.full():
                    first = True
                    while continuous or first:
                        first = False
                        try:
                            row_ = self.generator.__next__()
                            #print(row_['_index'])
                            offset = row_['_index']
                            self.row_queue.put(row_)

                        except StopIteration:
                            if not continuous:
                                self.row_queue.put('STOP')

                                return
                            else:
                                log.debug('Requerying data')
                                self.generator = client.read_data_as_row(tags=[s['tag'] for s in self.spec.values()],
                                                                         scenario=scenario, start=0,
                                                                         subscribe=subscribe, offset=offset)

                else:
                    sleep(.01)



        threading.Thread(target=_next, daemon=True).start()

    def __next__(self):
        try:

            row_ = self.row_queue.get(timeout=self.timeout)
        except _queue.Empty:
            log.debug('timeout')
            raise TimeoutError('Timed out!')

        if row_ == 'STOP':
            log.debug('Iter done')
            raise StopIteration('Iterator consumed')

        row = {}
        for tag, s in self.spec.items():
            row[tag] = row_[s['tag']]*s['scale'] + s['offset']
        row['_index'] = row_['_index']

        return row

class InputManager:
    def __init__(self, sources):
        self.sources = sources

    def get_data_from_sources(self):
        row = {}
        for source in self.sources:
            part_row = source.__next__()
            if isinstance(part_row, dict):
                if '_index' in row and '_index' in part_row and row['_index'] != part_row['_index']:
                    raise IndexError('Input sources indices misaligned or not equal!')
                #print(part_row)
                #TODO REMOVE HOT FIX BECAUSE CSV DATA IS USING HOURS NOT SECONDS!!!
                if '_index' in part_row and source.source_type == 'csv':
                    part_row['_index'] *= 3600

                row.update(part_row)
            else:
                raise ValueError('!')
        return row

    def __iter__(self):
        while True:

            try:
              row = self.get_data_from_sources()
            except StopIteration:

                return

            yield row


class RepeatedFunction:
    def __init__(self, interval, function, run_initially=False, *args, **kwargs):
        self._timer = None
        self.interval = interval
        self.function = function
        self.args = args
        self.kwargs = kwargs
        self.is_running = False
        self.next_call = time()

        if run_initially:
            self.function(*self.args, **self.kwargs)

    def _run(self):
        self.is_running = False
        self.start()
        self.function(*self.args, **self.kwargs)

    def start(self):
        if not self.is_running:
            self.next_call += self.interval
            self._timer = threading.Timer(self.next_call - time(), self._run)
            self._timer.start()
            self.is_running = True

    def stop(self):
        self._timer.cancel()
        self.is_running = False


class NumerousAdminClient(NumerousClient):
    def __init__(self, server=None, port=None, secure=None, refresh_token=None, instance_id=None):
        self.writers = []
        self._user = 'test_user'
        self._organization = 'EnergyMachines'
        self._execution_id = get_env(None, 'NUMEROUS_EXECUTION_ID', 'not_found')

#        self._token = get_env(token, 'NUMEROUS_API_TOKEN')
        self._access_token = None
        self._refresh_token = get_env(refresh_token, 'NUMEROUS_API_REFRESH_TOKEN')
        self._instance_id = str(uuid4()) if instance_id is None else instance_id
        secure = get_env(secure, 'SECURE_CHANNEL')

        channel = self._init_channel(server=server, port=port, secure=secure)
        stub = spm_pb2_grpc.SPMStub(channel)
        self._token_manager = spm_pb2_grpc.TokenManagerStub(channel)

        # Refresh token every 9 minutes (?)
        self._access_token_refresher = RepeatedFunction(
            interval=9 * 60, function=self._refresh_access_token, run_initially=True,
            refresh_token=self._refresh_token, instance_id=self._instance_id, execution_id=self._execution_id
        )
        self._access_token_refresher.start()




        # TODO: Maybe reconsider
        self._job_manager = spm_pb2_grpc.JobManagerStub(channel)

        self._base_client = stub
        self._terminated = False
        self._complete = False

        self._execution_id = get_env(None, 'NUMEROUS_EXECUTION_ID', 'not_found')
        if self._execution_id == 'not_found' or self._execution_id is None:
            self._execution_id = str(uuid4())
            log.debug(f'No execution id found - generating new: {self._execution_id}')

        log.info('init admin client')

    """
    # TODO: Remove? The client should not be allowed to stop/start instances of clients!
    def start_image(self, scenario, image, name, namespace, spm_job_id, google_project='simulationwidgets'):
        deployed_image = self._job_manager.StartImage(spm_pb2.ImageInformation(
            project_id=self._project, scenario_id=scenario,
            google_project=google_project, image=image,
            deployment_name=name, namespace=namespace,
            spm_job_id=spm_job_id
        ))

        return {'kubernetes_id': deployed_image.kubernetes_id, 'namespace': deployed_image.namespace}"""

    def launch_job(self, project, scenario, job):
        self._job_manager.StartJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job, user_id=self._user, organization_id=self._organization
        ))

    def terminate_job(self, project, scenario, job):
        terminated_job = self._job_manager.TerminateJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job, user_id=self._user, organization_id=self._organization
        ))

    def reset_job(self, project, scenario, job):
        reset_job = self._job_manager.ResetJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job, user_id=self._user, organization_id=self._organization
        ))

    def listen_executions(self):
        for doc in self._base_client.ListenExecutions(spm_pb2.Empty()):
            yield json.loads(doc.json)

    def update_execution(self, update_):
        self._base_client.UpdateExecution(spm_pb2.Json(json=json.dumps(update_)))

    def get_execution_status(self, exe_id):
        return self._job_manager.GetExecutionStatus(spm_pb2.ExecutionId(execution_id=exe_id)).json

    def update_job_by_backend(self, project, scenario, job, exe, message, status, log=None, complete=False):
        self._base_client.SetScenarioProgress(spm_pb2.ScenarioProgress(
            project=project, scenario=scenario, spm_job_id=job,
            message=message, status=status, progress=0
        ))

        if log is not None:
            self._base_client.PushScenarioLogEntry(spm_pb2.ScenarioLogEntry(
                project=project, scenario=scenario, spm_job_id=job, initialize=False, message=log
            ))

        if complete:
            self._base_client.CompleteExecutionIgnoreInstance(
                spm_pb2.ExecutionMessage(project_id=project, scenario_id=scenario, job_id=job,
                                         execution_id=exe))

        return spm_pb2.Empty()

    def read_execution_logs(self):
        for l in self._base_client.ReadExecutionLogEntries(spm_pb2.ExecutionReadLogs(execution_id=self._execution_id, start=0.0, end=0.0)):
            yield l.log_entry, datetime.fromtimestamp(l.timestamp)

@contextmanager
def open_admin_client(server=None, port=None, secure=None, refresh_token=None) -> NumerousAdminClient:
    numerous_client_ = NumerousAdminClient(server=server, port=port, secure=secure, refresh_token=refresh_token)

    try:
        yield numerous_client_

    finally:
        numerous_client_.close()