import os
import logging
from datetime import datetime
from time import sleep

from .errors import tb_str, SimulationError
from .system import NumerousSystem
from .report.report import Report
from numerous_api_client.client.data_source import DataSourceHibernating, DataSourceEmptyDataset, \
    DataSourceStreamClosed, DataSourceCompleted
from numerous_api_client.client.numerous_client import ScenarioStatus

LOG_LEVEL = os.getenv('LOG_LEVEL', logging.DEBUG)

class NumerousBaseJob:
    def __init__(self):
        self.system = None
        self.input = None
        self.app = None
        self.logger = logging.getLogger('numerous-base-job')
        self.init = True
        self.last_print = {}
        self._t0 = None

    def _post_init(self):  # This is called before _run_job from the app.loop method
        timeout = 120
        t, states = self._load_states()
        dt = self.app.nc.params.get('dt_simulation', 60)
        t_stop = self.app.end_time if not self.app.subscribe else 0
        self._t0 = t

        self.system = NumerousSystem(self.app.nc, self.app.scenario, self.app.files, self.app.start_time, t_stop,
                                     self.app.model_folder, self, states, dt)

        self.input = self.app.nc.get_inputs(self.app.scenario, t0=t, te=t_stop, dt=self.system.dt,
                                            tag_prefix='', tag_seperator='.',
                                            timeout=timeout)

    def serialize_states(self, t: float = None):
        """
        A method called when saving states.
        Returns: states as a json serializable object, to be saved.
        """
        return {}

    def _save_states(self, t, cause, details):
        if self.init:
            self.logger.info('model not initialized, could not save states')
            return

        model_states = self.serialize_states(t)
        states = {'t': t, 'states': model_states, 'cause': cause, 'details': details}

        self.app.nc.state.set('states', states)

        self.logger.debug(f"states saved: reason {cause} ({details})")

    def _load_states(self):

        t = self.app.start_time
        t_end = self.app.end_time
        states = self.app.nc.state.get('states', {})

        t = states.get('t', t)

        if t-t_end == 0:
            self.subscribe = True

        if states:
            self.logger.debug('states loaded')
        return t, states


    def _print_update(self, t, ix, update_interval=5):
        if self.last_print.get(ix) is None:
            self.last_print.update({ix: t})
            return True

        if t - self.last_print[ix] >= update_interval :
            self.last_print.update({ix: t})
            return True
        else:
            return False

    def _start(self, app):
        self.app = app
        self._post_init()
        self.run_job()

    def run_job(self):
        raise NotImplementedError


class NumerousReportJob(NumerousBaseJob):
    def __init__(self):
        super(NumerousReportJob, self).__init__()
        self.report = None
        self.template = f"{os.path.dirname(__file__)}/report/template/report_template_em.html"
        self.logger = logging.getLogger('numerous-report-job')
        self.logger.setLevel(level=LOG_LEVEL)


    def run_job(self):
        try:
            self.report = Report(self.app.nc.upload_file, template=self.template)
            self.logger.info("adding report content")
            self.add_report_content(self.app)
            self.logger.info("finalizing report")
            self.report.finalize()
        except Exception as e:
            self.logger.error(f'report job crashed: {tb_str(e)}')
            self.app.status = -1
            self.app.message = "app error - see logs"
        finally:
            if self.app.status == 0:
                self.app.terminate.set()
            elif self.app.status == 1:
                self.app.terminate.set()
            elif self.app.status == 2:
                self.app.terminate.set()
            self.logger.warning('job terminated')

    def add_report_content(self, app):
        raise NotImplementedError

