Source code for continuousflex.viewers.viewer_nma_dimred

# **************************************************************************
# * Authors:     J.M. De la Rosa Trevin (jmdelarosa@cnb.csic.es)
# *              Slavica Jonic  (slavica.jonic@upmc.fr)
# *              Mohamad Harastani (mohamad.harastani@upmc.fr)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307  USA
# *
# *  All comments concerning this program package may be sent to the
# *  e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
from continuousflex.protocols.data import PathData

"""
This module implement the wrappers around Xmipp CL2D protocol
visualization program.
"""

from os.path import basename, join, exists
import numpy as np

from pwem.convert.atom_struct import cifToPdb
from pyworkflow.utils import replaceBaseExt

from pyworkflow.utils.path import cleanPath, makePath, cleanPattern
from pyworkflow.viewer import (ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO)
from pyworkflow.protocol.params import StringParam, LabelParam
from pwem.objects import SetOfParticles
from pwem.viewers import VmdView
from pyworkflow.gui.browser import FileBrowserWindow

from continuousflex.protocols.protocol_nma_dimred import FlexProtDimredNMA

from continuousflex.protocols.data import Point, Data

from continuousflex.viewers.nma_plotter import FlexNmaPlotter

from continuousflex.viewers.nma_gui import ClusteringWindow, TrajectoriesWindow
from pwem.utils import runProgram
from pyworkflow.protocol import params

FIGURE_LIMIT_NONE = 0
FIGURE_LIMITS = 1

X_LIMITS_NONE = 0
X_LIMITS = 1
Y_LIMITS_NONE = 0
Y_LIMITS = 1
Z_LIMITS_NONE = 0
Z_LIMITS = 1

