import logging
import pickle
import jsonpickle
import jsonpickle.ext.numpy as jsonpickle_numpy
jsonpickle_numpy.register_handlers()
import numpy as np
import time
from _datetime import datetime, timedelta
from queue import Empty
from multiprocessing import Process, Value, Queue, Manager
import sim_tools.assemble_solve.model_assemble as MU
import sim_tools.assemble_solve.model_solver as model_solver
import sim_tools.store.simulation_store_mongodb as simulation_store_mongodb
from sim_components.regulators import Static_Regulator as static_regulator
from job_worker.job_models import *



from time import time
class Debounce:
    def __init__(self, duration=10):
        self.debounce_duration = duration  # s
        self.last_not_debounced = datetime.now() - timedelta(seconds=2 * self.debounce_duration)

    def __call__(self, function, *args, force=False,  **kwargs):

        if ((self.last_not_debounced + timedelta(seconds=self.debounce_duration)) - datetime.now()).total_seconds() <= 0 or force:
            self.last_not_debounced = datetime.now()
            return function(*args, **kwargs)


def main(job_id, resume_dict, sim_specification, progress_callback=None, dw=None, continue_callback=None, **kwargs):
    if not continue_callback:
        def continue_callback():
            return True

    if not progress_callback:
        def progress_callback(progress, meta=None):
            print()
            print('***Progress update***')
            print('progress: ',progress)
            if meta:
                print('meta: ', meta)

    #with open('sim_job_spec.txt', 'w') as file:
    #    import json
    #    json.dump(sim_specification,file)
    #print(sim_specification)
    debounce = Debounce()
    debounce_prog = Debounce(duration=3)

    stub_definition = sim_specification['model_spec']
    sim_spec = sim_specification['sim_spec']

    #logging.debug(stub_definition)
    if 'class' in stub_definition:
        Stub = jsonpickle.loads(stub_definition['class'])  # jsonpickle.decode(request.messageBody)

    elif 'source' in stub_definition:

        import os
        os.chdir('/sim')
        f=open('./current_model'+job_id+'.py', 'w')
        f.write("import sim_components.generic.items as BC\n"+stub_definition['source'])
        f.close()

        import importlib.util

        import ntpath

        module_spec = importlib.util.spec_from_file_location(
            ntpath.basename(f.name), f.name)
        module = importlib.util.module_from_spec(module_spec)
        module_spec.loader.exec_module(module)
        Stub =  module.Stub
    elif 'class_name' in stub_definition:
        modulename, decim, classname = stub_definition['class_name'].rpartition('.')
        import importlib
        module = importlib.import_module(modulename)
        Stub = getattr(module, classname)


    else:
        raise EnvironmentError('No model to solve!')

    args = []
    if 'args' in stub_definition:
        args = jsonpickle.loads(stub_definition['args'])

    kwargs = {}
    if 'kwargs' in stub_definition:
        kwargs = jsonpickle.loads(stub_definition['kwargs'])

    if 'kwargs_loaded' in stub_definition:
        kwargs = stub_definition['kwargs_loaded']
    #print('creating stub')

    item_fact = simulation_store_mongodb.Item_Factory('')
    stub = item_fact.get_Item('Stub', '', Stub, *args, **kwargs)

    #print('now getting interp data')
    interp_data={}
    if 'interp' in stub_definition:

        interp_data.update(jsonpickle.loads(stub_definition['interp']))


    if 'input_stream' in stub_definition:

        job_alias = stub_definition['input_stream']['job_alias']

        tag_map = stub_definition['input_stream']['tag_map']
        stream_ix = stub_definition['input_stream']['stream_ix']
        try:
            input_data_jo = JobOutput.objects.get(job_alias=job_alias)
        except DoesNotExist:
            raise DoesNotExist('The input stream for static data does not exist!')



        in_tags = ['t']

        for tag_pair in tag_map:
            #Get data from
            in_tags.append(tag_pair[0])


        #Get data from job output
        if stream_ix < 0:
            #print('whole year')
            #Use all data
            import pandas as pd
            t_sim = -stream_ix*3600*24*365.25

            t_stitched = 0

            while t_stitched < t_sim:

                for i in range(len(input_data_jo.data_set.data_streams)):
                    if t_stitched >0:
                        max_t=df_in['t'].max()
                        this_df = input_data_jo.data_set.data_streams[i].get_df(in_tags)
                        this_df['t'] = this_df['t'] + max_t
                        df_in = pd.concat([df_in, this_df], axis=0)

                    else:
                        df_in = input_data_jo.data_set.data_streams[i].get_df(in_tags)
                    t_stitched = df_in['t'].max()

            df_in=df_in.loc[df_in['t'] <= t_sim]
            #print(df_in.describe())

        else:
            df_in = input_data_jo.data_set.data_streams[stream_ix].get_df(in_tags)
        assert df_in.isnull().values.any()==False, 'Invalid values in stream!'
        #print(df_in)
        #df_in['t']
        interp_tags = {}
        for tag_pair in tag_map:
            val = (df_in[tag_pair[0]]*tag_pair[2]+tag_pair[3]).values
            if tag_pair[1] in interp_tags:
                interp_tags[tag_pair[1]] += val
            else:
                interp_tags[tag_pair[1]] = val

        for in_tag in list(df_in):
            if not in_tag == 't':
                if not in_tag in interp_tags:
                    val = df_in[in_tag].values
                    interp_tags[in_tag] = val

        group = {'t': df_in['t'].values, 'tags': interp_tags}

        interp_data.update({'g1' : group})

        if 'solve_all' in stub_definition['input_stream']:
            if stub_definition['input_stream']['solve_all']:
                sim_spec['t_0'] = min(df_in['t'].values)
                sim_spec['t_end'] = max(df_in['t'].values)


    if interp_data:
        #print('Interp data...')
        stub.add('Interp_Data', static_regulator.Static_Regulator, interp_data)

    #print('now assembling model')

    model = MU.Model('model', stub)
    model.assemble(sim_spec['persistent_tags'])


    dt_update = 0

    if 'n_dt_update' in sim_spec:
        if sim_spec['n_dt_update']>0:

            dt_update = sim_spec['dt']




    if 't_0' in sim_spec:

        model.apply_init_state({'t': sim_spec['t_0']})
        t0=sim_spec['t_0']
    else:
        t0=0

    if 'start' in sim_spec:
        import dateutil.parser

        start = dateutil.parser.parse(sim_spec['start'])
    else:
        start = datetime(2016, 1, 1, 0, 0, 0, 0)

    if 'stop_tag' in sim_spec:
        model.set_stop_tag(sim_spec['stop_tag'])
        #print('stop tag: ',sim_spec['stop_tag'])

    if 'initial_cond' in sim_spec:
        model.apply_init_state(sim_spec['initial_cond'])

    if resume_dict:
        #print(resume_dict.keys())
        model.apply_init_state(resume_dict)

    dw.data_map=model.persistent_tags
    dw.complete_map=list(model.get_combined()[1])
    dw.start = start
    dw.data_blocks.append(model.sys_dict_high_level.values(model.persistent_tags))
    dw.complete_data_block = model.get_combined()[0]
    dw.flush(force=True)


    def run_simulation(dt, t_end, dt_internal, dt_update, n_keep):
        tic_start = time()
        #print('Keeping each ',n_keep,' results')
        last_print=-1
        data_queue = Queue()
        manager = Manager()
        pickled_model_assy = manager.list()
        pickled_model_assy.append(pickle.dumps(model.to_dict()))

        continue_flag = Value('i',1)


        ticker = time()

        #pickled_model_assy[0] = pickle.dumps(model.to_dict())
        pickled_model_assy.append(pickle.dumps(model.sys_arr))
        #pickled_model_assy.append(pickle.dumps(model))

        t = model.sys_dict['t']
        last_t_solved = t
        interval = 1#3600*24*7 #a week

        status='running'

        cont=True
        keep_counter = 0
        tic = time()

        last_tic_solve = time()
        last_processed = t
        while t < t_end and cont:


            prog_meta = ''
            if t/interval-last_print >= 1:
                last_print = t/interval
                duration = time()-ticker
                prog_meta = ''
                ticker=time()


            prog = t / (t_end - t0)
            debounce_prog(progress_callback, prog, prog_meta=prog_meta)
            #push updated model
            pickled_model_assy[1] = pickle.dumps(model.sys_arr)

            if dt_update>0:
                #print('dt: ',dt)
                #print('dt end: ', t+dt_update)
                run_sim_step_process = Process(target=run_sim_step, args=(dt, t+dt_update, dt_internal, data_queue, pickled_model_assy, continue_flag, tic))
            else:
                run_sim_step_process = Process(target=run_sim_step,
                                               args=(dt, t_end, dt_internal, data_queue, pickled_model_assy, continue_flag, tic))
            tic = time()
            run_sim_step_process.start()
            try:
                #print('sim - sim step sub process started')

                #receive data and progress
                stepping_in_progress = True

                while stepping_in_progress:



                    try:

                        #time.sleep(.1)
                        message = data_queue.get(timeout=.1)

                        if message:
                            #deserialize
                            message = pickle.loads(message)
                            #print(message['type'])

                            if message['type'] == 'data':
                                #process data
                                #tic = time()
                                sys_arr = message['payload']
                                t_processed = sys_arr[0]
                                dt_processed = t_processed - last_processed
                                #"#if dt_processed > dt * 1.05 or dt_processed < dt * 0.95:
                                #""    raise ValueError('print not processing all values, ', dt_processed, ' ', dt)
                                last_processed = t_processed
                                model.update(sys_arr)
                                toc = time()
                                #print('high level output time: ', toc - tic)
                                #store sys arr high level
                                #tic = time()
                                keep_counter += 1
                                if keep_counter >= n_keep or n_keep < 0:
                                    #print('t: ',sys_arr[0])
                                    dw.data_blocks.append(model.sys_dict_high_level.values(model.persistent_tags))
                                    dw.complete_data_block = model.get_combined()[0]
                                    dw.flush()

                                    keep_counter = 0
                                #else:
                                    #print('rejected t: ',sys_arr[0])

                                toc = time()

                            elif message['type'] == 'FINAL':
                                #progress_callback(1)
                                stepping_in_progress = False
                                if message['payload']=="Stop":
                                    cont=False
                                #print('Final flag: ',message['payload'])
                                #print('Final!')
                                dw.flush(force=True)
                                if continue_flag.value>0:
                                    status = 'finished'
                                else:
                                    status = 'terminated'



                            elif message['type'] == 'progress':
                                #process progress
                                progress = message['payload']
                                solver_t = progress['t_solved']
                                dt_solved = solver_t - last_t_solved
                                tic_solve = time()
                                dt_real_solve = tic_solve - last_tic_solve
                                last_tic_solve = tic_solve

                                last_t_solved = solver_t
                                #print('solver t: ',solver_t)
                                overall_progress = (solver_t-t0)/(t_end-t0)*100
                                prog_str = 't simulation: '+str(solver_t)+'s, rate: '+str(dt_solved/dt_real_solve)+', progress: '+str(round(overall_progress*10)/10)+'%'
                                progress_callback(overall_progress, prog_str)

                            elif message['type'] == 'error':
                                error = message['payload']
                                #print(model.ops[error[1]])
                                run_sim_step_process.join()
                                stepping_in_progress = False
                                dw.flush(force=True)
                                #debounce(data_callback, force=True, complete_data_block=model.get_combined()[0])
                                raise ChildProcessError(model.ops[error[1]])

                            else:
                                raise ValueError('Unknow message: ', message['type'])

                        t = model.sys_dict['t']
                        #print('sim - reached t: ',t)

                    except Empty as e:
                        pass
                        #print('her')
                        #time.sleep(.01)

                    if not continue_callback():
                        continue_flag.value = 0
                        stepping_in_progress = False

                        cont = False
                        #dw.flush(force=True)
                        status = 'terminated'


                #print('cont here',continue_callback())
                run_sim_step_process.join(1)
                if run_sim_step_process.is_alive():
                    print('solver terminated!')
                    run_sim_step_process.terminate()


                #run_sim_step_process.
            except Exception as e:
                print('caught exception here')
                import sys, traceback
                traceback.print_exc(file=sys.stdout)

                continue_flag.value = 0
                if run_sim_step_process.join(5) == False:
                    run_sim_step_process.terminate()
                raise

            toc1 = time()
            #print('high level loop time: ', toc1 - tic1)

        #print('Completed')
        tic_stop= time()

        #print('time: ', tic_stop-tic_start)
        return status
    #print('start: ', t0, ' end: ', sim_spec['t_end'])

    #print('dt_internal: ',sim_spec['dt_internal'])
    status = run_simulation(sim_spec['dt'], sim_spec['t_end'], sim_spec['dt_internal'], dt_update, sim_spec['n_dt_update'])

    return status

