import sim_tools.assemble_solve.model_encoder as Model_Encoder
import sim_tools.assemble_solve.model_structure as Model_Structure
import sim_tools.assemble_solve.model_solver as Model_Solver
import inspect
import ast
import numpy as np
import collections
import copy
from copy import deepcopy
from time import time

class FixedDict(object):
    def __init__(self, dictionary):

        #assert isinstance(dictionary, collections.OrderedDict), "The dictionary must be an ordered dictionary to be fixed!"
        self._dictionary = collections.OrderedDict(dictionary)
        self.allowed_keys = set(self._dictionary.keys())
        self.length = len(self.allowed_keys)
        self.indices = {k: i for i, k in enumerate(self._dictionary.keys())}


    def __contains__(self, key):
        return key in self.allowed_keys

    def __setitem__(self, key, item):
        if key not in self._dictionary:
            raise KeyError("The key {} is not defined.".format(key))
        self._dictionary[key] = item

    def __getitem__(self, key):
        return self._dictionary[key]

    def update(self, dict):
        #check if any disallowed keys in new dict

        assert set(dict.keys()).issubset(self.allowed_keys), "No new keys allowed in update of FixedDict! - Differences found: "+ str(self.allowed_keys.difference(dict))
        self._dictionary.update(dict)

    def update_intersection(self, dict):
        keys_to_update = self.allowed_keys.intersection(dict.keys())
        self.update_keys(dict, keys_to_update)

    def update_keys(self, dict, keys_to_update):
        #keys_to_update = self.allowed_keys.intersection(dict.keys())
        self.update({k: dict[k] for k in keys_to_update})

    def update_from_array(self, arr):
        assert self.length == len(arr), "The length of the array must match the number of keys!"
        self._dictionary = {k: arr[i] for i, k in enumerate(self._dictionary.keys())}

    def values(self, keys=None):
        if keys:
            return [self[k] for k in keys]
        else:
            return list(self._dictionary.values())

    def items(self):
        return self._dictionary.items()

    def keys(self):
        return self.allowed_keys

    def get_indices(self, keys=None, ignore_missing=False):
        if keys:
            if ignore_missing:
                keys = self.allowed_keys.intersection(keys)
            return {k: self.indices[k] for k in keys}
        else:
            return self.indices

    def get_indices_list(self):

        return list(self.indices.keys())

    def dict(self):
        return self._dictionary


