Source code for deepfinder.protocols.protocol_train

# -*- coding: utf-8 -*-
# **************************************************************************
# *
# * Authors: Emmanuel Moebel (emmanuel.moebel@inria.fr)
# *
# * Inria - Centre de Rennes Bretagne Atlantique, France
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'you@yourinstitution.email'
# *
# **************************************************************************
from os.path import abspath

from pwem.protocols import EMProtocol
from pyworkflow.protocol import params, PointerParam
from pyworkflow.utils.properties import Message

from deepfinder import Plugin
import deepfinder.convert as cv
from deepfinder.objects import DeepFinderNet

from deepfinder.protocols import ProtDeepFinderBase
from tomo.protocols import ProtTomoBase
from tomo.objects import SetOfTomoMasks

"""
Describe your python module here:
This module will provide the traditional Hello world example
"""

PSIZE_CHOICES = list(range(40, 65, 4))

[docs]class DeepFinderTrain(EMProtocol, ProtDeepFinderBase, ProtTomoBase): """ This protocol launches the training procedure """ _label = 'train' def __init__(self, **args): EMProtocol.__init__(self, **args) self.nClass = None # -------------------------- DEFINE param functions ---------------------- def _defineParams(self, form): """ Define the input parameters that will be used. Params: form: this is the form to be populated with sections and params. """ # You need a params to belong to a section: form.addSection(label=Message.LABEL_INPUT) form.addParam('tomoMasksTrain', PointerParam, pointerClass='SetOfTomoMasks', label="Training TomoMasks", important=True, help='Training dataset. Please select here your TomoMasks. ' 'The corresponding tomograms will be loaded automatically.') form.addParam('tomoMasksValid', PointerParam, pointerClass='SetOfTomoMasks', label="Validation TomoMasks", important=True, help='Validation dataset. Please select here your TomoMasks. ' 'The corresponding tomograms will be loaded automatically.') form.addParam('coord', params.PointerParam, label="Coordinates", pointerClass='SetOfCoordinates3D', help='Select coordinate set.') form.addSection(label='Training Parameters') form.addParam('psize', params.EnumParam, default=0, # 40: 1st element in [40, 44, 48, 52, 56, 60, 64] choices=list(range(40, 65, 4)), label='Patch size', important=True, help='Size of patches loaded into memory for training.') form.addParam('bsize', params.IntParam, default=25, label='Batch size', important=True, help='Number of patches used to compute average loss.') form.addParam('epochs', params.IntParam, default=100, label='Number of epochs', important=True, help='At the end of each epoch, evaluation on validation set is performed (useful to check if network overfits).') form.addParam('stepsPerE', params.IntParam, default=100, label='Steps per epoch', important=True, help='Number of batches trained on per epoch.') form.addParam('stepsPerV', params.IntParam, default=10, label='Steps per validation', important=True, help='Number of batches used for validation.') form.addParam('bootstrap', params.BooleanParam, default=True, label='Bootstrap', important=True, help='Can remain checked. Usefull when in presence of unbalanced classes.') form.addParam('rndShift', params.IntParam, default=13, label='Random shift', important=True, help='(in voxels) Applied to positions in object list when sampling patches. Enhances network robustness. Make sure that objects are still contained in patches when applying shift.') # --------------------------- STEPS functions ------------------------------ def _insertAllSteps(self): # Insert processing steps self._insertFunctionStep('trainingStep') self._insertFunctionStep('createOutputStep')
[docs] def trainingStepOLD(self): # Get paths to tomograms and corresponding targets: path_tomo = [] path_segm = [] for segm in self.targets.get().iterItems(): fname_segm = segm.getFileName() fname_tomo = segm.getVolName() path_tomo.append(fname_tomo) path_segm.append(fname_segm) # Get objl_train and objl_valid and save to temp folder: objl_train = self._getObjlFromInputCoordinates(self.targets.get(), self.coordTrain.get()) objl_valid = self._getObjlFromInputCoordinates(self.targets.get(), self.coordValid.get()) fname_objl_train = abspath(self._getExtraPath('objl_train.xml')) cv.objl_write(objl_train, fname_objl_train) fname_objl_valid = abspath(self._getExtraPath('objl_valid.xml')) cv.objl_write(objl_valid, fname_objl_valid) # Get number of classes from objl, and store as attribute (useful for output step): self.nClass = len(cv.objl_get_labels(objl_train)) + 1 # (+1 for background class) # Save parameters to xml file: params = cv.ParamsTrain() params.path_out = abspath(self._getExtraPath())+'/' params.path_tomo = path_tomo params.path_target = path_segm params.path_objl_train = fname_objl_train params.path_objl_valid = fname_objl_valid params.Ncl = self.nClass params.psize = self._decodeContValue(getattr(self, 'psize').get()) params.bsize = self.bsize.get() params.nepochs = self.epochs.get() params.steps_per_e = self.stepsPerE.get() params.steps_per_v = self.stepsPerV.get() params.flag_direct_read = False # in current deepfinder version only works with tomos/targets stored as h5 params.flag_bootstrap = self.bootstrap.get() params.rnd_shift = self.rndShift.get() fname_params = abspath(self._getExtraPath('params_train.xml')) params.write(fname_params) # Launch DeepFinder training: deepfinder_args = '-p ' + fname_params Plugin.runDeepFinder(self, 'train', deepfinder_args)
[docs] def trainingStep(self): # Get tomo paths, target paths, train objl and valid objl for DeepFinder: path_tomos, path_targets, objl_train, objl_valid = self._getDeepFinderObjectsFromInput( self.tomoMasksTrain.get(), self.tomoMasksValid.get(), self.coord.get()) # Save objl to extra folder: fname_objl_train = abspath(self._getExtraPath('objl_train.xml')) cv.objl_write(objl_train, fname_objl_train) fname_objl_valid = abspath(self._getExtraPath('objl_valid.xml')) cv.objl_write(objl_valid, fname_objl_valid) # Get number of classes from objl, and store as attribute (useful for output step): self.nClass = len(cv.objl_get_labels(objl_train)) + 1 # (+1 for background class) # Save parameters to xml file: params = cv.ParamsTrain() params.path_out = abspath(self._getExtraPath()) + '/' params.path_tomo = path_tomos params.path_target = path_targets params.path_objl_train = fname_objl_train params.path_objl_valid = fname_objl_valid params.Ncl = self.nClass params.psize = self._decodeContValue(getattr(self, 'psize').get()) params.bsize = self.bsize.get() params.nepochs = self.epochs.get() params.steps_per_e = self.stepsPerE.get() params.steps_per_v = self.stepsPerV.get() params.flag_direct_read = False # in current deepfinder version only works with tomos/targets stored as h5 params.flag_bootstrap = self.bootstrap.get() params.rnd_shift = self.rndShift.get() fname_params = abspath(self._getExtraPath('params_train.xml')) params.write(fname_params) # Launch DeepFinder training: deepfinder_args = '-p ' + fname_params Plugin.runDeepFinder(self, 'train', deepfinder_args)
[docs] def createOutputStep(self): netWeights = DeepFinderNet() fname = abspath(self._getExtraPath('net_weights_FINAL.h5')) netWeights.setPath(fname) netWeights.setNbOfClasses(self.nClass) self._defineOutputs(netWeights=netWeights)
# --------------------------- UTILITY functions -------------------------------- # @staticmethod def _decodeContValue(idx): """Decode the psize value and represent it as expected by DeepFinder""" return PSIZE_CHOICES[idx] def _getDeepFinderObjectsFromInput(self, tomoMaskSetTrain, tomoMaskSetValid, coord3DSet): """Get all objects of specified class. Args: tomoMaskSetTrain (SetOfTomoMasks) tomoMaskSetValid (SetOfTomoMasks) coord3DSet (SetOfCoordinates3D) Returns: list of strings : path_tomos[] list of strings : path_targets[] list of dict : objl_train list of dict : objl_valid """ # Join the tomoMaskSets. 1st valid, then train. Order is important! tomoMaskSetAll = self._joinSetsOfTomoMasks(tomoMaskSetValid, tomoMaskSetTrain, self._getPath()) # Get the file paths for tomos and targets (=tomoMasks) path_tomos, path_targets = self._getPathListsFromTomoMaskSet(tomoMaskSetAll) # Get deepfinder objl from coord3DSet: objl_all = self._getObjlFromInputCoordinatesV2(tomoMaskSetAll, coord3DSet) # Separate objl into objl_valid and objl_train: Nvalid = tomoMaskSetValid.__len__() Ntrain = tomoMaskSetTrain.__len__() tidx_list_valid = list(range(Nvalid)) tidx_list_train = list(range(Nvalid, Nvalid+Ntrain)) objl_valid = [] for tidx in tidx_list_valid: objl_valid = objl_valid + cv.objl_get_tomo(objl_all, tidx) objl_train = [] for tidx in tidx_list_train: objl_train = objl_train + cv.objl_get_tomo(objl_all, tidx) return path_tomos, path_targets, objl_train, objl_valid @staticmethod def _joinSetsOfTomoMasks(tomoMaskSet1, tomoMaskSet2, path): """ Joins two tomoMaskSets. Args: tomoMaskSet1 (SetOfTomoMasks) tomoMaskSet1 (SetOfTomoMasks) Returns: SetOfTomoMasks """ #tomoMaskSet = SetOfTomoMasks() tomoMaskSet = SetOfTomoMasks.create(path, template='setOfTomoMasks%s.sqlite') tomoMaskSet.copyInfo(tomoMaskSet1) tomoMaskSet.setName('target set') for tomoMask in tomoMaskSet1: tomoMaskSet.append(tomoMask) for tomoMask in tomoMaskSet2: tomoMaskSet.append(tomoMask) return tomoMaskSet @staticmethod def _getPathListsFromTomoMaskSet(tomoMaskSet): """ Gets the path lists needed by DeepFinder from protocol input. Args: tomoMaskSet (SetOfTomoMasks) Returns: list of strings : path_tomos[] list of strings : path_targets[] """ path_tomos = [] path_targets = [] for tomoMask in tomoMaskSet: path_tomos.append(abspath(tomoMask.getVolName())) path_targets.append(abspath(tomoMask.getFileName())) return path_tomos, path_targets # @staticmethod # def _getSetOfTomosFromSetOfTomoMasks(tomoMaskSet): # """ # Args: # tomoMasksSet (SetOfTomoMasks) # Returns: # SetOfTomograms # """ # tomoSet = self._createSetOfTomograms() # for tomoMask in tomoMaskSet.iterItems(): # fn_tomo = tomoMask.getVolName() # tomo = Tomogram(location=fn_tomo) # tomoSet.append(tomo) # # return tomoSet # --------------------------- INFO functions ----------------------------------- # TODO def _summary(self): """ Summarize what the protocol has done""" summary = [] if self.isFinished(): summary.append("This protocol has printed *%s* %i times." % (self.message, self.times)) return summary def _methods(self): methods = [] if self.isFinished(): methods.append("%s has been printed in this run %i times." % (self.message, self.times)) if self.previousCount.hasPointer(): methods.append("Accumulated count from previous runs were %i." " In total, %s messages has been printed." % (self.previousCount, self.count)) return methods