# **************************************************************************
# *
# * Authors:     Roberto Marabini (roberto@cnb.csic.es)
# *              J.M. De la Rosa Trevin (jmdelarosa@cnb.csic.es)
# *              Josue Gomez Blanco (josue.gomez-blanco@mcgill.ca)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from pyworkflow.object import Integer
from pyworkflow.utils import (getBoolListFromValues,
                              getStringListFromValues, getListFromValues)
from pwem.protocols import ProtRefine3D, ProtClassify3D
from pwem import emlib
from .projmatch_initialize import *
from .projmatch_form import _defineProjectionMatchingParams
from .projmatch_steps import *
from xmipp3.base import isXmippCudaPresent
[docs]class XmippProtProjMatch(ProtRefine3D, ProtClassify3D):
    """ Performs 3D reconstruction and classification using multireference projection matching. This method assigns images to references and reconstructs volumes, allowing for the separation of structural classes."""
    _label = 'projection matching'
    
    FILENAMENUMBERLENGTH = 6
    def __init__(self, **args):        
        ProtRefine3D.__init__(self, **args)
        ProtClassify3D.__init__(self, **args)
        self.numberOfCtfGroups = Integer(1)
        self._lastIter = Integer(0)
        
    def _initialize(self):
        """ This function is mean to be called after the 
        working dir for the protocol have been set. (maybe after recovery from mapper)
        """
        self._loadInputInfo()
        # Setup the dictionary with filenames templates to 
        # be used by _getFileName
        createFilenameTemplates(self)
        # Load the values from several params generating a list
        # of values per iteration or references
        initializeLists(self)
    def _loadInputInfo(self):
        from ...convert import getImageLocation
        
        reference = self.input3DReferences.get() # Input can be either a single volume or a set of volumes.
        
        if isinstance(reference, Volume): # Treat the case of a single volume
            self.referenceFileNames = [getImageLocation(reference)]
        else:
            self.referenceFileNames = [getImageLocation(vol) for vol in reference]
            
        self.numberOfReferences = len(self.referenceFileNames)
        self.resolSam = reference.getSamplingRate()
        
    #--------------------------- DEFINE param functions --------------------------------------------   
        
    def _defineParams(self, form):
        """ Since the form definition is very very large,
        we have do it in a separated function.
        """
        _defineProjectionMatchingParams(self, form)
         
         
    #--------------------------- INSERT steps functions --------------------------------------------  
    
    def _insertAllSteps(self):
        self._initialize()
        # Insert initial steps
        self._insertFunctionStep('convertInputStep')
        self._insertFunctionStep('executeCtfGroupsStep')
#         insertExecuteCtfGroupsStep(self)
#         insertInitAngularReferenceFileStep(self)
        self._insertFunctionStep('initAngularReferenceFileStep')
        # Steps per iteration
        self._insertItersSteps()
        # Final steps
        self._insertFunctionStep('createOutputStep')
        
    def _insertItersSteps(self):
        """ Insert several steps needed per iteration. """
        
        for iterN in self.allIters():
            dirsStep = self._insertFunctionStep('createIterDirsStep', iterN)
            # Insert some steps per reference volume
            projMatchSteps = []
            for refN in self.allRefs():
                # Mask the references in the iteration
                insertMaskReferenceStep(self, iterN, refN, prerequisites=[dirsStep])
                # Create the library of projections
                insertAngularProjectLibraryStep(self, iterN, refN)
                # Projection matching steps
                projMatchStep = self._insertProjectionMatchingStep(iterN, refN)
                projMatchSteps.append(projMatchStep)
                
            # Select the reference that best fits each image
            self._insertFunctionStep('assignImagesToReferencesStep', iterN, prerequisites=projMatchSteps)
            
            insertAngularClassAverageStep(self, iterN, refN)
    
            # Reconstruct each reference with new averages
            for refN in self.allRefs():
                # Create new class averages with images assigned
                insertReconstructionStep(self, iterN, refN)
                
                if self.doComputeResolution and self._doSplitReferenceImages[iterN]:
                    # Reconstruct two halves of the data
                    insertReconstructionStep(self, iterN, refN, 'Split1')
                    insertReconstructionStep(self, iterN, refN, 'Split2')
                    # Compute the resolution
                    insertComputeResolutionStep(self, iterN, refN)
                    
                insertFilterVolumeStep(self, iterN, refN)
            # Calculate both angles and shifts devitations for this iteration
            self._insertFunctionStep('calculateDeviationsStep', iterN)
    
    def _insertProjectionMatchingStep(self, iterN, refN):
        args = getProjectionMatchingArgs(self, iterN)
        return self._insertFunctionStep('projectionMatchingStep', iterN, refN, args)
    
    #--------------------------- STEPS functions --------------------------------------------       
        #copyFile(self.selFileName, self._getFileName('inputParticlesDoc'))
        
    
