from __future__ import division, print_function, absolute_import


import pygfunction as gt
from decimal import Decimal
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
#import requests
from scipy.interpolate import interp1d
from scipy.optimize import differential_evolution, fmin

#from sim_components.generic.items import SubsystemPrescribed
#from sim_tools.store.simulation_store_mongodb import Item_Factory
from sim_tools.assemble_solve.model_assemble import Model

from time import time
from collections import OrderedDict

import numpy as np
from sim_components.generic.items import Subsystem, SubsystemPrescribed
from sim_tools.store.simulation_store_mongodb import Item_Factory, Simulation_Store

#import sim_components.thermodynamics.Fluid_Flow as FF
#import sim_components.thermodynamics.Heat_Transfer as HT

import time
import pymongo
import datetime
import os
import sim_components.generic.items as B
import getpass

Sandbox = "mongodb://LasseThomsen:ziQPn556jjRaxY6x@cluster0-shard-00-00-9vgq3.mongodb.net:27017,cluster0-shard-00-01-9vgq3.mongodb.net:27017,cluster0-shard-00-02-9vgq3.mongodb.net:27017/test?ssl=true&replicaSet=Cluster0-shard-0&authSource=admin"
ReliableServer = "mongodb://sim:m9Yx8ip6P88i@209.222.108.162:28017/sim_config?authSource=admin"
GCE_Kubernetes = "mongodb://sim:m9Yx8ip6P88i@35.195.195.235:27017/sim_config?authSource=admin"

Default_MongoDB_URI = GCE_Kubernetes


class Gfunc_Store():
    def __init__(self, MongoDB_URI=Default_MongoDB_URI):
        self.pmClient = pymongo.MongoClient(Default_MongoDB_URI)
        self.db = self.pmClient.sim_config

    def list_configurations(self, filter={}, print_=False):
        res = self.db.gfuncs.find(filter)
        if print_:
            for r in res:
                print(
                    r['Name'] + '---' + r['Description'] + '-' + r['Comment'] + ' - ' + str(r['TimeStamp']) + ' - ' + r[
                        'User'])
        else:
            return res


    def store_Config(self, conf, overwrite=False):
        conf.update({'TimeStamp': time.time(), 'DateTime': datetime.datetime.today(), 'User': getpass.getuser()})
        if overwrite:
            resp = self.db.gfuncs.find_and_modify(query={'Name': conf['Name']}, update={'$set': conf}, upsert=True,
                                                   full_response=True)
            print(resp)
            conf_id = 0
            # conf_id=resp['value']['_id']
        else:
            conf_id = self.db.gfuncs.insert_one(conf).inserted_id
        return conf_id

    def get_Config_by_Name(self, name):
        return self.db.gfuncs.find({'Name': name})

    def get_Config_by_id(self, id):
        return self.db.gfuncs.find({'_id': id})

    def close(self):
        self.pmClient.close()
ss = Gfunc_Store()

import time


