#!/usr/bin/env python3

################################################################################
##                                                                            ##
##  This file is part of NCrystal (see https://mctools.github.io/ncrystal/)   ##
##                                                                            ##
##  Copyright 2015-2020 NCrystal developers                                   ##
##                                                                            ##
##  Licensed under the Apache License, Version 2.0 (the "License");           ##
##  you may not use this file except in compliance with the License.          ##
##  You may obtain a copy of the License at                                   ##
##                                                                            ##
##      http://www.apache.org/licenses/LICENSE-2.0                            ##
##                                                                            ##
##  Unless required by applicable law or agreed to in writing, software       ##
##  distributed under the License is distributed on an "AS IS" BASIS,         ##
##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  ##
##  See the License for the specific language governing permissions and       ##
##  limitations under the License.                                            ##
##                                                                            ##
################################################################################

"""

Script which can read vdos curves from either .ncmat files or simple
two-column ascii text files (the two columns being energy grid and density) and
help users prepare output suitable for inclusion in .ncmat files

"""

import sys
import os
import numpy as np
import argparse
import pathlib

#We handle the NCrystal import separately, since we know where it might be, even
#if the user did not set up PYTHONPATH correctly.
#NB: This NCrystal import code copied from inspectfile script. Be
#sure to update everywhere if performing changes!!!
try:
    import NCrystal
except ImportError:
    NCrystal = None
    #Perhaps PYTHONPATH was not setup correctly. If we are running out of a
    #standard NCrystal installation, we can try to amend the search path
    #manually for the current script:
    op=os.path
    test=op.abspath(op.realpath(op.join(op.dirname(__file__),'../python')))
    if op.exists(op.join(test,'NCrystal/__init__.py')):
        sys.path += [test]
    try:
        import NCrystal
    except ImportError:
        NCrystal = None
if not NCrystal:
    print("ERROR: Could not import the NCrystal python module.")
    print("       If it is installed, make sure your PYTHONPATH is setup correctly.")
    sys.exit(1)

NC = NCrystal

def getVDOSFromNCMat(fn):
    info = NC.createInfo(fn)
    di_vdos = [di for di in info.dyninfos if isinstance(di,NC.Info.DI_VDOS)]
    if len(di_vdos)>1:
        print(f"WARNING: Using first of multiple vdos curves from file {fn}")
    elif not di_vdos:
        raise SystemExit(f"ERROR: No vdos found in file {fn}")
    eg,ds = di_vdos[0].vdosOrigEgrid(),di_vdos[0].vdosOrigDensity()
    ds /= ds.max()
    if len(eg)==2:
        eg = np.linspace(eg[0],eg[1],len(ds))
    return eg,ds

units_2_fact = {
    'eV' : 1.0,
    'meV' : 1e-3,
    'keV' : 1e3,
    'MeV' : 1e6,
    'THz' : NC.constant_planck*1e12,
    'GHz' : NC.constant_planck*1e9,
    'MHz' : NC.constant_planck*1e6,
    'kHz' : NC.constant_planck*1e3,
    'Hz' : NC.constant_planck
}

