Source code for xmipp3.protocols.protocol_deep_hand

# **************************************************************************
# *
# * Authors:     Jorge Garcia Condado (jorgegarciacondado@gmail.com)
# *
# * Unidad de  Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************

from pwem.protocols import EMProtocol
from pwem.objects import Volume

from pyworkflow.protocol.params import PointerParam, FloatParam
from pyworkflow.object import Float
from pyworkflow.utils.path import cleanPath

from xmipp3.base import XmippProtocol
from xmipp3.convert import getImageLocation
from pyworkflow import BETA, UPDATED, NEW, PROD

[docs]class XmippProtDeepHand(EMProtocol, XmippProtocol): """Predicts the handedness of a structure using a trained deep learning model. Determining correct handedness is essential in cryo-EM to ensure the accurate interpretation of 3D reconstructions. """ _label ="deep hand" _conda_env = "xmipp_pyTorch" _devStatus = UPDATED def __init__(self, *args, **kwargs): EMProtocol.__init__(self, *args, **kwargs) XmippProtocol.__init__(self) self.vResizedVolFile = 'resizedVol.mrc' self.vMaskFile = 'mask.mrc' self.vFilteredVolFile = 'filteredVol.mrc' def _defineParams(self, form): form.addSection('Input') form.addParam('inputVolume', PointerParam, pointerClass="Volume", label='Input Volume', allowsNull=False, important=True, help="Volume to process") form.addParam('threshold', FloatParam, label='Mask Threshold', allowsNull=False, important=True, help="Threshold for mask creation") form.addParam('thresholdAlpha', FloatParam, label='Alpha Threshold', default=0.7, help="Threshold for alpha helix determination") form.addParam('thresholdHand', FloatParam, label='Hand Threshold', default=0.6, help="Hand threshold to flip volume") # --------------------------- INSERT steps functions -------------------------------------------- def _insertAllSteps(self): self._insertFunctionStep('preprocessStep') self._insertFunctionStep('predictStep') self._insertFunctionStep('flipStep') self._insertFunctionStep('createOutputStep')
[docs] def preprocessStep(self): # Get volume information volume = self.inputVolume.get() fn_vol = getImageLocation(volume) T_s = volume.getSamplingRate() # Paths to new files created self.resizedVolFile = self._getPath(self.vResizedVolFile) self.maskFile = self._getPath(self.vMaskFile) self.filteredVolFile = self._getPath(self.vFilteredVolFile) # Resize to 1A/px self.runJob("xmipp_image_resize", "-i %s -o %s --factor %f" % (fn_vol, self.resizedVolFile, T_s)) # Threshold to obtain mask self.runJob("xmipp_transform_threshold", "-i %s -o %s --select below %f --substitute binarize" % (self.resizedVolFile, self.maskFile, self.threshold.get())) # Filter to 5A self.runJob("xmipp_transform_filter", "-i %s -o %s "\ "--fourier low_pass %f --sampling 1" % (self.resizedVolFile, self.filteredVolFile, 5.0))
[docs] def predictStep(self): # Get saved models DTlK alpha_model= self.getModel('deepHand', '5A_SSE_experimental.pth') hand_model = self.getModel('deepHand', '5A_TL_hand_alpha.pth') # Predict hand args = "--alphaModel %s --handModel %s -o %s " \ "--alphaThr %f --pathVf %s --pathVmask %s" % ( alpha_model, hand_model, self._getExtraPath(), self.thresholdAlpha.get(), self._getPath(self.vFilteredVolFile), self._getPath(self.vMaskFile)) env = self.removeCudaFromEnvPath(self.getCondaEnv()) self.runJob("xmipp_deep_hand", args, env=env) # Store hand value hand_file = self._getExtraPath('hand.txt') f = open(hand_file) self.hand = Float(float(f.read())) f.close()
[docs] def removeCudaFromEnvPath(self, env): paths = env['LD_LIBRARY_PATH'].split(':') finalPath = '' for p in paths: if p.find('cuda') != -1: pass else: finalPath+= p + ':' env['LD_LIBRARY_PATH'] = finalPath return env
[docs] def flipStep(self): if self.hand.get() > self.thresholdHand.get(): volume = self.inputVolume.get() fn_vol = getImageLocation(volume) pathFlipVol = self._getPath('flipVol.mrc') self.runJob("xmipp_transform_mirror", "-i %s -o %s --flipX" \ % (fn_vol, pathFlipVol))
[docs] def createOutputStep(self): self._defineOutputs(outputHand=self.hand) if self.hand.get() > self.thresholdHand.get(): vol = Volume() volFile = self._getPath('flipVol.mrc') vol.setFileName(volFile) Ts = self.inputVolume.get().getSamplingRate() vol.setSamplingRate(Ts) self.runJob("xmipp_image_header","-i %s --sampling_rate %f"%(volFile,Ts)) self._defineOutputs(outputVol=vol) else: self._defineOutputs(outputVol=self.inputVolume.get()) self._defineSourceRelation(self.inputVolume, self.outputVol) cleanPath(self._getPath(self.vResizedVolFile)) cleanPath(self._getPath(self.vMaskFile)) cleanPath(self._getPath(self.vFilteredVolFile))
# --------------------------- INFO functions ------------------------------- def _summary(self): summary = [] if hasattr(self, 'outputHand'): summary.append('Hand value is: %f' %self.outputHand.get()) summary.append('Hand values close to 1 mean the structure is predicted to be left handed') summary.append('Hand values close to 0 mean the structure is predicted to be right handed') if self.outputHand.get() > self.thresholdHand.get(): summary.append('Volume was flipped as it was deemed to be left handed') else: summary.append('Volume was not flipped as it was deemed to be right handed') else: summary.append("Output volume and handedness not ready yet.") return summary def _methods(self): methods = [] return methods def _validate(self): errors = [] errors += self.validateDLtoolkit() if self.thresholdAlpha.get() > 1.0 or self.thresholdAlpha.get() < 0.0: errors.append("Alpha threshold must be between 0.0 and 1.0") if self.thresholdHand.get() > 1.0 or self.thresholdHand.get() < 0.0: errors.append("Hand threshold must be between 0.0 and 1.0") return errors