
import ipywidgets as widgets
import numpy as np
import math
from scipy import interpolate
from job_worker.job_models import *
from plotly import __version__
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.plotly as py
import plotly.graph_objs as go 
import numpy as np
import json
import csv
init_notebook_mode(connected=True)
from datetime import datetime, timedelta
import pandas
import qgrid
import numpy as np
date_fmt = '%Y-%m-%d %H:%M:%S %z UTC'
style = {'description_width': 'initial'}
#from sim_tools.data.plotting import get_df_job_alias
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from plotly import __version__
import plotly.plotly as py
import plotly.graph_objs as go 
from job_worker.job_models import *
from copy import deepcopy

import base64
import pandas as pd
from IPython.display import HTML
import jsonpickle
from dateutil.relativedelta import *

init_notebook_mode(connected=True)

def num_fig(figure_number,captions=['','']):
    caption = '<div><b> Figure ' + str(figure_number['n']) + ': '
    
    caption = caption + captions[0] + '</b></div>' 
    if len(captions)>1 and captions[1]:
        caption = caption + '<div><i>Note: '+captions[1]+'</i></div>'
    figure_number['n']+=1
    display(HTML(caption))
    return figure_number



def create_download_link( df, title = "Download CSV file", filename = "data.csv"):
    csv = df.to_csv(sep=';',decimal=',')
    b64 = base64.b64encode(csv.encode())
    payload = b64.decode()
    html = '<a download="{filename}" href="data:text/csv;base64,{payload}" target="_blank">{title}</a>'
    html = html.format(payload=payload,title=title,filename=filename)
    return HTML(html)

def agg(df, atype='mean',freq='D',tags=None, df_to_update=None):
    #print(type(df_to_update))
    if not isinstance(df_to_update, type(df)):
        #print('kkk')
        df_to_update=df
    #else:
        #print('kk2')
    if tags:
        
        df_to_update[tags]=agg_(df[tags], atype,freq)
    else:
        return agg_(df, atype,freq)
    return df_to_update

def agg_(df, atype='mean',freq='D'):
    if atype=='mean':
        return df.groupby(pandas.Grouper(freq=freq)).mean()
    elif atype=='sum':
        return df.groupby(pandas.Grouper(freq=freq)).sum()
        
    else:
        raise ValidationError('Unsupported aggregator!')

def makeplot(layo,y,range_slider=False, fill=None, captions=[], y2=None):
        data=[]
        for i, s in enumerate(y):
            #print(s[3:])
            vis = s[4] if len(s)>4 else True
            yaxis = s[5] if len(s)>5 else None
            #print(yaxis)
            #print(vis)
            if fill:
                if i>0:
                    d=go.Scatter(x=s[0],
                       y=s[1], name=s[2], fill=fill[i], visible=vis, yaxis=yaxis)
                else:
                    d=go.Scatter(x=s[0],
                       y=s[1], name=s[2], fill=fill[i], visible=vis, yaxis=yaxis)
            else:
                d= go.Scatter(x=s[0],
                   y=s[1], name=s[2], visible=vis, yaxis=yaxis)
            data.append(d)
            
        
                   
        
        if range_slider:
            layout = dict(
                title=layo[0],
                yaxis=dict(
                    title=layo[2],
                ),
                xaxis=dict(
                    title=layo[1],
                    rangeselector=dict(
                        buttons=list([
                            dict(count=1,
                                 label='1m',
                                 step='month',
                                 stepmode='backward'),
                            dict(count=6,
                                 label='6m',
                                 step='month',
                                 stepmode='backward'),
                            dict(count=1,
                                label='YTD',
                                step='year',
                                stepmode='todate'),
                            dict(count=1,
                                label='1y',
                                step='year',
                                stepmode='backward'),
                            dict(step='all')
                        ])
                    ),
                    rangeslider=dict(),
                    type='date'
                )
            )
        else:
            y2_def = dict(
                    title=layo[3],
                    #titlefont=dict(
                    #    color='rgb(0,0,0)'
                    #),
                    #tickfont=dict(
                    #    color='rgb(0,0,0)'
                    #),
                    overlaying='y',
                    side='right'
                ) if y2 else None
            #print(y2_def)
            layout = dict(
                title=layo[0],
                xaxis=dict(
                    title=layo[1],
                    ),
                yaxis=dict(
                    title=layo[2],
                
                    ),
                yaxis2 = y2_def
            )
                

        fig = dict(data=data, layout=layout)
        iplot(fig)
     
        #if len(captions)>0:
            
            #display(HTML(figure_number))
            #figure_number += 1

        
            #display(HTML(num_fig(captions)))
        
        
