import logging
import uuid
import grpc
import json
import os
import threading
from time import sleep
from uuid import uuid4
from copy import deepcopy
from datetime import datetime
from queue import Queue, Empty

import _queue
import redis_namespace

from .bigtable import Bigtable
from .utilities import NumerousBaseException
from .deployment import Deploy
from .token_validation.validation import AccessLevel

from . import tokens as token_manager
from . import firebase


bigtable = Bigtable()
deployment_client = Deploy()


redis_host = os.getenv('REDIS_HOST')
redis_port = os.getenv('REDIS_PORT')


SCHEDULE_NAME = "schedule"
maxlen_stream = 1000


log = logging.getLogger('numerous_api.spm')
log.setLevel(logging.DEBUG)


org = os.getenv('NUMEROUS_ORGANIZATION')
if org is None:
    raise ValueError('No organtization specified!')


log.info(f"SPM connecting to redis at: redis://{redis_host}:{redis_port}")
redis = None
while redis is None:
    try:
        redis = redis_namespace.Redis.from_url(f"redis://{redis_host}:{redis_port}", namespace=org+":")
    except:
        sleep(5) 


class RedisQueue:
    """Simple Queue with Redis Backend"""
    def __init__(self, name, namespace='queue'):
       """The default connection parameters are: host='localhost', port=6379, db=0"""
       self.__db= redis
       self.key = '%s:%s' %(namespace, name)

    def qsize(self):
        """Return the approximate size of the queue."""
        return self.__db.llen(self.key)

    def empty(self):
        """Return True if the queue is empty, False otherwise."""
        return self.qsize() == 0

    def put(self, item):
        """Put item into the queue."""
        self.__db.rpush(self.key, item)

    def get(self, block=True, timeout=None):
        """Remove and return an item from the queue.

        If optional args block is true and timeout is None (the default), block
        if necessary until an item is available."""
        if block:
            item = self.__db.blpop(self.key, timeout=timeout)
        else:
            item = self.__db.lpop(self.key)

        if item:
            item = item[1]
        return item

    def get_nowait(self):
        """Equivalent to get(False)."""
        return self.get(False)


build_queue = RedisQueue('builds', 'numerous')


def d8(bytes_):
    return bytes_.decode('UTF-8')

def stream_exists(scenario, execution):
    key = scenario + '_' + execution
    stream = key + '_stream'
    exist_ = redis.xlen(stream)
    log.debug(stream + " exist " + str(exist_))

    return exist_ > 0

def set_scenario_execution_data_closed(scenario, execution, closed=True):
    bigtable.set_custom_meta_data(scenario, execution, 'data_closed', json.dumps({'closed': closed}))

def get_scenario_execution_data_closed(scenario, execution):

    closed = bigtable.get_custom_meta_data(scenario, execution, 'data_closed')
    if closed is None or 'closed' not in closed:
        stream_missing = not stream_exists(scenario, execution)

        return stream_missing

    log.debug(execution + str(" is closed ") + str(closed['closed']))
    return closed['closed']

def set_scenario_metadata(scenario: str, execution, tags: list, aliases: dict, offset: float, timezone:str, epoch_type: str):
    log.debug(f'setting scenario meta data for {scenario}')
    meta = {'offset': offset, 'tags': ';'.join([t.name for t in tags]),
            'timezone': timezone, 'epoch_type': epoch_type, 'aliases': "__".join([a.tag + "::" + ";".join(a.alias) for a in aliases])
           }

    redis.hmset(scenario+'_'+execution, meta)

    bigtable_meta = {'offset': offset, 'tags':
    [
        {'name': t.name,
        'displayName': t.displayName,
         'unit': t.unit,
      'description': t.description,
      'type': t.type,
     'scaling': t.scaling,
     'offset': t.offset}
    for t in tags],

            'timezone': timezone, 'epoch_type': epoch_type, 'aliases': {a.tag: a.alias for a in aliases}
           }
    bigtable.set_meta_data(scenario, execution, **bigtable_meta)

