from contextlib import contextmanager
import json
import os
import logging
import logging.handlers
import tarfile
import io
try:
    import pandas as pd
except:
    pass
from datetime import datetime, timedelta
import pytz
import grpc
from numerous_api_client.python_protos import spm_pb2_grpc, spm_pb2
from queue import Queue
import _queue
import threading
from time import time, sleep
import requests
import signal
from enum import Enum
from numerous_api_client.headers.ValidationInterceptor import ValidationInterceptor
import traceback
from uuid import uuid4
from dirtools import Dir, DirState, filehash
from pathlib import Path, PurePosixPath
import sys
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from numerous_cert_server.cert_helper.get_cert import get_cert
import numpy as np
from scipy.interpolate import interp1d
import types
import urllib3
urllib3.disable_warnings()

class JobStatus(Enum):
    ready=0
    running=1
    finished=2
    request_termination=3
    terminated=4
    failed=5
    requested=6
    initializing=7


def get_env(val, env, default=None):
    if val is None:
        env_val = os.getenv(env)
        #log.debug(env + ': '+str(val))
        if env_val is None:
            if default is None:
                raise KeyError(f'ENV var <{env}> is not set.')
            else:
                return default

        return env_val
    else:
        return val

def datetime_to_json_converter(o):
    if isinstance(o, datetime):
        return o.__str__()

tzl = pytz.utc

log = logging.getLogger('numerous_client.main')
log.setLevel(logging.DEBUG)


def handle_term(signum, frame):
    log.warning('Terminated')
    sys.exit()


signal.signal(signal.SIGINT, handle_term)
signal.signal(signal.SIGTERM, handle_term)


class NumerousBufferedWriter:
    def __init__(self, write_method, writer_closed_method, scenario:str=None, buffer_size:int=24*7):
        self.scenario = scenario
        self._tags = None
        self.ix = 0
        self._write_method = write_method
        self._writer_closed = writer_closed_method
        self._buffer = None
        self.buffer_size = buffer_size
        self.write_queue = Queue()
        self.closed = False

        self.last_flush = time()
        self.max_elapse_flush = 600
        self.min_elapse_flush = 1
        self.force_flush = False

        self.buffer_number_size = 1e6

        def must_flush(n_rows, rows_size):
            tic = time()
            buffer_counter = rows_size * n_rows
            if self.force_flush:
                self.force_flush = False

                self.last_flush = tic
                return True

            elif (n_rows>self.buffer_size and tic> self.min_elapse_flush+self.last_flush) or buffer_counter>self.buffer_number_size:
                self.last_flush = tic
                return True
            elif tic > (self.last_flush + self.max_elapse_flush):

                self.last_flush = tic

                return True

            return False

        def writer_thread_func():

            def write_generator():

                while True:
                    row = self.write_queue.get()

                    if row == 'STOP':

                        return

                    yield row

            not_done = True
            while not_done:
                try:
                    self._write_method(write_generator(), must_flush=must_flush)
                    not_done = False
                except:
                    tb = traceback.format_exc()
                    log.error(tb)
                    raise KeyboardInterrupt('Writer failed')




        self.writer_thread = threading.Thread(target=writer_thread_func)
        self.writer_thread.start()

    def _init_buffer(self):
        self._buffer_count = 0
        #self._buffer = {t: [] for t in self._tags}
        self._buffer_timestamp = []

    def write_row(self, data):
        if not '_index' in data:
            data['_index'] = self.ix
            self.ix += 1

        if isinstance(data['_index'], datetime):
            data['_index'] = data['_index'].timestamp()
        if self.closed:
            raise ValueError('Queue is closed!')
        self.write_queue.put(data)
        return data['_index']

    def flush(self):
        self.force_flush = True
        #self._write_method(self._buffer, self.scenario)
        self._init_buffer()

    def close(self):
        if not self.closed:
            self.flush()

            self.write_queue.put('STOP')

            self.writer_thread.join()

            self._writer_closed(self.scenario)

        self.closed = True


class NumerousLogHandler:
    def __init__(self, client, admin_log_id=None, echo=False):
        self.client = client
        self.buffered_records = []
        self.last_push = datetime.utcnow() - timedelta(hours=1)
        self.debounce = 5
        self.admin_log_id = admin_log_id
        self.echo = echo

        def timed_flush():
            while True:
                sleep(1)
                now_ = datetime.utcnow()
                if (now_ - self.last_push).seconds > self.debounce:
                    self.last_push = now_
                    self.flush()
        threading.Thread(target=timed_flush, daemon=True).start()


    def handle(self, record):
        self.log(record.msg)


    def log(self, message):
        now_ = datetime.utcnow()
        # print(record)
        record = (now_, message)
        self.buffered_records.append(record)
        if self.echo:
            print(record)

        if (now_ - self.last_push).seconds > self.debounce:
            self.last_push = now_
            self.flush()

    def flush(self):
        if len(self.buffered_records)>0:
            to_push = self.buffered_records.copy()
            self.buffered_records = []
            if self.admin_log_id is not None:
                self.client.push_scenario_logs_admin(to_push, exe_id=self.admin_log_id)
            else:
                self.client.push_scenario_logs(to_push)

    def close(self):
        self.flush()


    # TODO make a handler thread to schedule update to server!


