Source code for xmipp3.protocols.protocol_classification_gpuCorr_semi

# ******************************************************************************
# *
# * Authors:    Josue Gomez Blanco (
# *             Amaya Jimenez Moreno (
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address ''
# *
# ******************************************************************************

from os.path import getmtime
from datetime import datetime
from os.path import exists, splitext
import os

from pyworkflow import VERSION_2_0
import pyworkflow.protocol.params as params
from pyworkflow.object import Set, Float, String
from pyworkflow.protocol.constants import STATUS_NEW
from pyworkflow.utils import prettyTime
import pyworkflow.protocol.constants as const

from pwem.objects import SetOfParticles, SetOfClasses2D, Class2D
from pwem.constants import ALIGN_2D, ALIGN_NONE
from pwem.protocols import ProtAlign2D
import pwem.emlib.metadata as md

from xmipp3.constants import CUDA_ALIGN_SIGNIFICANT
from xmipp3.convert import (writeSetOfParticles, rowToAlignment,
from xmipp3.base import isXmippCudaPresent


[docs]class HashTableDict: def __init__(self, Ndict=HASH_SIZE): self.Ndict = Ndict self.dict = [{}]*Ndict
[docs] def isItemPresent(self, idx): return idx in self.dict[idx % self.Ndict]
[docs] def pushItem(self, idx): idxDict = idx % self.Ndict if not idx in self.dict[idxDict]: self.dict[idxDict][idx]=1
[docs]class XmippProtStrGpuCrrSimple(ProtAlign2D): """ 2D alignment in semi streaming using Xmipp GPU Correlation. A previous set of classes must be provided to include the new images in the corresponding class although the representatives will be maintained.""" _label = 'gl2d static' _lastUpdateVersion = VERSION_2_0 # --------------------------- DEFINE param functions ----------------------- def _defineAlignParams(self, form): form.addHidden(params.GPU_LIST, params.StringParam, default='0', expertLevel=const.LEVEL_ADVANCED, label="Choose GPU IDs", help="Add a list of GPU devices that can be used") form.addParam('inputRefs', params.PointerParam, pointerClass='SetOfClasses2D, SetOfAverages', important=True, label="Set of references", help='Set of references that will serve as reference for ' 'the classification. This can be a set of classes ' 'or set of averages') form.addParam('maxShift', params.IntParam, default=10, label='Maximum shift (px):', help='Maximum shift allowed during the alignment as ' 'percentage of the input set size', expertLevel=const.LEVEL_ADVANCED) form.addParam('keepBest', params.IntParam, default=1, label='Number of best images:', help='Number of the best images to keep for every class', expertLevel=const.LEVEL_ADVANCED) # --------------------------- INSERT steps functions ----------------------- def _insertAllSteps(self): """" Insert the steps to call cuda correlation program""" self.listInFn = [] self.listOutFn = [] self.doneListFn = [] self.lastDate = 0 self.flag_relion = False self.imgsRef = self._getExtraPath('imagesRef.xmd') self.htAlreadyProcessed = HashTableDict() xOrig = self.inputParticles.get().getXDim() self.maximumShift = int(self.maxShift.get() * xOrig / 100) self._loadInputList() if isinstance(self.inputRefs.get(), SetOfClasses2D): self.useAsRef = REF_CLASSES else: self.useAsRef = REF_AVERAGES deps = [] self._insertFunctionStep('convertAveragesStep') deps = self._insertStepsForParticles(deps) self._insertFunctionStep('createOutputStep', prerequisites=deps, wait=True) def _insertStepsForParticles(self, deps): stepIdClassify = self._insertFunctionStep('classifyStep', prerequisites= deps) deps.append(stepIdClassify) return deps # --------------------------- STEPS functions ------------------------------
[docs] def convertAveragesStep(self): if self.useAsRef == REF_CLASSES: writeSetOfClasses2D(self.inputRefs.get(), self.imgsRef, writeParticles=True) else: writeSetOfParticles(self.inputRefs.get(), self.imgsRef)
[docs] def classifyStep(self): inputImgs = self._getInputFn() writeSetOfParticles(self.listOfParticles, inputImgs, alignType=ALIGN_NONE) for p in self.listOfParticles: partId = p.getObjId() self.htAlreadyProcessed.pushItem(partId) self.lastDate = p.getObjCreation() self._saveCreationTimeFile(self.lastDate) metadataRef = md.MetaData(self.imgsRef) if metadataRef.containsLabel(md.MDL_REF) is False: args = ('-i %s --fill ref lineal 1 1 -o %s')%(self.imgsRef, self.imgsRef) self.runJob("xmipp_metadata_utilities", args, numberOfMpi=1) # Calling program xmipp_cuda_correlation count = 0 GpuListCuda = '' if self.useQueueForSteps() or self.useQueue(): GpuList = os.environ["CUDA_VISIBLE_DEVICES"] GpuList = GpuList.split(",") for elem in GpuList: GpuListCuda = GpuListCuda + str(count) + ' ' count += 1 else: GpuListAux = '' for elem in self.getGpuList(): GpuListCuda = GpuListCuda + str(count) + ' ' GpuListAux = GpuListAux + str(elem) + ',' count += 1 os.environ["CUDA_VISIBLE_DEVICES"] = GpuListAux outImgs, clasesOut = self._getOutputsFn() self._params = {'imgsRef': self.imgsRef, 'imgsExp': inputImgs, 'outputFile': outImgs, 'keepBest': self.keepBest.get(), 'maxshift': self.maximumShift, 'outputClassesFile': clasesOut, 'device': GpuListCuda, 'outputClassesFileNoExt': splitext(clasesOut)[0], } args = '-i %(imgsExp)s -r %(imgsRef)s -o %(outputFile)s ' \ '--keepBestN 1 --oUpdatedRefs %(outputClassesFileNoExt)s --dev %(device)s ' self.runJob(CUDA_ALIGN_SIGNIFICANT, args % self._params, numberOfMpi=1)
# ------ Methods for Streaming 2D Classification -------------- def _stepsCheck(self): self._checkNewInput() self._checkNewOutput() def _checkNewInput(self): """ Check if there are new particles to be processed and add the necessary steps.""" particlesFile = self.inputParticles.get().getFileName() now = self.lastCheck = getattr(self, 'lastCheck', now) mTime = datetime.fromtimestamp(getmtime(particlesFile)) self.debug('Last check: %s, modification: %s' % (prettyTime(self.lastCheck), prettyTime(mTime))) # If the input have not changed since our last check, # it does not make sense to check for new input data if self.lastCheck > mTime and hasattr(self, 'listOfParticles'): return None self.lastCheck = now outputStep = self._getFirstJoinStep() # Open input and close it as soon as possible self._loadInputList() fDeps=[] fDeps = self._insertStepsForParticles(fDeps) if outputStep is not None: outputStep.addPrerequisites(*fDeps) self.updateSteps() def _checkNewOutput(self): """ Check for already done files and update the output set. """ # Check for newly done items newDone = self._readDoneList() # We have finished when there is not more inputs (stream closed) # and the number of processed particles is equal to the number of inputs self.finished = (self.isStreamClosed == Set.STREAM_CLOSED and len(newDone)==0) streamMode = Set.STREAM_CLOSED if self.finished else Set.STREAM_OPEN if newDone: self._updateOutputSetOfClasses(newDone, streamMode) self.doneListFn += newDone elif not self.finished: # If we are not finished and no new output have been produced # it does not make sense to proceed and updated the outputs # so we exit from the function here return if self.finished: # Unlock createOutputStep if finished all jobs outputStep = self._getFirstJoinStep() if outputStep and outputStep.isWaiting(): outputStep.setStatus(STATUS_NEW)
[docs] def createOutputStep(self): pass
# --------------------------- INFO functions ------------------------------- def _validate(self): errors = [] refImage = self.inputRefs.get() [x1, y1, _] = refImage.getDimensions() [x2, y2, _] = self.inputParticles.get().getDim() if x1 != x2 or y1 != y2: errors.append('The input images (%s, %s) and the reference images (%s, %s) ' 'have different sizes' % (x1, y1, x2, y2)) if not isXmippCudaPresent("xmipp_cuda_correlation"): errors.append("I cannot find the Xmipp GPU programs in the path") return errors def _summary(self): summary = [] if not hasattr(self, 'outputClasses'): summary.append("Output alignment not ready yet.") else: summary.append("Input Particles: %s" % self.inputParticles.get().getSize()) if isinstance(self.inputRefs.get(), SetOfClasses2D): summary.append("Aligned with reference classes: %s" % self.inputRefs.get().getSize()) else: summary.append("Aligned with reference averages: %s" % self.inputRefs.get().getDimensions()) return summary def _methods(self): methods = [] if not hasattr(self, 'outputClasses'): methods.append("Output alignment not ready yet.") else: methods.append( "We aligned images %s with respect to the reference image set " "%s using Xmipp CUDA correlation" % (self.getObjectTag('inputParticles'), self.getObjectTag('inputRefs'))) return methods # --------------------------- UTILS functions ------------------------------ def _loadInputList(self): """ Load the input set of ctfs and create a list. """ particlesSet = self._loadInputParticleSet() self.isStreamClosed = particlesSet.getStreamState() self.listOfParticles = [] lastDate = self._readCreationTimeFile() for p in particlesSet.iterItems(orderBy='creation', where="creation>'%s'" % lastDate): idx = p.getObjId() if not self.htAlreadyProcessed.isItemPresent(idx): newPart = p.clone() newPart.setObjCreation(p.getObjCreation()) self.listOfParticles.append(newPart) particlesSet.close() self.debug("Closed db.") def _loadInputParticleSet(self): partSetFn = self.inputParticles.get().getFileName() updatedSet = SetOfParticles(filename=partSetFn) copyPartSet = SetOfParticles() updatedSet.loadAllProperties() copyPartSet.copy(updatedSet) updatedSet.close() return copyPartSet def _getFirstJoinStep(self): for s in self._steps: if s.funcName == 'createOutputStep': return s return None def _readDoneList(self): return [fn for fn in self.listOutFn if fn not in self.doneListFn] def _updateOutputSetOfClasses(self, outFnDone, streamMode): outputName = 'outputClasses' outputClasses = getattr(self, outputName, None) firstTime = True if outputClasses is None: outputClasses=self._createSetOfClasses2D(self.inputParticles.get()) else: firstTime = False outputClasses = SetOfClasses2D(filename=outputClasses.getFileName()) outputClasses.setStreamState(streamMode) outputClasses.setImages(self.inputParticles.get()) self._fillClassesFromMd(outFnDone, outputClasses, firstTime, streamMode) self._updateOutputSet(outputName, outputClasses, streamMode) if firstTime: self._defineSourceRelation(self.inputParticles, outputClasses) def _updateParticle(self, item, row): item.setClassId(row.getValue(md.MDL_REF)) item.setTransform(rowToAlignment(row, ALIGN_2D)) if self.flag_relion: item._rlnLogLikeliContribution=Float(None) item._rlnMaxValueProbDistribution=Float(None) item._rlnGroupName=String(None) item._rlnNormCorrection=Float(None) def _fillClassesFromMd(self, outFnDone, outputClasses, firstTime, streamMode): for outFn in outFnDone: mdImages = md.MetaData(outFn) inputSet = self._loadInputParticleSet() clsIdList = [] if self.useAsRef == REF_CLASSES: cls2d = self.inputRefs.get() for cls in cls2d: for img in cls: if img.hasAttribute('_rlnGroupName'): self.flag_relion = True break break if firstTime: self.lastId = 0 if self.useAsRef == REF_AVERAGES: repSet = self.inputRefs.get() for rep in repSet: repId = rep.getObjId() newClass = Class2D(objId=repId) newClass.setAlignment2D() newClass.copyInfo(inputSet) newClass.setAcquisition(inputSet.getAcquisition()) newClass.setRepresentative(rep) newClass.setStreamState(streamMode) outputClasses.append(newClass) else: cls2d = self.inputRefs.get() for cls in cls2d: representative = cls.getRepresentative() repId = cls.getObjId() newClass = Class2D(objId=repId) newClass.setAlignment2D() newClass.copyInfo(inputSet) newClass.setAcquisition(inputSet.getAcquisition()) newClass.setRepresentative(representative) newClass.setStreamState(streamMode) outputClasses.append(newClass) #Fill the output set with the previous particles # of the classes for cls in cls2d: repId = cls.getObjId() newClass = outputClasses[repId] for img in cls: # if not self.flag_relion \ # and img.hasAttribute('_rlnGroupName'): # self.flag_relion=True newClass.append(img) #We store the last id just in case we found any # problem with repeated ids that requires to # change them, this must not happen #if img.getObjId()>self.lastId: # self.lastId = img.getObjId() outputClasses.update(newClass) for imgRow in md.iterRows(mdImages, sortByLabel=md.MDL_REF): #Just in case of having repeated ids, this must not happen #self.lastId+=1 imgClassId = imgRow.getValue(md.MDL_REF) imgId = imgRow.getValue(md.MDL_ITEM_ID) if imgClassId not in clsIdList: if len(clsIdList) > 0: newClass.setAlignment2D() outputClasses.update(newClass) newClass = outputClasses[imgClassId] newClass.enableAppend() clsIdList.append(imgClassId) part = inputSet[imgId] self._updateParticle(part, imgRow) # Just in case of having repeated ids, this must not happen #part.setObjId(self.lastId) newClass.append(part) # this is to update the last class into the set. newClass.setAlignment2D() outputClasses.update(newClass) # FirstTime to False if iterate more than one metadata file. if firstTime: firstTime = False def _getUniqueFn(self, basename, list): if list == []: fn = basename + "_1.xmd" else: number = int(list[-1].split("_")[-1].split(".")[0]) + 1 fn = basename + "_%s.xmd" % number list.append(fn) return fn def _getInputFn(self): basename = self._getExtraPath('imagesExp') return self._getUniqueFn(basename, self.listInFn) def _getOutputsFn(self): nameImages = self._getExtraPath('general_images') imagesFn = self._getUniqueFn(nameImages, self.listOutFn) classesFn = imagesFn.replace('images', 'classes') return imagesFn, classesFn def _saveCreationTimeFile(self, cTime): fn = open(self._getExtraPath('creation.txt'),'w') fn.write(cTime) fn.close() def _readCreationTimeFile(self): if exists(self._getExtraPath('creation.txt')): fn = open(self._getExtraPath('creation.txt'), 'r') cTime = fn.readline() fn.close() else: cTime = 0 return cTime