# import necessary packages
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy as cart
import cartopy.feature as cfeature
import matplotlib.tri as tri
from matplotlib import colors
import os
from glob import glob
from pandas import read_csv
import h5py
import yaml
import sys

try:
    import pprint
    pp = pprint.PrettyPrinter(indent=4,sort_dicts=False)
    pp_import = True
except:
    pp_import = False
    
try:
    import cmasher as cmr
    cmap = cmr.cosmic    # CMasher
    cmap = plt.get_cmap('cmr.cosmic')   # MPL
    cmash_import = True
except:
    cmash_import = False
    cmap=plt.get_cmap('Blues_r')
    
    
# print all the info?
print_info = False

try:
    # provide path to project
    project_path = sys.argv[1]
except:
    print()
    print("!!!!!"*15+"\n")
    print("Please provide the path to the inversion folder as argument \n")
    print("!!!!!"*15+"\n")
    
    sys.exit()

# check path
if not os.path.isdir(project_path):
    print()
    print("!!!!!"*15+"\n")
    print("Please provide the path to the inversion folder as argument \n")
    print("!!!!!"*15+"\n")
    
    sys.exit()
    

# create paths
project_name = os.path.basename(os.path.normpath(project_path))
print(f"Project name: {project_name}")

output_path = project_path
plot_path = os.path.join(project_path,'notebook_plots')

if not os.path.exists(plot_path):
    os.makedirs(plot_path)
    

print("="*100+'\n')
print(f"All plots will be saved here: {plot_path} \n")
print("="*100)
    
# first read in config file, runtime file, stationlist, sourcegrid
inv_config_file = os.path.join(output_path,"inversion_config.yml")
runtime_file = os.path.join(output_path,"runtime.txt")
station_file = os.path.join(output_path,"stationlist.csv")
sourcegrid_file = os.path.join(output_path,"sourcegrid.npy")
sourcegrid_voronoi_file = os.path.join(output_path,"sourcegrid_voronoi.npy")




# check all parameters used
with open(inv_config_file) as f:
    inv_config = yaml.safe_load(f)
    
if pp_import and print_info:    
    pp.pprint(inv_config)
elif print_info:
    print(inv_config)
    
# get some info for plotting
only_ocean = inv_config['svp_grid_config']['svp_only_ocean']



# get information from runtime file
runtime_dict = {}
file = open(runtime_file)

for line in file:
    key,value = line.split(':')
    runtime_dict[key.replace(' ','_')] = value.strip()

if pp_import and print_info:    
    pp.pprint(runtime_dict)
elif print_info:
    print(runtime_dict)
    
    
    
print("Plotting stationlist and sourcegrid..")
# read in stationlist and sourcegrid
stationlist = read_csv(station_file,keep_default_na=False)
stat_lat = stationlist['lat']
stat_lon = stationlist['lon']

sourcegrid = np.load(sourcegrid_file)



# plot stationlist and sourcegrid
plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=0))
ax.set_global()

if only_ocean:
    ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=2)
else:
    ax.coastlines(color='k',linewidth=1,zorder=2)
    ax.coastlines(color='w',linewidth=2,zorder=1)

plt.scatter(sourcegrid[0],sourcegrid[1],s=20,c='k',transform=ccrs.PlateCarree(),zorder=3)
plt.title(f'Sourcegrid for {project_name}',fontsize=50,pad=30)
plt.scatter(stat_lon,stat_lat,s=150,c='lawngreen',marker='^',edgecolor='k',linewidth=1,label='Stations',transform=ccrs.PlateCarree(),zorder=3)
plt.legend(fontsize=25,loc=3)
plt.savefig(os.path.join(plot_path,f'sourcegrid.png'),bbox_inches='tight',facecolor='white')
plt.close()



print("Plotting voronoi cells..")
# plot the voronoi cells which are used to scale the noise sources
# triangulating the grid for this
sourcegrid_voronoi = np.load(sourcegrid_voronoi_file)
grid_tri = tri.Triangulation(sourcegrid_voronoi[0],sourcegrid_voronoi[1])


plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=0))
ax.set_global()
if only_ocean:
    ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=2)
else:
    ax.coastlines(color='k',linewidth=1,zorder=2)
    ax.coastlines(color='w',linewidth=2,zorder=1)
    
#plt.scatter(sourcegrid[0],sourcegrid[1],s=20,c=sourcegrid_voronoi[2],cmap=cmap,transform=ccrs.PlateCarree(),zorder=3)
plt.tripcolor(grid_tri,sourcegrid_voronoi[2],cmap=cmap,vmin=0,linewidth=0.0,edgecolor='none',zorder=1,transform=ccrs.Geodetic())
cbar = plt.colorbar(pad=0.01)
cbar.ax.tick_params(labelsize=30) 
cbar.set_label('Voronoi cell size',rotation=270,labelpad=60,fontsize=40)

