# **************************************************************************
# *
# * Authors: Grigory Sharov (gsharov@mrc-lmb.cam.ac.uk) [1]
# * J.M. De la Rosa Trevin (delarosatrevin@scilifelab.se) [2]
# *
# * [1] MRC Laboratory of Molecular Biology, MRC-LMB
# * [2] SciLifeLab, Stockholm University
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 3 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from emtable import Table
from pwem.constants import ALIGN_PROJ
from pwem.protocols import ProtInitialVolume
from pwem.objects import Volume
from pyworkflow.constants import PROD
from pyworkflow.protocol.params import (PointerParam, FloatParam,
LabelParam, IntParam,
EnumParam, StringParam,
BooleanParam,
LEVEL_ADVANCED)
import relion
import relion.convert as convert
from .protocol_base import ProtRelionBase
[docs]class ProtRelionInitialModel(ProtInitialVolume, ProtRelionBase):
""" This protocols creates a 3D initial model using Relion.
Generate a 3D initial model _de novo_ from 2D particles using
Relion Stochastic Gradient Descent (SGD) algorithm.
"""
_label = '3D initial model'
_devStatus = PROD
IS_CLASSIFY = False
IS_3D_INIT = True
IS_2D = False
CHANGE_LABELS = ['rlnChangesOptimalOrientations',
'rlnChangesOptimalOffsets']
def __init__(self, **args):
ProtRelionBase.__init__(self, **args)
def _initialize(self):
""" This function is mean to be called after the
working dir for the protocol have been set.
(maybe after recovery from mapper)
"""
self._createFilenameTemplates()
self._createIterTemplates()
if not self.doContinue:
self.continueRun.set(None)
self.maskZero = False
self.copyAlignment = False
self.hasReferenceCTFCorrected = False
# -------------------------- DEFINE param functions -----------------------
def _defineParams(self, form):
self._defineConstants()
self.IS_3D = not self.IS_2D
form.addSection(label='Input')
form.addParam('doContinue', BooleanParam, default=False,
label='Continue from a previous run?',
help='If you set to *Yes*, you should select a previous'
'run of type *%s* class and most of the input '
'parameters will be taken from it.'
% self.getClassName())
form.addParam('inputParticles', PointerParam,
pointerClass='SetOfParticles',
condition='not doContinue',
important=True,
label="Input particles",
help='Select the input images from the project.')
form.addParam('maskDiameterA', IntParam, default=-1,
label='Particle mask diameter (A)',
help='The experimental images will be masked with a '
'soft circular mask with this <diameter>. '
'Make sure this diameter is not set too small '
'because that may mask away part of the signal! If '
'set to a value larger than the image size no '
'masking will be performed.\n\n'
'The same diameter will also be used for a '
'spherical mask of the reference structures if no '
'user-provided mask is specified.')
form.addParam('continueRun', PointerParam,
pointerClass=self.getClassName(),
condition='doContinue', allowsNull=True,
label='Select previous run',
help='Select a previous run to continue from.')
form.addParam('continueIter', StringParam, default='last',
condition='doContinue',
label='Continue from iteration',
help='Select from which iteration do you want to '
'continue. If you use *last*, then the last '
'iteration will be used. Otherwise, a valid '
'iteration number should be provided.')
self.addSymmetry(form)
form.addSection(label='CTF')
form.addParam('continueMsg', LabelParam, default=True,
condition='doContinue',
label='CTF parameters are not available in continue mode')
form.addParam('doCTF', BooleanParam, default=True,
label='Do CTF-correction?', condition='not doContinue',
help='If set to Yes, CTFs will be corrected inside the '
'MAP refinement. The resulting algorithm '
'intrinsically implements the optimal linear, or '
'Wiener filter. Note that input particles should '
'contains CTF parameters.')
form.addParam('haveDataBeenPhaseFlipped', LabelParam,
condition='not doContinue',
label='Have data been phase-flipped? '
'(Don\'t answer, see help)',
help='The phase-flip status is recorded and managed by '
'Scipion. \n In other words, when you import or '
'extract particles, \nScipion will record whether '
'or not phase flipping has been done.\n\n'
'Note that CTF-phase flipping is NOT a necessary '
'pre-processing step \nfor MAP-refinement in '
'RELION, as this can be done inside the internal\n'
'CTF-correction. However, if the phases have been '
'flipped, the program will handle it.')
form.addParam('ignoreCTFUntilFirstPeak', BooleanParam, default=False,
label='Ignore CTFs until first peak?',
condition='not doContinue',
help='If set to Yes, then CTF-amplitude correction will '
'only be performed from the first peak '
'of each CTF onward. This can be useful if the CTF '
'model is inadequate at the lowest resolution. '
'Still, in general using higher amplitude contrast '
'on the CTFs (e.g. 10-20%) often yields better '
'results. Therefore, this option is not generally '
'recommended.')
form.addParam('doCtfManualGroups', BooleanParam, default=False,
label='Do manual grouping ctfs?',
condition='not doContinue',
help='Set this to Yes the CTFs will grouping manually.')
form.addParam('defocusRange', FloatParam, default=1000,
label='Defocus range for group creation (in Angstroms)',
condition='doCtfManualGroups and not doContinue',
help='Particles will be grouped by defocus.'
'This parameter is the bin for a histogram.'
'All particles assigned to a bin form a group')
form.addParam('numParticles', FloatParam, default=10,
label='minimum size for defocus group',
condition='doCtfManualGroups and not doContinue',
help='If defocus group is smaller than this value, '
'it will be expanded until number of particles '
'per defocus group is reached')
form.addSection('Optimisation')
form.addParam('numberOfClasses', IntParam, default=1,
condition='not doContinue',
label='Number of classes',
help='The number of classes (K) for a multi-reference '
'ab initio SGD refinement. These classes will be '
'made in an unsupervised manner, starting from a '
'single reference in the initial iterations of '
'the SGD, and the references will become '
'increasingly dissimilar during the in between '
'iterations.')
form.addParam('doFlattenSolvent', BooleanParam, default=True,
condition='not doContinue',
label='Flatten and enforce non-negative solvent?',
help='If set to Yes, the job will apply a spherical '
'mask and enforce all values in the reference '
'to be non-negative.')
form.addParam('symmetryGroup', StringParam, default='c1',
condition='not doContinue',
label="Symmetry",
help='SGD sometimes works better in C1. If you make an '
'initial model in C1 but want to run Class3D/Refine3D '
'with a higher point group symmetry, the reference model '
'must be rotated to conform the symmetry convention. '
'You can do this by the relion_align_symmetry command.')
group = form.addGroup('Sampling')
group.addParam('angularSamplingDeg', EnumParam, default=1,
choices=relion.ANGULAR_SAMPLING_LIST,
label='Initial angular sampling (deg)',
help='There are only a few discrete angular samplings'
' possible because we use the HealPix library to'
' generate the sampling of the first two Euler '
'angles on the sphere. The samplings are '
'approximate numbers and vary slightly over '
'the sphere.')
group.addParam('offsetSearchRangePix', FloatParam, default=6,
label='Offset search range (pix)',
help='Probabilities will be calculated only for '
'translations in a circle with this radius (in '
'pixels). The center of this circle changes at '
'every iteration and is placed at the optimal '
'translation for each image in the previous '
'iteration.')
group.addParam('offsetSearchStepPix', FloatParam, default=2,
label='Offset search step (pix)',
help='Translations will be sampled with this step-size '
'(in pixels). Translational sampling is also done '
'using the adaptive approach. Therefore, if '
'adaptive=1, the translations will first be '
'evaluated on a 2x coarser grid.')
form.addSection(label='SGD')
self._defineSGD3(form)
form.addParam('sgdNoiseVar', IntParam, default=-1,
condition='not doContinue',
expertLevel=LEVEL_ADVANCED,
label='Increased noise variance half-life',
help='When set to a positive value, the initial '
'estimates of the noise variance will internally '
'be multiplied by 8, and then be gradually '
'reduced, having 50% after this many particles '
'have been processed. By default, this option '
'is switched off by setting this value to a '
'negative number. In some difficult cases, '
'switching this option on helps. In such cases, '
'values around 1000 have found to be useful. '
'Change the factor of eight with the additional '
'argument *--sgd_sigma2fudge_ini*')
form.addSection('Compute')
self._defineComputeParams(form)
form.addParam('extraParams', StringParam, default='',
label='Additional arguments',
help="In this box command-line arguments may be "
"provided that are not generated by the GUI. This "
"may be useful for testing developmental options "
"and/or expert use of the program, e.g: \n"
"--dont_combine_weights_via_disc\n"
"--verb 1\n"
"--pad 2\n")
form.addParallelSection(threads=1, mpi=3)
def _defineSGD3(self, form):
""" Define SGD parameters for Relion version 3. """
group = form.addGroup('Iterations')
group.addParam('numberOfIterInitial', IntParam, default=50,
label='Number of initial iterations',
help='Number of initial SGD iterations, at which the '
'initial resolution cutoff and the initial subset '
'size will be used, and multiple references are '
'kept the same. 50 seems to work well in many '
'cases. Increase if the correct solution is not '
'found.')
group.addParam('numberOfIterInBetween', IntParam, default=200,
label='Number of in-between iterations',
help='Number of SGD iterations between the initial and '
'final ones. During these in-between iterations, '
'the resolution is linearly increased, together '
'with the mini-batch or subset size. In case of a '
'multi-class refinement, the different references '
'are also increasingly left to become dissimilar. '
'200 seems to work well in many cases. Increase '
'if multiple references have trouble separating, '
'or the correct solution is not found.')
group.addParam('numberOfIterFinal', IntParam, default=50,
label='Number of final iterations',
help='Number of final SGD iterations, at which the '
'final resolution cutoff and the final subset '
'size will be used, and multiple references are '
'left dissimilar. 50 seems to work well in many '
'cases. Perhaps increase when multiple reference '
'have trouble separating.')
group.addParam('writeIter', IntParam, default=10,
expertLevel=LEVEL_ADVANCED,
label='Write-out frequency (iter)',
help='Every how many iterations do you want to write the '
'model to disk. Negative value means only write '
'out model after entire iteration.')
line = form.addLine('Resolution (A)',
help='This is the resolution cutoff (in A) that '
'will be applied during the initial and final '
'SGD iterations. 35A and 15A respectively '
'seems to work well in many cases.')
line.addParam('initialRes', IntParam, default=35, label='Initial')
line.addParam('finalRes', IntParam, default=15, label='Final')
line = form.addLine('Mini-batch size',
help='The number of particles that will be processed '
'during the initial and final iterations. \n\n'
'For initial, 100 seems to work well in many '
'cases. Lower values may result in wider '
'searches of the energy landscape, but possibly '
'at reduced resolutions. \n\n'
'For final, 300-500 seems to work well in many '
'cases. Higher values may result in increased '
'resolutions, but at increased computational '
'costs.')
line.addParam('initialBatch', IntParam, default=100, label='Initial')
line.addParam('finalBatch', IntParam, default=500, label='Final')
[docs] def addSymmetry(self, container):
pass
# -------------------------- INSERT steps functions -----------------------
# -------------------------- STEPS functions ------------------------------
def _getVolumes(self):
""" Return the list of volumes generated.
The number of volumes in the list will be equal to
the number of classes requested by the user in the protocol. """
# Provide 1 as default value for making it backward compatible
k = self.getAttributeValue('numberOfClasses', 1)
pixelSize = self._getInputParticles().getSamplingRate()
lastIter = self._lastIter()
volumes = []
for i in range(1, k + 1):
vol = Volume(self._getExtraPath('relion_it%03d_class%03d.mrc')
% (lastIter, i))
vol.setSamplingRate(pixelSize)
volumes.append(vol)
return volumes
[docs] def createOutputStep(self):
imgSet = self._getInputParticles()
volumes = self._getVolumes()
outImgSet = self._createSetOfParticles()
outImgSet.copyInfo(imgSet)
self._fillDataFromIter(outImgSet, self._lastIter())
if len(volumes) > 1:
output = self._createSetOfVolumes()
output.setSamplingRate(imgSet.getSamplingRate())
for vol in volumes:
output.append(vol)
self._defineOutputs(outputVolumes=output)
else:
output = volumes[0]
self._defineOutputs(outputVolume=output)
self._defineSourceRelation(self.inputParticles, output)
self._defineOutputs(outputParticles=outImgSet)
self._defineTransformRelation(self.inputParticles, outImgSet)
# -------------------------- INFO functions -------------------------------
def _validateNormal(self):
errors = []
return errors
def _validateContinue(self):
errors = []
continueRun = self.continueRun.get()
continueRun._initialize()
lastIter = continueRun._lastIter()
if self.continueIter.get() == 'last':
continueIter = lastIter
else:
continueIter = int(self.continueIter.get())
if continueIter > lastIter:
errors += ["The iteration from you want to continue must be "
"%01d or less" % lastIter]
return errors
def _summaryNormal(self):
summary = []
it = self._lastIter() or -1
if it >= 1:
table = Table(fileName=self._getFileName('model', iter=it),
tableName='model_general')
row = table[0]
resol = float(row.rlnCurrentResolution)
summary.append("Current resolution: *%0.2f*" % resol)
return summary
def _summaryContinue(self):
summary = ["Continue from iteration %01d" % self._getContinueIter()]
return summary
# -------------------------- UTILS functions ------------------------------
def _setBasicArgs(self, args):
""" Return a dictionary with basic arguments. """
args.update({'--o': self._getExtraPath('relion'),
'--oversampling': '1'
})
if self.doFlattenSolvent:
args['--flatten_solvent'] = ''
if not self.doContinue:
args.update({'--sym': self.symmetryGroup.get()})
args['--pad'] = 1 if self.skipPadding else 2
if self.skipGridding:
args['--skip_gridding'] = ''
self._setSGDArgs(args)
self._setSamplingArgs(args)
def _setSGDArgs(self, args):
args['--sgd'] = ''
args['--sgd_ini_iter'] = self.numberOfIterInitial.get()
args['--sgd_inbetween_iter'] = self.numberOfIterInBetween.get()
args['--sgd_fin_iter'] = self.numberOfIterFinal.get()
args['--sgd_write_iter'] = self.writeIter.get()
args['--sgd_ini_resol'] = self.initialRes.get()
args['--sgd_fin_resol'] = self.finalRes.get()
args['--sgd_ini_subset'] = self.initialBatch.get()
args['--sgd_fin_subset'] = self.finalBatch.get()
args['--K'] = self.numberOfClasses.get()
if not self.doContinue:
args['--denovo_3dref'] = ''
args['--sgd_sigma2fudge_halflife'] = self.sgdNoiseVar.get()
def _setSamplingArgs(self, args):
""" Set sampling related params"""
if not self.doContinue:
args['--healpix_order'] = self.angularSamplingDeg.get()
args['--offset_range'] = self.offsetSearchRangePix.get()
args['--offset_step'] = self.offsetSearchStepPix.get() * 2
def _fillDataFromIter(self, imgSet, iteration):
outImgsFn = self._getFileName('data', iter=iteration)
imgSet.setAlignmentProj()
px = imgSet.getSamplingRate()
self.reader = convert.createReader(alignType=ALIGN_PROJ,
pixelSize=px)
mdIter = convert.Table.iterRows('particles@' + outImgsFn, key='rlnImageId')
imgSet.copyItems(self._getInputParticles(), doClone=False,
updateItemCallback=self._createItemMatrix,
itemDataIterator=mdIter)
def _createItemMatrix(self, item, row):
self.reader.setParticleTransform(item, row)