Source code for xmipp3.protocols.protocol_classify_kmeans2d

# *****************************************************************************
# *
# * Authors:     Tomas Majtner         tmajtner@cnb.csic.es (2017)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# *****************************************************************************

import os

import pyworkflow.protocol.params as param

from pyworkflow import VERSION_2_0

import pwem.emlib.metadata as md
import pyworkflow.protocol.constants as cons
from pyworkflow.object import Set
from pyworkflow.utils.path import cleanPath
from pyworkflow.utils.properties import Message

from pwem.protocols import ProtClassify2D
from pwem.objects import SetOfParticles, SetOfClasses2D
from pwem.constants import ALIGN_NONE

from xmipp3.convert import (writeSetOfParticles, xmippToLocation,
                            readSetOfParticles)


[docs]class XmippProtKmeansClassif2D(ProtClassify2D): """ Classifies a set of particles using a clustering algorithm to subdivide the original dataset into a given number of classes. """ _label = '2D kmeans clustering' _lastUpdateVersion = VERSION_2_0 def __init__(self, **args): ProtClassify2D.__init__(self, **args) self.stepsExecutionMode = param.STEPS_PARALLEL # --------------------------- DEFINE param functions ----------------------- def _defineParams(self, form): form.addSection(label=Message.LABEL_INPUT) form.addParam('inputParticles', param.PointerParam, label="Input images", important=True, pointerClass='SetOfParticles', help='Select the input images to be classified.') form.addParam('numberOfClasses', param.IntParam, default=10, label='Number of 2D classes:', help='Number of 2D classes to be created.') form.addParam('maxObjects', param.IntParam, default=-1, expertLevel=param.LEVEL_ADVANCED, label='Threshold for number of particles:', help='Threshold for number of particles after which the ' 'position of clusters will be fixed.') form.addParallelSection(threads=1, mpi=1) # --------------------------- INSERT steps functions ----------------------- def _insertAllSteps(self): self.finished = False self.insertedDict = {} self.SetOfParticles = [m.clone() for m in self.inputParticles.get()] partsSteps = self._insertNewPartsSteps(self.insertedDict, self.SetOfParticles) self._insertFunctionStep('createOutputStep', prerequisites=partsSteps, wait=True)
[docs] def createOutputStep(self): pass
def _getFirstJoinStepName(self): # This function will be used for streaming, to check which is # the first function that need to wait for all particles # to have completed, this can be overriden in subclasses # (e.g., in Xmipp 'sortPSDStep') return 'createOutputStep' def _getFirstJoinStep(self): for s in self._steps: if s.funcName == self._getFirstJoinStepName(): return s return None def _insertNewPartsSteps(self, insertedDict, inputParticles): deps = [] writeSetOfParticles([m.clone() for m in inputParticles], self._getExtraPath("allDone.xmd"), alignType=ALIGN_NONE) writeSetOfParticles([m.clone() for m in inputParticles if int(m.getObjId()) not in insertedDict], self._getExtraPath("newDone.xmd"), alignType=ALIGN_NONE) stepId = \ self._insertFunctionStep('kmeansClassifyStep', self._getExtraPath("newDone.xmd"), prerequisites=[]) deps.append(stepId) for part in inputParticles: if part.getObjId() not in insertedDict: insertedDict[part.getObjId()] = stepId return deps def _stepsCheck(self): # Input particles set can be loaded or None when checked for new inputs # If None, we load it self._checkNewInput() self._checkNewOutput() def _checkNewInput(self): # Check if there are new particles to process from the input set partsFile = self.inputParticles.get().getFileName() partsSet = SetOfParticles(filename=partsFile) partsSet.loadAllProperties() self.SetOfParticles = [m.clone() for m in partsSet] self.streamClosed = partsSet.isStreamClosed() partsSet.close() partsSet = self._createSetOfParticles() readSetOfParticles(self._getExtraPath("allDone.xmd"), partsSet) newParts = any(m.getObjId() not in partsSet for m in self.SetOfParticles) outputStep = self._getFirstJoinStep() if newParts: fDeps = self._insertNewPartsSteps(self.insertedDict, self.SetOfParticles) if outputStep is not None: outputStep.addPrerequisites(*fDeps) self.updateSteps() def _checkNewOutput(self): if getattr(self, 'finished', False): return # Load previously done items (from text file) doneList = self._readDoneList() # Check for newly done items partsSet = self._createSetOfParticles() readSetOfParticles(self._getExtraPath("allDone.xmd"), partsSet) newDone = [m.clone() for m in self.SetOfParticles if int(m.getObjId()) not in doneList] self.finished = self.streamClosed and (len(doneList) == len(partsSet)) if newDone: self._writeDoneList(newDone) elif not self.finished: # If we are not finished and no new output have been produced # it does not make sense to proceed and updated the outputs # so we exit from the function here return if self.finished: # Unlock createOutputStep if finished all jobs outputStep = self._getFirstJoinStep() if outputStep and outputStep.isWaiting(): outputStep.setStatus(cons.STATUS_NEW) def _loadOutputSet(self, SetClass, baseName): setFile = self._getPath(baseName) if os.path.exists(setFile): outputSet = SetClass(filename=setFile) outputSet.loadAllProperties() outputSet.enableAppend() else: outputSet = SetClass(filename=setFile) outputSet.setStreamState(outputSet.STREAM_OPEN) inputs = self.inputParticles.get() outputSet.copyInfo(inputs) return outputSet def _updateOutputSet(self, outputName, outputSet, state=Set.STREAM_OPEN): outputSet.setStreamState(state) if self.hasAttribute(outputName): outputSet.write() # Write to commit changes outputAttr = getattr(self, outputName) # Copy the properties to the object contained in the protocol outputAttr.copy(outputSet, copyId=False) # Persist changes self._store(outputAttr) else: set2D = self._createSetOfClasses2D(self.inputParticles.get()) self._fillClassesFromLevel(set2D) outputSet = {'outputClasses': set2D} self._defineOutputs(**outputSet) self._defineSourceRelation(self.inputParticles, set2D)
[docs] def kmeansClassifyStep(self, fnInputMd): iteration = 0 args = "-i %s -k %d -m %d" % (fnInputMd, self.numberOfClasses.get(), self.maxObjects.get()) self.runJob("xmipp_classify_kmeans_2d", args) cleanPath(self._getExtraPath("level_00")) blocks = md.getBlocksInMetaDataFile(self._getExtraPath("output.xmd")) fnDir = self._getExtraPath() # Gather all images in block for b in blocks: if b.startswith('class0'): args = "-i %s@%s --iter 5 --distance correlation " \ "--classicalMultiref --nref 1 --odir %s --oroot %s" % \ (b, self._getExtraPath("output.xmd"), fnDir, b) if iteration == 0: args += " --nref0 1" else: args += " --ref0 %s" % \ self._getExtraPath("level_00/%s_classes.stk" % b) self.runJob("xmipp_classify_CL2D", args, numberOfMpi=max(2, self.numberOfMpi.get())) cleanPath(self._getExtraPath("level_00/%s_classes.xmd" % b)) streamMode = Set.STREAM_CLOSED if self.finished else Set.STREAM_OPEN outSet = self._loadOutputSet(SetOfClasses2D, 'classes2D.sqlite') self._updateOutputSet('outputParticles', outSet, streamMode)
def _readDoneList(self): """ Read from a file the id's of the items that have been done. """ doneFile = self._getAllDone() doneList = [] # Check what items have been previously done if os.path.exists(doneFile): with open(doneFile) as f: doneList += [int(line.strip()) for line in f] return doneList def _getAllDone(self): return self._getExtraPath('DONE_all.TXT') def _writeDoneList(self, partList): """ Write to a text file the items that have been done. """ with open(self._getAllDone(), 'a') as f: for part in partList: f.write('%d\n' % part.getObjId()) def _updateParticle(self, item, row): item.setClassId(row.getValue(md.MDL_REF)) def _updateClass(self, item): classId = item.getObjId() if classId in self._classesInfo: index, fn, _ = self._classesInfo[classId] item.setAlignment2D() rep = item.getRepresentative() rep.setLocation(index, fn) rep.setSamplingRate(self.inputParticles.get().getSamplingRate()) def _loadClassesInfo(self, filename, blockId): """ Read some information about the produced 2D classes from the metadata file. """ self._classesInfo = {} # store classes info, indexed by class id mdClasses = md.MetaData(filename) for classNumber, row in enumerate(md.iterRows(mdClasses)): index, fn = xmippToLocation(row.getValue(md.MDL_IMAGE)) self._classesInfo[blockId] = (index, fn, row.clone()) def _fillClassesFromLevel(self, clsSet): """ Create the SetOfClasses2D from a given iteration. """ blocks = md.getBlocksInMetaDataFile(self._getExtraPath('output.xmd')) for bId, b in enumerate(blocks): if b.startswith('class0'): self._loadClassesInfo( self._getExtraPath("level_00/%s_classes.stk" % b), bId) xmpMd = b + "@" + self._getExtraPath('output.xmd') iterator = md.SetMdIterator(xmpMd, sortByLabel=md.MDL_ITEM_ID, updateItemCallback= self._updateParticle, skipDisabled=True) clsSet.classifyItems(updateItemCallback=iterator.updateItem, updateClassCallback=self._updateClass)