plt.title(f'Voronoi cells for {project_name}',fontsize=50,pad=30)
plt.scatter(stat_lon,stat_lat,s=150,c='lawngreen',marker='^',edgecolor='k',linewidth=1,label='Stations',transform=ccrs.PlateCarree(),zorder=3)
plt.legend(fontsize=25,loc=3)
plt.savefig(os.path.join(plot_path,f'sourcegrid_voronoi.png'),bbox_inches='tight',facecolor='white')
plt.close()



print("Plotting station ray coverage..")
# create dictionary with stations and their antipoles
station_dict = dict()
station_anti_dict = dict()

for sta_i in stationlist.iterrows():
    sta = sta_i[1]
    station_dict.update({f"{sta['net']}.{sta['sta']}":[sta['lat'],sta['lon']]})

    if sta['lon'] < 0:
        station_anti_dict.update({f"{sta['net']}.{sta['sta']}":[-sta['lat'],sta['lon']+180]})
    elif sta['lon'] >= 0:
        station_anti_dict.update({f"{sta['net']}.{sta['sta']}":[-sta['lat'],sta['lon']-180]})

        
# read in list of station pairs used
used_obs_corr = read_csv(os.path.join(output_path,'used_obs_corr_list.csv'),header=None)

stat_pair_list = []

for i in used_obs_corr.iterrows():
    sta_1 = f"{i[1][0].split('--')[0].split('.')[0]}.{i[1][0].split('--')[0].split('.')[1]}"
    sta_2 = f"{i[1][0].split('--')[1].split('.')[0]}.{i[1][0].split('--')[1].split('.')[1]}"
    
    stat_pair_list.append(f"{sta_1}--{sta_2}")
    

stat_pair_dict = {}
n_rays = 0

plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson())
ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=1)
#ax.coastlines()
ax.set_global()

for stat_pair in stat_pair_list:
    i = stat_pair.split('--')[0]
    j = stat_pair.split('--')[1]


    stat_pair_dict.update({i:[station_dict[i][0],station_dict[i][1]]})
    stat_pair_dict.update({j:[station_dict[j][0],station_dict[j][1]]})

    plt.plot([station_dict[i][1],station_anti_dict[j][1]],[station_dict[i][0],station_anti_dict[j][0]], color='k',zorder=2, transform=ccrs.Geodetic(),alpha=4/np.size(list(station_dict.keys())))
    plt.plot([station_anti_dict[i][1],station_dict[j][1]],[station_anti_dict[i][0],station_dict[j][0]], color='k',zorder=2, transform=ccrs.Geodetic(),alpha=4/np.size(list(station_dict.keys())))

    n_rays += 1

stat_lat = np.asarray(list(stat_pair_dict.values())).T[0]
stat_lon = np.asarray(list(stat_pair_dict.values())).T[1]

plt.scatter(stat_lon,stat_lat,marker='^',s=250,c='k',edgecolors='w',linewidths=2,transform=ccrs.PlateCarree(),zorder=3)

plt.title(f"Ray coverage for {project_name} with {n_rays} rays",pad=30,fontsize=30)
plt.savefig(os.path.join(plot_path,f'kernel_ray_coverage.png'),bbox_inches='tight',facecolor='white')

plt.close()



print("Plotting misfit..")
# plot misfit
# misfit
steps_avail_path = [os.path.join(output_path,i) for i in os.listdir(output_path) if i.startswith('iteration') and not os.path.isfile(os.path.join(output_path,i))]

misfit_step = []
misfit_dict = dict()

for j in steps_avail_path:

    measr_file_paths_var = [os.path.join(j,i) for i in os.listdir(j) if i.endswith('measurement.csv')]

    #print(measr_file_paths_var)
    if measr_file_paths_var == []:
        print(f'No measurement for {os.path.basename(j)}')

    else:
        i = measr_file_paths_var[0]
        step_nr_var = int(i.split('/')[-2].split('_')[1])
        measr_step_var = read_csv(i,keep_default_na=False)

        l2_norm_all = np.asarray(measr_step_var['l2_norm'])
        l2_norm = l2_norm_all[~np.isnan(l2_norm_all)]
        mf_step_var = np.mean(l2_norm)

        misfit_step.append([step_nr_var,mf_step_var])  
        misfit_dict.update({step_nr_var:mf_step_var})


misfit = [i[1] for i in misfit_step]
step = [i[0] for i in misfit_step]

step,misfit = zip(*sorted(zip(step,misfit)))

# misfit reduction
mf_reduc = (1-(misfit[-1]/misfit[0]))*100