def get_scenario_metadata_internal(scenario: str, execution):
    log.debug(f'getting scenario meta data for {scenario}, {execution}')
    meta = redis.hgetall(scenario+'_'+execution)
    if meta is None or meta == {}:
        # escalate to bigtable
        meta = bigtable.get_meta_data(scenario, execution)
        # store in redis
        if meta is not None:
            redis.hmset(scenario+'_'+execution, meta)
        return meta
    else:
        meta = {k.decode('UTF-8'): v.decode('UTF-8') for k, v in meta.items()}

        meta['tags'] = meta['tags'].split(';') if len(meta['tags'])>0 else []


        meta['offset'] = float(meta['offset'])

        return meta


def get_scenario_metadata(scenario: str, execution):
    log.debug(f'getting scenario meta data for {scenario}, {execution}')

    meta = bigtable.get_meta_data(scenario, execution)

    return meta


def set_scenario_custom_metadata(scenario: str, execution, key:str, meta:str):
    log.debug(f'setting custom scenario meta data for {scenario}')
    bigtable.set_custom_meta_data(scenario, execution, key=key, meta=meta)

def get_scenario_custom_metadata(scenario: str, execution, key:str):
    log.debug(f'getting scenario custom meta data for {scenario}, {execution}')
    meta = bigtable.get_custom_meta_data(scenario,execution, key=key)

    return meta

def listen_row(scenario, execution, tags, subscribe=False):
    log.debug('Listen row:')
    key = scenario +'_'+ execution
    stream = key + '_stream'

    last_id = redis.get(key+'_row_flushed')
    if last_id is None:
        last_id = 0
    else:
        last_id = d8(last_id)

    log.debug('Last flushed: '+str(last_id))
    log.debug('Stream exists: '+str(redis.exists(stream)))

    #if redis.exists(stream):

    while True:
        #data = redis.xread({stream: last_id}, None, block=subscribe, count=1000)
        data = redis.xread({stream: last_id}, None, 1000 if subscribe else 1000)

        #redis.xr
        #log.debug('data len from redis: '+str(len(data)))
        if len(data) > 0:
            for d in data:
                #log.debug('stream message: '+str(d))
                #log.debug('stream: ' + str(stream))
                #log.debug('d0: ' + str(d[0]))
                #log.debug('d01: '+str(d8(d[0])))

                if d[0] == stream:
                    #log.debug('stream ok')
                    for m in d[1]:

                        m_id = d8(m[0])
                        #print(m_id)

                        payload = {d8(k): d8(v) for k, v in m[1].items()}

                        if '_MESSAGE_' in payload:
                            #log.debug('Message: ' +str(payload['_MESSAGE_']))
                            if payload['_MESSAGE_'] == 'FINALIZED':
                                return

                            elif payload['_MESSAGE_'] == 'ROWS_MOVED':

                                if payload['_ROWS_'].split(';')[-1]!=last_id:
                                    raise ValueError('Arg')

                            elif payload['_MESSAGE_'] == 'RESET':
                                pass
                            elif payload['_MESSAGE_'] == 'STARTED':
                                pass

                            elif payload['_MESSAGE_'] == 'BLOCKS_NOT_STREAMED':
                                blocks = payload['_BLOCKS_'].split(';')
                                skipped_rows = read_data(scenario, execution, tags, int(blocks[0]), int(blocks[-1])+1, False, False, True)
                                for sr in skipped_rows:
                                    yield sr

                            else:
                                raise KeyError('Unknown message: ', payload['_MESSAGE_'])

                        #Ignore Reset...

                        else:

                            #log.debug(payload)
                            data = []
                            for k, v in payload.items():
                                data.append(dict(tag=k, values=[float(v_) for v_ in v.split(';')]))

                            yield (data, True, False)

                        last_id = m_id

        if not subscribe:
            break



