import refprop as rp
import numpy as np
import copy
import math
#import matplotlib.pyplot as plt
import pickle
import os, numpy as np
from ctREFPROP.ctREFPROP import REFPROPFunctionLibrary

print('REFPROP')
print('Location: ', os.environ['RPPREFIX']+ '\REFPRP64.DLL')

#setup Refprop
rp = REFPROPFunctionLibrary(os.environ['RPPREFIX'] + '\REFPRP64.DLL')
rp.SETPATHdll(os.environ['RPPREFIX'])



def generate_property_tables(flash, comb, N, xmin, xmax, ymin, ymax, logx=False, logy=False):
    Nx = N[0]
    Ny = N[1]

    X = np.linspace(xmin, xmax, Nx)

    Y = np.linspace(ymin, ymax, Ny)

    a = np.empty([Nx * Ny])
    a[:] = np.NAN

    X_1D = copy.copy(a)
    Y_1D = copy.copy(a)

    T_1D = copy.copy(a)
    q_1D = copy.copy(a)
    p_1D = copy.copy(a)

    h_m_1D = copy.copy(a)
    D_1D = copy.copy(a)
    s_m_1D = copy.copy(a)

    for i in range(0, Nx):
        # print(str(i/Nx*100)+'%')
        for j in range(0, Ny):
            X_1D[i * Ny + j] = X[i]
            Y_1D[i * Ny + j] = Y[j]

            #try:

            if comb[0] == 'h' or comb[0] == 's':
                x = float(X[i])
            elif comb[0] == 'D':

                x = float(X[i])
            else:
                x = float(X[i])

            if logx:
                x = float(np.exp(x))

            if comb[1] == 'h' or comb[1] == 's':
                y = float(Y[j])
            elif comb[1] == 'D':
                y = float(Y[j])
            else:
                y = float(Y[j])

            # y=float(Y[j])

            if logy:
                y = float(np.exp(y))

            prop = flash(comb, x, y)

            D_1D[i * Ny + j] = prop['D']
            s_m_1D[i * Ny + j] = prop['s']
            T_1D[i * Ny + j] = prop['t']
            q_1D[i * Ny + j] = prop['q']
            p_1D[i * Ny + j] = prop['p']
            h_m_1D[i * Ny + j] = prop['h']

            #except rp.RefpropdllError as e:
            #    print(e)

            #except rp.RefpropdllWarning as w:
            #    print(w)

    h_1D = copy.copy(h_m_1D)
    r_1D = copy.copy(D_1D)
    s_1D = copy.copy(s_m_1D)
    print('Nan Tv')
    Tv = removeNaN(np.reshape(T_1D, (Nx, Ny)))
    print('Nan Dv')
    Dv = removeNaN(np.reshape(D_1D, (Nx, Ny)))
    print('Nan qv')
    print(q_1D)
    qv = removeNaN(np.reshape(q_1D, (Nx, Ny)))
    print('Nan pv')
    pv = removeNaN(np.reshape(p_1D, (Nx, Ny)))
    print('Nan hv')
    hv = removeNaN(np.reshape(h_1D, (Nx, Ny)))
    print('Nan rhov')
    rhov = removeNaN(np.reshape(r_1D, (Nx, Ny)))
    print('Nan sv')
    sv = removeNaN(np.reshape(s_1D, (Nx, Ny)))

    Xv = removeNaN(np.reshape(X_1D, (Nx, Ny)))
    Yv = removeNaN(np.reshape(Y_1D, (Nx, Ny)))

    # plt.figure()
    # plt.pcolormesh(Xv,Yv,qv)
    # plt.colorbar() # Color Bar
    # #plt.clim(0,1)
    # plt.show()
    #
    # plt.figure()
    # plt.pcolormesh(Xv,Yv,pv)
    # plt.colorbar() # Color Bar
    # #plt.clim(0,6000)
    # plt.show()
    #
    # plt.figure()
    # plt.pcolormesh(Xv,Yv,Tv)
    # plt.colorbar() # Color Bar
    # #plt.clim(200,500)
    # plt.show()
    #
    # plt.figure()
    # plt.pcolormesh(Xv,Yv,sv)
    # plt.colorbar() # Color Bar
    # #plt.clim(0,100)
    # plt.show()
    #
    # plt.figure()
    # plt.pcolormesh(Xv,Yv,hv)
    # plt.colorbar() # Color Bar
    # plt.show()

    return {
        'x': X.tolist(),
        'xv': Xv.tolist(),
        'y': Y.tolist(),
        'Yv': Yv.tolist(),
        'T': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(Tv))]).tolist(),
        'Tv': Tv.tolist(),
        'p': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(pv))]).tolist(),
        'pv': pv.tolist(),
        'h': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(hv))]).tolist(),
        'hv': hv.tolist(),
        'q': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(qv))]).tolist(),
        'qv': qv.tolist(),
        's': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(sv))]).tolist(),
        'sv': sv.tolist(),
        'D': np.concatenate([[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(Dv))]).tolist(),
        'Dv': Dv.tolist(),
        'rho': np.concatenate(
            [[Nx, Ny, xmin, xmax, ymin, ymax], X, Y, np.ndarray.flatten(np.transpose(rhov))]).tolist(),
        'rhov': rhov.tolist()
    }


