#!/usr/bin/env python

################################################################################
##                                                                            ##
##  This file is part of NCrystal (see https://mctools.github.io/ncrystal/)   ##
##                                                                            ##
##  Copyright 2015-2020 NCrystal developers                                   ##
##                                                                            ##
##  Licensed under the Apache License, Version 2.0 (the "License");           ##
##  you may not use this file except in compliance with the License.          ##
##  You may obtain a copy of the License at                                   ##
##                                                                            ##
##      http://www.apache.org/licenses/LICENSE-2.0                            ##
##                                                                            ##
##  Unless required by applicable law or agreed to in writing, software       ##
##  distributed under the License is distributed on an "AS IS" BASIS,         ##
##  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  ##
##  See the License for the specific language governing permissions and       ##
##  limitations under the License.                                            ##
##                                                                            ##
################################################################################

#########################################################################
########################### System setup ################################
#########################################################################

#enable py3 behaviour in py2.6+ (avoid unicode_literals on purpose!):
from __future__ import division, print_function, absolute_import
__metaclass__ = type  #classes are new-style without inheriting from "object"

import sys
import os

#TODO: Provide a --list file option

#in unit test we dont display interactive images and we reduce cpu consumption
#by watching for special env var:
_unittest = bool(os.environ.get('NCRYSTAL_INSPECTFILE_UNITTESTS',0))

pyversion = sys.version_info[0:3]
if pyversion < (2,7,5) or (pyversion >= (3,0,0) and pyversion < (3,4,0)):
    if not _unittest:
        print('WARNING: Unsupported python version %i.%i.%i detected (recommended is v2.7.5/v3.4.0 or latter).'%pyversion)

#Function for importing required python modules which may be missing, to provide
#a somewhat more helpful error to the user:
def import_optpymod(name):
    try:
        import importlib
    except ImportError:
        importlib=None
    if importlib:
        try:
            themod = importlib.import_module(name)
        except ImportError:
            themod = None
    else:
        #python 2.6 does not have importlib (not that we actually support 2.6...)
        namespace={}
        try:
            exec('import %s as mod'%name,namespace)
            themod=namespace['mod']
        except ImportError:
            themod=None
    if not themod:
        print('ERROR: Could not import a required python module: %s'%name)
        print('       Please make sure it is installed on your system.')
        sys.exit(1)
    return themod

#We handle the NCrystal import separately, since we know where it might be, even
#if the user did not set up PYTHONPATH correctly.
#NB: This NCrystal import code copied also used in other shipped scripts, be
#sure to update everywhere if performing changes!!!
try:
    import NCrystal
except ImportError:
    NCrystal = None
    #Perhaps PYTHONPATH was not setup correctly. If we are running out of a
    #standard NCrystal installation, we can try to amend the search path
    #manually for the current script:
    op=os.path
    test=op.abspath(op.realpath(op.join(op.dirname(__file__),'../python')))
    if op.exists(op.join(test,'NCrystal/__init__.py')):
        sys.path += [test]
    try:
        import NCrystal
    except ImportError:
        NCrystal = None
if not NCrystal:
    print("ERROR: Could not import the NCrystal Python module.")
    print("       If it is installed, make sure your PYTHONPATH is setup correctly.")
    sys.exit(1)

#########################################################################
#########################################################################
#########################################################################