def read_data(scenario: str, execution:str, tags: list, start: float, end: float, time_range=True, listen=False, skip_stream=False):
    if tags is None:
        tags=[]



    if len(tags) == 0:
        meta = get_scenario_metadata_internal(scenario, execution)
        if meta is not None:
            tags = meta['tags']
        elif not listen:
                return

    elif not '_index' in tags:
        tags = ['_index'] + tags

    if len(tags) != 0:
        if time_range:

            log.debug(f'reading time range.')
            a = bigtable.read_time_range(scenario, execution, tags, start, end)

        else:
            log.debug(f'reading block range, exe {execution} scenario: {scenario} start: {int(start)}, end: {int(end)}: '+str(tags[:10]))

            a = bigtable.read_block_range(scenario, execution, tags, int(start), int(end))

        count = 0
        for d in a:
            count+=1
            yield [{'tag': d_.tag, 'values': d_.values} for d_ in d.data], d.row_complete, d.block_complete
        #log.debug(f'Read {count} blocks')

    count = 0
    if not skip_stream:
        listen_gen = listen_row(scenario, execution, tags, listen)
        for l in listen_gen:
            count += 1
            yield l
        #log.debug(f'And from stream read {count} blocks')

def read_data_stats(scenario: str, execution:str, tag='_index'):
    stats = bigtable.read_data_stats(scenario=scenario, execution=execution, tag=tag)
    return stats

def set_block_counter(scenario: str, execution, block_counter, only_redis=False):

    redis.set(scenario+'_'+execution+'_block_counter', str(block_counter).encode('UTF-8'))

    #if not only_redis:
    #    bigtable.set_block_counter(scenario, execution, block_counter)

def get_block_counter(scenario: str, execution):
    block_counter_ = redis.get(scenario+'_'+execution+'_block_counter')

    if block_counter_ is None:
        # escalate to bigtable
        block_counter = bigtable.get_block_counter(scenario, execution)
        if block_counter is not None:
            block_counter = int(block_counter)

        # store in redis
        if block_counter is not None:
            set_block_counter(scenario, execution, block_counter, only_redis=True)
        else:
            block_counter = 0

        return block_counter

    else:
        block_counter = int(block_counter_.decode('UTF-8')) if block_counter_ is not None else 0

        return block_counter

def clear_data(scenario:str, execution, only_in_memory=False):
    key = scenario + '_'+execution
    stream = key + '_stream'

    redis.delete(key)
    redis.delete(stream)
    redis.delete(key + '_block_counter')
    if not only_in_memory:
        pass
        #bigtable.clear(scenario, execution)

    redis.xtrim(stream, 0)
    redis.xadd(stream, fields={'_MESSAGE_': 'RESET'})


def close_data(scenario, execution, eta, finalized):
    key = scenario + '_' + execution
    stream = key + '_stream'
    log.debug(str(stream)+' close data: '+str(finalized))
    if finalized:

        redis.xadd(stream, fields={"_MESSAGE_": "FINALIZED"})
        #redis.expire(stream, 5)

