from mongoengine import *
import datetime
from copy import deepcopy
from decimal import Decimal
import uuid

# class model_assemble(Document):
#     name
#     filename
#     author
#     version
#     verified



    
class User(Document):
    email = StringField(required=False)
    first_name = StringField(max_length=50)
    last_name = StringField(max_length=50)
    login = StringField(max_length=50,required=True)
    
    def __str__(self):
        if not self:
            return ''
        return self.first_name

class Config(Document):
    name = StringField(required=True, unique=True)
    description = StringField()
    comment = StringField()
    configuration = DictField()
    tags = ListField(StringField())
    uploaded_date = DateTimeField(defalt=datetime.datetime.now)
    uploaded_by = ReferenceField(User)
    
    def __str__(self):
        if not self:
            return ''
        return self.name

class Inputs(EmbeddedDocument):
    name = StringField(required=True)
    meta = {'allow_inheritance': True}
    _hidden_fields = ListField(StringField(),default=['_cls', '_hidden_fields', '_override_fields', 'name', 'uuid'])
    _override_fields = ListField(StringField())
    uuid = UUIDField(default=lambda: uuid.uuid4(), primary_key=True)

    def clone(self):
        self_copy={k: deepcopy(getattr(self,k)) for k in self._fields.keys()}
        self_copy.pop('uuid')
        return self.__class__(**self_copy)
    
    def hide_fields(self,fields_to_hide):
        self._hidden_fields=list(set(self._hidden_fields+fields_to_hide))
        return self
     
    
class Task_Inputs(Inputs):
    sim_module = StringField()
    stub = BooleanField(default=True)
    
    
#class Generic_Inputs(Inputs):
class Generic_Inputs(Inputs):
    Input_Dict = DictField(default={})

class EMA_Inputs(Inputs):
    Evaporator_T_In = FloatField(required=True)
    Evaporator_T_Out = FloatField(required=True)
    Condenser_T_In = FloatField(required=True)
    Condenser_T_Out = FloatField(required=True)
    Subcooler_T_In = FloatField(required=True)
    Subcooler_T_Out = FloatField(required=True)
    Compressors_Running = IntField(default=4)
    EMA = ReferenceField(Config)

class EMB_Inputs(Inputs):
    Evaporator1_T_In = FloatField(required=True)
    Evaporator1_Flow = FloatField(required=True)
    Evaporator2_T_In = FloatField(required=True)
    Evaporator2_Flow = FloatField(required=True)
    Condenser1_T_In = FloatField(required=True)
    Condenser1_Flow = FloatField(required=True)
    Condenser2_T_In = FloatField(required=True)
    Condenser2_Flow = FloatField(required=True)
    Subcooler_T_In = FloatField(required=True)
    Subcooler_Flow = FloatField(required=True)
    Flow = BooleanField(default=True)
    Compressor1_Running = BooleanField(default=True)
    Compressor2_Running = BooleanField(default=True)
    EMB = ReferenceField(Config)

class EM_Static(Inputs):
    Chilled_System_Temperature = FloatField(required=True)
    Chilled_System_Return_dT = FloatField(required=True)
    Heating_System_Temperature = FloatField(required=True)
    Heating_System_Return_dT = FloatField(required=True)
    Hot_Water_Temperature = FloatField(required=True)
    Hot_Water_Return_dT = FloatField(required=True)

class EM_Static_COP(Inputs):
    Chilled_System_Temperature = FloatField(required=True)
    Chilled_System_Return_dT = FloatField(required=True)
    Heating_System_Temperature = FloatField(required=True)
    Heating_System_Return_dT = FloatField(required=True)
    Hot_Water_Temperature = FloatField(required=True)
    Cold_Water_Temperature = FloatField(required=True)
    Heating_System_Power = FloatField(required=True)
    Hot_Water_Flow = FloatField(required=True)
    EM = ReferenceField(Config)
    Rating = BooleanField(default=False)

class Interp_Generic(Inputs):
    Interp_Data = DictField(default={})
    Static_Data = DictField(default={})

class Solver_Inputs(Inputs):
    t_Start = FloatField(default=0,required=True)
    t_End = FloatField(default=1000,required=True)
    t_Step= FloatField(default=60,required=False)    
# class Input_Config(Input):
#     config=EmbeddedDocumentField(Config)
#     tags=ListField(StringField(max_length=20))
    