def parse_cmdline():
    argparse = import_optpymod('argparse')

    descr="""
Load input crystal files (usually .ncmat, .nxs, .laz or .lau) files with
NCrystal (v%s) and plot resulting isotropic cross sections for thermal
neutrons. Specifying more than one file results in a single comparison plot of
the total scattering cross section based on the different materials, whereas
specifying just a single file, results in a more detailed cross section plot as
well as a 2D plot of generated scatter angles."""%NCrystal.__version__

    descr=descr.strip()

    epilog="""
examples:
  $ %(prog)s Al_sg225.ncmat
  plot aluminium cross sections and scatter-angles versus neutron wavelength.
  $ %(prog)s Al_sg225.ncmat Ge_sg227.ncmat --common temp=200
  cross sections for aluminium and germanium at T=200K
  $ %(prog)s "Al_sg225.ncmat;dcutoff=0.1" "Al_sg225.ncmat;dcutoff=0.4" "Al_sg225.ncmat;dcutoff=0.8"
  effect of d-spacing cut-off on aluminium cross sections
  $ %(prog)s "Al_sg225.ncmat;temp=20" "Al_sg225.ncmat;temp=293.15" "Al_sg225.ncmat;temp=600"
  effect of temperature on aluminium cross sections"""

    parser = argparse.ArgumentParser(description=descr,
                                     epilog=epilog,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    default_temp=293.15
    default_dcut=0.4
    parser.add_argument('input_files', metavar='FILE', type=str, nargs='*',
                        help="""Input crystallographic file in NXS, Laz or Lau format. It is possible to
                        specify additional parameters by appending them to each
                        file by using the usual format accepted by NCMatCfg (see
                        also examples below).""")
    parser.add_argument('-d','--dump', action='store_true',
                        help='Print derived crystallographic information rather than displaying plots')
    parser.add_argument('--common','-c', metavar='CFG', type=str, default=[],
                        help='Common configuration items that will be applied to all files',action='append')
    parser.add_argument('--coh_elas','--bragg', action='store_true',
                        help="""Only generate coherent-elastic (Bragg diffraction) component""")
    parser.add_argument('--incoh_elas', action='store_true',
                        help="""Only generate incoherent-elastic component""")
    parser.add_argument('--elastic', action='store_true',
                        help="""Only generate elastic component""")
    parser.add_argument('--inelastic', action='store_true',
                        help="""Only generate inelastic components""")
    parser.add_argument('-a','--absorption', action='store_true',
                        help="""Include absorption in cross section plots""")
    parser.add_argument('-p','--pdf', action='store_true',
                        help="""Generate PDF file rather than launching an interactive plot viewer.""")
    parser.add_argument('--test', action='store_true',
                        help="""Perform quick validation of NCrystal installation.""")
    dpi_default=200
    parser.add_argument('--dpi', default=-1,type=int,
                        help="""Change plot resolution. Set to 0 to leave matplotlib defaults alone.
                        (default value is %i, or whatever the NCRYSTAL_DPI env var is set to)."""%dpi_default)

    args=parser.parse_args()

    if args.dpi>3000:
        parser.error('Too high DPI value requested.')

    if args.dpi==-1:
        _=os.environ.get('NCRYSTAL_DPI',None)
        if _:
            try:
                _=int(_)
                if _<0:
                    raise ValueError
            except ValueError:
                print("ERROR: NCRYSTAL_DPI environment variable must be set to integer >=0")
                raise SystemExit
            if _>3000:
                parser.error('Too high DPI value requested via NCRYSTAL_DPI environment variable.')
            args.dpi=_
        else:
            args.dpi=dpi_default

    if args.test:
        if any((args.input_files,args.dump,args.coh_elas,args.incoh_elas,args.elastic,args.inelastic,args.absorption,args.pdf)):
            parser.error('Do not specify other arguments with --test.')
    elif not args.input_files:
        parser.error('Missing input file arguments')

    ncomp_select = sum((1 if _ else 0) for _ in (args.coh_elas,args.incoh_elas,args.elastic,args.inelastic))
    if ncomp_select > 1:
        parser.error('Do not specify more than one of: --coh_elas/--bragg, --incoh_elas, --elastic or --inelastic.')

    if args.coh_elas: args.comp = 'coh_elas'
    elif args.incoh_elas: args.comp = 'incoh_elas'
    elif args.elastic: args.comp = 'elastic'
    elif args.inelastic: args.comp = 'inelastic'
    else: args.comp = 'all'

    if args.absorption and ncomp_select>0:
        parser.error('Do not specify --absorption with either of: --coh_elas/--bragg, --incoh_elas, --elastic or --inelastic.')

    if args.dump and (ncomp_select>0 or args.absorption):
        parser.error('Do not specify --dump with either of: --coh_elas/--bragg, --incoh_elas, --elastic, --inelastic or --absorption.')

    if args.dump and len(args.input_files)>1:
        parser.error('Do not specify more than one input file with --dump [-d].')

    args.common=';'.join(args.common)
    return args

def create_wavelengths(np,cfgs,npoints,logspace=False):
    pcbraggs = [_f for _f in (c.get_scatter('coh_elas',allowfail=True) for c in cfgs) if _f]
    fallback = 10.0#materials with no bragg threshold
    bragg_thresholds = [(NCrystal.ekin2wl(c.domain()[0]) or fallback) for c in pcbraggs if not c._nullprocess]
    max_bragg_threshold = max(bragg_thresholds) if bragg_thresholds else None
    max_bragg_threshold = max_bragg_threshold or fallback
    wls_max = float(int(max_bragg_threshold*1.01+1.0))
    wls_min = 0.001
    wls_max *= 1.2
    return (np.logspace if logspace else np.linspace)(wls_min,wls_max,npoints)

_mpldpi=[None]
_pdffilename='ncrystal.pdf'
_npplt = None
def import_npplt(pdf=False):
    global _npplt
    if _npplt:
        #pdf par must be same as last call:
        assert bool(pdf)==bool(_npplt[2] is not None)
        return _npplt
    np = import_optpymod('numpy')
    mpl = import_optpymod('matplotlib')
    if not mpl.compare_versions(mpl.__version__, '0.99.1.1'):
        print("ERROR: Your version of matplotlib (%s) is too ancient to work - aborting plotting!"%mpl.__version__)
        sys.exit(1)
    if not mpl.compare_versions(mpl.__version__, '1.3'):
        if not _unittest:
            print("WARNING: Your version of matplotlib (%s) is unsupported - expect trouble! (needs at least version 1.3)."%mpl.__version__)

    if _mpldpi[0]:
        mpl.rcParams['figure.dpi']=_mpldpi[0]

    #ability to quit plot windows with Q:
    if 'keymap.quit' in mpl.rcParams and not 'q' in mpl.rcParams['keymap.quit']:
        mpl.rcParams['keymap.quit'] = tuple(list(mpl.rcParams['keymap.quit'])+['q','Q'])

    if _unittest or pdf:
        mpl.use('agg')
    if pdf:
        try:
            from matplotlib.backends.backend_pdf import PdfPages
        except ImportError:
            print("ERROR: Your installation of matplotlib does not have required support for PDF output.")
            sys.exit(1)
    plt = import_optpymod('matplotlib.pyplot')

    _npplt = (np,plt,PdfPages(_pdffilename) if pdf else None)
    return _npplt

def getfields(cfgstr):
    return set([_f for _f in [e.split('=',1)[0] for e in cfgstr.strip().split(';') if '=' in e] if _f])

def combine_cfgs(cfg1,cfg2,prefercfg1=True):
    #assumes cfg1 starts with "filename;" and possibly the ignorefilecfg
    #keyword. Will append cfg2 items after the filename (& ignorefilecfg kw) and
    #before other cfg1 items, in order to give preference to cfg1 items.
    c1=cfg1.split(';')
    if not prefercfg1:
        return ';'.join(c1+cfg2.split(';'))
    filename = [c1.pop(0)]
    if c1 and c1[0].strip()=='ignorefilecfg':
        filename += [c1.pop(0).strip()]
    return ';'.join([_f for _f in (e.strip() for e in filename + cfg2.split(';')+c1) if _f])

#functions for creating labels and title:

def _remove_common_keyvals(dicts):
    """remove any key from the passed dicts which exists with identical value in all
    the dicts. Returns a single dictionary with entries thus removed."""
    sets=[set((k,v) for k,v in list(d.items())) for d in dicts]
    common = dict(set.intersection(*sets))
    for k in list(common.keys()):
        for d in dicts:
            d.pop(k,None)
    return common

def _classify_differences(parts):
    l=[]
    for p in parts:
        cfg = 'FILENAME=%s;COMPNAME=%s' % (p._cfg.cfgstr,p._compname)
        raw = ['ignorefilecfg=1' if e.strip()=='ignorefilecfg' else e for e in cfg.split(';')]
        l += [ dict([s.strip() for s in e.split('=',1)] for e in raw if '=' in e) ]
    common = _remove_common_keyvals(l)
    return common,l

def _cfgdict_to_str(cfgdict):
    fn = cfgdict.pop('FILENAME','')
    o = [fn] if fn else []
    cn = cfgdict.pop('COMPNAME','')
    if cfgdict:
        o += [', '.join(('%s=%s'%(k,v) if k!='ignorefilecfg' else k) for k,v in sorted(cfgdict.items()))]
    if cn:
        o += [ { 'coh_elas':'Coherent elastic',
                 'incoh_elas':'Incoherent elastic',
                 'elastic':'Elastic',
                 'inelastic':'Inelastic',
                 'absorption':'Absorption',
                 'all':'Total scattering',
                 'elastic+inelastic+absorption':'Total scattering+Absorption'}[cn] ]
    return ' '.join(b%a for a,b in zip(o,['%s','[%s]','(%s)']))

def create_title_and_labels(parts):
    partscfg_common,partscfg_unique = _classify_differences(parts)
    return _cfgdict_to_str(partscfg_common),[_cfgdict_to_str(uc) or 'default' for uc in partscfg_unique]

def _end_plot(plt,pdf):
    if _unittest:
        plt.savefig(open(os.devnull,'wb'),format='raw',dpi=10)
        plt.close()
    elif pdf:
        pdf.savefig()
        plt.close()
    else:
        plt.show()

def plot_xsect(cfgs,comp,absorption,pdf=False):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','all')
    assert not absorption or comp=='all'
    np,plt,pdf = import_npplt(pdf)
    plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Cross section [barn]')

    if len(cfgs)==1 and comp in ('all','elastic'):
        if comp=='all':
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas'),cfgs[0].get_scatter('inelastic')]
        else:
            parts=[cfgs[0].get_scatter('coh_elas'),cfgs[0].get_scatter('incoh_elas')]
        if absorption:
            assert comp=='all'
            parts += [cfgs[0].get_absorption()]
    else:
        if absorption:
            assert comp=='all'
            parts=[c.get_totalxsect() for c in cfgs]
        else:
            parts=[c.get_scatter(comp) for c in cfgs]
    parts = [p for p in parts if not p._nullprocess]
    title,labels = create_title_and_labels(parts)
    if len(set(labels))!=len(labels):
        print("WARNING: Comparing identical setups?")

    wavelengths = create_wavelengths(np,cfgs,3000)
    ekins = NCrystal.wl2ekin(wavelengths)
    need_tot = (len(cfgs)==1 and len(parts)>1)
    xsects_tot = None
    max_len_label = 0

    #colors inspired by http://www.mulinblog.com/a-color-palette-optimized-for-data-visualization/
    col_red = "#F15854"
    partcols = [
        "#5DA5DA", # (blue)
        "#FAA43A", # (orange)
        "#60BD68", # (green)
        #"#B2912F", # (brown)
        "#B276B2", # (purple)
        #"#DECF3F", # (yellow)
        #"#F17CB0", # (pink)
        "#4D4D4D", # (gray)
        ]
    if not need_tot:
        partcols = [col_red]+partcols

    linewidth = 2.0

    for ipart,part in enumerate(parts):
        cfg = part._cfg
        compname = part._compname
        xsects = part.crossSectionNonOriented(ekins)
        if need_tot:
            if xsects_tot is None:
                xsects_tot = np.zeros(len(xsects))
            xsects_tot += xsects
        label=labels[ipart]
        max_len_label = max(max_len_label,len(label))
        ls={0:'-',1:'--',2:':'}.get(ipart//len(partcols),'-.')
        plt.plot(wavelengths,xsects,label=label,color=partcols[ipart%len(partcols)],lw=linewidth,ls=ls)
    if need_tot:
        plt.plot(wavelengths,xsects_tot,
                 label={'all':'Total','elastic':'Total elastic'}[comp],
                 color="#F15854",lw=linewidth)#red-ish colour (see above)
    leg_fsize = 'large'
    if max_len_label > 40: leg_fsize = 'medium'
    if max_len_label > 60: leg_fsize = 'small'
    if max_len_label > 80: leg_fsize = 'smaller'
    try:
        if len(parts)>1:
            leg=plt.legend(loc='best',fontsize=leg_fsize)
            if hasattr(leg,'set_draggable'):
                leg.set_draggable(True)
            else:
                leg.draggable(True)
    except TypeError:
        plt.legend(loc='best')
    plt.grid()
    plt.gca().set_xlim(0.0,wavelengths[-1])
    plt.gca().set_ylim(0.0,None)
    if title:
        plt.title(title)
    try:
        plt.tight_layout()
    except (ValueError, AttributeError):
        plt.subplots_adjust(bottom=0.15, right=0.9, top=0.9, left = 0.15)
        pass
    _end_plot(plt,pdf)

def plot_2d_scatangle(cfg,comp,pdf=False):
    assert comp in ('coh_elas','incoh_elas','elastic','inelastic','all')
    np,plt,pdf = import_npplt(pdf)

    #reproducible plots:
    import random
    random.seed(123456)

    #higher granularity wavelengths than for 1D plot to avoid artifacts:
    wavelengths = create_wavelengths(np,[cfg],100 if _unittest else 30000 )

    part=cfg.get_scatter(comp)

    #get title (label should be uninteresting for a single part):
    title,labels = create_title_and_labels([part])

    #First figure out how many points to put at each wavelength
    xsects = part.crossSectionNonOriented(NCrystal.wl2ekin(wavelengths))

    n2d = 100 if _unittest else 25000
    sumxs = np.sum(xsects)
    if sumxs:
        n_at_wavelength = np.random.poisson(xsects*n2d/np.sum(xsects))
    else:
        n_at_wavelength = np.zeros(len(xsects))

    n2d=int(np.sum(n_at_wavelength))#correction for random fluctuations
    plot_angles = np.zeros(n2d)
    plot_wls = np.zeros(n2d)

    j = 0
    for i,n in enumerate(n_at_wavelength):
        i,n = int(i),int(n)
        ang,de = part.generateScatteringNonOriented(NCrystal.wl2ekin(wavelengths[i]),repeat=int(n))
        plot_angles[j:j+n] = ang
        plot_wls[j:j+n].fill(wavelengths[i])
        j+=n

    try:
        plt.scatter(plot_wls,plot_angles*(180./np.pi),alpha=0.2,marker='.',edgecolor=None,color='black',s=2,zorder=1)
    except ValueError:
        plt.scatter(plot_wls,plot_angles*(180./np.pi),alpha=0.2,edgecolor=None,color='black',s=2,zorder=1)

    plt.gca().set_xlim(0.0,wavelengths[-1])
    plt.gca().set_ylim(0,180)
    plt.xlabel('Neutron wavelength [angstrom]')
    plt.ylabel('Scattering angle [degrees]')
    if title:
        plt.title(title)
    plt.grid()
    try:
        plt.tight_layout()
    except (ValueError, AttributeError):
        plt.subplots_adjust(bottom=0.12, right=0.95, top=0.9, left = 0.12)
        pass
    _end_plot(plt,pdf)

def dump_info(info):
    NCrystal.dump(info)


def create_scatter(cfg):
    s=NCrystal.createScatter(cfg)
    if s.isOriented():
        print("Error: Can not create plots for oriented crystals")
    return s

class XSSum:
    def __init__(self,*processes):
        self._p = processes[:]
    def crossSectionNonOriented(self,ekin):
        return sum(p.crossSectionNonOriented(ekin) for p in self._p)

class Cfg:

    def __init__(self,cfgstr, commoncfgstr):
        #very crude normalisation, we should really let the C++ code NCMatCfg.cc do this splitting (todo):
        while '= ' in cfgstr:
            cfgstr=cfgstr.replace('= ','=')
        while ' ;' in cfgstr:
            cfgstr=cfgstr.replace(' ;',';')
        while cfgstr.endswith(' '):
            cfgstr=cfgstr[0:-1]
        self._cfgstr = combine_cfgs(cfgstr,commoncfgstr)
        f=getfields(self._cfgstr)
        forbidden=['bragg','coh_elas','incoh_elas','elas','bkgd']
        em="when using this script, use --coh_elas/--bragg/--incoh_elas/--elastic/--inelastic to enable/disable components rather than putting these variables in the cfg string: %s, inelas. Only exception is that inelas can be used to select alternative inelastic models."%(', '.join(forbidden))
        for _ in forbidden:
            assert not _ in f,em
        if 'inelas' in f:
            for _ in ['none','0','sterile','false']:
                assert not ';inelas=%s'%_ in self._cfgstr,em
        self._sc = {}
        self._abs = None
        self._totxs = None
        self._info = None

    def get_scatter(self,comp = 'all', allowfail = False):
        assert comp in ('all','coh_elas','incoh_elas','elastic','inelastic')
        if not comp in self._sc:
            cfgstr = combine_cfgs(self._cfgstr,{'coh_elas':'coh_elas=1;incoh_elas=0;inelas=0',
                                                'incoh_elas':'incoh_elas=1;coh_elas=0;inelas=0',
                                                'elastic':'elas=1;inelas=0',
                                                'inelastic':'elas=0'}.get(comp,''),False)
            try:
                sc = NCrystal.createScatter(cfgstr)
            except NCrystal.NCException:
                if allowfail:
                    return None
                else:
                    raise
            sc._nullprocess = bool(sc.domain()[0]>1e99)
            sc._cfg = self
            sc._compname = comp
            self._sc[comp] = sc
        return self._sc[comp]

    def get_absorption(self):
        if not self._abs:
            a = NCrystal.createAbsorption(self._cfgstr)
            a._nullprocess = bool(a.domain()[0]>1e99)
            a._cfg = self
            a._compname = 'absorption'
            self._abs = a
        return self._abs

    def get_totalxsect(self):
        if not self._totxs:
            a,s = self.get_absorption(),self.get_scatter('all')
            t = XSSum(a,s)
            t._nullprocess = a._nullprocess and s._nullprocess
            t._cfg = self
            t._compname = 'elastic+inelastic+absorption'
            self._totxs = t
        return self._totxs

    def get_info(self):
        if not self._info:
            self._info = NCrystal.createInfo(self._cfgstr)
        return self._info

    @property
    def cfgstr(self):
        return self._cfgstr

def main():
    args = parse_cmdline()
    _mpldpi[0] = args.dpi

    if args.test:
        NCrystal.test()
        sys.exit(0)

    cfgs=[Cfg(e,args.common) for e in args.input_files]
    if args.dump:
        assert len(cfgs)==1
        cfgs[0].get_info().dump()
        return
    plot_xsect( cfgs, comp  = args.comp, absorption = args.absorption, pdf=args.pdf )
    if len(cfgs)==1 and not bool(os.environ.get('NCRYSTAL_INSPECTFILE_NO2DSCATTER',0)):
        plot_2d_scatangle( cfgs[0], comp = args.comp, pdf=args.pdf )
    if args.pdf:
        _,_,pdf = import_npplt(True)
        import datetime
        try:
            d = pdf.infodict()
        except AttributeError:
            d={}
        d['Title'] = 'Plots made with NCrystal-inspectfile from file%s %s'%('' if len(args.input_files)==1 else 's',
                                                                            ','.join(os.path.basename(f) for f in args.input_files))
        d['Author'] = 'NCrystal %s (via inspectfile)'%NCrystal.__version__
        d['Subject'] = 'NCrystal plots'
        d['Keywords'] = 'NCrystal'
        d['CreationDate'] = datetime.datetime.today()
        d['ModDate'] = datetime.datetime.today()
        pdf.close()
        print("created %s"%_pdffilename)

if __name__=='__main__':
    main()

# TODO:
#   - allow tuning of n2d and alpha pars for 2D plot?
#   - show hkl values in 1d and 2d plot?
#   - option to show energy & mfp rather than wavelength & xsect?
#   - deltaE/wl_out plots as well (todo). Perhaps --wl=... option, resulting in plots for that wl.