def generate_g_function(
        # Borehole dimensions
        D=1,  # Borehole buried depth (m)
        H=250.0,  # Borehole length (m)
        r_b=0.114 / 2,  # Borehole radius (m)
        B=5,  # Borehole spacing (m)
        N_1=2,
        N_2=2,
        nSegments=13,
        silent=False
):

    cV = 2300000
    k_s = 3.4  # Ground thermal conductivity (W/m.K)
    alpha = k_s / cV
    boreField = gt.boreholes.rectangle_field(N_1, N_2, B, B, H, D, r_b)

    # -------------------------------------------------------------------------
    # Evaluate the g-functions for the borefield
    # -------------------------------------------------------------------------

    # Number of segments per borehole

    # Test simulation time setups
    # TEST1: Geometrically expanding time vector.
    dt = 100 * 3600.  # Time step
    tmax = 1000. * 8760. * 3600.  # Maximum time
    Nt = 50  # Number of time steps
    t_s = H ** 2 / (9. * alpha)  # Bore field characteristic time

    time = gt.utilities.time_geometric(dt, tmax, Nt)

    # g-function for uniform borehole wall temperature
    gfunc_uniform_T = gt.gfunction.uniform_temperature(
        boreField, time, alpha, nSegments=nSegments, disp=True)

    model_params = {
        'B/H': str(round(Decimal(B / H), 7)),
        'rb/H': str(round(Decimal(r_b / H), 7)),
        'D/H': str(round(Decimal(D / H), 7)),
        'geometry': 'rect/' + str(N_1) + 'x' + str(N_2),

    }
    id = '-'.join([str(i) for i in model_params.values()])
    print(id)
    model_params.update({
        'id': id,
        'N_1': N_1,
        'N_2': N_2,
        'G_final': max(gfunc_uniform_T),
        't_ts_final': max(time / t_s),
        't_s': t_s,
        'ver': '0.1.0'
    })

    model_params.update({'n_requests': 0, 'first_timestamp': datetime.datetime.now().isoformat()})
    model_params.update({'ln_t/t_s': list(np.log(time / t_s)), 'G': list(gfunc_uniform_T)})

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.plot(np.log10(time / 3600 / 24), gfunc_uniform_T,
             label='Uniform borehole wall temperature')

    ax1.set_xlabel('log10(t)')
    ax1.set_ylabel('G-value')
    ax1.legend()

    ax1.grid()


    ax2.plot(np.log(time/t_s), gfunc_uniform_T,
             label='Uniform borehole wall temperature')

    ax2.set_xlabel('log(t/t_s)')
    ax2.set_ylabel('G-value')
    ax2.legend()

    ax2.grid()

    plt.show()
    shouldStore='y'
    if not silent:
        shouldStore = input('Store this g-function with id <'+id+'>? (y/N)')

    if shouldStore == 'y' or silent:

        #r = requests.post('https://gfunc.numerously.com/gfunc_store', json={'params': model_params})
        #r = requests.post('http://127.0.0.1:5000/gfunc_store', json={'params': model_params})
        #print(id + ' stored successfully.')
        c = {'Name': id + '_model_params', 'Description': 'G-function', 'Tags': ['G-func'],
                        'Comment': '', 'Configuration':
                            model_params
                        }

        if ss.store_Config(c, True) == 0:
            print(id + ' stored successfully.')
        else:
            raise ValueError('Not saved')

    return shouldStore, model_params

