# **************************************************************************
# *
# * Authors: J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [1]
# *
# * [1] SciLifeLab, Stockholm University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import os
from os.path import abspath
import numpy as np
import pyworkflow as pw
import pyworkflow.protocol.params as params
from pyworkflow.utils import removeBaseExt
from pyworkflow.utils.properties import Message
from pwem.emlib.image import ImageHandler
from ..objects import TiltSeries, TiltImage, SetOfTiltSeries
from .protocol_ts_base import ProtTsProcess
OUTPUT_TILT_SERIES_ODD = 'outputTiltSeriesOdd'
OUTPUT_TILT_SERIES_EVEN = 'outputTiltSeriesEven'
EVEN = 'even'
ODD = 'odd'
OUTPUT_TILT_SERIES_DW = 'outputTiltSeriesDW'
[docs]class ProtTsCorrectMotion(ProtTsProcess):
"""
Base class for movie alignment protocols such as:
motioncorr, crosscorrelation and optical flow
Alignment parameters are defined in common. For example,
the frames range used for alignment and final sum, the binning factor
or the cropping options (region of interest)
"""
# Attributes used for even/odd frames splitting if requested
evenAvgFrameList = []
oddAvgFrameList = []
tsMList = []
# Even / odd functionality
evenOddCapable = False
outputSetEven = None
outputSetOdd = None
# -------------------------- DEFINE param functions -----------------------
def _defineParams(self, form):
form.addSection(label=Message.LABEL_INPUT)
form.addParam('inputTiltSeriesM', params.PointerParam,
pointerClass='SetOfTiltSeriesM',
important=True,
label='Input Tilt-Series (movies)',
help='Select input tilt-series movies that you want'
'to correct for beam-induced motion. ')
group = form.addGroup('Alignment')
line = group.addLine('Frames to ALIGN',
help='Frames range to ALIGN on each movie. The '
'first frame is 1. If you set 0 in the final '
'frame to align, it means that you will '
'align until the last frame of the movie.')
line.addParam('alignFrame0', params.IntParam, default=1,
label='from')
line.addParam('alignFrameN', params.IntParam, default=0,
label='to')
group.addParam('useAlignToSum', params.BooleanParam, default=True,
label='Use ALIGN frames range to SUM?',
help="If *Yes*, the same frame range will be used to "
"ALIGN and to SUM. If *No*, you can selected a "
"different range for SUM (must be a subset).")
line = group.addLine('Frames to SUM', condition="not useAlignToSum",
help='Frames range to SUM on each movie. The '
'first frame is 1. If you set 0 in the final '
'frame to sum, it means that you will sum '
'until the last frame of the movie.')
line.addParam('sumFrame0', params.IntParam, default=1,
label='from')
line.addParam('sumFrameN', params.IntParam, default=0,
label='to')
group.addParam('binFactor', params.FloatParam, default=1.,
label='Binning factor',
help='1x or 2x. Bin stack before processing.')
line = group.addLine('Crop offsets (px)',
expertLevel=params.LEVEL_ADVANCED)
line.addParam('cropOffsetX', params.IntParam, default=0, label='X')
line.addParam('cropOffsetY', params.IntParam, default=0, label='Y')
line = group.addLine('Crop dimensions (px)',
expertLevel=params.LEVEL_ADVANCED,
help='How many pixels to crop from offset\n'
'If equal to 0, use maximum size.')
line.addParam('cropDimX', params.IntParam, default=0, label='X')
line.addParam('cropDimY', params.IntParam, default=0, label='Y')
if self.evenOddCapable:
form.addParam('splitEvenOdd', params.BooleanParam,
default=False,
label='Split & sum odd/even frames?',
expertLevel=params.LEVEL_ADVANCED,
help='(Used for denoising data preparation). If set to Yes, 2 additional movies/tilt '
'series will be generated, one generated from the even frames and the other from the '
'odd ones using the same alignment for the whole stack of frames.')
form.addParallelSection(threads=4, mpi=1)
# --------------------------- STEPS functions ----------------------------
def _doSplitEvenOdd(self):
""" Returns if even/odd stuff has to be done"""
if not self.evenOddCapable:
return False
else:
return self.splitEvenOdd.get()
[docs] def processTiltImageStep(self, tsId, tiltImageId, *args):
tiltImageM = self._tsDict.getTi(tsId, tiltImageId)
workingFolder = self.__getTiltImageMWorkingFolder(tiltImageM)
pw.utils.makePath(workingFolder)
self._processTiltImageM(workingFolder, tiltImageM, *args)
if self._doSplitEvenOdd():
baseName = removeBaseExt(tiltImageM.getFileName())
evenName = abspath(self._getExtraPath(baseName + '_avg_' + EVEN))
oddName = abspath(self._getExtraPath(baseName + '_avg_' + ODD))
alignedFrameStack = self._getExtraPath(baseName + '_aligned_movie.mrcs')
# Get even/odd xmd files
args = '--img %s ' % abspath(alignedFrameStack)
args += '--type frames '
args += '-o %s ' % (evenName + '.xmd')
args += '-e %s ' % (oddName + '.xmd')
args += '--sum_frames '
self.runJob('xmipp_image_odd_even', args)
# Store the corresponding tsImM to use its data later in the even/odd TS
self.tsMList.append(tiltImageM)
# Update even and odd average lists
self.evenAvgFrameList.append(evenName + '_aligned.mrc')
self.oddAvgFrameList.append(oddName + '_aligned.mrc')
pw.utils.cleanPath(alignedFrameStack)
tiFn, tiFnDW = self._getOutputTiltImagePaths(tiltImageM)
if not os.path.exists(tiFn):
raise Exception("Expected output file '%s' not produced!" % tiFn)
if not pw.utils.envVarOn('SCIPION_DEBUG_NOCLEAN'):
pw.utils.cleanPath(workingFolder)
[docs] def processTiltSeriesStep(self, tsId):
""" Create a single stack with the tiltseries. """
def addTiltImage(tiFile, tsObject, suffix, tsMov, tsIde, samplingRate, objId, index):
"""
:param tiFile: aligned tilt image file
:param tsObject: Tilt Series to which the new Ti Image will be added
:param suffix: used for the location. Admitted values are even or odd
:param tsMov: Tilt Series Movies object
:param tsIde: Tilt series identifier
:param samplingRate: current Tilt Series sampling rate
:param objId: location of the Tilt Image which will be added
:param index: position of the slice in the generated slack
"""
ta = tsMov.getTiltAngle()
to = tsMov.getAcquisitionOrder()
acq = tsMov.getAcquisition()
ti = TiltImage(tiltAngle=ta, tsId=tsIde, acquisitionOrder=to)
ti.setSamplingRate(samplingRate)
ti.setIndex(index)
ti.setAcquisition(acq)
newLocation = (objId, self._getExtraPath(tsIde + '_' + suffix + '.mrcs'))
ih.convert(tiFile, newLocation)
ti.setLocation(newLocation)
tsObject.append(ti)
pw.utils.cleanPath(tiFile)
ts = self._tsDict.getTs(tsId)
ts.setDim([])
tiList = self._tsDict.getTiList(tsId)
tiList.sort(key=lambda ti: ti.getTiltAngle())
ih = ImageHandler()
tsFn = self._getOutputTiltSeriesPath(ts)
tsFnDW = self._getOutputTiltSeriesPath(ts, '_DW')
# Merge all micrographs from the same tilt images in a single "mrcs" stack file
for i, ti in enumerate(tiList):
tiFn, tiFnDW = self._getOutputTiltImagePaths(ti)
newLocation = (i+1, tsFn)
ih.convert(tiFn, newLocation)
ti.setLocation(newLocation)
pw.utils.cleanPath(tiFn)
if os.path.exists(tiFnDW):
ih.convert(tiFnDW, (i+1, tsFnDW))
pw.utils.cleanPath(tiFnDW)
self._tsDict.setFinished(tsId)
# Even and odd stuff
if self._doSplitEvenOdd():
template = 'tiltseries%s.sqlite'
sRate = self._getOutputSampling()
# Even
if self.outputSetEven:
self.outputSetEven.enableAppend()
else:
self.outputSetEven = SetOfTiltSeries.create(self._getPath(), template=template, suffix=EVEN)
self.outputSetEven.setSamplingRate(sRate)
# Odd
if self.outputSetOdd:
self.outputSetOdd.enableAppend()
else:
self.outputSetOdd = SetOfTiltSeries.create(self._getPath(), template=template, suffix=ODD)
self.outputSetOdd.setSamplingRate(sRate)
tsObjEven = TiltSeries()
tsObjOdd = TiltSeries()
tsObjEven.copyInfo(ts, copyId=True)
tsObjOdd.copyInfo(ts, copyId=True)
self.outputSetEven.append(tsObjEven)
self.outputSetOdd.append(tsObjOdd)
# Merge all micrographs from the same tilt images in a sorted-by-tilt single "mrcs" stack file,
ind = np.argsort([ti.getTiltAngle() for ti in self.tsMList])
counter = 1
# for fileEven, fileOdd, tsImM in zip(self.evenAvgFrameList, self.oddAvgFrameList, self.tsMList):
for i in ind:
tiEvenFile = self.evenAvgFrameList[i]
tiOddFile = self.oddAvgFrameList[i]
tsM = self.tsMList[i]
# Even
addTiltImage(tiEvenFile, tsObjEven, EVEN, tsM, tsId, sRate, counter, counter)
# Odd
addTiltImage(tiOddFile, tsObjOdd, ODD, tsM, tsId, sRate, counter, counter)
counter += 1
# update items and size info
self.outputSetEven.update(tsObjEven)
self.outputSetOdd.update(tsObjOdd)
# Empty lists for next iteration
self.evenAvgFrameList = []
self.oddAvgFrameList = []
self.tsMList = []
# Dose weighting
if self._createOutputWeightedTS():
if getattr(self, 'outputSetDW', None) is None:
self.outputSetDW = self._createOutputSet(suffix='_dose-weighted')
self.outputSetDW.setSamplingRate(self._getOutputSampling())
else:
self.outputSetDW.enableAppend()
tsObjDW = TiltSeries()
tsObjDW.copyInfo(ts, copyId=True)
self.outputSetDW.append(tsObjDW)
for i, ti in enumerate(tiList):
tiOut = TiltImage(location=(i+1, tsFnDW))
tiOut.copyInfo(ti, copyId=True)
tiOut.setAcquisition(ti.getAcquisition())
tiOut.setSamplingRate(self._getOutputSampling())
tiOut.setIndex(i+1)
tiOut.setObjId(ti.getIndex())
tsObjDW.append(tiOut)
self.outputSetDW.update(tsObjDW)
[docs] def createOutputStep(self):
"""" Overwrites the parent method to allow the creation of odd and even outputs"""
ProtTsProcess.createOutputStep(self)
if self._doSplitEvenOdd():
# Even
self.outputSetEven.setStreamState(self.outputSetEven.STREAM_CLOSED)
self.outputSetEven.write()
self._store(self.outputSetEven)
# Odd
self.outputSetOdd.setStreamState(self.outputSetOdd.STREAM_CLOSED)
self.outputSetOdd.write()
self._store(self.outputSetOdd)
if self._createOutputWeightedTS():
self.outputSetDW.setStreamState(self.outputSetDW.STREAM_CLOSED)
self.outputSetDW.write()
self._store(self.outputSetDW)
def _updateOutput(self, tsIdList):
""" Update the output set with the finished Tilt-series.
Params:
:param tsIdList: list of ids of finished tasks.
"""
# Flag to check the first time we save output
self._createOutput = getattr(self, '_createOutput', True)
outputSet = self._getOutputSet()
if outputSet is None:
# Special case just to update the outputSet status
# but it only makes sense when there is outputSet
if not tsIdList:
return
outputSet = self._createOutputSet()
else:
outputSet.enableAppend()
self._createOutput = False
# Call the sub-class method to update the output
outputSet.setSamplingRate(self._getOutputSampling())
self._updateOutputSet(outputSet, tsIdList)
outputSet.setStreamState(outputSet.STREAM_OPEN)
if self._doSplitEvenOdd():
self.outputSetEven.setStreamState(self.outputSetEven.STREAM_OPEN)
self.outputSetOdd.setStreamState(self.outputSetOdd.STREAM_OPEN)
if self._createOutputWeightedTS():
self.outputSetDW.setStreamState(self.outputSetDW.STREAM_OPEN)
if self._createOutput:
outputSet.updateDim()
outputs = {self._getOutputName(): outputSet}
if self._createOutputWeightedTS():
self.outputSetDW.updateDim()
outputs.update({OUTPUT_TILT_SERIES_DW: self.outputSetDW})
if self._doSplitEvenOdd():
self.outputSetEven.updateDim()
self.outputSetOdd.updateDim()
outputs.update({OUTPUT_TILT_SERIES_EVEN: self.outputSetEven,
OUTPUT_TILT_SERIES_ODD: self.outputSetOdd
})
self._defineOutputs(**outputs)
self._defineSourceRelation(self._getInputTsPointer(), self.outputSetEven)
self._defineSourceRelation(self._getInputTsPointer(), self.outputSetOdd)
else:
self._defineOutputs(**outputs)
if self._createOutputWeightedTS():
self._defineSourceRelation(self._getInputTsPointer(), self.outputSetDW)
self._defineSourceRelation(self._getInputTsPointer(), outputSet)
self._createOutput = False
else:
outputSet.write()
self._store(outputSet)
if self._doSplitEvenOdd():
self.outputSetEven.write()
self._store(self.outputSetEven)
self.outputSetOdd.write()
self._store(self.outputSetOdd)
if self._createOutputWeightedTS():
self.outputSetDW.write()
self._store(self.outputSetDW)
outputSet.close()
if self._doSplitEvenOdd():
self.outputSetEven.close()
self.outputSetOdd.close()
if self._createOutputWeightedTS():
self.outputSetDW.close()
if self._tsDict.allDone():
self._coStep.setStatus(params.STATUS_NEW)
def _updateOutputSet(self, outputSet, tsIdList):
""" Override this method to convert the TiltSeriesM into TiltSeries.
"""
for tsId in tsIdList:
ts = TiltSeries()
ts.copyInfo(self._tsDict.getTs(tsId), copyId=True)
ts.setSamplingRate(self._getOutputSampling())
outputSet.append(ts)
tList = self._tsDict.getTiList(tsId)
ind = np.argsort([ti.getTiltAngle() for ti in tList])
counter = 1
for i in ind: # Make each row of the sqlite file be sorted by
# index after having been sorted by angle previously, in order to avoid tilt image mismatching in
# another operations, such as the fiducial alignment, which expects the sqlite to be sorted that way
ti = tList[i]
tiOut = TiltImage(location=(counter, ti.getFileName()))
tiOut.copyInfo(ti, copyId=True)
tiOut.setAcquisition(ti.getAcquisition())
tiOut.setSamplingRate(self._getOutputSampling())
tiOut.setIndex(counter)
tiOut.setObjId(ti.getIndex())
ts.append(tiOut)
counter += 1
outputSet.update(ts)
# --------------------------- INFO functions ------------------------------
def _validate(self):
errors = []
return errors
def _summary(self):
return [self.summaryVar.get('')]
# --------------------------- UTILS functions ----------------------------
def _initialize(self):
inputTs = self._getInputTs()
acq = inputTs.getAcquisition()
gain, dark = self.getGainAndDark()
self.__basicArgs = [
acq.getDoseInitial(), acq.getDosePerFrame(), gain, dark]
def _getArgs(self):
return self.__basicArgs
def _getInputTsPointer(self):
return self.inputTiltSeriesM
def _getOutputSampling(self):
return self.inputTiltSeriesM.get().getSamplingRate() * self._getBinFactor()
def _processTiltImageM(self, workingFolder, tiltImageM, *args):
""" This function should be implemented in subclasses to really provide
the processing step for this TiltSeries Movie.
Output corrected image (and DW one) should be copied to expected name.
"""
pass
[docs] def getGainAndDark(self):
""" Return temporary paths of gain and dark if relevant. """
inputTs = self.inputTiltSeriesM.get()
gain = os.path.abspath(self._getTmpPath('gain.mrc')) if inputTs.getGain() else None
dark = os.path.abspath(self._getTmpPath('dark.mrc')) if inputTs.getDark() else None
return gain, dark
def _getFrameRange(self, n, prefix):
"""
Params:
:param n: Number of frames of the movies
:param prefix: what range we want to consider, either 'align' or 'sum'
:return: (i, f) initial and last frame range
"""
# In case that the user select the same range for ALIGN and SUM
# we also use the 'align' prefix
if self._useAlignToSum():
prefix = 'align'
first = self.getAttributeValue('%sFrame0' % prefix)
last = self.getAttributeValue('%sFrameN' % prefix)
if first <= 1:
first = 1
if last <= 0:
last = n
return first, last
def _getBinFactor(self):
return self.getAttributeValue('binFactor', 1.0)
# ----- Some internal functions ---------
def _getTiltImageMRoot(self, tim):
return '%s_%02d' % (tim.getTsId(), tim.getObjId())
def __getTiltImageMWorkingFolder(self, tiltImageM):
return self._getTmpPath(self._getTiltImageMRoot(tiltImageM))
def _getOutputTiltImagePaths(self, tiltImageM):
""" Return expected output path for correct movie and DW one.
"""
base = self._getExtraPath(self._getTiltImageMRoot(tiltImageM))
return base + '.mrc', base + '_DW.mrc'
def _getOutputTiltSeriesPath(self, ts, suffix=''):
return self._getExtraPath('%s%s.mrcs' % (ts.getTsId(), suffix))
def _useAlignToSum(self):
return True
def _createOutputWeightedTS(self):
return False
[docs]class ProtTsAverage(ProtTsCorrectMotion):
"""
Simple protocol to average TiltSeries movies as basic
motion correction. It is used mainly for testing purposes.
"""
_label = 'average tilt-series movies'
_devStatus = pw.BETA
def _processTiltImageM(self, workingFolder, tiltImageM, *args):
""" Simple add all frames and divide by its number. """
ih = ImageHandler()
sumImg = ih.createImage()
img = ih.createImage()
n = tiltImageM.getNumberOfFrames()
fn = tiltImageM.getFileName()
sumImg.read((1, fn))
for frame in range(2, n + 1):
img.read((frame, fn))
sumImg.inplaceAdd(img)
# sumImg.inplaceDivide(float(n))
outputFn = self._getOutputTiltImagePaths(tiltImageM)[0]
sumImg.write(outputFn)