class NumerousClient:
    def __init__(self, job_id=None, project=None, scenario=None, server=None, port=None, refresh_token=None, clear_data=None, no_log=False, instance_id=None, secure=None, url=None, resumed=None):
        if url is not None:
            from urllib.parse import urlparse
            parsed_url = urlparse(url)

            server=parsed_url.netloc.split(':')[0]
            port=parsed_url.netloc.split(':')[1]
            secure=parsed_url.scheme == "https"

        self._t = [0]
        self.job_states = JobStatus
        self._project = get_env(project, 'NUMEROUS_PROJECT')
        self._scenario = get_env(scenario, 'NUMEROUS_SCENARIO')
        self._job_id = get_env(job_id, 'JOB_ID')
        self._refresh_token = get_env(refresh_token, 'NUMEROUS_API_REFRESH_TOKEN')
        self._access_token = None

        self._instance_id = str(uuid4()) if instance_id is None else instance_id
        self._execution_id = get_env(None, 'NUMEROUS_EXECUTION_ID', 'not_found')
        if self._execution_id == 'not_found' or self._execution_id is None:
            self._execution_id = str(uuid4()).replace('-','_')
            log.info(f'No execution id found - generating new: {self._execution_id}')

        self._complete = True
        self._hibernating = False
        self._error =False

        self.writers = []

        self.channel = self._init_channel(server=server, port=port, secure=secure, instance_id=self._instance_id)
        self.stub = spm_pb2_grpc.SPMStub(self.channel)
        self._token_manager = spm_pb2_grpc.TokenManagerStub(self.channel)
        self._base_client = self.stub

        # Refresh token every 9 minutes (?)
        self._access_token_refresher = RepeatedFunction(
            interval=9*60, function=self._refresh_access_token, run_initially=True,
            refresh_token=self._refresh_token, instance_id=self._instance_id, execution_id=self._execution_id, project_id=self._project, scenario_id=self._scenario, job_id=self._job_id
        )
        self._access_token_refresher.start()

        self._terminated = False
        if not no_log:
            self._init_logger()

        self.last_progress = -999999
        self.progress_debounce = 10

        self.last_state_set = -999999
        self.state_debounce = 60

        clear_data = get_env(clear_data, 'NUMEROUS_CLEAR_DATA', default='False') == 'True'

        self._was_resumed = get_env(resumed, 'NUMEROUS_JOB_RESUMED', default='False') == 'True'


        self._hibernate_callback = self._default_hibernation_callback
        self._hibernate_callback_args = ([], {})

        if clear_data:
            self.clear_timeseries_data()

        self._listen = threading.Thread(target=self._command_listener)
        self._listen.setDaemon(True)
        self._listen.start()

    def __enter__(self):
        return self
    
    def __exit__(self, _type, _value, _traceback):
        self.close()

    def complete_execution(self):
        log.debug(f'Completing job {self._job_id} with execution {self._execution_id}. Hibernating: {self._hibernating}')
        self._base_client.CompleteExecution(spm_pb2.CompleteExecutionMessage(
            execution=spm_pb2.ExecutionMessage(
                project_id=self._project, scenario_id=self._scenario,
                job_id=self._job_id, execution_id=self._execution_id
            ), hibernate=self._hibernating))

    def _init_channel(self, server, port, secure=None, instance_id=None):
        secure_channel = get_env(secure, 'SECURE_CHANNEL')
        server = get_env(server, 'NUMEROUS_API_SERVER')
        port = get_env(port, 'NUMEROUS_API_PORT')

        log.info(f"Client connecting to: {server}:{port}, using secure channel: {secure_channel}")
        if str(secure_channel) == 'True':
            #cert = str.encode(get_cert(f'https://{server}:4443/cert'))
            creds = grpc.ssl_channel_credentials()
            # Deal with cert being self-signed:
            #cert_decoded = x509.load_pem_x509_certificate(cert, default_backend())
            #cert_cn = cert_decoded.subject.rfc4514_string().split('CN=')[-1].split(',')[0]
            #options = (('grpc.ssl_target_name_override', cert_cn),)
            channel = grpc.secure_channel(f'{server}:{port}', creds)
        else:
            channel = grpc.insecure_channel(f'{server}:{port}')

        vi = ValidationInterceptor(token=self._access_token, token_callback=self._get_current_token, instance=instance_id)
        self._instance = vi.instance
        channel = grpc.intercept_channel(channel, vi)
        return channel

    def _get_current_token(self):
        return self._access_token

    def _init_logger(self):
        log_queue = Queue(-1)
        numerous_logger = logging.handlers.QueueHandler(queue=log_queue)
        logger = logging.getLogger()
        logger.setLevel(logging.ERROR)
        logger.addHandler(numerous_logger)
        self.log_handler = NumerousLogHandler(self)
        self.log_listener = logging.handlers.QueueListener(log_queue, self.log_handler)
        self.log_listener.start()

    def _refresh_access_token(self, refresh_token, instance_id, execution_id, project_id=None, scenario_id=None, job_id=None):
        token = self._token_manager.GetAccessToken(
            spm_pb2.RefreshRequest(
                refresh_token=spm_pb2.Token(val=refresh_token), instance_id=instance_id, execution_id=execution_id,project_id=project_id, scenario_id=scenario_id, job_id=job_id
            )
        )
        self._access_token = token.val

    def _command_listener(self):
        for channel, message in self.subscribe_messages([".".join(['COMMAND', self._project, self._scenario, self._job_id])]):
            log.debug('Received message: '+str(message)+'\n on channel: '+str(channel))
            if 'command' in message:
                command = message['command']
                if command == 'terminate':
                    log.warning('Received Termination Command')
                    signal.raise_signal(signal.SIGTERM)
                    break

                if command == 'hibernate':
                    log.warning("Received Hibernate Command")
                    self.hibernate(message='Hibernating')
                    break

    def get_scenario_document(self, scenario=None, project=None):
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project

        scenario_data = self._base_client.GetScenario(spm_pb2.Scenario(project=project, scenario=scenario))

        return json.loads(scenario_data.scenario_document), scenario_data.files

    def get_job(self, scenario_doc=None):
        if scenario_doc is None:
            scenario_doc, files = self.get_scenario_document()

        return scenario_doc['jobs'][self._job_id]

    def get_group_document(self, group):
        group_data = self._base_client.GetGroup(spm_pb2.Group(project=self._project, group=group))

        return json.loads(group_data.group_document)

    def get_project_document(self):
        proj_data = self._base_client.GetProject(spm_pb2.Project(project=self._project))

        return json.loads(proj_data.project_document)

    def listen_scenario_document(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        for doc in self._base_client.ListenScenario(spm_pb2.Scenario(project=project, scenario=scenario)):
            json_doc = doc.scenario_document

            yield json.loads(json_doc), doc.files



    def get_scenario(self, scenario=None, project=None, path='.'):
        os.makedirs(path, exist_ok=True)

        log.debug('Get scenario')
        if scenario is None:
            scenario = self._scenario
        if project is None:
            project = self._project

        scenario_data, scenario_files = self.get_scenario_document(scenario, project)

        scenario_files = [{'name': f.name, 'url': f.url, 'path': f.path}
                          for f in scenario_files]

        model_data, model_files = self.get_model(scenario_data['systemID'], project_id=project, scenario_id=scenario)

        model_files = [{'name': f.name, 'url': f.url, 'path': f.path}
        for f in model_files]

        file_paths_local = {}
        for f in scenario_files + model_files:

            f_path = path + '/' + f['name']
            file_paths_local[f['name']] = f_path
            r = requests.get(f['url'], allow_redirects=True)
            open(f_path, 'wb').write(r.content)

        return scenario_data, model_data, file_paths_local

    def get_model(self, model_id, project_id=None, scenario_id=None):
        if project_id is None:
            project_id=self._project

        model = self._base_client.GetModel(spm_pb2.Model(model_id=model_id, project_id=project_id, scenario_id=scenario_id))

        return json.loads(model.model), model.files

    #def set_scenario_data_tags(self, tags, scenario=None):
    #    if scenario is None:
    #        scenario = self._scenario

    #    self._base_client.SetScenarioDataTags(spm_pb2.ScenarioDataTags(project=self._project, scenario=scenario, tags=tags))
    def set_scenario_job_image(self, name, path, project=None, scenario=None, job=None):
        if scenario is None:
            scenario = self._scenario

        if project is None:
            project = self._project

        if job is None:
            job = self._job_id

        self._base_client.SetScenarioJobImage(spm_pb2.ScenarioJobImage(
            project_id=project, scenario_id=scenario, job_id=job,
            name=name, path=path
        ))

    def set_scenario_progress(self, message, status, progress=None, clean=False, scenario=None, force=False, project=None, job=None):
        if status == 'failed':
            self._error = True

        #log.debug('Status: '+ status)
        if scenario is None:
            scenario = self._scenario

        if project is None:
            project = self._project

        if job is None:
            job = self._job_id

        tic = time()

        if tic > (self.last_progress + self.progress_debounce) or force:

            self.last_progress=tic
            self._base_client.SetScenarioProgress(spm_pb2.ScenarioProgress(
                project=project, scenario=scenario, job_id=job,
                message=message, status=status, clean=clean, progress=progress
            ))

    def clear_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.ClearScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

    def set_scenario_results(self, names:list, values:list, units:list, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.SetScenarioResults(spm_pb2.ScenarioResults(project=self._project, scenario=scenario, names=names, values=values, units=units))

    def get_scenario_results(self, scenario=None):
        if scenario is None:
            scenario = self._scenario

        results = self._base_client.GetScenarioResults(spm_pb2.Scenario(project=self._project, scenario=scenario))

        return results.names, results.values, results.units

    def push_scenario_error(self, error, scenario=None):
        if scenario is None:
            scenario = self._scenario

        self._base_client.PushScenarioError(spm_pb2.ScenarioError(project=self._project, scenario=scenario, error=error, spm_job_id=self._job_id))

    #def push_scenario_log(self, message, initialize=False, scenario=None):
    #    if scenario is None:
    #        scenario = self._scenario

    #    self._base_client.PushScenarioLogEntr(spm_pb2.ScenarioLogEntry(
    #        project=self._project, scenario=scenario, message=message,
    #        spm_job_id=self._job_id, initialize=initialize
    #    ))

    def push_scenario_logs(self, logs):

        #print('logs: ',logs)
        timestamps = [l[0].timestamp() for l in logs]
        #print('time: ', timestamps)
        self._base_client.PushExecutionLogEntries(spm_pb2.LogEntries(
            scenario=self._scenario,
            execution_id=self._execution_id, log_entries=[l[1] for l in logs],
            timestamps=timestamps
        ))

    def delete_scenario(self, scenario=None):
        self._base_client.PushScenarioFormattedError(
            spm_pb2.ScenarioFormattedError(project=self._project, scenario=scenario, message=message, hint=hint,
                                           category=category, exception_object_type=exception_object_type,
                                           exception_object_message=exception_object_message,
                                           full_traceback=full_traceback))

    def get_download_files(self, files, local_path='./tmp', scenario=None):
        os.makedirs(local_path, exist_ok=True)
        if scenario is None:
            scenario = self._scenario
    #    project = "Xcqmc04Ckc9y7IbG6BWa"
    #    scenario = "pYyQFfXoCMZ2Tp6whbak"

        files = self._base_client.GetSignedURLs(spm_pb2.FileSignedUrls(files=[spm_pb2.FileSignedUrl(path=scenario+'/'+f) for f in files], scenario=scenario))

        files_out = {}
        for f in files.files:

            r = requests.get(f.url, allow_redirects=True)
            local_file_path = local_path+'/'+f.name
            open(local_file_path, 'wb').write(r.content)
            files_out[f.name] = local_file_path
        return files_out

    def upload_file(self, local_file, file_id, file=None, scenario=None, content_type="text/html"):
        if scenario is None:
            scenario = self._scenario

        if file is None:
            file = local_file.split('/')[-1]
        #print('File: ', file)
        upload_url = self._base_client.GenerateScenarioUploadSignedURL(spm_pb2.ScenarioFilePath(project=self._project, scenario=scenario, path=file, file_id=file_id, content_type=content_type))

        def upload():
            filename = local_file

            filesize = os.path.getsize(filename)

            headers = {"Content-Type": "multipart/related"}
            params = {
                "name": file,
                "mimeType": content_type
            }
            r = requests.post(
                upload_url.url,
                headers=headers,
                params=params,
                data=open(filename, 'rb')
            )

        upload()

    def get_timeseries_stats(self, project=None, scenario=None, execution=None, tag='_index'):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id
        print('stat tag: ', tag)
        return self._base_client.GetScenarioDataStats(spm_pb2.ScenarioStatsRequest(project=project, scenario=scenario, execution=execution, tag=tag))

    def get_dataset_closed(self, project=None, scenario=None, execution=None):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario


        return self._base_client.GetDataSetClosed(spm_pb2.Scenario(project=project, scenario=scenario, execution=execution)).is_closed

    def get_timeseries_meta_data(self, project=None, scenario=None, execution=None):
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario

        return self._base_client.GetScenarioMetaData(spm_pb2.Scenario(project=project, scenario=scenario, execution=execution))

    def set_timeseries_meta_data(self, tags:list, aliases:dict={}, offset=0, timezone='UTC', epoch_type='s', scenario= None, execution=None, project=None):

        def g(tag, key, def_):

            if key in tag:
                return tag[key]
            elif def_ is not None:
                return def_
            else:
                raise KeyError(f'No default for key: {key}')

        def wrap_tag(tag):

            kwargs = {k: g(tag,k,d) for k, d in {
                'name': None,
                'displayName': '',
                'unit': '',
                'description': '',
                'type': 'double',
                'scaling': 1,
                'offset': 0
            }.items()}

            return spm_pb2.Tag(
                **kwargs
            )
        if project is None:
            project = self._project
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        return self._base_client.SetScenarioMetaData(
            spm_pb2.ScenarioMetaData(project=project, scenario=scenario, execution=execution,
            tags=[wrap_tag(tag) for tag in tags],
            aliases=[spm_pb2.Alias(tag=k, aliases=v) for k,v in aliases.items()],
                                     offset=offset, timezone=timezone, epoch_type=epoch_type
                            )
        )

    def get_timeseries_custom_meta_data(self, scenario=None, execution=None, key=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        return self._base_client.GetScenarioCustomMetaData(spm_pb2.ScenarioCustomMetaData(scenario=scenario, execution=execution, key=key)).meta

    def set_timeseries_custom_meta_data(self, meta:str, scenario=None, execution=None, key=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        return self._base_client.SetScenarioCustomMetaData(spm_pb2.ScenarioCustomMetaData(scenario=scenario, execution=execution, key=key, meta=meta))

    def get_state(self, scenario=None, execution=None):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        try:
            state_json = self.get_timeseries_custom_meta_data(scenario=None, execution=execution, key='state')
            return json.loads(state_json)
        except json.decoder.JSONDecodeError:
            #print('bad json: ', state_json)
            return None

    def set_state(self, state, scenario=None, execution=None, force=False):
        tic = time()
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        if tic > (self.last_state_set + self.state_debounce) or force:
            log.debug('State set.')
            self.last_state_set = tic
            return self.set_timeseries_custom_meta_data(meta=json.dumps(state),scenario=scenario, execution=execution, key='state')

    def get_latest_main_execution(self, project=None, scenario=None):
        print('p: ', project)
        return self._base_client.GetLatestMainExecution(spm_pb2.Scenario(project=project, scenario=scenario)).execution_id


    def data_read_df(self, tags=[], start:datetime=None, end:datetime=None, project=None, scenario=None, execution=None):
        #print(tags)

        if project is None:
            project = self._project

        if scenario is None:
            scenario = self._scenario

        if execution is None:
            execution = self.get_latest_main_execution(project, scenario)
            log.debug('Exe: ' + execution)

        #if execution is None:
        #    execution = self._execution_id

        if  start is None:
            start=0
        if end is None:
            end = 0

        time_range = (start > 0 or end > 0)


        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(project=project, scenario=scenario, execution=execution, tags=tags, start=start, end=end, time_range=time_range,
                                 listen=False))

        #read_data = self._base_client.ReadData(
        #    spm_pb2.ReadScenario(scenario=scenario, tags=tags, start=start, end=end, time_range=time_range, listen=False))

        #data = {r.tag: r.values for l in read_data for r in l.data}
        data = {}
        min_len = 1e12
        for l in read_data:
            #print('l: ', l)

            for r in l.data:
                if not r.tag in data:
                    data[r.tag] = []

                data[r.tag]+=r.values

        for v in data.values():
            if len(v) < min_len:
                min_len = len(v)

        data_ = {}

        for k,v in data.items():
            data_[k] = v[:min_len]


        df = pd.DataFrame(data_)

        return df

    def read_data_stream(self, tags, start:datetime=None, end:datetime=None, scenario=None, execution=None):

        time_range= False
        if isinstance(start, datetime):
            time_range = True
            start = start.timestamp()
        elif isinstance(start, float):
            time_range = True

        if isinstance(end, datetime):
            time_range = True
            end = end.timestamp()
        elif isinstance(end, float):
            time_range = True

        if start is None:
            start = 0

        if end is None:
            end = 0

        if not scenario:
            scenario = self._scenario

        if execution is None:
            execution = self._execution_id

        read_data = self._base_client.ReadData(spm_pb2.ReadScenario(scenario=scenario, execution=execution, tags=tags, start=start, end=end, time_range = time_range))

        for l in read_data:
            for r in l.data:
                yield r.tag, r.values

    def read_data_as_row(self, tags=[], start:int=0, end:int=0, scenario=None, execution=None, subscribe=False, offset=0, project=None, time_range=True):
        if project is None:
            project = self._project

        log.debug('Reading data as row')
        log.debug('Subscribed: ' + str(subscribe) + ', starting from: '+str(start))
        if not scenario:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        print(project)
        print(scenario)
        print(execution)
        print(tags)
        print('start: ', start)
        print('end: ', end)

        #time_range = (start > 0 or end > 0)
        #log.debug(f'Time range?: {time_range}')
        end = end if end > 0 or not time_range else 3600*24*365*10000


        print('timerange: ', time_range)
        print('listen: ', subscribe)

        read_data = self._base_client.ReadData(
            spm_pb2.ReadScenario(project=project, scenario=scenario, execution=execution, tags=tags, start=start, end=end, time_range=time_range, listen=subscribe))
        current_row = {}


        for b in read_data:


            for d in b.data:
                if d.tag != "":
                    if not d.tag in current_row:
                        current_row[d.tag]=list(d.values)
                    else:
                        current_row[d.tag] += list(d.values)


            if b.row_complete or b.block_complete:
                print(current_row.keys())
                #print(tags)
                if '_index' in current_row:
                    len_current = len(current_row)-1
                else:
                    len_current = len(current_row)

                log.debug('Meta: '+str(self._base_client.GetScenarioMetaData(
                    spm_pb2.Scenario(project=project, scenario=scenario, execution=execution))))

                #assert len(current_row)-1 == len(tags), f'Missing tags from result: {set(tags).difference(set(current_row))}'
                len_index = len(current_row['_index'])

                wrong_len_tags = []
                for k in list(current_row.keys()):
                    if (len_cr:=len(current_row[k])) != len_index:
                        wrong_len_tags.append((k, len_cr))
                if len(wrong_len_tags)>0:
                    raise ValueError(f'The following tags have a wrong number of data points returned (_index has {len_index}: '+"\n".join([tag[0] +': '+str(tag[1]) for tag in wrong_len_tags]))

                for i in range(len_index):
                    current_ix = current_row['_index'][i] if time_range else i

                    if current_ix >= start and current_ix < end:

                        r = {k: v[i] for k,v in current_row.items()}
                        if not time_range or (r['_index']>=start and r['_index']<end):
                            r['_index']+=offset

                            yield r


                current_row = {}

        if len(list(current_row.keys()))>0:

            for i, ix in enumerate(current_row['_index']):
                #print(i)
                r = {k: v[i] for k, v in current_row.items()}
                r['_index'] += offset
                yield r


        log.debug('Reading as row complete')
        return

    def clear_timeseries_data(self, scenario=None, execution=None):
        log.warning('Clearing data!')

        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        self._base_client.ClearData(
            spm_pb2.Scenario(scenario=scenario, execution=execution))

    def data_write_df(self, df, scenario=None, execution=None, clear=True):
        if scenario is None:
            scenario = self._scenario
        if execution is None:
            execution = self._execution_id

        if clear:
            self.clear_timeseries_data(scenario=scenario, execution=execution)

        #df['_index']=[t.timestamp() for t in df.index]
        len_ix = len(df['_index'])
        i=0
        data=[]
        while i<len_ix:
            i_ = i + 10000
            data.append([{'tag': c, 'values': df[c].values[i:i_]} for c in list(df)])
            i=i_

        def data_writer():
            size = 0
            b_size = len(df['_index'])

            blocks = []
            for r in data:
                for d in r:
                    size += b_size
                    blocks.append(spm_pb2.DataBlock(**d))
                    if size>2e5:
                        yield spm_pb2.DataList(scenario=scenario, execution=execution, data=blocks, clear=False, block_complete=False, row_complete=False)
                        size=0
                        blocks = []
                if size>0:
                    yield spm_pb2.DataList(scenario=scenario, execution=execution, data=blocks, clear=False, block_complete=True,
                                           row_complete=False)

        tic = time()
        self._base_client.WriteDataList(data_writer())

        toc = time()
        #self._base_client.CloseData(spm_pb2.DataCompleted(scenario=scenario))

    def _rows_to_blocks(self, rows):
        blocks = []
        data = {}
        counter = 0

        def flush_block(data):

            blocks.append([spm_pb2.DataBlock(tag=k, values=v) for k, v in data.items()])

        for r in rows:
            row_count = 0
            for k, v in r.items():
                if not k in data:
                    data[k] = []
                data[k].append(v)
                row_count += 1

            counter += row_count
            if counter + row_count > 3e6/8:
                flush_block(data)
                data = {}
                counter = 0

        if counter>0:
            flush_block(data)
            data = {}
            counter = 0

        return blocks

    def write_with_row_generator(self, data_generator,  scenario:str=None, clear = False, must_flush=None, execution=None):
        if must_flush is None:
            def must_flush_(n):
                return True
            must_flush = must_flush_

        #clear_ = clear
        if scenario is None:
            scenario = self._scenario

        if execution is None:
            execution = self._execution_id

        def gen(clear):

            rows = []
            for row in data_generator:
                if row is not None:
                    rows.append(row)
                if len(rows)>0 and must_flush(len(rows), len(row)):

                    blocks = self._rows_to_blocks(rows)
                    rows = []
                    for block in blocks:
                        yield spm_pb2.DataList(scenario=scenario, execution=execution, data=block, clear=clear, block_complete=False, row_complete=True)
                    clear = False

            #if len(rows)>0:
            blocks = self._rows_to_blocks(rows)

            for block in blocks:
                yield spm_pb2.DataList(scenario=scenario, execution=execution, data=block, clear=clear, block_complete=True, row_complete=True)

            #return

        self._base_client.WriteDataList(gen(clear))

    def new_writer(self, buffer_size=100, clear=False):
        if clear:
            self.clear_timeseries_data(self._scenario)

        return NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

    def writer_closed(self, scenario):
        log.debug('Writer Closed')
        self._base_client.CloseData(spm_pb2.DataCompleted(scenario=scenario, eta=-1, finalized=True))

    def read_configuration(self, name):
        return self._base_client.ReadConfiguration(spm_pb2.ConfigurationRequest(name=name))

    def store_configuration(self, name, description, datetime_, user, configuration, comment, tags):
        return self._base_client.StoreConfiguration(spm_pb2.Configuration(name=name, description=description, datetime=datetime_, user=user, configuration=configuration, comment=comment, tags=tags))

    @contextmanager
    def open_writer(self, aliases:dict=None, clear=False, buffer_size=24*7) -> NumerousBufferedWriter:

        if clear:
            self.clear_timeseries_data(self._scenario)
        writer_ = NumerousBufferedWriter(self.write_with_row_generator, self.writer_closed, self._scenario, buffer_size)

        self.writers.append(writer_)

        try:
            yield writer_
        finally:
            writer_.close()

    def get_inputs(self, scenario_data, t0=0, te=0, dt=3600, tag_prefix='Stub', tag_seperator='.', timeout=10):
        input_sources = {}
        static_data = {}
        input_source_types = {}
        only_static = True
        input_scenario_map_spec = {s['scenarioID']: s for s in
                                   scenario_data['inputScenarios']}

        for c in scenario_data['simComponents']:
            if "subcomponents" in c:
                for sc in c["subcomponents"]:
                    for sc_ in scenario_data['simComponents']:
                        if sc['uuid'] == sc_['uuid']:
                            sc_['isSubcomponent'] = True

        # Get tags to update and sources
        for sc in scenario_data['simComponents']:
            if not sc['disabled'] and (sc['isMainComponent'] or sc['isSubcomponent']):
                for iv in sc['inputVariables']:

                    tag_ = tag_seperator.join(filter(None, [tag_prefix, sc['name'], iv['id']]))

                    if iv['dataSourceType'] == 'static':
                        static_data[tag_] = iv['value']


                    elif iv['dataSourceType'] == 'scenario':

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = input_scenario_map_spec[iv['dataSourceID']]
                            input_sources[iv['dataSourceID']]['tags'] = {}

                        input_sources[iv['dataSourceID']]['tags'][tag_] = {'tag': iv['tagSource']['tag'], 'scale': iv['scaling'],
                                                                 'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']]='scenario'

                        """
                        elif iv['dataSourceType'] in ['dataset','csv']:

                        only_static = False

                        if iv['dataSourceID'] not in input_sources:
                            input_sources[iv['dataSourceID']] = {}

                        input_sources[iv['dataSourceID']][tag_] = {'tag': iv['tagSource']['tag'], 'scale': iv['scaling'],
                                                                   'offset': iv['offset']}
                        input_source_types[iv['dataSourceID']]='csv'
                        """
                    else:
                        raise ValueError('Unknown data source type: ' + iv['dataSourceType'])

        # First add the static source
        log.info('static only: '+ str(only_static))
        inputs = [StaticSource(static_data, t0=t0, dt=dt, generate_index=only_static)]
        log.debug('input source types: ' + str(input_source_types))
        # Then add dynamic sources
        input_scenario_map = {s['scenarioID']: s['executionID'] if 'executionID' in s else None for s in scenario_data['inputScenarios']}


        for input_source, spec in input_sources.items():
            log.debug(spec)
            log.debug('Setting up input sources, scenario: '+input_source + " exe:"+ input_scenario_map[input_source])

            inputs.append(DynamicSource(client=self, scenario=input_source, execution=input_scenario_map[input_source], spec=spec, t0=t0, te=te, timeout=timeout, source_type=input_source_types[input_source]))

        input_manager = InputManager(inputs, self._t, t0)
        return input_manager



    def publish_message(self, message):
        channel = ".".join(['EVENT', self._project, self._scenario])

        self._base_client.PublishSubcsriptionMessage(spm_pb2.SubscriptionMessage(channel=channel, message=json.dumps(message)), project_id=self._project, scenario_id=self._scenario)

    def subscribe_messages(self, channel_patterns):
        for message in self._base_client.SubscribeForUpdates(spm_pb2.Subscription(channel_patterns=channel_patterns, scenario=self._scenario)):
            yield message.channel, json.loads(message.message)

    def refresh_token(self, refresh_token):
        refreshed_token = self._token_manager.GetAccessToken(spm_pb2.Token(val=refresh_token))
        return refreshed_token

    def was_resumed(self) -> bool:
        """Returns True if job was resumed and False otherwise"""
        return self._was_resumed

    @staticmethod
    def _default_hibernation_callback(*args, **kwargs) -> None:
        """Default function to be called when job is hibernated. Does it need functionality?"""
        print(f"Default Hibernation Callback with arguments. args: '{args}' kwargs: '{kwargs}'.")

    def set_hibernation_callback(self, func, *args, **kwargs) -> None:
        """
        Use this function to set a callback to be used when job is hibernated.
        :param func: function to be called
        """
        self._hibernate_callback = func
        self._hibernate_callback_args = (args, kwargs)

    def hibernate(self, message: str = 'Hibernating') -> None:
        """
        Call this function to hibernate Client. This will enable client to reload its state
        after being restarted at a later point.
        """
        log.info('Client received hibernation signal.')
        self._hibernating = True
        self._complete = True
        self.set_scenario_progress(message=message, status='hibernating', force=True)
        signal.raise_signal(signal.SIGTERM)

    def close(self):
        # Close all writers
        for w in self.writers:
            w.close()

        # Close thread ofr refreshing token
        self._access_token_refresher.stop()

        # If job has been hibernated, call the set hibernate callback
        if self._hibernating:
            self._hibernate_callback(*self._hibernate_callback_args[0], **self._hibernate_callback_args[1])

        if self._complete:
            if not self._hibernating and not self._error:
                print("setting status...")
                self.set_scenario_progress(message='Finished', status='finished', force=True)

            print("complete...")
            self.complete_execution()

        if hasattr(self, 'log_handler'):
            log.debug('flushing log')
            self.log_handler.close()

        log.debug('Client closing')
        if hasattr(self, 'log_listener'):
            self.log_listener.stop()


@contextmanager
def open_client(job_id=None, project=None, scenario=None, clear_data=None, no_log=False, instance_id=None, refresh_token=None) -> NumerousClient:
    numerous_client_ = NumerousClient(job_id, project, scenario, clear_data=clear_data, no_log=no_log, instance_id=instance_id, refresh_token=refresh_token)

    try:
        yield numerous_client_

    except Exception as e:
        logging.warning(f"CLIENT FAILED WITH ERROR CODE - w: {e} | {traceback.format_exc()}")
        log.warning(f"CLIENT FAILED WITH ERROR CODE - wl: {e} | {traceback.format_exc()}")
        print(f"CLIENT FAILED WITH ERROR CODE: {e} | {traceback.format_exc()}")
        numerous_client_.push_scenario_error(error=traceback.format_exc())


    finally:

        logging.warning("Closing!")


        numerous_client_.close()


class StaticSource:
    def __init__(self, row_data, t0=0, t_end=None, dt=3600, generate_index=True):
        self.row = row_data
        self.source_type = 'static'
        #self.dt = dt
        self.t = np.inf
        self.t0 = t0
        #self.t_end = None
        self.subscribe = False

        self.generate_index = generate_index

    def get_at_time(self, t):
        self.t = t
        self.row['_index'] = t
        self.row['_index_relative'] = t - self.t0
        return self.row


class DynamicSource:
    def __init__(self, client, scenario, execution, spec, t0=0, te=0, timeout=None, source_type=None):
        self.timeout = timeout

        self.t0 = t0

        self.meta = client.get_timeseries_meta_data(scenario=scenario, execution=execution)
        self.closed = client.get_dataset_closed(scenario=scenario, execution=execution)

        self.name = scenario

        self.spec = spec

        self.spec['offset'] = self.spec['offset'] if 'offset' in self.spec else 0

        #Offset which must be added to this datasource in order for the query to be in the correct time frame.
        offset = self.meta.offset + t0 - self.spec['offset']

        #Offset which must be applied to the queried data in order for it to match into the simulation time frame.
        self.offset_after_read = - self.meta.offset + self.spec['offset']

        start = offset
        end = offset + te-t0
        #else:
            #TODO Get info from scenario
        #    offset = self.spec['offset'] - self.meta.offset
        stats = client.get_timeseries_stats(None, scenario, execution)

        #If data set is open, we will try to listen for data as its being generated if t end not reached.
        self.subscribe = not self.closed
        continuous = True#not self.subscribe

        self.offset = offset
        #self.meta.offset + self.spec['offset']



        if stats.min+offset > t0:
            raise Exception(f'dynamic input scenario {scenario} with tags {self.meta.tags} is offset to the future {stats.min+offset}>{t0}')


        self.start = start
        self.end = end
        #print('name: ', self.name)
        #print('Source type: ', spec['type'])
        #print('spec off: ',
        #      datetime.fromtimestamp(self.spec['offset']) if self.spec['offset'] > 0 else self.spec['offset'])

        print('meta off: ', datetime.fromtimestamp(self.meta.offset))

        print('Combined offset: ', offset)
        print('Combined offset: ', datetime.fromtimestamp(offset))
        print('t0: ', datetime.fromtimestamp(t0))
        print('te: ', datetime.fromtimestamp(te))
        print('start: ', datetime.fromtimestamp(start))
        print('end: ', datetime.fromtimestamp(end))
        print('start: ', start)
        print('end: ', end)

        tags=set([s['tag'] for s in spec['tags'].values()])
        log.debug('Datasource Tags: '+str(tags))
        self.generator = client.read_data_as_row(scenario=scenario, tags=tags,
                                                 execution=execution, start=start, end=end,
                                                 subscribe=self.subscribe)
        self.source_type = source_type
        self.row_queue = Queue(100)
        self.t = -1
        self.row_buffer = RowBuffer(tags=[s for s in spec['tags'].keys()], offset=offset)

        def _next():

            torg = 0
            last_torg = 0
            while True:
                if not self.row_queue.full():
                    first = True
                    while continuous or first:
                        first = False
                        try:

                            row_ = self.generator.__next__()

                            torg = row_['_index'] - self.meta.offset + last_torg
                            self.row_queue.put(row_)


                        except StopIteration:
                            print(f"all data read: {torg+self.offset}")
                            if self.subscribe:
                                stats = client.get_timeseries_stats(None, scenario, execution)
                                if stats.max > self.t:
                                    self.start = self.t
                                    self.end = stats.max
                                    self.generator = client.read_data_as_row(tags=tags,
                                                                         scenario=scenario, execution=execution, start=self.start, end=self.end,
                                                                         subscribe=self.subscribe)
                                else:
                                    sleep(10)

                            if not continuous:
                                self.row_queue.put('STOP')

                                return

                            elif continuous:
                                #self.offset += t0
                                log.debug('Requerying data')
                                #log.debug('Offset: '+str(self.offset))
                                self.generator = client.read_data_as_row(tags=tags,
                                                                         scenario=scenario, execution=execution, start=self.meta.offset, end=self.end,
                                                                         subscribe=self.subscribe, offset=torg)
                                last_torg = torg


                else:
                    sleep(.01)

        threading.Thread(target=_next, daemon=True).start()

    def __next__(self):

        try:
            row_ = self.row_queue.get(timeout=self.timeout)
        except _queue.Empty:
            log.debug('timeout')
            raise TimeoutError('Timed out!')

        if row_ == 'STOP':
            log.debug('Iter done')
            raise StopIteration('Iterator consumed')

        row = {}
        for tag, s in self.spec['tags'].items():
            try:
                row[tag] = row_[s['tag']]*s['scale'] + s['offset']
            except:
                log.debug(str(row_))
                raise

        #print('offset: ', self.offset)

        row['_index_relative'] = row_['_index'] - self.offset

        row['_index'] = row_['_index'] + self.offset_after_read

        #print(self.name)
        #print('_index: ', row['_index'])

        #print('_ix_rel: ', row['_index_relative'])

        row['_datetime_utc'] = datetime.utcfromtimestamp(row['_index'])

        #print('_datetime_utc: ', row['_datetime_utc'])


        self.t = row['_index']
        return row

    def get_at_time(self, t):
        if self.t <= t:
            self.row_buffer.trim()

        while self.t <= t:
            #print(t)
            #print(self.t)
            part_row = self.__next__()

            #print('prd: ', part_row)
            self.row_buffer.add(part_row)
        part_row = self.row_buffer.get(t, method=0)
        #print('ddd: ', part_row)
        return part_row

import numpy as np
class RowBuffer:
    def __init__(self, tags=None, offset=None):
        self.buffer_len = 100
        self._buffer = None
        self._pos = 0
        self.tags = tags
        self.offset = offset

        self.ix_t = None #self.tags.index('_index')

    def trim(self):
        if self._buffer is None:
            return

        if len(self._buffer) == 0:
            self._pos=0
            return

        last_row = self._buffer[-1]
        self._buffer[:,0] =  self._buffer[:,self._pos-1]
        self._buffer[0, 0] = 0
        self._pos = 1

    def add(self, row):
        row.pop('_datetime_utc')
        #print(row)
        if self._pos >= self.buffer_len:
            raise IndexError('Buffer size exceeded!')

        if self.ix_t is None:
            self._buffer = np.empty((len(row) + 1, self.buffer_len), dtype=np.float64)
            self._buffer[0, :] = [i for i in range(self.buffer_len)]
            #self.tags = row.keys()
            self.ix_t = list(row.keys()).index('_index')+1

        #self._buffer[0, self._pos] = row['_index_relative']
        #print('vals: ', [v for v in row.values()])
        self._buffer[1:,self._pos] = [v for v in row.values()]
        self._pos += 1



    def get_(self, t, method):
        y = {tag: None for tag in self.tags}
        x = [d['_index_relative'] for d in self._buffer]
        for tag in self.tags:
            y[tag] = [d[tag] for d in self._buffer]
        irow = {}
        if method == 0:
            for tag in self.tags:
                ifun=interp1d(x, y[tag], kind='zero')
                #try:
                irow[tag] = ifun(t).tolist()
                #except Exception as e:
                #    a=1
        else:
            raise ValueError('Unknown interpolation type!')


        t_rel = t-self.offset
        irow.update({'_index': t, '_index_relative': t_rel, '_datetime_utc': datetime.utcfromtimestamp(t)})

        return irow

    def get(self, t, method):
        #print('xnew: ', t)
        if method == 0:
            #print('ix_t: ', self.ix_t)
            #print(self._buffer[self.ix_t,:self._pos])
            ifun=interp1d(self._buffer[self.ix_t,:self._pos], self._buffer[0,:self._pos], kind='zero')
            ix = int(ifun(t))
            #print(ix)
            #print(self._buffer)
            irow = {tag: val for tag, val in zip(self.tags, self._buffer[1:,ix])}
            #print(irow)
        else:
            raise ValueError('Unknown interploation type')
        return irow


class InputManager:
    def __init__(self, sources, t, t0):
        self.sources = sources
        self.t = t
        #print('Inp mana t: ', t)
        self.t0 = t0

    def get_at_time(self, t):

        row = {}

        #print('t: ', t)
        for source in self.sources:

            try:
                part_row = source.get_at_time(t)

            except TimeoutError as e:
                if self.source.subscribe:
                    return None
                else:
                    raise e


            #if isinstance(part_row, dict):
            #    tolerance = 0
            #    if '_index_relative' in row and '_index_relative' in part_row and abs(row['_index_relative'] - part_row['_index_relative'])>tolerance:

            #        raise IndexError(f'Input sources indices misaligned or not equal! {row["_index_relative"]} =! {part_row["_index_relative"]}')
                #print(part_row)
                #TODO REMOVE HOT FIX BECAUSE CSV DATA IS USING HOURS NOT SECONDS!!!
                #if '_index' in part_row and source.source_type == 'csv':
                #    part_row['_index'] *= 3600

            row.update(part_row)
            #else:
            #    print('pr: ', part_row)
            #    raise ValueError('!')
        #row.pop('_index')
        row.update({'_index': t, '_index_relative': t - self.t0})
        #print('im row: ', row)
        return row


class RepeatedFunction:
    def __init__(self, interval, function, run_initially=False, *args, **kwargs):
        self._timer = None
        self.interval = interval
        self.function = function
        self.args = args
        self.kwargs = kwargs
        self.is_running = False
        self.next_call = time()

        if run_initially:
            self.function(*self.args, **self.kwargs)

    def _run(self):
        self.is_running = False
        self.start()
        self.function(*self.args, **self.kwargs)

    def start(self):
        if not self.is_running:
            self.next_call += self.interval
            self._timer = threading.Timer(self.next_call - time(), self._run)
            self._timer.start()
            self.is_running = True

    def stop(self):
        self._timer.cancel()
        self.is_running = False


class Test(NumerousClient):
    def __init__(self):
        super(Test, self).__init__()


class NumerousAdminClient(NumerousClient):
    def __init__(self, server=None, port=None, secure=None, refresh_token=None, instance_id=None, url=None):
        super(NumerousAdminClient, self).__init__(
            server=server, port=port, secure=secure, refresh_token=refresh_token, instance_id=instance_id, project="", scenario="", job_id="", url=url
        )
        self._job_manager = spm_pb2_grpc.JobManagerStub(self.channel)
        self._build_manager = spm_pb2_grpc.BuildManagerStub(self.channel)
        self._token_manager = spm_pb2_grpc.TokenManagerStub(self.channel)

        self._terminated = False
        self._complete = False
        self._user = 'test_user'
        self._organization = 'EnergyMachines'

        log.info('init admin client')

    def launch_job(self, project, scenario, job):
        self._job_manager.StartJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job,
            user_id=self._user, organization_id=self._organization
        ))

    def terminate_job(self, project, scenario, job):
        terminated_job = self._job_manager.TerminateJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job, user_id=self._user, organization_id=self._organization
        ))

    def reset_job(self, project, scenario, job):
        reset_job = self._job_manager.ResetJob(spm_pb2.Job(
            project_id=project, scenario_id=scenario, job_id=job, user_id=self._user, organization_id=self._organization
        ))

    def listen_executions(self):
        for doc in self._base_client.ListenExecutions(spm_pb2.Empty()):
            yield json.loads(doc.json)

    def update_execution(self, update_):
        self._base_client.UpdateExecution(spm_pb2.Json(json=json.dumps(update_, default=datetime_to_json_converter)))

    def get_execution_status(self, exe_id):
        return self._job_manager.GetExecutionStatus(spm_pb2.ExecutionId(execution_id=exe_id)).json

    def clear_execution(self, scenario, execution):
        self._base_client.ClearExecutionMemory(spm_pb2.Scenario(scenario=scenario, execution=execution))

    def remove_data_execution(self, scenario, execution):
        self._base_client.RemoveExecutionData(spm_pb2.Scenario(scenario=scenario, execution=execution))

    def remove_execution(self, scenario, execution):
        self.remove_data_execution(scenario, execution)
        self.remove_execution_schedule(execution)

    def delete_execution(self, exe_id):
        self._job_manager.DeleteExecution(spm_pb2.ExecutionId(execution_id=exe_id))

    def update_job_by_backend(self, project, scenario, job, exe, message, status, log=None, complete=False):
        self._base_client.SetScenarioProgress(spm_pb2.ScenarioProgress(
            project=project, scenario=scenario, job_id=job,
            message=message, status=status, progress=0
        ))

        if log is not None:
            self._base_client.PushExecutionLogEntries(spm_pb2.LogEntries(
                scenario=scenario,
                execution_id=exe, log_entries=[log],
                timestamps=[datetime.utcnow().timestamp()]
            ))


        if complete:
            self._base_client.CompleteExecutionIgnoreInstance(
                spm_pb2.ExecutionMessage(project_id=project, scenario_id=scenario, job_id=job,
                                         execution_id=exe))

        return spm_pb2.Empty()

    def read_execution_logs(self, execution=None):
        if execution is None:
            execution=self._execution_id

        for l in self._base_client.ReadExecutionLogEntries(spm_pb2.ExecutionReadLogs(execution_id=execution, start=0, end=0)):
            yield l.log_entry, datetime.fromtimestamp(l.timestamp)

    def hibernate_job(self, project_id, scenario_id, job_id, organization_id, user_id):
        self._job_manager.HibernateJob(spm_pb2.Job(
            project_id=project_id, scenario_id=scenario_id, job_id=job_id,
            organization_id=organization_id, user_id=user_id
        ))

    def resume_job(self, project_id, scenario_id, job_id, organization_id, user_id):
        self._job_manager.ResumeJob(spm_pb2.Job(
            project_id=project_id, scenario_id=scenario_id, job_id=job_id,
            organization_id=organization_id, user_id=user_id
        ))

    def add_job_schedule(self, project_id, scenario_id, job_id, organization_id, user_id, sleep_after, sleep_for, enable_scheduling):
        self._job_manager.AddJobSchedule(spm_pb2.JobSchedule(
            job=spm_pb2.Job(
                project_id=project_id, scenario_id=scenario_id, job_id=job_id,
                organization_id=organization_id, user_id=user_id
            ),
            schedule=spm_pb2.Schedule(sleep_after=sleep_after, sleep_for=sleep_for, enable_scheduling=enable_scheduling)
        ))

    def update_job_schedule(self, project_id, scenario_id, job_id, organization_id, user_id, sleep_after, sleep_for, enable_scheduling):
        self._job_manager.AddJobSchedule(spm_pb2.JobSchedule(
            job=spm_pb2.Job(
                project_id=project_id, scenario_id=scenario_id, job_id=job_id,
                organization_id=organization_id, user_id=user_id
            ),
            schedule=spm_pb2.Schedule(sleep_after=sleep_after, sleep_for=sleep_for, enable_scheduling=enable_scheduling)
        ))

    def remove_job_schedule(self, project_id, scenario_id, job_id, organization_id, user_id):
        self._job_manager.RemoveJobSchedule(spm_pb2.Job(
            project_id=project_id, scenario_id=scenario_id, job_id=job_id,
            organization_id=organization_id, user_id=user_id
        ))

    def get_job_schedule(self, project_id, scenario_id, job_id, organization_id, user_id):
        return self._job_manager.GetJobSchedule(spm_pb2.Job(
            project_id=project_id, scenario_id=scenario_id, job_id=job_id,
            organization_id=organization_id, user_id=user_id
        ))

    def add_execution_schedule(self, execution_id, sleep_after, sleep_for, enable_scheduling):
        self._job_manager.AddExecutionSchedule(spm_pb2.ExecutionAndSchedule(
            execution=spm_pb2.ExecutionId(execution_id=execution_id),
            schedule=spm_pb2.Schedule(sleep_after=sleep_after, sleep_for=sleep_for, enable_scheduling=enable_scheduling)
        ))

    def update_execution_schedule(self, execution_id, sleep_after, sleep_for, enable_scheduling):
        self._job_manager.UpdateExecutionSchedule(spm_pb2.ExecutionAndSchedule(
            execution=spm_pb2.ExecutionId(execution_id=execution_id),
            schedule=spm_pb2.Schedule(sleep_after=sleep_after, sleep_for=sleep_for, enable_scheduling=enable_scheduling)
        ))

    def remove_execution_schedule(self, execution_id):
        self._job_manager.RemoveExecutionSchedule(spm_pb2.ExecutionId(execution_id=execution_id))

    def get_execution_schedule(self, execution_id):
        return self._job_manager.GetExecutionSchedule(spm_pb2.ExecutionId(execution_id=execution_id))

    def _create_gzipped_tarball(self, source: Dir) -> Path:
        files = source.files()
        tar_name = f"{uuid4()}.tar.gz"
        with tarfile.open(tar_name, mode="w:gz") as tar:
            for file_name in files:
                file_path = Path(source.path) / str(file_name)
                info = tarfile.TarInfo(str(PurePosixPath(Path(file_name))))
                info.size = file_path.stat().st_size
                with file_path.open("rb") as file:
                    tar.addfile(info, fileobj=file)
        return Path(tar_name)
    
    def _extract_gzipped_tarball(self, file, path: Path) -> None:
        with tarfile.open(fileobj=file, mode="r:gz") as tar:
            log.debug("Pulled tar containing:")
            for name in tar.members:
                log.debug(f" - {name}")
            tar.extractall(str(path))

    def make_commit(self, path: str, repository: str = 'test', branch: str = 'master', comment: str = None, parent_commit_id: str = None, dry_run: bool = False):
        if comment is None:
            raise ValueError("comment cannot be null.")

        dir = Dir(path)
        snapshot_id = dir.hash(index_func=filehash)

        dir_state = DirState(dir, index_cmp=filehash)
        state_ = dir_state.state.copy()
        del state_["directory"]
             
        if dry_run:
            return None, snapshot_id, state_

        try:
            snapshot = self._build_manager.CreateSnapshot(spm_pb2.Snapshot(repository=repository, snapshot_id=snapshot_id, file_structure=json.dumps(state_), only_get_uploads=False))
        except grpc._channel._InactiveRpcError as error:
            if error.code != grpc.StatusCode.ALREADY_EXISTS:
                log.warning(f"Snapshot {snapshot_id} aleady exists. Ignoring.")
            history = self._build_manager.GetHistory(spm_pb2.HistoryRequest(repository=repository))
            for update in history.updates:
                if update.snapshot_id == snapshot_id:
                    return update.id, snapshot_id, None
            raise

        tar_path = self._create_gzipped_tarball(dir)
        log.debug(f"Created tarball {tar_path} for {path}")
        with open(tar_path, "rb") as tar_file:
            upload_response = requests.post(
                snapshot.upload_url,
                data=tar_file,
                headers={"Content-Type": "multipart/related"},
                params={
                    "name": snapshot_id,
                    "mimeType": "application/octet-stream"
                },
            )
            log.debug(f"Uploaded snapshot {snapshot_id} got response: {upload_response.status_code} {upload_response.reason}")
        tar_path.unlink()
        histroy_reply = self._build_manager.CreateHistoryUpdate(spm_pb2.HistoryUpdate(
            repository=repository,
            branch=branch,
            snapshot_id=snapshot_id,
            parent_commit_id=parent_commit_id,
            comment=comment,
            timestamp=datetime.utcnow().timestamp()
        ))
        log.debug(f'Commit complete for snapshot {snapshot_id}')

        return histroy_reply.commit_id, snapshot_id, None



    def pull_snapshot(self, repository: str, snapshot_id: str, path: Path):
        snapshot = self._build_manager.PullSnapshot(spm_pb2.PullRequest(repository=repository, snapshot_id=snapshot_id))
        response = requests.get(snapshot.archive_url.url, allow_redirects=True)
        self._extract_gzipped_tarball(io.BytesIO(response.content), path)


    def generate_scenario_token(self, project, scenario, job, user_id, organization_id, agent, purpose, admin=False):

        token = self._token_manager.CreateRefreshToken(spm_pb2.TokenRequest(
            project_id = project,
            scenario_id = scenario,
            job_id = job,
            user_id=user_id,
            organization_id=organization_id,
            agent=agent,
            purpose=purpose,
            admin=admin
            )).val

        return token

    def enqueue_build(self, repository, commit_id, launch_after_build=False):
        resp = self._build_manager.LaunchBuild(spm_pb2.BuildRequest(repository=repository, commit_id=commit_id, launch_after_build=launch_after_build))
        return resp.build_id

    def update_scenario_image(self, name, path, project, scenario, job, repository, commit, build, internal=True):
        self._build_manager.UpdateJobImageReference(spm_pb2.JobImageReferenceUpdate(project = project, scenario = scenario, job = job, name=name, path=path,
                                                                                    repository=repository, commit_id=commit, internal=internal, build=build
                                                                                    ))

    def build_worker(self, work_function, *args):
        reponse_queue = Queue()

        def responses():
            while True:
                yield spm_pb2.BuildResponse(status=reponse_queue.get())

        for work in self._build_manager.Builder(responses()):
            reponse_queue.put('ACK')
            reponse_queue.put(work_function(self, work.json, *args))

    def update_build(self, repository, build_id, **update):
        self._build_manager.UpdateBuild(spm_pb2.BuildUpdate(repository=repository, build_id=build_id, update=json.dumps(update)))

    def get_build_info(self, repository, build_id):
        response = self._build_manager.GetBuildInfo(spm_pb2.BuildInfoRequest(repository=repository, build_id=build_id))
        return json.loads(response.info)

    def get_history(self, repository, branch=None, history_updates=None):
        response = self._build_manager.GetHistory(spm_pb2.HistoryRequest(repository=repository, branch=branch, history_updates=history_updates))
        return response.updates

    def push_scenario_logs_admin(self, logs, exe_id):
        timestamps = [l[0].timestamp() for l in logs]

        self._base_client.PushExecutionLogEntries(spm_pb2.LogEntries(
            scenario='',
            execution_id=exe_id, log_entries=[l[1] for l in logs],
            timestamps=timestamps
        ))

    def delete_system(self, system):
        self._base_client.DeleteSystem(spm_pb2.SystemRequest(system=system))

    def delete_user(self, user):
        self._base_client.DeleteUser(spm_pb2.UserRequest(user=user))

    def complete_signup(self, invitationID, firstName, lastName, organization):
        self._base_client.CompleteUserSignup(spm_pb2.SignupData(
            invitationID=invitationID,
            firstName=firstName,
            lastName=lastName,
            organization=organization

        ))

    def login_get_refresh(self, email, password):
        return self._token_manager.LoginGetRefreshToken(spm_pb2.LoginCredentials(
            email =email,
            password = password,
            token_request = spm_pb2.TokenRequest(
                project_id = "ssfsdf",
                scenario_id = "sdffd",
                admin = True,
                execution_id ="sdsfd",
                job_id = "sfdsdf",
                user_id = "sfsdfsdfsdf",
                organization_id = "sdfsfd",
                access_level = 4
            )

        ))

    def get_cluster_info(self):

        return self._job_manager.GetClusterInfo(spm_pb2.ClusterInfoRequest())

@contextmanager
def open_admin_client(server=None, port=None, secure=None, refresh_token=None) -> NumerousAdminClient:
    numerous_client_ = NumerousAdminClient(server=server, port=port, secure=secure, refresh_token=refresh_token)

    try:
        yield numerous_client_

    finally:
        numerous_client_.close()
