import os
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d
import morton

def latlon_2_cartesian(lon, lat, R):
    lat_r = lat #np.deg2rad(lat)
    lon_r = lon #np.deg2rad(lon)
    return np.cos(lat_r)*np.cos(lon_r)*R, np.cos(lat_r)*np.sin(lon_r)*R, np.sin(lat_r)*R

m = morton.Morton(dimensions=3, bits=32)

gridfile = "/archive/meteo/external-models/dwd/grids/icon_grid_0016_R02B06_G.nc"
gridfile = "/project/meteo/work/Fabian.Jakub/lib/icon-examples/08-offline-nest-in-03-with-circle-LAM_lowres/lam_R0.8_1244m/grid.lam_0.8deg_R1244m.nc"

grid=xr.open_dataset(gridfile)
if 'cartesian_x_vertices' in grid:
    cartesian_v = xr.concat([
        grid.cartesian_x_vertices,
        grid.cartesian_y_vertices,
        grid.cartesian_z_vertices],
        dim="v")
elif 'longitude_vertices' in grid:
    cartesian_v = xr.concat(latlon_2_cartesian(grid.longitude_vertices, grid.latitude_vertices, 1.),dim='v')
else:
    raise KeyError("Dont know how to compute cartesian coordinates of vertices")

try:
    cartesian_v = cartesian_v.drop(["vlon", "vlat"])
except ValueError:
    pass


limit = (np.iinfo(np.int32).max-1)/2

coords = cartesian_v * limit + limit

mc = [m.pack(*map(int,_)) for _ in coords.data.T]
order = np.argsort( mc )

cm="YlGnBu"
cm="viridis"

fig = plt.figure(1)
fig.clf()
ax = plt.axes(projection='3d')
x,y,z = cartesian_v.data
c=range(0,len(x))
p = ax.scatter3D(x, y, z, c=c, cmap=cm)
fig.colorbar(p)
plt.suptitle('original mesh numbering')

fig = plt.figure(2)
fig.clf()
ax = plt.axes(projection='3d')
x,y,z = cartesian_v.data[:,order]
c=range(0,len(x))
p = ax.scatter3D(x, y, z, c=c, cmap=cm)
fig.colorbar(p)
plt.suptitle('morton z-curve mesh')

from hilbertcurve.hilbertcurve import HilbertCurve

hilbertcurve = HilbertCurve(32, 3)
hc = [hilbertcurve.distance_from_point(list(map(int,_))) for _ in coords.data.T]
hc_order = np.argsort( hc )

fig = plt.figure(3)
fig.clf()
ax = plt.axes(projection='3d')
x,y,z = cartesian_v.data[:,hc_order]
c=range(0,len(x))
p = ax.scatter3D(x, y, z, c=c, cmap=cm)
fig.colorbar(p)
plt.suptitle('hilbert-curve mesh')

for i in (1,2,3):
    plt.figure(i)
    fname = f"sfc.{i}.{os.path.basename(gridfile).replace('nc','pdf')}"
    plt.savefig(fname, bbox_inches='tight')
