Source code for deepfinder.protocols.protocol_train

# -*- coding: utf-8 -*-
# **************************************************************************
# *
# * Authors: Emmanuel Moebel (emmanuel.moebel@inria.fr)
# *
# * Inria - Centre de Rennes Bretagne Atlantique, France
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'you@yourinstitution.email'
# *
# **************************************************************************
from os.path import abspath

from pwem.protocols import EMProtocol
from pyworkflow import BETA
from pyworkflow.protocol import params, PointerParam, GPU_LIST, LEVEL_ADVANCED
from pyworkflow.utils.properties import Message

from deepfinder import Plugin
import deepfinder.convert as cv
from deepfinder.objects import DeepFinderNet

from deepfinder.protocols import ProtDeepFinderBase
from tomo.protocols import ProtTomoBase
from tomo.objects import SetOfTomoMasks

PSIZE_CHOICES = list(range(40, 65, 4))


[docs]class DeepFinderTrain(EMProtocol, ProtDeepFinderBase, ProtTomoBase): """ This protocol launches the training procedure """ _label = 'train' _devStatus = BETA def __init__(self, **args): EMProtocol.__init__(self, **args) self.nClass = None # -------------------------- DEFINE param functions ---------------------- def _defineParams(self, form): """ Define the input parameters that will be used. Params: form: this is the form to be populated with sections and params. """ # You need a params to belong to a section: form.addSection(label=Message.LABEL_INPUT) form.addParam('tomoMasksTrain', PointerParam, pointerClass='SetOfTomoMasks', label="Training TomoMasks", important=True, help='Training dataset. Please select here your TomoMasks. ' 'The corresponding tomograms will be loaded automatically.') form.addParam('tomoMasksValid', PointerParam, pointerClass='SetOfTomoMasks', label="Validation TomoMasks", important=True, help='Validation dataset. Please select here your TomoMasks. ' 'The corresponding tomograms will be loaded automatically.') form.addParam('coord', params.PointerParam, label="Coordinates", pointerClass='SetOfCoordinates3D', help='Select coordinate set.') form.addSection(label='Training Parameters') form.addParam('psize', params.EnumParam, default=0, # 40: 1st element in [40, 44, 48, 52, 56, 60, 64] choices=list(range(40, 65, 4)), label='Patch size', important=True, help='Size of patches loaded into memory for training.') form.addParam('bsize', params.IntParam, default=25, label='Batch size', important=True, help='Number of patches used to compute average loss.') form.addParam('epochs', params.IntParam, default=100, label='Number of epochs', important=True, help='At the end of each epoch, evaluation on validation set is performed (useful to check if ' 'network overfits).') form.addParam('stepsPerE', params.IntParam, default=100, label='Steps per epoch', important=True, help='Number of batches trained on per epoch.') form.addParam('stepsPerV', params.IntParam, default=10, label='Steps per validation', important=True, help='Number of batches used for validation.') form.addParam('bootstrap', params.BooleanParam, default=True, label='Bootstrap', important=True, help='Can remain checked. Usefull when in presence of unbalanced classes.') form.addParam('rndShift', params.IntParam, default=13, label='Random shift', important=True, help='(in voxels) Applied to positions in object list when sampling patches. Enhances network ' 'robustness. Make sure that objects are still contained in patches when applying shift.') form.addHidden(GPU_LIST, params.StringParam, default='0', expertLevel=LEVEL_ADVANCED, label="Choose GPU IDs", help="GPU ID, normally it is 0.") # --------------------------- STEPS functions ------------------------------ def _insertAllSteps(self): # Insert processing steps self._insertFunctionStep('trainingStep') self._insertFunctionStep('createOutputStep')
[docs] def trainingStepOLD(self): # Get paths to tomograms and corresponding targets: path_tomo = [] path_segm = [] for segm in self.targets.get().iterItems(): fname_segm = segm.getFileName() fname_tomo = segm.getVolName() path_tomo.append(fname_tomo) path_segm.append(fname_segm) # Get objl_train and objl_valid and save to temp folder: objl_train = self._getObjlFromInputCoordinates(self.targets.get(), self.coordTrain.get()) objl_valid = self._getObjlFromInputCoordinates(self.targets.get(), self.coordValid.get()) fname_objl_train = abspath(self._getExtraPath('objl_train.xml')) cv.objl_write(objl_train, fname_objl_train) fname_objl_valid = abspath(self._getExtraPath('objl_valid.xml')) cv.objl_write(objl_valid, fname_objl_valid) # Get number of classes from objl, and store as attribute (useful for output step): self.nClass = len(cv.objl_get_labels(objl_train)) + 1 # (+1 for background class) # Save parameters to xml file: params = cv.ParamsTrain() params.path_out = abspath(self._getExtraPath())+'/' params.path_tomo = path_tomo params.path_target = path_segm params.path_objl_train = fname_objl_train params.path_objl_valid = fname_objl_valid params.Ncl = self.nClass params.psize = self._decodeContValue(getattr(self, 'psize').get()) params.bsize = self.bsize.get() params.nepochs = self.epochs.get() params.steps_per_e = self.stepsPerE.get() params.steps_per_v = self.stepsPerV.get() params.flag_direct_read = False # in current deepfinder version only works with tomos/targets stored as h5 params.flag_bootstrap = self.bootstrap.get() params.rnd_shift = self.rndShift.get() fname_params = abspath(self._getExtraPath('params_train.xml')) params.write(fname_params) # Launch DeepFinder training: deepfinder_args = '-p ' + fname_params Plugin.runDeepFinder(self, 'train', deepfinder_args)
[docs] def trainingStep(self): # Get tomo paths, target paths, train objl and valid objl for DeepFinder: path_tomos, path_targets, objl_train, objl_valid = self._getDeepFinderObjectsFromInput( self.tomoMasksTrain.get(), self.tomoMasksValid.get(), self.coord.get()) # Save objl to extra folder: fname_objl_train = abspath(self._getExtraPath('objl_train.xml')) cv.objl_write(objl_train, fname_objl_train) fname_objl_valid = abspath(self._getExtraPath('objl_valid.xml')) cv.objl_write(objl_valid, fname_objl_valid) # Get number of classes from objl, and store as attribute (useful for output step): self.nClass = len(cv.objl_get_labels(objl_train)) + 1 # (+1 for background class) # Save parameters to xml file: params = cv.ParamsTrain() params.path_out = abspath(self._getExtraPath()) + '/' params.path_tomo = path_tomos params.path_target = path_targets params.path_objl_train = fname_objl_train params.path_objl_valid = fname_objl_valid params.Ncl = self.nClass params.psize = self._decodeContValue(getattr(self, 'psize').get()) params.bsize = self.bsize.get() params.nepochs = self.epochs.get() params.steps_per_e = self.stepsPerE.get() params.steps_per_v = self.stepsPerV.get() params.flag_direct_read = False # in current deepfinder version only works with tomos/targets stored as h5 params.flag_bootstrap = self.bootstrap.get() params.rnd_shift = self.rndShift.get() fname_params = abspath(self._getExtraPath('params_train.xml')) params.write(fname_params) # Launch DeepFinder training: deepfinder_args = '-p ' + fname_params Plugin.runDeepFinder(self, 'train', deepfinder_args, gpuId=getattr(self, GPU_LIST).get())
[docs] def createOutputStep(self): netWeights = DeepFinderNet() fname = abspath(self._getExtraPath('net_weights_FINAL.h5')) netWeights.setPath(fname) netWeights.setNbOfClasses(self.nClass) self._defineOutputs(netWeights=netWeights)
# --------------------------- UTILITY functions -------------------------------- # @staticmethod def _decodeContValue(idx): """Decode the psize value and represent it as expected by DeepFinder""" return PSIZE_CHOICES[idx] def _getDeepFinderObjectsFromInput(self, tomoMaskSetTrain, tomoMaskSetValid, coord3DSet): """Get all objects of specified class. Args: tomoMaskSetTrain (SetOfTomoMasks) tomoMaskSetValid (SetOfTomoMasks) coord3DSet (SetOfCoordinates3D) Returns: list of strings : path_tomos[] list of strings : path_targets[] list of dict : objl_train list of dict : objl_valid """ # Join the tomoMaskSets. 1st valid, then train. Order is important! tomoMaskSetAll = self._joinSetsOfTomoMasks(tomoMaskSetValid, tomoMaskSetTrain, self._getPath()) # Get the file paths for tomos and targets (=tomoMasks) path_tomos, path_targets = self._getPathListsFromTomoMaskSet(tomoMaskSetAll) # Get deepfinder objl from coord3DSet: objl_all = self._getObjlFromInputCoordinatesV2(tomoMaskSetAll, coord3DSet) # Separate objl into objl_valid and objl_train: Nvalid = tomoMaskSetValid.__len__() Ntrain = tomoMaskSetTrain.__len__() tidx_list_valid = list(range(Nvalid)) tidx_list_train = list(range(Nvalid, Nvalid+Ntrain)) objl_valid = [] for tidx in tidx_list_valid: objl_valid = objl_valid + cv.objl_get_tomo(objl_all, tidx) objl_train = [] for tidx in tidx_list_train: objl_train = objl_train + cv.objl_get_tomo(objl_all, tidx) return path_tomos, path_targets, objl_train, objl_valid @staticmethod def _joinSetsOfTomoMasks(tomoMaskSet1, tomoMaskSet2, path): """ Joins two tomoMaskSets. Args: tomoMaskSet1 (SetOfTomoMasks) tomoMaskSet1 (SetOfTomoMasks) Returns: SetOfTomoMasks """ #tomoMaskSet = SetOfTomoMasks() tomoMaskSet = SetOfTomoMasks.create(path, template='setOfTomoMasks%s.sqlite') tomoMaskSet.copyInfo(tomoMaskSet1) tomoMaskSet.setName('target set') for tomoMask in tomoMaskSet1: tomoMaskSet.append(tomoMask) for tomoMask in tomoMaskSet2: tomoMaskSet.append(tomoMask) return tomoMaskSet @staticmethod def _getPathListsFromTomoMaskSet(tomoMaskSet): """ Gets the path lists needed by DeepFinder from protocol input. Args: tomoMaskSet (SetOfTomoMasks) Returns: list of strings : path_tomos[] list of strings : path_targets[] """ path_tomos = [] path_targets = [] for tomoMask in tomoMaskSet: path_tomos.append(abspath(tomoMask.getVolName())) path_targets.append(abspath(tomoMask.getFileName())) return path_tomos, path_targets # --------------------------- INFO functions ----------------------------------- # TODO def _summary(self): """ Summarize what the protocol has done""" summary = [] if self.isFinished(): summary.append("Training finished.") return summary def _methods(self): methods = [] if self.isFinished(): methods.append("%s has been printed in this run %i times." % (self.message, self.times)) if self.previousCount.hasPointer(): methods.append("Accumulated count from previous runs were %i." " In total, %s messages has been printed." % (self.previousCount, self.count)) return methods