def removeNaN(data):
    count_nan_remove = 1
    count_nan = 0
    loops = 0
    while count_nan_remove > 0:

        cdata = copy.copy(data)
        count_nan = 0
        count_nan_remove = 0
        for i in range(0, len(data[:, 0])):

            for j in range(0, len(data[0, :])):

                if math.isnan(data[i, j]):
                    count_nan += 1
                    n = 0
                    d = 0

                    if i > 0:
                        if not math.isnan(data[i - 1, j]):
                            d = d + data[i - 1, j]
                            n = n + 1

                    if i < len(data[:, 0]) - 1:
                        if not math.isnan(data[i + 1, j]):
                            d = d + data[i + 1, j]
                            n = n + 1

                    if j > 0:
                        if not math.isnan(data[i, j - 1]):
                            d = d + data[i, j - 1]
                            n = n + 1

                    if j < len(data[0, :]) - 1:
                        if not math.isnan(data[i, j + 1]):
                            d = d + data[i, j + 1]
                            n = n + 1
                    if n >= 1:
                        count_nan_remove += 1
                        cdata[i, j] = d / n
        if loops == 0:
            print('NaN found: ', count_nan)
            print('NaN removed: ', count_nan_remove)
            assert count_nan < len(data[:, 0]) * len(
                data[0, :]) * 0.75, "More NaN than allow - max is 75% this was: " + str(
                count_nan / (len(data[:, 0]) * len(data[0, :])) * 100) + '%!!'
        loops += 1
        data = cdata

    return cdata