class WriteBuffer:

    def __init__(self, scenario, execution, disable_stream=False, clear=False, blockcounter=None):
        self.data = {}
        self.scenario = scenario
        self.execution = execution

        self._key = "_".join([scenario, execution])

        self.size = 0
        self.max_size_block = 1000*100
        self.row_size = 0
        self.row_counter = 0
        self.last_row_id = None
        self.rows_in_block = []

        self.max_stream_size = 1000000
        self.stream_disabled = disable_stream
        self.has_stream = False

        self.lock = None
        self.lock_write = None
        if clear:
            clear_data(self.scenario, self.execution)


        self.block_counter = get_block_counter(self.scenario, self.execution) if blockcounter is None else blockcounter

        stream = self._key + "_stream"
        log.debug('Init stream: '+stream)
        id = redis.xadd(stream,
                        fields={'_MESSAGE_': 'STARTED'}, maxlen=maxlen_stream)

    def acquire_lock(self):
        while True:
            if self.lock is None:

                self.lock =str(uuid4())

                return self.lock
            else:
                sleep(0.1)


    def release_lock(self, lock):
        if lock == self.lock:

            self.lock = None
        else:
            raise KeyError('releasing wrong lock!')

    def acquire_lock_write(self):
        while True:
            if self.lock_write is None:

                self.lock_write =str(uuid4())

                return self.lock_write
            else:
                sleep(0.1)

    def release_lock_write(self, lock):
        if lock == self.lock_write:

            self.lock_write = None
        else:
            raise KeyError('releasing wrong lock!')


    def add_data(self, tag, values, row_complete=False, block_complete=False):
        lock=self.acquire_lock_write()

        if tag not in self.data:
            self.data[tag] = []

        self.data[tag] += values
        self.size += len(values)

        #if tag=='_index':
        #    self.row_size+=len(values)

        self.release_lock_write(lock)

    def push_row(self, ignore_index):
        log.debug('push row -  ignoring ix: '+str(ignore_index))
        if '_index' in self.data or ignore_index:
            #log.debug('push index')
            lock = self.acquire_lock()
            lock_write = self.acquire_lock_write()
            r_f = self.row_counter
            tag = '_index' if not ignore_index else list(self.data.keys())[0]

            r_t = len(self.data[tag])
            #log.debug('r_t: '+str(r_t) + ' r_f: ' + str(r_f))
            if r_t>r_f:

                def push_row_thread(self, lock, data, r_f, r_t):

                    stream = self._key + "_stream"
                    log.debug('Push to stream: '+ stream)
                    id = redis.xadd(stream, fields={k.encode("UTF-8"): ";".join([str(val) for val in v[r_f:r_t]]).encode('UTF-8') for k, v in data.items()}, maxlen=maxlen_stream)

                    self.last_row_id = d8(id)
                    self.rows_in_block.append(self.last_row_id)

                    self.release_lock(lock)

                log.debug('push data')
                self.row_counter = r_t
                self.row_size = r_f
                self.release_lock_write(lock_write)
                push_row_thread_ = threading.Thread(target=push_row_thread, args=(self,lock, deepcopy(self.data),r_f,r_t))
                push_row_thread_.start()
                #log.debug('Push stated')
            else:
                self.release_lock(lock)
                self.release_lock_write(lock_write)

    def complete(self, block_complete=False, row_complete=False, ignore_index=False):
        log.debug('Complete')
        if row_complete or block_complete:

            if self.size>0 and (self.size<=self.max_stream_size or self.has_stream) and not self.stream_disabled:
                self.has_stream = True
                self.push_row(ignore_index)

                if not block_complete and self.size>self.max_size_block:
                    log.debug('Flushed due to buffer size reached!')
                    self.flush()

            else:
                self.stream_disabled=True

        if block_complete:

            self.flush()

    def flush(self, join_=False):

        def flush_thread(self, lock, data):

            if self.size>0:
                if self.last_row_id is None:
                    self.last_row_id = 0
                redis.set(self._key + '_row_flushed', self.last_row_id)
                stream = self._key +'_stream'

                #if self.block_counter == 0:
                #    set_scenario_metadata(scenario=self.scenario, tags=list(data.keys()), start=data['_index'][0], aliases=dict())
                self.block_counter_last = self.block_counter
                log.debug('Flushing to bigtable! - '+str(self._key))

                self.block_counter = bigtable.push_data_version_dict(self.scenario, self.execution, data, block_length=1000)
                set_block_counter(self.scenario, self.execution, self.block_counter)

                if self.stream_disabled or self.last_row_id == 0:
                    redis.xadd(stream, fields={'_MESSAGE_': 'BLOCKS_NOT_STREAMED', '_BLOCKS_': ";".join(
                        [str(self.block_counter_last + 1), str(self.block_counter)])})

                else:
                    redis.xadd(stream, fields={'_MESSAGE_': 'ROWS_MOVED', '_ROWS_': ";".join(self.rows_in_block),
                                               '_BLOCKS_': ";".join(
                                                   [str(self.block_counter_last + 1), str(self.block_counter)])})
                    redis.xdel(stream, *self.rows_in_block)

                    self.rows_in_block=[]

            self.release_lock(lock)

        write_lock = self.acquire_lock_write()

        lock = self.acquire_lock()

        flush_thread_ = threading.Thread(target=flush_thread, args=(self,lock, deepcopy(self.data)))
        flush_thread_.start()

        self.data = {k: [] for k in self.data.keys()}

        self.size = 0
        self.row_size = 0
        self.row_counter = 0



        self.release_lock_write(write_lock)


        if join_:
            flush_thread_.join()

    def close(self):

        self.flush(join_=True)

#endpoints implementations for document data in firebase

def myconverter(o):
    if isinstance(o, datetime):
        return o.__str__()