def parseArgs():
    descr="""

Script which can read vdos curves from either .ncmat files or simple two-column
ascii text files (the two columns being energy grid and density) and help users
prepare output suitable for inclusion in .ncmat files. When the input is not an
.ncmat file, the user must specify the energy grid units with --unit.

Thus it is possible to plot the curve, compare against vdos curves from other
.ncmat files, apply low-E truncation (cutoff), regularise binning, and overlay
with an ideal Debye spectrum for a given Debye energy or Debye
temperature. Finally, when running without --plot, it will output the resulting
spectrum into format which is ready for inclusion in .ncmat files.

"""
    parser = argparse.ArgumentParser(description=descr)
    parser.add_argument("FILE",help="Either .ncmat file with VDOS or a text file with two colums of data: egrid and density")
    units_opts = ', '.join(sorted(units_2_fact.keys()))
    parser.add_argument("--unit",'-u',default='',type=str,help=f"Unit of egrid values. Possible options are {units_opts}.")
    parser.add_argument('--plot','-p', action='store_true',help='Plot extracted spectrum')
    parser.add_argument("--cutoff",'-c',nargs='+',default=None,type=float,
                        help="""Emin cutoff points in eV (more than one can be provided for simultaneous
                        inspection with --plot).""")
    parser.add_argument("--ref",'-r',nargs='+',default='',type=str,
                        help="""Optionally provide list of .ncmat files with
                        vdos data, to superimpose on plots.""")
    parser.add_argument("--forceregular",'-f',type=int,nargs='?',default=0,
                        help="""Optionally provide this argument to
                        reparameterise with that amount of linearly spaced
                        points in [emin,emax+epsilon], where epsilon is chosen
                        so the grid can be extended to 0 with a whole number of
                        bins. This format will be directly used by NCrystal
                        without on-the-fly reparameterisation upon loading.""")
    parser.add_argument("--debye",'-d',nargs='?',default='',type=str,
                        help="""Set to debye temperature (unit K) or egrid point
                        (units like meV, eV, THz, ...) in order to plot Debye
                        spectrum with that parameter on top.""")
    dpi_default=200
    parser.add_argument('--dpi', default=-1,type=int,
                        help="""Change plot resolution. Set to 0 to leave matplotlib defaults alone.
                        (default value is %i, or whatever the NCRYSTAL_DPI env var is set to)."""%dpi_default)

    args=parser.parse_args()
    if args.dpi>3000:
        parser.error('Too high DPI value requested.')
    if args.dpi==-1:
        _=os.environ.get('NCRYSTAL_DPI',None)
        if _:
            try:
                _=int(_)
                if _<0:
                    raise ValueError
            except ValueError:
                print("ERROR: NCRYSTAL_DPI environment variable must be set to integer >=0")
                raise SystemExit
            if _>3000:
                parser.error('Too high DPI value requested via NCRYSTAL_DPI environment variable.')
            args.dpi=_
        else:
            args.dpi=dpi_default

    if args.FILE.endswith('.ncmat'):
        if args.unit:
            parser.error(f'Do not supply --unit when input is .ncmat file')
        args.unit='eV'
    if args.unit and not args.unit in units_2_fact:
        parser.error(f'Unknown unit {args.unit}. Valid options are {units_opts}')
    if not args.unit and not args.plot:
        parser.error(f'Unless --plot is specified, must supply --unit of input data')
    if args.ref and not args.plot:
        parser.error(f'Option --ref requires --plot')
    if args.debye and not args.plot:
        parser.error(f'Option --debye requires --plot')
    if args.cutoff and len(args.cutoff)>1 and not args.plot:
        parser.error(f'Option --cutoff can only have one argument when not using --plot')
    if args.cutoff and len(args.cutoff)>1 and args.forceregular:
        parser.error(f'Option --cutoff can only have one argument when using --forceregular')
    if args.forceregular and not args.unit:
        parser.error('Do not use --forceregular without --unit')

    if args.debye:
        if args.debye.endswith('K'):
            args.debye = float(args.debye[0:-1])*NC.constant_boltzmann
        else:
            #find (longest, so "meV" does not trigger "eV") fitting unit:
            l=[ (len(u),u) for u in units_2_fact.keys() if args.debye.endswith(u) ]
            l.sort()
            if not l:
                parser.error("Option --debye requires unit (see --help)")
            unit = l[-1][1]
            args.debye = units_2_fact[unit] * float(args.debye[0:-len(unit)])
    return args

args=parseArgs()
args_file_basename=os.path.basename(args.FILE)

if args.FILE.endswith('.ncmat'):
    egrid,density = (_.copy() for _ in getVDOSFromNCMat(args.FILE))
else:
    _ = np.genfromtxt(args.FILE,dtype=[('egrid','f8'),('density','f8')])
    egrid=_['egrid'].copy()
    density=_['density'].copy()
    density /= density.max()

assert len(egrid) == len(density)
print (f"Loaded VDOS with {len(density)} grid points from {args_file_basename}")

numpy_is_sorted = lambda a: np.all(a[:-1] <= a[1:])
numpy_is_strongly_sorted = lambda a: np.all(a[:-1] < a[1:])

if not numpy_is_strongly_sorted(egrid):
    for i in range(len(egrid)-1):
        if not egrid[i] < egrid[i+1]:
            print("Problems detected in egrid points with values ",egrid[i],"and",egrid[i+1])
    raise SystemExit('ERROR: egrid values (first column) of input file are not in sorted'
                     +' (ascending) order, or there are identical elements.')

if args.unit:
    egrid *= units_2_fact[args.unit]

cutoffs=[]
if args.unit and args.cutoff:
    for c in args.cutoff:
        if c >= egrid[-1]:
            raise SystemExit(f'ERROR: Cutoff value {c} is higher than highest point in egrid')
        i=np.searchsorted(egrid,c)
        assert i==0 or egrid[i-1]<c
        assert egrid[i]>=c
        cutoffs+=[ (i, egrid[i] ) ]
        print(f" => Mapping cutoff value of {c} to grid point at {cutoffs[-1][1]}")

def applyCutoff(egrid,density,cutoffs):
    if cutoffs:
        assert len(cutoffs)==1
        c_idx,c_val = cutoffs[0]
        return egrid[c_idx:], density[c_idx:]
    return egrid,density

if args.forceregular or (not args.plot):
    if applyCutoff(egrid,density,cutoffs)[0][0]<=1e-5:
        raise SystemExit(f"""
        ERROR: The first value in the loaded egrid is {egrid[0]} which is less than 1e-5eV.
        This is not allowed when using --plot or --forceregular.
        Please use the --cutoff parameter to remove lowest part of input spectrum (perhaps
        after investigating the cutoff value with --plots).
        """)

