# -*- coding: utf-8 -*-
from tomo.constants import BOTTOM_LEFT_CORNER
from tomo.objects import SetOfTomograms
from tomo.protocols import ProtTomoBase
import deepfinder.objects
import deepfinder.convert as cv

[docs]class ProtDeepFinderBase(ProtTomoBase): def _createSetOfDeepFinderSegmentations(self, suffix=''): return self._createSet(deepfinder.objects.SetOfDeepFinderSegmentations, 'segmentations%s.sqlite', suffix) def _createSetOfCoordinates3DWithScore(self, volSet, suffix=''): coord3DSet = self._createSet(deepfinder.objects.SetOfCoordinates3DWithScore, 'coordinates%s.sqlite', suffix, indexes=['_volId']) coord3DSet.setPrecedents(volSet) return coord3DSet @staticmethod def _getObjlFromInputCoordinates(coord3DSet): """Get all objects of specified class. Args: tomoSet (SetOfTomograms) coord3DSet (SetOfCoordinates3D) Returns: list of dict: deep finder object list (contains particle infos) """ objl = [] tomoList = [tomo.clone() for tomo in coord3DSet.getPrecedents()] for tomo in tomoList: tomoId = tomo.getObjId() for coord in coord3DSet.iterCoordinates(volume=tomoId): x = coord.getX(BOTTOM_LEFT_CORNER) y = coord.getY(BOTTOM_LEFT_CORNER) z = coord.getZ(BOTTOM_LEFT_CORNER) lbl = int(str(coord._dfLabel)) cv.objl_add(objl, label=lbl, coord=[z, y, x], tomo_idx=tomoId) return objl @staticmethod def _getObjlFromInputCoordinatesV2(tomoSet, coord3DSet): # emoebel : I modified a bit to suit my needs """Get all Coord objects related to the given Tomogram objects. The output is an objl as needed by DeepFinder. The tomo_idx in the objl respects the order in tomoSet, which is important for the Train protocol Args: tomoSet (SetOfTomograms) coord3DSet (SetOfCoordinates3D) Returns: list of dict: deep finder object list (contains particle infos) """ # /!\ tidx is tomo index for object list, tomoId is tomo index for SetOfCoordinates3D. Not the same !! objl = [] for tidx, tomo in enumerate(tomoSet): tomoId = tomo.getObjId() for coord in coord3DSet.iterCoordinates(volume=tomoId): x = coord.getX(BOTTOM_LEFT_CORNER) y = coord.getY(BOTTOM_LEFT_CORNER) z = coord.getZ(BOTTOM_LEFT_CORNER) lbl = int(str(coord._dfLabel)) cv.objl_add(objl, label=lbl, coord=[z, y, x], tomo_idx=tidx) return objl