def read_scenario(project, scenario):
    scenario_document, files = firebase.read_scenario(project, scenario)
    return json.dumps(scenario_document, default=myconverter), files

def read_group(project, group):
    group_document = firebase.read_group(project, group)
    return json.dumps(group_document, default=myconverter)

def read_project(project):
    project_document = firebase.read_project(project)
    return json.dumps(project_document, default=myconverter)


def delete_scenario(project, scenario):
    firebase.delete_scenario(project, scenario)


def delete_system(system):
    firebase.delete_system(system)


def delete_user(user):
    firebase.delete_user(user)


def delete_scenario_data(scenario, execution, columns):
    bigtable.delete_columns(scenario, execution, columns)


def listen_scenario(project, scenario):
    for doc in firebase.listen_scenario(project, scenario):
        yield json.dumps(doc, default=myconverter)


def set_data_tags(project: str, scenario: str, tags: list):
    firebase.set_data_tags(project, scenario, tags)


def push_formatted_error(project:str, scenario:str,message=None, hint=None, category=None, exception_object_type=None,
        exception_object_message=None, full_traceback=None, initializing=False):

    firebase.push_formatted_error(
        project, scenario, message, hint, category, exception_object_type,
        exception_object_message, full_traceback, initializing
        )


def complete_exe_ignore_instance(project_id, scenario_id, job_id, execution_id):
    # Endpoint for instance to complete execution
    job = firebase.get_job(project_id, scenario_id, job_id)
    if job is not None:
        active_exe = job['active_execution'] if 'active_execution' in job else None

        if active_exe == execution_id:
            firebase.complete_execution(active_exe)
            firebase.clear_active_exe(project_id, scenario_id, job_id)

    close_data(scenario_id, execution_id, None, True)


def delete_execution(execution_id):
    execution = firebase.get_execution(execution_id)

    if "launch_details" not in execution:
        return

    deployment_client.delete_job(
        execution['launch_details']['name'],
        execution['launch_details']['namespace'],
        execution['launch_details']['cluster']
    )


def push_logs(execution: str, log_entries, timestamps):
    if len(timestamps) == len(log_entries):
        redis.publish(execution, json.dumps({'logs': [l for l in log_entries], 'timestamps': [t for t in timestamps]}))
        bigtable.push_log_entries(execution, log_entries=log_entries, timestamps=timestamps)
    else:
        raise IndexError("logs and timestamps must have same length.")


def read_logs_timerange(execution, start, end):
    log.debug(f'Reading log for {execution}, start {start}, end {end}')
    for l, t in bigtable.read_logs_time_range(execution, start=start, end=end):
        yield l, t


def get_files_signed_urls(paths):
    files = firebase.get_files_signed_urls(paths)
    return files


def generate_resumable_upload_url(project, scenario, file, file_id, content_type):
    url = firebase.generate_resumable_upload_url(project, scenario, file, file_id, content_type)
    return {'scenario': scenario, 'url': url}


def get_model(model_id, project_id, scenario_id):
    model, model_file_urls = firebase.get_model(model_id, project_id, scenario_id)
    return json.dumps(model, default=myconverter), model_file_urls


def set_results(project_id, scenario_id, result_names, values, units):
    firebase.set_results(project_id, scenario_id, result_names, values, units)


def clear_results(project_id, scenario_id):
    firebase.clear_results(project_id, scenario_id)


def get_results(project_id, scenario_id):
    results = firebase.get_results(project_id, scenario_id)
    result_dict = {'names': [], 'values': [], 'units': []}

    for r in results:
        result_dict['names'].append(r['colName'])
        result_dict['values'].append(r['value'])
        result_dict['units'].append(r['unit'])

    return result_dict


def subscribe_channels(channels, terminate_=None):
    pubsub = redis.pubsub()
    sub_queue = Queue()
    def sub_thread():
        def message_handler(m):
            sub_queue.put(m)

        pubsub.psubscribe(**{c: message_handler for c in channels})

        while terminate_ is None or not terminate_.is_set():
            message = pubsub.get_message()
            if message is None:
               sleep(0.1)

        pubsub.close()

    t = threading.Thread(target=sub_thread)
    t.start()

    while terminate_ is None or not terminate_.is_set():
        try:
            message = sub_queue.get(timeout=1)
            yield message
        except _queue.Empty:
            pass
    t.join()
    log.debug('Subscription closed.')