def store_g_function(
        # Borehole dimensions
        D=1,  # Borehole buried depth (m)
        H=250.0,  # Borehole length (m)
        r_b=0.114 / 2,  # Borehole radius (m)
        B=5,  # Borehole spacing (m)
        N_1=2,
        N_2=2,
        nSegments=13,
        silent=False
):

    cV = 2300000
    k_s = 3.4  # Ground thermal conductivity (W/m.K)
    alpha = k_s / cV
    boreField = gt.boreholes.rectangle_field(N_1, N_2, B, B, H, D, r_b)

    # -------------------------------------------------------------------------
    # Evaluate the g-functions for the borefield
    # -------------------------------------------------------------------------

    # Number of segments per borehole

    # Test simulation time setups
    # TEST1: Geometrically expanding time vector.
    dt = 100 * 3600.  # Time step
    tmax = 1000. * 8760. * 3600.  # Maximum time
    Nt = 50  # Number of time steps
    t_s = H ** 2 / (9. * alpha)  # Bore field characteristic time

    time = gt.utilities.time_geometric(dt, tmax, Nt)

    # g-function for uniform borehole wall temperature
    gfunc_uniform_T = gt.gfunction.uniform_temperature(
        boreField, time, alpha, nSegments=nSegments, disp=True)

    model_params = {
        'B/H': str(round(Decimal(B / H), 7)),
        'rb/H': str(round(Decimal(r_b / H), 7)),
        'D/H': str(round(Decimal(D / H), 7)),
        'geometry': 'rect/' + str(N_1) + 'x' + str(N_2),

    }
    id = '-'.join([str(i) for i in model_params.values()])
    print(id)
    model_params.update({
        'id': id,
        'N_1': N_1,
        'N_2': N_2,
        'G_final': max(gfunc_uniform_T),
        't_ts_final': max(time / t_s),
        't_s': t_s,
        'ver': '0.1.0'
    })

    model_params.update({'n_requests': 0, 'first_timestamp': datetime.now().isoformat()})
    model_params.update({'ln_t/t_s': list(np.log(time / t_s)), 'G': list(gfunc_uniform_T)})

    fig, (ax1, ax2) = plt.subplots(2, 1)

    ax1.plot(np.log10(time / 3600 / 24), gfunc_uniform_T,
             label='Uniform borehole wall temperature')

    ax1.set_xlabel('log10(t)')
    ax1.set_ylabel('G-value')
    ax1.legend()

    ax1.grid()


    ax2.plot(np.log(time/t_s), gfunc_uniform_T,
             label='Uniform borehole wall temperature')

    ax2.set_xlabel('log(t/t_s)')
    ax2.set_ylabel('G-value')
    ax2.legend()

    ax2.grid()

    plt.show()
    shouldStore='y'
    if not silent:
        shouldStore = input('Store this g-function with id <'+id+'>? (y/N)')

    if shouldStore == 'y' or silent:

        #r = requests.post('https://gfunc.numerously.com/gfunc_store', json={'params': model_params})
        #r = requests.post('http://127.0.0.1:5000/gfunc_store', json={'params': model_params})
        #print(id + ' stored successfully.')
        #if r.status_code == 200:
        #    print(id + ' stored successfully.')
        #else:
        #    raise ValueError('Not saved - code: ', r.status_code)

        c = {'Name': id + '_model_params', 'Description': 'G-function', 'Tags': ['G-func'],
             'Comment': comments, 'Configuration':
                 model_params
             },
        if ss.store_Config(c, True) == 0:
            print(id + ' stored successfully.')
        else:
            raise ValueError('Not saved')

    return shouldStore, model_params

#FIT IT
def EvalExpFit(x, a, b):
    return np.dot(a, 1 - np.exp(-np.outer(b, x)))


def ExpFitDiffEvol(N, x, y):
    """Fit N-exponential decay to a dataseries (x, y) using differential
    evolution as implemented in scipy.optimize.

    Parameters
    ----------
    N : float
        number of summed exponentials to fit

    x : array
        x values

    y : array
        y values

        returns a, b
        len(a) = N
        len(b) = N

        y(x) = \sum_{i=1}^N a_i \exp ( - b_i x )

        Use of differntial evolution inspired by use of another genetic algorithm
        to perform exponential fit of by Weizhong Zou in

            Zou, Weizhong. Larson, Ronald G.
            "A mesoscopic simulation method for predicting the rheology of
            semi-dilute wormlike micellar solutions." Journal of Rheology. 58,
            681 (2014).

    """
    x = np.array(x)
    y = np.array(y)

    # bounds = [[min(x), max(x)]]*N + [[min(y), max(y)]]*N
    bounds = [[0, max(x) * 1.5]] * N + [[0, max(y) * 1.5]] * N

    def objective(s):
        taui, fi = np.split(s, 2)
        # return np.sum((1 - np.dot(fi, 1 - np.exp(-np.outer(1./taui, x)))/y)**2.)
        return np.sum((np.dot(fi, 1 - np.exp(-np.outer(1. / taui, x))) / y - 1) ** 2.)

    result = differential_evolution(objective, bounds)
    s = result['x']
    taui, fi = np.split(s, 2)
    return fi, 1. / taui