def generate_Two_Phase_Fluid_Properties_from_Refprop(fluid_name, file_name, N):




    if fluid_name[-3:] == 'MIX':
        a, b, z, ierr, herr = rp.SETMIXdll(fluid_name, 'HMX.BNC', 'DEF')
        if ierr != 0:
            raise ValueError(str(ierr) + ':' + herr)
    else:
        z = [1.0] + [0.0] * 19
        ierr = rp.SETFLUIDSdll(fluid_name)
        if ierr != 0:
            raise ValueError(str(ierr))


    print('Version: ',rp.RPVersion())

    # MOLAR_BASE_SI = rp.GETENUMdll(0, "MOLAR BASE SI").iEnum
    MOLAR_SI = rp.GETENUMdll(0, "USER").iEnum

    # T, P, D, H, S, W, I, E, K, and N
    r = rp.REFPROP1dll("102;152;204;305;354;403;0;0;0;0", "UNITUSER", MOLAR_SI, 0, 0, 0, z)
    #
    USER = rp.GETENUMdll(0, "USER").iEnum
    print('User: ',USER)

    print('Molar SI: ', MOLAR_SI)
    lim = {}

    print('\nString Outputs, stored in hUnits')
    print('--------------------------------')
    for k, key in zip(["TMIN", "TMAX", "DMAX", "PMAX", "XMASS", "MM"],
                      ['tmin', 'tmax', 'Dmax', 'pmax', 'xmass', 'm']):
        r = rp.REFPROP1dll("", k, MOLAR_SI, 0, 0, 0, z)
        print('r: ', r)
        lim[key] = r.c

    print('Limits: ', lim)

    #Keys to combine
    keys = ['D', 's', 't', 'q', 'p', 'h']

    #?????
    repl_nan = set([-9999990.0, -9999970.0])

    #Molar mass?
    w_g = lim['m']

    #Generic simulation limits
    sim_lim = {'tmin': -40, 'tmax': 150, 'pmax': 12000}

    if sim_lim['tmin'] > lim['tmin']:
        print('tmin: refprop: ', lim['tmin'], ' sim: ', sim_lim['tmin'])
        lim['tmin'] = max(sim_lim['tmin'], lim['tmin']+1)

    if sim_lim['tmax'] < lim['tmax']:
        print('tmax: refprop: ', lim['tmax'], ' sim: ', sim_lim['tmax'])
        lim['tmax'] =  min(sim_lim['tmax'], lim['tmax']-1)

    if sim_lim['pmax'] < lim['pmax']:
        print('pmax: refprop: ', lim['pmax'], ' sim: ', sim_lim['pmax'])
        lim['pmax'] = min(sim_lim['pmax'], lim['pmax']-1)

    def flash(inputs, a, b):
        props = {}

        r = rp.REFPROP2dll(fluid_name, inputs, "D;S;T;QMOLE;P;H", USER, 0, a, b, z)

        #props = {k: r.Output[i] if not r.Output[i] in repl_nan else np.nan for i, k in enumerate(keys)}
        props = {k: r.Output[i] for i, k in enumerate(keys)}
        # props['q'] = r.q if not r.q in repl_nan else np.nan
        props['q'] = min(1, max(props['q'], 0))
        # props['q'] = props['q'] if props['q'] < 1 else 1
        # print(props['q'])

        return props
    # pT

    #GENERATE lookup-tbales from here!

    print('GENERATING <Tp>:')
    props_data = {
        'Tp': generate_property_tables(flash, 'Tp', N, lim['tmin'], lim['tmax'], np.log(1), np.log(lim['pmax']), False,
                                       True)}

    props_data.update({'Tq': generate_property_tables(flash, 'Tq', N, lim['tmin'], lim['tmax'], 0, 1, False)})


    #Find limits for hv and sv from TP
    TP = props_data['Tp']
    hv = np.array(TP['hv'], np.float64)
    sv = np.array(TP['sv'], np.float64)
    print('Min hv')
    hmin = np.min(hv[~np.isnan(hv)])
    print(hmin)

    print('Max hv')
    hmax = np.max(hv[~np.isnan(hv)])
    print(hmax)


    # Dh
    print('GENERATING <rh>:')
    props_data.update(
        {'rh': generate_property_tables(flash, 'Dh', N, .1, lim['Dmax'], hmin, hmax, True, True)})

    print('GENERATING <pq>:')
    # pq

    props_data.update({'pq': generate_property_tables(flash, 'pq', N, np.log(1), np.log(lim['pmax']), 0, 1, True)})

    print('GENERATING <ph>:')
    props_data.update(
        {'ph': generate_property_tables(flash, 'ph', N, np.log(1), np.log(lim['pmax']), hmin, hmax, True)})

    print('Min sv')
    smin = np.min(sv[~np.isnan(sv)])
    print(smin)
    print('Max sv')
    smax = np.max(sv[~np.isnan(sv)])
    print(smax)

    # ps
    print('GENERATING <ps>:')
    props_data.update(
        {'ps': generate_property_tables(flash, 'ps', N, np.log(1), np.log(lim['pmax']), smin, smax, True)})


    #Form properties data structure
    props = {}
    props['RF'] = fluid_name
    props['w_g'] = w_g

    props['lim'] = lim
    props['data'] = props_data

    pickle.dump(props, open(file_name, "wb"))

    props.pop('data')

    props['file'] = file_name

    return props

from sim_tools.store.simulation_store_mongodb import Simulation_Store as Simulation_Store
ss = Simulation_Store()
def store_Fluid_Properties(name, props, desc='', comm=''):
    # from sim_tool_EM.Core.Simulation_Store_MongoDB import Simulation_Store

    print('Would you like to store as a configuration? Y(es)/N(o)')
    inp = input()
    if inp == 'Y':


        c = {'Name': name, 'Description': desc, 'Comment': comm,
             'Configuration': props}

        # 'Configuration': {'Type': HxType, 'h_data': {'h': h.tolist(), 'mflow_1':mdot_CW1_in_h.tolist(),'mflow_2':mdot_CW2_in_h.tolist()},
        # 'kV1_plate': knp1[0]/N_Plates_kV, 'kV1_Port':kp1[0], 'kV2_plate': knp2[0]/N_Plates_kV, 'kV2_Port':kp2[0], 'A_Plate':A_Plate, 'V1_Plate':0.134/1000, 'V2_Plate':0.195/2/1000
        # print(c)
        id = ss.store_Config(c)

        c2_obj = ss.get_Config_by_id(id).next()

        ss.close()
    else:
        print('Config not saved!')

if __name__ == '__main__':

    #Define number of points in interpolation table
    N = (333, 333)

    RF = input('RF (whole filename with ext): ')

    print('Generating data for: ', RF[:-4])
    RF_name = RF[:-4]

    tag = input('Tag: ')

    props = generate_Two_Phase_Fluid_Properties_from_Refprop(RF, RF_name + '_' + tag + '.dat', N)


    #Store properties in DB
    conf_name = RF_name + '_' + tag

    store_Fluid_Properties(conf_name, props, desc='Testing', comm='v2')

    print('Stored as ', conf_name)