#### MISFIT PLOTS
plt.figure(figsize=(15,8))
plt.plot(step,misfit,'k',marker='o',markerfacecolor='b',markersize=10)
plt.grid()
plt.xlabel('Iterations',fontsize=15)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.title(f'Misfit for "{project_name}" reduced by {np.around(mf_reduc,2)}%',fontsize=25,pad=20)
plt.savefig(os.path.join(plot_path,'misfit_vs_iterations.png'),bbox_inches='tight',facecolor='white')
plt.close()




# pick an iteration
iteration_nr = 0

print(f"Plotting steplength test for iteration {iteration_nr}..")

### plot steplength tests
steplength_files = sorted([os.path.join(i,'misfit_step_test.npy') for i in steps_avail_path if os.path.isfile(os.path.join(i,'misfit_step_test.npy'))])
steplength_fit_files = sorted([os.path.join(i,'misfit_step_test_fit.npy') for i in steps_avail_path if os.path.isfile(os.path.join(i,'misfit_step_test_fit.npy'))])

step = steplength_files[iteration_nr].split('/')[-2].split('_')[1]

mf_step = np.asarray(np.load(steplength_files[iteration_nr],allow_pickle=True)[0])
mf_final_step = np.load(steplength_files[iteration_nr],allow_pickle=True)[1]
mf_step_fit = np.load(steplength_fit_files[iteration_nr],allow_pickle=True)

step_m = [i[0] for i in mf_step]
mf_m = [i[1] for i in mf_step]

plt.figure(figsize=(15,8))
plt.scatter(step_m,mf_m,c='r')
#plt.scatter(mf_step[:,0],mf_step[:,1],c='r')
plt.scatter(mf_final_step[0],mf_final_step[1],c='r',s=100,marker='x',label='Predicted misfit')
plt.scatter(mf_final_step[0],misfit_dict[int(step)+1],c='b',s=100,marker='x',label='Actual misfit')
plt.plot(mf_step_fit[0],mf_step_fit[1],c='b',label='Fitted exponential')
plt.grid()
plt.title(f'Steplength test for iteration {step} with final step {np.around(mf_final_step[0],2)}. Predicted misfit: {np.around(mf_final_step[1],2)}. Actual misfit: {np.around(misfit_dict[int(step)+1],2)}')
plt.legend()
plt.savefig(os.path.join(plot_path,f'iteration_{step}_0_slt.png'),bbox_inches='tight',facecolor='white')
plt.close()



inversionmodel_paths = sorted(glob(os.path.join(output_path,"models/iteration*.h5")))

# load the first file and check content
inversionmodel_0 = h5py.File(inversionmodel_paths[0],'r')

#print(list(inversionmodel_0.keys()))

# To access the model and grid:
# inversionmodel_0['model']
# inversionmodel_0['coordinates']
print(f"Plotting frequency spectrum of noise sources..")

# plot frequency spectrum used for the forward modelling of the sources
inversionmodel_0_freq = inversionmodel_0['frequencies'][()]
inversionmodel_0_spec = inversionmodel_0['spectral_basis'][()]

plt.figure(figsize=(20,10))
plt.plot(inversionmodel_0_freq,inversionmodel_0_spec[0],linewidth=2,c='k')
plt.xlabel("Frequency [Hz]",fontsize=20)
plt.ylabel("Normalised Power",fontsize=20)
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)
plt.grid()
plt.title(f"Source frequency spectrum for {project_name}",fontsize=30,pad=20)
plt.savefig(os.path.join(plot_path,f'iteration_0_frequency_spectrum.png'),bbox_inches='tight',facecolor='white')
plt.close()




# plot the sourcemodel
iteration_nr = 0
triangulation = True

print(f"Plotting noise distribution for iteration {iteration_nr}..")


#print(step)
source_distr_file = h5py.File(inversionmodel_paths[iteration_nr],'r')
source_grid = np.asarray(source_distr_file['coordinates'])
source_distr = np.asarray(source_distr_file['model']).T[0]
source_distr_norm = source_distr/np.max(np.abs(source_distr))

plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=0))
ax.set_global()

if only_ocean:
    ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=2)
else:
    ax.coastlines(color='k',linewidth=1,zorder=2)
    ax.coastlines(color='w',linewidth=2,zorder=1)

if triangulation:
    triangles = tri.Triangulation(source_grid[0],source_grid[1])

    if cmash_import:
        plt.tripcolor(triangles,source_distr_norm,cmap=cmap,linewidth=0.0,edgecolor='none',vmin=0,zorder=1,transform=ccrs.Geodetic())
    else:
        plt.tripcolor(triangles,source_distr_norm,cmap=plt.get_cmap('Blues_r'),linewidth=0.0,edgecolor='none',vmin=0,zorder=1,transform=ccrs.Geodetic())

