Source code for xmipp3.protocols.protocol_align_volume

# **************************************************************************
# *
# * Authors:     Javier Vargas and Adrian Quintana (jvargas@cnb.csic.es aquintana@cnb.csic.es)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

import numpy as np
import pyworkflow.protocol.params as params
from pyworkflow.protocol import STEPS_PARALLEL

from pwem.protocols import ProtAlignVolume
from pwem.objects import Volume, Transform, SetOfVolumes

from xmipp3.convert import getImageLocation
from pyworkflow import BETA, UPDATED, NEW, PROD


ALIGN_MASK_CIRCULAR = 0
ALIGN_MASK_BINARY_FILE = 1

ALIGN_ALGORITHM_EXHAUSTIVE = 0
ALIGN_ALGORITHM_LOCAL = 1
ALIGN_ALGORITHM_EXHAUSTIVE_LOCAL = 2
ALIGN_ALGORITHM_FAST_FOURIER = 3


[docs]class XmippProtAlignVolume(ProtAlignVolume): """ Aligns a set of volumes using cross correlation or a Fast Fourier method. """ _label = 'align volume' nVols = 0 _devStatus = PROD def __init__(self, **args): ProtAlignVolume.__init__(self, **args) self.stepsExecutionMode = STEPS_PARALLEL #--------------------------- DEFINE param functions -------------------------------------------- def _defineParams(self, form): form.addSection(label='Volume parameters') form.addParam('inputReference', params.PointerParam, pointerClass='Volume', label="Reference volume", important=True, help='Reference volume to be used for the alignment.') form.addParam('inputVolumes', params.MultiPointerParam, pointerClass='SetOfVolumes,Volume', label="Input volume(s)", important=True, help='Select one or more volumes (Volume or SetOfVolumes)\n' 'to be aligned againt the reference volume.') group1 = form.addGroup('Mask') group1.addParam('applyMask', params.BooleanParam, default=False, label='Apply mask?', help='Apply a 3D Binary mask to the volumes') group1.addParam('maskType', params.EnumParam, choices=['circular','binary file'], default=ALIGN_MASK_CIRCULAR, label='Mask type', display=params.EnumParam.DISPLAY_COMBO, condition='applyMask', help='Select the type of mask you want to apply') group1.addParam('maskRadius', params.IntParam, default=-1, condition='applyMask and maskType==%d' % ALIGN_MASK_CIRCULAR, label='Mask radius', help='Insert the radius for the mask') group1.addParam('maskFile', params.PointerParam, condition='applyMask and maskType==%d' % ALIGN_MASK_BINARY_FILE, pointerClass='VolumeMask', label='Mask file', help='Select the volume mask object') form.addSection(label='Search strategy') form.addParam('alignmentAlgorithm', params.EnumParam, default=ALIGN_ALGORITHM_FAST_FOURIER, choices=['exhaustive', 'local', 'exhaustive + local', 'fast fourier'], label='Alignment algorithm', display=params.EnumParam.DISPLAY_COMBO, help='Exhaustive searches all possible combinations within a search space.' 'Local searches around a given position.' 'Be aware that the Fast Fourier algorithm requires a special compilation' 'of Xmipp (--cltomo flag). It performs the same job as the ' 'exhaustive method but much faster.') anglesCond = 'alignmentAlgorithm!=%d' % ALIGN_ALGORITHM_LOCAL group = form.addGroup('Angles range', condition=anglesCond, expertLevel=params.LEVEL_ADVANCED) line = group.addLine('Rotational angle (deg)') line.addParam('minRotationalAngle', params.FloatParam, default=0, label='Min') line.addParam('maxRotationalAngle', params.FloatParam, default=360, label='Max') line.addParam('stepRotationalAngle', params.FloatParam, default=5, label='Step') line = group.addLine('Tilt angle (deg)', expertLevel=params.LEVEL_ADVANCED) line.addParam('minTiltAngle', params.FloatParam, default=0, label='Min') line.addParam('maxTiltAngle', params.FloatParam, default=180, label='Max') line.addParam('stepTiltAngle', params.FloatParam, default=5, label='Step') line = group.addLine('Inplane angle (deg)', expertLevel=params.LEVEL_ADVANCED) line.addParam('minInplaneAngle', params.FloatParam, default=0, label='Min') line.addParam('maxInplaneAngle', params.FloatParam, default=360, label='Max') line.addParam('stepInplaneAngle', params.FloatParam, default=5, label='Step') group = form.addGroup('Shifts range', condition=anglesCond, expertLevel=params.LEVEL_ADVANCED) line = group.addLine('Shift X (px)') line.addParam('minimumShiftX', params.FloatParam, default=0, label='Min') line.addParam('maximumShiftX', params.FloatParam, default=0, label='Max') line.addParam('stepShiftX', params.FloatParam, default=1, label='Step') line = group.addLine('Shift Y (px)', expertLevel=params.LEVEL_ADVANCED) line.addParam('minimumShiftY', params.FloatParam, default=0, label='Min') line.addParam('maximumShiftY', params.FloatParam, default=0, label='Max') line.addParam('stepShiftY', params.FloatParam, default=1, label='Step') line = group.addLine('Shift Z (px)', expertLevel=params.LEVEL_ADVANCED) line.addParam('minimumShiftZ', params.FloatParam, default=0, label='Min') line.addParam('maximumShiftZ', params.FloatParam, default=0, label='Max') line.addParam('stepShiftZ', params.FloatParam, default=1, label='Step') line = form.addLine('Scale ', expertLevel=params.LEVEL_ADVANCED, condition=anglesCond) line.addParam('minimumScale', params.FloatParam, default=1, label='Min') line.addParam('maximumScale', params.FloatParam, default=1, label='Max') line.addParam('stepScale', params.FloatParam, default=0.005, label='Step') group = form.addGroup('Initial values', condition='alignmentAlgorithm==%d' % ALIGN_ALGORITHM_LOCAL, expertLevel=params.LEVEL_ADVANCED) line = group.addLine('Initial angles') line.addParam('initialRotAngle', params.FloatParam, default=0, label='Rot') line.addParam('initialTiltAngle', params.FloatParam, default=0, label='Tilt') line.addParam('initialInplaneAngle', params.FloatParam, default=0, label='Psi') line = group.addLine('Initial shifts ', expertLevel=params.LEVEL_ADVANCED) line.addParam('initialShiftX', params.FloatParam, default=0, label='X') line.addParam('initialShiftY', params.FloatParam, default=0, label='Y') line.addParam('initialShiftZ', params.FloatParam, default=0, label='Z') group.addParam('optimizeScale', params.BooleanParam, default=False, expertLevel=params.LEVEL_ADVANCED, label='Optimize scale', help='Choose YES if you want to optimize the scale of input volume/s based on the reference') group.addParam('initialScale', params.FloatParam, default=1, expertLevel=params.LEVEL_ADVANCED, condition='optimizeScale', label='Initial scale') form.addParallelSection(threads=8, mpi=1) #--------------------------- INSERT steps functions -------------------------------------------- def _insertAllSteps(self): # Iterate through all input volumes and align them # againt the reference volume refFn = getImageLocation(self.inputReference.get()) maskArgs = self._getMaskArgs() alignArgs = self._getAlignArgs() alignSteps = [] idx=1 for vol in self._iterInputVolumes(): volFn = getImageLocation(vol) volId = vol.getObjId() stepId = self._insertFunctionStep('alignVolumeStep', refFn, volFn, self._getExtraPath("vol%02d.mrc"%idx), maskArgs, alignArgs, idx, prerequisites=[]) alignSteps.append(stepId) idx+=1 self._insertFunctionStep('createOutputStep', prerequisites=alignSteps) #--------------------------- STEPS functions --------------------------------------------
[docs] def alignVolumeStep(self, refFn, inVolFn, outVolFn, maskArgs, alignArgs, volId): args = "--i1 %s --i2 %s --apply %s" % (refFn, inVolFn, outVolFn) args += maskArgs args += alignArgs args += " --copyGeo %s" % ( self._getExtraPath('transformation-matrix_vol%06d.txt'%volId)) self.runJob("xmipp_volume_align", args) if self.alignmentAlgorithm == ALIGN_ALGORITHM_EXHAUSTIVE_LOCAL: args = "--i1 %s --i2 %s --apply --local" % (refFn, outVolFn) args += " --copyGeo %s" % ( self._getExtraPath('transformation-matrix_vol%06d.txt'%volId)) self.runJob("xmipp_volume_align", args)
[docs] def createOutputStep(self): Ts = self.inputReference.get().getSamplingRate() vols = [] idx=1 for vol in self._iterInputVolumes(): outVol = Volume() fnOutVol = self._getExtraPath("vol%02d.mrc"%idx) outVol.setLocation(fnOutVol) outVol.setObjComment(vol.getObjComment()) outVol.setObjLabel(vol.getObjLabel()) #set transformation matrix fhInputTranMat = self._getExtraPath('transformation-matrix_vol%06d.txt'%idx) transMatFromFile = np.loadtxt(fhInputTranMat) transformationMat = np.reshape(transMatFromFile,(4,4)) transform = Transform() transform.setMatrix(transformationMat) outVol.setTransform(transform) vols.append(outVol) # Set the sampling rate in the mrc header self.runJob("xmipp_image_header", "-i %s --sampling_rate %f"%(fnOutVol, Ts)) idx+=1 if len(vols) > 1: volSet = self._createSetOfVolumes() volSet.setSamplingRate(Ts) for vol in vols: volSet.append(vol) outputArgs = {'outputVolumes': volSet} else: vols[0].setSamplingRate(Ts) outputArgs = {'outputVolume': vols[0]} self._defineOutputs(**outputArgs) if len(vols) > 1: for pointer in self.inputVolumes: self._defineSourceRelation(pointer, outputArgs['outputVolumes']) else: for pointer in self.inputVolumes: self._defineSourceRelation(pointer, outputArgs['outputVolume'])
# --------------------------- INFO functions -------------------------------------------- def _validate(self): errors = [] for pointer in self.inputVolumes: if pointer.pointsNone(): errors.append('Invalid input, pointer: %s' % pointer.getObjValue()) errors.append(' extended: %s' % pointer.getExtended()) return errors def _summary(self): summary = [] nVols = self._getNumberOfInputs() if nVols > 0: summary.append("Volumes to align: *%d* " % nVols) else: summary.append("No volumes selected.") summary.append("Alignment method: %s" % self.getEnumText('alignmentAlgorithm')) return summary def _methods(self): nVols = self._getNumberOfInputs() if nVols > 0: methods = 'We aligned %d volumes against a reference volume using ' % nVols #TODO: Check a more descriptive way to add the reference and # all aligned volumes to the methods (such as obj.getNameId()) # also to show the number of volumes from each set in the input. # This approach implies to consistently include also the outputs # ids to be tracked in all the workflow's methods. if self.alignmentAlgorithm == ALIGN_ALGORITHM_FAST_FOURIER: methods += ' the Fast Fourier alignment described in [Chen2013].' elif self.alignmentAlgorithm == ALIGN_ALGORITHM_LOCAL: methods += ' a local search of the alignment parameters.' elif self.alignmentAlgorithm == ALIGN_ALGORITHM_EXHAUSTIVE: methods += ' an exhaustive search.' elif self.alignmentAlgorithm == ALIGN_ALGORITHM_EXHAUSTIVE_LOCAL: methods += ' an exhaustive search followed by a local search.' else: methods = 'No methods available yet.' return [methods] def _citations(self): if self.alignmentAlgorithm == ALIGN_ALGORITHM_FAST_FOURIER: return ['Chen2013'] #--------------------------- UTILS functions -------------------------------------------- def _iterInputVolumes(self): """ Iterate over all the input volumes. """ for pointer in self.inputVolumes: item = pointer.get() if item is None: break itemId = item.getObjId() if isinstance(item, Volume): item.outputName = self._getExtraPath('output_vol%06d.mrc' % itemId) # If item is a Volume and label is empty if not item.getObjLabel(): # Volume part of a set if item.getObjParentId() is None: item.setObjLabel("%s.%s" % (pointer.getObjValue(), pointer.getExtended())) else: item.setObjLabel('%s.%s' % (self.getMapper().getParent(item).getRunName(), item.getClassName())) yield item elif isinstance(item, SetOfVolumes): for vol in item: vol.outputName = self._getExtraPath('output_vol%06d_%03d.mrc' % (itemId, vol.getObjId())) # If set item label is empty if not vol.getObjLabel(): # if set label is not empty use it if item.getObjLabel(): vol.setObjLabel("%s - %s%s" % (item.getObjLabel(), vol.getClassName(), vol.getObjId())) else: vol.setObjLabel("%s - %s%s" % (self.getMapper().getParent(item).getRunName(), vol.getClassName(), vol.getObjId())) yield vol def _getNumberOfInputs(self): """ Return the total number of input volumes. """ nVols = 0 for _ in self._iterInputVolumes(): nVols += 1 return nVols def _getMaskArgs(self): maskArgs = '' if self.applyMask: if self.maskType == ALIGN_MASK_CIRCULAR: maskArgs+=" --mask circular -%d" % self.maskRadius else: maskArgs+=" --mask binary_file %s" % self.maskFile.get().getFileName() return maskArgs def _getAlignArgs(self): alignArgs = ' --dontWrap' if self.alignmentAlgorithm == ALIGN_ALGORITHM_FAST_FOURIER: alignArgs += " --frm" elif self.alignmentAlgorithm == ALIGN_ALGORITHM_LOCAL: alignArgs += " --local --rot %f %f 1 --tilt %f %f 1 --psi %f %f 1 -x %f %f 1 -y %f %f 1 -z %f %f 1" %\ (self.initialRotAngle, self.initialRotAngle, self.initialTiltAngle, self.initialTiltAngle, self.initialInplaneAngle, self.initialInplaneAngle, self.initialShiftX, self.initialShiftX, self.initialShiftY, self.initialShiftY, self.initialShiftZ,self.initialShiftZ) if self.optimizeScale: alignArgs += " --scale %f %f 0.005" %(self.initialScale, self.initialScale) else: alignArgs += " --dontScale" else: # Exhaustive or Exhaustive+Local alignArgs += " --rot %f %f %f --tilt %f %f %f --psi %f %f %f -x %f %f %f -y %f %f %f -z %f %f %f --scale %f %f %f" %\ (self.minRotationalAngle, self.maxRotationalAngle, self.stepRotationalAngle, self.minTiltAngle, self.maxTiltAngle, self.stepTiltAngle, self.minInplaneAngle, self.maxInplaneAngle, self.stepInplaneAngle, self.minimumShiftX, self.maximumShiftX, self.stepShiftX, self.minimumShiftY, self.maximumShiftY, self.stepShiftY, self.minimumShiftZ, self.maximumShiftZ, self.stepShiftZ, self.minimumScale, self.maximumScale, self.stepScale) return alignArgs
[docs]class XmippProtAlignVolumeForWeb(XmippProtAlignVolume): """ Aligns a set of volumes using cross correlation. Based on Xmipp protocol for aligning volumes, but the parameters are restricted for ease of use. """ _label = 'align volume web' def _defineParams(self, form): XmippProtAlignVolume._defineParams(self, form) maskGroup = form.getParam('Mask') maskGroup.config(condition='False') # Set as default the fast fourier align method # this requires that the Xmipp is compiled with # the corresponding flag form.getParam('alignmentAlgorithm').config(default=ALIGN_ALGORITHM_FAST_FOURIER)