class Outputs(EmbeddedDocument):
    fields = ListField(StringField())
    
    def clone(self):
        return Outputs(fields=self.fields)
    
class Graph(EmbeddedDocument):
    type = StringField()
    x = StringField()
    y = ListField(StringField())
    y_modes = ListField(StringField())
    x_title = StringField()
    y_title = StringField()
    title = StringField()
    
    def clone(self):
        self_copy={k: getattr(self,k) for k in self._fields.keys()}
        return self.__class__(**self_copy)

class Data_View_Field(EmbeddedDocument):
    name = StringField(required=True)
    text_rep = StringField(default="{}: {}")
    scaling = FloatField(default=1)
    decimal_spaces = IntField(default=1)

    def clone(self):
        self_copy = {k: getattr(self, k) for k in self._fields.keys()}
        return self.__class__(**self_copy)

class Data_View(EmbeddedDocument):
    title=StringField()
    fields=EmbeddedDocumentListField(Data_View_Field)

    def clone(self):
        self_copy = {k: getattr(self, k) for k in self._fields.keys()}
        self_copy.pop('fields')
        s_clone=self.__class__(**self_copy)

        for i in self.fields:
            s_clone.fields.append(i.clone())

        return s_clone

class Model_Structure(Document):
    structure = DictField()
    sys_ix = DictField()


class Raw_Data(Document):
    #fields = ListField(StringField())
    chunk_size = IntField(default=100000)
    upload_date = DateTimeField(defalt=datetime.datetime.now)
    num_chunks = IntField(default=0)
    count = IntField(default=0)
    block_count = IntField(default=0)
    model = ReferenceField(Model_Structure, default=lambda: Model_Structure(structure={}, sys_ix={}).save())
    t_from = FloatField(default=0)
    t_to = FloatField(default=1)
    t_step = FloatField(default=1)
    t_solved = FloatField(default=0)

    def put_model_structure(self, model_struct):
        self.model.structure = model_struct
        self.model.save()

        #db = Raw_Data._get_db()
        #db.model_structure.insert_one({ 'parent_id': self.id, 'structure': model_struct})

    def get_model_structure(self):
        return self.model.structure

    def put_sys_ix(self, sys_ix):
        self.model.sys_ix = {k.replace('.','+'):v for k,v in sys_ix.items()}
        self.model.save()

    def put_sys_arr(self, sys_arr):
        db = Raw_Data._get_db()


        if self.count >= self.chunk_size or self.num_chunks == 0:



            self.count = 0
            self.num_chunks += 1
            self.save()


            res = db.model_sys_arr_chunks.insert_one({'n': self.num_chunks, 'parent_id': self.id, 'first_block': self.block_count})

            self.chunk_id = res.inserted_id

            self.save()

        db.model_sys_arr_chunks.update({'_id': self.chunk_id}, {'$push': {'data': sys_arr}, '$set': {'last_block': self.block_count}})
        self.count += len(sys_arr)
        self.block_count += 1

    def get_sys_arr(self, ix, tags):
        db = Raw_Data._get_db()
        #print(tags)
        chunk=db.model_sys_arr_chunks.find_one({'parent_id': self.id, 'first_block': {'$lte': ix}, 'last_block': {'$gte': ix}})

        #TODO: inline ix selection in mongodb query!
        data = chunk['data']
        first_block = chunk['first_block']

        sys_ix = {k.replace('+', '.'): v for k, v in self.model.sys_ix.items()}

        rdata=data[ix - int(first_block)][:]



        import numpy as np
        vals=[]
        #TODO: Can be optimized!
        for t in tags:
            if t in sys_ix:
                vals += [rdata[sys_ix[t]]]
            else:
                vals += [0]


        return vals


        #for i in range(data[1][:]):

        #print( data)

    def get_t_data(self, tag):
        db = Raw_Data._get_db()

        def find(lst, key, value):
            for i, dic in enumerate(lst):
                if dic[key] == value:
                    return i

            return -1

        variables = self.model.structure['variables']
        #print('tag',tag)
        off_set = 0
        #if tag.split('.')[-1][0]=='T':
         #   off_set = -273.15
        s_ix=self.model.sys_ix
        #print(s_ix)
        data = []
        avg=0
        ntag=tag.replace('.','+')
        if ntag in s_ix:
            ix = s_ix[ntag]#find(variables,'path',tag)
            tix = s_ix['t']#find(variables,'path','t')
            #print(ix)
            chunks=db.model_sys_arr_chunks.find({'parent_id': self.id})

            #t=[]
            first_block=0
            acc = 0
            count=0
            for c in chunks:
                #print(c['data'])
                #print(len(c['data'][0][:]))
                for r in c['data']:
                    data += [{'t': r[tix],'y': r[ix] + off_set}]
                    acc+=r[ix]
                    count+=1

            avg=acc / count

        #print(data)
        return {'name': tag, 'data': data, 'avg': avg}

    def get_series(self, tag):
        db = Raw_Data._get_db()

        def find(lst, key, value):
            for i, dic in enumerate(lst):
                if dic[key] == value:
                    return i

            return -1


        s_ix = self.model.sys_ix
        # print(s_ix)
        data = []
        avg = 0
        ntag = tag.replace('.', '+')
        if ntag in s_ix:
            ix = s_ix[ntag]  # find(variables,'path',tag)

            chunks = db.model_sys_arr_chunks.find({'parent_id': self.id})


            for c in chunks:

                for r in c['data']:
                    data += [r[ix]]




        # print(data)
        return data
        #for i in range(data[1][:]):

        #print( data)

    def delete_clean(self):
        db=Raw_Data._get_db()
        #db.model_structure.delete_one({'parent_id': self.id})
        chunks = db.model_sys_arr_chunks.delete_many({'parent_id': self.id})
        print('Deleted: ' + chunks.count() + ' result sys arr chunks')
        self.model.delete()
        self.delete()

