Source code for xmipp3.protocols.protocol_core_analysis

# ******************************************************************************
# *
# * Authors:     J.M. De la Rosa Trevin (jmdelarosa@cnb.csic.es)
# *              Carlos Oscar Sánchez Sorzano (coss@cnb.csic.es)
# *              Daniel Marchán Torres (da.marchan@cnb.csic.es)
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# ******************************************************************************

from os.path import join, exists
from glob import glob


import pyworkflow.protocol.params as param
import pyworkflow.protocol.constants as const
from pyworkflow.utils.path import cleanPath, makePath

import pwem.emlib.metadata as md
from pwem.protocols import ProtClassify2D
from pwem.constants import ALIGN_2D

from xmipp3.convert import (writeSetOfClasses2D, xmippToLocation,
                            rowToAlignment)


CLASSES_CORE = '_core'


[docs]class XmippProtCoreAnalysis(ProtClassify2D): """ Analyzes the core of a 2D classification. The core is calculated through the Mahalanobis distance from each image to the center of the class. """ _label = 'core analysis' def __init__(self, **args): ProtClassify2D.__init__(self, **args) if self.numberOfMpi.get() < 2: self.numberOfMpi.set(2) #--------------------------- DEFINE param functions ------------------------ def _defineParams(self, form): form.addSection(label='Input') form.addParam('inputClasses', param.PointerParam, label="Input classes", pointerClass='SetOfClasses2D', help='Set of input classes to be analyzed') form.addParam('thZscore', param.FloatParam, default=3, label='Junk Zscore', help='Which is the average Z-score to be considered as ' 'junk. Typical values go from 1.5 to 3. For the ' 'Gaussian distribution 99.5% of the data is ' 'within a Z-score of 3. Lower Z-scores reject more ' 'images. Higher Z-scores accept more images.') form.addParam('thPCAZscore', param.FloatParam, default=3, label='PCA Zscore', help='Which is the PCA Z-score to be considered as junk. ' 'Typical values go from 1.5 to 3. For the Gaussian ' 'distribution 99.5% of the data is within a ' 'Z-score of 3. Lower Z-scores reject more images. ' 'Higher Z-scores accept more images.') form.addParallelSection(threads=0, mpi=4) #--------------------------- INSERT steps functions ------------------------ def _insertAllSteps(self): self._defineFileNames() self._insertFunctionStep('analyzeCore') self._insertFunctionStep('createOutputStep')
[docs] def analyzeCore(self): # Put in a function convertInputStep fnLevel = self._getExtraPath('level_00') makePath(fnLevel) inputMdName = join(fnLevel, 'level_classes.xmd') writeSetOfClasses2D(self.inputClasses.get(), inputMdName, writeParticles=True) args = " --dir %s --root level --computeCore %f %f"%(self._getExtraPath(), self.thZscore, self.thPCAZscore) self.runJob('xmipp_classify_CL2D_core_analysis', args) self.runJob("xmipp_classify_evaluate_classes", "-i %s"%\ self._getExtraPath(join("level_00","level_classes_core.xmd")), numberOfMpi=1)
#--------------------------- STEPS functions ------------------------------- def _defineFileNames(self): """ Centralize how files are called within the protocol. """ self.levelPath = self._getExtraPath('level_%(level)02d/') myDict = { 'final_classes': self._getPath('classes2D%(sub)s.sqlite'), 'output_particles': self._getExtraPath('images.xmd'), 'level_classes' : self.levelPath + 'level_classes%(sub)s.xmd', 'level_images' : self.levelPath + 'level_images%(sub)s.xmd', 'classes_scipion': (self.levelPath + 'classes_scipion_level_' '%(level)02d%(sub)s.sqlite'), } self._updateFilenamesDict(myDict)
[docs] def createOutputStep(self): """ Store the SetOfClasses2D object resulting from the protocol execution. """ inputParticles = self.inputClasses.get().getImages() level = self._getLastLevel() subset = CLASSES_CORE subsetFn = self._getFileName("level_classes", level=level, sub=subset) if exists(subsetFn): classes2DSet = self._createSetOfClasses2D(inputParticles, subset) self._fillClassesFromLevel(classes2DSet, 'last', subset) result = {'outputClasses' + subset: classes2DSet} self._defineOutputs(**result) self._defineSourceRelation(inputParticles, classes2DSet)
#--------------------------- INFO functions -------------------------------- def _validate(self): validateMsgs = [] if self.numberOfMpi <= 1: validateMsgs.append('Mpi needs to be greater than 1.') return validateMsgs def _citations(self): citations=['Sorzano2014'] return citations def _methods(self): strline ='We calculated the class cores %s. [Sorzano2014]' % self.getObjectTag('outputClasses_core') return [strline] #--------------------------- UTILS functions ------------------------------- def _updateParticle(self, item, row): item.setClassId(row.getValue(md.MDL_REF)) item.setTransform(rowToAlignment(row, ALIGN_2D)) def _updateClass(self, item): classId = item.getObjId() if classId in self._classesInfo: index, fn, _ = self._classesInfo[classId] item.setAlignment2D() rep = item.getRepresentative() rep.setLocation(index, fn) rep.setSamplingRate(self.inputClasses.get().getImages().getSamplingRate()) def _loadClassesInfo(self, filename): """ Read some information about the produced 2D classes from the metadata file. """ self._classesInfo = {} # store classes info, indexed by class id mdClasses = md.MetaData(filename) for classNumber, row in enumerate(md.iterRows(mdClasses)): index, fn = xmippToLocation(row.getValue(md.MDL_IMAGE)) # Store info indexed by id, we need to store the row.clone() since # the same reference is used for iteration self._classesInfo[index] = (index, fn, row.clone()) def _fillClassesFromLevel(self, clsSet, level, subset): """ Create the SetOfClasses2D from a given iteration. """ classRf = '' self._loadClassesInfo(self._getLevelMdClasses(lev=level, subset=classRf)) if subset == '' and level == "last": xmpMd = self._getFileName('output_particles') if not exists(xmpMd): xmpMd = self._getLevelMdImages(level, subset) else: xmpMd = self._getLevelMdImages(level, subset) iterator = md.SetMdIterator(xmpMd, sortByLabel=md.MDL_ITEM_ID, updateItemCallback=self._updateParticle, skipDisabled=True) # itemDataIterator is not neccesary because, the class SetMdIterator # contain all the information about the metadata clsSet.classifyItems(updateItemCallback=iterator.updateItem, updateClassCallback=self._updateClass) def _getLevelMdClasses(self, lev=0, block="classes", subset=""): """ Return the classes metadata for this iteration. block parameter can be 'info' or 'classes'.""" if lev == "last": lev = self._getLastLevel() mdFile = self._getFileName('level_classes', level=lev, sub=subset) if block: mdFile = block + '@' + mdFile return mdFile def _getLevelMdImages(self, level, subset): if level == "last": level = self._getLastLevel() xmpMd = self._getFileName('level_images', level=level, sub=subset) if not exists(xmpMd): self._createLevelMdImages(level, subset) return xmpMd def _createLevelMdImages(self, level, sub): if level == "last": level = self._getLastLevel() mdClassesFn = self._getLevelMdClasses(lev=level, block="", subset=sub) mdImgs = md.joinBlocks(mdClassesFn, "class0") mdImgs.write(self._getFileName('level_images', level=level, sub=sub)) def _getLastLevel(self): """ Find the last Level number """ clsFn = self._getFileName('level_classes', level=0, sub="") levelTemplate = clsFn.replace('level_00', 'level_??') lev = len(glob(levelTemplate)) - 1 return lev