# **************************************************************************
# *
# * Authors: J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [1]
# *
# * [1] SciLifeLab, Stockholm University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import os
import pyworkflow as pw
import pyworkflow.protocol.params as params
from pwem.convert.headers import getFileFormat, MRC
from pyworkflow.object import Set, Integer
from pyworkflow.protocol import STATUS_NEW
from pyworkflow.utils.properties import Message
from pwem import emlib
from .protocol_ts_base import ProtTsProcess
from ..objects import SetOfCTFTomoSeries, CTFTomoSeries
[docs]class ProtTsEstimateCTF(ProtTsProcess):
"""
Base class for estimating the CTF on TiltSeries
"""
# -------------------------- DEFINE param functions -----------------------
def _defineParams(self, form):
""" Define input parameters from this program into the given form. """
form.addSection(label='Input')
form.addParam('inputTiltSeries', params.PointerParam, important=True,
pointerClass='SetOfTiltSeries, SetOfCTFTomoSeries',
label='Input tilt series')
self._defineProcessParams(form)
self._defineStreamingParams(form)
def _defineProcessParams(self, form):
""" Should be implemented in subclasses. """
pass
# --------------------------- STEPS functions ----------------------------
[docs] def processTiltImageStep(self, tsId, tiltImageId, *args):
ti = self._tsDict.getTi(tsId, tiltImageId)
# Create working directory for a tilt-image
workingDir = self.__getTiworkingDir(ti)
tiFnMrc = os.path.join(workingDir, self.getTiPrefix(ti) + '.mrc')
pw.utils.makePath(workingDir)
self._convertInputTi(ti, tiFnMrc)
# Call the current estimation of CTF that is implemented in subclasses
self._estimateCtf(workingDir, tiFnMrc, ti, *args)
if not pw.utils.envVarOn(pw.SCIPION_DEBUG_NOCLEAN):
pw.utils.cleanPath(workingDir)
ti.setCTF(self.getCtf(ti))
def _convertInputTi(self, ti, tiFn):
""" This function will convert the input tilt-image
taking into account the downFactor.
It can be overwritten in subclasses if another behaviour is required.
"""
downFactor = self.ctfDownFactor.get()
ih = emlib.image.ImageHandler()
if not ih.existsLocation(ti):
raise Exception("Missing input file: %s" % ti)
tiFName = ti.getFileName()
# Make xmipp considers the input object as TS to work as expected
if getFileFormat(tiFName) == MRC:
tiFName = tiFName.split(':')[0] + ':mrcs'
tiFName = str(ti.getIndex()) + '@' + tiFName
if downFactor != 1:
ih.scaleFourier(tiFName, tiFn, downFactor)
else:
ih.convert(tiFName, tiFn, emlib.DT_FLOAT)
def _estimateCtf(self, workingDir, tiFn, tiltImage, *args):
raise Exception("_estimateCTF function should be implemented!")
[docs] def processTiltSeriesStep(self, tsId):
""" Step called after all CTF are estimated for a given tilt series. """
self._tsDict.setFinished(tsId)
def _updateOutput(self, tsIdList):
""" Update the output set with the finished Tilt-series.
Params:
:param tsIdList: list of ids of finished tasks.
"""
ts = self._getTiltSeries(tsIdList[0])
tsId = ts.getTsId()
objId = ts.getObjId()
# Flag to check the first time we save output
self._createOutput = getattr(self, '_createOutput', True)
outputSet = self._getOutputSet()
if outputSet is None:
# Special case just to update the outputSet status
# but it only makes sense when there is outputSet
if not tsIdList:
return
outputSet = self._createOutputSet()
else:
outputSet.enableAppend()
self._createOutput = False
newCTFTomoSeries = CTFTomoSeries()
newCTFTomoSeries.copyInfo(ts)
newCTFTomoSeries.setTiltSeries(ts)
newCTFTomoSeries.setTsId(tsId)
newCTFTomoSeries.setObjId(objId)
outputSet.append(newCTFTomoSeries)
index = 1
for ti in self._tsDict.getTiList(tsId):
newCTFTomo = ti._ctfModel
newCTFTomo.setIndex(Integer(index))
index += 1
newCTFTomoSeries.append(newCTFTomo)
newCTFTomoSeries.calculateDefocusUDeviation()
newCTFTomoSeries.calculateDefocusVDeviation()
if not (newCTFTomoSeries.getIsDefocusUDeviationInRange() and
newCTFTomoSeries.getIsDefocusVDeviationInRange()):
newCTFTomoSeries.setEnabled(False)
newCTFTomoSeries.write(properties=False)
outputSet.update(newCTFTomoSeries)
if self._createOutput:
self._defineOutputs(**{self._getOutputName(): outputSet})
self._defineSourceRelation(self._getInputTs(pointer=True),
outputSet)
self._createOutput = False
else:
outputSet.write()
self._store(outputSet)
outputSet.close()
self._store()
if self._tsDict.allDone():
self._coStep.setStatus(STATUS_NEW)
[docs] def createOutputStep(self):
outputSet = self._getOutputSet()
outputSet.setStreamState(outputSet.STREAM_CLOSED)
outputSet.write()
self._store(outputSet)
# --------------------------- INFO functions ------------------------------
def _validate(self):
errors = []
return errors
def _summary(self):
summary = []
if hasattr(self, 'outputSetOfCTFTomoSeries'):
inputLabel = 'Input Tilt-Series'
if isinstance(self.inputTiltSeries.get(), SetOfCTFTomoSeries):
inputLabel = 'Input CTFTomoSeries'
summary.append(
inputLabel + ": %d.\nnumber of CTF estimated: %d.\n"
% (self._getInputTs().getSize(),
self.outputSetOfCTFTomoSeries.getSize()))
else:
summary.append("Output classes not ready yet.")
return summary
# --------------------------- UTILS functions ----------------------------
def _createOutputSet(self, suffix=''):
""" Method to create the output set.
By default will a SetOfTiltSeries, but can be re-defined in subclasses.
"""
outputSetOfCTFTomoSeries = SetOfCTFTomoSeries.create(self._getPath(),
template='CTFmodels%s.sqlite')
outputSetOfCTFTomoSeries.setSetOfTiltSeries(self._getInputTs(pointer=True))
outputSetOfCTFTomoSeries.setStreamState(Set.STREAM_OPEN)
return outputSetOfCTFTomoSeries
def _getTiltSeries(self, itemId):
obj = None
inputSetOfTiltseries = self._getInputTs()
for item in inputSetOfTiltseries.iterItems(iterate=False):
if item.getTsId() == itemId:
obj = item
if isinstance(obj, CTFTomoSeries):
obj = item.getTiltSeries()
break
if obj is None:
raise ("Could not find tilt-series with tsId = %s" % itemId)
return obj
def _getOutputName(self):
""" Return the output name, by default 'outputTiltSeries'.
This method can be re-implemented in subclasses that have a
different output. (e.g outputTomograms).
"""
return 'outputSetOfCTFTomoSeries'
def _getOutputSet(self):
return getattr(self, self._getOutputName(), None)
def _getInputTsPointer(self):
return self.inputTiltSeries
def _getInputTs(self, pointer=False):
if isinstance(self.inputTiltSeries.get(), SetOfCTFTomoSeries):
return self.inputTiltSeries.get().getSetOfTiltSeries(pointer=pointer)
return self.inputTiltSeries.get() if not pointer else self.inputTiltSeries
def _initialize(self):
""" This function define a dictionary with parameters used
for CTF estimation that are common for all micrographs. """
# Get pointer to input micrographs
inputTs = self._getInputTs()
acq = inputTs.getAcquisition()
downFactor = self.getAttributeValue('ctfDownFactor', 1.0)
sampling = inputTs.getSamplingRate()
if downFactor != 1.0:
sampling *= downFactor
self._params = {'voltage': acq.getVoltage(),
'sphericalAberration': acq.getSphericalAberration(),
'magnification': acq.getMagnification(),
'ampContrast': acq.getAmplitudeContrast(),
'samplingRate': sampling,
'scannedPixelSize': inputTs.getScannedPixelSize(),
'windowSize': self.windowSize.get(),
'lowRes': self.lowRes.get(),
'highRes': self.highRes.get(),
'minDefocus': self.minDefocus.get(),
'maxDefocus': self.maxDefocus.get()
}
[docs] def getCtfParamsDict(self):
""" Return a copy of the global params dict,
to avoid overwriting values. """
return self._params
[docs] def getCtf(self, ti):
""" Should be implemented in subclasses. """
pass
# ----- Some internal functions ---------
[docs] def getTiRoot(self, tim):
return '%s_%02d' % (tim.getTsId(), tim.getObjId())
def __getTiworkingDir(self, tiltImage):
return self._getTmpPath(self.getTiRoot(tiltImage))
def _getOutputTiPaths(self, tiltImageM):
""" Return expected output path for correct movie and DW one.
"""
base = self._getExtraPath(self.getTiRoot(tiltImageM))
return base + '.mrc', base + '_DW.mrc'
def _getOutputTiltSeriesPath(self, ts, suffix=''):
return self._getExtraPath('%s%s.mrcs' % (ts.getTsId(), suffix))
def _useAlignToSum(self):
return True
[docs] def getTiPrefix(self, ti):
return '%s_%03d' % (ti.getTsId(), ti.getObjId())
[docs] def allowsDelete(self, obj):
return True