def publish_messages(channels, messages):

    for channel, message in zip(channels, messages):
        if not isinstance(message, str):
            message = json.dumps(message)
        redis.publish(channel, message)

def enqueue_build(repository, commit_id, project_id, scenario_id, job_id, launch_after_build=False):
    history_doc = firebase.get_history_entry('default', repository, commit_id)

    history_doc['timestamp'] = history_doc['timestamp'].timestamp()
    build_id = firebase.create_build('default', repository, history_doc['snapshot'], commit_id)

    json_ = json.dumps({
        'repository': repository,
        'commit_id': commit_id,
        'history_doc': history_doc,
        'project_id': project_id,
        'scenario_id': scenario_id,
        'job_id': job_id,
        'build_id': build_id,
        'launch_after_build': launch_after_build
    }).encode('utf-8')

    build_queue.put(json_)
    #redis.rpush('build_queue', '', json_)

    firebase.update_history_update('default', repository, commit_id, {'build_id': build_id})

    return build_id


def dequeue_build_request(replies):
    print('Build worker init')

    work_reply_queue = Queue()
    status = {'continue': True}

    def work_messages():
        try:
            for reply in replies:
                work_reply_queue.put(reply)
                print('reply: ', reply)
        except grpc.RpcError:
            #work_reply_queue.put("STOP")
            status['continue']= False

    # start celery worker

    wm = threading.Thread(target=work_messages, daemon=True)

    wm.start()

    while status['continue']:
        b = None
        while status['continue'] and b is None:
            try:
                b = build_queue.get(timeout=1)
            except Empty:
                b = None
                sleep(1)
            if b is not None:
                break
        if status['continue']:
            yield b

            def check_build_reply(checks):
                for c in checks:
                    try:
                        reply = work_reply_queue.get(timeout=c[1])

                        if not hasattr(reply, 'status') or reply.status != c[0]:
                            log.debug('build not acked')
                            if c[2]=='resubmit':
                                build_queue.put(b)
                            return False

                    except Empty:
                        if c[2] == 'resubmit':
                            build_queue.put(b)
                        return False

                return True

            status['continue'] = check_build_reply([('ACK', 5, 'resubmit'), ('OK', None, None)])

    log.debug('Closed builder worker')


