import ipywidgets as widgets
import math

from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.graph_objs as go
import json
import pandas
import numpy as np
import base64
from IPython.display import HTML

date_fmt = '%Y-%m-%d %H:%M:%S %z UTC'
style = {'description_width': 'initial'}


def save_json(data, file):
    with open(file, 'w') as outfile:
        json.dump(data, outfile)


def load_json(file):
    with open(file) as json_file:
        data = json.load(json_file)
        return data


def create_download_link(df, title="Download CSV file", filename="data.csv"):
    csv = df.to_csv(sep=';', decimal=',')
    b64 = base64.b64encode(csv.encode())
    payload = b64.decode()
    html = '<a download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">{title}</a>'
    html = html.format(payload=payload, title=title, filename=filename)
    return HTML(html)


def agg(df, atype='mean', freq='D', tags=None, df_to_update=None):
    # print(type(df_to_update))
    if df_to_update is None:
        # print('kkk')
        df_to_update = df
    # else:
    # print('kk2')
    if tags is not None:
        if len(tags) > 0:
            df_to_update[tags] = agg_(df[tags], atype, freq)

    else:
        return agg_(df, atype, freq)
    return df_to_update


def agg_(df, atype='mean', freq='D'):
    if freq == 'M':
        if atype == 'mean':
            return df.groupby([(df.index.year), (df.index.month)]).mean()
        elif atype == 'stddev':
            return df.groupby([(df.index.year), (df.index.month)]).stddev()
        elif atype == 'sum':
            return df.groupby([(df.index.year), (df.index.month)]).sum()
        elif atype == 'min':
            return df.groupby([(df.index.year), (df.index.month)]).min()
        elif atype == 'max':
            return df.groupby([(df.index.year), (df.index.month)]).max()

        else:
            raise ValidationError('Unsupported aggregator!')
    else:
        if atype == 'mean':
            return df.groupby(pandas.Grouper(freq=freq)).mean()
        elif atype == 'stddev':
            return df.groupby([(df.index.year), (df.index.month)]).stddev()
        elif atype == 'sum':
            return df.groupby(pandas.Grouper(freq=freq)).sum()
        if atype == 'min':
            return df.groupby(pandas.Grouper(freq=freq)).min()
        elif atype == 'max':
            return df.groupby(pandas.Grouper(freq=freq)).max()

        else:
            raise ValidationError('Unsupported aggregator!')


def makeplot(layo, y, range_slider=False, fill=None, captions=[], y2=None, show_iplot=True):
    data = []
    for i, s in enumerate(y):
        # print(s[3:])
        vis = s[4] if len(s) > 4 else True
        yaxis = s[5] if len(s) > 5 else None
        # print(yaxis)
        # print(vis)
        if fill:
            if i > 0:
                d = go.Scatter(x=s[0],
                               y=s[1], name=s[2], fill=fill[i], visible=vis, yaxis=yaxis)
            else:
                d = go.Scatter(x=s[0],
                               y=s[1], name=s[2], fill=fill[i], visible=vis, yaxis=yaxis)
        else:
            d = go.Scatter(x=s[0],
                           y=s[1], name=s[2], visible=vis, yaxis=yaxis)
        data.append(d)

    if range_slider:
        layout = dict(
            title=layo[0],
            yaxis=dict(
                title=layo[2],
                rangemode='tozero',
            ),
            xaxis=dict(
                title=layo[1],
                rangeselector=dict(
                    buttons=list([
                        dict(count=1,
                             label='1m',
                             step='month',
                             stepmode='backward'),
                        dict(count=6,
                             label='6m',
                             step='month',
                             stepmode='backward'),
                        dict(count=1,
                             label='YTD',
                             step='year',
                             stepmode='todate'),
                        dict(count=1,
                             label='1y',
                             step='year',
                             stepmode='backward'),
                        dict(step='all')
                    ])
                ),
                rangeslider=dict(),
                type='date'
            )
        )
    else:
        y2_def = dict(
            title=layo[3],
            rangemode='tozero',
            # titlefont=dict(
            #    color='rgb(0,0,0)'
            # ),
            # tickfont=dict(
            #    color='rgb(0,0,0)'
            # ),
            overlaying='y',
            side='right'
        ) if y2 else None
        # print(y2_def)
        layout = dict(
            title=layo[0],
            xaxis=dict(
                title=layo[1],
            ),
            yaxis=dict(
                title=layo[2],
                rangemode='tozero',

            ),
            yaxis2=y2_def
        )

    # print(data)
    fig = dict(data=data, layout=layout)
    if show_iplot:
        iplot(fig)

    return fig

    # if len(captions)>0:

    # display(HTML(figure_number))
    # figure_number += 1

    # display(HTML(num_fig(captions)))