def fit_g_function_node_model(
    g_func_par, silent=False,
    T_init = 5.9,
    q = 20,  # W/m
    H = 250,
    cV = 2300000,
    k_s = 3.4  # Ground thermal conductivity (W/m.K)
):
    alpha = k_s / cV
    t_s = H ** 2 / (9. * alpha)  # Bore field characteristic time

    df_g = pd.DataFrame({'ln_t_t_s': g_func_par['ln_t/t_s'], 'G-function': g_func_par['G']})
    df_g['t_t_s'] = np.exp(df_g['ln_t_t_s'])
    df_g['t'] = df_g['t_t_s'] * t_s

    # df_g['dTb'] = df_g['T'] - T_init - q * Rb

    # df_g['ln_t_t_s'] = np.log(df_g['t_t_s']+0.0001)
    # df_g['G-function'] = 2*np.pi*ks*df_g['dTb']/q

    df_g.set_index('t', inplace=True)

    G = interp1d(df_g.index /t_s, df_g['G-function'], fill_value='extrapolate')
    t_max = max(df_g.index)
    t_min = min(df_g.index[df_g.index > 0])

    #Fit using sum of exponentials
    n = 4

    a, b = ExpFitDiffEvol(n, df_g.index.values, df_g['G-function'].values)
    G_fit = EvalExpFit(df_g.index.values, a, b)

    #Plot initial fit
    fig1, ax1 = plt.subplots(1, 1)

    ax1.plot(np.log10(df_g.index.values / 3600 / 24), G_fit, label='G fit')
    ax1.plot(np.log10(df_g.index.values / 3600 / 24), df_g['G-function'].values, label='G')
    # ax2_3.plot(sol['t'], , label='G-function ref')

    ax1.set_xlabel('log10(t_days)')
    ax1.set_ylabel('G_fit/G_ref-1')
    ax1.legend()

    ax1.grid()

    plt.show()

    #Define node fitting model
    item_fact = Item_Factory('')

    def solve_stable(C, h, T0, dt_fix, t0, t_out):
        energy_system = {
            'items': OrderedDict({
                'BHS': {'item_class': 'sim_components.simple_components.estimated_energy_components.Simple_BHS_G_func', 'T0': T0, 'k_s': k_s, 'H': H,
                        'q': q, 'TG': 0,
                        'C': C, 'h': np.array(h)
                        },
            })
        }
        test = SubsystemPrescribed('Stub', '', item_fact, energy_system)

        model = Model('', test)
        model.assemble('[ALL]')

        dT_last = -1000
        model.t = t0
        model.apply_init_state({'t': t0})

        model.solve_t_out(t_out, dt_fix=dt_fix, history=True)

        sol = pd.DataFrame(model.history)
        dT_this = (model.sys_dict['Stub.BHS.V0.T'] - model.sys_dict['Stub.BHS.V1.T'])
        T = [model.sys_dict['Stub.BHS.V' + str(i) + '.T'] for i in range(len(C))]
        T_volumes = {k: model.history[k] for k in ['t'] + ['Stub.BHS.V' + str(i) + '.T' for i in range(len(C))]}
        T_hist = {'t': sol['t'], 'T': sol['Stub.BHS.V0.T']}

        return dT_this, model.t, T, T_hist, T_volumes


    keys = ['Stub.BHS.V' + str(i) + '.T' for i in range(n)]


    def calc_T_resp(C, h, n):

        T0 = [0 for i in range(n + 1)]
        t_steps = np.logspace(np.log10(t_min), np.log10(t_max), num=100, base=10.0)

        # print('max t step: ', max(t_steps))
        keys = ['Stub.BHS.V' + str(i) + '.T' for i in range(n)]
        t_end = []
        dT = []
        t = 0

        for i in range(len(C) - 1):
            h_ = h[i:]
            C_ = C[i:]
            T0_ = T0[1:]
            # print(T0_)
            t063_0 = C_[0] * t_s / h_[0]
            t063_1 = C_[1] * t_s / h_[1]
            t_1ppm = 14 * t063_1
            # print('t_1ppm: ', t_1ppm)
            dt_fix = t063_0 / 100
            # dt= dt_fix*10
            t_out = []
            for t_step in t_steps:
                if t_step > t_1ppm:
                    break
                if t_step > t:
                    t_out.append(t_step)

            # print(t_out)
            if len(t_out) == 0:
                t_out.append(min(t + dt_fix * 3, max(t_steps)))
            # print('max tout: ',max(t_out))
            dTt, t, T0, T_hist, T_volumes = solve_stable(C_, h_, T0_, dt_fix, t, t_out)

            if i > 0:
                T_full_hist['t'] += T_volumes['t']
                # print('min: ', min(T_volumes['t']))
                for ix, dT_ in enumerate(dT):
                    # print('here')
                    # print(len(T_volumes[keys[0]]))
                    # print(len(list(np.array(T_volumes[keys[0]])+np.sum(dT[ix:]))))
                    T_full_hist[keys[ix]] += list(np.array(T_volumes[keys[0]]) + sum(dT[ix:]))
                # print(T_volumes[1:])
                for j, k in enumerate(keys[:-ix - 1]):
                    # print(ix+1+j)

                    T_full_hist[keys[ix + 1 + j]] += T_volumes[k]

            else:
                # print(len(T_volumes['t']))
                T_full_hist = T_volumes

            # print(T)
            T_hist = pd.DataFrame(T_hist)
            # print(T_hist)
            T_hist['T'] += sum(dT)
            # print('Tb final: ',T_hist['T'].values[-1], ' t final: ', T_hist['t'].values[-1])
            # print(T0)
            t_end.append(t)
            dT.append(dTt)

        extend = False
        for i, t_ in enumerate(t_steps):
            if t_ > t:
                extend = True
                break
        if extend:
            # print(t_steps[i:])
            extend_data = {'t': t_steps[i:]}
            for k in keys:
                extend_data[k] = list(np.ones(len(t_steps[i:])) * T_full_hist[k][-1])

            extend_data = pd.DataFrame(extend_data)
            # print(extend_data)

            T_full_hist = pd.DataFrame(T_full_hist)
            T_full_hist = pd.concat([T_full_hist, extend_data])
        T_full_hist = pd.DataFrame(T_full_hist)
        # for k, v in T_full_hist.items():
        # print(len(v))
        sol = T_full_hist  # pd.DataFrame(T_full_hist)
        return sol




    min_cost = {'cost': np.inf}

    fig1, (ax1,  ax3, ) = plt.subplots(2,1)

    def cost(X, n, plot=False):
        # C = C_start

        # dT_r = np.abs(np.array(X))# q/np.abs(np.array(X))/ks*27000/600
        # dT = dT_r#/sum(dT_r)*100/0.63
        # h= q/dT/ks*27000/600
        # h=[abs(x) for x in X]
        # C  = tau/14*h*ks*3600
        X = [abs(x) for x in X]
        h = []
        C = []
        for i in range(n):
            h.append(X[n + i])
            C.append(X[i])
            # h.append(X[i+n])
            # if i > 0:
            #    C.append(C[i-1]*(1+X[i]))
            #    h.append(h[i-1]/(1+X[n+i]))
            # else:
            #    h.append(X[n+i])
            #    C.append(X[i])
        # o.clear_output()
        #with o:
            #
        sol = calc_T_resp(C, h, n)

        # Plot status
        sol['ln_t_t_s'] = np.log(sol['t'] / t_s + 0.0001)
        sol['G'] = (sol['Stub.BHS.V0.T']) * (2 * np.pi * k_s) / q
        sol['dT'] = (sol['Stub.BHS.V0.T'])
        # calc reference
        # print(max(sol['t']))
        sol['G_ref'] = G(sol['t'] / t_s)
        sol['dT_ref'] = G(sol['t'] / t_s) / (2 * np.pi * k_s) * q

        # compare and calc cost
        cost = np.sqrt((((sol['G'] / sol['G_ref'] - 1)) ** 2).sum())

        data = {'X': list(X), 'C': list(C), 'h': list(h), 'cost': cost}
        #if cost < min_cost['cost']:
            #o.clear_output()
            #with o:
                #print('C: ', C, ' h: ', h)
        if plot:
                ax1.clear()

                ax1.plot(sol['ln_t_t_s'], sol['G'], label='G-function fit')
                ax1.plot(sol['ln_t_t_s'], sol['G_ref'], label='G-function ref')

                ax1.set_xlabel('ln(t/t_s)')
                ax1.set_ylabel('G-value')
                ax1.legend()

                ax1.grid()

                #ax2.clear()
                #ax2.plot(np.log10(sol['t'] / 3600 / 24), sol['G'], label='G-function fit')
                #ax2.plot(np.log10(sol['t'] / 3600 / 24), sol['G_ref'], label='G-function ref')

                #ax2.set_xlabel('log10(t_days)')
                #ax2.set_ylabel('G-value')
                #ax2.legend()

                #ax2.grid()

                ax3.clear()
                ax3.plot(np.log10(sol['t'] / 3600 / 24), sol['G'] / sol['G_ref'] - 1, label='G Relative Error')
                # ax2_3.plot(sol['t'], , label='G-function ref')

                ax3.set_xlabel('log10(t_days)')
                ax3.set_ylabel('G_fit/G_ref-1')
                ax3.legend()

                ax3.grid()

                #ax4.clear()
                #ax4.plot(np.log10(sol['t'] / 3600 / 24), sol['G'] - sol['G_ref'], label='G-function diff')
                # ax2_3.plot(sol['t'], , label='G-function ref')

                ##ax4.set_xlabel('log10(t_days)')
                #ax4.set_ylabel('G-fit - G-ref')
                #ax4.legend()

                #ax4.grid()


                print(cost)
                plt.show()

        if cost < min_cost['cost']:
            min_cost['cost'] = cost
            #with open(file, 'w') as outfile:
             #   json.dump(data, outfile)

        return cost

    #Recalculate a and b for node model params
    C_ = list(1 / b / t_s * 1 / a)
    from copy import deepcopy
    C_s = deepcopy(C_)
    C_s.sort()
    # C_s = list(C_s)
    C_s

    h_s = []
    for c in C_s:
        # print(C_.index(c))
        h_s.append(1 / (a[C_.index(c)]))
    h_s = np.array(h_s)
    sum(1 / h_s)

    min_cost = {'cost': np.inf}
    x_guess_ = list(np.array(C_s) * 12.5) + list(h_s)
    #cost_out = cost(x_guess_, n)
    print('Now doing detailed fit... Be patient... (1-2 min)')
    xopt = fmin(func=cost, disp=True, x0=x_guess_, args=(n,), maxfun=250)

    cost(xopt,n, plot=True)
    shouldStore='y'
    if not silent:
        shouldStore = input('Store this g-function model fit with id <' + g_func_par['id'] + '>? (y/N)')

    fit_params = {'id': g_func_par['id'], 'n': n, 'C': list(np.abs(xopt[:n])), 'h': list(np.abs(xopt[n:]))}

    if shouldStore == 'y' or silent:
        #r = requests.post('https://gfunc.numerously.com/gfunc_node_model_store', json={'params': fit_params})
        #r = requests.post('http://127.0.0.1:5000/gfunc_node_model_store', json={'params': fit_params})
        #if r.status_code == 200:
        #    print(str(fit_params['id']) + ' stored successfully.')
        #else:
        #    raise ValueError('Not saved - code: ',r.status_code)

        c = {'Name': g_func_par['id'] + '_fit_params', 'Description': 'G-function-fit', 'Tags': ['G-func'],
             'Comment': '', 'Configuration':
                 fit_params
             }

        if ss.store_Config(c, True) == 0:
            print(g_func_par['id'] + ' stored successfully.')
        else:
            raise ValueError('Not saved')

    return fit_params