[docs]    def volumeConvertStep(self, reconstructedFilteredVolume, maskedFileName):
        runVolumeConvertStep(self, reconstructedFilteredVolume, maskedFileName) 
    
[docs]    def executeCtfGroupsStep(self, **kwargs):
        runExecuteCtfGroupsStep(self, **kwargs) 
    
    
[docs]    def angularProjectLibraryStep(self, iterN, refN, args, stepParams, **kwargs):
        runAngularProjectLibraryStep(self, iterN, refN, args, stepParams, **kwargs) 
        
[docs]    def initAngularReferenceFileStep(self):
        runInitAngularReferenceFileStep(self) 
        
[docs]    def projectionMatchingStep(self, iterN, refN, args):
        runProjectionMatching(self, iterN, refN, args) 
    
[docs]    def assignImagesToReferencesStep(self, iterN):
        runAssignImagesToReferences(self, iterN) 
        
[docs]    def cleanVolumeStep(self, vol1, vol2):
        cleanPath(vol1, vol2) 
    
[docs]    def reconstructionStep(self, iterN, refN, program, method, args, suffix, **kwargs):
        runReconstructionStep(self, iterN, refN, program, method, args, suffix, **kwargs) 
    
[docs]    def storeResolutionStep(self, resolIterMd, resolIterMaxMd, sampling):
        runStoreResolutionStep(self, resolIterMd, resolIterMaxMd, sampling) 
    
[docs]    def calculateFscStep(self, iterN, refN, args, constantToAdd, **kwargs):
        runCalculateFscStep(self, iterN, refN, args, constantToAdd, **kwargs) 
    
[docs]    def filterVolumeStep(self, iterN, refN, constantToAddToFiltration, **kwargs):
        runFilterVolumeStep(self, iterN, refN, constantToAddToFiltration, **kwargs) 
    
[docs]    def createOutputStep(self):
        runCreateOutputStep(self) 
    #--------------------------- INFO functions -------------------------------------------- 
    
    def _validate(self):
        errors = []
        if self.doCTFCorrection:
            if not self.doAutoCTFGroup and not exists(self.setOfDefocus.get()):
                errors.append("Error: for non-automated ctf grouping, "
                              "please provide a docfile!")
            if not self.inputParticles.get().hasCTF():
                errors.append("Error: for doing CTF correction the input "
                              "particles should have CTF information.")
        if self.numberOfMpi <= 1:
            errors.append("The number of MPI processes has to be larger than 1")
        
        self._validateDim(self.inputParticles.get(),
                          self.input3DReferences.get(),
                          errors, 'Input particles', 'Reference volume')
        if self.useGpu and not isXmippCudaPresent():
            errors.append("You have asked to use GPU, but I cannot find Xmipp GPU programs in the path")
        return errors
    
    def _citations(self):
        cites = []
        return cites
    
    def _summary(self):
        summary = []
        return summary
    
    def _methods(self):
        return self._summary()  # summary is quite explicit and serve as methods
    
    #--------------------------- UTILS functions --------------------------------------------
    
