Source code for xmipp3.protocols.protocol_movie_alignment_consensus

# **************************************************************************
# *
# * Authors:    Carlos Oscar Sorzano (coss@cnb.csic.es)
# *             Daniel Marchán Torres (da.marchan@cnb.csic.es)  -- streaming version
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'coss@cnb.csic.es'
# *
# **************************************************************************
"""
Consensus alignment protocol
"""

import os
from datetime import datetime
from pyworkflow.gui.plotter import Plotter
import numpy as np
from math import ceil
try:
    from itertools import izip
except ImportError:
    izip = zip
from pwem.objects import SetOfMovies, SetOfMicrographs, MovieAlignment, Image
from pyworkflow.object import Set
import pyworkflow.protocol.params as params
from pyworkflow.protocol import STEPS_PARALLEL, Protocol
import pyworkflow.utils as pwutils
from pwem.protocols import ProtAlignMovies
from pyworkflow.protocol.constants import (STATUS_NEW)
from xmipp3.convert import getScipionObj
from pwem.constants import ALIGN_NONE
from pyworkflow import BETA, UPDATED, NEW, PROD

ACCEPTED = 'Accepted'
DISCARDED = 'Discarded'


[docs]class XmippProtConsensusMovieAlignment(ProtAlignMovies, Protocol): """ The protocol compares two sets of aligned movies (reference and secondary) to evaluate their alignment consistency. It calculates the correlation between shift trajectories, allowing for a minimum correlation threshold to be set. Movies with correlations below this threshold can be discarded. The protocol can also generate plots showing the trajectories and correlations for each movie. This helps in identifying and validating the quality of movie alignments based on consensus among different alignment runs. """ _label = 'movie alignment consensus' outputName = 'consensusAlignments' _devStatus = PROD def __init__(self, **args): ProtAlignMovies.__init__(self, **args) self.stepsExecutionMode = STEPS_PARALLEL def _defineParams(self, form): form.addSection(label='Input Consensus') form.addParam('inputMovies1', params.PointerParam, pointerClass='SetOfMovies', label="Reference Aligned Movies", important=True, help='Select the aligned movies to evaluate (this first set will give the global shifts)') form.addParam('inputMovies2', params.PointerParam, pointerClass='SetOfMovies', label="Secondary Aligned Movies", help='Shift to be compared with reference alignment') form.addParam('minConsCorrelation', params.FloatParam, default=-1, label='Minimum consensus shifts correlation', help="Minimum value for the consensus correlations between shifts trajectories." "\nIf there are noticeable discrepancies between the two estimations below this correlation," " it will be discarded. If this value is set to -1 no movies will be discarded." "\n Values near 1 will indicate that there are a clear correlation between shifts trajectories.") form.addParam('trajectoryPlot', params.BooleanParam, default=False, label='Global Alignment Trajectory Plot', help="This will generate a plot for each movie where the reference and the secondary trajectory" "will be plot in the same graph with its correlation value.") form.addParallelSection(threads=4) # --------------------------- INSERT steps functions ------------------------- def _insertAllSteps(self): self.initializeParams() self._insertFunctionStep('createOutputStep', prerequisites=[], wait=True)
[docs] def createOutputStep(self): self._closeOutputSet()
[docs] def initializeParams(self): self.finished = False self.insertedDict = {} self.processedDict = [] self.movieFn1 = self.inputMovies1.get().getFileName() self.movieFn2 = self.inputMovies2.get().getFileName() self.micsFn = self._getMicsPath() self.stats = {} self.isStreamClosed = self.inputMovies1.get().isStreamClosed() and \ self.inputMovies2.get().isStreamClosed() self.samplingRate = self.inputMovies1.get().getSamplingRate() self.acquisition = self.inputMovies1.get().getAcquisition() self.allMovies1 = {movie.getObjId(): movie.clone() for movie in self._loadInputMovieSet(self.movieFn1).iterItems()} self.allMovies2 = {movie.getObjId(): movie.clone() for movie in self._loadInputMovieSet(self.movieFn2).iterItems()} pwutils.makePath(self._getExtraPath('DONE'))
def _getFirstJoinStepName(self): # This function will be used for streaming, to check which is # the first function that need to wait for all movies # to have completed, this can be overriden in subclasses # (e.g., in Xmipp 'sortPSDStep') return 'createOutputStep' def _getFirstJoinStep(self): for s in self._steps: if s.funcName == self._getFirstJoinStepName(): return s return None def _stepsCheck(self): self._checkNewInput() self._checkNewOutput() def _checkNewInput(self): # Check if there are new movies to process from the input set self.lastCheck = getattr(self, 'lastCheck', datetime.now()) mTime = max(datetime.fromtimestamp(os.path.getmtime(self.movieFn1)), datetime.fromtimestamp(os.path.getmtime(self.movieFn2))) self.debug('Last check: %s, modification: %s' % (pwutils.prettyTime(self.lastCheck), pwutils.prettyTime(mTime))) # If the input movies.sqlite have not changed since our last check, # it does not make sense to check for new input data if self.lastCheck > mTime and self.processedDict: # If this is empty it is due to a static "continue" action or it is the first round return None movieSet1 = self._loadInputMovieSet(self.movieFn1) movieSet2 = self._loadInputMovieSet(self.movieFn2) movieDict1 = {movie.getObjId(): movie.clone() for movie in movieSet1.iterItems()} movieDict2 = {movie.getObjId(): movie.clone() for movie in movieSet2.iterItems()} newIds1 = [idMovie for idMovie in movieDict1.keys() if idMovie not in self.processedDict] self.allMovies1.update(movieDict1) newIds2 = [idMovie for idMovie in movieDict2.keys() if idMovie not in self.processedDict] self.allMovies2.update(movieDict2) self.lastCheck = datetime.now() self.isStreamClosed = movieSet1.isStreamClosed() and \ movieSet2.isStreamClosed() movieSet1.close() movieSet2.close() outputStep = self._getFirstJoinStep() if len(set(self.allMovies1)) > len(set(self.processedDict)) and \ len(set(self.allMovies2)) > len(set(self.processedDict)): fDeps = self._insertNewMovieSteps(newIds1, newIds2, self.insertedDict) if outputStep is not None: outputStep.addPrerequisites(*fDeps) self.updateSteps() def _insertNewMovieSteps(self, movies1Dict, movies2Dict, insDict): deps = [] newIDs = list(set(movies1Dict).intersection(set(movies2Dict))) for movieID in newIDs: if movieID not in insDict: stepId = self._insertFunctionStep('alignmentCorrelationMovieStep', movieID, prerequisites=[]) deps.append(stepId) insDict[movieID] = stepId self.processedDict.append(movieID) return deps
[docs] def alignmentCorrelationMovieStep(self, movieId): movie1 = self.allMovies1.get(movieId) movie2 = self.allMovies2.get(movieId) doneFn = self._getMovieDone(movieId) if self.isContinued() and self._isMovieDone(movieId): self.info("Skipping movie with ID: %s, seems to be done" % movieId) return # Clean old finished files pwutils.cleanPath(doneFn) if (movie1 is None) or (movie2 is None): self.info('AlignmentCorrelationMovieStep movie1 or movie2 are None') return alignment1 = movie1.getAlignment() alignment2 = movie2.getAlignment() shiftX_1, shiftY_1 = alignment1.getShifts() shiftX_2, shiftY_2 = alignment2.getShifts() # Transformation of the shifts to calculate the shifts trajectory correlation S1 = np.ones([3, len(shiftX_1)]) S2 = np.ones([3, len(shiftX_2)]) S1[0, :] = shiftX_1 S1[1, :] = shiftY_1 S2[0, :] = shiftX_2 S2[2, :] = shiftY_2 A = np.dot(np.dot(S1, S2.T), np.linalg.inv(np.dot(S2, S2.T))) S2_p = np.dot(A, S2) S1_cart = np.array([S1[0, :]/S1[2, :], S1[1, :]/S1[2, :]]) print("S1cart= ", S1_cart) S2_p_cart = np.array([S2_p[0, :] / S2_p[2, :], S2_p[1, :] / S2_p[2, :]]) print("S2pcart= ", S2_p_cart) rmse_cart = np.sqrt((np.square(S1_cart - S2_p_cart)).mean()) maxe_cart = np.max(S1_cart - S2_p_cart) corrX_cart = np.corrcoef(S1_cart[0, :], S2_p_cart[0, :])[0, 1] corrY_cart = np.corrcoef(S1_cart[1, :], S2_p_cart[1, :])[0, 1] corr_cart = np.min([corrY_cart, corrX_cart]) self.info('Root Mean Squared Error %f' % rmse_cart) self.info('General Corr min(corrX, corrY) %f' % corr_cart) if corr_cart >= self.minConsCorrelation.get(): self.info('Movie with id %d has a correlated alignment shift trajectory' %movieId) fn = self._getMovieSelecFileAccepted() with open(fn, 'a') as f: f.write('%d T\n' % movieId) elif corr_cart < self.minConsCorrelation.get(): self.info('Movie with id %d has discrepancy in the alignment with correlation %f' % (movieId, corr_cart)) fn = self._getMovieSelecFileDiscarded() with open(fn, 'a') as f: f.write('%d F\n' % movieId) stats_loc = {'shift_corr': corr_cart, 'shift_corr_X': corrX_cart, 'shift_corr_Y': corrY_cart, 'max_error': maxe_cart, 'rmse_error': rmse_cart, 'S1_cart': S1_cart, 'S2_p_cart': S2_p_cart} self.stats[movieId] = stats_loc self._store() # Mark this movie as finished open(doneFn, 'w').close()
def _checkNewOutput(self): """ Check for already selected movies and update the output set. """ # Load previously done items (from text file) doneListDiscarded = self._readCertainDoneList(DISCARDED) doneListAccepted = self._readCertainDoneList(ACCEPTED) # Check for newly done items movieListIdAccepted = self._readtMovieId(True) movieListIdDiscarded = self._readtMovieId(False) newDoneAccepted = [movieId for movieId in movieListIdAccepted if movieId not in doneListAccepted] newDoneDiscarded = [movieId for movieId in movieListIdDiscarded if movieId not in doneListDiscarded] firstTimeAccepted = len(doneListAccepted) == 0 firstTimeDiscarded = len(doneListDiscarded) == 0 allDone = len(doneListAccepted) + len(doneListDiscarded) +\ len(newDoneAccepted) + len(newDoneDiscarded) # We have finished when there is not more input movies (stream closed) # and the number of processed movies is equal to the number of inputs maxMovieSize = len(set(self.allMovies1).intersection(set(self.allMovies2))) self.finished = (self.isStreamClosed and allDone == maxMovieSize) streamMode = Set.STREAM_CLOSED if self.finished else Set.STREAM_OPEN def readOrCreateOutputs(doneList, newDone, label=''): if len(doneList) > 0 or len(newDone) > 0: with self._lock: movSet = self._loadOutputSet(SetOfMovies, 'movies'+label+'.sqlite') micSet = self._loadOutputSet(SetOfMicrographs, 'micrographs'+label+'.sqlite') label = ACCEPTED if label == '' else DISCARDED self.fillOutput(movSet, micSet, newDone, label) movSet.setSamplingRate(self.samplingRate) micSet.setSamplingRate(self.samplingRate) micSet.setAcquisition(self.acquisition.clone()) movSet.setAcquisition(self.acquisition.clone()) return movSet, micSet return None, None movieSet, micSet = readOrCreateOutputs(doneListAccepted, newDoneAccepted) movieSetDiscarded, micSetDiscarded = readOrCreateOutputs(doneListDiscarded, newDoneDiscarded, DISCARDED) if not self.finished and not newDoneDiscarded and not newDoneAccepted: # If we are not finished and no new output have been produced # it does not make sense to proceed and updated the outputs # so we exit from the function here return def updateRelationsAndClose(movieSet, micSet, first, label=''): if os.path.exists(self._getPath('movies'+label+'.sqlite')): micsAttrName = 'outputMicrographs'+label self._updateOutputSet(micsAttrName, micSet, streamMode) self._updateOutputSet('outputMovies'+label, movieSet, streamMode) if first: # We consider that Movies are 'transformed' into the Micrographs # This will allow to extend the micrograph associated to a set of # movies to another set of micrographs generated from a # different movie alignment self._defineTransformRelation(self.inputMovies1, micSet) micSet.close() movieSet.close() updateRelationsAndClose(movieSet, micSet, firstTimeAccepted) updateRelationsAndClose(movieSetDiscarded, micSetDiscarded, firstTimeDiscarded, DISCARDED) if self.finished: # Unlock createOutputStep if finished all jobs outputStep = self._getFirstJoinStep() if outputStep and outputStep.isWaiting(): outputStep.setStatus(STATUS_NEW)
[docs] def fillOutput(self, movieSet, micSet, newDone, label): if newDone: inputMovieSet = self._loadInputMovieSet(self.movieFn1) inputMicSet = self._loadInputMicrographSet(self.micsFn) for movieId in newDone: movie = inputMovieSet[movieId].clone() mic = inputMicSet[movieId].clone() movie.setEnabled(self._getEnable(movieId)) mic.setEnabled(self._getEnable(movieId)) alignment1 = movie.getAlignment() shiftX_1, shiftY_1 = alignment1.getShifts() setAttribute(mic, '_alignment_corr', self.stats[movieId]['shift_corr']) setAttribute(mic, '_alignment_rmse_error', self.stats[movieId]['rmse_error']) setAttribute(mic, '_alignment_max_error', self.stats[movieId]['max_error']) alignment = MovieAlignment(xshifts=shiftX_1, yshifts=shiftY_1) movie.setAlignment(alignment) self._writeCertainDoneList(movieId, label) if self.trajectoryPlot.get(): firstFrame, _, _ = self.inputMovies1.get().getFramesRange() self._createAndSaveTrajectoriesPlot(movieId, firstFrame, self.samplingRate) mic.plotCart = Image() mic.plotCart.setFileName(self._getTrajectoriesPlot(movieId)) movieSet.append(movie) micSet.append(mic) inputMovieSet.close() inputMicSet.close()
def _loadOutputSet(self, SetClass, baseName, fixSampling=True): """ Load the output set if it exists or create a new one. """ setFile = self._getPath(baseName) if os.path.exists(setFile) and os.path.getsize(setFile) > 0: outputSet = SetClass(filename=setFile) outputSet.loadAllProperties() outputSet.enableAppend() else: outputSet = SetClass(filename=setFile) outputSet.setStreamState(outputSet.STREAM_OPEN) inputMovies = self.inputMovies1.get() outputSet.copyInfo(inputMovies) if fixSampling: newSampling = inputMovies.getSamplingRate() * self._getBinFactor() outputSet.setSamplingRate(newSampling) return outputSet def _loadInputMovieSet(self, moviesFn): self.debug("Loading input db: %s" % moviesFn) movieSet = SetOfMovies(filename=moviesFn) movieSet.loadAllProperties() movieSet.close() self.debug("Closed db.") return movieSet def _loadInputMicrographSet(self, micsFn): self.debug("Loading input db: %s" % micsFn) micSet = SetOfMicrographs(filename=micsFn) micSet.loadAllProperties() micSet.close() self.debug("Closed db.") return micSet def _summary(self): # return message pass def _validate(self): """ The function of this hook is to add some validation before the protocol is launched to be executed. It should return a list of errors. If the list is empty the protocol can be executed. """ errors = [] if (self.inputMovies1.get().hasAlignment() == ALIGN_NONE) or \ (self.inputMovies2.get().hasAlignment() == ALIGN_NONE): errors.append("The inputs ( _Input Movies 1_ or _Input Movies 2_ must be aligned before") return errors # ------------------------------------ Utils functions ------------------------------------ def _isMovieDone(self, id): """ A movie is done if the marker file exists. """ return os.path.exists(self._getMovieDone(id)) def _getMovieDone(self, id): """ Return the file that is used as a flag of termination. """ return self._getExtraPath('DONE', 'movie_%06d.TXT' % id) def _readDoneList(self): """ Read from a file the id's of the items that have been done. """ doneFile = self._getAllDone() doneList = [] # Check what items have been previously done if os.path.exists(doneFile): with open(doneFile) as f: doneList += [int(line.strip()) for line in f] return doneList def _getAllDone(self): return self._getExtraPath('DONE_all.TXT') def _writeDoneList(self, partList): """ Write to a text file the items that have been done. """ with open(self._getAllDone(), 'a') as f: for part in partList: f.write('%d\n' % part.getObjId()) def _getMicsPath(self): prot1 = self.inputMovies1.getObjValue() # pointer to previous protocol if hasattr(prot1, 'outputMicrographs'): path1 = prot1.outputMicrographs.getFileName() if os.path.getsize(path1) > 0: return path1 elif hasattr(prot1, 'outputMicrographsDoseWeighted'): path2 = prot1.outputMicrographsDoseWeighted.getFileName() if os.path.getsize(path2) > 0: return path2 else: return None def _readCertainDoneList(self, label): """ Read from a text file the id's of the items that have been done. """ doneFile = self._getCertainDone(label) doneList = [] # Check what items have been previously done if os.path.exists(doneFile): with open(doneFile) as f: doneList += [int(line.strip()) for line in f] return doneList def _writeCertainDoneList(self, movieId, label): """ Write to a text file the items that have been done. """ doneFile = self._getCertainDone(label) with open(doneFile, 'a') as f: f.write('%d\n' % movieId) def _createAndSaveTrajectoriesPlot(self, movieId, first, pixSize): """ Write to a text file the items that have been done. """ stats = self.stats[movieId] fn = self._getExtraPath('global_trajectories_%d' %movieId+'_plot_cart.png') shift_X1 = stats['S1_cart'][0, :] shift_Y1 = stats['S1_cart'][1, :] shift_X2 = stats['S2_p_cart'][0, :] shift_Y2 = stats['S2_p_cart'][1, :] # ---------------- PLOT ----------------------- sumMeanX1 = [] sumMeanY1= [] sumMeanX2 = [] sumMeanY2 = [] def px_to_ang(px): y1, y2 = px.get_ylim() x1, x2 = px.get_xlim() ax_ang2.set_ylim(y1 * pixSize, y2 * pixSize) ax_ang.set_xlim(x1 * pixSize, x2 * pixSize) ax_ang.figure.canvas.draw() ax_ang2.figure.canvas.draw() figureSize = (6, 4) plotter = Plotter(*figureSize) figure = plotter.getFigure() ax_px = figure.add_subplot(111) ax_px.grid() ax_px.set_xlabel('Shift x (px)') ax_px.set_ylabel('Shift y (px)') ax_px.set_xlabel('Shift x (px) (CorrX:%.3f)' % stats['shift_corr_X']) ax_px.set_ylabel('Shift y (px) (CorrX:%.3f)' % stats['shift_corr_Y']) ax_ang = ax_px.twiny() ax_ang.set_xlabel('Shift x (A)') ax_ang2 = ax_px.twinx() ax_ang2.set_ylabel('Shift y (A)') i = first # The output and log files list the shifts relative to the first frame. # ROB unit seems to be pixels since sampling rate is only asked # by the program if dose filtering is required skipLabels = ceil(len(shift_X1) / 10.0) labelTick = 1 for x1, y1, x2, y2 in zip(shift_X1, shift_Y1, shift_X2, shift_Y2): sumMeanX1.append(x1) sumMeanY1.append(y1) sumMeanX2.append(x2) sumMeanY2.append(y2) if labelTick == 1: ax_px.text(x1 - 0.02, y1 + 0.02, str(i)) labelTick = skipLabels else: labelTick -= 1 i += 1 # automatically update lim of ax_ang when lim of ax_px changes. ax_px.callbacks.connect("ylim_changed", px_to_ang) ax_px.callbacks.connect("xlim_changed", px_to_ang) ax_px.plot(sumMeanX1, sumMeanY1, color='b', label='reference shifts') ax_px.plot(sumMeanX2, sumMeanY2, color='r', label='target shifts') ax_px.plot(sumMeanX1, sumMeanY1, 'yo') ax_px.plot(sumMeanX1[0], sumMeanY1[0], 'ro', markersize=10, linewidth=0.5) ax_px.set_title('Global frame alignment') ax_px.legend() plotter.tightLayout() plotter.savefig(fn) plotter.close() def _getTrajectoriesPlot(self, movieId): """ Write to a text file the items that have been done. """ return self._getExtraPath('global_trajectories_%d' %movieId+'_plot_cart.png') def _getCertainDone(self, label): return self._getExtraPath('DONE_'+label+'.TXT') def _getMovieSelecFileAccepted(self): return self._getExtraPath('selection-movie-accepted.txt') def _getMovieSelecFileDiscarded(self): return self._getExtraPath('selection-movie-discarded.txt') def _readtMovieId(self, accepted): if accepted: fn = self._getMovieSelecFileAccepted() else: fn = self._getMovieSelecFileDiscarded() moviesList = [] # Check what items have been previously done if os.path.exists(fn): with open(fn) as f: moviesList += [int(line.strip().split()[0]) for line in f] return moviesList def _getEnable(self, movieId): fn = self._getMovieSelecFileAccepted() # Check what items have been previously done if os.path.exists(fn): with open(fn) as f: for line in f: if movieId == int(line.strip().split()[0]): if line.strip().split()[1] == 'T': return True else: return False
[docs]def setAttribute(obj, label, value): if value is None: return setattr(obj, label, getScipionObj(value))