class NumerousSimulationJob(NumerousBaseJob):
    def __init__(self):
        super(NumerousSimulationJob, self).__init__()
        self.system = None
        self.align_outputs_to_next_timestep = True
        self.logger = logging.getLogger('numerous-simulation-job')
        self.logger.setLevel(level=LOG_LEVEL)

    def initialize_simulation_system(self):
        """
        A method that is called once the first data is read. All initialization should be done here.
        Can return the initial output to be saved.
        """
        return

    def step(self, t: float = None, dt: float = None):
        """
        A method that is called after each data read. Could be a step solver, or some other data manipulating function.
        Returns: tuple of next timestamp and outputs as a dict with tags to be saved

        """
        return t+dt, {"no_job_defined": None}

    def run_job(self):

        # This is the classic pipelines approach - well suited for digital twins
        # Run simulation in a loop
        self.logger.debug(f'entering loop')

        cause = "no cause"
        details = ""
        i = 0
        last_data = None
        data = None
        t = None
        try:
            self.app.nc.set_scenario_progress('bootup', ScenarioStatus.INITIALZING, 0.0, force=True)
            self.app.terminate.wait(timeout=5)
            self.app.nc.set_scenario_progress('waiting for initial data', ScenarioStatus.INITIALZING, 0.0, force=True)

            self.output = self.app.nc.new_writer(buffer_size=0)
            t_stop = self.system.end_time
            t = self._t0

            while True:
                if self.app.nc.terminate_event.is_set():
                    self.app.status = 0
                    break
                if self.app.nc.hibernate_event.is_set():
                    self.app.status = 2
                    break

                try:
                    data = self.input.get_at_time(t)
                except (DataSourceHibernating, TimeoutError) as e:
                    if self.app.allow_hibernation:
                        self.app.status = 2
                        self.app.nc.hibernate(message="hibernating")
                    else:
                        self.app.status = -1
                    details = tb_str(e)
                    break
                except DataSourceCompleted as e:
                    self.app.status = 1
                    details = tb_str(e)
                    self.app.message = 'No more input data'
                    break
                except (DataSourceEmptyDataset, DataSourceStreamClosed) as e:
                    self.app.status = -1
                    details = tb_str(e)
                    break

                completed = 0
                if t_stop > 0:
                    completed = (1 - (t_stop - t) / (t_stop - self.app.start_time)) * 100

                i+=1
                if self.app.backup.is_set():
                    self._save_states(t, 'backup', f"scheduled checkpoint @ {datetime.now()}")
                    self.app.backup.clear()

                if data is None:
                    self.app.nc.set_scenario_progress(f"waiting for data. Last update: "
                                                  f"{datetime.fromtimestamp(t)}",
                                                  ScenarioStatus.RUNNING if not self.init else ScenarioStatus.WAITING,
                                                  completed)
                    if self._print_update(datetime.now().timestamp(), 0, 10):
                        self.logger.info(f"no data. Simulation time: {t}. ")
                        sleep(1)
                    continue

                self.system.update_inputs(data)

                if t_stop > 0:
                    self.app.nc.set_scenario_progress("running", ScenarioStatus.RUNNING, completed)

                if t==0 and data['_index'] > 0:
                    self.logger.warning(f'setting start time to {data["_index"]}')
                    t = data['_index']
                    t0 = t

                # if data is a dict, then convert to list

                if self.init:
                    self.app.nc.set_scenario_progress("building model", ScenarioStatus.MODEL_INITIALIZING, 0, force=True)
                    initial_output = self.system.initialize_model()
                    if initial_output:
                        if "_index" not in initial_output:
                            initial_output.update({'_index': 0})
                        self.output.write_row(initial_output)

                try:
                    tnew, outputs = self.step(t, self.system.dt)
                    if not outputs:
                        continue
                    if self.init:
                        self.app.nc.set_timeseries_meta_data([{"name": tag} for tag in outputs.keys()],
                                                         offset=self.app.start_time)
                        self.init = False
                    if not self.align_outputs_to_next_timestep:
                        outputs.update({'_index': t - self.app.start_time})
                    else:
                        outputs.update({'_index': tnew - self.app.start_time})

                except Exception as e:
                    details = tb_str(e)
                    raise SimulationError(details)

                # Advance time
                t = tnew
                # save output
                self.output.write_row(outputs)

                self.logger.info(f'Calculation step. Time is now {t}. completed: {completed}')
                last_data = data

                if (t >= t_stop) and (t_stop > 0):
                    self.logger.warning('maximum time reached')
                    self.app.status = 1
                    self.message = "Simulation completed"
                    break

        except Exception as e:
            self.logger.error(f'numerous_app crashed: {tb_str(e)}')
            self.logger.debug(f'previous data: {last_data}')
            self.logger.debug(f'data at crash: {data}')
            cause = 'exception'
            details = {"error message": tb_str(e), "previous_data": last_data, "data at crash": data}
            self.app.status = -1
            self.app.message = "app error - see logs"
        finally:
            if self.app.status == 0:
                self.app.terminate.set()
                cause = "forcefully terminated"
            elif self.app.status == 1:
                self.app.terminate.set()
                cause = "completed"
            elif self.app.status == 2:
                self.app.terminate.set()
                cause = 'hibernating'
            self.output.close()
            self._save_states(t, cause, details)
            self.logger.warning('job terminated')