def start_job(
    project_id: str, scenario_id: str, job_id: str, server_address: str, port: str, user_id: str, organization_id: str,
    execution_id_override: str or None = None, clear_data_flag: bool = True, image_url: str or None = None,
    admin: bool = False, resumed=False, secure_channel=True
) -> None:
    """
    Starts a job in kubernetes and updates firestore appropriately.
    :param project_id: ID of job project
    :param scenario_id: ID of job scenario
    :param job_id: Job ID
    :param server_address: The address the job should use to contact the API
    :param user_id: ID of the user that launched the job
    :param organization_id: ID of organization that launched the job
    :param execution_id_override: Manually set execution ID for job
    :param clear_data_flag: True to erase data from last time job ran. False otherwise
    :param image_url: Url of the image to be assigned to the pod
    :param admin: Whether or not to generate admin token for job
    """
    job = firebase.get_job(project_id, scenario_id, job_id)
    if job is None:
        raise ValueError(f"Job does not exist for project_id={project_id}, scenario_id={scenario_id}, job_id={job_id}")

    log.debug('Job: %s', job)

    if image_url is None:
        image_url = _get_or_def(job, ['image', 'path'], None, True)

    # Check if job is already active - if so cannot launch
    active_exe = job['active_execution'] if 'active_execution' in job else None
    if active_exe is not None and execution_id_override != active_exe:
        log.debug('Skipping launch - job has active execution!')
        raise NumerousBaseException(f"There is already an active exe for this job.")

    # Generate a unique id for this execution
    exe_id = str(uuid.uuid4()) if execution_id_override is None else execution_id_override
    kubernetes_job_name = f"execution-{exe_id}"

    # Generate refresh token
    refresh_token = token_manager.generate_refresh_token(
        admin=admin, project_id=project_id, scenario_id=scenario_id, execution_id=exe_id,
        job_id=job_id, user_id='api', organization_id='api', agent='api', purpose='api',
        access_level=AccessLevel.WRITE
    )

    log.debug('Exe id: %s', exe_id)
    env_variables = {
        'NUMEROUS_API_REFRESH_TOKEN': refresh_token,
        'NUMEROUS_API_PORT': str(port),
        'NUMEROUS_API_SERVER': server_address,
        'NUMEROUS_PROJECT': project_id,
        'NUMEROUS_SCENARIO': scenario_id,
        'NUMEROUS_EXECUTION_ID': exe_id,
        'NUMEROUS_JOB_RESUMED': "True" if resumed else "False",
        'JOB_ID': job_id,
        'SECURE_CHANNEL': secure_channel,
        'CLEAR_DATA': str(clear_data_flag)
    }

    log.debug('Starting image: %s', image_url)
    kubernetes_job_spec = dict(
        cluster = _get_or_def(job, ['cluster', 'id'], 'default'),
        name=kubernetes_job_name,
        image=image_url,
        namespace=organization_id.lower(),
        nodepool=_get_or_def(job, ['cluster', 'nodepool'], 'default-workers'),
        env_variables=env_variables,
        cpu_limit=_get_or_def(job, ['cluster', 'cpu_limit'], "1"),
        cpu_request=_get_or_def(job, ['cluster', 'cpu_request'], "1"),
        memory_limit=_get_or_def(job, ['cluster', 'memory_limit'], "16000000000"),
        memory_request=_get_or_def(job, ['cluster', 'memory_request'], "100000000"),
        user_id=user_id,
        organization_id=organization_id,
    )

    firebase.submit_progress(project_id, scenario_id, job_id, 'requested', 'initializing', True)
    response = deployment_client.create_job(**kubernetes_job_spec)

    kubernetes_job_spec['name'] = response[0]
    kubernetes_job_spec['namespace'] = response[1]
    kubernetes_job_spec['start_time'] = datetime.utcnow().timestamp()

    # Get schedule data from job and check if an execution already exists with set ID
    schedule_data = firebase.get_job_schedule(project_id=project_id, scenario_id=scenario_id, job_id=job_id, schedule_key=SCHEDULE_NAME)
    current_execution = firebase.get_execution(execution_id=exe_id)

    # If an execution exists and it has a schedule, use the schedule instead of the one from job
    if current_execution is not None and SCHEDULE_NAME in current_execution.keys():
        log.debug("Got schedule from execution - using it instead of schedule from job.")
        schedule_data = current_execution[SCHEDULE_NAME]

    log.debug("Starting job - schedule: %s", schedule_data)
    firebase.submit_execution(
        project_id, scenario_id, job_id, exe_id, launch_details=kubernetes_job_spec,
        schedule_data=schedule_data, schedule_key=SCHEDULE_NAME
    )


def resume_job(project_id, scenario_id, job_id, user_id, server_address, server_port, secure_channel):
    job = firebase.get_job(project_id, scenario_id, job_id)
    if job is None:
        raise KeyError(f"Job {job_id} not found")
    
    # Find the active execution (execution is still active (in terms of the job) even if its hibernating)
    if 'active_execution' not in job:
        raise ValueError('No active execution.')

    log.debug(f'Requested to resume job with active execution: {job["active_execution"]}')
    execution = job['active_execution']
    exe = firebase.get_execution(execution)

    # If the execution is found and not active and is hibernating -> start a job without clearing data
    if exe is None:
        return
    
    if exe.get("active") and exe.get("hibernating"):
        firebase.update_execution({
            'execution': execution, 
            'hibernating': False, 
            'active': True,
            'timed_out_epoch': None
        })
        start_job(
            project_id=project_id, scenario_id=scenario_id, job_id=job_id,
            server_address=server_address, user_id=user_id, port=server_port,
            organization_id=org, execution_id_override=execution, resumed=True,
            secure_channel=secure_channel
        )
    else:
        log.warning(f"Could not resume job with active={exe.get('active')} and hibernating={exe.get('hibernating')}")