def regularise(egrid,density,n):
    #first step back from any zeroes at the upper end:
    i=1
    while density[-i]==0.0:
        i=i+1
    safepeel = i-2
    if safepeel>=1:
        print (f"Ignoring {safepeel} last points while regularising since last {safepeel+1} points are 0.")
        egrid,density = egrid[0:-(safepeel)],density[0:-(safepeel)]
    emin,emax=egrid[0],egrid[-1]
    print('old range',emin,emax)
    THZ=NC.constant_planck*1e12,
    print('old range [THZ]',emin/THZ,emax/THZ)

    for k in range(1,1000000000):
        #k is number of bins below emin, an integral number by definition in a regularised grid.
        binwidth = emin/k
        nbins=int(np.floor((emax-emin)/binwidth))+1
        eps = (emin+nbins*binwidth)-emax
        assert eps>=0.0
        if nbins+1 >= n:
            break
    n=nbins+1
    binwidth = emin/k
    new_emax = emin + (n-1) * binwidth
    if abs( (new_emax-binwidth) - emax ) < 1e-3*binwidth:
        nbins -= 1
        n -= 1
        new_emax -= binwidth
    print (f" ==> Choosing regular grid with n={n} pts from emin={emin} to emax={new_emax} ({new_emax-emax} beyond old emax)")
    assert new_emax >= emax-binwidth*1.001e-3
    new_egrid = np.linspace(emin,new_emax,n)
    test=new_egrid[0] / ( (new_egrid[-1]-new_egrid[0])/(len(new_egrid)-1) )
    assert abs(round(test)-test)<1e-6,f'{test}'
    new_density = np.interp(new_egrid,egrid,density, left=0.0, right=0.0)
    print(f'last density values in new grid: ',new_density[-5:])
    return new_egrid,new_density

if args.forceregular:
    regularised_egrid,regularised_density = regularise(*applyCutoff(egrid,density,cutoffs),args.forceregular)

if args.plot:
    import matplotlib as mpl
    mpl.rcParams['figure.dpi']=args.dpi
    #ability to quit plot windows with Q:
    if 'keymap.quit' in mpl.rcParams and not 'q' in mpl.rcParams['keymap.quit']:
        mpl.rcParams['keymap.quit'] = tuple(list(mpl.rcParams['keymap.quit'])+['q','Q'])
    import matplotlib.pyplot as plt


    def common():
        plt.title(os.path.basename(args.FILE))
        plt.grid(ls=':')
        plt.show()

    if not args.unit:
        plt.xlabel('Unknown unit (set with --unit flag)')
        plt.plot(egrid,density)
        common()
    else:
        vis_unit = ('meV',1000.0)
        plt.xlabel(vis_unit[0])
        plt.plot(egrid*vis_unit[1],density,'o-',label=args_file_basename)
        if args.forceregular:
            plt.plot(regularised_egrid*vis_unit[1],regularised_density,'x-',label='regularised')
        for c_idx, c_val in cutoffs:
            d=density[c_idx]
            # f(x)=k*x^2, f(c_val)=d<=> k*c_val^2 = d <=> k = d/c_val^2
            x=np.linspace(0.0,c_val,3000)
            plt.plot(x*vis_unit[1],(d/c_val**2)*(x**2),label=f'with cutoff {c_val}')
        if args.debye:
            x=np.linspace(0.0,max(egrid.max(),args.debye),1000)
            y = np.where(  x<=args.debye, x**2 * ( density.max() / args.debye**2 ), 0.0 )
            plt.plot(x*vis_unit[1],y,
                     label=f'Debye spectrum (E_Debye={1000*args.debye:.5}meV, T_Debye={args.debye/NC.constant_boltzmann:.5}K)')
        for r in args.ref:
            eg,ds = getVDOSFromNCMat(r)
            plt.plot(eg*vis_unit[1],ds,label=os.path.basename(r))
        plt.legend()
        common()

    sys.exit(0)

if args.forceregular:
    egrid, density = regularised_egrid,regularised_density
else:
    egrid, density = applyCutoff(egrid,density,cutoffs)

#Check if egrid is linspace:
binwidth = (egrid[-1]-egrid[0])/(len(egrid)-1)
is_linspace=True
if not args.forceregular:
    for i in range(len(egrid)-1):
        bw=egrid[i+1]-egrid[i]
        if abs(binwidth-bw)>0.01*binwidth:
            is_linspace=False
            break
    if is_linspace:
        print('NB: Detected linearly spaced input egrid')

outfn=pathlib.Path('converted_output.ncmat')
with outfn.open('wt') as fn:
    fn.write(f"""NCMAT v2
#Autogenerated file from {args_file_basename}.
@DENSITY
  1.0 g_per_cm3 #FIX{'ME'}!! Please replace with proper value, or remove and provide crystal structure!
@DYNINFO
  element  <UNKNOWN-PLEASE-EDIT>
  fraction 1
  type     vdos\n""")
    if is_linspace:
        fn.write(f'  vdos_egrid {egrid[0]:.14} {egrid[-1]:.14}')
    else:
        fn.write(NC.formatVectorForNCMAT('vdos_egrid',egrid))
    fn.write('\n')
    fn.write((NC.formatVectorForNCMAT('vdos_density',density/density.max())))
    fn.write('\n')
    print(f"Wrote {outfn}")
