# **************************************************************************
# *
# * Authors: COS Sorzano
# *
# * Unidad de Bioinformatica of Centro Nacional de Biotecnologia , CSIC
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from pyworkflow import VERSION_3_0
from pyworkflow.protocol.params import (PointerParam, StringParam, EnumParam, FileParam, IntParam, GPU_LIST)
from pyworkflow.protocol.constants import LEVEL_ADVANCED
from pyworkflow.utils import Message
from pwem.protocols import ProtAlign2D
from xmipp3.convert import readSetOfParticles, writeSetOfParticles
import os
import xmipp3
from pyworkflow import BETA, UPDATED, NEW, PROD
[docs]class XmippProtDeepCenterPredict(ProtAlign2D, xmipp3.XmippProtocol):
"""Center a set of particles in 2D using a neural network. The particles remain the same, but their alignment
includes an approximate shift to place them in the center. This protocol only predicts, it does not train. """
_lastUpdateVersion = VERSION_3_0
_conda_env = 'xmipp_DLTK_v1.0'
_label = 'deep center predict'
_devStatus = BETA
PRETRAINED = 0
PREVIOUS = 1
LOCALFILE = 2
def __init__(self, **args):
ProtAlign2D.__init__(self, **args)
# --------------------------- DEFINE param functions --------------------------------------------
def _defineParams(self, form):
form.addParallelSection(threads=1, mpi=16)
form.addHidden(GPU_LIST, StringParam, default='0',
expertLevel=LEVEL_ADVANCED,
label="Choose GPU IDs",
help="GPU may have several cores. Set it to zero"
" if you do not know what we are talking about."
" First core index is 0, second 1 and so on.")
form.addSection(label=Message.LABEL_INPUT)
form.addParam('inputParticles', PointerParam, label="Input images",
pointerClass='SetOfParticles',
help='The set does not need to be centered or have alignment parameters')
form.addParam('modelSource', EnumParam, label="Alignment source", default=self.PRETRAINED,
choices=["Pretrained model", "Previous protocol", "Local file"],
help="Source for the neural network")
form.addParam('protocolPointer', PointerParam, label="Protocol", condition="modelSource==1",
pointerClass="XmippProtDeepCenter", allowsNull=True)
form.addParam('modelFile', FileParam, label="Model", condition="modelSource==2",
help="Provide a local .h5 file")
form.addParam('modelXdim', IntParam, label="Image size of model", condition="modelSource==2",
default = 64, help="Image size on which the model was trained")
# --------------------------- INSERT steps functions --------------------------------------------
def _insertAllSteps(self):
self.fnImgs = self._getExtraPath('imgs.xmd')
if self.useQueueForSteps() or self.useQueue():
myStr = os.environ["CUDA_VISIBLE_DEVICES"]
else:
myStr = self.gpuList.get()
os.environ["CUDA_VISIBLE_DEVICES"] = self.gpuList.get()
numGPU = myStr.split(',')
self._insertFunctionStep("convertInputStep", self.inputParticles.get())
self._insertFunctionStep("predict", numGPU[0])
self._insertFunctionStep("createOutputStep")
# --------------------------- STEPS functions ---------------------------------------------------
[docs] def predict(self, gpuId):
if self.modelSource==self.PRETRAINED:
fnModel = self.getModel('deepCenter', 'deepCenterModel.h5')
elif self.getRunMode()==self.PREVIOUS:
fnModel = self.protocolPointer.get()._getExtraPath("model.h5")
else:
fnModel = self.modelFile.get()
args = "-i %s --gpu %s --model %s -o %s --scale %f" % (self.fnImgs, gpuId, fnModel,
self.fnImgs, self.scaleFactor)
self.runJob("xmipp_deep_center_predict", args, numberOfMpi=1, env=self.getCondaEnv())
epsilon = 1e-6 # Un margen de tolerancia pequeño
if abs(self.scaleFactor - 1.0) > epsilon:
fnShifts = self._getTmpPath("shifts.xmd")
self.runJob("xmipp_metadata_utilities", '-i %s --operate keep_column "itemId shiftX shiftY psi" -o %s'%\
(self.fnImgs,fnShifts), numberOfMpi=1)
self.fnImgs=self._getExtraPath("imgs.xmd")
self.runJob("xmipp_metadata_utilities", '-i %s --set join %s itemId -o %s'%\
(fnShifts, self.fnImgs, self.fnImgs), numberOfMpi=1)
[docs] def createOutputStep(self):
fnPredict = self.fnImgs
outputSet = self._createSetOfParticles()
readSetOfParticles(fnPredict, outputSet)
outputSet.copyInfo(self.inputParticles.get())
outputSet.setAlignment2D()
self._defineOutputs(outputParticles=outputSet)
self._store(outputSet)
self._defineSourceRelation(self.inputParticles.get(), outputSet)