class Result(Document):
    fields=ListField(StringField())
    chunk_size=IntField(default=10000)
    upload_date=DateTimeField(defalt=datetime.datetime.now)
    
    num_chunks=IntField(default=0)
    count=IntField(default=0)


    
    def write_point(self, data_point):
        db=Result._get_db()
        if self.count>=self.chunk_size or self.num_chunks==0:
            self.count=0
            self.num_chunks+=1
            self.save()
            res=db.chunks.insert_one({'n': self.num_chunks, 'res_id': self.id})
           
            self.chunk_id=res.inserted_id
           
            self.save()

        
        db.chunks.update({'_id': self.chunk_id},{'$push':{'data':data_point}})
        self.count+=1

    def read_results(self):
        self.save()
        db=Result._get_db()
        chunks=db.chunks.find({'res_id': self.id})
        data=[]
        for c in chunks:
            data+=c['data']
        return data
    
    def write_result_data(self, result_data):
        db=Result._get_db()
        db.result_data.insert_one({'res_id': self.id, 'data':result_data})
    
    def get_result_data(self):
        db=Result._get_db()
        res_data=db.result_data.find({'res_id': self.id})
        data=[]
        for r in res_data:
            data+=[r['data']]
        return data
      
    def close(self):
        self.save()
        
    def delete_clean(self):
        db=Result._get_db()
        chunks=db.chunks.delete_many({'res_id': self.id})
        print('Deleted: '+ chunks.count() + ' result chunks')
        self.delete()
  

class Task(Document):
    name = StringField(max_length=120, required=True) 
    status = StringField(max_length=120, required=True)
    update_num = IntField(deafult=0)
    progress_string =  StringField(max_length=500)
    current_simulation_time = FloatField()
    progress_percentage = FloatField()
    time_elapsed = FloatField()
    ETA = DateTimeField()
    duration_expected = FloatField()
    remaining_expected = FloatField()
    rate_avg = FloatField()
    rate_last = FloatField()
    log_file = FileField()
    outputs_file = FileField()
    outputs_result = ReferenceField(Result)
    outputs_raw_data = ReferenceField(Raw_Data)

    def delete_clean(self):
        self.log_file.delete()
        self.outputs_file.delete()
        try:
            if not self.outputs_result == None:
                self.outputs_result.delete_clean()
        except:
            pass

        try:

            if not self.outputs_raw_data == None:
                self.outputs_raw_data.delete_clean()
        except:
            pass

        self.delete()
        
    
    
