# **************************************************************************
# * Authors:  Mohamad Harastani          (mohamad.harastani@upmc.fr)
# *           Rémi Vuillemot             (remi.vuillemot@upmc.fr)
# * IMPMC, UPMC Sorbonne University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# **************************************************************************
from os.path import basename
from pwem.emlib import MetaData, MDL_ORDER
from pyworkflow.protocol.params import StringParam, LabelParam
from pyworkflow.viewer import (ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO)
from pyworkflow.utils import replaceBaseExt, replaceExt
from continuousflex.protocols.data import Point, Data
from continuousflex.viewers.nma_plotter import FlexNmaPlotter
from continuousflex.protocols import FlexProtSynthesizeImages
import xmipp3
import pwem.emlib.metadata as md
from pwem.viewers import ObjectView
from continuousflex.protocols.protocol_image_synthesize import NMA_YES
import matplotlib.pyplot as plt
[docs]class FlexProtSynthesizeImageViewer(ProtocolViewer):
    """ Visualization of results from synthesized images
    """
    _label = 'viewer synthetic images'
    _targets = [FlexProtSynthesizeImages]
    _environments = [DESKTOP_TKINTER, WEB_DJANGO]
    def __init__(self, **kwargs):
        ProtocolViewer.__init__(self, **kwargs)
        self._data = None
[docs]    def getData(self):
        if self._data is None:
            self._data = self.loadData()
        return self._data 
           
    def _defineParams(self, form):
        form.addSection(label='Visualization')
        form.addParam('displayRawDeformation', StringParam, default='7 8',
                      condition=self.protocol.confVar.get()==NMA_YES,
                      label='Display the computed normal-mode amplitudes',
                      help='Type 7 to see the histogram of amplitudes along mode 7; \n'
                           'type 8 to see the histogram of amplitudes along mode 8, etc.\n'
                           'Type 7 8 to see the 2D plot of amplitudes along modes 7 and 8.\n'
                           'Type 7 8 9 to see the 3D plot of amplitudes along modes 7, 8 and 9; etc.'
                           )
        form.addParam('displayHists', LabelParam,
                      label="Display shift and angle histograms",
                      help="Display shift and angle histograms")
        form.addParam('displayVolumes', LabelParam,
                      label="Display images",
                      help="Display the volumes that are generated")
    def _getVisualizeDict(self):
        return {'displayRawDeformation': self._viewRawDeformation,
                'displayHists': self._viewHists,
                'displayVolumes': self._viewVolumes,
                } 
                        
    def _viewVolumes(self, paramName):
        volumes = self.protocol.outputImages
        return [ObjectView(self._project, volumes.strId(), volumes.getFileName())]
    def _viewHists(self, paramName):
        mdVolumes = md.MetaData(self.protocol._getExtraPath('GroundTruth.xmd'))
        X = []
        Y = []
        Rot = []
        Tilt= []
        Psi = []
        for objId in mdVolumes:
            X.append(mdVolumes.getValue(md.MDL_SHIFT_X,objId))
            Y.append(mdVolumes.getValue(md.MDL_SHIFT_Y, objId))
            Rot.append(mdVolumes.getValue(md.MDL_ANGLE_ROT, objId))
            Tilt.append(mdVolumes.getValue(md.MDL_ANGLE_TILT, objId))
            Psi.append(mdVolumes.getValue(md.MDL_ANGLE_PSI, objId))
        fig, ax = plt.subplots(2, 3)
        fig.suptitle('Histogram of generated rigid-body paramerers')
        ax[0,0].hist(X, bins=25)
        ax[0,0].set_title('Shift X')
        ax[0,1].hist(Y, bins=25)
        ax[0,1].set_title('Shift Y')
        ax[1,0].hist(Rot, bins=25)
        ax[1,0].set_title('Rot')
        ax[1,1].hist(Tilt, bins=25)
        ax[1,1].set_title('Tilt')
        ax[1,2].hist(Psi, bins=25)
        ax[1,2].set_title('Psi')
        plt.show()
    def _viewRawDeformation(self, paramName):
        components = self.displayRawDeformation.get()
        return self._doViewRawDeformation(components)
        
    def _doViewRawDeformation(self, components):
        components = list(map(int, components.split()))
        dim = len(components)
        views = []
        
        if dim > 0:
            modeList = []
            modeNameList = []
            missingList = []
            
            for modeNumber in components:
                found = False
                modes_md = MetaData(replaceExt(self.protocol.inputModes.get().getFileName(),'xmd'))
                for i, objId in enumerate(modes_md):
                    modeId = modes_md.getValue(MDL_ORDER, objId)
                    if modeNumber == modeId:
                        modeNameList.append('Mode %d' % modeNumber)
                        modeList.append(i)
                        found = True
                        break
                if not found:
                    missingList.append(str(modeNumber))
                    
            if missingList:
                return [self.errorMessage("Invalid mode(s) *%s*\n." % (', '.join(missingList)), 
                              title="Invalid input")]
            
            # Actually plot
            plotter = FlexNmaPlotter(data=self.getData())
            baseList = [basename(n) for n in modeNameList]
            
            self.getData().XIND = modeList[0]
            if dim == 1:
                plotter.plotArray1D("Histogram of normal-mode amplitudes: %s" % baseList[0], 
                                    "Amplitude", "Number of images")
            else:
                self.getData().YIND = modeList[1]
                if dim == 2:
                    plotter.plotArray2D_xy("Normal-mode amplitudes: %s vs %s" % tuple(baseList), *baseList)
                elif dim == 3:
                    self.getData().ZIND = modeList[2]
                    plotter.plotArray3D_xyz("Normal-mode amplitudes: %s %s %s" % tuple(baseList), *baseList)
            views.append(plotter)
            
        return views
    
[docs]    def loadData(self):
        mdVolumes = md.MetaData(self.protocol._getExtraPath('GroundTruth.xmd'))
        data = Data()
        for objId in mdVolumes:
            # pointData = list(map(float, particle._xmipp_nmaDisplacements))
            pointData = list(mdVolumes.getValue(md.MDL_NMA,objId))
            # inserting 6 zeros for the first 6 never used modes
            for j in range(6):
                pointData.insert(0, 0)
            # print(pointData)
            data.addPoint(Point(pointId=objId,
                                data=pointData,
                                weight=0.0))
        return data