# **************************************************************************
# *
# * Authors: Grigory Sharov (gsharov@mrc-lmb.cam.ac.uk)
# *
# * MRC Laboratory Of Molecular Biology (MRC-LMB)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
import re
from glob import glob
from pyworkflow.protocol.params import Float
import pyworkflow.protocol.params as params
import pyworkflow.utils as pwutils
from pyworkflow.protocol.constants import LEVEL_ADVANCED
from pwem.protocols import ProtClassify3D
from pwem.constants import ALIGN_PROJ
from pwem.objects import SetOfVolumes, SetOfClassesVol
from pwem.constants import ALIGN_NONE
import pwem.emlib.metadata as md
from xmipp3 import Plugin
from xmipp3.convert import (readSetOfClassesVol, getImageLocation,
writeSetOfVolumes, rowToAlignment)
MISSING_WEDGE_Y = 0
MISSING_WEDGE_X = 1
MISSING_PYRAMID = 2
MISSING_CONE = 3
[docs]class XmippProtMLTomo(ProtClassify3D):
""" Align and classify 3D images with missing data regions in Fourier
space, e.g. subtomograms or RCT reconstructions, by a 3D
multi-reference refinement based on a maximum-likelihood (ML) target
function.
See http://xmipp.cnb.csic.es/twiki/bin/view/Xmipp/Ml_tomo_v31
for further documentation
"""
_label = 'mltomo'
def _initialize(self):
""" This function is mean to be called after the
working dir for the protocol have been set.
(maybe after recovery from mapper)
"""
self._createFilenameTemplates()
self._createIterTemplates()
def _createFilenameTemplates(self):
""" Centralize how files are called for iterations and references. """
self.extraIter = self._getExtraPath('results/mltomo_it%(iter)06d')
myDict = {
'data_it': self.extraIter + '_img.xmd',
#'data': self._getExtraPath('results/mltomo_img.xmd'),
'classes_scipion': self.extraIter + '_classes_scipion.sqlite',
'log_it': self.extraIter + '_log.xmd',
'ref_it': self.extraIter + '_ref.xmd',
'ref': self._getExtraPath('results/mltomo_ref.xmd'),
'fsc_it': self.extraIter + '.fsc',
#'fsc': self._getExtraPath('results/mltomo.fsc'),
'volume': self.extraIter + '_ref%(ref3d)06d.mrc'
}
self._updateFilenamesDict(myDict)
def _createIterTemplates(self):
""" Setup the regex on how to find iterations. """
self._iterTemplate = self._getFileName('ref_it', iter=0).replace('000000', '??????')
# Iterations will be identify by _itXXXXXX_ where XXXXXX is the
# iteration number and is restricted to only 6 digits.
self._iterRegex = re.compile(r'_it(\d{6,6})_')
#--------------------------- DEFINE param functions -----------------------
def _defineParams(self, form):
form.addSection(label='General params')
form.addParam('inputVols', params.PointerParam,
pointerClass="SetOfVolumes",
label='Set of volumes',
help="Set of input volumes")
form.addParam('copyAlignment', params.BooleanParam, default=True,
label='Consider previous alignment?',
help='If set to Yes, then alignment information '
'from input volumes will be considered.')
form.addParam('generateRefs', params.BooleanParam, default=True,
label='Automatically generate references',
help="If set to true, 3D classes will be generated "
"automatically. Otherwise you can provide initial "
"reference volumes yourself.")
form.addParam('numberOfReferences', params.IntParam,
label='Number of references', default=3,
condition="generateRefs",
help="Number of references to generate automatically")
form.addParam('inputRefVols', params.PointerParam,
pointerClass="SetOfVolumes, Volume",
condition="not generateRefs",
label='Input reference volume(s)',
help="Provide a set of initial reference volumes")
form.addParam('numberOfIterations', params.IntParam,
label='Number of iterations', default=25,
help="Maximum number of iterations to perform")
form.addParam('symmetry', params.StringParam, default='c1',
label='Symmetry group',
help="See http://xmipp.cnb.csic.es/twiki/bin/view/Xmipp/Symmetry "
"for a description of the symmetry groups format. If no "
"symmetry is present, give c1")
form.addParam('missingDataType', params.EnumParam,
choices=['wedge_y', 'wedge_x', 'pyramid', 'cone'],
default=0,
display=params.EnumParam.DISPLAY_COMBO,
label='Missing data regions',
help="Provide missing data region type:\n\n"
"a) wedge_y for a missing wedge where the tilt "
"axis is along Y\n"
"b) wedge_x for a missing wedge where the tilt "
"axis is along X\n"
"c) pyramid for a missing pyramid where the tilt "
"axes are along Y and X\n"
"d) cone for a missing cone (pointing along Z)")
form.addParam('missingAng', params.StringParam, default='-60 60',
label='Angles of missing data',
help='Provide angles for missing data area in the '
'following format:\n\n'
'for wedge_y or wedge_x: -60 60\n'
'for pyramid: -60 60 -60 60 (for y and x, respectively)\n'
'for cone: 45')
form.addParam('maxCC', params.BooleanParam, default=False,
label='Use CC instead of ML',
help='Use constrained cross-correlation and weighted '
'averaging instead of ML')
form.addSection(label='Sampling')
form.addParam('angSampling', params.FloatParam, default=10.0,
label='Angular sampling (deg)',
help="Angular sampling rate (in degrees)")
form.addParam('angSearch', params.FloatParam, default=-1.0,
label='Angular search range (deg)',
help="Angular search range around orientations of "
"input particles (by default [-1.0], exhaustive "
"searches are performed)")
form.addParam('globalPsi', params.BooleanParam, default=False,
expertLevel=LEVEL_ADVANCED,
label='Exhaustive psi search',
help="Exhaustive psi searches (only for c1 symmetry)")
form.addParam('limitTrans', params.FloatParam, default=-1.0,
expertLevel=LEVEL_ADVANCED,
label='Max shift (px)',
help="Maximum allowed shifts "
"(negative value means no restriction)")
line = form.addLine('Tilt angle limits (deg)',
expertLevel=LEVEL_ADVANCED,
help='Limits for tilt angle search (in degrees)')
line.addParam('tiltMin', params.FloatParam, default=0.0, label='Min')
line.addParam('tiltMax', params.FloatParam, default=180.0, label='Max')
form.addParam('psiSampling', params.FloatParam, default=-1.0,
expertLevel=LEVEL_ADVANCED,
label='Psi angle sampling (deg)',
help="Angular sampling rate for the in-plane "
"rotations (in degrees)")
form.addSection(label='Restrictions')
form.addParam('dim', params.IntParam, default=-1,
label='Downscale input to (px)',
help="Use downscaled (in fourier space) images of this "
"size (in pixels)")
form.addParam('maxRes', params.FloatParam, default=0.5,
label='Maximum resolution (px^-1)',
help="Maximum resolution (in pixel^-1) to use")
form.addParam('doPerturb', params.BooleanParam, default=False,
label='Perturb',
help="Apply random perturbations to angular sampling "
"in each iteration")
form.addParam('dontRotate', params.BooleanParam, default=False,
label='Do not rotate',
help="Keep orientations fixed, only translate and classify")
form.addParam('dontAlign', params.BooleanParam, default=False,
label='Do not align',
help="Keep angles and shifts fixed "
"(otherwise start from random)")
form.addParam('maskFile', params.PointerParam,
pointerClass='VolumeMask',
allowsNull=True,
expertLevel=LEVEL_ADVANCED,
label='Mask', condition='dontAlign',
help='Mask input volumes; only valid in combination '
'with --dont_align')
form.addParam('onlyAvg', params.BooleanParam, default=False,
label='Only average',
help="Keep orientations and classes, "
"only output weighted averages")
form.addParam('dontImpute', params.BooleanParam, default=False,
expertLevel=LEVEL_ADVANCED,
label='Do not impute',
help='Use weighted averaging, rather than imputation')
form.addParam('noImpThresh', params.FloatParam, default=1.0,
expertLevel=LEVEL_ADVANCED,
condition='dontImpute',
label='Threshold for averaging',
help='Threshold to avoid division by zero '
'for weighted averaging')
form.addParam('fixSigmaNoise', params.BooleanParam, default=False,
expertLevel=LEVEL_ADVANCED,
label='Fix sigma noise',
help='Do not re-estimate the standard deviation in '
'the pixel noise')
form.addParam('fixSigmaOffset', params.BooleanParam, default=False,
expertLevel=LEVEL_ADVANCED,
label='Fix sigma offsets',
help='Do not re-estimate the standard deviation in the '
'origin offsets')
form.addParam('fixFrac', params.BooleanParam, default=False,
expertLevel=LEVEL_ADVANCED,
label='Fix model fractions',
help='Do not re-estimate the model fractions. '
'Calculations start with even distribution.')
form.addSection(label='Advanced')
group = form.addGroup('Regularization parameters')
line = group.addLine('Regularization',
help='Regularization parameters (in N/K^2)')
line.addParam('regIni', params.FloatParam, default=0.0,
label='Initial')
line.addParam('regFinal', params.FloatParam, default=0.0,
label='Final')
group.addParam('regSteps', params.IntParam, default=5,
label='Steps',
help='Number of iterations in which the regularization '
'is changed from reg0 to regF')
form.addParam('numberOfImpIterations', params.IntParam,
expertLevel=LEVEL_ADVANCED,
label='Iterations in inner imputation loop', default=1,
help="Number of iterations for inner imputation loop")
# FIXME: next param is to continue from iter X,
#form.addParam('iterStart', params.IntParam,
# expertLevel=LEVEL_ADVANCED,
# label='Initial iteration', default=1,
# help="Number of initial iteration")
form.addParam('eps', params.FloatParam, default=5e-5,
expertLevel=LEVEL_ADVANCED,
label='Stopping criterium',
help="Stopping criterium")
form.addParam('stdNoise', params.FloatParam, default=1.0,
expertLevel=LEVEL_ADVANCED,
label='Expected noise std',
help="Expected standard deviation for pixel noise")
form.addParam('stdOrig', params.FloatParam, default=3.0,
expertLevel=LEVEL_ADVANCED,
label='Expected origin offset std (px)',
help="Expected standard deviation for origin offset "
"(in pixels)")
form.addParallelSection(threads=1, mpi=3)
#--------------------------- INSERT steps functions -----------------------
def _insertAllSteps(self):
self._initialize()
self._insertFunctionStep('convertInputs')
self._insertFunctionStep('runMLTomo')
self._insertFunctionStep('createOutput')
#--------------------------- STEPS functions ------------------------------
[docs] def createRefMd(self, vols):
refVols = self._getExtraPath('ref_volumes.xmd')
mdFn = md.MetaData()
if self.isSetOfVolumes():
for vol in vols:
imgId = vol.getObjId()
row = md.Row()
row.setValue(md.MDL_ITEM_ID, int(imgId))
row.setValue(md.MDL_IMAGE, getImageLocation(vol))
row.setValue(md.MDL_ENABLED, 1)
row.addToMd(mdFn)
else:
imgId = vols.getObjId()
row = md.Row()
row.setValue(md.MDL_ITEM_ID, int(imgId))
row.setValue(md.MDL_IMAGE, getImageLocation(vols))
row.setValue(md.MDL_ENABLED, 1)
row.addToMd(mdFn)
mdFn.write(refVols, md.MD_APPEND)
[docs] def runMLTomo(self):
fnVols = self._getExtraPath('input_volumes.xmd')
refVols = self._getExtraPath('ref_volumes.xmd')
missDataFn = self._getExtraPath('wedges.xmd')
outDir = self._getExtraPath("results")
pwutils.makePath(outDir)
params = ' -i %s' % fnVols
params += ' --oroot %s' % (outDir + '/mltomo')
params += ' --iter %d' % self.numberOfIterations.get()
params += ' --sym %s' % self.symmetry.get()
params += ' --missing %s' % missDataFn
params += ' --maxres %0.2f' % self.maxRes.get()
params += ' --dim %d' % self.dim.get()
params += ' --ang %0.1f' % self.angSampling.get()
params += ' --ang_search %0.1f' % self.angSearch.get()
params += ' --limit_trans %0.1f' % self.limitTrans.get()
params += ' --tilt0 %0.1f --tiltF %0.1f' % (self.tiltMin.get(),
self.tiltMax.get())
params += ' --psi_sampling %0.1f' % self.psiSampling.get()
params += ' --reg0 %0.1f --regF %0.1f --reg_steps %d' % \
(self.regIni.get(), self.regFinal.get(), self.regSteps.get())
params += ' --impute_iter %d' % self.numberOfImpIterations.get()
#params += ' --istart %d' % self.iterStart.get()
params += ' --eps %0.2f' % self.eps.get()
params += ' --pixel_size %0.2f' % self.inputVols.get().getSamplingRate()
params += ' --noise %0.1f --offset %0.1f' % (self.stdNoise.get(),
self.stdOrig.get())
params += ' --thr %d' % self.numberOfThreads.get()
if self.generateRefs:
params += ' --nref %d' % self.numberOfReferences.get()
else:
params += ' --ref %s' % refVols
if self.copyAlignment:
params += ' --keep_angles'
# when considering alignment, use it,
# otherwise starts from random angles
if self.doPerturb:
params += ' --perturb'
if self.dontRotate:
params += ' --dont_rotate'
if self.dontAlign:
params += ' --dont_align'
if self.onlyAvg:
params += ' --only_average'
if self.dontImpute:
params += ' --dont_impute'
params += ' --noimp_threshold %0.1f' % self.noImpThresh.get()
if self.fixSigmaNoise:
params += ' --fix_sigma_noise'
if self.fixSigmaOffset:
params += ' --fix_sigma_offset'
if self.fixFrac:
params += ' --fix_fractions'
if self.globalPsi:
params += ' --dont_limit_psirange'
if self.maxCC:
params += ' --maxCC'
if self.dontAlign and self.maskFile:
params += ' --mask %s' % self.maskFile.get().getFileName()
self.runJob('xmipp_ml_tomo', '%s' % params,
env=self.getMLTomoEnviron(),
numberOfMpi=self.numberOfMpi.get(),
numberOfThreads=self.numberOfThreads.get())
[docs] def createOutput(self):
# output files:
# mltomo_ref.xmd contains all info for output 3D classes
# mltomo_refXXXXXX.mrc output volume - 3D class
# mltomo_img.xmd contains alignment metadata for all vols
# mltomo.fsc
outputGlobalMdFn = self._getFileName('ref')
setOfClasses = self._createSetOfClassesVol()
setOfClasses.setImages(self.inputVols.get())
readSetOfClassesVol(setOfClasses, outputGlobalMdFn)
self._defineOutputs(outputClasses=setOfClasses)
self._defineSourceRelation(self.inputVols, self.outputClasses)
#--------------------------- INFO functions -------------------------------
def _summary(self):
messages = []
if not hasattr(self, 'outputClasses'):
messages.append('Output is not ready')
else:
messages.append('Number of input volumes: %d' %
self.inputVols.get().getSize())
if self.generateRefs:
messages.append('References were auto-generated')
else:
messages.append('References were provided by user')
messages.append('Number of output classes: %d' %
self.outputClasses.getSize())
return messages
def _validate(self):
errors = []
missNum = self.missingDataType.get()
angString = self.missingAng.get()
angs = str(angString).split()
if (missNum == 0 or missNum == 1) and len(angs) != 2:
errors.append('Wrong angles of missing data! Provide two values '
'for a missing wedge')
elif missNum == 2 and len(angs) != 4:
errors.append('Wrong angles of missing data! Provide four values '
'for a missing pyramid')
elif missNum == 3 and len(angs) != 1:
errors.append('Wrong angles of missing data! Provide one value '
'for a missing cone')
if not self.copyAlignment and self.angSearch != -1:
errors.append('You cannot do local searches when '
'ignoring input alignments.')
return errors
def _citations(self):
return ['Scheres2009c']
#--------------------------- UTILS functions ------------------------------
[docs] def getMLTomoEnviron(self):
env = Plugin.getEnviron()
return env
[docs] def isSetOfVolumes(self):
return isinstance(self.inputRefVols.get(), SetOfVolumes)
def _postprocessVolumeRow(self, img, imgRow):
# explicitly set this from ..protocols.protocol input
# to avoid conflict with input metadata
missNum = self.missingDataType.get()
imgRow.setValue(md.MDL_MISSINGREGION_NR, missNum + 1)
if not imgRow.hasLabel(md.MDL_REF):
imgRow.setValue(md.MDL_REF, 1)
if not imgRow.hasLabel(md.MDL_LL):
imgRow.setValue(md.MDL_LL, 1.)
def _firstIter(self):
return self._getIterNumber(0) or 1
def _lastIter(self):
return self._getIterNumber(-1)
def _getIterNumber(self, index):
""" Return the list of iteration files, give the iterTemplate. """
result = None
files = sorted(glob(self._iterTemplate))
if files:
f = files[index]
s = self._iterRegex.search(f)
if s:
result = int(s.group(1)) # group 1 is 6 digits iteration number
return result
def _getIterClasses(self, it, clean=False):
""" Return a classes .sqlite file for this iteration.
If the file doesn't exists, it will be created by
converting from this iteration mltomo_it??????_img.xmd file.
"""
data_classes = self._getFileName('classes_scipion', iter=it)
if clean:
pwutils.cleanPath(data_classes)
if not pwutils.exists(data_classes):
clsSet = SetOfClassesVol(filename=data_classes)
clsSet.setImages(self.inputVols.get())
self._fillClassesFromIter(clsSet, it)
clsSet.write()
clsSet.close()
return data_classes
def _loadClassesInfo(self, iteration):
""" Read some information about the produced MLtomo 3D classes
from the *ref.xmd file.
"""
self._classesInfo = {} # store classes info, indexed by class id
modelFn = md.MetaData('classes@' +
self._getFileName('ref_it', iter=iteration))
for classNumber, row in enumerate(md.iterRows(modelFn)):
fn = str(row.getValue(md.MDL_IMAGE))
# Store info indexed by id, we need to store the row.clone() since
# the same reference is used for iteration
self._classesInfo[classNumber + 1] = (fn, row.clone())
def _fillClassesFromIter(self, clsSet, iteration):
""" Create the SetOfClassesVol from a given iteration. """
self._loadClassesInfo(iteration)
dataFn = self._getFileName('data_it', iter=iteration)
clsSet.classifyItems(updateItemCallback=self._updateParticle,
updateClassCallback=self._updateClass,
itemDataIterator=md.iterRows(dataFn,
sortByLabel=md.MDL_ITEM_ID))
def _updateParticle(self, item, row):
item.setClassId(row.getValue(md.MDL_REF))
item.setTransform(rowToAlignment(row, ALIGN_PROJ))
item._xmippLogLikeliContribution = Float(row.getValue(md.MDL_LL))
def _updateClass(self, item):
classId = item.getObjId()
if classId in self._classesInfo:
fn, row = self._classesInfo[classId]
item.setAlignmentProj()
item.getRepresentative().setLocation(1, fn)
item._xmippWeight = Float(row.getValue(md.MDL_WEIGHT))
item._xmippSignalChange = Float(row.getValue(md.MDL_SIGNALCHANGE))