class Model:
    def __init__(self, name, system):
        self.system=system
        self.items=collections.OrderedDict({})
        self.items.update(self.system.items)
        self.items.update({self.system.name: self.system})
        self.t = 0
        self.error_op=np.zeros(8,dtype=np.int32)
        self.stop_ix = 'stop'
        self.parsed_funcs = {}
        self.last_out = 0
        self.next_out = 0
        self.step_time = 0
        self.non_step_time = 0
        self.toc = 0
        
    def assemble(self, persistent_tags):
        #iterate over items to draw out functions and compile them!
            #find special variables
                #states
                #states_dot
                #reservoir states_dot

        #Create the items dict of the model and add the system item to it
        
        self.states=[]
        self.states_dot=[]
        self.variables=[]
        self.ops=[]

        self.enc=Model_Encoder.Encoder()
        self.sys_dict=collections.OrderedDict({'t':0, 'stop':0})
        self.print_list=[]
        self.exports={}
        for item in self.items.values():
            #print(item.name)
            ops=[]
            pop_list=[]
            
            if hasattr(item,'param_differences'):
                for k,v in item.param_differences.items():
                    ops+=[{'path':'abs', 'source': item.name, 'line': 0, 'target': k, 'func': 'subtract', 'left': v[0], 'right':v[1]}]
            
            if hasattr(item,'exports'):
                self.exports.update(item.exports)
                
            if hasattr(item,'ops'):
                
                ops+=item.ops
                
            for i,a in enumerate(ops):
                
                self.enc.add_data(a)
                
                if 'constant' in a:
                    self.sys_dict[a['target']]=a['constant']
                    pop_list+=[i]
                else:
                    self.sys_dict[a['target']]=np.float64(0)
                
            for i in sorted(pop_list,reverse=True):
                del ops[i]
            for a in ops:
                self.sys_dict[a['target']] = np.float64(0)
                
            self.ops+=ops
            
            ops=[]

            for cf in item.calc_funcs:
                #print(cf)
                if cf.__qualname__ not in self.parsed_funcs:
                    sourcelines=inspect.getsourcelines(cf)
                    #print(sourcelines)
                    self.parsed_funcs[cf.__qualname__]={'code_lines': [(cl, ast.parse(cl.strip()) if i >0 else None) for i, cl in enumerate(sourcelines[0])], 'source': cf.__name__, 'line': sourcelines[1]}

                source_obj = self.parsed_funcs[cf.__qualname__]
                count = 0

                line =source_obj['line']
                source=source_obj['source']
                proc_code=''
                
                for i, cl in enumerate(source_obj['code_lines']):
                   

                   
                    
                    if i>0:
                        call = cl[1]

                        for j,n in enumerate(call.body):
                        
                            if isinstance(n,ast.Assign):


                                ops+=self.enc.parse_Assign(n,'tmp_'+source+'_'+str(count), line+i,source)
                        
                                count+=1
                            else:
                                
                                if isinstance(n,ast.Expr):
                                    if isinstance(n.value, ast.Call):
                                        if n.value.func.id == 'print':
                                            self.print_list+=[{'target': self.enc.lookup_varname(item, n.value.args[0].id)}]


                                     

            if hasattr(item, 'calc_funcs_str'):
                for cf in item.calc_funcs_str:
                    count = 0

                    code = cf['source']

                    source = cf['name']
                    proc_code = ''

                    for i, cl in enumerate(code.splitlines()):


                        s = cl.strip()
                        #print(s)


                        call = ast.parse(s)
                        for j, n in enumerate(call.body):

                            if isinstance(n, ast.Assign):

                                ops += self.enc.parse_Assign(n, 'tmp_' + source + '_' + str(count), 0 + i, source)

                                count += 1
                            else:

                                if isinstance(n, ast.Expr):
                                    if isinstance(n.value, ast.Call):
                                        if n.value.func.id == 'print':
                                            self.print_list += [
                                                {'target': self.enc.lookup_varname(item, n.value.args[0].id)}]

            pop_list = []
            #print(ops)

            for i, a in enumerate(ops):

                self.enc.preprocess(item, a)

                if 'constant' in a:
                    self.sys_dict[a['target']] = a['constant']

                    pop_list += [i]
                else:
                    self.sys_dict[a['target']] = np.float64(0)

            for i in sorted(pop_list, reverse=True):
                del ops[i]

            self.ops += ops
            ops = []
        #print(self.ops)
        """"
        unaccounted = {}

        for o in self.ops:
            if not o['target'] in unaccounted:
                unaccounted[o['target']]=[]

            if 'left' in o:
                unaccounted[o['target']].append(o['left'])

            if 'right' in o:
                unaccounted[o['target']].append(o['right'])

            if 'args' in o:
                unaccounted[o['target']]+=o['args']



        constants = {}
        main_ops = {}

        for i, k in enumerate(self.sys_dict.keys()):
            if k not in unaccounted:
                constants[k] = (i, [])

        while len(unaccounted.keys())>0:
            unaccounted_copy = copy(unaccounted)
            for a, v in unaccounted_copy.items():

                #check if a is state_dot assignment
                if a in state_dot:
                    main_ops[a] = v
                    unaccounted.pop(a)

                for d in v[1]:
                    if d in unaccounted:
                        break

                    if d in state_derived_vars:
                        stated_derived_vars[a] = v
                        unaccounted.pop(a)
                        break

                constant_vars[a] = v
                unaccounted.pop(a)



        """




        #import sim_tools.utils.dependency_graph as dg

        #dg.make_dependency_graph(self.dependency_tree)

        for item in self.items.values():
            for v in item.Variables:
                v['path']=item.gp(v['name'])

                self.sys_dict[v['path']]=v['val']
                if v['type'] == 'state':
                    self.states+=[v['path']]
                    self.states_dot+=[v['path']+'_dot']
                    self.sys_dict[v['path']+'_dot']=0
                    v_dot=copy.copy(v)
                    v_dot['name']=v['name']+'_dot'
                    v_dot['type']='diff_state'
                    v_dot['unit']='('+v['unit']+')/s'
                    item.Variables.append(v_dot)
                if v['type'] == 'reservoir':
                    self.sys_dict[v['path']+'_dot']=0
            self.variables+=item.Variables
        #print(self.variables)
        self.enc.make_data_bank()
        
        self.sys_ix={k: i for i, k in enumerate(self.sys_dict.keys())}

        self.states_ix = np.array([self.sys_ix[s] for s in self.states], dtype=np.int32)
        self.states_dot_ix = np.array([self.sys_ix[s] for s in self.states_dot], dtype=np.int32)
        
        self.enc_ops=[]

        for i,a in enumerate(self.ops):
            try:
                #if i < 10:
                #    print(a)
                self.enc_ops_i=[0]*6
                self.enc_ops_i[self.enc.pos['target']]=self.sys_ix[a['target']]
                if self.sys_ix[a['target']] in self.states_dot_ix:
                    self.enc_ops_i[self.enc.pos['increment']]=1
                if 'op' in a:
                    self.enc_ops_i[self.enc.pos['func']]=self.enc.op_codes[a['op']]
                
                if 'func' in a:
                    self.enc_ops_i[self.enc.pos['func']]=self.enc.op_codes[a['func']]
                    if a['func'] == 'print':
                        self.print_list+=[a]
                        
                
                if 'left' in a:

                    self.enc_ops_i[self.enc.pos['left']]=self.sys_ix[a['left']]
                
                if 'right' in a:
                    self.enc_ops_i[self.enc.pos['right']]=self.sys_ix[a['right']]
                
                if 'comparators' in a:
                    self.enc_ops_i[self.enc.pos['right']]=self.sys_ix[a['comparators']]
                    
                if 'args' in a:
                    for j,l in enumerate(a['args']):
                        if isinstance(l,str):
                            self.enc_ops_i[self.enc.pos['arg'+str(j)]]=self.sys_ix[l]
                        else:
                            self.enc_ops_i[self.enc.pos['arg'+str(j)]]=l
                self.enc_ops+=[self.enc_ops_i]
            except:
                print(a)
                print('!!!!')
                raise

        # convert to correct datatypes
        self.enc.data_def = np.array(self.enc.data_def, dtype=np.int32)
        self.enc.data_bank = np.array(self.enc.data_bank, dtype=np.float64)
        self.enc_ops = np.array(self.enc_ops, dtype=np.int32)

        self.sys_ix_no_tmp = collections.OrderedDict({})
        for i, k in enumerate(self.sys_dict.keys()):
            if k.find('tmp') < 0:
                self.sys_ix_no_tmp[k] = i

        #self.sys_arr = [0] * len(self.sys_ix)
        #for k, i in self.sys_ix.items():
        #    self.sys_arr[i] = self.sys_dict[k]
        #print(self.sys_dict.values())
        self.sys_arr = np.array(list(self.sys_dict.values()),dtype=np.float64)

        #Construction of sys_dict_high_level
        self.sys_dict_high_level= collections.OrderedDict()
        for k,v in self.sys_ix_no_tmp.items():
            self.sys_dict_high_level[k] = self.sys_arr[v]

        self.system.output(self.sys_dict_high_level)

        self.common_keys = set(self.sys_dict.keys()).intersection(set(self.sys_dict_high_level.keys()))
        self.union_keys = set(self.sys_dict.keys()).union(set(self.sys_dict_high_level.keys()))

        #fix the dicts
        self.sys_dict=FixedDict(self.sys_dict)
        self.sys_dict_high_level = FixedDict(self.sys_dict_high_level)

        if persistent_tags == 'ALL':
            self.persistent_tags = list(self.sys_dict_high_level.keys())
        elif persistent_tags == 'NONE':
            self.persistent_tags = []
        elif isinstance(persistent_tags, dict):
            candidates = list(copy.copy(self.sys_dict_high_level.keys()))


            if 'level' in persistent_tags:
                level = persistent_tags['level']
                selected_candidates = []
                for tag in candidates:
                    if tag.count('.') <= level+1:
                        selected_candidates.append(tag)

                candidates = copy.copy(selected_candidates)
                #print('Len at level '+str(level)+': ', len(selected_candidates))
            #print(candidates)
            self.persistent_tags = candidates

        else:
            #print('N persistent tags requested: ', len(persistent_tags))
            self.persistent_tags = list(self.sys_dict_high_level.allowed_keys.intersection(persistent_tags))


        self.persistent_tags.append('t')
        self.persistent_tags = list(set(self.persistent_tags))

        #print('N persistent tags: ', len(self.persistent_tags))
        self.persistent_tags.sort()

        self.indices_persistent_tags = self.sys_dict_high_level.get_indices(self.persistent_tags, ignore_missing=True)

        self.init_dict = copy.deepcopy(self.sys_dict)

        #self.structure=self.get_structure()

    def update_sys_dict_from_high_level(self):
        # sync from high level to low level
        self.sys_dict.update_keys(self.sys_dict_high_level, self.common_keys)

    def update_high_level_from_sys_dict(self):
        # sync from high level to low level
        self.sys_dict_high_level.update_keys(self.sys_dict, self.common_keys)
    
    def apply_init_state(self,init_dict, allow_new=False):#, allow_new=True):
        if not False:
            try:
                assert set(init_dict.keys()).issubset(self.union_keys), 'Found keys that is neither in low nor high level sys dict!'
            except AssertionError:
                print('Missing keys starting')
                print('Init has number of keys: ',len(init_dict.keys()))
                print('Model has number of keys: ', len(self.union_keys))
                print('Number difs: ',len(set(init_dict.keys()).difference(self.union_keys)))
                print('Number difs2: ', len(self.union_keys.difference(set(init_dict.keys()))))
                dif_list=list(set(init_dict.keys()).difference(self.union_keys.intersection(set(init_dict.keys()))))
                dif_list.sort()
                dif_list2 = list(self.union_keys.difference(self.union_keys.intersection(set(init_dict.keys()))))
                dif_list2.sort()
                print('keys in init not in union',dif_list)
                print('keys in union not in init', dif_list2)
                #print('keys in union not in init',self.union_keys).difference(self.union_keys.intersection(set(init_dict.keys())))
                print('Missing keys ending')
                raise
        self.sys_dict.update_intersection(init_dict)
        self.sys_dict_high_level.update_intersection(init_dict)
        self.sys_arr = np.array(self.sys_dict.values(), dtype=np.float64)

        self.t = self.sys_dict_high_level['t']

        #for k,v in init_dict.items():
        #   self.sys_dict[k]=v
        
        #if allow_new:
        #    self.sys_ix={k: i for i, k in enumerate(self.sys_dict.keys())}
        
        #for k,v in init_dict.items():
        #    self.sys_arr[self.sys_ix[k]]=v

    #def update(self, dict):
    #    self.apply_init_state(dict, allow_new=False)

    def get_structure(self):
        return Model_Structure.make_Items_Graph(self.system, self.ops, self.sys_dict_high_level)

    def step(self,dt,dt_fix=0.05):
        from copy import deepcopy


        y0 = Model_Solver.state_vals(self.sys_arr, self.states_ix)
        self.y_last = Model_Solver.solve_BE2(self.t, self.t+dt, dt_fix, y0, self.sys_arr, deepcopy(self.sys_arr), self.states_ix, self.states_dot_ix, self.enc_ops, self.error_op,
                                   self.enc.data_def, self.enc.data_bank)

        self.t+=dt

        if not self.error_op[0] == 0:
            print(self.error_op)
            print('Arg0:', self.sys_arr[self.error_op[5]])
            print('Arg1:', self.sys_arr[self.error_op[6]])
            print('Arg2:', self.sys_arr[self.error_op[7]])

            print(self.ops[self.error_op[1]])
            raise ValueError('')


    def step_update(self,dt,dt_fix=0.05, history=False):

        self.step(dt,dt_fix)

        self.update(self.sys_arr)

        if history:
            self.update_history()

    def step_update_output(self,dt, dt_out, dt_fix=0.05, history=False, output=None):
        if self.toc == 0:
            self.toc = time()
        dt_rem = dt

        while dt_rem>0:

            dt_step = min(dt_rem, self.next_out - self.t)

            dt_rem -= dt_step
            tic = time()
            self.step(dt_step, dt_fix)
            toc = time()
            self.step_time += toc - tic
            self.non_step_time += tic - self.toc
            self.toc = toc


            if self.t >= self.next_out:
                self.update(self.sys_arr)
                self.last_out = self.t
                self.next_out = self.t + dt_out

                if history:
                    self.update_history()

                if output is not None:
                    output({pt: self.sys_dict_high_level[pt] for pt in self.persistent_tags})

    def solve_t_out(self, t_out, dt_fix=0.05, history=False):
        for t_out_ in t_out:
            dt = t_out_ - self.t
            if dt>0:
                self.step_update(dt, dt_fix=min(dt_fix, dt), history=history)
            elif dt==0 and history:
                self.update_history()

    def clear_history(self):

        self.history = {k: [] for k in self.sys_dict_high_level.keys()}


    def reset(self):
        self.clear_history()
        self.apply_init_state(self.init_dict)
        self.t = 0

    def update_history(self):
        if not hasattr(self,'history'):
            self.history = {k: [] for k in self.sys_dict_high_level.keys()}

        for k, v in self.sys_dict_high_level.items():
            self.history[k].append(v)




    def set_stop_tag(self, tag):
        self.stop_tag = tag
        self.stop_ix = self.sys_ix[tag]

    def to_dict(self):
        return {
            'states_ix': self.states_ix, 'states_dot_ix': self.states_dot_ix,
            'enc_ops': self.enc_ops,
            'data_def': self.enc.data_def, 'data_bank': self.enc.data_bank,
            'sys_arr': self.sys_arr,
            'sys_dict': self.sys_dict, 'sys_ix': self.sys_ix, 'ops': self.ops,
            'stop_ix': self.stop_ix
        }


    def update(self, updated_sys_arr):

        #first update sys_dict
        self.sys_dict.update_from_array(updated_sys_arr)
        self.sys_dict_high_level.update_keys(self.sys_dict, self.common_keys)
        # run output function
        self.system.output(self.sys_dict_high_level)

        self.sys_dict.update_keys(self.sys_dict_high_level, self.common_keys)

        self.sys_arr=np.array(self.sys_dict.values(), dtype=np.float64)

    def get_combined(self):

        combined_dict = self.get_combined_dict()
        combined_dict = FixedDict(combined_dict)

        return combined_dict.values(), combined_dict.get_indices_list()

    def get_combined_dict(self):
        combined_dict = {}
        combined_dict.update(self.sys_dict.dict())
        combined_dict.update(self.sys_dict_high_level.dict())

        return combined_dict