def makebar(layo, y, stack=False, captions=[], y2=None, show_iplot=True):
    data = []
    for s in y:
        vis = s[4] if len(s) > 4 else True
        yaxis = s[5] if len(s) > 5 else None
        data += [go.Bar(x=s[0],
                        y=s[1], name=s[2], visible=vis, yaxis=yaxis)]

    y2_def = dict(
        title=layo[3],
        titlefont=dict(
            color='rgb(148, 103, 189)'
        ),
        tickfont=dict(
            color='rgb(148, 103, 189)'
        ),
        overlaying='y',
        side='right'
    ) if y2 else None

    layout = dict(
        title=layo[0],
        xaxis=dict(
            title=layo[1],
        ),
        yaxis=dict(
            title=layo[2],

        ),
        yaxis2=y2_def
    )
    if stack:
        layout['barmode'] = 'stack'
    fig = dict(data=data, layout=layout)

    if show_iplot:
        iplot(fig)

    return fig

    # if len(captions)>0:

    # display(HTML(num_fig(captions)))


def make_graphs(df, plot_defs, plot=True, overwrite=None, stack=False, y2=None, show_iplot=True):
    # If overwrite is not None -> it has VALUES and presumably a person wants to change the titles etc.
    # of the pre-defined graphs. 
    # "overwrite" is a list in itself, capable of changing multiple attributes of the graph to be made 
    # overwrite=[no_of_graph_to_be_altered,('')]

    # If overwrite is not None...
    if not overwrite:
        select = [(pd, None, 1) for pd in plot_defs]
    else:
        select = []
    for o in overwrite:
        # print(type(o)
        # The 'overwrite' list members are getting added to the "select" list here? Why exactly?
        if isinstance(o, int):
            select.append((plot_defs[o], None, 1))
        else:
            # print(len(o))
            select.append((plot_defs[o[0]], o[1], 1 if len(o[1]) < 3 else o[1][2]))

    figures = []
    for o in select:
        pd = o[0]
        xy = o[1]

        # xy is still getting values from the overwrite part, this is bassically what is in the () within the overwrite values
        # If xy is not None ->
        if not xy:
            # Shall be able to change the title of the 2nd Y_AXIS. DOES NOT HAPPEN THOUGH
            xy = (pd.x_title, pd.y_title, pd.y_title2 if hasattr(pd, 'y_title2') else '')

        captions = [pd.caption if hasattr(pd, 'caption') else '', pd.note if hasattr(pd, 'note') else '']
        # print(o)
        if len(xy) > 3:
            captions[1] = xy[3]

        if len(xy) > 4:
            captions[0] = xy[4]

        if len(xy) < 3:
            xy = (xy[0], xy[1], '')

        if plot:
            figures.append(makeplot((pd.title, xy[0], xy[1], xy[2]), [(df.index, df[y[0]] * y[2] * o[2], y[1], y[2],
                                                                       y[3] if len(y) > 3 else True,
                                                                       y[4] if len(y) > 4 else None) for y in
                                                                      pd.y_values], y2=y2, captions=captions,
                                    show_iplot=show_iplot))
        else:
            figures.append(makebar((pd.title, xy[0], xy[1], xy[2]), [(df.index, df[y[0]] * y[2] * o[2], y[1], y[2],
                                                                      y[3] if len(y) > 3 else True,
                                                                      y[4] if len(y) > 4 else None) for y in
                                                                     pd.y_values], stack=stack, y2=y2,
                                   captions=captions, show_iplot=show_iplot))

        return figures


def make_tables(df, plot_defs, overwrite=None):
    if not overwrite:
        select = [(pd, None, 1) for pd in plot_defs]
    else:
        select = []
        for o in overwrite:
            # print(type(o))
            if isinstance(o, int):
                select.append((plot_defs[o], None, 1))
            else:
                # print(len(o))
                select.append((plot_defs[o[0]], o[1], 1 if len(o[1]) < 3 else o[1][2]))

    for o in select:
        pd = o[0]
        xy = o[1]

        this_df = pandas.DataFrame()
        for y in pd.y_values:
            this_df[y[1]] = df[y[0]] * y[2] * o[2]

        # this_df['Month']=this_df.index.month_name()
        # this_df.set_index('Month', inplace=True)
        # pandas.set_option('decimal', 2)
        pandas.options.display.float_format = '{:,.2f}'.format
        this_qgrid = qgrid.QgridWidget(df=this_df, precision=2, show_toolbar=False)
        # this_qgrid.observe(observe_qgrid)

        display(create_download_link(this_df, title=pd.title, filename=base_job_alias + '_' + pd.title + '.csv'))
        display(this_df)


class dot_dict():
    def __init__(self, dictionary: dict):
        for k, v in dictionary.items():
            setattr(self, k, v)


def diff_tags(df, tags, difftags, scale=1):
    for t, tdiff in zip(tags, difftags):
        df[tdiff] = list(df[t].diff().values * scale)
    return df
