# **************************************************************************
# *
# * Authors: Airen Zaldivar Peraza (azaldivar@cnb.csic.es) [1]
# * J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [2]
# *
# * [1] Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# * [2] SciLifeLab, Stockholm University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import enum
import os
from datetime import datetime
from collections import OrderedDict
import pyworkflow.object as pwobj
import pyworkflow.protocol as pwprot
import pyworkflow.protocol.params as params
import pyworkflow.utils as pwutils
import pwem.constants as emcts
import pwem.objects as emobj
from pwem.protocols import EMProtocol
[docs]class ProtParticles(EMProtocol):
pass
[docs]class ProtProcessParticles(ProtParticles):
""" This class will serve as a base for all protocol
that performs some operation on Particles (i.e. filters, mask, resize, etc)
It is mainly defined by an inputParticles and outputParticles.
"""
def _defineParams(self, form):
form.addSection(label=pwutils.Message.LABEL_INPUT)
form.addParam('inputParticles', params.PointerParam,
pointerClass='SetOfParticles',
label=pwutils.Message.LABEL_INPUT_PART, important=True)
# Hook that should be implemented in subclasses
self._defineProcessParams(form)
__threads, __mpi = self._getDefaultParallel()
form.addParallelSection(threads=__threads, mpi=__mpi)
def _defineProcessParams(self, form):
""" This method should be implemented by subclasses
to add other parameter relatives to the specific operation."""
pass
def _getDefaultParallel(self):
""" Return the default value for thread and MPI
for the parallel section definition.
"""
return 0, 0
[docs]class ProtFilterParticles(ProtProcessParticles):
""" Base class for filters on particles of type ProtPreprocessParticles.
"""
pass
[docs]class ProtOperateParticles(ProtProcessParticles):
""" Base class for operations on particles of type ProtPreprocessParticles.
"""
def __init__(self, **args):
ProtProcessParticles.__init__(self, **args)
[docs]class ProtMaskParticles(ProtProcessParticles):
""" This is the base for the branch of mask,
between the ProtPreprocessParticles """
pass
# Micrograph type constants for particle extraction
SAME_AS_PICKING = 0
OTHER = 1
[docs]class ProtExtractParticlesOutput(enum.Enum):
"""Predefined outputs for particle extraction protocols"""
outputParticles = emobj.SetOfParticles
# noinspection SqlDialectInspection
[docs]class ProtExtractParticles(ProtParticles):
""" Base class for all extract-particles protocols.
This class will take care of the streaming functionality and
derived classes should mainly overwrite the '_extractMicrograph' function.
"""
_possibleOutputs = ProtExtractParticlesOutput
# --------------------------- DEFINE param functions ------------------------
def _defineParams(self, form):
form.addSection(label='Input')
form.addParam('inputCoordinates', params.PointerParam,
pointerClass='SetOfCoordinates',
important=True,
label="Input coordinates",
help='Select the SetOfCoordinates ')
# The name for the following param is because historical reasons
# now it should be named better 'micsSource' rather than
# 'downsampleType', but this could make inconsistent previous executions
# of this protocols, we will keep the name
form.addParam('downsampleType', params.EnumParam,
choices=['same as picking', 'other'],
default=0, important=True,
display=params.EnumParam.DISPLAY_HLIST,
label='Micrographs source',
help='By default the particles will be extracted '
'from the micrographs used in the picking '
'step ( _same as picking_ option ). \n'
'If you select _other_ option, you must provide '
'a different set of micrographs to extract from. \n'
'*Note*: In the _other_ case, ensure that provided '
'micrographs and coordinates are related '
'by micName or by micId. Difference in pixel size '
'will be handled automatically.')
form.addParam('inputMicrographs', params.PointerParam, allowsNull=True,
pointerClass='SetOfMicrographs',
condition='downsampleType != %s' % SAME_AS_PICKING,
important=True, label='Input micrographs',
help='Select the SetOfMicrographs from which to extract.')
form.addParam('ctfRelations', params.RelationParam, allowsNull=True,
condition='inputCoordinates is not None',
relationName=emcts.RELATION_CTF,
attributeName='getInputMicrographs',
label='CTF estimation',
help='Choose some CTF estimation related to input '
'micrographs. \n CTF estimation is needed if you '
'want to do phase flipping or you want to '
'associate CTF information to the particles.')
self._definePreprocessParams(form)
def _definePreprocessParams(self, form):
""" Should be implemented in sub-classes to define some
specific parameters """
pass
# --------------------------- INSERT steps functions ----------------------
def _insertAllSteps(self):
# Let's load input data for the already existing micrographs
# before the streaming
self.debug(">>> _insertAllSteps ")
pwutils.makeFilePath(self._getAllDone())
self.micDict = OrderedDict()
self.coordDict = {}
micDict = self._loadInputList()
self.initialIds = self._insertInitialSteps()
pickMicIds = self._insertNewMicsSteps(micDict.values())
self._insertFinalSteps(pickMicIds)
def _insertInitialSteps(self):
""" Override this function to insert some steps before the
extract micrograph steps.
Should return a list of ids of the initial steps. """
return []
def _insertNewMicsSteps(self, inputMics):
""" Insert steps to process new mics (from streaming)
Params:
inputMics: input mics set to be check
"""
return self._insertNewMics(inputMics,
lambda mic: mic.getMicName(),
self._insertExtractMicrographStep,
self._insertExtractMicrographListStep,
*self._getExtractArgs())
def _insertFinalSteps(self, micSteps):
""" Override this function to insert some steps after the
extraction of micrograph steps.
Receive the list of step ids of the picking steps. """
self._insertFunctionStep('createOutputStep',
prerequisites=micSteps, wait=True, needsGPU=False)
def _insertExtractMicrographStep(self, mic, prerequisites, *args):
""" Basic method to insert a picking step for a given micrograph. """
micStepId = self._insertFunctionStep('extractMicrographStep',
mic.getMicName(), *args,
prerequisites=prerequisites)
return micStepId
# -------------------------- STEPS functions ------------------------------
[docs] def extractMicrographStep(self, micKey, *args):
""" Step function that will be common for all extraction protocols.
It will take an id and will grab the micrograph from a micDict map.
The micrograph will be passed as input to the _extractMicrograph
function.
"""
# Retrieve the corresponding micrograph with this key and the
# associated list of coordinates
mic = self.micDict[micKey]
micDoneFn = self._getMicDone(mic)
micFn = mic.getFileName()
if self.isContinued() and os.path.exists(micDoneFn):
self.info("Skipping micrograph: %s, seems to be done" % micFn)
return
coordList = self.coordDict[mic.getObjId()]
self._convertCoordinates(mic, coordList)
# Clean old finished files
pwutils.cleanPath(micDoneFn)
self.info("Extracting micrograph: %s " % micFn)
self._extractMicrograph(mic, *args)
# Mark this mic as finished
open(micDoneFn, 'w').close()
def _extractMicrograph(self, mic, *args):
""" This function should be implemented by subclasses in order
to pick the given micrograph. """
pass
# ---------- Methods to extract many micrographs at once ------------------
def _insertExtractMicrographListStep(self, micList, prerequisites, *args):
""" Basic method to insert a picking step for a given micrograph. """
return self._insertFunctionStep('extractMicrographListStep',
[mic.getMicName() for mic in micList],
*args, prerequisites=prerequisites)
[docs] def extractMicrographListStep(self, micKeyList, *args):
micList = []
for micName in micKeyList:
mic = self.micDict[micName]
micDoneFn = self._getMicDone(mic)
micFn = mic.getFileName()
if self.isContinued() and os.path.exists(micDoneFn):
self.info("Skipping micrograph: %s, seems to be done" % micFn)
else:
# Clean old finished files
pwutils.cleanPath(micDoneFn)
self.info("Extracting micrograph: %s " % micFn)
micList.append(mic)
self._extractMicrographList(micList, *args)
for mic in micList:
# Mark this mic as finished
open(self._getMicDone(mic), 'w').close()
def _extractMicrographList(self, micList, *args):
""" Extract more than one micrograph at once.
Here the default implementation is to iterate through the list and
call the single extract, but it could be re-implemented on each
subclass to provide a more efficient implementation.
"""
for mic in micList:
self._extractMicrograph(mic, *args)
# --------------------------- UTILS functions -----------------------------
def _convertCoordinates(self, mic, coordList):
""" This function should be implemented by subclasses. """
pass
def _getExtractArgs(self):
""" Should be implemented in sub-classes to define the argument
list that should be passed to the extract step function.
"""
return []
def _micsOther(self):
""" Return True if other micrographs are used for extract.
Should be implemented in derived classes.
"""
return False
def _useCTF(self):
""" Return True if a SetOfCTF is associated with the extracted
particles.
Should be implemented in derived classes.
"""
return False
# ------ Methods for Streaming extraction --------------
def _isStreamClosed(self):
return self.coordsClosed
def _areAllMicsProcessed(self):
"""
This condition determines if the processing is complete when all the micrographs associated
with the input coordinates have been processed.
"""
currentPicsMics = self.inputCoordinates.get().getUniqueValues("_micName")
return len(self.micDict) == len(currentPicsMics)
def _stepsCheck(self):
# To allow streaming picking we need to detect:
# 1) new micrographs ready to be picked
# 2) new output coordinates that have been produced and add then
# to the output set.
self._checkNewInput()
self._checkNewOutput()
def _loadInputList(self):
""" Load the input set of micrographs and create a dictionary.
The dictionary self.micDict, will contain all the micrographs
from which particles have been extracted and also ones that
are ready to be extracted.
For creating this dictionary, we need to inspect:
1) Input micrographs coming from coordinates, to know that a given
micrographs has been picked.
2) Micrographs to be extracted (in case it is different from 1)
3) New computed CTF (in case it is associated with the particles)
"""
def _loadSet(inputSet, SetClass, getKeyFunc):
setFn = inputSet.getFileName()
self.debug("Loading input db: %s" % setFn)
updatedSet = SetClass(filename=setFn)
updatedSet.loadAllProperties()
newItemDict = OrderedDict()
for item in updatedSet:
micKey = getKeyFunc(item)
if micKey not in self.micDict:
newItemDict[micKey] = item.clone()
streamClosed = updatedSet.isStreamClosed()
updatedSet.close()
self.debug("Closed db.")
return newItemDict, streamClosed
def _loadMics(micSet):
return _loadSet(micSet, emobj.SetOfMicrographs,
lambda mic: mic.getMicName())
def _loadCTFs(ctfSet):
return _loadSet(ctfSet, emobj.SetOfCTF,
lambda ctf: ctf.getMicrograph().getMicName())
# Load new micrographs coming from the coordinates
self.debug("Loading Mics from Coords.")
coordMics = self.inputCoordinates.get().getMicrographs()
micDict, self.micsClosed = _loadMics(coordMics)
# If we are extracting from other micrographs, then we will use
# the other micrographs and filter those that
if self._micsOther():
self.debug("Loading other Mics.")
oMicDict, oMicClosed = _loadMics(self.inputMicrographs.get())
self.micsClosed = self.micsClosed and oMicClosed
micDictNew = {}
for micKey, mic in micDict.items():
if micKey in oMicDict:
oMic = oMicDict[micKey]
# Let's fix the id in case it does not correspond
# we want to have the id coming from the coordinates
# to match each coordinate to its micrograph
oMic.copyObjId(mic)
micDictNew[micKey] = oMic
micDict = micDictNew
self.debug("Mics are closed? %s" % self.micsClosed)
if self._useCTF():
self.debug("Loading CTFs.")
ctfDict, ctfClosed = _loadCTFs(self.ctfRelations.get())
micDictNew = {}
for micKey, mic in micDict.items():
if micKey in ctfDict:
mic.setCTF(ctfDict[micKey])
micDictNew[micKey] = mic
micDict = micDictNew
# if not use CTF, self.ctfsClosed is True
self.ctfsClosed = ctfClosed if self._useCTF() else True
self.debug("CTFs are closed? %s" % self.ctfsClosed)
self.debug("Loading Coords.")
# Now load the coordinates for the newly detected micrographs. If
# micrographs does not have coordinates, is not processed.
micDict = self._loadInputCoords(micDict)
# Store this value to be used when inserting new steps and batch mode
self.streamClosed = self._isStreamClosed()
return micDict
def _loadInputCoords(self, micDict):
""" Load coordinates from the input streaming.
"""
# TODO: this takes for ever if you are NOT
# doing streaming and have several thousands of mics
# so I add a counter to keep the user entertained
import sys
a = datetime.now()
counter = 1
coordsFn = self.getCoords().getFileName()
self.debug("Loading input db: %s" % coordsFn)
coordSet = emobj.SetOfCoordinates(filename=coordsFn)
# FIXME: Temporary to avoid loadAllPropertiesFail
coordSet._xmippMd = pwobj.String()
coordSet.loadAllProperties()
micList = dict() # To store a dictionary with mics with coordinates
for micKey, mic in micDict.items():
if counter % 50 == 0:
b = datetime.now()
print(b - a, 'reading coordinates for mic number', "%06d" % counter)
sys.stdout.flush() # force buffer to print
counter += 1
micId = mic.getObjId()
coordList = []
self.debug("Loading coords for mic: %s (%s)" % (micId, micKey))
for coord in coordSet.iterItems(where='_micId=%s' % micId):
# TODO: Check performance penalty of using this clone
coordList.append(coord.clone())
self.debug("Coords found: %s" % len(coordList))
if coordList:
self.coordDict[micId] = coordList
micList[micKey] = mic
self.coordsClosed = coordSet.isStreamClosed()
coordSet.close()
self.debug("Coords are closed? %s" % self.coordsClosed)
self.debug("Closed db.")
return micList
def _checkNewInput(self):
self.debug(">>> _checkNewInput ")
def _modificationTime():
""" Check the last modification time of any of the three possible
input files. """
items = [self.inputCoordinates.get()]
if self._micsOther():
items.append(self.inputMicrographs.get())
else:
items.append(self.inputCoordinates.get().getMicrographs())
if self._useCTF():
items.append(self.ctfRelations.get())
def _mTime(fn):
return datetime.fromtimestamp(os.path.getmtime(fn))
return max([_mTime(i.getFileName()) for i in items])
mTime = _modificationTime()
now = datetime.now()
self.lastCheck = getattr(self, 'lastCheck', now)
self.debug('Last check: %s, modification: %s'
% (pwutils.prettyTime(self.lastCheck),
pwutils.prettyTime(mTime)))
# If the input micrographs.sqlite have not changed since our last check,
# it does not make sense to check for new input data, but we must
# check if sets are closed.
self.debug("self.lastCheck > mTime %s , hasattr(self, 'micDict') %s"
% (self.lastCheck > mTime, hasattr(self, 'micDict')))
if self.lastCheck > mTime and hasattr(self, 'micDict'):
return None
# Open input micrographs.sqlite and close it as soon as possible
newMics = self._loadInputList()
self.lastCheck = now
outputStep = self._getFirstJoinStep()
if newMics:
fDeps = self._insertNewMicsSteps(newMics.values())
if outputStep is not None:
outputStep.addPrerequisites(*fDeps)
self.updateSteps()
def _checkNewOutput(self):
if getattr(self, 'finished', False):
return
# Load previously done items (from text file)
doneList = self._readDoneList()
# Check for newly done items
newDone = [m for m in self.micDict.values()
if m.getObjId() not in doneList and self._isMicDone(m)]
# Update the file with the newly done mics
# or exit from the function if no new done mics
inputLen = len(self.micDict)
self.debug('_checkNewOutput: ')
self.debug(' input: %s, doneList: %s, newDone: %s'
% (inputLen, len(doneList), len(newDone)))
firstTime = len(doneList) == 0
allDone = len(doneList) + len(newDone)
# We have finished when there is not more input mics (stream closed)
# and the number of processed mics is equal to the number of inputs
streamClosed = self._isStreamClosed()
allMicsProcessed = self._areAllMicsProcessed()
self.finished = streamClosed and allDone == inputLen and allMicsProcessed
self.debug(' is finished? %s ' % self.finished)
self.debug(' is stream closed? %s ' % streamClosed)
streamMode = pwobj.Set.STREAM_CLOSED if self.finished else pwobj.Set.STREAM_OPEN
if newDone:
self._updateOutputPartSet(newDone, streamMode)
self._writeDoneList(newDone)
elif not self.finished:
# If we are not finished and no new output have been produced
# it does not make sense to proceed and updated the outputs
# so we exit from the function here
# Maybe it would be good idea to take a snap to avoid
# so much IO if this protocol does not have much to do now
if allDone == len(self.micDict):
self._streamingSleepOnWait()
return
self.debug(' finished: %s ' % self.finished)
self.debug(' self.streamClosed (%s) AND' % streamClosed)
self.debug(' allDone (%s) == len(self.listOfMics (%s)'
% (allDone, inputLen))
self.debug(' streamMode: %s' % streamMode)
if self.finished: # Unlock createOutputStep if finished all jobs
# Close the output set
self._updateOutputPartSet([], pwobj.Set.STREAM_CLOSED)
outputStep = self._getFirstJoinStep()
if outputStep and outputStep.isWaiting():
outputStep.setStatus(pwprot.STATUS_NEW)
[docs] def readPartsFromMics(self, micDoneList, outputParts):
""" This method should be implemented in subclasses to read
the particles from a given list of micrographs.
"""
pass
def _updateOutputPartSet(self, micList, streamMode):
outputName = ProtExtractParticlesOutput.outputParticles.name
outputParts = getattr(self, outputName, None)
firstTime = True
if outputParts is None:
inputMics = self.getInputMicrographs()
outputParts = self._createSetOfParticles()
outputParts.copyInfo(inputMics)
outputParts.setCoordinates(self.inputCoordinates)
if self.getAttributeValue('doFlip', False):
outputParts.setIsPhaseFlipped(not inputMics.isPhaseFlipped())
outputParts.setSamplingRate(self._getNewSampling())
outputParts.setHasCTF(self._useCTF())
else:
firstTime = False
outputParts.enableAppend()
self.readPartsFromMics(micList, outputParts)
self._updateOutputSet(outputName, outputParts, streamMode)
if firstTime:
# self._storeMethodsInfo(fnImages)
self._defineSourceRelation(self.inputCoordinates, outputParts)
if self._useCTF():
self._defineSourceRelation(self.ctfRelations, outputParts)
if self._micsOther():
self._defineSourceRelation(self.inputMicrographs, outputParts)
def _getMicDone(self, mic):
return self._getExtraPath('DONE', 'mic_%06d.TXT' % mic.getObjId())
def _isMicDone(self, mic):
""" A mic is done if the marker file exists. """
return os.path.exists(self._getMicDone(mic))
def _getAllDone(self):
return self._getExtraPath('DONE', 'all.TXT')
def _readDoneList(self):
""" Read from a text file the id's of the items that have been done. """
doneFile = self._getAllDone()
doneList = []
# Check what items have been previously done
if os.path.exists(doneFile):
with open(doneFile) as f:
doneList += [int(line.strip()) for line in f]
return doneList
def _writeDoneList(self, micList):
""" Write to a text file the items that have been done. """
doneFile = self._getAllDone()
if not os.path.exists(doneFile):
pwutils.makeFilePath(doneFile)
with open(doneFile, 'a') as f:
for mic in micList:
f.write('%d\n' % mic.getObjId())
def _getFirstJoinStepName(self):
# This function will be used for streaming, to check which is
# the first function that need to wait for all micrographs
# to have completed, this can be overwritten in subclasses
# (eg in Xmipp 'sortPSDStep')
return 'createOutputStep'
def _getFirstJoinStep(self):
for s in self._steps:
if s.funcName == self._getFirstJoinStepName():
return s
return None
[docs] def createOutputStep(self):
pass # Nothing to do now
# self._createOutput(self._getExtraPath())
[docs]class ProtExtractParticlesPair(ProtParticles):
""" Base class for all extract-particles pairs protocols. Until now,
this protocols is not in streaming mode.
"""