# **************************************************************************
# * Authors: Mohamad Harastani (mohamad.harastani@upmc.fr)
# *
# * This program is free software; you can redistribute it and/or modify
# * it under the terms of the GNU General Public License as published by
# * the Free Software Foundation; either version 2 of the License, or
# * (at your option) any later version.
# *
# * This program is distributed in the hope that it will be useful,
# * but WITHOUT ANY WARRANTY; without even the implied warranty of
# * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# * GNU General Public License for more details.
# *
# * You should have received a copy of the GNU General Public License
# * along with this program; if not, write to the Free Software
# * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA
# * 02111-1307 USA
# *
# * All comments concerning this program package may be sent to the
# * e-mail address 'scipion@cnb.csic.es'
# *
# **************************************************************************
"""
This module implement the wrappers aroung Xmipp CL2D protocol
visualization program.
"""
from continuousflex.protocols.protocol_heteroflow import FlexProtHeteroFlow
from pwem.viewers import EmProtocolViewer
from pyworkflow.protocol.params import LabelParam, IntParam, EnumParam, StringParam
from pyworkflow.viewer import ProtocolViewer, DESKTOP_TKINTER, WEB_DJANGO
from pwem.viewers import ObjectView
import numpy as np
import matplotlib.pyplot as plt
from continuousflex.protocols.utilities.OF_plots import plot_quiver_3d, plot_quiver_2d
from continuousflex.protocols.utilities.spider_files3 import open_volume, open_image
import pyworkflow.protocol.params as params
from pyworkflow.utils.process import runJob
from pyworkflow.utils.path import makePath
XY = 0
XZ = 1
ZY = 2
[docs]class FlexHeteroFlowViewer(EmProtocolViewer):
""" Visualization of results from the HeteroFlow protocol
"""
_label = 'viewer heteroflow'
_targets = [FlexProtHeteroFlow]
_environments = [DESKTOP_TKINTER, WEB_DJANGO]
def __init__(self, **kwargs):
ProtocolViewer.__init__(self, **kwargs)
self._data = None
def _defineParams(self, form):
form.addSection(label='Visualization')
group = form.addGroup('Optical flows')
group.addParam('FlowNumber', IntParam, default=1,
label='Optical flow for volume number')
group.addParam('DownSample', IntParam, default=2,
expertLevel=params.LEVEL_ADVANCED,
label='Downsample the 3D quiver plot',
help='Loading the 3D quiver plot can take time, we can downsample it by this number')
group.addParam('displayFlow', LabelParam,
label="Display 3D optical flow",
help="Display the calculated optical flow of volume of specified number vs the reference")
group.addParam('displayFlow2', EnumParam,
choices=[' x-y plane: Euler( 0, 0, 0)', ' x-z plane: Euler( 0,90, 0)',
' z-y plane: Euler(90,90, 0)'],
default=XY,
display=params.EnumParam.DISPLAY_COMBO,
label='Display projected 3D-to-2D optical flow',
help="Display the projected 3D-2D optical flow on one of the planes")
group.addParam('RotTiltPsi', StringParam, default=None,
expertLevel=params.LEVEL_ADVANCED,
label='rot tilt psi',
help='Project the 3D optical flow using specific Euler angles')
form.addParam('displayVolumes', LabelParam,
label="Display warped volumes",
help="Display the volumes that are generated by applying the calculated optical flow of each of "
"the input volumes on the reference")
form.addParam('displayHistCC', LabelParam,
label="Histogram of normalized cross correlation",
help="Histogram of the normalized cross correlation between the input volumes and "
"the warped reference")
form.addParam('displayHistmsd', LabelParam,
label="Histogram of mean square distance",
help="Histogram of the mean square distance between the input volumes and "
"the warped reference")
form.addParam('displayHistmad', LabelParam,
label="Histogram of normalized mean absolute distance",
help="Histogram of the mean absolute distance between the input volumes and "
"the warped reference")
def _getVisualizeDict(self):
return {'displayVolumes': self._viewVolumes,
'displayHistCC': self._viewParam,
'displayHistmsd': self._viewParam,
'displayHistmad': self._viewParam,
'displayFlow' : self._viewFlow,
'displayFlow2' : self._viewFlow2,
'RotTiltPsi' : self._viewFlow2
}
def _viewVolumes(self, paramName):
volumes = self.protocol.WarpedRefByFlows
return [ObjectView(self._project, volumes.strId(), volumes.getFileName())]
def _viewParam(self, paramName):
datamat_fn = self.protocol._getExtraPath('cc_msd_mad.txt')
datamat = np.loadtxt(datamat_fn, delimiter=' ')
if paramName == 'displayHistCC':
plt.figure()
plt.hist(datamat[:, 0])
# plt.xlim(0,1)
plt.title('Histogram of normalized cross correlation between warped\n reference by '
'optical flows (estimated volumes) and the input volumes')
plt.xlabel('Cross correlation')
plt.ylabel('Number of volumes')
plt.show()
elif paramName == 'displayHistmsd':
plt.figure()
plt.hist(datamat[:, 1])
plt.title('Histogram of normalized mean square distance between warped\n reference by '
'optical flows (estimated volumes) and the input volumes')
plt.xlabel('normalized mean square distance')
plt.ylabel('Number of volumes')
plt.show()
elif paramName == 'displayHistmad':
plt.figure()
plt.hist(datamat[:, 2])
plt.title('Histogram of normalized mean absolute distance between warped\n reference by '
'optical flows (estimated volumes) and the input volumes')
plt.xlabel('normalized absolute square distance')
plt.ylabel('Number of volumes')
plt.show()
pass
def _viewFlow(self, paramName):
number = str(self.FlowNumber).zfill(6)
flow = self.read_optical_flow_by_number(number)
title = '3D optical flow for input volume number %d' % self.FlowNumber
plot_quiver_3d(flow, downsample=self.DownSample.get(), title=title)
pass
def _viewFlow2(self, paramName):
number = str(self.FlowNumber).zfill(6)
flow3D = self.read_optical_flow_by_number(number)
op_path = self.protocol._getExtraPath() + '/optical_flows/'
path_flowx = op_path + str(number).zfill(6) + '_opflowx.spi'
path_flowy = op_path + str(number).zfill(6) + '_opflowy.spi'
path_flowz = op_path + str(number).zfill(6) + '_opflowz.spi'
makePath(self.protocol._getTmpPath())
proj_x = self.protocol._getTmpPath('proj_x.spi')
proj_y = self.protocol._getTmpPath('proj_y.spi')
proj_z = self.protocol._getTmpPath('proj_z.spi')
rot = 0
tilt = 0
psi = 0
title = 'Projected optical flow of input volume number %d on XY plane' % self.FlowNumber
if self.displayFlow2 == XZ:
title = 'Projected optical flow of input volume number %d on XZ plane' % self.FlowNumber
rot = 90
tilt = 90
elif self.displayFlow2 == ZY:
title = 'Projected optical flow of input volume number %d on ZY plane' % self.FlowNumber
tilt = 90
if paramName == 'RotTiltPsi':
rot, tilt, psi = list(map(float, self.RotTiltPsi.get().split()))
title = 'Projected optical flow of input volume number %d \n using Euler angles' \
' (%.1f, %.1f, %.1f)' % (self.FlowNumber, rot, tilt, psi)
#print(rot, tilt, psi)
command_x = '-i ' + path_flowx + ' -o ' + proj_x + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
command_y = '-i ' + path_flowy + ' -o ' + proj_y + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
command_z = '-i ' + path_flowz + ' -o ' + proj_z + ' --angles ' + str(rot) + ' ' + str(tilt) + ' ' + str(psi)
runJob(None, 'xmipp_phantom_project', command_x)
runJob(None, 'xmipp_phantom_project', command_y)
runJob(None, 'xmipp_phantom_project', command_z)
px = open_image(proj_x)
py = open_image(proj_y)
pz = open_image(proj_z)
p = np.zeros([3, np.shape(px)[0], np.shape(px)[1]])
p[0, :, :] = px
p[1, :, :] = py
p[2, :, :] = pz
T = self.euler_matrix(rot, tilt, psi)
p_reshaped = np.reshape(p, [3, np.shape(px)[0] * np.shape(px)[1]])
pn = np.reshape(np.matmul(T, p_reshaped), [3, np.shape(px)[0], np.shape(px)[1]])
flow2D = np.zeros([np.shape(px)[0], np.shape(px)[1], 2])
flow2D[:,:,0] = pn[0,:,:]
flow2D[:,:,1] = pn[1,:,:]
# We need to scale flow2D by the magnitude of flow3D
mag_3D = np.sqrt(flow3D[0, :, :, :] * flow3D[0, :, :, :] +
flow3D[1, :, :, :] * flow3D[1, :, :, :] +
flow3D[2, :, :, :] * flow3D[2, :, :, :])
max_3D = np.max(mag_3D)
mag_2D = np.sqrt(flow2D[:, :, 0] * flow2D[:, :, 0] +
flow2D[:, :, 1] * flow2D[:, :, 1])
max_2D = np.max(mag_2D)
flow2D = (max_3D/max_2D)*flow2D
plot_quiver_2d(flow2D, title=title)
pass
[docs] def read_optical_flow_by_number(self, num):
op_path = self.protocol._getExtraPath() + '/optical_flows/'
path_flowx = op_path + str(num).zfill(6) + '_opflowx.spi'
path_flowy = op_path + str(num).zfill(6) + '_opflowy.spi'
path_flowz = op_path + str(num).zfill(6) + '_opflowz.spi'
flow = self.read_optical_flow(path_flowx, path_flowy, path_flowz)
return flow
[docs] def read_optical_flow(self, path_flowx, path_flowy, path_flowz):
x = open_volume(path_flowx)
y = open_volume(path_flowy)
z = open_volume(path_flowz)
l = np.shape(x)
# print(l)
flow = np.zeros([3, l[0], l[1], l[2]])
flow[0, :, :, :] = x
flow[1, :, :, :] = y
flow[2, :, :, :] = z
return flow
[docs] def euler_matrix(self,rot, tilt, psi):
from math import sin, cos, radians
t1 = -radians(psi)
t2 = -radians(tilt)
t3 = -radians(rot)
a11 = cos(t1) * cos(t2) * cos(t3) - sin(t1) * sin(t3)
a12 = -cos(t3) * sin(t1) - cos(t1) * cos(t2) * sin(t3)
a13 = cos(t1) * sin(t2)
a21 = cos(t1) * sin(t3) + cos(t2) * cos(t3) * sin(t1)
a22 = cos(t1) * cos(t3) - cos(t2) * sin(t1) * sin(t3)
a23 = sin(t1) * sin(t2)
a31 = -cos(t3) * sin(t2)
a32 = sin(t2) * sin(t3)
a33 = cos(t2)
T = np.array([[a11, a12, a13], [a21, a22, a23], [a31, a32, a33]])
return T