def _get_job_channel(project_id, scenario_id, job_id):
    return ".".join(['COMMAND', project_id, scenario_id, job_id])

def reset_job(project_id, scenario_id, job_id):
    # Publish command for job to terminate
    channel = _get_job_channel(project_id, scenario_id, job_id)

    log.debug(f'Terminating channel {channel}')

    publish_messages([channel], [{'command': 'terminate'}])

    job = firebase.get_job(project_id, scenario_id, job_id)

    if job is not None:
        if 'active_execution' in job:
            execution = job['active_execution']
            log.debug('act exe: ' + str(execution))
            try:
                exe = firebase.get_execution(execution)

                if execution is not None:
                    log.debug('Suspending job')
                    try:
                        response = deployment_client.set_deadline_job(
                            exe['launch_details']['name'], exe['launch_details']['namespace'],
                            exe['launch_details']['cluster'], 0
                        )
                    except Exception as e:
                        log.warning('Error in job set deadline: ' + str(e))
                log.debug('Completing execution')
                firebase.complete_execution(execution)
            except:
                pass

            log.debug('Clearing active exe')
            firebase.clear_active_exe(project_id, scenario_id, job_id)

        firebase.submit_progress(project_id, scenario_id, job_id, 'forcefully reset!', 'failed', True)

    log.debug('Completed reset.')


def get_execution_status(execution_id):
    execution = firebase.get_execution(execution_id)

    if execution is None:
        return

    if 'launch_details' not in execution:
        return None

    return deployment_client.get_job_status(
        execution['launch_details']['name'],
        execution['launch_details']['namespace'],
        execution['launch_details']['cluster']
    )


def set_scenario_progress(project, scenario, job_id, message, status, clean, progress):

    firebase.submit_progress(project, scenario, job_id, message, status,
                             clean, progress)


def update_execution(execution):
    #execution = json.loads(request.json)

    firebase.update_execution(execution)

def get_active_exe(project_id: str, scenario_id: str, job_id: str) -> dict or None:
    """Get the active execution for a job. Returns None if there if no execution is active"""
    job = firebase.get_job(
        project_id=project_id,
        scenario_id=scenario_id,
        spm_job_id=job_id
    )

    if job is None:
        raise KeyError("Job was not found!")

    return job.get('active_execution', None)


def add_schedule_to_execution(execution_id, schedule) -> None:
    """Add a schedule specification to a given execution"""
    firebase.update_execution({
        'execution': execution_id,
        SCHEDULE_NAME: {
            'enable_scheduling': schedule.enable_scheduling,
            'hibernate_after_s': schedule.hibernate_after,
            'resume_schedule': schedule.resume_schedule,
        }
    })


def update_execution_schedule(execution_id: str, schedule) -> None:
    """Just calls add_schedule for now - should be updated to handle partial updates? NEED FIX FOR DEFAULTS IN PROTO"""
    execution = firebase.get_execution(execution_id=execution_id)

    if SCHEDULE_NAME not in execution.keys():
        raise KeyError("No schedule in execution! Cannot update.")

    add_schedule_to_execution(execution_id, schedule)


def delete_execution_schedule(execution_id: str) -> None:
    """Delete schedule from execution. If no schedule exists -> do nothing."""
    execution = firebase.get_execution(execution_id=execution_id)
    if execution is None:
        raise KeyError(f"Execution with ID: '{execution_id}' was not found!")

    firebase.update_execution({'execution': execution_id, SCHEDULE_NAME: None})


def get_execution_schedule(execution_id: str) -> dict:
    """Return the schedule for the execution specified. None if does not exist."""
    execution = firebase.get_execution(execution_id=execution_id)
    if execution is None:
        raise KeyError(f"Execution with ID: '{execution_id}' was not found!")

    return execution.get(SCHEDULE_NAME, None)

def get_cluster_info():
    return deployment_client.get_cluster_info()

def _get_or_def(obj, path, default_val, always_raise=False):

    try:
        for p in path:
            obj = obj[p]

        return obj

    except KeyError:
        if always_raise:
            raise
        return default_val