class Scenario(EmbeddedDocument):
    name = StringField(max_length=120, required=True)
    inputs = EmbeddedDocumentListField(Inputs)
    outputs = EmbeddedDocumentListField(Outputs)
    results = EmbeddedDocumentListField(Outputs)
    task = ReferenceField(Task, default=None)
    status = StringField(default='Not run')
    action = StringField(default='None')
    graphs = EmbeddedDocumentListField(Graph)
    result_views = EmbeddedDocumentListField(Data_View)
    progress = FloatField(default=0)
    uuid = UUIDField(default=lambda: uuid.uuid4(), primary_key=True)
    visualization = StringField(default='EM_Rosendal')
    persistant_tags = ListField(StringField(),default=[])

    def update_status(self):
        self.progress=0
        prog=0
        if self.task:
            self.status=self.task.status.strip()
            import random
            if self.task.progress_percentage:
                self.progress=self.task.progress_percentage

                prog=round(Decimal(self.task.progress_percentage*(self.task.progress_percentage<100)+(self.task.progress_percentage>=100)*100),0)


            # round(Decimal(float(random.uniform(0,1)*100)), 0)            #
        return self.status, prog
        
       
    
    def is_running(self):
        not_running_set=set(['Successful', 'Failed', 'Not run','Ready'])
        self.update_status()
        #print(self.status)
        if self.status in not_running_set:
            return False
        else:
            return True
    
    #def get_progress(self):
    #    if not self.task==None:
    #        self.status=self.task.status
    #        self.progress=self.task.porgress_percentage
    #        return self.status, self.progress
    #    else:
    #        return self.status, self.progress
            
    def clone(self):
        s_clone=Scenario(name=self.name)
        for i in self.inputs:
            #print(i.clone())
            s_clone.inputs.append(i.clone())
        
        for i in self.outputs:
            #print(i.clone())
            s_clone.outputs.append(i.clone())
        
        for i in self.results:
            #print(i.clone())
            s_clone.results.append(i.clone())
        
        for i in self.graphs:
            #print(i.clone())
            s_clone.graphs.append(i.clone())

        for i in self.result_views:
            #print(i.clone())
            s_clone.result_views.append(i.clone())

        s_clone.persistant_tags=self.persistant_tags
        
            
        return s_clone
        
    
class Group(EmbeddedDocument):
    name = StringField(max_length=120, required=True)
    inputs = EmbeddedDocumentListField(Inputs)
    outputs = EmbeddedDocumentListField(Outputs)
    results= EmbeddedDocumentListField(Outputs)
    scenarios=EmbeddedDocumentListField(Scenario)
    graphs = EmbeddedDocumentListField(Graph)
    uuid = UUIDField(default=lambda: uuid.uuid4(), primary_key=True)

    def clone(self):
        g_clone=Group(name=self.name)
        for i in self.inputs:
            #print(i.clone())
            g_clone.inputs.append(i.clone())
        
        for i in self.outputs:
            #print(i.clone())
            g_clone.outputs.append(i.clone())
        
        for i in self.results:
            #print(i.clone())
            g_clone.results.append(i.clone())
        
        for i in self.graphs:
            #print(i.clone())
            g_clone.graphs.append(i.clone())
            
        for s in self.scenarios:
            #print(i.clone())
            g_clone.scenarios.append(s.clone())
            
        return g_clone

class Project(Document):
    title = StringField(max_length=120, required=True, unique=True, verbose_name='title')
    owner = ReferenceField(User)
    tags = ListField(StringField(max_length=30),verbose_name="tags")
    collaborators=ListField(ReferenceField(User))
    #meta = {'allow_inheritance': True}
    groups=EmbeddedDocumentListField(Group)
    uuid = UUIDField(default=lambda: uuid.uuid4(), primary_key=True)
    version = IntField(default=0)
    results_version = IntField(default=0)
    write_lock = IntField(default=0)
    private=BooleanField(default=True)

    def clone(self, newtitle):
        p_clone = Project(title=newtitle, owner = self.owner, tags=self.tags, collaborators = self.collaborators)

        for g in self.groups:
            p_clone.groups.append(g.clone())

        return p_clone

ReliableServer="mongodb://sim:m9Yx8ip6P88i@209.222.108.162:28017/sim_data?authSource=admin"
GCE_Kubernetes="mongodb://sim:m9Yx8ip6P88i@35.195.195.235:27017/sim_data?authSource=admin"

Default_MongoDB_URI=GCE_Kubernetes##ReliableServer
connect(host=Default_MongoDB_URI)

    