def check_exist_g_func_node_model(
        D=1,  # Borehole buried depth (m)
        H=250.0,  # Borehole length (m)
        r_b=0.114 / 2,  # Borehole radius (m)

        B=2,  # Borehole spacing (m)

        N_1=2,
        N_2=2,

                       ):
    #r = requests.post('https://gfunc.numerously.com/gfunc_node_model', json={'params': {
    #r = requests.post('http://127.0.0.1:5000/gfunc_node_model', json={'params': {
    #    'D': D,
    #    'H': H,
    #    'R_b': r_b,  # Borehole radius (m)
    #    'B': B,  # Borehole spacing (m)
    #    'N_1': N_1,
    #    'N_2': N_2
    #}})

    model_params = {
        'B/H': str(round(Decimal(B / H), 7)),
        'rb/H': str(round(Decimal(r_b / H), 7)),
        'D/H': str(round(Decimal(D / H), 7)),
        'geometry': 'rect/' + str(N_1) + 'x' + str(N_2),

    }
    print(model_params)
    id = '-'.join([str(i) for i in model_params.values()])

    conf=ss.get_Config_by_Name(id+'_fit_params')
    if conf.count()>0:
        conf =conf.next()
    else:
        conf=None

    #print(r.status_code)
    #if r.status_code == 200:
        #print(r.json())
        #G_params = r.json()
        #if 'ver' in G_params:
            #return G_params
    return conf

