import json
import logging
import logging.handlers
import os
import signal
import threading
import traceback
from contextlib import contextmanager
from datetime import datetime
from queue import Queue
from time import time
from uuid import uuid4

import grpc
import numpy as np
import requests

from numerous_api_client.client.common import JobStatus, get_env, log, RepeatedFunction
from numerous_api_client.client.data_source import StaticSource, DynamicSource, InputManager
from numerous_api_client.client.numerous_buffered_writer import NumerousBufferedWriter
from numerous_api_client.client.numerous_log_handler import NumerousLogHandler
from numerous_api_client.headers.ValidationInterceptor import ValidationInterceptor
from numerous_api_client.python_protos import spm_pb2_grpc, spm_pb2, health_pb2_grpc

try:
    import pandas as pd
except:
    pass


class NumerousClient:
    def __init__(self, job_id=None, project=None, scenario=None, server=None, port=None, refresh_token=None,
                 clear_data=None, no_log=False, instance_id=None, secure=None, url=None):
        if url is not None:
            from urllib.parse import urlparse
            parsed_url = urlparse(url)

            server = parsed_url.netloc.split(':')[0]
            port = parsed_url.netloc.split(':')[1]
            secure = parsed_url.scheme == "https"

        self._t = [0]
        self.job_states = JobStatus
        self._project = get_env(project, 'NUMEROUS_PROJECT')
        self._scenario = get_env(scenario, 'NUMEROUS_SCENARIO')
        self._job_id = get_env(job_id, 'JOB_ID')
        self._refresh_token = get_env(refresh_token, 'NUMEROUS_API_REFRESH_TOKEN')
        self._access_token = None

        self._instance_id = str(uuid4()) if instance_id is None else instance_id
        self._execution_id = get_env(None, 'NUMEROUS_EXECUTION_ID', 'not_found')
        if self._execution_id == 'not_found' or self._execution_id is None:
            self._execution_id = str(uuid4()).replace('-', '_')
            log.info(f'No execution id found - generating new: {self._execution_id}')

        self._complete = True
        self._hibernating = False
        self._error = False

        self.writers = []

        self.channel = self._init_channel(server=server, port=port, secure=secure, instance_id=self._instance_id)
        self.stub = spm_pb2_grpc.SPMStub(self.channel)
        self._token_manager = spm_pb2_grpc.TokenManagerStub(self.channel)
        self._health = health_pb2_grpc.HealthStub(self.channel)
        self._base_client = self.stub

        # Refresh token every 9 minutes (?)
        self._access_token_refresher = RepeatedFunction(
            interval=9 * 60, function=self._refresh_access_token, run_initially=True,
            refresh_token=self._refresh_token, instance_id=self._instance_id, execution_id=self._execution_id,
            project_id=self._project, scenario_id=self._scenario, job_id=self._job_id
        )
        self._access_token_refresher.start()

        self._terminated = False
        if not no_log:
            self._init_logger()

        self.last_progress = -999999
        self.progress_debounce = 10

        self.last_state_set = -999999
        self.state_debounce = 60

        clear_data = get_env(clear_data, 'NUMEROUS_CLEAR_DATA', default='False') == 'True'

        # If clearing data, the job was not resumed and vice versa
        self._was_resumed = not clear_data

        self._hibernate_callback = self._default_hibernation_callback
        self._hibernate_callback_args = ([], {})

        if clear_data:
            self.clear_timeseries_data()

        self._listen = threading.Thread(target=self._command_listener)
        self._listen.setDaemon(True)
        self._listen.start()

    def __enter__(self):
        return self

    def __exit__(self, _type, _value, _traceback):
        self.close()

    def complete_execution(self):
        log.debug(
            f'Completing job {self._job_id} with execution {self._execution_id}. Hibernating: {self._hibernating}')
        self._base_client.CompleteExecution(spm_pb2.CompleteExecutionMessage(
            execution=spm_pb2.ExecutionMessage(
                project_id=self._project, scenario_id=self._scenario,
                job_id=self._job_id, execution_id=self._execution_id
            ), hibernate=self._hibernating))

    def _init_channel(self, server, port, secure=None, instance_id=None):
        secure_channel = get_env(secure, 'SECURE_CHANNEL')
        server = get_env(server, 'NUMEROUS_API_SERVER')
        port = get_env(port, 'NUMEROUS_API_PORT')

        log.info(f"Client connecting to: {server}:{port}, using secure channel: {secure_channel}")
        if str(secure_channel) == 'True':
            creds = grpc.ssl_channel_credentials()
            channel = grpc.secure_channel(f'{server}:{port}', creds)
        else:
            channel = grpc.insecure_channel(f'{server}:{port}')

        vi = ValidationInterceptor(token=self._access_token, token_callback=self._get_current_token,
                                   instance=instance_id)
        self._instance = vi.instance
        channel = grpc.intercept_channel(channel, vi)
        return channel

    def _get_current_token(self):
        return self._access_token

    def _init_logger(self):
        log_queue = Queue(-1)
        numerous_logger = logging.handlers.QueueHandler(queue=log_queue)
        logger = logging.getLogger()
        logger.setLevel(logging.ERROR)
        logger.addHandler(numerous_logger)
        self.log_handler = NumerousLogHandler(self)
        self.log_listener = logging.handlers.QueueListener(log_queue, self.log_handler)
        self.log_listener.start()

    def _refresh_access_token(self, refresh_token, instance_id, execution_id, project_id=None, scenario_id=None,
                              job_id=None):
        token = self._token_manager.GetAccessToken(
            spm_pb2.RefreshRequest(
                refresh_token=spm_pb2.Token(val=refresh_token), instance_id=instance_id, execution_id=execution_id,
                project_id=project_id, scenario_id=scenario_id, job_id=job_id
            )
        )
        self._access_token = token.val

    def _command_listener(self):
        for channel, message in \
                self.subscribe_messages([".".join(['COMMAND', self._project, self._scenario, self._job_id])]):
            log.debug('Received message: ' + str(message) + '\n on channel: ' + str(channel))
            if 'command' in message:
                command = message['command']
                if command == 'terminate':
                    log.warning('Received Termination Command')
                    signal.raise_signal(signal.SIGTERM)
                    break

                if command == 'hibernate':
                    log.warning("Received Hibernate Command")
                    self.hibernate(message='Hibernating')
                    break

    def get_scenario_document(self, scenario=None, project=None):
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project

        scenario_data = self._base_client.GetScenario(spm_pb2.Scenario(project=project, scenario=scenario))

        return json.loads(scenario_data.scenario_document), scenario_data.files

    def get_job(self, scenario_doc=None):
        if scenario_doc is None:
            scenario_doc, files = self.get_scenario_document()

        return scenario_doc['jobs'][self._job_id]

    def get_group_document(self, group):
        group_data = self._base_client.GetGroup(spm_pb2.Group(project=self._project, group=group))

        return json.loads(group_data.group_document)

    def get_project_document(self):
        proj_data = self._base_client.GetProject(spm_pb2.Project(project=self._project))

        return json.loads(proj_data.project_document)

    def listen_scenario_document(self, scenario=None, project=None):
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project

        for doc in self._base_client.ListenScenario(spm_pb2.Scenario(project=project, scenario=scenario)):
            json_doc = doc.scenario_document

            yield json.loads(json_doc), doc.files

    def get_scenario(self, scenario=None, project=None, path='.'):
        os.makedirs(path, exist_ok=True)

        log.debug('Get scenario')
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project

        scenario_data, scenario_files = self.get_scenario_document(scenario, project)

        scenario_files = [{'name': f.name, 'url': f.url, 'path': f.path}
                          for f in scenario_files]

        model_data, model_files = self.get_model(scenario_data['systemID'], project_id=project, scenario_id=scenario)

        model_files = [{'name': f.name, 'url': f.url, 'path': f.path} for f in model_files]

        file_paths_local = {}
        for f in scenario_files + model_files:
            f_path = path + '/' + f['name']
            file_paths_local[f['name']] = f_path
            r = requests.get(f['url'], allow_redirects=True)
            open(f_path, 'wb').write(r.content)

        return scenario_data, model_data, file_paths_local

    def get_model(self, model_id, project_id=None, scenario_id=None):
        if project_id is None:
            project_id = self._project

        model = self._base_client.GetModel(
            spm_pb2.Model(model_id=model_id, project_id=project_id, scenario_id=scenario_id))

        return json.loads(model.model), model.files

    def set_scenario_job_image(self, name, path, project=None, scenario=None, job=None):
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project
        if job is None:
            job = self._job_id

        self._base_client.SetScenarioJobImage(spm_pb2.ScenarioJobImage(
            project_id=project, scenario_id=scenario, job_id=job,
            name=name, path=path
        ))

    def set_scenario_progress(self, message, status, progress=None, clean=False, scenario=None, force=False,
                              project=None, job=None):
        if status == 'failed':
            self._error = True

        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project
        if job is None:
            job = self._job_id

        tic = time()

        if tic > (self.last_progress + self.progress_debounce) or force:
            self.last_progress = tic
            self._base_client.SetScenarioProgress(spm_pb2.ScenarioProgress(
                project=project, scenario=scenario, job_id=job,
                message=message, status=status, clean=clean, progress=progress
            ))

    def clear_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.ClearScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

    def set_scenario_results(self, names: list, values: list, units: list, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.SetScenarioResults(
            spm_pb2.ScenarioResults(project=self._project, scenario=scenario, names=names, values=values, units=units))

    def get_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        results = self._base_client.GetScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

        return results.names, results.values, results.units

    def get_scenario_results_document(self, scenario=None, execution=None):
        if scenario is None:
            scenario = self._scenario

        results = self._base_client.GetScenarioResultDocument(spm_pb2.Scenario(project=self._project, scenario=scenario, execution=execution))

        return json.loads(results.result)

    def set_scenario_results_document(self, results_dict, scenario=None, execution=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        self._base_client.SetScenarioResultDocument(spm_pb2.ScenarioResultsDocument(project=self._project, scenario=scenario, execution=execution, result=json.dumps(results_dict)))


    def push_scenario_error(self, error, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.PushScenarioError(
            spm_pb2.ScenarioError(project=self._project, scenario=scenario, error=error, spm_job_id=self._job_id))

    def push_scenario_logs(self, logs):
        timestamps = [l[0].timestamp() for l in logs]
        self._base_client.PushExecutionLogEntries(spm_pb2.LogEntries(
            scenario=self._scenario,
            project_id=self._project,
            execution_id=self._execution_id, log_entries=[l[1] for l in logs],
            timestamps=timestamps
        ))

    def get_download_files(self, files, local_path='./tmp', scenario=None, project_id=None):
        os.makedirs(local_path, exist_ok=True)
        if scenario is None:
            scenario = self._scenario
        if project_id is None:
            project_id = self._project

        files = self._base_client.GetSignedURLs(spm_pb2.FilePaths(
            files=[spm_pb2.FileSignedUrl(path=scenario + '/' + f) for f in files],
            scenario=scenario,
            project_id=project_id
        ))

        files_out = {}
        for f in files.files:
            r = requests.get(f.url, allow_redirects=True)
            local_file_path = local_path + '/' + f.name
            open(local_file_path, 'wb').write(r.content)
            files_out[f.name] = local_file_path
        return files_out

    def upload_file(self, local_file, file_id, file=None, scenario=None, content_type="text/html"):
        if scenario is None:
            scenario = self._scenario
        if file is None:
            file = local_file.split('/')[-1]
        upload_url = self._base_client.GenerateScenarioUploadSignedURL(
            spm_pb2.ScenarioFilePath(project=self._project, scenario=scenario, path=file, file_id=file_id,
                                     content_type=content_type))

        def upload():
            filename = local_file

            filesize = os.path.getsize(filename)

            headers = {"Content-Type": "multipart/related"}
            params = {
                "name": file,
                "mimeType": content_type
            }
            r = requests.post(
                upload_url.url,
                headers=headers,
                params=params,
                data=open(filename, 'rb')
            )

        upload()

    def get_timeseries_stats(self, project=None, scenario=None, execution=None, tag='_index'):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id
        log.debug('stat tag: %s', tag)
        return self._base_client.GetScenarioDataStats(
            spm_pb2.ScenarioStatsRequest(project=project, scenario=scenario, execution=execution, tag=tag))

    def get_dataset_closed(self, project=None, scenario=None, execution=None):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        return self._base_client.GetDataSetClosed(
            spm_pb2.Scenario(project=project, scenario=scenario, execution=execution)).is_closed

    def get_timeseries_meta_data(self, project=None, scenario=None, execution=None):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario

        return self._base_client.GetScenarioMetaData(
            spm_pb2.Scenario(project=project, scenario=scenario, execution=execution))

    def set_timeseries_meta_data(self, tags: list, aliases=None, offset=0, timezone='UTC', epoch_type='s',
                                 scenario=None, execution=None, project=None):
        if aliases is None:
            aliases = {}

        def g(tag, key, def_):

            if key in tag:
                return tag[key]
            elif def_ is not None:
                return def_
            else:
                raise KeyError(f'No default for key: {key}')

        def wrap_tag(tag):

            kwargs = {k: g(tag, k, d) for k, d in {
                'name': None,
                'displayName': '',
                'unit': '',
                'description': '',
                'type': 'double',
                'scaling': 1,
                'offset': 0
            }.items()}

            return spm_pb2.Tag(
                **kwargs
            )

        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        return self._base_client.SetScenarioMetaData(
            spm_pb2.ScenarioMetaData(project=project, scenario=scenario, execution=execution,
                                     tags=[wrap_tag(tag) for tag in tags],
                                     aliases=[spm_pb2.Alias(tag=k, aliases=v) for k, v in aliases.items()],
                                     offset=offset, timezone=timezone, epoch_type=epoch_type
                                     )
        )

    def get_timeseries_custom_meta_data(self, scenario=None, project_id=None, execution=None, key=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id
        if project_id is None:
            project_id = self._project

        return self._base_client.GetScenarioCustomMetaData(
            spm_pb2.ScenarioCustomMetaData(scenario=scenario, project=project_id, execution=execution, key=key)).meta

    def set_timeseries_custom_meta_data(self, meta: str, scenario=None, project_id=None, execution=None, key=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id
        if project_id is None:
            project_id = self._project

        return self._base_client.SetScenarioCustomMetaData(
            spm_pb2.ScenarioCustomMetaData(scenario=scenario, project=project_id, execution=execution, key=key, meta=meta))

    def get_state(self, scenario=None, execution=None):
        try:
            state_json = self.get_timeseries_custom_meta_data(scenario=scenario, execution=execution, key='state')
            return json.loads(state_json)
        except json.decoder.JSONDecodeError:
            return None

    def set_state(self, state, scenario=None, execution=None, force=False):
        tic = time()
        if tic > (self.last_state_set + self.state_debounce) or force:
            log.debug('State set.')
            self.last_state_set = tic
            return self.set_timeseries_custom_meta_data(meta=json.dumps(state), scenario=scenario, execution=execution,
                                                        key='state')

    def get_latest_main_execution(self, project=None, scenario=None):
        return self._base_client.GetLatestMainExecution(
            spm_pb2.Scenario(project=project, scenario=scenario)).execution_id

    def data_read_df(self, tags=None, start: datetime = None, end: datetime = None, project=None, scenario=None,
                     execution=None):
        if tags is None:
            tags = []
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self.get_latest_main_execution(project, scenario)
            log.debug('Exe: ' + execution)

        if start is None:
            start = 0
        if end is None:
            end = 0

        time_range = start > 0 or end > 0

        read_data = self._base_client.ReadData(spm_pb2.ReadScenario(project=project, scenario=scenario, execution=execution, tags=tags, start=start, end=end,
                                 time_range=time_range, listen=False))

        data = {}
        min_len = 1e12
        for l in read_data:
            for r in l.data:
                if not r.tag in data:
                    data[r.tag] = []

                data[r.tag] += r.values

        for v in data.values():
            if len(v) < min_len:
                min_len = len(v)

        data_ = {}

        for k, v in data.items():
            data_[k] = v[:min_len]

        df = pd.DataFrame(data_)

        return df

    def read_data_stream(self, tags, start: datetime = None, end: datetime = None, scenario=None, execution=None, project=None):

        time_range = False
        if isinstance(start, datetime):
            time_range = True
            start = start.timestamp()
        elif isinstance(start, float):
            time_range = True

        if isinstance(end, datetime):
            time_range = True
            end = end.timestamp()
        elif isinstance(end, float):
            time_range = True

        if start is None:
            start = 0

        if end is None:
            end = 0

        if not scenario:
            scenario = self._scenario
        
        if project is None:
            project = self._project

        if execution is None:
            execution = self._execution_id

        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(project=project, scenario=scenario, execution=execution, tags=tags, start=start, end=end,
                                 time_range=time_range))

        for l in read_data:
            for r in l.data:
                yield r.tag, r.values

    def read_data_as_row(self, tags=None, start: int = 0, end: int = 0, scenario=None, execution=None, subscribe=False,
                         offset=0, project=None, time_range=True):
        if tags is None:
            tags = []
        if project is None:
            project = self._project
        if not scenario:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        #print('meta: ', self.get_timeseries_meta_data(scenario=scenario, execution=execution))
        #print('stats: ',self.get_timeseries_stats(scenario=scenario, execution=execution))

        log.debug('Reading data as row')
        log.debug('Subscribed: ' + str(subscribe) + ', starting from: ' + str(start))

        #print(project)
        #print(scenario)
        #print(execution)
        #print(tags)
        #print('start: ', start)
        #print('end: ', end)

        # time_range = (start > 0 or end > 0)
        end = end if end > 0 or not time_range else 3600 * 24 * 365 * 10000

        #print('timerange: ', time_range)
        #print('listen: ', subscribe)

        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(project=project, scenario=scenario, execution=execution, tags=tags, start=start,
                                 end=end, time_range=time_range, listen=subscribe))
        current_row = {}

        for b in read_data:

            for d in b.data:
                if d.tag != "":
                    if not d.tag in current_row:
                        current_row[d.tag] = list(d.values)
                    else:
                        current_row[d.tag] += list(d.values)

            if b.row_complete or b.block_complete:

                if '_index' in current_row:
                    len_current = len(current_row) - 1
                else:
                    len_current = len(current_row)

               # log.debug('Meta: ' + str(self._base_client.GetScenarioMetaData(
               #     spm_pb2.Scenario(project=project, scenario=scenario, execution=execution))))

                # assert len(current_row)-1 == len(tags), f'Missing tags from result: {set(tags).difference(set(current_row))}'
                len_index = len(current_row['_index'])

                wrong_len_tags = []
                for k in list(current_row.keys()):
                    if (len_cr := len(current_row[k])) != len_index:
                        wrong_len_tags.append((k, len_cr))
                if len(wrong_len_tags) > 0:
                    raise ValueError(
                        f'The following tags have a wrong number of data points returned (_index has {len_index}: ' +
                        "\n".join([tag[0] + ': ' + str(tag[1]) for tag in wrong_len_tags]))

                for i in range(len_index):
                    current_ix = current_row['_index'][i] if time_range else i
                    r = {k: v[i] for k, v in current_row.items()}
                    r['_index'] += offset
                    yield r

                    #if current_ix >= start and current_ix < end:
                    #    r = {k: v[i] for k, v in current_row.items()}
                    #    if not time_range or (r['_index'] >= start and r['_index'] < end):
                    #        r['_index'] += offset#

                    #        yield r

                current_row = {}

        if len(list(current_row.keys())) > 0:

            for i, ix in enumerate(current_row['_index']):
                r = {k: v[i] for k, v in current_row.items()}
                r['_index'] += offset
                yield r

        log.debug('Reading as row complete')
        return

    def clear_timeseries_data(self, scenario=None, execution=None):
        log.warning('Clearing data!')

        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        self._base_client.ClearData(
            spm_pb2.Scenario(scenario=scenario, execution=execution))

    def data_write_df(self, df, scenario=None, execution=None, clear=True):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        if clear:
            self.clear_timeseries_data(scenario=scenario, execution=execution)

        len_ix = len(df['_index'])
        i = 0
        data = []
        while i < len_ix:
            i_ = i + 10000
            data.append([{'tag': c, 'values': df[c].values[i:i_]} for c in list(df)])
            i = i_

        def data_writer():
            size = 0
            b_size = len(df['_index'])

            blocks = []
            for r in data:
                for d in r:
                    size += b_size
                    blocks.append(spm_pb2.DataBlock(**d))
                    if size > 2e5:
                        yield spm_pb2.DataList(scenario=scenario, execution=execution, data=blocks, clear=False,
                                               block_complete=False, row_complete=False)
                        size = 0
                        blocks = []
                if size > 0:
                    yield spm_pb2.DataList(scenario=scenario, execution=execution, data=blocks, clear=False,
                                           block_complete=True,
                                           row_complete=False)

        self._base_client.WriteDataList(data_writer())

    def _estimate_block_size(self, data):
        m = 0
        for k, v in data.items():
            m += len(json.dumps(v)) / 1e6
        return m

    def _rows_to_blocks(self, rows):
        if len(rows) == 0:
            return []

        data = {k: [float(r[k]) for r in rows] for k in rows[0]}

        mdata = self._estimate_block_size(data)
        mdatamax = 3.5  # allow some overhead
        estchunks = np.ceil(mdata / mdatamax).astype(int)

        blocks = [list() for i in range(estchunks)]

        # split data into chunks of equal size (except last chunk)

        for k, v in data.items():

            vs = np.array_split(np.array(v), estchunks)

            for i, vsi in enumerate(vs):
                blocks[i].append(spm_pb2.DataBlock(tag=k, values=vsi.tolist()))

        return blocks

    def write_with_row_generator(self, data_generator, scenario: str = None, clear=False, must_flush=None,
                                 execution=None):
        if must_flush is None:
            def must_flush_(n, m):
                return True

            must_flush = must_flush_

        if scenario is None:
            scenario = self._scenario

        if execution is None:
            execution = self._execution_id

        def gen(clear):
            yield spm_pb2.DataList(scenario=scenario, execution=execution, data={}, clear=clear,
                                   block_complete=False, row_complete=False)
            rows = []
            for row in data_generator:
                if row is not None:
                    rows.append(row)
                if len(rows) > 0 and must_flush(len(rows), len(row)):

                    blocks = self._rows_to_blocks(rows)
                    rows = []
                    for block in blocks:
                        yield spm_pb2.DataList(scenario=scenario, execution=execution, data=block, clear=clear,
                                               block_complete=False, row_complete=True)
                    clear = False

            # if len(rows)>0:
            blocks = self._rows_to_blocks(rows)

            for block in blocks:
                yield spm_pb2.DataList(scenario=scenario, execution=execution, data=block, clear=clear,
                                       block_complete=True, row_complete=True)

            # return

        self._base_client.WriteDataList(gen(clear))

    def new_writer(self, buffer_size=100, clear=False):
        if clear:
            self.clear_timeseries_data(self._scenario)

        return NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

    def writer_closed(self, scenario):
        log.debug('Writer Closed')
        execution = self._execution_id
        self._base_client.CloseData(spm_pb2.DataCompleted(scenario=scenario, execution=execution, eta=-1, finalized=True))

    def read_configuration(self, name):
        return self._base_client.ReadConfiguration(spm_pb2.ConfigurationRequest(name=name))

    def store_configuration(self, name, description, datetime_, user, configuration, comment, tags):
        return self._base_client.StoreConfiguration(
            spm_pb2.Configuration(name=name, description=description, datetime=datetime_, user=user,
                                  configuration=configuration, comment=comment, tags=tags))

    @contextmanager
    def open_writer(self, aliases: dict = None, clear=False, buffer_size=24 * 7) -> NumerousBufferedWriter:

        if clear:
            self.clear_timeseries_data(self._scenario)
        writer_ = NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

        self.writers.append(writer_)

        try:
            yield writer_
        finally:
            writer_.close()

    def get_inputs(self, scenario_data, t0=0, te=0, dt=3600, tag_prefix='Stub', tag_seperator='.', timeout=10):
        input_sources = {}
        static_data = {}
        input_source_types = {}
        only_static = True
        input_scenario_map_spec = {s['scenarioID']: s for s in
                                   scenario_data['inputScenarios']}

        for c in scenario_data['simComponents']:
            if "subcomponents" in c:
                for sc in c["subcomponents"]:
                    for sc_ in scenario_data['simComponents']:
                        if sc['uuid'] == sc_['uuid']:
                            sc_['isSubcomponent'] = True

        # Get tags to update and sources
        for sc in scenario_data['simComponents']:
            if not sc['disabled'] and (sc['isMainComponent'] or sc['isSubcomponent']):
                for iv in sc['inputVariables']:

                    tag_ = tag_seperator.join(filter(None, [tag_prefix, sc['name'], iv['id']]))

                    if iv['dataSourceType'] == 'static':
                        static_data[tag_] = iv['value']
                    elif iv['dataSourceType'] == 'scenario':

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = input_scenario_map_spec[iv['dataSourceID']]
                            input_sources[iv['dataSourceID']]['tags'] = {}

                        input_sources[iv['dataSourceID']]['tags'][tag_] = {'tag': iv['tagSource']['tag'],
                                                                           'scale': iv['scaling'],
                                                                           'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']] = 'scenario'

                        """
                        elif iv['dataSourceType'] in ['dataset','csv']:

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = {}

                        input_sources[iv['dataSourceID']][tag_] = {'tag': iv['tagSource']['tag'], 'scale': iv['scaling'],
                                                                   'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']]='csv'
                        """
                    else:
                        raise ValueError('Unknown data source type: ' + iv['dataSourceType'])

        # First add the static source
        log.info('static only: ' + str(only_static))
        inputs = [StaticSource(static_data, t0=t0, dt=dt, generate_index=only_static)]
        log.debug('input source types: ' + str(input_source_types))
        # Then add dynamic sources
        input_scenario_map = {s['scenarioID']: s['executionID'] if 'executionID' in s else None for s in
                              scenario_data['inputScenarios']}

        for input_source, spec in input_sources.items():
            log.debug(spec)
            log.debug(
                'Setting up input sources, scenario: ' + input_source + " exe:" + input_scenario_map[input_source])

            inputs.append(
                DynamicSource(client=self, scenario=input_source, execution=input_scenario_map[input_source], spec=spec,
                              t0=t0, te=te, timeout=timeout, source_type=input_source_types[input_source]))

        input_manager = InputManager(inputs, self._t, t0)
        return input_manager

    def publish_message(self, message, channel=None):
        if channel is None:
            channel = ".".join(['EVENT', self._project, self._scenario])
        self._base_client.PublishSubscriptionMessage(
            spm_pb2.SubscriptionMessage(channel=channel, message=json.dumps(message)))

    def subscribe_messages(self, channel_patterns):
        for message in self._base_client.SubscribeForUpdates(
                spm_pb2.Subscription(channel_patterns=channel_patterns, scenario=self._scenario, project_id=self._project)):
            yield message.channel, json.loads(message.message)

    def refresh_token(self, refresh_token):
        refreshed_token = self._token_manager.GetAccessToken(spm_pb2.Token(val=refresh_token))
        return refreshed_token

    def was_resumed(self) -> bool:
        """Returns True if job was resumed and False otherwise"""
        return self._was_resumed

    @staticmethod
    def _default_hibernation_callback(*args, **kwargs) -> None:
        """Default function to be called when job is hibernated. Does it need functionality?"""
        print(f"Default Hibernation Callback with arguments. args: '{args}' kwargs: '{kwargs}'.")

    def set_hibernation_callback(self, func, *args, **kwargs) -> None:
        """
        Use this function to set a callback to be used when job is hibernated.
        :param func: function to be called
        """
        self._hibernate_callback = func
        self._hibernate_callback_args = (args, kwargs)

    def hibernate(self, message: str = 'Hibernating') -> None:
        """
        Call this function to hibernate Client. This will enable client to reload its state
        after being restarted at a later point.
        """
        log.info('Client received hibernation signal.')
        self._hibernating = True
        self._complete = True
        self.set_scenario_progress(message=message, status='hibernating', force=True)
        signal.raise_signal(signal.SIGTERM)

    def close(self):
        # Close all writers
        for w in self.writers:
            w.close()

        # Close thread ofr refreshing token
        self._access_token_refresher.stop()

        # If job has been hibernated, call the set hibernate callback

        if self._hibernating:
            print("hibernate...")
            self._hibernate_callback(*self._hibernate_callback_args[0], **self._hibernate_callback_args[1])

        if self._complete:

            if not self._hibernating:
                if not self._error:
                    print("setting status...")
                    self.set_scenario_progress(message='Finished', status='finished', force=True)

            print("complete...")
            self.complete_execution()

        if hasattr(self, 'log_handler'):
            log.debug('flushing log')
            self.log_handler.close()

        log.debug('Client closing')
        if hasattr(self, 'log_listener'):
            self.log_listener.stop()


@contextmanager
def open_client(job_id=None, project=None, scenario=None, clear_data=None, no_log=False, instance_id=None,
                refresh_token=None) -> NumerousClient:
    numerous_client_ = NumerousClient(job_id, project, scenario, clear_data=clear_data, no_log=no_log,
                                      instance_id=instance_id, refresh_token=refresh_token)
    try:
        yield numerous_client_

    except Exception as e:
        logging.warning(f"CLIENT FAILED WITH ERROR CODE - w: {e} | {traceback.format_exc()}")
        log.warning(f"CLIENT FAILED WITH ERROR CODE - wl: {e} | {traceback.format_exc()}")
        print(f"CLIENT FAILED WITH ERROR CODE: {e} | {traceback.format_exc()}")
        numerous_client_.push_scenario_error(error=traceback.format_exc())

    finally:
        logging.warning("Closing!")
        numerous_client_.close()
