
import pandas as pd
import statsmodels.api as sm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LassoCV, LassoLarsCV, LassoLarsIC, LassoLars
from multiprocessing import Pool
import os
from functools import partial
import pywt

def get_dummies(first_day):
    dummies=np.zeros((last_day, 7))
    for day in range(last_day):
        dummies[day, (day+first_day) % 7] = 1
    return dummies

def remove_LTSC(x, param, wavelet):
    mode='symmetric'
    coeffs = pywt.wavedec(x, wavelet, level = 14, mode = mode)
    ltsc = pywt.waverec(coeffs[:15 - param] + [None]*param, wavelet, mode)
    ltsc = ltsc[:len(x)]
    x = x - ltsc
    return x, ltsc

def get_cal_dataset(cal_day_index, df, dummies, param, wavelet):

    cal_day_index_h = cal_day_index*24

    prices = df['prices'][cal_day_index_h:cal_day_index_h + cal_window_len_d * 24].to_numpy()
    exog1 = df['exog1'][cal_day_index_h:cal_day_index_h + (cal_window_len_d + 1) * 24].to_numpy()
    exog2 = df['exog2'][cal_day_index_h:cal_day_index_h + (cal_window_len_d + 1) * 24].to_numpy()
    dummies_in_window = dummies[cal_day_index:cal_day_index + cal_window_len_d + 1,:]

    #variable transformation
    prices, prices_median, prices_mad = var_stab_trans(prices)
    exog1, _, _ = var_stab_trans(exog1)
    exog2, _, _ = var_stab_trans(exog2)

    #remove LTSC
    prices, prices_ltsc = remove_LTSC(prices, param, wavelet)
    exog1, _ = remove_LTSC(exog1, param, wavelet)
    exog2, _ = remove_LTSC(exog2, param, wavelet)

    prices = np.reshape(prices, (-1, 24))    #prices : (cal_window_len_d, 24)
    exog1 = np.reshape(exog1, (-1, 24))      #exog1 : (cal_window_len_d + 1, 24)
    exog2 = np.reshape(exog2, (-1, 24))      #exog2 : (cal_window_len_d + 1, 24)

    prices_min = np.min(prices, axis = 1)
    prices_max = np.max(prices, axis = 1)

    X = np.zeros((cal_window_len_d - 7, 129))
    Xr = np.zeros((1, 129))

    for i in range(24):
        X[:, i] = prices[6:-1, i]         # 0 - 23: 24 prices -1 day
        X[:, 24 + i] = prices[5:-2, i]    # 24 - 47: 24 prices -2 day
        X[:, 48 + i] = prices[0:-7, i]    # 48 - 71: 24 prices -7 day
        X[:, 72 + i] = exog1[7:-1, i]     # 72 - 95: 24 exog1 
        X[:, 96 + i] = exog2[7:-1, i]     # 96 - 119: 24 exog2
        Xr[0, i] = prices[-1,i]
        Xr[0, 24 + i] = prices[-2, i]
        Xr[0, 48 + i] = prices[-7, i]
        Xr[0, 72 + i] = exog1[-1, i]
        Xr[0, 96 + i] = exog2[-1, i]

    X[:, 120] = prices_min[6:-1]
    X[:, 121] = prices_max[6:-1]
    X[:, 122:] = dummies_in_window[7:-1,:]

    Xr[0, 120] = prices_min[-1]
    Xr[0, 121] = prices_max[-1]
    Xr[0, 122:] = dummies_in_window[-1,:]

    Y = prices[7:, :]
    
    return X, Xr, Y, prices_median, prices_mad, prices_ltsc[-24:]

def run_model(X, Xr, Y, prices_median, prices_mad, dataset, cal_day_index, prices_ltsc, param_index, wavelet):
    """
    returns 24 hourly forecasts
    """
    forecasts = [0]*24
    alphas = [0]*24
    nonzeros = [0]*24
    for hour in range(24):
        model = LassoLarsCV(cv=7, fit_intercept = False)
        fitted_model = model.fit(X, Y[:, hour])
        forecasts[hour] = inv_var_stab_trans(model.predict(Xr)[0] + prices_ltsc[hour], prices_median, prices_mad) 
    return forecasts

def make_forecasts(cal_day_index, df, dummies, dataset, param, param_index, wavelet):
    """ 
    cal_day_index:    first calibration day index
    returns 24 hourly forecasts for the day right after the calibration window, i.e, (cal_day_index + cal_window_len_d +1) day 
    """
    X, Xr, Y, prices_median, prices_mad, prices_ltsc = get_cal_dataset(cal_day_index, df, dummies, param, wavelet)
    forecasts = run_model(X, Xr, Y, prices_median, prices_mad, dataset, cal_day_index, prices_ltsc, param_index, wavelet)
    return forecasts

def var_stab_trans(x):
    x_median = np.median(x)
    x_mad = sm.robust.scale.mad(x)
    if x_mad == 0:
        x_mad = np.std(x, ddof = 1) 
    x = np.arcsinh((x - x_median) / x_mad)
    return x, x_median, x_mad

def inv_var_stab_trans(x, x_median, x_mad):
    return np.sinh(x) * x_mad + x_median

file_names={'PJM' : 'PJM.csv',
            'NP'  : 'NP.csv'}

cal_window_len_d = 364
first_cal_day = 0       #first calibration day index #second part 1092
last_day = 2184        #last estimation day #second part 2184

days = range(last_day - cal_window_len_d - first_cal_day)  
params = np.arange(6,15)
wavelets = ['db4']

if __name__ == '__main__':
    
    for dataset in file_names:   
        df = pd.read_csv(f"DATA/{file_names[dataset]}", header = 0, names = ['dates', 'prices', 'exog1', 'exog2'])
        df = df.iloc[first_cal_day*24:last_day*24,:].reset_index(drop = True)
        dummies = get_dummies(pd.to_datetime(df['dates'][0]).dayofweek)
        for wavelet in wavelets:
            for i in range(len(params)):
                param = params[i]
                if not os.path.exists(f'LASSO2_{wavelet}_{i}_cv7_{dataset}'):
                    os.mkdir(f'LASSO2_{wavelet}_{i}_cv7_{dataset}')
        
                with Pool(14) as pool:
                    sol = pool.map(partial(make_forecasts, df = df, dummies = dummies, dataset = dataset, param = param, param_index = i, wavelet = wavelet), days)
        
                forecast_list = []
                forecast_list = [item for sublist1 in sol for item in sublist1]
                np.savetxt(f'LASSO2_{wavelet}_{i}_cv7_{dataset}/{dataset}_forecasts.csv', forecast_list)
                
