import logging
import datetime
from importlib.metadata import version
import os
import abc
import traceback

from time import sleep
from threading import Thread, Event

from .system import NumerousSystem
from .job import NumerousBaseJob
from numerous_cli.client import get_client
from numerous_api_client.client import NumerousClient
from numerous_api_client.client.numerous_client import ScenarioStatus
from numerous_api_client.client.data_source import DataSourceHibernating, DataSourceEmptyDataset, \
    DataSourceStreamClosed, DataSourceCompleted

VERSION = version('numerous_image_tools')
APPVERSION = os.getenv('APPVERSION', "?")
APPNAME = os.getenv('APPNAME')
LOG_LEVEL = os.getenv('LOG_LEVEL', logging.INFO)
MODEL = os.getenv('MODEL')
MAX_RESTARTS = int(os.getenv('MAX_RESTARTS', 0))


local = os.getenv('KUBERNETES_SERVICE_HOST',"") == ""
logger = logging.getLogger('numerous-image-tools')
logger.setLevel(level=LOG_LEVEL)

def tb_str(e):
    s=traceback.format_exception(etype=type(e), value=e, tb=e.__traceback__)
    es=[ss.strip() for ss in s]
    return "\n".join(es)

class SimulationError(Exception):
    pass

class NumerousApp():
    def __init__(self, numerous_client: NumerousClient = None,
                 appname = "defaultnumerousApp",
                 max_restarts = 0, numerous_job: NumerousBaseJob = None, model_folder: str = None):
        self.appname = appname
        self.nc = numerous_client
        self.last_print = {}
        self.backup = Event()
        self.terminate = Event()
        self.output = None
        self.scenario = None
        self.model_definition = None
        self.files = None
        self.job_spec = None
        self.solver_settings = None
        self.sim_job = None
        self.subscribe = False
        self.solver_settings = None
        self.logger = None
        self.allow_hibernation = False
        self.start_time = None
        self.end_time = None
        self.max_restarts = max_restarts
        self.status = 0
        self.message = ""
        self.logger = logging.getLogger(self.appname)
        self.logger.setLevel(level=LOG_LEVEL)
        self.model_folder = model_folder
        self.numerous_job = numerous_job
        self.init = True
        self.setup_scenario()


    def setup_scenario(self):
        # Setup for buckets and firestore job
        self.scenario, self.model_definition, self.files = self.nc.get_scenario()
        self.nc.set_scenario_progress('initializing', ScenarioStatus.ENVIRONMENT_INITIALIZING, 0.0, force=True)

        # Create file dw for data upload

        # Prepare simulation setup

        self.start_time = self.nc.run_settings.start.timestamp()
        self.end_time = self.nc.run_settings.end.timestamp()
        self.appname = self.scenario['scenarioName']
        self.allow_hibernation = self.nc.params.get('allow_hibernation', 'False')
        self.logger.info(f'allow hibernation: {self.allow_hibernation}')

    def save_states(self, t, cause, details):
        if self.init:
            self.logger.info('model not initialized, could not save states')
            return

        states = self.numerous_job.serialize_states(t)

        self.nc.state.set('t', t)
        self.nc.state.set('states', states)
        self.nc.state.set('cause', cause)
        self.nc.state.set('details', details)

        self.logger.debug(f"states saved: reason {cause} ({details})")

    def load_states(self):

        t = self.start_time
        t_end = self.end_time
        t = self.nc.state.get('t', t)
        states = self.nc.state.get('states', None)

        if t-t_end == 0:
            self.subscribe = True

        if states:
            self.logger.debug('states loaded')
        return t, states


    def print_update(self, t, ix, update_interval=5):
        if self.last_print.get(ix) is None:
            self.last_print.update({ix: t})
            return True

        if t - self.last_print[ix] >= update_interval :
            self.last_print.update({ix: t})
            return True
        else:
            return False

    def loop(self):

        self.logger.info(f'starting numerous app for {self.appname}')
        thread = Thread(target=self.run, args=(), daemon=False)
        thread.start()
        starttime = datetime.datetime.now()
        last_checkpoint = starttime
        restarts = 0
        warned = False

        while not self.terminate.is_set():
            checkpoint_time = datetime.datetime.now()
            if not thread.is_alive() and self.status == -1:
                if not warned:
                    self.nc.set_scenario_progress(message=
                                                  f"{self.message}. Retrying after 5 mins "
                                                  f"({restarts + 1}/{self.max_restarts})",
                                                  status=ScenarioStatus.RUNNING, force=True)
                    warned = True
                if (datetime.datetime.now() - starttime > datetime.timedelta(seconds=300)) and \
                        restarts < self.max_restarts:
                    self.logger.error(f'{self.appname} is dead.. restarting')
                    self.status = 0
                    self.message = ""
                    thread = Thread(target=self.run, args=(), daemon=False)
                    thread.start()
                    sleep(5)
                    restarts +=1
                    starttime = datetime.datetime.now()
                    warned = False

                elif restarts >= self.max_restarts:
                    self.logger.error(f'{self.appname} is dead and reached max restarts')
                    break

            if checkpoint(checkpoint_time, last_checkpoint):
                self.backup.set()
                last_checkpoint = checkpoint_time

            sleep(0.5)
        self.logger.debug(f"waiting for {self.appname} to shut down...")
        thread.join()
        self.logger.info(f"{self.appname} stopped")
        self.output.close()
        self.logger.info("Main thread stopped gracefully")

    def run(self):

        # This is the classic pipelines approach - well suited for digital twins
        # Run simulation in a loop
        self.logger.debug(f'entering loop')
        timeout = 120
        cause = "no cause"
        details = ""
        i = 0
        last_data = None
        data = None
        t = None
        try:
            self.nc.set_scenario_progress('bootup', ScenarioStatus.INITIALZING, 0.0, force=True)
            self.terminate.wait(timeout=5)
            self.nc.set_scenario_progress('waiting for initial data', ScenarioStatus.INITIALZING, 0.0, force=True)

            self.output = self.nc.new_writer(buffer_size=0)
            t, states = self.load_states()
            t0 = t
            dt = self.nc.params.get('dt_simulation', 60)
            t_stop = self.end_time if not self.subscribe else 0

            self.system = NumerousSystem(self.nc, self.scenario, self.files, self.start_time,
                                         self.model_folder, self.numerous_job, states, dt)

            input = self.nc.get_inputs(self.scenario, t0=t, te=t_stop, dt=dt, tag_prefix='', tag_seperator='.',
                                       timeout=timeout)
            while True:
                if self.nc.terminate_event.is_set():
                    self.status = 0
                    break
                if self.nc.hibernate_event.is_set():
                    self.status = 2
                    break

                try:
                    data = input.get_at_time(t)
                except (DataSourceHibernating, TimeoutError) as e:
                    if self.allow_hibernation:
                        self.status = 2
                        self.nc.hibernate(message="hibernating")
                    else:
                        self.status = -1
                    details = tb_str(e)
                    break
                except DataSourceCompleted as e:
                    self.status = 1
                    details = tb_str(e)
                    self.message = 'No more input data'
                    break
                except (DataSourceEmptyDataset, DataSourceStreamClosed) as e:
                    self.status = -1
                    details = tb_str(e)
                    break

                completed = 0
                if t_stop > 0:
                    completed = (1 - (t_stop - t) / (t_stop - t0)) * 100

                i+=1
                if self.backup.is_set():
                    self.save_states(t, 'backup', f"scheduled checkpoint @ {datetime.datetime.now()}")
                    self.backup.clear()

                if data is None:
                    self.nc.set_scenario_progress(f"waiting for data. Last update: "
                                                  f"{datetime.datetime.fromtimestamp(t)}",
                                                  ScenarioStatus.RUNNING if not self.init else ScenarioStatus.WAITING,
                                                  completed)
                    if self.print_update(datetime.datetime.now().timestamp(), 0, 10):
                        self.logger.info(f"no data. Simulation time: {t}. ")
                        sleep(1)
                    continue

                self.system.update_inputs(data)

                if t_stop > 0:
                    self.nc.set_scenario_progress("running", ScenarioStatus.RUNNING, completed)

                if t==0 and data['_index'] > 0:
                    self.logger.warning(f'setting start time to {data["_index"]}')
                    t = data['_index']
                    t0 = t

                # if data is a dict, then convert to list

                if self.init:
                    self.nc.set_scenario_progress("building model", ScenarioStatus.MODEL_INITIALIZING, 0, force=True)
                    initial_output = self.system.initialize_model()
                    if initial_output:
                        if "_index" not in initial_output:
                            initial_output.update({'_index': 0})
                        self.output.write_row(initial_output)

                try:
                    tnew, outputs = self.numerous_job.step(t, dt)
                    if self.init:
                        self.nc.set_timeseries_meta_data([{"name": tag} for tag in outputs.keys()],
                                                         offset=self.start_time)
                        self.init = False
                    if not self.numerous_job.align_outputs_to_next_timestep:
                        outputs.update({'_index': t - self.start_time})
                    else:
                        outputs.update({'_index': tnew - self.start_time})

                except Exception as e:
                    details = tb_str(e)
                    raise SimulationError(details)

                # Advance time
                t = tnew
                # save output
                self.output.write_row(outputs)

                self.logger.info(f'Calculation step. Time is now {t}. completed: {completed}')
                last_data = data

                if (t >= t_stop) and (t_stop > 0):
                    self.logger.warning('maximum time reached')
                    self.status = 1
                    self.message = "Simulation completed"
                    break

        except Exception as e:
            self.logger.error(f'numerous_app crashed: {tb_str(e)}')
            self.logger.debug(f'previous data: {last_data}')
            self.logger.debug(f'data at crash: {data}')
            cause = 'exception'
            details = {"error message": tb_str(e), "previous_data": last_data, "data at crash": data}
            self.status = -1
            self.message = "app error - see logs"
        finally:
            if self.status == 0:
                self.terminate.set()
                cause = "forcefully terminated"
            elif self.status == 1:
                self.terminate.set()
                cause = "completed"
            elif self.status == 2:
                self.terminate.set()
                cause = 'hibernating'
            self.output.close()
            self.save_states(t, cause, details)
            self.logger.warning('job terminated')