def get_g_func_model_ref(ref):
    #r = requests.post('https://gfunc.numerously.com/gfunc_node_model', json={'reference': {
    #r = requests.post('http://127.0.0.1:5000/gfunc_node_model', json={'params': {
    #    'id': ref,

    #}})


    #print(r.status_code)
    #if r.status_code == 200:
    #    print(r.json())
    #    G_params = r.json()#

        #return G_params
    #else:
    #    raise ReferenceError('Reference for g-function model {} not found!'.format(ref))


    conf = ss.get_Config_by_Name(ref)
    if conf.count()>0:
        conf =conf.next()
    else:
        conf=None

    return conf

def check_update_g_func_fit(
        # Borehole dimensions
        D=1,  # Borehole buried depth (m)
        H=250,  # Borehole length (m)
        r_b=0.114 / 2,  # Borehole radius (m)

        B=25,  # Borehole spacing (m)

        N_1=5,
        N_2=4,
        nSegments=8,
        force_update=False,
        silent=False

    ):


    params = check_exist_g_func_node_model(
        D=D,  # Borehole buried depth (m)
        H=H,  # Borehole length (m)
        r_b=r_b,  # Borehole radius (m)

        B=B,  # Borehole spacing (m)

        N_1=N_1,
        N_2=N_2,
    )




    if not params or force_update:
        print('No fit for g-func.')
        create = 'y'
        if not silent:
            create = input('Would you like to create it? (y/N)')

        if create == 'y' or silent:
            saved, g_params = generate_g_function(
                # Borehole dimensions
                D=D,  # Borehole buried depth (m)
                H=H,  # Borehole length (m)
                r_b=r_b,  # Borehole radius (m)

                B=B,  # Borehole spacing (m)

                N_1=N_1,
                N_2=N_2,
                nSegments=nSegments,
                silent=silent

            )

            params = fit_g_function_node_model(g_params, silent=silent)
            g_params.update(params)
            return g_params



    return params['Configuration']



if __name__ == '__main__':
    print(get_g_func_model_ref('BIGHQ_fit_3'))
    print(check_update_g_func_fit(N_1=7, N_2=7, H=290, B=20, D=40, r_b=0.057))