Source code for deepfinder.protocols.protocol_segment

from os.path import abspath

from pyworkflow import BETA
from pyworkflow.protocol import params, PointerParam, GPU_LIST, LEVEL_ADVANCED
from pyworkflow.utils import removeBaseExt
from import Message
from tomo.objects import Tomogram, TomoMask, SetOfTomoMasks
from tomo.protocols import ProtTomoPicking

from deepfinder import Plugin
from deepfinder.protocols import ProtDeepFinderBase

[docs]class DeepFinderSegment(ProtTomoPicking, ProtDeepFinderBase): """This protocol segments tomograms, using a trained neural network.""" _label = 'segment' _devStatus = BETA _outputFiles = [] _outputFilesBinned = [] # --------------------------- DEFINE param functions ---------------------- def _defineParams(self, form): ProtTomoPicking._defineParams(self, form) form.addParam('weights', PointerParam, pointerClass='DeepFinderNet', label="Neural network model", important=True, help='Select a trained DeepFinder neural network.') form.addParam('psize', params.IntParam, default=100, label='Patch size', important=True, help='It must be a multiple of 4, due to the network architecture.') form.addHidden(GPU_LIST, params.StringParam, default='0', expertLevel=LEVEL_ADVANCED, label="Choose GPU IDs", help="GPU ID, normally it is 0.") # --------------------------- INSERT steps functions ---------------------- def _insertAllSteps(self): self._insertFunctionStep('launchSegmentationStep') self._insertFunctionStep('createOutputStep') # --------------------------- STEPS functions -----------------------------
[docs] def launchSegmentationStep(self): for tomo in self.inputTomograms.get().iterItems(): outputFileName = self._genOutputFileName(tomo, binned=False) self._outputFiles.append(outputFileName) # Launch annotation GUI passing the tomogram file name deepfinder_args = '-t ' + tomo.getFileName() deepfinder_args += ' -w ' + self.weights.get().getPath() # FIXME: Return object from pointer deepfinder_args += ' -c ' + str(self.weights.get().getNbOfClasses()) deepfinder_args += ' -p ' + str(self.psize) deepfinder_args += ' -o ' + abspath(self._getExtraPath(outputFileName)) Plugin.runDeepFinder(self, 'segment', deepfinder_args, gpuId=getattr(self, GPU_LIST).get())
[docs] def createOutputStep(self): tomoMaskSet = SetOfTomoMasks.create(self._getPath(), template='setOfTomoMasks%s.sqlite') tomoMaskSet.copyInfo(self.inputTomograms.get()) tomoMaskSet.setDim(self.inputTomograms.get().getDimensions()) tomoMaskSet.setName('segmented tomogram set') for tomo, tomoMaskName in zip(self.inputTomograms.get(), self._outputFiles): # Import generated target from extra folder and store into TomoMask object: tomoMask = TomoMask() tomoMask.cleanObjId() tomoMask.copyInfo(tomo) tomoMask.setFileName(self._getExtraPath(tomoMaskName)) # Link to origin tomogram: tomoMask.setVolName(tomo.getFileName()) tomoMaskSet.append(tomoMask) # Link to output: self._defineOutputs(outputTargetSet=tomoMaskSet)
# --------------------------- INFO functions ---------------------- def _summary(self): """ Summarize what the protocol has done""" summary = [] if self.isFinished(): summary.append("Segmentation finished.") return summary
[docs] def getMethods(self, output): msg = 'User picked %d particles ' % self.outputTargetSet.getSize() return msg
def _methods(self): methodsMsgs = [] if self.inputTomograms is None: return ['Input tomogram not available yet.'] methodsMsgs.append("Input tomograms imported of dims %s." % ( str(self.inputTomograms.get().getDim()))) if self.outputTargetSet.getSize() >= 1: for key, output in self.iterOutputAttributes(): msg = self.getMethods(output) methodsMsgs.append("%s: %s" % (self.getObjectTag(output), msg)) else: methodsMsgs.append(Message.TEXT_NO_OUTPUT_CO) return methodsMsgs # --------------------------- UTILS functions ---------------------- def _genOutputData(self, fileList, suffix): outputSetOfTomo = self._createSetOfTomograms(suffix=suffix) outputSetOfTomo.copyInfo(self.inputTomograms.get()) for i, inTomo in enumerate(self.inputTomograms.get()): tomo = Tomogram() tomo.setLocation(self._getExtraPath(fileList[i])) tomo.setSamplingRate(inTomo.getSamplingRate()) outputSetOfTomo.append(tomo) return outputSetOfTomo @staticmethod def _genOutputFileName(tomo, binned=False): binStr = '' if binned: binStr = '_binned' return 'segmentation_' + removeBaseExt(tomo.getFileName()) + '%s.mrc' % binStr