
import ast
import functools
import numpy as np
ERR=0
ADD=1
SUB=2
MULT=3
DIV=4
POW=5
SQRT=6
ABS=7
SIGN=8
INTERP_BILIN=9
INTERP_LIN=10
NEG=11
EQ=12
GT=13
LT=14
GE=15
LE=16
EXP=17
LOG=18
CALC_POLY2X3=19
ASS=20
PRINT=21
MAX3=22
MIN3=23
MAX2=24
MIN2=25
SWITCH=26
MOD=27
SIN=28
RAND=29
ASSERT=30
IFEXP=31
NOZDIV=32
POSITIVE=33
NEGATIVE=34
INRANGE=35
MULT_SIGN=36
INTERP_NN=37
ARCTAN=38
MIX=39

TARGET=1
FUNC=0
INC=2
LEFT=3
RIGHT=4
ARG0=3
ARG1=4
ARG2=5

sentinel=object()


class Encoder:
    def __init__(self):
        self.data=[]
        self.pos={'target':TARGET, 'increment':INC, 'func':FUNC, 'left':LEFT, 'right':RIGHT, 'arg0':ARG0, 'arg1':ARG1, 'arg2':ARG2}
        op_codes_list=['']*(MIX+1)
        op_codes_list[ADD]='add'
        op_codes_list[SUB]='subtract'
        op_codes_list[MULT]='multiply'
        op_codes_list[DIV]='divide'
        op_codes_list[POW]='power'

        op_codes_list[SQRT]='sqrt'
        op_codes_list[ABS]='abs'
        op_codes_list[SIGN]='sign'
        op_codes_list[INTERP_BILIN]='interp_bilin'
        op_codes_list[INTERP_LIN]='interp_lin'
        op_codes_list[NEG]='negate'
        op_codes_list[EQ]='equal'
        op_codes_list[GT]='greater'
        op_codes_list[LT]='less'
        op_codes_list[GE]='greater_equal'
        op_codes_list[LE]='less_equal'
        op_codes_list[EXP]='exp'
        op_codes_list[LOG]='ln'
        op_codes_list[CALC_POLY2X3]='calc_poly2x3'
        op_codes_list[ASS]='assign'
        op_codes_list[PRINT]='print'
        op_codes_list[MAX2] = 'max2'
        op_codes_list[MIN2] = 'min2'
        op_codes_list[MAX3] = 'max'
        op_codes_list[MIN3] = 'min'
        op_codes_list[SWITCH] = 'switch'
        op_codes_list[MOD] = 'modulus'
        op_codes_list[SIN] = 'sin'
        op_codes_list[RAND] = 'rand'
        op_codes_list[ASSERT] = 'assertion'
        op_codes_list[IFEXP] = 'ifexp'
        op_codes_list[NOZDIV] = 'no_zero_div'
        op_codes_list[POSITIVE] = 'positive'
        op_codes_list[NEGATIVE] = 'negative'
        op_codes_list[INRANGE] = 'inrange'
        op_codes_list[MULT_SIGN] = 'mult_sign'
        op_codes_list[INTERP_NN] = 'interp_nn'
        op_codes_list[ARCTAN] = 'arctan'
        op_codes_list[MIX] = 'mix'
        self.op_codes={o: i for i, o in enumerate(op_codes_list)}
        self.data_obs=[]
        self.sentinel = object()


    def recurse_Attribute(self,attr):
        if isinstance(attr.value,ast.Name):
            return attr.value.id+'.'+attr.attr
        elif isinstance(attr.value, ast.Attribute):
            return self.recurse_Attribute(attr.value)+'.'+attr.attr


    def parse_Assign(self, assign, tmp,assign_no,source):
        #print(assign.targets[0]._fields)
        a={'source':source, 'line':assign_no}#, 'dump':ast.dump(assign)}
        #ast.dump(assign)
        #a={'source':source, 'line':assign[1]}
        if isinstance(assign.targets[0],ast.Name):
            a['target']=assign.targets[0].id
            
        elif isinstance(assign.targets[0],ast.Attribute):
            a['target']=self.recurse_Attribute(assign.targets[0])
        else:
            pass
            #print(assign.targets[0])
            
        tmp_counter=0

        ops=[]
        if isinstance(assign.value,ast.Num):
            a['constant']= np.float64(assign.value.n)
            #a['op']='constant'
            
        elif isinstance(assign.value,ast.UnaryOp):
            if isinstance(assign.value.op,ast.USub):
                a['op']= 'negate'
            else:
                print(ast.dump(assign))
                raise NameError('Unknown operator type: '+str(assign.value.op))
           
            o, tmp_counter, var = self.recursive_parse(assign.value.operand,tmp,tmp_counter,assign_no,source)
            a['left']=var
            ops+=o
        
        elif isinstance(assign.value,ast.BinOp):
            if isinstance(assign.value.op,ast.Sub):
                a['op']= 'subtract'
            
            elif isinstance(assign.value.op,ast.Mult):
                a['op']= 'multiply'
                
            elif isinstance(assign.value.op,ast.Add):
                a['op']= 'add'
            
            elif isinstance(assign.value.op,ast.Div):
                a['op']= 'divide'
                
            elif isinstance(assign.value.op,ast.Pow):

                a['op']= 'power'

            elif isinstance(assign.value.op,ast.Mod):

                a['op']= 'modulus'

            else:
                raise NameError('Unknown operator type: '+str(assign.value.op) + ' in ' + str(assign))
            
            o, tmp_counter, var = self.recursive_parse(assign.value.right,tmp,tmp_counter,assign_no,source)
            a['right']=var
            ops+=o
                
            o, tmp_counter, var = self.recursive_parse(assign.value.left,tmp,tmp_counter,assign_no,source)
            a['left']=var
            ops+=o
     
        elif isinstance(assign.value,ast.Compare):
            if isinstance(assign.value.ops[0],ast.Lt):
                a['op']= 'less'
            
            if isinstance(assign.value.ops[0],ast.LtE):
                a['op']= 'less_equal'
            if isinstance(assign.value.ops[0],ast.Gt):
                a['op']= 'greater'
            if isinstance(assign.value.ops[0],ast.GtE):
                a['op']= 'greater_equal'
                
            if isinstance(assign.value.ops[0],ast.Eq):
                a['op']= 'equal'
            
            o, tmp_counter, var = self.recursive_parse(assign.value.comparators[0],tmp,tmp_counter,assign_no,source)
            a['comparators']=var
            ops+=o
                
            o, tmp_counter, var = self.recursive_parse(assign.value.left,tmp,tmp_counter,assign_no,source)
            a['left']=var
            ops+=o        
        
        elif isinstance(assign.value,ast.Call):
            try:
                
                a['func']=assign.value.func.id
                
                a['args']=[]
                for i, arg in enumerate(assign.value.args):
                    o, tmp_counter, var = self.recursive_parse(arg,tmp,tmp_counter,assign_no,source)
                    a['args']+=[var]
                    ops+=o
            except AttributeError:
                raise AttributeError('Cannot parse functions calls with attibute in name.\n'+ast.dump(assign.value))
                
        elif isinstance(assign.value,ast.Attribute):
            a['left']=self.recurse_Attribute(assign.value)
            a['op']= 'assign'
        elif isinstance(assign.value,ast.Name):
            a['left']=assign.value.id
            a['op']= 'assign'
        elif isinstance(assign.value, ast.IfExp):
            a['func']='ifexp'
            o, tmp_counter, test = self.recursive_parse(assign.value.test,tmp, tmp_counter,assign_no,source)
            ops += o
            o, tmp_counter, body = self.recursive_parse(assign.value.body, tmp, tmp_counter, assign_no, source)
            ops += o
            o, tmp_counter, orelse = self.recursive_parse(assign.value.orelse, tmp, tmp_counter, assign_no, source)
            ops += o
            #print(ops[-3:])

            a['args']=[test, body, orelse]

            #print(assign.value)
        else:
            raise NameError('Unknown function: '+str((assign.value)) )
                  
        ops+=[a]
        return ops
 
    def recursive_parse(self,arg,tmp,tmp_counter,assign_no,source):
        os=[]
        if isinstance(arg, ast.Call) or isinstance(arg, ast.BinOp) or isinstance(arg, ast.UnaryOp) or isinstance(arg, ast.Compare) or isinstance(arg, ast.Num):
            Var=tmp+'_'+str(tmp_counter)
            tmp_counter+=1

            arg.targets=[]

            t=ast.Name()
            t.id=Var
            arg.targets+=[t]
            arg.value=arg
            os=self.parse_Assign(arg,Var,assign_no,source)

        elif isinstance(arg, ast.Attribute):
            Var=self.recurse_Attribute(arg)
                
        elif isinstance(arg, ast.Name):
            Var=arg.id
      
        else:
            raise NameError('Unknown object in '+source+' line: '+str(assign_no)+' - parsing of this type of object not implemented: '+str(arg) )
        #print(os)
        return os, tmp_counter, Var
    
    
    def rgetattr(self, obj, attr, default=sentinel):
        if default is sentinel:
            _getattr = getattr
        else:
            def _getattr(obj, name):
                return getattr(obj, name, default)
        return functools.reduce(_getattr, [obj]+attr.split('.'))
    
    def add_data(self,assign):
        # print(assign['source'])
        # print(assign['target'])
        # print(assign['func'])
        if assign['func']=='interp_bilin' or assign['func']=='interp_lin' or assign['func']=='calc_poly2x3' or assign['func']=='interp_nn':
            #print(assign)
            #print(type(assign['args'][0]))
            self.data+=[assign['args'][0]]
            self.data_obs+=[0]
        #print(self.data)
            assign['args'][0]=len(self.data_obs)-1

        if assign['func'] == 'max':
            if len(assign['args']==2):
                assign['func'] = 'max2'
            elif len(assign['args']==3):
                assign['func'] = 'max3'
            else:
                raise ValueError('max needs 2 or 3 arguments not: ', len(assign['args']))

        if assign['func'] == 'min':
            if len(assign['args']==2):
                assign['func'] = 'min2'
            elif len(assign['args']==3):
                assign['func'] = 'min3'
            else:
                raise ValueError('min needs 2 or 3 arguments not: ', len(assign['args']))





    
    def preprocess(self,caller,assign):
        
        if 'func' in assign or 'ifexp' in assign:
            
                
                
            if assign['func']=='interp_bilin' or assign['func']=='interp_lin' or assign['func']=='calc_poly2x3':
                
                s=assign['args'][0].split('.')
                
                attrs=''
                for i in range(0,len(s)):
                    if not s[i]=='self':
                        if not attrs=='':
                            attrs+='.'
                        attrs+=s[i]

                if not attrs=='':
                    ob=self.rgetattr(caller,attrs)
                    did=id(ob)
                else:
                    ob=getattr(caller,s[-1])
                    did=id(ob)

                if did not in self.data_obs:
                    self.data_obs+=[did]
                    
                    self.data+=[ob]

                assign['args'][0]=self.data_obs.index(did)
                
                for i in range(1,len(assign['args'])):
                    assign['args'][i]=self.lookup_varname(caller,assign['args'][i])
                    
            else:
                assign['args']=[self.lookup_varname(caller,arg) for arg in assign['args']]
        
        tags=['target','left','right','comparators']
        
        for t in tags:
            if t in assign:
                assign[t]=self.lookup_varname(caller,assign[t])
       
    def lookup_varname(self,caller,var):
        if var[:4] == 'Stub' or var[:5] == 'Input':
            #assume path is absolute!
            return var
        else:
            s=var.split('.')
            attrs=''
            for i in range(0,len(s)-1):
                if not s[i]=='self':
                    if i>0:
                        attrs+='.'+s[i]
                    else:
                        attrs+=s[i]

            if not attrs=='':
                import operator
                try:
                    return operator.attrgetter(attrs)(caller).gp(s[-1])
                except:
                    print(caller, ' ', var)
                    raise
                #return getattr(caller,attrs).gp(s[-1])
            else:
                try:
                    return caller.gp(s[-1])
                except AttributeError as ae:
                    raise ValueError(f'The caller for {var} does not exist')
            
    """"
    def make_data_bank(self):
        self.data_bank=np.zeros(1,dtype=np.float64)
        
        self.data_def=[]
        for d in self.data:
            
            self.data_def+=[self.data_bank.shape[0]]
            #print(type(d))
            self.data_bank=np.concatenate([self.data_bank,d])
        #print('data def')
        #print(self.data_def)        
    """
    def make_data_bank(self):
        import pandas as pd
        self.data_def = []
        cum_ix=0
        for d in self.data:
            """""
            try:
                x = d[3:3+int(d[0])]
                y = d[3+int(d[0]):]
                print(len(x))
                print(len(y))
                dat={'t': x, 'y': y}
                df = pd.DataFrame(dat)
                df.describe()
                print('first four int: ',[int(i) for i in d[:4]])
                print('first four: ', d[:4])

                print('max/min: ', max(y), ' - ', min(y))
            except:
                #raise
                print('issue')

            print()
            """
            self.data_def.append(cum_ix)
            cum_ix += len(d)
        #print(self.data)
        if len(self.data) >0:
            self.data_bank = np.concatenate(self.data)
        else:
            self.data_bank = []

