from javax.swing import *
from javax.swing.border import LineBorder
from java.awt import GridLayout, BorderLayout, FlowLayout, Color, Dimension, Canvas
from java.awt.event import InputEvent
from visad import DisplayListener, DisplayEvent, MouseHelper, Display, CellImpl
from visad.java3d import MouseBehaviorJ3D
import subs
import jarray

import hydra_properties


#- this is not a gui component:  must be passed to DisplayViewSelectToolBar
class DisplayViewSelect(DisplayListener):
  def __init__(self, display, toolbar):
    self.display     = display
    self.toolbar     = toolbar
    self.display.addDisplayListener(self)
    self.proj_cntrl  = self.display.getProjectionControl()
    self.init_matrix = self.proj_cntrl.getMatrix()
    self.dspRender   = self.display.getDisplayRenderer()
    self.mouseHelper = self.dspRender.getMouseBehavior().getMouseHelper()
    self.count_map_removed = 0
    self.count_map_added   = 0
    self.spatial_removed   = 0
   
  def displayChanged(self, event):
    self.toolbar.displayChanged(self, event)

        
   
#--------------------------------------------------------------------------------
#  This extends JToolBar and provides a bar of toggle buttons for control
#  of the display view.  Takes a VisAD display, and optionally an array
#  of DisplayViewSelect objects.  Can be used as any gui component would.
#---------------------------------------------------------------------------------    
class DisplayViewSelectToolBar(JToolBar):
  def __init__(self, display_s, rangeZoom=None, pickAction=None):
    JToolBar.__init__(self)
    self.setFloatable(0)
    
    self.actions     = []
    self.listeners   = []
    self.buttonGroup = ButtonGroup()
    self.selectedAction = None
    if rangeZoom == None:
      rangeZoom = [0 for i in xrange(len(display_s))]
    
    for i in xrange(len(display_s)):
      self.listeners.append(DisplayViewSelect(display_s[i], self))
      
                  
    #--- add in default actions/JToggleButtons:
    self.actions.append(Reset(self.listeners))
    self.reset = JToggleButton(ImageIcon("./ui/icons/Home16.gif"), 
                                 actionPerformed=self.select, actionCommand="reset")
    self.reset.setToolTipText("reset")
    self.add(self.reset)
    self.buttonGroup.add(self.reset)
    
    self.actions.append(ZoomUp(self.listeners))
    self.zoom_plus  = JToggleButton(ImageIcon("./ui/icons/ZoomIn16.gif"),
                                     actionPerformed=self.select, actionCommand="zoomup")
    self.zoom_plus.setToolTipText("zoom in (Shift+Left+Drag)")
    self.add(self.zoom_plus)
    self.buttonGroup.add(self.zoom_plus)
    
    self.actions.append(ZoomDown(self.listeners))
    self.zoom_minus = JToggleButton(ImageIcon("./ui/icons/ZoomOut16.gif"),
                                      actionPerformed=self.select, actionCommand="zoomdown")
    self.zoom_minus.setToolTipText("zoom out (Shift+Left+Drag)")
    self.add(self.zoom_minus)
    self.buttonGroup.add(self.zoom_minus)
    
    self.actions.append(Translate(self.listeners))
    self.translate = JToggleButton(ImageIcon("./ui/icons/four_arrows.gif"), 
                                     actionPerformed=self.select, actionCommand="translate")
    self.translate.setToolTipText("translate: Right+Drag (Cntr+Left+Drag)")
    self.add(self.translate)
    self.buttonGroup.add(self.translate)
    
    for i in xrange(len(display_s)):
      if rangeZoom[i] == 0:
        self.actions.append(RubberBandZoom([self.listeners[i]]))
      else:
        self.actions.append(RubberBandRangeZoom([self.listeners[i]]))
    self.zoom_rbbox = JToggleButton(ImageIcon("./ui/icons/rubber_band.gif"), 
                                      actionPerformed=self.select, actionCommand="rubberbandzoom")
    self.zoom_rbbox.setToolTipText("rubber band zoom")
    self.add(self.zoom_rbbox)
    self.buttonGroup.add(self.zoom_rbbox)
    
    if pickAction == None:
      self.pickAction = Pick(self.listeners)
      self.actions.append(self.pickAction)
      self.click = JToggleButton(ImageIcon("./ui/icons/cursor16.jpg"), 
                                 actionPerformed=self.select, actionCommand="pick")
      self.click.setToolTipText("pick image")
      self.add(self.click)
      self.buttonGroup.add(self.click)
    elif pickAction == "grab":
      self.pickAction = Pick(self.listeners, func="grab")
      self.actions.append(self.pickAction)
      self.click = JToggleButton(ImageIcon("./ui/icons/cursor16.jpg"),
                                 actionPerformed=self.select, actionCommand="pick")
      self.click.setToolTipText("grab")
      self.add(self.click)
      self.buttonGroup.add(self.click)
    

    
  def addHAction(self, action):
    self.actions.append(action)
    action.listeners = self.listeners
    jtb = JToggleButton(ImageIcon(action.getIcon()), actionPerformed=self.select, 
                        actionCommand=action.getActionCommand())
    jtb.setToolTipText(action.getToolTipText())
    self.add(jtb)
    self.buttonGroup.add(jtb)
    
  def select(self, event, leftFunc=None):
    cmd = event.getActionCommand()
    for i in xrange(len(self.actions)):
      if cmd == self.actions[i].actionCommand:
        self.actions[i].selected(leftFunc)
        self.selectedAction = self.actions[i]
      else:
        self.actions[i].deselected()

  def displayChanged(self, listener, event):
    #print event

    for i in xrange(len(self.actions)):
      self.actions[i].doit(listener, event)

    if self.selectedAction != None:
      if self.selectedAction.actionCommand == "reset":
        for i in xrange(len(self.actions)):
          self.actions[i].doReset(listener,event)
    
      
