#!/usr/bin/env python
import os
import numpy as np

from raytraverse.lightfield import ZonalLightResult, LightResult
from raytraverse.utility import pool_call
from raytraverse import api


def checkDST(ts, year=2022):
    """returns DST hour offset"""
    # USA
    # dst_table = {
    #         2020: {'s': (3, 8), 'e': (11, 1)},
    #         2021: {'s': (3, 14), 'e': (11, 7)},
    #         2022: {'s': (3, 13), 'e': (11, 6)},
    #         2023: {'s': (3, 12), 'e': (11, 5)},
    #         2024: {'s': (3, 10), 'e': (11, 3)},
    #         2025: {'s': (3, 9), 'e': (11, 2)},
    #         2026: {'s': (3, 8), 'e': (11, 1)},
    #         2027: {'s': (3, 14), 'e': (11, 7)},
    #         2028: {'s': (3, 12), 'e': (11, 5)},
    #         2029: {'s': (3, 11), 'e': (11, 4)}
    #         }
    # Switzerland
    dst_table = {
        2022: {'s': (3, 27), 'e': (10, 30)}
    }
    dst = np.zeros(len(ts))
    dst_start = dst_table[year]['s']
    dst_end = dst_table[year]['e']
    dst[np.logical_and(dst_start[0] < ts[:, 0], ts[:, 0] < dst_end[0])] = 1
    dst[np.logical_and(dst_start[0] == ts[:, 0], ts[:, 1] > dst_start[1])] = 1
    dst[np.logical_and(dst_end[0] == ts[:, 0], ts[:, 1] < dst_end[1])] = 1
    return dst


def resample_and_check(zlr, pm, skyd):
    """a low resolution reinterpretration and generation of mar 21 plan images to check data"""
    points = pm.point_grid(False)
    # reinterpolate point sampling to be same across all timesteps (ZonalLightResult -> LightResult)
    lr = zlr.rebase(points)
    lr.write("sunresults_grid.npz")
    # filter mar 21st for check images
    skyfilter = skyd.masked_idx(np.arange(1896, 1920))
    lr.pull2planhdr("plane.rad", "testg", sky=skyfilter)
    return lr


def compute_lr(lr, whours, grid="2'"):
    """compute ASE for a gridded LightResult
    
    Returns
    -------
    45th-99th hour percentile coverages
    """
    data, labels, names = lr.pull("point", metric=[lr.axis("metric").index("illum")])
    direct_sun = data[whours] > 1000
    ase = np.sum(direct_sun, axis=0)
    ase250 = np.sum(ase > 250) / len(ase)
    qase250 = np.quantile(ase, 1-ase250)
    qase10 = np.quantile(ase, .9)
    print(f"{grid} gridded data:")
    print(f"ASE_250: {ase250:.03f}")
    print(f"quantile hours @ASE_250: {qase250:.02f}")
    print(f"quantile hours @10%: {qase10:.02f}")
    aseh = np.sum(direct_sun, axis=1) / len(ase)
    return np.quantile(aseh, np.arange(.45, 1, .01))


def compute_zlr(zlr, whours):
    """comute sun coverage for non-gridded ZonalLightResult
    
    Returns
    -------
    45th-99th hour percentile coverages
    
    """
    dsun = []
    for d, wh in zip(zlr.data, whours):
        if wh:
            dsun.append(np.sum(d[..., 3] * (d[..., 4] > 1000))/np.sum(d[..., 3]))
    return np.quantile(dsun, np.arange(.45, 1, .01))


def main():
    # load with 2' grid for ASE calc
    scn, pm, skyd = api.auto_reload("office", "plane.rad", ptres=0.6)
    print(skyd.daysteps)
    zlr = ZonalLightResult("sunresults_full.npz")
    
    # only compute LightResult and check images if file does not exist
    # otherwise reload from file
    if os.path.isfile("sunresults_grid.npz"):
        lr = LightResult("sunresults_grid.npz")
    else:
        print(zlr.info())
        lr = resample_and_check(zlr, pm, skyd)
        print(lr.info())
    
    # pull to 2d to view density of sampling
    d, l, n = zlr.pull("metric", metric=[zlr.axis("metric").index("illum")])
    print("size of zonal data:", d.shape)
    
    # generate working hours mask including DST
    dates = lr.axis("sky").value_array()
    dst = checkDST(dates)
    dates[:, 2] += dst
    whours = np.logical_and(dates[:, 2] > 8, dates[:, 2] < 18)
    
    # hold percentile results
    results = []
    # x-axis: annual working hours @ each percentile
    results.append((1 - np.arange(.45, 1, .01)) * whours.size)
    
    # compute for zonal result
    results.append(compute_zlr(zlr, whours))
    # compute for 2' grid
    results.append(compute_lr(lr, whours, "2'"))
    
    d, l, n = lr.pull("metric", metric=[lr.axis("metric").index("illum")])
    print("size of 2' data:", d.shape) # (3405600, 1)
    print()
    
    # repeat with a 1' grid
    points = pm.point_grid(False, 1)
    lr2 = zlr.rebase(points)
    results.append(compute_lr(lr2, whours, "1'"))
    
    d, l, n = lr2.pull("metric", metric=[lr2.axis("metric").index("illum")])
    print(d.shape)
    print()
    
    # repeat with a 6" grid"
    points = pm.point_grid(False, 2)
    lr3 = zlr.rebase(points)
    results.append(compute_lr(lr3, whours, '6"'))
    
    d, l, n = lr3.pull("metric", metric=[lr3.axis("metric").index("illum")])
    print(d.shape)
    print()
    
    # save results for plotting
    np.savetxt("ase_interpolated.txt", np.stack(results).T)
    
if __name__ == '__main__':
    main()