import sys
import emtable
import os
from datetime import datetime
import time
from pyworkflow.utils import prettyTime
from pyworkflow import VERSION_3_0
from pyworkflow.object import Set
from pyworkflow.protocol.params import IntParam
from pyworkflow.protocol import ProtStreamingBase, STEPS_PARALLEL
from pyworkflow.constants import BETA
from pwem.objects import SetOfClasses2D, SetOfAverages
from pwem.constants import ALIGN_NONE, ALIGN_2D
from xmipp3.convert import (readSetOfParticles, writeSetOfParticles,
from xmipp3.protocols.protocol_classify_pca import updateEnviron, XmippProtClassifyPca
OUTPUT_CLASSES = "outputClasses"
OUTPUT_AVERAGES = "outputAverages"
PCA_FILE = "pca_done.txt"
CLASSIFICATION_FILE = "classification_done.txt"
LAST_DONE_FILE = "last_done.txt"
[docs]class XmippProtClassifyPcaStreaming(ProtStreamingBase, XmippProtClassifyPca):
""" Classifies a set of images. """
_label = '2D classification pca streaming'
_lastUpdateVersion = VERSION_3_0
_conda_env = 'xmipp_pyTorch'
_devStatus = BETA
# Mode
_possibleOutputs = {OUTPUT_CLASSES: SetOfClasses2D,
def __init__(self, **args):
XmippProtClassifyPca.__init__(self, **args)
self.stepsExecutionMode = STEPS_PARALLEL
# --------------------------- DEFINE param functions ------------------------
def _defineParams(self, form):
form = self._defineCommonParams(form)
form.addParam('classificationBatch', IntParam, default=50000,
condition="not mode",
label="particles for initial classification",
help='Number of particles for an initial classification to compute the 2D references')
form.addParallelSection(threads=3, mpi=1)
# --------------------------- INSERT steps functions ----------------------
[docs] def stepsGeneratorStep(self) -> None:
This step should be implemented by any streaming protocol.
It should check its input and when ready conditions are met
call the self._insertFunctionStep method.
self.newDeps = []
newParticlesSet = self._loadEmptyParticleSet()
if self.isContinued() and self._isPcaDone():
self.info('Continue protocol')
while not self.finish:
if not self._newParticlesToProcess():
self.info('No new particles')
with self._lock: # This lock needs to be here since the classifyItems access the inputSet
# and can create a race condition
particlesSet = self._loadInputParticleSet()
self.streamState = particlesSet.getStreamState()
where = None
if self.lastCreationTime:
where = 'creation>"' + str(self.lastCreationTime) + '"'
for particle in particlesSet.iterItems(orderBy='creation', direction='ASC', where=where):
tmp = particle.getObjCreation()
self.lastCreationTime = tmp
self.info('%d new particles' % len(newParticlesSet))
self.info('Last creation time REGISTER %s' % self.lastCreationTime)
# ------------------------------------ PCA TRAINING ----------------------------------------------
if self._doPcaTraining(newParticlesSet):
self.pcaStep = self._insertFunctionStep(self.runPCASteps, newParticlesSet, prerequisites=[])
if len(newParticlesSet) == len(self.inputParticles.get()) and self.streamState == Set.STREAM_CLOSED:
self.staticRun = True
self.info('Static Run')
# ------------------------------------ CLASSIFICATION ----------------------------------------------
if self._doClassification(newParticlesSet):
self._insertClassificationSteps(newParticlesSet, self.lastCreationTime)
newParticlesSet = self._loadEmptyParticleSet()
if self.streamState == Set.STREAM_CLOSED:
self.info('Stream closed')
# Finish everything and close output sets
if len(newParticlesSet):
self.info('Finish processing with last batch %d' % len(newParticlesSet))
self.lastRound = True
self._insertFunctionStep(self.closeOutputStep, prerequisites=self.newDeps)
self.finish = True
continue # To avoid waiting 1 min
sys.stdout.flush() # One last flush
# --------------------------- STEPS functions -------------------------------
def _initialStep(self):
self.finish = False
self.lastCreationTime = 0
self.lastCreationTimeProcessed = 0
self.streamState = Set.STREAM_OPEN
self.lastRound = False
self.pcaLaunch = False
self.classificationLaunch = False
self.classificationRound = 0
self.firstTimeDone = False
self.staticRun = False
# Initialize files
def _initFnStep(self):
self.imgsPcaXmd = self._getExtraPath('images_pca.xmd')
self.imgsPcaXmdOut = self._getTmpPath('images_pca.xmd') # Wiener
self.imgsPcaFn = self._getTmpPath('images_pca.mrc')
self.imgsOrigXmd = self._getExtraPath('imagesInput_.xmd')
self.imgsXmd = self._getTmpPath('images_.xmd')
self.imgsFn = self._getTmpPath('images_.mrc')
self.refXmd = self._getTmpPath('references.xmd')
self.ref = self._getExtraPath('classes.mrcs')
self.sigmaProt = self.sigma.get()
if self.sigmaProt == -1:
self.sigmaProt = self.inputParticles.get().getDimensions()[0] / 3
self.sampling = self.inputParticles.get().getSamplingRate()
resolution = self.resolution.get()
if resolution < 2 * self.sampling:
resolution = (2 * self.sampling) + 0.5
self.resolutionPca = resolution
if self.mode == self.UPDATE_CLASSES:
self.numberClasses = len(self.initialClasses.get())
self.numberClasses = self.numberOfClasses.get()
[docs] def runPCASteps(self, newParticlesSet):
# Run PCA steps
self.pcaLaunch = True
if self.correctCtf:
self.convertInputStep(newParticlesSet, self.imgsPcaXmd, self.imgsPcaXmdOut) # Wiener filter
self.convertInputStep(newParticlesSet, self.imgsPcaXmd, self.imgsPcaFn)
numTrain = min(len(newParticlesSet), self.training.get())
self.pcaTraining(self.imgsPcaFn, self.resolutionPca, numTrain)
self.pcaLaunch = False
def _insertClassificationSteps(self, newParticlesSet, lastCreationTime):
classStep = self._insertFunctionStep(self.runClassificationSteps,
newParticlesSet, prerequisites=self.pcaStep)
updateStep = self._insertFunctionStep(self.updateOutputSetOfClasses,
lastCreationTime, Set.STREAM_OPEN, prerequisites=classStep)
[docs] def runClassificationSteps(self, newParticlesSet):
self.classificationLaunch = True
self.convertInputStep(newParticlesSet, self.imgsOrigXmd, self.imgsFn)
self.classification(self.imgsFn, self.numberClasses,
self.imgsOrigXmd, self.mask.get(), self.sigmaProt)
self.classificationLaunch = False
[docs] def pcaTraining(self, inputIm, resolutionTrain, numTrain):
args = ' -i %s -s %s -hr %s -lr 530 -p %s -t %s -o %s/train_pca --batchPCA' % \
(inputIm, self.sampling, resolutionTrain, self.coef.get(), numTrain, self._getExtraPath())
env = self.getCondaEnv()
env = self._setEnvVariables(env)
self.runJob("xmipp_classify_pca_train", args, numberOfMpi=1, env=env)
[docs] def classification(self, inputIm, numClass, stfile, mask, sigma):
args = ' -i %s -c %s -b %s/train_pca_bands.pt -v %s/train_pca_vecs.pt -o %s/classes -stExp %s' % \
(inputIm, numClass, self._getExtraPath(), self._getExtraPath(), self._getExtraPath(),
if mask:
args += ' --mask --sigma %s ' % (sigma)
if self.mode == self.UPDATE_CLASSES or self._isClassificationDone():
args += ' -r %s ' % self.ref
env = self.getCondaEnv()
env = self._setEnvVariables(env)
self.runJob("xmipp_classify_pca", args, numberOfMpi=1, env=env)
[docs] def updateOutputSetOfClasses(self, lastCreationTime, streamMode):
outputClasses, update = self._loadOutputSet(outputName)
self._fillClassesFromLevel(outputClasses, update)
self._updateOutputSet(outputName, outputClasses, streamMode)
if not update: # First time
self._defineSourceRelation(self._getInputPointer(), outputClasses)
self.numberClasses = len(outputClasses) # In case the original number of classes is not reached
self.lastCreationTimeProcessed = lastCreationTime
self.info(r'Last creation time processed UPDATED is %s' % str(self.lastCreationTimeProcessed))
self.info(r'Last classification round processed is %d' % self.classificationRound)
self.classificationRound += 1
def _updateOutputAverages(self, update):
outRefs = self._loadOutputAverageSet()
readSetOfParticles(self.ref, outRefs)
self._updateOutputSet(OUTPUT_AVERAGES, outRefs, Set.STREAM_CLOSED)
if not update: # First Time
self._defineSourceRelation(self._getInputPointer(), outRefs)
[docs] def closeOutputStep(self):
# --------------------------- UTILS functions -----------------------------
def _loadInputParticleSet(self):
""" Returns te input set of particles"""
partSet = self.inputParticles.get()
return partSet
def _getInputPointer(self):
return self.inputParticles
def _loadEmptyParticleSet(self):
partSet = self.inputParticles.get()
self.acquisition = partSet.getAcquisition()
copyPartSet = self._createSetOfParticles()
return copyPartSet
def _setEnvVariables(self, env):
""" Method to set all the environment variables needed to run PCA program """
env['LD_LIBRARY_PATH'] = ''
# Limit the number of threads
env['OMP_NUM_THREADS'] = '12'
env['MKL_NUM_THREADS'] = '12'
env['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
return env
def _updateFnClassification(self):
"""Update input based on the iteration it is"""
self.imgsOrigXmd = updateFileName(self.imgsOrigXmd, self.classificationRound)
self.imgsXmd = updateFileName(self.imgsXmd, self.classificationRound)
self.imgsFn = updateFileName(self.imgsFn, self.classificationRound)
self.info('Starts classification round: %d' % self.classificationRound)
def _newParticlesToProcess(self):
particlesFile = self.inputParticles.get().getFileName()
now = datetime.now()
self.lastCheck = getattr(self, 'lastCheck', now)
mTime = datetime.fromtimestamp(os.path.getmtime(particlesFile))
self.debug('Last check: %s, modification: %s'
% (self.lastCheck,
# If the input have not changed since our last check,
# it does not make sense to check for new input data
if (self.lastCheck > mTime and self.lastCreationTime) and not self.lastRound:
newParticlesBool = False
newParticlesBool = True
self.lastCheck = now
return newParticlesBool
def _fillClassesFromLevel(self, clsSet, update=False):
""" Create the SetOfClasses2D from a given iteration. """
mdIter = emtable.Table.iterRows('particles@' + self._getExtraPath('classes_images.star'))
params = {}
if update:
self.info(r'Last creation time processed is %s' % str(self.lastCreationTimeProcessed))
params = {"where": 'creation>"' + str(self.lastCreationTimeProcessed) + '"'}
with self._lock:
itemDataIterator=mdIter, # relion style
doClone=False, # So the creation time is maintained
raiseOnNextFailure=False) # So streaming can happen
def _loadOutputSet(self, outputName):
Load the output set if it exists or create a new one.
outputSet = getattr(self, outputName, None)
update = False
if outputSet is None:
outputSet = self._createSetOfClasses2D(self._getInputPointer())
update = True
return outputSet, update
def _loadOutputAverageSet(self):
Load an empty output setOfAverages
outputRefs = self._createSetOfAverages() # We need to create always an empty set since we need to rebuild it
return outputRefs
def _doPcaTraining(self, newParticlesSet):
""" Two cases for launching PCA steps:
- If there are enough particles and PCA has not been launched or finished.
- Launch if it's the last round of new particles and PCA has not been launched or finished. """
return (len(newParticlesSet) >= self.training.get() and not self._isPcaDone() and not self.pcaLaunch) or \
(self.lastRound and not self._isPcaDone() and not self.pcaLaunch)
def _doClassification(self, newParticlesSet):
""" Three cases for launching Classification
- First round of classification: enough particles and PCA done
- Update classification: classification done, PCA done and enough batch size
- First or update classification with a smaller batch: last round of particles and not classification launch
return (len(newParticlesSet) >= self.classificationBatch.get() and self._isPcaDone()
and not self._isClassificationDone() and not self.classificationLaunch) \
or (len(newParticlesSet) >= BATCH_UPDATE and self._isClassificationDone() and self._isPcaDone()
and not self.classificationLaunch) \
or (self.lastRound and not self.classificationLaunch)
def _isPcaDone(self):
done = False
if os.path.exists(self._getExtraPath(PCA_FILE)):
done = True
return done
def _setPcaDone(self):
with open(self._getExtraPath(PCA_FILE), "w") as file:
file.write('%d' % self.pcaStep)
def _getPcaStep(self):
" If continue then use this to collect the code for the PCA step"
with open(self._getExtraPath(PCA_FILE), "r") as file:
content = file.read()
return int(content)
def _isClassificationDone(self):
done = False
if os.path.exists(self._getExtraPath(CLASSIFICATION_FILE)):
done = True
if self.mode == self.UPDATE_CLASSES:
done = True
return done
def _setClassificationDone(self):
with open(self._getExtraPath(CLASSIFICATION_FILE), "w"):
self.debug("Creating Classification DONE file")
def _writeLastClassificationRound(self, classificationRound):
with open(self._getExtraPath(CLASSIFICATION_FILE), "w") as file:
file.write('%d' % classificationRound)
def _getLastClassificationRound(self):
with open(self._getExtraPath(CLASSIFICATION_FILE), "r") as file:
content = file.read()
return int(content)
def _writeLastDone(self, creationTime):
""" Write to a text file the last item creation time done. """
with open(self._getExtraPath(LAST_DONE_FILE), 'w') as file:
file.write('%s' % creationTime)
def _getLastDone(self):
# Open the file in read mode and read the number
with open(self._getExtraPath(LAST_DONE_FILE), "r") as file:
content = file.read()
return str(content)
def _updateVarsToContinue(self):
""" Method to if needed and the protocol is set to continue then it will see in which state it was stopped """
self.pcaStep = []
if self._isClassificationDone():
self.lastCreationTime = self._getLastDone()
self.classificationRound = self._getLastClassificationRound() + 1 # Since this is the last processed
self.lastCreationTime = 0
self.classificationRound = 0
self.lastCreationTimeProcessed = self.lastCreationTime
# Convert the string to a datetime object
self.lastCheck = datetime.strptime(self.lastCreationTime, '%Y-%m-%d %H:%M:%S')
# --------------------------- Static functions --------------------------------
[docs]def updateFileName(filepath, round):
filename = os.path.basename(filepath)
newFilename = f"{filename[:filename.find('_')]}_{round}{filename[filename.rfind('.'):]}"
return os.path.join(os.path.dirname(filepath), newFilename)