class HAction:
  def __init__(self, listeners, actionCommand, icon=None, tooltipText=None):
    self.listeners     = listeners
    self.actionCommand = actionCommand
    self.num_listeners = len(listeners)
    self.enabled       = 0
    self.icon          = icon
    self.tooltipText   = tooltipText
    
  def selected(self,leftFunc=None):
    self.enabled = 1
    
  def deselected(self):
    self.enabled = 0
  
  def doit(self, listener, event):
    pass
    
  def getIcon(self):
    return self.icon
  
  def getActionCommand(self):
    return self.actionCommand
    
  def getToolTipText(self):
    return self.tooltipText
    
  def setListeners(self, listeners):
    self.listeners = listeners

  def doReset(self,listener,event):
    pass
    
    
class RubberBandSubset(HAction):
  def __init__(self, listeners, subsetRef, actionCommand="imagesubset", icon=None, tooltipText=None):
    HAction.__init__(self, listeners, actionCommand, icon, tooltipText)
    self.subsetRef = subsetRef
    
    
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):      
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.DIRECT, MouseHelper.ZOOM],
                                                     [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                                    [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                                     [MouseHelper.NONE, MouseHelper.NONE]],
                                                    [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                                     [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
      self.listeners[i].display.toggle(self.subsetRef[0], 1)
      self.enabled = 1
      self.subsetRef[0].enabled = 1
      
  def deselected(self):
    for i in xrange(self.num_listeners):
      self.listeners[i].display.toggle(self.subsetRef[0], 0)
    self.enabled = 0 
    self.subsetRef[0].enabled = 0
    
    
class RubberBandZoom(HAction):
  def __init__(self, listeners, actionCommand="rubberbandzoom"):
    HAction.__init__(self, listeners, actionCommand)
    self.rbbzRef_s = []
    for i in xrange(self.num_listeners):
      ref = listeners[i].display.enableRubberBandBoxZoomer(0,color=hydra_properties.foreground_color)
      listeners[i].display.toggle(ref, 0)
      self.rbbzRef_s.append(ref)
        
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):      
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.DIRECT, MouseHelper.ZOOM],
                                                   #  [MouseHelper.NONE, MouseHelper.ZOOM],    
                                                     [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                                    [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                                     [MouseHelper.NONE, MouseHelper.NONE]],
                                                    [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                                     [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
      self.listeners[i].display.toggle(self.rbbzRef_s[i], 1)
      self.enabled = 1
      
  def deselected(self):
    for i in xrange(self.num_listeners):
      self.listeners[i].display.toggle(self.rbbzRef_s[i], 0)
    self.enabled = 0
        
  def doit(self, listener, event):
    
    if (listener in self.listeners) == 0:
      id = None
      return
    else:
      id = self.listeners.index(listener)
    
    if event.getId() == DisplayEvent.MAP_REMOVED:
      ds = event.map.getDisplayScalar()
      listener.count_map_removed += 1
      if ds == Display.YAxis or ds == Display.XAxis:
        listener.spatial_removed = 1
      
    if event.getId() == DisplayEvent.MAP_ADDED:
      ds = event.map.getDisplayScalar()
      if ds == Display.YAxis or ds == Display.XAxis:
        listener.count_map_added += 1
      if (listener.count_map_added == listener.count_map_removed) and listener.spatial_removed == 1:
        listener.count_map_added   = 0
        listener.count_map_removed = 0
        listener.spatial_removed   = 0
        ref = listener.display.enableRubberBandBoxZoomer(0, color=hydra_properties.foreground_color)
        self.rbbzRef_s[id] = ref
        if self.enabled == 1:
          listener.display.toggle(ref, 1)
        else:
          listener.display.toggle(ref, 0)
        
    if event.getId() == DisplayEvent.MAPS_CLEARED:
      listener.spatial_removed = 1
      listener.display.removeReference(self.rbbzRef_s[id])
      listener.count_map_removed = 2

      
      
class RubberBandRangeZoom(HAction):
  def __init__(self, listeners, actionCommand="rubberbandzoom"):
    HAction.__init__(self, listeners, actionCommand)
    self.rbbzRef_s = []
    self.xmaps        = None
    self.ymaps        = None
    self.init_x_range = None
    self.init_y_range = None
    self.reset_flag   = [0]
    
    self.rbbzRef_s.append(self.initialize())
    
  def initialize(self):
    for i in xrange(self.num_listeners):
      smapList = subs.getDisplayScalarMapLists(self.listeners[i].display)
      xx = smapList[0] 
      yy = smapList[1]
      zz = smapList[2]
      xmap_s = []
      ymap_s = None
      xmap = xx[0]
      ymap = yy[0]
      if len(yy) > 1:
        ymap_s = []
        ymap_s.append(yy[1])
      zref = self.listeners[i].display.enableRubberBandBox(0,x=xmap.getScalar(),y=ymap.getScalar(),color=hydra_properties.foreground_color)
      self.listeners[i].display.toggle(zref, 0)
      
      self.xmaps = xx
      self.ymaps = yy

      #if self.init_x_range == None and self.init_y_range == None:
      self.init_x_range = []
      for j in xrange(len(xx)):
        self.init_x_range.append(xx[j].getRange())
        
      self.init_y_range = []
      for j in xrange(len(yy)):
        self.init_y_range.append(yy[j].getRange())
      
      class ZoomBox(CellImpl):
        def __init__(self, dsp, zoomBoxRef, reset, xmap, ymap, xmap_s=None, ymap_s=None):
          self.first      = 1
          self.zoomBoxRef = zoomBoxRef
          self.xmap       = xmap
          self.ymap       = ymap
          self.xmap_s     = xmap_s
          self.ymap_s     = ymap_s
          self.so_a       = jarray.zeros(2, 'd')
          self.so_b       = jarray.zeros(2, 'd')
          self.data_a     = jarray.zeros(2, 'd')
          self.data_b     = jarray.zeros(2, 'd')
          self.dsp_a      = jarray.zeros(2, 'd')
          self.dsp_b      = jarray.zeros(2, 'd')
          self.reset      = reset
          self.dsp        = dsp
        
        def doAction(self):
          if self.first == 0:
            set = self.zoomBoxRef.getData()
            lo  = set.getLow()
            hi  = set.getHi()
            if ((hi[0] - lo[0]) < (hi[1] - lo[1])):
              slo = lo[0]
              shi = hi[0]
            else:
              slo = lo[1]
              shi = hi[1]
            if ((lo[0] == hi[0]) or (lo[1] == hi[1])):
              pass
            else:
              self.reset[0] = 1
              self.dsp.disableAction()
              self.ymap.getScale(self.so_a, self.data_a, self.dsp_a)
              self.xmap.setRange(lo[0], hi[0])
              self.ymap.setRange(lo[1], hi[1])
              if self.ymap_s != None:
                dsp_lo = lo[1]*self.so_a[0] + self.so_a[1]
                dsp_hi = hi[1]*self.so_a[0] + self.so_a[1]
                self.ymap_s[0].getScale(self.so_b, self.data_b, self.dsp_b)
                data_lo = (dsp_lo - self.so_b[1])/self.so_b[0]
                data_hi = (dsp_hi - self.so_b[1])/self.so_b[0]
                self.ymap_s[0].setRange(data_lo, data_hi)
              self.dsp.enableAction()
                
          else:
            self.first = 0
      zcell = ZoomBox(self.listeners[i].display, zref, self.reset_flag, xmap, ymap, None, ymap_s)
      zcell.addReference(zref)
      return zref
      
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):      
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.DIRECT, MouseHelper.ZOOM],
                                                     [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                                    [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                                     [MouseHelper.NONE, MouseHelper.NONE]],
                                                    [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                                     [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
      self.listeners[i].display.toggle(self.rbbzRef_s[i], 1)
    self.enabled = 1
      
  def deselected(self):
    for i in xrange(self.num_listeners):
      self.listeners[i].display.toggle(self.rbbzRef_s[i], 0)
    self.enabled = 0  
  
  def doit(self, listener, event):
    if (listener in self.listeners) == 0:
      id = None
      return
    else:
      id = self.listeners.index(listener)
    
    if event.getId() == DisplayEvent.MAP_REMOVED:
      ds = event.map.getDisplayScalar()
      listener.count_map_removed += 1
      if ds == Display.YAxis or ds == Display.XAxis:
        listener.spatial_removed = 1
      
    if event.getId() == DisplayEvent.MAP_ADDED:
      ds = event.map.getDisplayScalar()
      if ds == Display.YAxis or ds == Display.XAxis:
        listener.count_map_added += 1
      if (listener.count_map_added == listener.count_map_removed) and listener.spatial_removed == 1:
        listener.count_map_added   = 0
        listener.count_map_removed = 0
        listener.spatial_removed   = 0

        ref = self.initialize()
        self.rbbzRef_s[id] = ref
        if self.enabled == 1:
          listener.display.toggle(ref, 1)
        else:
          listener.display.toggle(ref, 0)
        
    if event.getId() == DisplayEvent.MAPS_CLEARED:
      listener.spatial_removed = 1
      listener.display.removeReference(self.rbbzRef_s[id])
      listener.count_map_removed = 2     


  def doReset(self,listener,event):
    if (listener in self.listeners) == 1:
      id = self.listeners.index(listener)
      if self.reset_flag[0] == 1 and event.getId() == DisplayEvent.MOUSE_RELEASED_RIGHT or event.getId() == DisplayEvent.MOUSE_RELEASED_LEFT:

        #self.listeners[id].display.disableAction()

        if self.xmaps != None:
          for i in xrange(len(self.xmaps)):
            self.xmaps[i].setRange(self.init_x_range[i][0], self.init_x_range[i][1])
        if self.ymaps != None:
          for i in xrange(len(self.ymaps)):
            self.ymaps[i].setRange(self.init_y_range[i][0], self.init_y_range[i][1])

        #self.listeners[id].display.enableAction()

        self.reset_flag[0] = 0
    

class Zoom(HAction):
  def __init__(self, listeners, actionCommand):
    HAction.__init__(self, listeners, actionCommand)
        
  def selected(self, leftFunc=None):
    for i in xrange(self.num_listeners):
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.NONE, MouseHelper.ZOOM],
                                      [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                     [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                   [MouseHelper.NONE, MouseHelper.NONE]],
                                  [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                   [MouseHelper.DIRECT, MouseHelper.DIRECT]]])    
    self.enabled = 1
    
  def deselected(self, leftFunc=None):
    self.enabled = 0
        
  
  def zoom(self, scale, listener):
    t1 = MouseBehaviorJ3D.static_make_matrix(0,0,0,scale,0,0,0)
    c1 = listener.proj_cntrl.getMatrix()
    m1 = MouseBehaviorJ3D.static_multiply_matrix(t1, c1)
    listener.proj_cntrl.setMatrix(m1)   
    
  def doit(self, listener, event):
    if self.enabled == 1:
      if event.getId() == DisplayEvent.MOUSE_RELEASED_RIGHT or event.getId() == DisplayEvent.MOUSE_RELEASED_LEFT:
        self.zoom(self.scale, listener)
    
class ZoomUp(Zoom):
  def __init__(self, listeners, actionCommand="zoomup"):
    Zoom.__init__(self, listeners, actionCommand)
    self.scale = 1.1
    
class ZoomDown(Zoom):
  def __init__(self, listeners, actionCommand="zoomdown"):
    Zoom.__init__(self, listeners, actionCommand)
    self.scale = 0.9
   
    
class Translate(HAction):
  def __init__(self, listeners, actionCommand="translate"):
    HAction.__init__(self, listeners, actionCommand)
    
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.TRANSLATE, MouseHelper.ZOOM],
                                   [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                  [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                   [MouseHelper.NONE, MouseHelper.NONE]],
                                  [[MouseHelper.TRANSLATE, MouseHelper.DIRECT],
                                   [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
  def deselected(self,leftFunc=None):
    pass
    
    
class Reset(HAction):
  def __init__(self, listeners, actionCommand="reset"):
    HAction.__init__(self, listeners, actionCommand)
    
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):
      self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.NONE, MouseHelper.ZOOM],
                                      [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                     [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                      [MouseHelper.NONE, MouseHelper.NONE]],
                                     [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                      [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
    self.enabled = 1
    
  def deselected(self,leftFunc=None):
    self.enabled = 0
    
  def doit(self, listener, event):
    if self.enabled == 1:
      if event.getId() == DisplayEvent.MOUSE_RELEASED_RIGHT or event.getId() == DisplayEvent.MOUSE_RELEASED_LEFT:
        listener.proj_cntrl.setMatrix(listener.init_matrix)

     
class Pick(HAction):
  def __init__(self, listeners, actionCommand="pick", func=None):
    HAction.__init__(self, listeners, actionCommand)
    self.func = func
    
  def selected(self,leftFunc=None):
    for i in xrange(self.num_listeners):
      if self.func == None:
        
        if leftFunc == "DIRECT":
          leftButton = MouseHelper.DIRECT
        else:
          leftButton = MouseHelper.CURSOR_TRANSLATE
          
        self.listeners[i].mouseHelper.setFunctionMap([[[leftButton, MouseHelper.ZOOM],
                                   [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                  [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                   [MouseHelper.NONE, MouseHelper.NONE]],
                                  [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.DIRECT],
                                   [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
      elif self.func == "grab":
        self.listeners[i].mouseHelper.setFunctionMap([[[MouseHelper.DIRECT, MouseHelper.ZOOM],
                                      [MouseHelper.TRANSLATE, MouseHelper.NONE]],
                                     [[MouseHelper.CURSOR_TRANSLATE, MouseHelper.CURSOR_ZOOM],
                                      [MouseHelper.NONE, MouseHelper.NONE]],
                                     [[MouseHelper.DIRECT, MouseHelper.DIRECT],
                                      [MouseHelper.DIRECT, MouseHelper.DIRECT]]])
      self.listeners[i].display.enableEvent(DisplayEvent.MOUSE_DRAGGED)
    self.enabled = 1

  def deselected(self):
    for i in xrange(self.num_listeners):
      self.listeners[i].display.disableEvent(DisplayEvent.MOUSE_DRAGGED)
    self.enabled = 0