else:

    if cmash_import:
        plt.scatter(source_grid[0],source_grid[1],s=20,c=source_distr_norm,vmin=0,transform=ccrs.PlateCarree(),cmap=cmap,zorder=3)
    else:
        plt.scatter(source_grid[0],source_grid[1],s=20,c=source_distr_norm,vmin=0,transform=ccrs.PlateCarree(),cmap=plt.get_cmap('Blues_r'),zorder=3)


cbar = plt.colorbar(pad=0.01)
cbar.ax.tick_params(labelsize=30) 
cbar.set_label('Power Spectral Density',rotation=270,labelpad=40,fontsize=40)

plt.title(f'Noise distribution for {project_name}: Iteration {iteration_nr}',fontsize=50,pad=30)

plt.scatter(stat_lon,stat_lat,s=150,c='lawngreen',marker='^',edgecolor='k',linewidth=1,transform=ccrs.PlateCarree(),zorder=4)
plt.savefig(os.path.join(plot_path,f'iteration_{iteration_nr}_1_noise_distribution.png'),bbox_inches='tight',facecolor='white')
plt.close()




# Plot the gradient (smoothed and unsmoothed)
iteration_nr = 0
triangulation = True

print(f"Plotting gradient for iteration {iteration_nr}..")


sourcegrid = np.load(sourcegrid_file)
grid_tri = tri.Triangulation(sourcegrid_voronoi[0],sourcegrid_voronoi[1])

grad_all_smooth = os.path.join(output_path,f'iteration_{iteration_nr}','grad_all_smooth.npy')
grad_all = os.path.join(output_path,f'iteration_{iteration_nr}','grad_all.npy')
smooth_param = os.path.join(output_path,f'iteration_{iteration_nr}',f'smoothing_iter_{iteration_nr}.npy')


# plot unsmoothed gradient
grad = np.load(grad_all,allow_pickle=True)[0]

v = np.max(np.abs(grad))

plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=0))
ax.set_global()

if only_ocean:
    ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=2)
else:
    ax.coastlines()

if triangulation:
    plt.tripcolor(grid_tri,grad,cmap=plt.get_cmap('seismic'),linewidth=0.0,edgecolor='none',vmin=-v,vmax=v,zorder=1,transform=ccrs.PlateCarree())
else:
    plt.scatter(sourcegrid[0],sourcegrid[1],s=20,c=grad,transform=ccrs.PlateCarree(),cmap='seismic',vmin=-v,vmax=v,zorder=3)

cbar = plt.colorbar(pad=0.01)
cbar.formatter.set_powerlimits((0, 0))
cbar.ax.tick_params(labelsize=30) 
cbar.set_label('Gradient',rotation=270,labelpad=40,fontsize=40)

plt.title(f'Gradient for {project_name}: Iteration {step}',fontsize=50, pad=30)

plt.scatter(stat_lon,stat_lat,s=150,c='lawngreen',marker='^',edgecolor='k',linewidth=1,transform=ccrs.PlateCarree(),zorder=3)
plt.savefig(os.path.join(plot_path,f'iteration_{step}_2_gradient.png'),bbox_inches='tight',facecolor='white')
plt.close() 
            
    
print(f"Plotting smoothed gradient for iteration {iteration_nr}..")

    
# plot smoothed gradient
grad = np.load(grad_all_smooth,allow_pickle=True)[0]

v = np.max(np.abs(grad))

plt.figure(figsize=(50,20))
ax = plt.axes(projection=ccrs.Robinson(central_longitude=0))
ax.set_global()

if only_ocean:
    ax.add_feature(cfeature.NaturalEarthFeature('cultural', 'admin_0_countries', '50m', edgecolor='black', facecolor=cfeature.COLORS['land']),zorder=2)
else:
    ax.coastlines()

if triangulation:
    plt.tripcolor(grid_tri,grad,cmap=plt.get_cmap('seismic'),linewidth=0.0,edgecolor='none',vmin=-v,vmax=v,zorder=1,transform=ccrs.PlateCarree())
else:
    plt.scatter(sourcegrid[0],sourcegrid[1],s=20,c=grad,transform=ccrs.PlateCarree(),cmap='seismic',vmin=-v,vmax=v,zorder=3)

cbar = plt.colorbar(pad=0.01)
cbar.formatter.set_powerlimits((0, 0))
cbar.ax.tick_params(labelsize=30) 
cbar.set_label('Smoothed gradient',rotation=270,labelpad=40,fontsize=40)

plt.title(f'Smoothed gradient for {project_name}: Iteration {step} with {np.load(smooth_param)}° smoothing',fontsize=50, pad=30)

plt.scatter(stat_lon,stat_lat,s=150,c='lawngreen',marker='^',edgecolor='k',linewidth=1,transform=ccrs.PlateCarree(),zorder=3)
plt.savefig(os.path.join(plot_path,f'iteration_{step}_3_gradient_smooth.png'),bbox_inches='tight',facecolor='white')
plt.close() 


print("Plotting finished.")