import os
import xarray as xr
import glob
import numpy as np
import matplotlib.colors
from matplotlib import pyplot as plt


varentry = {
        'm_tot'     : (None, 'total mass [kg]'),
        'm_w'       : ( 1, 'liq. mass [kg]'),
        'm_i'       : ( 2, 'ice mass [kg]'),
        'm_r'       : ( 3, 'rimed mass [kg]'),
        'm_f'       : (13, 'frozen mass [kg]'),
        'v_r'       : ( 4, 'volume rime [m3]'),
        'T'         : (14, 'particle Temperature [K]'),
        'd'         : ( 5, 'diameter [m]'),
        'A'         : ( 6, 'projected area [m2]'),
        'xi'        : ( 7, 'multiplicity'),
        'mm'        : ( 8, 'monomer multiplicity'),
        'vt'        : (10, 'vt'),
        'gblCellId' : (11, 'gblCellId'),
        'jk'        : (12, 'jk'),
        'atmoT'     : (15, 'atmospheric temperature'),
        'dQdt'      : (16, 'dQdt')
        }


def grid_info(gridfile='Torus_Triangles_1024x4_150m.nc'):
    print(f"Trying to load grid info from {gridfile=}")
    with xr.open_dataset(gridfile) as grid:
       dx = dy = dz = float(grid.edge_length.isel(edge=0))
       vol = dx * dz * grid.domain_height
       domain_length = grid.domain_length
       print(f"{dx=} {dz=} {domain_length=} {vol=}")
    return dict(dx=dx, dz=dz, vol=vol, domain_length=domain_length)

def ds_get_var(ds, varname, multiplicity=True):
    get_var = lambda vname: ds[f'addVar{varentry[vname][0]:04d}']
    if varname == 'm_tot':
        vardata = xr.concat([ get_var(vname) for vname in ('m_f', 'm_w', 'm_i', 'm_r')], dim='tracer').sum('tracer')
    else:
        vardata = get_var(varname)

    if multiplicity:
        vardata *= get_var('xi')
    return vardata


def plot_bubble(pfile, varname, dx=None, dz=None, vol=None, domain_length=None, xdim=None, ydim=None, gridfile='Torus_Triangles_1024x4_150m.nc', **kwargs):
    if any([_ is None for _ in (dx, dz, vol, domain_length)]):
        ginfo = grid_info(gridfile)

    ds = xr.open_dataset(pfile)
    ds['xloc'] = ds.longitude / (2*np.pi) * domain_length
    ds = ds.set_coords(['xloc', 'altitude'])
    lon = ds.longitude / (2*np.pi) * domain_length

    vardata = ds_get_var(ds, varname)

    if xdim is None:
        xdim = np.arange( float(lon.min())-2*dx, float(lon.max())+2*dx, dx )
    if ydim is None:
        ydim = np.arange( 0, 13500, dz )

    #Hcount, xedges, yedges = np.histogram2d(x=lon, y=ds.altitude, bins=(xdim, ydim), weights=get_var('xi') )
    Hmass,  xedges, yedges = np.histogram2d(x=lon, y=ds.altitude, bins=(xdim, ydim), weights=vardata)

    H = Hmass/vol
    print(f"{pfile=} {H.max()=}")
    X, Y = np.meshgrid(xedges, yedges, indexing='ij')

    plt.pcolormesh(X*1e-3, Y*1e-3, H, cmap=plt.cm.Spectral_r, norm=matplotlib.colors.LogNorm(vmin=1e-5,vmax=1e-2))
    cbar = plt.colorbar(orientation="horizontal", shrink=0.6, label=f'{varentry[varname][1]}/[m3]')
    plt.xlabel('horiz. distance [km]')
    plt.ylabel('height [km]')
    plt.gca().set_aspect('equal')
    time = float(pfile[9:-3])/60
    plt.title(f"{time=:.0f}min")

    return dict(dx=dx, dz=dz, vol=vol, domain_length=domain_length, xdim=xdim, ydim=ydim, gridfile=gridfile, Hmass=Hmass)




def _main(varname='m_tot'):
    ginfo = grid_info()
    pfiles = sorted(glob.glob('particles*.*.nc'))
    r0 = plot_bubble(pfiles[0], varname, **ginfo)
    r1 = plot_bubble(pfiles[-1], varname, **ginfo)
    r1['xdim'] = np.arange( np.minimum(r0['xdim'].min(), r1['xdim'].min()), np.maximum(r0['xdim'].max(), r1['xdim'].max()), ginfo['dx'] )

    plt.figure(figsize=(14,6), dpi=200)

    for pf in pfiles:
        out = f'{pf}.jpg'
        if os.path.exists(out):
            print(f"Skipping {out=}")
        else:
            plt.clf()
            plot_bubble(pf, varname, **r1)
            plt.savefig(f'{pf}.jpg', bbox_inches='tight')



if __name__ == '__main__':
    _main()
