import queue
from queue import Empty
from numerous.engine.simulation.simulation import Simulation
from numerous.engine.simulation.solvers import SolverType
from numerous.engine.model.model import Model
from numerous.engine_tools.numerous_system_initializer import SystemInitializer
from numerous.engine.system.external_mappings.interpolation_type import InterpolationType
from numerous.engine.system.external_mappings import ExternalMappingElement

from numerous.utils.data_loader import DataLoader, DataFrame
from numerous.utils.logger_levels import LoggerLevel
from numerous.utils.historian import Historian
from numerous.image_tools.job import NumerousSimulationJob

import numpy as np
import pandas as pd
import logging

logger = logging.getLogger(__file__)

class FileDWHistorian(Historian):
    def __init__(self, file_dw, max_size=None):
        super().__init__(max_size)
        self.max_size = max_size
        self.file_dw = file_dw
        self._historians = []
        self.df = pd.DataFrame()

    def store(self, df):
        self.df = df

class LineByLineDataLoader(DataLoader):

    def __init__(self, external_mappings: list):
        super().__init__()

    def load(self, df: DataFrame, t: int) -> DataFrame:
        self.df = df
        return self.df

class NumerousEngineJob(NumerousSimulationJob):
    def __init__(self):
        super().__init__()
        self.tag = 'system'
        self.len_input_data = None
        self.df_local_path = None
        self.historian = FileDWHistorian(max_size=2, file_dw=None)
        self.simulation = None
        self.dt = None
        self.t_simulation = None
        self.t_stop = None
        self.t_start = None
        self.t_offset = None
        self.y0 = None
        self._skip_one_step = False
        self._initial_output = None
        self._output_queue = queue.Queue()

    def setup_mappings(self, t0: float, time_multiplier: float, index_to_timestep_mapping: str,
                       index_to_timestep_mapping_start: int):
        self.dataframe_aliases = {}
        self.external_mappings = []

        df_dict = {}
        missing = []
        for component in self.system.components.values():
            name = component.name

            for alias, column_name in component.parameters.items():
                mapped_variable = f"{self.tag}.{name}.{alias}"
                self.dataframe_aliases.update({mapped_variable: (column_name, InterpolationType.PIESEWISE)})
                df_dict.update({column_name: component.inputs.get(alias, None)})

        # for key, val in self.dataframe_aliases.items():
        #    v = self.initial_data.get(val[0], None)
        #    if v is None:
        #        missing.append(f'No initial value for {key}')
        #
        #    df_dict.update({val[0]: v})

        if len(missing) > 0:
            raise ValueError(f'Missing values {missing} in initial data')
        if not df_dict:
            return

        df_dict.update({index_to_timestep_mapping: [t0]})

        df = pd.DataFrame(df_dict, dtype=np.float64)  # .set_index('time')

        # Add external mappings so data can be read
        self.external_mappings.append(
            ExternalMappingElement(
                df,
                index_to_timestep_mapping,
                index_to_timestep_mapping_start,
                time_multiplier,
                self.dataframe_aliases
            )
        )

    def initialize_simulation_system(self):
        self.dt= self.system.dt
        self.t_offset = self.system.start_time

        self.y0 = self.system.states
        self.t_start = self.system.states.get('t', self.system.start_time)-self.system.start_time
        self.len_input_data = 1

        self.t_simulation = self.system.end_time-self.system.start_time
        self.t_stop = self.t_start + self.dt

        self.setup_mappings(0, 1, 't', 0)

        # Create a system based on specified components
        enginesystem = SystemInitializer(
            self.tag,
            system=self.system, external_mappings=self.external_mappings if self.external_mappings else None,
            data_loader=LineByLineDataLoader(self.external_mappings) if self.external_mappings else None
        )

        self.enginesystem = enginesystem

        # Create model based on system
        """
        model = Model(
            enginesystem, external_mappings=enginesystem.external_mappings,
            data_loader=LineByLineDataLoader(enginesystem.external_mappings), historian=self.historian,
            logger_level=LoggerLevel.INFO, imports=[("external_functions", "if_replacement_1"),
                                                              ("external_functions", "if_replacement_11"),
                                                              ("external_functions", "if_replacement_12"),
                                                              ("external_functions", "if_replacement_1_1")],
            use_llvm=True
        )
        """
        model = Model(
            enginesystem, historian=self.historian,
            logger_level=LoggerLevel.INFO,
            use_llvm=True, **self.system.parameters
        )

        # Create simulation object
        #simulation = None
        simulation = Simulation(
            model, t_start=self.t_start, t_stop=self.t_stop,
            num=1, num_inner=1,
            max_step=self.dt, solver_type=SolverType.NUMEROUS
        )

        if self.y0 is not None:
            for i, y_ in enumerate(self.y0):
                simulation.model.numba_model.write_variables(y_, i)

        self.simulation = simulation

        tags = []
        for name in model.logged_aliases.keys():
            tags.append({'name': name})

        output = None
        if not self.y0:
            self.initial_step(self.t_start, self.dt)
            output = self._return_output_from_queue()
        else:
            self.simulation.numba_model.historian_reinit()

        return output

    def _map_inputs_to_numpy_array(self, t):
        time = []
        res = [[]]
        for component in self.system.components.values():
            for alias, column_name in component.parameters.items():
                res[0].append(component.inputs.get(alias, None))
        time.append(t)

        np_res=np.array([res], dtype=np.float64)
        np_time=np.array([time], dtype=np.float64)
        t_max = np.max(np_time)

        return np_res, np_time, t_max

    def _update_input(self, t):

        # we do not wish to update external data from within the solver, so we set the update to never happen
        external_mappings_numpy, external_mappings_time, t_max = \
            self._map_inputs_to_numpy_array(t)
        self.simulation.model.numba_model.max_external_t = self.system.end_time-self.t_offset+self.dt#t_max+np.int64(2*self.dt)
        self.simulation.model.numba_model.update_external_data(external_mappings_numpy, external_mappings_time)
        self.simulation.model.numba_model.map_external_data(t)

    def _add_output_to_queue(self, output):
        self._output_queue.put(output)

    def _return_output_from_queue(self):
        try:
            return self._output_queue.get(block=False)
        except Empty:
            return

    def queue_outputs(self):

        df = self.simulation.model.historian.df
        if df.empty:
            self.simulation.model.create_historian_df()
            df = self.simulation.model.historian.df
        output_dict = {col: df[col].values for col in self.simulation.model.logged_variables}
        self.simulation.model.historian_df = None
        self.simulation.model.historian.df = pd.DataFrame()
        self.simulation.numba_model.historian_ix = 0
        output_dict.update({"_index": df['time'].values})  # best practise is to save relative time

        outputs = [dict(zip(output_dict,t)) for t in zip(*output_dict.values())]
        for output in outputs:
            self._add_output_to_queue(output)

    def serialize_states(self, t: float = None):
        states = list(self.simulation.model.numba_model.read_variables())
        return states

    def initial_step(self, t_: float, dt: float = None):
        self._update_input(t_)
        _, tint = self.simulation.step_solve(t_, min(dt, self.t_simulation-t_))
        self.queue_outputs()


    def step(self, t: float = None, dt: float = None, init=False):
        next_output = self._return_output_from_queue()
        if next_output:
            return next_output['_index'] + self.t_offset, next_output

        t_ = self.simulation.solver.info.t if self.simulation.solver.info else 0
        self._update_input(t_)
        # watch out for round-off error!
        _, t_step = self.simulation.step_solve(t_, min(dt, self.t_simulation-t_))
        self.queue_outputs()

        next_output = self._return_output_from_queue()
        return next_output['_index'] + self.t_offset, next_output