def run_sim_step(dt, t_end, dt_int, data_queue, pickled_model_assy,continue_flag, tic):



    #print('run sim step started')
    m = pickle.loads(pickled_model_assy[0])
    #model=pickle.loads(pickled_model_assy[2])
    #print('run sim step loaded model asssy')
    sys_arr = pickle.loads(pickled_model_assy[1])
    #model.update(sys_arr)

    states_ix = m['states_ix']
    states_dot_ix = m['states_dot_ix']
    enc_ops = m['enc_ops']
    data_def = m['data_def']
    data_bank = m['data_bank']
    stop_ix = m['stop_ix']
    error_op = np.zeros(8, dtype=np.int32)
    #print(stop_ix)

    def update_data(data_block):
        data_queue.put(pickle.dumps({'type': 'data', 'payload': data_block}))

    def final(flag=None):
        if flag:
            data_queue.put(pickle.dumps({'type': 'FINAL', 'payload': flag}))
        else:
            data_queue.put(pickle.dumps({'type': 'FINAL', 'payload': ''}))

    def update_progress(progress):
        data_queue.put(pickle.dumps({'type': 'progress', 'payload': progress}))

    def error(err):
        data_queue.put(pickle.dumps({'type': 'error', 'payload': err}))

    t=sys_arr[0]

    #print('run sim step starting from t: ',t, ' and solving to: ', t_end)

    #print('low level setup time: ', toc - tic)


    step=0
    t0=sys_arr[0]
    while t <= t_end and continue_flag.value>0:
        step+=1
        #print('starting t: ', t)
        #print('ending t: ', t+dt)
        #print(dt, ' ', t_end, ' ', dt_int)
        y0 = model_solver.state_vals(sys_arr, states_ix)
        #print('t in: ',t)
        #print('t end in: ', step*dt)

        y_last = model_solver.solve_BE(t, t0+step*dt, dt_int, y0, sys_arr, sys_arr,
                                        states_ix, states_dot_ix, enc_ops, error_op,
                                        data_def, data_bank)

        t = sys_arr[0]
        #print('t here: ', t)
        update_data(sys_arr)
        toc = time()


        #print('t_solved: ', t, ' rate: ', dt/(toc-tic))
        #update progress

        update_progress({'rate': dt/(toc-tic), 't_solved': t})
        #tic = time()
        #check error
        if error_op[0] != 0:
            #print('run sim step error: ', error_op)
            print(sys_arr[error_op[5]])
            print(sys_arr[error_op[6]])
            print(sys_arr[error_op[7]])
            error(error_op)

            raise RuntimeError(error_op)

        #check stop
        if sys_arr[stop_ix]>0:
            final('Stop')
            return


    #print('low level solve time: ',toc-tic)


    if continue_flag.value <= 0:
        #print('Cont flag: ', continue_flag.value)
        final('Stop')
    else:
        final()

    #print('Solver exited!')
    #print('run sim step - solving completed')
        #toc = time()