def makebar(layo,y,stack=False, captions=[], y2=None):
        data=[]
        for s in y:
            vis = s[4] if len(s)>4 else True
            yaxis = s[5] if len(s)>5 else None
            data+=[go.Bar(x=s[0],
                   y=s[1], name=s[2],visible=vis, yaxis=yaxis)]
        
        y2_def = dict(
                    title=layo[3],
                    titlefont=dict(
                        color='rgb(148, 103, 189)'
                    ),
                    tickfont=dict(
                        color='rgb(148, 103, 189)'
                    ),
                    overlaying='y',
                    side='right'
                ) if y2 else None
        
        layout = dict(
            title=layo[0],
            xaxis=dict(
                title=layo[1],
                ),
            yaxis=dict(
                title=layo[2],

                ),
            yaxis2 = y2_def
        )
        if stack:
            layout['barmode']='stack'
        fig = dict(data=data, layout=layout)
        iplot(fig)    
        
        #if len(captions)>0:
                    
           # display(HTML(num_fig(captions)))

def make_graphs(df, plot_defs,plot=True, overwrite=None, stack=False, y2=None):
    
    if not overwrite:
        select = [(pd, None,1) for pd in plot_defs]
    else:
        select = []
        for o in overwrite:
            #print(type(o))
            if isinstance(o,int):
                select.append((plot_defs[o], None, 1))
            else:
                #print(len(o))
                select.append((plot_defs[o[0]], o[1], 1 if len(o[1])<3 else o[1][2]))
    
    
    for o in select:
        pd=o[0]
        xy=o[1]
        
        
        
        if not xy:
            xy=(pd.x_title,pd.y_title, pd.y_title2 if hasattr(pd, 'y_title2') else '')
            
        captions = [pd.caption if hasattr(pd, 'caption') else '', pd.note if hasattr(pd, 'note') else '']
        #print(o)data_stream
        if len(xy)>3:
            captions[1] = xy[3]
        
        if len(xy)>4:
            captions[0] = xy[4]
            
        if len(xy)< 3:
            xy=(xy[0], xy[1], '')
        
        if plot:
            makeplot((pd.title,xy[0],xy[1], xy[2]),[(df.index,df[y[0]]*y[2]*o[2],y[1], y[2], y[3] if len(y)>3 else True, y[4] if len(y)>4 else None) for y in pd.y_values], y2=y2, captions=captions)
        else:
            makebar((pd.title,xy[0],xy[1], xy[2]),[(df.index,df[y[0]]*y[2]*o[2],y[1], y[2], y[3] if len(y)>3 else True, y[4] if len(y)>4 else None) for y in pd.y_values], stack=stack, y2=y2, captions=captions)
    
def make_tables(df, plot_defs, base_job_alias,overwrite=None):
    
    if not overwrite:
        select = [(pd, None,1) for pd in plot_defs]
    else:
        select = []
        for o in overwrite:
            #print(type(o))
            if isinstance(o,int):
                select.append((plot_defs[o], None, 1))
            else:
                #print(len(o))
                select.append((plot_defs[o[0]], o[1], 1 if len(o[1])<3 else o[1][2]))
    
    
    for o in select:
        pd=o[0]
        xy=o[1]
       
        
       
        this_df = pandas.DataFrame()
        for y in pd.y_values:
            this_df[y[1]] = df[y[0]]*y[2]*o[2]
        
        #this_df['Month']=this_df.index.month_name()
        #this_df.set_index('Month', inplace=True)
        #pandas.set_option('decimal', 2)    
        pandas.options.display.float_format = '{:,.2f}'.format
        this_qgrid = qgrid.QgridWidget(df=this_df, precision=2, show_toolbar=False)
        #this_qgrid.observe(observe_qgrid)
        
        display(create_download_link(this_df,title=pd.title, filename=base_job_alias+'_'+pd.title+'.csv'))
        display(this_df)
        
        
            


def get_df_job_alias(job_alias, tags):
    job_out = JobOutput.objects.get(job_alias=job_alias)
    data_stream = job_out.data_set.data_streams[0]


    df = data_stream.get_df(tags, fill_nan=True)
    return df
