Source code for xmipp3.protocols.protocol_projmatch.protocol_projmatch

# **************************************************************************
# *
# * Authors:     Roberto Marabini (roberto@cnb.csic.es)
# *              J.M. De la Rosa Trevin (jmdelarosa@cnb.csic.es)
# *              Josue Gomez Blanco (josue.gomez-blanco@mcgill.ca)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

from pyworkflow.object import Integer
from pyworkflow.utils import (getFloatListFromValues, getBoolListFromValues,
                              getStringListFromValues)
from pwem.protocols import ProtRefine3D, ProtClassify3D
from pwem import emlib
from .projmatch_initialize import *
from .projmatch_form import _defineProjectionMatchingParams
from .projmatch_steps import *
from xmipp3.base import isXmippCudaPresent


[docs]class XmippProtProjMatch(ProtRefine3D, ProtClassify3D): """ 3D reconstruction and classification using multireference projection matching""" _label = 'projection matching' FILENAMENUMBERLENGTH = 6 def __init__(self, **args): ProtRefine3D.__init__(self, **args) ProtClassify3D.__init__(self, **args) self.numberOfCtfGroups = Integer(1) self._lastIter = Integer(0) def _initialize(self): """ This function is mean to be called after the working dir for the protocol have been set. (maybe after recovery from mapper) """ self._loadInputInfo() # Setup the dictionary with filenames templates to # be used by _getFileName createFilenameTemplates(self) # Load the values from several params generating a list # of values per iteration or references initializeLists(self) def _loadInputInfo(self): from ...convert import getImageLocation reference = self.input3DReferences.get() # Input can be either a single volume or a set of volumes. if isinstance(reference, Volume): # Treat the case of a single volume self.referenceFileNames = [getImageLocation(reference)] else: self.referenceFileNames = [getImageLocation(vol) for vol in reference] self.numberOfReferences = len(self.referenceFileNames) self.resolSam = reference.getSamplingRate() #--------------------------- DEFINE param functions -------------------------------------------- def _defineParams(self, form): """ Since the form definition is very very large, we have do it in a separated function. """ _defineProjectionMatchingParams(self, form) #--------------------------- INSERT steps functions -------------------------------------------- def _insertAllSteps(self): self._initialize() # Insert initial steps self._insertFunctionStep('convertInputStep') self._insertFunctionStep('executeCtfGroupsStep') # insertExecuteCtfGroupsStep(self) # insertInitAngularReferenceFileStep(self) self._insertFunctionStep('initAngularReferenceFileStep') # Steps per iteration self._insertItersSteps() # Final steps self._insertFunctionStep('createOutputStep') def _insertItersSteps(self): """ Insert several steps needed per iteration. """ for iterN in self.allIters(): dirsStep = self._insertFunctionStep('createIterDirsStep', iterN) # Insert some steps per reference volume projMatchSteps = [] for refN in self.allRefs(): # Mask the references in the iteration insertMaskReferenceStep(self, iterN, refN, prerequisites=[dirsStep]) # Create the library of projections insertAngularProjectLibraryStep(self, iterN, refN) # Projection matching steps projMatchStep = self._insertProjectionMatchingStep(iterN, refN) projMatchSteps.append(projMatchStep) # Select the reference that best fits each image self._insertFunctionStep('assignImagesToReferencesStep', iterN, prerequisites=projMatchSteps) insertAngularClassAverageStep(self, iterN, refN) # Reconstruct each reference with new averages for refN in self.allRefs(): # Create new class averages with images assigned insertReconstructionStep(self, iterN, refN) if self.doComputeResolution and self._doSplitReferenceImages[iterN]: # Reconstruct two halves of the data insertReconstructionStep(self, iterN, refN, 'Split1') insertReconstructionStep(self, iterN, refN, 'Split2') # Compute the resolution insertComputeResolutionStep(self, iterN, refN) insertFilterVolumeStep(self, iterN, refN) # Calculate both angles and shifts devitations for this iteration self._insertFunctionStep('calculateDeviationsStep', iterN) def _insertProjectionMatchingStep(self, iterN, refN): args = getProjectionMatchingArgs(self, iterN) return self._insertFunctionStep('projectionMatchingStep', iterN, refN, args) #--------------------------- STEPS functions --------------------------------------------
[docs] def convertInputStep(self): """ Generated the input particles metadata expected by projection matching. And copy the generated file to be used as initial docfile for further iterations. """ from ...convert import writeSetOfParticles writeSetOfParticles(self.inputParticles.get(), self.selFileName, blockName=self.blockWithAllExpImages)
#copyFile(self.selFileName, self._getFileName('inputParticlesDoc'))
[docs] def createIterDirsStep(self, iterN): """ Create the necessary directory for a given iteration. """ iterDirs = [self._getFileName(k, iter=iterN) for k in ['iterDir', 'projMatchDirs', 'libraryDirs']] for d in iterDirs: makePath(d) return iterDirs
[docs] def volumeConvertStep(self, reconstructedFilteredVolume, maskedFileName): runVolumeConvertStep(self, reconstructedFilteredVolume, maskedFileName)
[docs] def executeCtfGroupsStep(self, **kwargs): runExecuteCtfGroupsStep(self, **kwargs)
[docs] def transformMaskStep(self, program, args, **kwargs): runTransformMaskStep(self, program, args, **kwargs)
[docs] def angularProjectLibraryStep(self, iterN, refN, args, stepParams, **kwargs): runAngularProjectLibraryStep(self, iterN, refN, args, stepParams, **kwargs)
[docs] def initAngularReferenceFileStep(self): runInitAngularReferenceFileStep(self)
[docs] def projectionMatchingStep(self, iterN, refN, args): runProjectionMatching(self, iterN, refN, args)
[docs] def assignImagesToReferencesStep(self, iterN): runAssignImagesToReferences(self, iterN)
[docs] def cleanVolumeStep(self, vol1, vol2): cleanPath(vol1, vol2)
[docs] def reconstructionStep(self, iterN, refN, program, method, args, suffix, **kwargs): runReconstructionStep(self, iterN, refN, program, method, args, suffix, **kwargs)
[docs] def storeResolutionStep(self, resolIterMd, resolIterMaxMd, sampling): runStoreResolutionStep(self, resolIterMd, resolIterMaxMd, sampling)
[docs] def calculateFscStep(self, iterN, refN, args, constantToAdd, **kwargs): runCalculateFscStep(self, iterN, refN, args, constantToAdd, **kwargs)
[docs] def filterVolumeStep(self, iterN, refN, constantToAddToFiltration, **kwargs): runFilterVolumeStep(self, iterN, refN, constantToAddToFiltration, **kwargs)
[docs] def createOutputStep(self): runCreateOutputStep(self)
#--------------------------- INFO functions -------------------------------------------- def _validate(self): errors = [] if self.doCTFCorrection: if not self.doAutoCTFGroup and not exists(self.setOfDefocus.get()): errors.append("Error: for non-automated ctf grouping, " "please provide a docfile!") if not self.inputParticles.get().hasCTF(): errors.append("Error: for doing CTF correction the input " "particles should have CTF information.") if self.numberOfMpi <= 1: errors.append("The number of MPI processes has to be larger than 1") self._validateDim(self.inputParticles.get(), self.input3DReferences.get(), errors, 'Input particles', 'Reference volume') if self.useGpu and not isXmippCudaPresent(): errors.append("You have asked to use GPU, but I cannot find Xmipp GPU programs in the path") return errors def _citations(self): cites = [] return cites def _summary(self): summary = [] return summary def _methods(self): return self._summary() # summary is quite explicit and serve as methods #--------------------------- UTILS functions --------------------------------------------
[docs] def allIters(self): """ Iterate over all iterations. """ for i in range(1, self.numberOfIterations.get()+1): yield i
[docs] def allRefs(self): """ Iterate over all references. """ for i in range(1, self.numberOfReferences+1): yield i
[docs] def allCtfGroups(self): """ Iterate over all CTF groups. """ for i in range(1, self.numberOfCtfGroups.get() + 1): yield i
[docs] def itersFloatValues(self, attributeName, firstValue=-1): """ Take the string of a given attribute and create a list of floats that will be used by the iteratioins. An special first value will be added to the list for iteration 0. """ valuesStr = self.getAttributeValue(attributeName) if valuesStr is None: raise Exception('None value for attribute: %s' % attributeName) return [firstValue] + getFloatListFromValues(valuesStr, length=self.numberOfIterations.get())
[docs] def itersBoolValues(self, attributeName, firstValue=False): """ Take the string of a given attribute and create a list of booleans that will be used by the iteratioins. An special first value will be added to the list for iteration 0. """ valuesStr = self.getAttributeValue(attributeName) if valuesStr is None: raise Exception('None value for attribute: %s' % attributeName) return [firstValue] + getBoolListFromValues(valuesStr, length=self.numberOfIterations.get())
[docs] def itersStringValues(self, attributeName, firstValue='c1'): """ Take the string of a given attribute and create a list of strings that will be used by the iteratioins. An special first value will be added to the list for iteration 0. """ valuesStr = self.getAttributeValue(attributeName) if valuesStr is None: raise Exception('None value for attribute: %s' % attributeName) return [firstValue] + getStringListFromValues(valuesStr, length=self.numberOfIterations.get())
def _getBlockFileName(self, blockName, blockNumber, filename, length=None): l = length or self.FILENAMENUMBERLENGTH return blockName + str(blockNumber).zfill(l) + '@' + filename def _getExpImagesFileName(self, filename): return self.blockWithAllExpImages + '@' + filename def _getRefBlockFileName(self, ctfBlName, ctfBlNumber, refBlName, refBlNumber, filename, length=None): l = length or self.FILENAMENUMBERLENGTH return ctfBlName + str(ctfBlNumber).zfill(l) + '_' + refBlName + str(refBlNumber).zfill(l) + '@' + filename def _getFourierMaxFrequencyOfInterest(self, iterN, refN): """ Read the corresponding resolution metadata and return the desired resolution. """ md = emlib.MetaData(self._getFileName('resolutionXmdMax', iter=iterN, ref=refN)) resol = md.getValue(emlib.MDL_RESOLUTION_FREQREAL, md.firstObject()) if resol > 0: return resol sampling = md.getValue(emlib.MDL_SAMPLINGRATE, md.firstObject()) return 2*sampling
[docs] def calculateDeviationsStep(self, it): """ Calculate both angles and shifts devitations for all iterations """ SL = emlib.SymList() mdIter = emlib.MetaData() #for it in self.allIters(): mdIter.clear() SL.readSymmetryFile(self._symmetry[it]) md1 = emlib.MetaData(self.docFileInputAngles[it]) md2 = emlib.MetaData(self.docFileInputAngles[it-1]) #ignore disabled, md1.removeDisabled() md2.removeDisabled() #first metadata file may not have shiftx and shifty if not md2.containsLabel(emlib.MDL_SHIFT_X): md2.addLabel(emlib.MDL_SHIFT_X) md2.addLabel(emlib.MDL_SHIFT_Y) md2.fillConstant(emlib.MDL_SHIFT_X,0.) md2.fillConstant(emlib.MDL_SHIFT_Y,0.) oldLabels=[emlib.MDL_ANGLE_ROT, emlib.MDL_ANGLE_TILT, emlib.MDL_ANGLE_PSI, emlib.MDL_SHIFT_X, emlib.MDL_SHIFT_Y] newLabels=[emlib.MDL_ANGLE_ROT2, emlib.MDL_ANGLE_TILT2, emlib.MDL_ANGLE_PSI2, emlib.MDL_SHIFT_X2, emlib.MDL_SHIFT_Y2] md2.renameColumn(oldLabels,newLabels) md2.addLabel(emlib.MDL_SHIFT_X_DIFF) md2.addLabel(emlib.MDL_SHIFT_Y_DIFF) md2.addLabel(emlib.MDL_SHIFT_DIFF) mdIter.join1(md1, md2, emlib.MDL_IMAGE, emlib.INNER_JOIN) SL.computeDistance(mdIter,False,False,False) emlib.activateMathExtensions() #operate in sqlite shiftXLabel = emlib.label2Str(emlib.MDL_SHIFT_X) shiftX2Label = emlib.label2Str(emlib.MDL_SHIFT_X2) shiftXDiff = emlib.label2Str(emlib.MDL_SHIFT_X_DIFF) shiftYLabel = emlib.label2Str(emlib.MDL_SHIFT_Y) shiftY2Label = emlib.label2Str(emlib.MDL_SHIFT_Y2) shiftYDiff = emlib.label2Str(emlib.MDL_SHIFT_Y_DIFF) shiftDiff = emlib.label2Str(emlib.MDL_SHIFT_DIFF) #timeStr = str(dtBegin) operateString = shiftXDiff+"="+shiftXLabel+"-"+shiftX2Label operateString += "," + shiftYDiff+"="+shiftYLabel+"-"+shiftY2Label mdIter.operate(operateString) operateString = shiftDiff+"=sqrt(" \ +shiftXDiff+"*"+shiftXDiff+"+" \ +shiftYDiff+"*"+shiftYDiff+");" mdIter.operate(operateString) iterFile = self._mdDevitationsFn(it) mdIter.write(iterFile,emlib.MD_APPEND) self._setLastIter(it)
def _mdDevitationsFn(self, it): mdFn = self._getPath('deviations.xmd') return "iter_%03d@" % it + mdFn def _setLastIter(self, iterN): self._lastIter.set(iterN) self._store(self._lastIter)
[docs] def getLastIter(self): return self._lastIter.get()
def _fillParticlesFromIter(self, partSet, iteration): print("_fillParticlesFromIter") import pwem.emlib.metadata as md imgSet = self.inputParticles.get() imgFn = "all_exp_images@" + self._getFileName('docfileInputAnglesIters', iter=iteration, ref=1) partSet.copyInfo(imgSet) partSet.setAlignmentProj() partSet.copyItems(imgSet, updateItemCallback=self._createItemMatrix, itemDataIterator=md.iterRows(imgFn, sortByLabel=md.MDL_ITEM_ID)) def _createItemMatrix(self, item, row): from ...convert import createItemMatrix from pwem.constants import ALIGN_PROJ createItemMatrix(item, row, align=ALIGN_PROJ) def _getIterParticles(self, it, clean=False): import pwem.objects as em """ Return a classes .sqlite file for this iteration. If the file doesn't exists, it will be created by converting from this iteration data.star file. """ dataParticles = self._getFileName('particlesScipion', iter=it) if clean: cleanPath(dataParticles) if not exists(dataParticles): partSet = em.SetOfParticles(filename=dataParticles) self._fillParticlesFromIter(partSet, it) partSet.write() partSet.close() else: partSet = em.SetOfParticles(filename=dataParticles) imgSet = self.inputParticles.get() partSet.copyInfo(imgSet) partSet.setAlignmentProj() return partSet