import numpy as np
from numba import jit, prange
from sim_tools.assemble_solve import model_encoder as mes

@jit(nopython=True, nogil=True, cache=True)
def state_vals(sys, states):
    y = np.zeros(len(states), dtype=np.float64)
    for i in prange(0, len(states)):
        y[i] = sys[states[i]]
    return y


@jit(nopython=True, nogil=True, cache=True, parallel=False)
def run_ops(t, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data, first):
    '''Function to run the operations on the sys array which consists calculating the time derivative.
    t - simulation time in seconds
    y - vector of states in diff eq system
    sys - the system parameters, states and other variables
    states - indices for the states position in sys
    states_dot - indices for the time derivatives of the states position in sys
    ops - list of operations to perform on sys
    error_op - place to store errorneous operation in case of an error occuring
    data_def - vector of start indices for data blocks in data
    data - vector of all static lookup data needed by the equations

    '''

    # Create the time derivative of states vector
    y_dot = y * 0
    # initialize error code to 0
    error_code = 0

    # write the simulation time in the sys vector to use during post analysis
    sys[0] = t

    # Set all time derivatives to 0 - they'll be filled during execution of ops
    for i in range(0, len(states_dot)):
        sys[states_dot[i]] = 0

    # Move in state values to sys from the supplied state vector
    for i in range(0, len(states)):
    #    if np.isnan(y[i]) or np.isinf(y[i]):
    #        error_op[0] = 7
    #        return y_dot

        sys[states[i]] = y[i]

        # For loop over all operations
    # All ops are operating on sys
    for i in range(0, len(ops[:, 0])):

        # Assign value
        if ops[i, mes.FUNC] == mes.ASS:
            res = sys[ops[i, mes.LEFT]]

        # Add values
        elif ops[i, mes.FUNC] == mes.ADD:
            res = sys[ops[i, mes.LEFT]] + sys[ops[i, mes.RIGHT]]

        # Subtract values
        elif ops[i, mes.FUNC] == mes.SUB:
            res = sys[ops[i, mes.LEFT]] - sys[ops[i, mes.RIGHT]]

        # Multiply values
        elif ops[i, mes.FUNC] == mes.MULT:
            res = sys[ops[i, mes.LEFT]] * sys[ops[i, mes.RIGHT]]

        # Divide values
        elif ops[i, mes.FUNC] == mes.DIV:

            # raise error code if div by 0
            if sys[ops[i, mes.RIGHT]] == 0:
                error_code = 2
                res = 0
            else:
                res = sys[ops[i, mes.LEFT]] / sys[ops[i, mes.RIGHT]]

        elif ops[i, mes.FUNC] == mes.NOZDIV:

            # raise error code if div by 0
            if sys[ops[i, mes.ARG1]] == 0:

                res = sys[ops[i, mes.ARG2]]
            else:
                res = sys[ops[i, mes.ARG0]]/sys[ops[i, mes.ARG1]]

        elif ops[i, mes.FUNC] == mes.POSITIVE:
            a = sys[ops[i, mes.ARG0]]
            res = max(a,0)

        elif ops[i, mes.FUNC] == mes.NEGATIVE:
            a = sys[ops[i, mes.ARG0]]
            res = min(a,0)

        elif ops[i, mes.FUNC] == mes.INRANGE:
            a = sys[ops[i, mes.ARG1]]
            res = (a>=sys[ops[i, mes.ARG0]])*(a<sys[ops[i, mes.ARG2]])

        elif ops[i, mes.FUNC] == mes.MULT_SIGN:
            a = sys[ops[i, mes.ARG0]]
            if a >= 0:
                res = a * sys[ops[i, mes.ARG1]]
            else:
                res = a * sys[ops[i, mes.ARG2]]

        # Modulus values
        elif ops[i, mes.FUNC] == mes.MOD:

            # raise error code if div by 0
            if sys[ops[i, mes.RIGHT]] == 0:
                error_code = 45
                res = 0
            else:
                res = sys[ops[i, mes.LEFT]] % sys[ops[i, mes.RIGHT]]

        # Check for equality
        elif ops[i, mes.FUNC] == mes.EQ:
            res = 1 * (sys[ops[i, mes.LEFT]] == sys[ops[i, mes.RIGHT]])

        elif ops[i, mes.FUNC] == mes.GT:
            res = 1 * (sys[ops[i, mes.LEFT]] > sys[ops[i, mes.RIGHT]])

        elif ops[i, mes.FUNC] == mes.LT:
            res = 1 * (sys[ops[i, mes.LEFT]] < sys[ops[i, mes.RIGHT]])

        elif ops[i, mes.FUNC] == mes.GE:
            res = 1 * (sys[ops[i, mes.LEFT]] >= sys[ops[i, mes.RIGHT]])

        elif ops[i, mes.FUNC] == mes.LE:
            res = 1 * (sys[ops[i, mes.LEFT]] <= sys[ops[i, mes.RIGHT]])

        elif ops[i, mes.FUNC] == mes.NEG:
            res = -sys[ops[i, mes.LEFT]]

        elif ops[i, mes.FUNC] == mes.SQRT:
            res = np.sqrt(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.ABS:
            res = np.abs(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.SIGN:
            res = np.sign(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.EXP:
            res = np.exp(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.LOG:
            res = np.log(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.POW:
            res = sys[ops[i, mes.LEFT]] ** sys[ops[i, mes.RIGHT]]
        elif ops[i, mes.FUNC] == mes.SIN:
            res = np.sin(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.ARCTAN:
            res = np.arctan(sys[ops[i, mes.ARG0]])

        elif ops[i, mes.FUNC] == mes.MIX:
            a0 = sys[ops[i, mes.ARG0]]
            if a0 > 1 or a0 < 0:
                error_code = 10
                res = 0
            else:
                res = a0*sys[ops[i, mes.ARG1]] + (1- a0)*sys[ops[i, mes.ARG2]]

        elif ops[i, mes.FUNC] == mes.RAND:
            res = np.random.rayleigh(sys[ops[i, mes.ARG0]])

        # Do bilinear interpolation
        elif ops[i, mes.FUNC] == mes.INTERP_BILIN:
            res = bilinear_interpolation(ops[i, mes.ARG0], sys[ops[i, mes.ARG1]], sys[ops[i, mes.ARG2]], data_def, data)

        # Do linear interpolation
        elif ops[i, mes.FUNC] == mes.INTERP_LIN:
            res = linear_interpolation(ops[i, mes.ARG0], sys[ops[i, mes.ARG1]], data_def, data)

        # Do nn interpolation
        elif ops[i, mes.FUNC] == mes.INTERP_NN:
            res = nn_interpolation(ops[i, mes.ARG0], sys[ops[i, mes.ARG1]], data_def, data)

            # Calc polynomial value
        elif ops[i, mes.FUNC] == mes.CALC_POLY2X3:
            res = calc_Polynom2x3(ops[i, mes.ARG0], sys[ops[i, mes.ARG1]], sys[ops[i, mes.ARG2]], data_def, data)


        elif ops[i, mes.FUNC] == mes.PRINT:

            res = sys[ops[i, mes.ARG0]]
            print(res)

        elif ops[i, mes.FUNC] == mes.MAX3:
            res = max(sys[ops[i, mes.ARG0]], sys[ops[i, mes.ARG1]], sys[ops[i, mes.ARG2]])

        elif ops[i, mes.FUNC] == mes.MAX2:
            res = max(sys[ops[i, mes.ARG0]], sys[ops[i, mes.ARG1]])

        elif ops[i, mes.FUNC] == mes.MIN3:
            res = min(sys[ops[i, mes.ARG0]], sys[ops[i, mes.ARG1]], sys[ops[i, mes.ARG2]])

        elif ops[i, mes.FUNC] == mes.MIN2:
            res = min(sys[ops[i, mes.ARG0]], sys[ops[i, mes.ARG1]])

        elif ops[i, mes.FUNC] == mes.ASSERT:
            res = sys[ops[i, mes.ARG0]]
            if res < 0.9:
                error_code = 45
                print(mes.ARG1)
                print(mes.ARG2)

        elif ops[i, mes.FUNC] == mes.IFEXP:
            if sys[ops[i, mes.ARG0]] > 0.1:
                res= sys[ops[i, mes.ARG1]]
            else:
                res=sys[ops[i, mes.ARG2]]




        elif ops[i, mes.FUNC] == mes.SWITCH:
            if sys[ops[i, mes.ARG0]]>0:
                res = sys[ops[i, mes.ARG1]]
            else:
                res = sys[ops[i, mes.ARG2]]





        # If no matching function for op raise error code
        else:

            error_code = 1

        if np.isnan(res) or np.isinf(res):
            error_code = 5
            if np.isnan(sys[ops[i, mes.ARG0]]):
                error_code += 1

            if np.isnan(sys[ops[i, mes.ARG1]]):
                error_code += 2

            if np.isnan(sys[ops[i, mes.ARG2]]):
                error_code += 3

        # If assign with incrementation (used for assigning time derivative of states)
        if ops[i, mes.INC] > 0:
            sys[ops[i, mes.TARGET]] += res
        else:
            sys[ops[i, mes.TARGET]] = res

        # Handle the error and break running ops
        if error_code > 0:
            error_op[2:] = ops[i, :]
            error_op[0] = error_code
            error_op[1] = i
            return y_dot

    # Pick out derivatives of states and place in return vector
    for i in range(0, len(states_dot)):
        y_dot[i] = sys[states_dot[i]]

    return y_dot


@jit(nopython=True, nogil=True, cache=True)
def calc_Polynom2x3(ix, x1, x2, data_def, data):
    i = data_def[ix]

    y = data[i] + data[i + 1] * x1 + data[i + 2] * x2 + data[i + 3] * x1 ** 2 + data[i + 4] * x1 * x2 + data[
        i + 5] * x2 ** 2 + data[i + 6] * x1 ** 3 + data[i + 7] * x1 ** 2 * x2 + data[i + 8] * x1 * x2 ** 2 + data[
            i + 9] * x2 ** 3
    return y


@jit(nopython=True, nogil=True, cache=True)
def bilinear_interpolation(ix, x, y, data_def, data):
    '''Interpolate (x,y) from values associated with four points.

    The four points are a list of four triplets:  (x, y, value).
    The four points can be in any order.  They should form a rectangle.

        >>> bilinear_interpolation(12, 5.5,
        ...                        [(10, 4, 100),
        ...                         (20, 4, 200),
        ...                         (10, 6, 150),
        ...                         (20, 6, 300)])
        165.0

    '''
    ds_ix = data_def[ix]
    x_start_ix = ds_ix + 6

    x_len = data[ds_ix + 0]

    x_min = data[ds_ix + 2]

    x_max = data[ds_ix + 3]

    y_len = data[ds_ix + 1]

    y_min = data[ds_ix + 4]
    y_max = data[ds_ix + 5]

    y_start_ix = x_start_ix + np.int64(x_len)
    if (y_start_ix) < ds_ix:
        error_code = 3

    if x_max == x_min:
        x_norm = x
    else:
        x_norm = (x - x_min) / (x_max - x_min)

    if y_max == y_min:
        y_norm = y
    else:
        y_norm = (y - y_min) / (y_max - y_min)

    x_frac_ix = x_norm * (x_len - 1)

    x1_ix = np.int64(np.floor(x_frac_ix))

    if x1_ix < 0:
        x1_ix = 0
    elif x1_ix >= x_len:
        x1_ix = np.int64(x_len) - 1

    x1 = data[x1_ix + x_start_ix]

    x2_ix = np.int64(np.ceil(x_frac_ix))

    if x2_ix < 0:
        x2_ix = 0
    elif x2_ix >= x_len:
        x2_ix = np.int64(x_len) - 1

    x2 = data[x2_ix + x_start_ix]

    y_frac_ix = y_norm * (y_len - 1)

    y1_ix = np.int64(np.floor(y_frac_ix))

    if y1_ix < 0:
        y1_ix = 0
    elif y1_ix >= y_len:
        y1_ix = np.int64(y_len) - 1

    y1 = data[y1_ix + y_start_ix]

    y2_ix = np.int64(np.ceil(y_frac_ix))

    if y2_ix < 0:
        y2_ix = 0
    elif y2_ix >= y_len:
        y2_ix = np.int64(y_len) - 1

    y2 = data[y2_ix + y_start_ix]

    q_start_ix = y_start_ix + np.int64(y_len)

    q11 = data[q_start_ix + x1_ix + np.int64(x_len) * y1_ix]
    q12 = data[q_start_ix + x1_ix + np.int64(x_len) * y2_ix]
    q21 = data[q_start_ix + x2_ix + np.int64(x_len) * y1_ix]
    q22 = data[q_start_ix + x2_ix + np.int64(x_len) * y2_ix]

    # See formula at:  http://en.wikipedia.org/wiki/Bilinear_interpolation

    # points = sorted(points)               # order points by x, then by y
    # (x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points

    #    #   if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
    #     raise ValueError('points do not form a rectangle')
    # if not x1 <= x <= x2 or not y1 <= y <= y2:
    #     raise ValueError('(x, y) not within the rectangle')

    if x == x2 or x == x1 or x1 == x2:
        if y == y1 or y == y2 or y1 == y2:
            return q11
        else:
            z = (q11 * (y2 - y) + q12 * (y - y1)) / (y2 - y1)
    elif y == y1 or y == y2 or y1 == y2:
        if x == x1 or x == x2 or x1 == x2:
            return q11
        else:
            z = (q11 * (x2 - x) + q21 * (x - x1)) / (x2 - x1)
    else:
        z = (q11 * (x2 - x) * (y2 - y) +
             q21 * (x - x1) * (y2 - y) +
             q12 * (x2 - x) * (y - y1) +
             q22 * (x - x1) * (y - y1)
             ) / ((x2 - x1) * (y2 - y1))
    return z

#@jit(nopython=True, nogil=True, cache=True)
def bilinear_interpolation_nonjit(ix, x, y, data_def, data):
    '''Interpolate (x,y) from values associated with four points.

    The four points are a list of four triplets:  (x, y, value).
    The four points can be in any order.  They should form a rectangle.

        >>> bilinear_interpolation(12, 5.5,
        ...                        [(10, 4, 100),
        ...                         (20, 4, 200),
        ...                         (10, 6, 150),
        ...                         (20, 6, 300)])
        165.0

    '''
    ds_ix = data_def[ix]
    x_start_ix = ds_ix + 6

    x_len = data[ds_ix + 0]

    x_min = data[ds_ix + 2]

    x_max = data[ds_ix + 3]

    y_len = data[ds_ix + 1]

    y_min = data[ds_ix + 4]
    y_max = data[ds_ix + 5]

    y_start_ix = x_start_ix + np.int64(x_len)
    if (y_start_ix) < ds_ix:
        error_code = 3

    if x_max == x_min:
        x_norm = x
    else:
        x_norm = (x - x_min) / (x_max - x_min)

    if y_max == y_min:
        y_norm = y
    else:
        y_norm = (y - y_min) / (y_max - y_min)

    x_frac_ix = x_norm * (x_len - 1)

    x1_ix = np.int64(np.floor(x_frac_ix))

    if x1_ix < 0:
        x1_ix = 0
    elif x1_ix >= x_len:
        x1_ix = np.int64(x_len) - 1

    x1 = data[x1_ix + x_start_ix]

    x2_ix = np.int64(np.ceil(x_frac_ix))

    if x2_ix < 0:
        x2_ix = 0
    elif x2_ix >= x_len:
        x2_ix = np.int64(x_len) - 1

    x2 = data[x2_ix + x_start_ix]

    y_frac_ix = y_norm * (y_len - 1)

    y1_ix = np.int64(np.floor(y_frac_ix))

    if y1_ix < 0:
        y1_ix = 0
    elif y1_ix >= y_len:
        y1_ix = np.int64(y_len) - 1

    y1 = data[y1_ix + y_start_ix]

    y2_ix = np.int64(np.ceil(y_frac_ix))

    if y2_ix < 0:
        y2_ix = 0
    elif y2_ix >= y_len:
        y2_ix = np.int64(y_len) - 1

    y2 = data[y2_ix + y_start_ix]

    q_start_ix = y_start_ix + np.int64(y_len)

    q11 = data[q_start_ix + x1_ix + np.int64(x_len) * y1_ix]
    q12 = data[q_start_ix + x1_ix + np.int64(x_len) * y2_ix]
    q21 = data[q_start_ix + x2_ix + np.int64(x_len) * y1_ix]
    q22 = data[q_start_ix + x2_ix + np.int64(x_len) * y2_ix]

    # See formula at:  http://en.wikipedia.org/wiki/Bilinear_interpolation

    # points = sorted(points)               # order points by x, then by y
    # (x1, y1, q11), (_x1, y2, q12), (x2, _y1, q21), (_x2, _y2, q22) = points

    #    #   if x1 != _x1 or x2 != _x2 or y1 != _y1 or y2 != _y2:
    #     raise ValueError('points do not form a rectangle')
    # if not x1 <= x <= x2 or not y1 <= y <= y2:
    #     raise ValueError('(x, y) not within the rectangle')

    if x == x2 or x == x1 or x1 == x2:
        if y == y1 or y == y2 or y1 == y2:
            return q11
        else:
            z = (q11 * (y2 - y) + q12 * (y - y1)) / (y2 - y1)
    elif y == y1 or y == y2 or y1 == y2:
        if x == x1 or x == x2 or x1 == x2:
            return q11
        else:
            z = (q11 * (x2 - x) + q21 * (x - x1)) / (x2 - x1)
    else:
        z = (q11 * (x2 - x) * (y2 - y) +
             q21 * (x - x1) * (y2 - y) +
             q12 * (x2 - x) * (y - y1) +
             q22 * (x - x1) * (y - y1)
             ) / ((x2 - x1) * (y2 - y1))
    return z

@jit(nopython=True, nogil=True)#, cache=True)
def linear_interpolation(ix, x, data_def, data):
    '''Interpolate (x) from values associated with two points.

    '''
    ds_ix = data_def[ix]

    x_start_ix = ds_ix + 3
    x_len = int(data[ds_ix])
    x_min = data[ds_ix + 1]
    x_max = data[ds_ix + 2]

    #print()
    #print('lin interp')

    #print(x_start_ix)
    #print(x_len)
    #print(x_min)
    #print(x_max)
    #print(data[x_start_ix:ds_ix + 2 +5])
    #print(data[x_start_ix + x_len:x_start_ix + x_len + 5])
    #print(data[x_start_ix + x_len*2 -5:])



    if x_max == x_min:
        x_norm = 0
    else:
        x_norm = (x - x_min) / (x_max - x_min)

    x_frac_ix = x_norm * (x_len - 1)

    x1_ix = int(np.floor(x_frac_ix))
    if x1_ix < 0:
        x1_ix = 0
    elif x1_ix >= x_len:
        x1_ix = x_len - 1

    x1 = data[x1_ix + x_start_ix]

    x2_ix = int(np.ceil(x_frac_ix))
    x2 = data[x2_ix + x_start_ix]

    if x2_ix < 0:
        x2_ix = 0
    elif x2_ix >= x_len:
        x2_ix = x_len - 1

    q_start_ix = x_start_ix + x_len

    y_dat = data[q_start_ix:q_start_ix+x_len]



    q1 = data[q_start_ix + x1_ix]
    q2 = data[q_start_ix + x2_ix]

    if x == x1 or x == x2 or x1 == x2:
        z = q1
    else:
        z = (q1 * (x2 - x) + q2 * (x - x1)) / (x2 - x1)
    """"
    if ds_ix == 3168084:
        if z < 20 or z>

        print()
        print(ds_ix)
        print('x frac: ', x_frac_ix)
        print('x1: ', x1)
        print('x2: ', x2)
        print('q1: ', q1)
        print('q2: ', q2)
        print('z: ', z)
    """
    #if min(y_dat) > 19.0 and min(y_dat) < 21:
    if z > q1 and z>q2:
        if q1 > q2:
            z = q1
        else:
            z = q2

    elif z < q1 and z < q2:
        if q1 < q2:
            z = q1
        else:
            z = q2
    return z

@jit(nopython=True, nogil=True)#, cache=True)
def nn_interpolation(ix, x, data_def, data):
    '''Interpolate (x) from values associated with two points.

    '''
    ds_ix = data_def[ix]

    x_start_ix = ds_ix + 3
    x_len = int(data[ds_ix])
    x_min = data[ds_ix + 1]
    x_max = data[ds_ix + 2]

    #print()
    #print('lin interp')

    #print(x_start_ix)
    #print(x_len)
    #print(x_min)
    #print(x_max)
    #print(data[x_start_ix:ds_ix + 2 +5])
    #print(data[x_start_ix + x_len:x_start_ix + x_len + 5])
    #print(data[x_start_ix + x_len*2 -5:])



    if x_max == x_min:
        x_norm = 0
    else:
        x_norm = (x - x_min) / (x_max - x_min)

    x_frac_ix = x_norm * (x_len - 1)

    x1_ix = int(np.round(x_frac_ix))
    if x1_ix < 0:
        x1_ix = 0
    elif x1_ix >= x_len:
        x1_ix = x_len - 1

    x1 = data[x1_ix + x_start_ix]



    q_start_ix = x_start_ix + x_len

    y_dat = data[q_start_ix:q_start_ix+x_len]



    q1 = data[q_start_ix + x1_ix]




    return q1

@jit(nopython=True, nogil=True, cache=True)
def equilibrate_sys(y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank):
    y_dot_last = y_dot = y * 0
    change = 1
    count = 0
    while change > 0:
        y_dot = run_ops(0, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
        change = np.sum(np.abs(y_dot - y_dot_last))
        y_dot_last = y_dot
        count += 1

    return count


@jit(nopython=True, nogil=True, cache=True)
def solve_BE(t, tend, dt_solve, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank):
    dy = y1 = y2 = y_dot = y * 0
    maxtries = 100
    dt_solve_act = dt_solve# 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
    t1 = t + dt_solve_act
    while t1 <= tend and error_op[0] == 0:
        #dt_solve_act = dt_solve if t > 100 else 0.05
        sys_copy = np.copy(sys)
        y_dot = run_ops(t, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)

        y1 = y + dt_solve_act * y_dot

        error = 1
        last_error = 1e20
        cnv = -1


        tries = 0
        while error_op[0] == 0 and cnv < 0 and error>0:
            tries += 1
            sys[:]=sys_copy[:]
            y2 = y + dt_solve_act * run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank,
                                        False)

            # g = run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # gd = run_ops(t1, y1*1.001, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # y2=y1-(y1-y-h*g)/(1-h*(gd-g))
            dy = y2 - y1
            y1 = y2  # np.copy(y2)
            # error = np.max(np.abs(dy / (y2 + 1e-5)))
            error = np.sqrt(np.sum(dy ** 2))  # np.max(np.abs(dy))#
            cnv = error - last_error
            last_error = error

        y = y2
        #dt_solve_act = 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
        t1 = t + dt_solve_act
        t = t1



    return y

@jit(nopython=True, nogil=True, cache=True)
def solve_BE2(t, tend, dt_solve, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank):
    dy = y1 = y2 = y_dot = y * 0
    maxtries = 100
    dt_solve_act = dt_solve# 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
    t1 = t + dt_solve_act
    while t1 <= tend and error_op[0] == 0:
        #dt_solve_act = dt_solve if t > 100 else 0.05
        sys_copy = np.copy(sys)
        y_dot = run_ops(t, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
        #print(y_dot)
        #print(y)
        y1 = y + dt_solve_act * y_dot

        error = 1
        last_error = 1e20
        cnv = -1


        tries = 0
        while error_op[0] == 0 and cnv < 0 and error>0:
            tries += 1
            sys[:]=sys_copy[:]
            y2 = y + dt_solve_act * run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank,
                                        False)

            # g = run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # gd = run_ops(t1, y1*1.001, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # y2=y1-(y1-y-h*g)/(1-h*(gd-g))
            dy = y2 - y1
            y1 = y2  # np.copy(y2)
            # error = np.max(np.abs(dy / (y2 + 1e-5)))
            error = np.sum(np.abs(dy))  # np.max(np.abs(dy))#
            cnv = error - last_error
            last_error = error

        y = y2
        #dt_solve_act = 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
        t1 = t + dt_solve_act
        t = t1



    return y

@jit(nopython=True, nogil=True, cache=True)
def solve_BDF5(t, tend, dt_solve, y, sys, sys_out1, states, states_dot, ops, error_op, data_def, data_bank):
    #dy =  (y * 0).copy()
    #y1 =  (y * 0).copy()
    y2 =  (y * 0).copy()
    #y_dot =  (y * 0).copy()
    y_1 =  (y * 0).copy()
    y_2 = (y * 0).copy()
    y_3 =  (y * 0).copy()
    y_4 =  (y * 0).copy()
    y_5=  (y * 0).copy()
    maxtries = 3
    # 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
    #dt_solve_act = dt_solve
    t1 = t + dt_solve
    order = 0

    while t1 <= tend and error_op[0] == 0:
        #dt_solve_act = min(dt_solve_act * 1.5, dt_solve)
        sys_copy = sys.copy()
        sys_out = sys.copy()
        #y_dot = run_ops(t, y, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)

        #+ dt_solve_act * y_dot

        error = 1
        last_error = 1e20
        cnv = -1


        tries = 0
        while error_op[0] == 0 and cnv < 0 and tries < maxtries:
            tries += 1
            sys[:]=sys_copy.copy()
            y1 = y.copy()
                #y2 = y + dt_solve_act * run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank,
                 #                       False)

            derivatives = run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank,
                                        False)

            if order < 1:
                y2 = y + dt_solve * derivatives
            elif order < 2:
                y2 = 2 / 3 * dt_solve * derivatives - 1 / 3 * y_1 + 4 / 3 * y1
            elif order < 3:
                y2 = 6 / 11 * dt_solve * derivatives + 18 / 11 * y1 - 9 / 11 * y_1 + 2 / 11 * y_2
            elif order < 4:
                y2 = 12 / 25 * dt_solve * derivatives+ 48 / 25 * y1 - 36 / 25 * y_1 + 16 / 25 * y_2 - 3 / 25 * y_3
            elif order < 5:
                y2 = (60 * dt_solve * derivatives + 300 * y1 - 300 * y_1 + 200 * y_2 - 75 * y_3 + 12 * y_4) / 137
            else:
                y2 = (60 * dt_solve * derivatives + 360 * y1 - 450 * y_1 + 400 * y_2 - 225 * y_3 + 72 * y_4 - 10 * y_5) / 147
            # g = run_ops(t1, y1, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # gd = run_ops(t1, y1*1.001, sys, sys_out, states, states_dot, ops, error_op, data_def, data_bank, False)
            # y2=y1-(y1-y-h*g)/(1-h*(gd-g))
            dy = y2 - y1
            y1 = y2.copy()
            # error = np.max(np.abs(dy / (y2 + 1e-5)))
            error = np.sqrt(np.sum(dy ** 2))  # np.max(np.abs(dy))#
            cnv = error - last_error
            last_error = error


        y_5 = y_4.copy()
        y_4 = y_3.copy()
        y_3 = y_2.copy()
        y_2 = y_1.copy()
        y_1 = y.copy()
        y = y2.copy()
        order += 1


        #dt_solve_act = 1e-3 + (dt_solve - 1e-3) / (1 + np.exp(-(t - 1)))
        t1 += dt_solve




    return y