[docs]    def allIters(self):
        """ Iterate over all iterations. """
        for i in range(1, self.numberOfIterations.get()+1):
            yield i 
            
[docs]    def allRefs(self):
        """ Iterate over all references. """
        for i in range(1, self.numberOfReferences+1):
            yield i 
            
[docs]    def allCtfGroups(self):
        """ Iterate over all CTF groups. """
        for i in range(1, self.numberOfCtfGroups.get() + 1):
            yield i 
            
[docs]    def itersFloatValues(self, attributeName, firstValue=-1):
        """ Take the string of a given attribute and
        create a list of floats that will be used by 
        the iteratioins. An special first value will be
        added to the list for iteration 0.
        """
        valuesStr = self.getAttributeValue(attributeName)
        if valuesStr is None:
            raise Exception('None value for attribute: %s' % attributeName)
        return [firstValue] + getListFromValues(valuesStr, length=self.numberOfIterations.get(), caster=float) 
    
[docs]    def itersBoolValues(self, attributeName, firstValue=False):
        """ Take the string of a given attribute and
        create a list of booleans that will be used by 
        the iteratioins. An special first value will be
        added to the list for iteration 0.
        """
        valuesStr = self.getAttributeValue(attributeName)
        if valuesStr is None:
            raise Exception('None value for attribute: %s' % attributeName)
        return [firstValue] + getBoolListFromValues(valuesStr, length=self.numberOfIterations.get()) 
        
        
    def _getBlockFileName(self, blockName, blockNumber, filename, length=None):
        l = length or self.FILENAMENUMBERLENGTH
        
        return blockName + str(blockNumber).zfill(l) + '@' + filename
    
    def _getExpImagesFileName(self, filename):
        return self.blockWithAllExpImages + '@' + filename
    
    def _getRefBlockFileName(self, ctfBlName, ctfBlNumber, refBlName, refBlNumber, filename, length=None):
        l = length or self.FILENAMENUMBERLENGTH
        
        return ctfBlName + str(ctfBlNumber).zfill(l) + '_' + refBlName + str(refBlNumber).zfill(l) + '@' + filename
    def _getFourierMaxFrequencyOfInterest(self, iterN, refN):
        """ Read the corresponding resolution metadata and return the
        desired resolution.
        """
        md = emlib.MetaData(self._getFileName('resolutionXmdMax', iter=iterN, ref=refN))
        resol = md.getValue(emlib.MDL_RESOLUTION_FREQREAL, md.firstObject())
        if resol > 0:
            return resol
        sampling = md.getValue(emlib.MDL_SAMPLINGRATE, md.firstObject())
        return 2*sampling
    
