#!/usr/bin/env python3

import xarray as xr
import numpy as np
import glob
import sys
import os
import matplotlib.pyplot as plt

BASE='/home/f/Fabian.Jakub/kcs/work/palm/JOBS/ONKELTOMS_06'
BASE=f"{os.environ['SCRATCH']}/ONKELTOMS_06"
EXPERIMENTS = glob.glob(f"{BASE}*")
experiment_name = lambda _: _.replace(BASE, "").split(".")[0][1:]

NSRFCLAYERS = 5

result = {}
for EXP in EXPERIMENTS:
    E = experiment_name(EXP)
    files3d = sorted(glob.glob(f"{EXP}/OUTPUT/*3d*.nc"))
    print(f"{E} : {files3d}")
    for f3d in files3d:
        D = xr.open_dataset(f3d, decode_times=False)
        bot_layer = np.isfinite(D.theta.isel(time=0)).argmax(dim='zu_3d')
        if E not in result:
            result[E] = dict(
                    time = [],
                    tmean = [],
                    tmin = [],
                    tmax = [],
                    tsrfc = [],
                    )
        for it, t in enumerate(D.time):
            if t<1e6:
                result[E]['time'].append(float(t.data))

                tsrfc = xr.concat([ D.theta.isel(time=it, zu_3d=bot_layer+_) for _ in range(NSRFCLAYERS) ], dim='z')
                result[E]['tsrfc'].append(tsrfc)
                result[E]['tmean'].append(float(tsrfc.mean().data))
                result[E]['tmin'].append (float(tsrfc.min().data))
                result[E]['tmax'].append (float(tsrfc.max().data))

        print(f"{E} : {result[E]}")

colors = ('blue', 'red', 'orange', 'black', 'green')

plt.figure(num=1)
plt.clf()
for iE, E in enumerate(sorted(result.keys())):
    t = np.array(result[E]['time'])/3600.
    tmean, tmin, tmax = [  np.array(result[E][_]) - 273.15 for _ in ('tmean', 'tmin', 'tmax') ]
    plt.plot(t, tmin, linestyle=':', color=colors[iE])
    plt.plot(t, tmax, linestyle=':', color=colors[iE])
    plt.fill_between(t, tmin, tmax, alpha=.3, color=colors[iE])
    plt.plot(t, tmean, label=E, color=colors[iE])

plt.title(f'Temperate near surface [0 - {NSRFCLAYERS*2} m]')
plt.xlabel('time [hrs]')
plt.ylabel('Temperature [C]')
plt.legend(loc='best')
plt.savefig(f'srfc_temp.timeseries.pdf', bbox_inches='tight')

plt.figure(num=2, figsize=(14,4))
imshow_args = dict(origin='lower', cmap='Spectral_r')
for it, t in enumerate(result["3_10"]['time']):
    plt.clf()
    plt.suptitle(f'Temperate near surface [0 - {NSRFCLAYERS*2} m] time={t/3600.:.2f} hrs')

    temp_1d = (result["1d"]["tsrfc"][it].mean('z') - 273.15)
    temp_3_10 = (result["3_10"]["tsrfc"][it].mean('z') - 273.15)
    temp_rtm = (result["rtm"]["tsrfc"][it].mean('z') - 273.15)
    imshow_args['vmin'] = float(xr.concat([temp_1d.min(), temp_3_10.min(), temp_rtm.min()], dim='tmp').min())
    imshow_args['vmax'] = float(xr.concat([temp_1d.max(), temp_3_10.max(), temp_rtm.max()], dim='tmp').max())

    plt.subplot(131)
    plt.imshow(temp_3_10, **imshow_args)
    plt.title(f'3_10')
    plt.colorbar()

    plt.subplot(132)
    plt.imshow(temp_rtm, **imshow_args)
    plt.title(f'rtm')
    plt.colorbar()

    plt.subplot(133)
    plt.imshow(temp_1d, **imshow_args)
    plt.title(f'1D')
    plt.colorbar()

    plt.savefig(f'srfc_temp.{it}.png')
