Source code for cryocare.protocols.protocol_predict

import json
from os.path import abspath, join

from pwem.protocols import EMProtocol
from pyworkflow import BETA
from pyworkflow.protocol import params, StringParam
from pyworkflow.utils import Message, removeBaseExt, makePath
from scipion.constants import PYTHON

from cryocare import Plugin
from tomo.objects import Tomogram
from tomo.protocols import ProtTomoBase

from cryocare.constants import PREDICT_CONFIG, CRYOCARE_MODEL
from cryocare.utils import CryocareUtils as ccutils


[docs]class ProtCryoCAREPrediction(EMProtocol, ProtTomoBase): """Generate the final restored tomogram by applying the cryoCARE trained network to both tomograms followed by per-pixel averaging.""" _label = 'CryoCARE Prediction' _configPath = [] _outputFiles = [] _devStatus = BETA # -------------------------- DEFINE param functions ---------------------- def _defineParams(self, form): """ Define the input parameters that will be used. Params: form: this is the form to be populated with sections and params. """ # You need a params to belong to a section: form.addSection(label=Message.LABEL_INPUT) form.addParam('even', params.PointerParam, pointerClass='SetOfTomograms', label='Even tomograms', important=True, allowsNull=False, help='Set of tomogram reconstructed from the even frames of the tilt' 'series movies.') form.addParam('odd', params.PointerParam, pointerClass='SetOfTomograms', label='Odd tomograms', important=True, allowsNull=False, help='Set of tomograms reconstructed from the odd frames of the tilt' 'series movies.') form.addParam('model', params.PointerParam, pointerClass='CryocareModel', label="cryoCARE Model", important=True, allowsNull=False, help='Select a trained cryoCARE model.') form.addParam('n_tiles', StringParam, label="Number of tiles", default='1 1 1', important=True, allowsNull=False, help='Normally the gpu cannot handle the whole size of the tomogrmas, so it can be split into ' 'n tiles per axis to process smaller volumes instead of one big at once.') form.addHidden(params.GPU_LIST, params.StringParam, default='0', expertLevel=params.LEVEL_ADVANCED, label="Choose GPU IDs", help="GPU ID, normally it is 0.") # --------------------------- STEPS functions ------------------------------ def _insertAllSteps(self): numTomo = 0 makePath(self._getPredictConfDir()) # Insert processing steps for evenTomo, oddTomo in zip(self.even.get(), self.odd.get()): self._insertFunctionStep(self.preparePredictStep, evenTomo.getFileName(), oddTomo.getFileName(), numTomo) self._insertFunctionStep(self.predictStep, numTomo) numTomo += 1 self._insertFunctionStep(self.createOutputStep)
[docs] def preparePredictStep(self, evenTomo, oddTomo, numTomo): outputName = self._getOutputName(evenTomo) self._outputFiles.append(outputName) config = { 'model_name': CRYOCARE_MODEL, 'path': self.model.get().getPath(), 'even': evenTomo, 'odd': oddTomo, 'output_name': outputName, 'n_tiles': [int(i) for i in self.n_tiles.get().split()] } self._configPath.append(join(self._getPredictConfDir(), '{}_{:03d}.json'.format(PREDICT_CONFIG, numTomo))) with open(self._configPath[numTomo], 'w+') as f: json.dump(config, f, indent=2)
[docs] def predictStep(self, numTomo): # Run cryoCARE Plugin.runCryocare(self, PYTHON, '$(which cryoCARE_predict.py) --conf %s' % self._configPath[numTomo], gpuId=getattr(self, params.GPU_LIST).get())
[docs] def createOutputStep(self): outputSetOfTomo = self._createSetOfTomograms(suffix='_denoised') outputSetOfTomo.copyInfo(self.even.get()) for i, inTomo in enumerate(self.even.get()): tomo = Tomogram() tomo.setLocation(self._outputFiles[i]) tomo.setSamplingRate(inTomo.getSamplingRate()) outputSetOfTomo.append(tomo) self._defineOutputs(outputTomograms=outputSetOfTomo)
# --------------------------- INFO functions ----------------------------------- def _summary(self): """ Summarize what the protocol has done""" summary = [] if self.isFinished(): summary.append( "Tomogram denoising finished.") return summary def _validate(self): validateMsgs = [] msg = ccutils.checkInputTomoSetsSize(self.even.get(), self.odd.get()) if msg: validateMsgs.append() return validateMsgs # --------------------------- UTIL functions ----------------------------------- def _getOutputName(self, inTomoName): outputName = removeBaseExt(inTomoName) + '_denoised.mrc' return abspath(self._getExtraPath(outputName.replace('_Even', '').replace('_Odd', ''))) def _getPredictConfDir(self): return self._getExtraPath(PREDICT_CONFIG)