[docs]    def calculateDeviationsStep(self, it):
        """ Calculate both angles and shifts devitations for all iterations
        """
    
        SL = emlib.SymList()
        mdIter = emlib.MetaData()
        #for it in self.allIters():
        mdIter.clear()
        SL.readSymmetryFile(self._symmetry[it])
        md1 = emlib.MetaData(self.docFileInputAngles[it])
        md2 = emlib.MetaData(self.docFileInputAngles[it-1])
        #ignore disabled,
        md1.removeDisabled()
        md2.removeDisabled()
        #first metadata file may not have shiftx and shifty
        if not md2.containsLabel(emlib.MDL_SHIFT_X):
            md2.addLabel(emlib.MDL_SHIFT_X)
            md2.addLabel(emlib.MDL_SHIFT_Y)
            md2.fillConstant(emlib.MDL_SHIFT_X,0.)
            md2.fillConstant(emlib.MDL_SHIFT_Y,0.)
        oldLabels=[emlib.MDL_ANGLE_ROT,
                   emlib.MDL_ANGLE_TILT,
                   emlib.MDL_ANGLE_PSI,
                   emlib.MDL_SHIFT_X,
                   emlib.MDL_SHIFT_Y]
        newLabels=[emlib.MDL_ANGLE_ROT2,
                   emlib.MDL_ANGLE_TILT2,
                   emlib.MDL_ANGLE_PSI2,
                   emlib.MDL_SHIFT_X2,
                   emlib.MDL_SHIFT_Y2]
        md2.renameColumn(oldLabels,newLabels)
        md2.addLabel(emlib.MDL_SHIFT_X_DIFF)
        md2.addLabel(emlib.MDL_SHIFT_Y_DIFF)
        md2.addLabel(emlib.MDL_SHIFT_DIFF)
        mdIter.join1(md1, md2, emlib.MDL_IMAGE, emlib.INNER_JOIN)
        SL.computeDistance(mdIter,False,False,False)
        emlib.activateMathExtensions()
        #operate in sqlite
        shiftXLabel     = emlib.label2Str(emlib.MDL_SHIFT_X)
        shiftX2Label    = emlib.label2Str(emlib.MDL_SHIFT_X2)
        shiftXDiff      = emlib.label2Str(emlib.MDL_SHIFT_X_DIFF)
        shiftYLabel     = emlib.label2Str(emlib.MDL_SHIFT_Y)
        shiftY2Label    = emlib.label2Str(emlib.MDL_SHIFT_Y2)
        shiftYDiff      = emlib.label2Str(emlib.MDL_SHIFT_Y_DIFF)
        shiftDiff       = emlib.label2Str(emlib.MDL_SHIFT_DIFF)
        #timeStr = str(dtBegin)
        operateString   =       shiftXDiff+"="+shiftXLabel+"-"+shiftX2Label
        operateString  += "," + shiftYDiff+"="+shiftYLabel+"-"+shiftY2Label
        mdIter.operate(operateString)
        operateString  =  shiftDiff+"=sqrt(" \
                          
+shiftXDiff+"*"+shiftXDiff+"+" \
                          
+shiftYDiff+"*"+shiftYDiff+");"
        mdIter.operate(operateString)
        iterFile = self._mdDevitationsFn(it)
        mdIter.write(iterFile,emlib.MD_APPEND)
        self._setLastIter(it) 
    
    def _mdDevitationsFn(self, it):
        mdFn = self._getPath('deviations.xmd')
        return "iter_%03d@" % it + mdFn
    def _setLastIter(self, iterN):
        self._lastIter.set(iterN)
        self._store(self._lastIter)
[docs]    def getLastIter(self):
        return self._lastIter.get() 
    
    def _fillParticlesFromIter(self, partSet, iteration):
        print("_fillParticlesFromIter")
        import pwem.emlib.metadata as md
        
        imgSet = self.inputParticles.get()
        imgFn = "all_exp_images@" + self._getFileName('docfileInputAnglesIters', iter=iteration, ref=1)
        partSet.copyInfo(imgSet)
        partSet.setAlignmentProj()
        
        partSet.copyItems(imgSet,
                            updateItemCallback=self._createItemMatrix,
                            itemDataIterator=md.iterRows(imgFn, sortByLabel=md.MDL_ITEM_ID))
    
    def _createItemMatrix(self, item, row):
        from ...convert import createItemMatrix
        from pwem.constants import ALIGN_PROJ
        
        createItemMatrix(item, row, align=ALIGN_PROJ)
    
    def _getIterParticles(self, it, clean=False):
        import pwem.objects as em
        """ Return a classes .sqlite file for this iteration.
        If the file doesn't exists, it will be created by 
        converting from this iteration data.star file.
        """
        
        dataParticles = self._getFileName('particlesScipion', iter=it)
        
        if clean:
            cleanPath(dataParticles)
            
        if not exists(dataParticles):
            partSet = em.SetOfParticles(filename=dataParticles)
            self._fillParticlesFromIter(partSet, it)
            partSet.write()
            partSet.close()
        else:
            partSet = em.SetOfParticles(filename=dataParticles)
            imgSet = self.inputParticles.get()
            partSet.copyInfo(imgSet)
            partSet.setAlignmentProj()
        return partSet