[docs]class FlexDimredNMAViewer(ProtocolViewer): """ Visualization of results from the NMA protocol """ _label = 'viewer nma dimred' _targets = [FlexProtDimredNMA] _environments = [DESKTOP_TKINTER, WEB_DJANGO] def __init__(self, **kwargs): ProtocolViewer.__init__(self, **kwargs) self._data = None
[docs] def getData(self): if self._data is None: self._data = self.loadData() return self._data
def _defineParams(self, form): form.addSection(label='Visualization') form.addParam('displayRawDeformation', StringParam, default='1 2', label='Display normal-mode amplitudes in the low-dimensional space', help='Type 1 to see the histogram of normal-mode amplitudes in the low-dimensional space, ' 'using axis 1; \n ' 'Type 2 to see the histogram of normal-mode amplitudes in the low-dimensional space, ' 'using axis 2; etc. \n ' 'Type 1 2 to see normal-mode amplitudes in the low-dimensional space, using axes 1 and 2; \n' 'Type 1 2 3 to see normal-mode amplitudes in the low-dimensional space, using axes 1, 2, ' 'and 3; etc. ' ) form.addParam('displayClustering', LabelParam, label='Open clustering tool?', help='Open a GUI to visualize the images as points ' 'and select some of them to create clusters, and compute the 3D reconstructions from the ' 'clusters.') form.addParam('displayTrajectories', LabelParam, label='Open trajectories tool?', help='Open a GUI to visualize the images as points, ' 'draw and adjust trajectories, and animate them.') form.addParam('limits_modes', params.EnumParam, choices=['Automatic (Recommended)', 'Set manually Use upper and lower values'], default=FIGURE_LIMIT_NONE, label='Error limits', display=params.EnumParam.DISPLAY_COMBO, help='If you want to use a range of Error in the color bar choose to set it manually.') form.addParam('LimitLow', params.FloatParam, default=None, condition='limits_modes==%d' % FIGURE_LIMITS, label='Lower Error value', help='The lower Error used in the graph') form.addParam('LimitHigh', params.FloatParam, default=None, condition='limits_modes==%d' % FIGURE_LIMITS, label='Upper Error value', help='The upper Error used in the graph') form.addParam('xlimits_mode', params.EnumParam, choices=['Automatic (Recommended)', 'Set manually x-axis limits'], default=X_LIMITS_NONE, label='x-axis limits', display=params.EnumParam.DISPLAY_COMBO, help='This allows you to use a specific range of x-axis limits') form.addParam('xlim_low', params.FloatParam, default=None, condition='xlimits_mode==%d' % X_LIMITS, label='Lower x-axis limit') form.addParam('xlim_high', params.FloatParam, default=None, condition='xlimits_mode==%d' % X_LIMITS, label='Upper x-axis limit') form.addParam('ylimits_mode', params.EnumParam, choices=['Automatic (Recommended)', 'Set manually y-axis limits'], default=Y_LIMITS_NONE, label='y-axis limits', display=params.EnumParam.DISPLAY_COMBO, help='This allows you to use a specific range of y-axis limits') form.addParam('ylim_low', params.FloatParam, default=None, condition='ylimits_mode==%d' % Y_LIMITS, label='Lower y-axis limit') form.addParam('ylim_high', params.FloatParam, default=None, condition='ylimits_mode==%d' % Y_LIMITS, label='Upper y-axis limit') form.addParam('zlimits_mode', params.EnumParam, choices=['Automatic (Recommended)', 'Set manually z-axis limits'], default=Z_LIMITS_NONE, label='z-axis limits', display=params.EnumParam.DISPLAY_COMBO, help='This allows you to use a specific range of z-axis limits') form.addParam('zlim_low', params.FloatParam, default=None, condition='zlimits_mode==%d' % Z_LIMITS, label='Lower z-axis limit') form.addParam('zlim_high', params.FloatParam, default=None, condition='zlimits_mode==%d' % Z_LIMITS, label='Upper z-axis limit') def _getVisualizeDict(self): return {'displayRawDeformation': self._viewRawDeformation, 'displayClustering': self._displayClustering, 'displayTrajectories': self._displayTrajectories, } def _viewRawDeformation(self, paramName): components = self.displayRawDeformation.get() return self._doViewRawDeformation(components) def _doViewRawDeformation(self, components): components = list(map(int, components.split())) dim = len(components) views = [] if dim > 0: modeList = [m - 1 for m in components] modeNameList = ['Axis %d' % m for m in components] missingList = [] if missingList: return [self.errorMessage("Invalid mode(s) *%s*\n." % (', '.join(missingList)), title="Invalid input")] # Actually plot if self.limits_modes == FIGURE_LIMIT_NONE: plotter = FlexNmaPlotter(data=self.getData(), xlim_low=self.xlim_low, xlim_high=self.xlim_high, ylim_low=self.ylim_low, ylim_high=self.ylim_high, zlim_low=self.zlim_low, zlim_high=self.zlim_high) else: plotter = FlexNmaPlotter(data=self.getData(), LimitL=self.LimitLow, LimitH=self.LimitHigh, xlim_low=self.xlim_low, xlim_high=self.xlim_high, ylim_low=self.ylim_low, ylim_high=self.ylim_high, zlim_low=self.zlim_low, zlim_high=self.zlim_high) baseList = [basename(n) for n in modeNameList] self.getData().XIND = modeList[0] if dim == 1: plotter.plotArray1D("Histogram of normal-mode amplitudes in low-dimensional space: %s" % baseList[0], "Amplitude", "Number of images") else: self.getData().YIND = modeList[1] if dim == 2: plotter.plotArray2D("Normal-mode amplitudes in low-dimensional space: %s vs %s" % tuple(baseList), *baseList) elif dim == 3: self.getData().ZIND = modeList[2] plotter.plotArray3D("Normal-mode amplitudes in low-dimensional space: %s %s %s" % tuple(baseList), *baseList) views.append(plotter) return views def _displayClustering(self, paramName): self.clusterWindow = self.tkWindow(ClusteringWindow, title='Clustering Tool', dim=self.protocol.reducedDim.get(), data=self.getData(), callback=self._createCluster, limits_mode=self.limits_modes, LimitL=self.LimitLow, LimitH=self.LimitHigh, xlim_low=self.xlim_low, xlim_high=self.xlim_high, ylim_low=self.ylim_low, ylim_high=self.ylim_high, zlim_low=self.zlim_low, zlim_high=self.zlim_high, ) return [self.clusterWindow] def _displayTrajectories(self, paramName): self.trajectoriesWindow = self.tkWindow(TrajectoriesWindow, title='Trajectories Tool', dim=self.protocol.reducedDim.get(), data=self.getData(), callback=self._generateAnimation, loadCallback=self._loadAnimation, numberOfPoints=10, limits_mode=self.limits_modes, LimitL=self.LimitLow, LimitH=self.LimitHigh, xlim_low=self.xlim_low, xlim_high=self.xlim_high, ylim_low=self.ylim_low, ylim_high=self.ylim_high, zlim_low=self.zlim_low, zlim_high=self.zlim_high, ) return [self.trajectoriesWindow] def _createCluster(self): """ Create the cluster with the selected particles from the cluster. This method will be called when the button 'Create Cluster' is pressed. """ # Write the particles prot = self.protocol project = prot.getProject() inputSet = prot.getInputParticles() makePath(prot._getTmpPath()) fnSqlite = prot._getTmpPath('cluster_particles.sqlite') cleanPath(fnSqlite) partSet = SetOfParticles(filename=fnSqlite) partSet.copyInfo(inputSet) for point in self.getData(): if point.getState() == Point.SELECTED: particle = inputSet[point.getId()] partSet.append(particle) partSet.write() partSet.close() from continuousflex.protocols.protocol_batch_cluster import FlexBatchProtNMACluster # from xmipp3.protocols.nma.protocol_batch_cluster import BatchProtNMACluster newProt = project.newProtocol(FlexBatchProtNMACluster) clusterName = self.clusterWindow.getClusterName() if clusterName: newProt.setObjLabel(clusterName) newProt.inputNmaDimred.set(prot) newProt.sqliteFile.set(fnSqlite) project.launchProtocol(newProt) project.getRunsGraph() def _loadAnimationData(self, obj): prot = self.protocol animationName = obj.getFileName() # assumes that obj.getFileName is the folder of animation animationPath = prot._getExtraPath(animationName) animationFiles = [animationName + '.vmd', animationName + '.pdb', 'trajectory.txt'] for s in animationFiles: f = join(animationPath, s) if not exists(f): self.errorMessage('Animation file "%s" not found. ' % f) return # Load animation trajectory points trajectoryPoints = np.loadtxt(join(animationPath, 'trajectory.txt')) data = PathData(dim=trajectoryPoints.shape[1]) for i, row in enumerate(trajectoryPoints): data.addPoint(Point(pointId=i + 1, data=list(row), weight=1)) self.trajectoriesWindow.setPathData(data) self.trajectoriesWindow.setAnimationName(animationName) self.trajectoriesWindow._onUpdateClick() def _showVmd(): vmdFn = join(animationPath, animationName + '.vmd') VmdView(' -e %s' % vmdFn).show() self.getTkRoot().after(500, _showVmd) def _loadAnimation(self): prot = self.protocol browser = FileBrowserWindow("Select the animation folder (animation_NAME)", self.getWindow(), prot._getExtraPath(), onSelect=self._loadAnimationData) browser.show() def _generateAnimation(self): prot = self.protocol projectorFile = prot.getProjectorFile() animation = self.trajectoriesWindow.getAnimationName() animationPath = prot._getExtraPath('animation_%s' % animation) cleanPath(animationPath) makePath(animationPath) animationRoot = join(animationPath, 'animation_%s' % animation) trajectoryPoints = np.array([p.getData() for p in self.trajectoriesWindow.pathData]) np.savetxt(join(animationPath, 'trajectory.txt'), trajectoryPoints) if projectorFile: M = np.loadtxt(projectorFile) deformations = np.dot(trajectoryPoints, np.linalg.pinv(M)) else: Y = np.loadtxt(prot.getOutputMatrixFile()) X = np.loadtxt(prot.getDeformationFile()) # Find closest points in deformations deformations = [X[np.argmin(np.sum((Y - p) ** 2, axis=1))] for p in trajectoryPoints] pdb = prot.getInputPdb() pdbFile = pdb.getFileName() structureEM = prot.getInputPdb().getPseudoAtoms() if not structureEM: localFn = replaceBaseExt(basename(pdbFile), 'pdb') cifToPdb(pdbFile, localFn) pdbFile = basename(localFn) modesFn = prot.inputNMA.get()._getExtraPath('modes.xmd') for i, d in enumerate(deformations): atomsFn = animationRoot + 'atomsDeformed_%02d.pdb' % (i + 1) cmd = '-o %s --pdb %s --nma %s --deformations %s' % (atomsFn, pdbFile, modesFn, str(d)[1:-1]) runProgram('xmipp_pdb_nma_deform', cmd) # Join all deformations in a single pdb # iterating going up and down through all points # 1 2 3 ... n-2 n-1 n n-1 n-2 ... 3, 2 n = len(deformations) r1 = list(range(1, n + 1)) r2 = list(range(2, n)) # Skip 1 at the end r2.reverse() loop = r1 + r2 trajFn = animationRoot + '.pdb' trajFile = open(trajFn, 'w') for i in loop: atomsFn = animationRoot + 'atomsDeformed_%02d.pdb' % i atomsFile = open(atomsFn) for line in atomsFile: trajFile.write(line) trajFile.write('TER\nENDMDL\n') atomsFile.close() trajFile.close() # Delete temporary atom files cleanPattern(animationRoot + 'atomsDeformed_??.pdb') # Generate the vmd script vmdFn = animationRoot + '.vmd' vmdFile = open(vmdFn, 'w') vmdFile.write(""" mol new %s animate style Loop display projection Orthographic mol modcolor 0 0 Index mol modstyle 0 0 Beads 1.000000 8.000000 animate speed 0.5 animate forward """ % trajFn) vmdFile.close() VmdView(' -e ' + vmdFn).show()
[docs] def loadData(self): """ Iterate over the images and the output matrix txt file and create a Data object with theirs Points. """ matrix = np.loadtxt(self.protocol.getOutputMatrixFile()) particles = self.protocol.getInputParticles() data = Data() for i, particle in enumerate(particles): data.addPoint(Point(pointId=particle.getObjId(), data=matrix[i, :], weight=particle._xmipp_cost.get())) return data