def checkpoint(tnow, tlast, interval=60):
    dt=datetime.timedelta(minutes=interval)
    if tnow-tlast > dt:
        return True
    else:
        return False

def run_job(numerous_job=None, appname=None, max_restarts=0, model_folder="models"):
    """
    Runs the simulation job with the model located in model_folder_name. Handles all events.
    """

    nc = None
    try:
        RESET_JOB = os.getenv('RESET_JOB', 'False') == "True"

        if local:
            nc = get_client(reset_job=RESET_JOB)
        else:
            nc = NumerousClient()

        nc.logger.setLevel(LOG_LEVEL)

        if numerous_job is None:
            raise KeyError("no simulation job specified")
        if appname is None:
            logger.warning("no appname specified")

        logger.info(f"Welcome to numerous image tools. Base version {VERSION} - app {APPNAME} version {APPVERSION}")

        app = NumerousApp(numerous_client=nc, appname=appname,
                          max_restarts=max_restarts, numerous_job=numerous_job, model_folder=model_folder)

        app.loop()

        if app.status == -1:
            nc.set_scenario_progress(app.message, ScenarioStatus.FAILED, 0, force=True)
        elif app.status == 0:
            nc.set_scenario_progress("interrupted", ScenarioStatus.FINISHED, force=True)
        elif app.status == 1:
            nc.set_scenario_progress(app.message, ScenarioStatus.FINISHED, 100, force=True)
        elif app.status == 2:
            nc.set_scenario_progress("hibernating", ScenarioStatus.HIBERNATING, force=True)
        else:
            nc.set_scenario_progress("unknown", ScenarioStatus.FINISHED, force=True)

        logger.debug(f'setting state {app.status}')
        nc.state.set('_status', app.status)

    except Exception as e:
        logger.error(f"unhandled exception: {tb_str(e)}")
        if nc is not None:
            nc.set_scenario_progress("unhandled error", ScenarioStatus.FAILED, 0.0, force=True)
            logger.debug(f'setting state -2')
            nc.state.set('_status', -2)
    finally:
        if nc is not None:
            nc.close()


if __name__ == "__main__":

    if APPNAME is None:
        raise KeyError("no appname specified")
    if MODEL is None:
        raise KeyError("no model folder specified")

    run_job(model_folder_name=MODEL, appname=APPNAME, max_restarts=MAX_RESTARTS)








