from java.lang import Double
from visad import DisplayListener, DisplayEvent, CellImpl, Real, ConstantMap, Display, Gridded1DSet, FlatField, RealTupleType, FunctionType
from visad.java3d import DirectManipulationRendererJ3D
import subs
from subs import *
import jarray
from visad.python.JPythonMethods import *
from javax.swing import JTextField, JPanel
from java.awt import BorderLayout, FlowLayout, Color
from javax.swing.border import LineBorder
import graph

import hydra_properties

#selector_colors = ["magenta", Color(0.5, 0.5, 0.0), Color(0.0, 0.5, 0.5)]


class linkHistLocation(DisplayListener):
  def __init__(self, display, scalar):
    self.x, self.y, self.z, self.display=subs.getDisplayMaps(display)
    self.dr = self.display.getDisplayRenderer()
    self.display.addDisplayListener(self)
    self.scalar=scalar
    self.called=0

  def displayChanged(self, event):
    if event.getId() == DisplayEvent.MOUSE_RELEASED:
      self.xx = self.dr.getDirectAxisValue(self.x)
      self.yy = self.dr.getDirectAxisValue(self.y)
      if not Double(self.xx).isNaN():
        if self.called==0:
          self.display.disableAction()
          self.val=[self.xx]
          self.display.enableAction()
          self.called=1
        else:
          self.display.disableAction()
          self.val.append(self.xx)         
          self.called=0
          self.scalar.setRange(min(self.val),max(self.val))
          self.display.enableAction()
     
     
class dragHistLocation(CellImpl):
  #def __init__(self, display, scalarMap):
  def __init__(self, scalarMap, img):
    #xmap, ymap, zmap, d = subs.getDisplayScalarMaps(display)
    #self.display   = display
    self.scalarMap = scalarMap
    self.setup     = 0
    
    self.botPanel   = JPanel()
    self.display = graph.histogram(img, 20, bottom=self.botPanel,clip=0,color=hydra_properties.foreground_color)
    dsp_rdr = self.display.getDisplayRenderer()
    dsp_rdr.setBackgroundColor(hydra_properties.background_color)
    dsp_rdr.setForegroundColor(hydra_properties.foreground_color)
    xmap, ymap, zmap, d = subs.getDisplayScalarMaps(self.display)
    
    lowhi = self.scalarMap.getRange()
    
    self.leftSelectRef = self.display.addData("left", Real(xmap.getScalar(), lowhi[0]),
      constantMaps=[ConstantMap(-1, Display.YAxis)]+[ConstantMap(8,Display.PointSize)]+makeColorMap("magenta"), renderer=DirectManipulationRendererJ3D(), zlayer=0.1)
      
    self.rghtSelectRef = self.display.addData("rght", Real(xmap.getScalar(), lowhi[1]),
      constantMaps=[ConstantMap(-1, Display.YAxis)]+[ConstantMap(8,Display.PointSize)]+makeColorMap("green"), renderer=DirectManipulationRendererJ3D(), zlayer=0.1)
          
    self.addReference(self.leftSelectRef)
    self.addReference(self.rghtSelectRef)
    
    manual = 1
    if manual == 1:
      self.leftSelectTextField = JTextField(actionPerformed=self.leftSelect)
      self.leftSelectTextField.setBorder(LineBorder(Color.magenta,2))
      self.leftSelectTextField.setText('%8.3f' % lowhi[0])
      self.rghtSelectTextField = JTextField(actionPerformed=self.rghtSelect)
      self.rghtSelectTextField.setBorder(LineBorder(Color.green,2))
      
      self.rghtSelectTextField.setText('%8.3f' % lowhi[1])
      textPanel = JPanel()
      textPanel.setLayout(FlowLayout())
      textPanel.add(self.leftSelectTextField)
      textPanel.add(self.rghtSelectTextField)
      self.botPanel.add(textPanel)
      
    self.setup = 1   #- don't trigger the doAction below in the init
            
  def doAction(self):
    if self.setup == 1:
      lval = self.leftSelectRef.getData().getValue()
      self.leftSelectTextField.setText('%8.3f' % lval)
      rval = self.rghtSelectRef.getData().getValue() 
      self.rghtSelectTextField.setText('%8.3f' % rval)
      self.scalarMap.setRange(lval, rval)
      
  def leftSelect(self,event):
    txt = self.leftSelectTextField.getText()
    val = float(txt)
    type = self.leftSelectRef.getData().getType()
    self.leftSelectRef.setData(Real(type, val))
    
    
  def rghtSelect(self,event):
    txt = self.rghtSelectTextField.getText()
    val = float(txt)
    type = self.rghtSelectRef.getData().getType()
    self.rghtSelectRef.setData(Real(type, val))
      
      
class dragHistLocationColorFunction(dragHistLocation):
  def __init__(self, scalarMap, img, vminmaxbig, colorTable, colorscaleselect=None):
    dragHistLocation.__init__(self, scalarMap, img)
    self.vminmaxbig = vminmaxbig
    self.ct = colorTable
    self.colorscaleselect = colorscaleselect
    
  def doAction(self):
    if self.setup == 1:
      lval = self.leftSelectRef.getData().getValue()
      self.leftSelectTextField.setText('%8.3f' % lval)
      rval = self.rghtSelectRef.getData().getValue()  
      self.rghtSelectTextField.setText('%8.3f' % rval)

      cc = self.scalarMap.getControl()
         
      ct_len = len(self.ct[0])
      ct = self.colorscaleselect.currentTable
      
      vmin = self.vminmaxbig[0]
      vmax = self.vminmaxbig[1]
      vdiff = vmax - vmin
      vinc  = vdiff/(ct_len - 1)
      
      if rval <= vmax:
        fsmpls = jarray.zeros(ct_len+7,'f')
        fsmpls[0] = -10000.0
        fsmpls[1] = 0.0
        for i in xrange(ct_len-1): fsmpls[i+2] = fsmpls[i+1] + vinc/(vmax - vmin)
        fsmpls[ct_len+1] = (vmax - lval)/(rval - lval)
        fsmpls[ct_len+1] *= 1.0001
        fsmpls[ct_len+2] = (self.vminmaxbig[2] - lval)/(rval - lval)
        fsmpls[ct_len+3] = (self.vminmaxbig[3] - lval)/(rval - lval)
        fsmpls[ct_len+4] = (self.vminmaxbig[4] - lval)/(rval - lval)
        fsmpls[ct_len+5] = (self.vminmaxbig[5] - lval)/(rval - lval)
        fsmpls[ct_len+6] = (self.vminmaxbig[6] - lval)/(rval - lval)
        clr_set = Gridded1DSet(RealType.Generic, [fsmpls], ct_len+7)
        new_ct  = [[], [], []]
                
        new_ct[0].append(ct[0][0]), new_ct[1].append(ct[1][0]), new_ct[2].append(ct[2][0])
        for i in range(ct_len): new_ct[0].append(ct[0][i]), new_ct[1].append(ct[1][i]), new_ct[2].append(ct[2][i])
        new_ct[0].append(ct[0][ct_len-1]), new_ct[1].append(ct[1][ct_len-1]), new_ct[2].append(ct[2][ct_len-1])
        new_ct[0].append(0.3), new_ct[1].append(0.0), new_ct[2].append(0.3)
        new_ct[0].append(0.5), new_ct[1].append(0.5), new_ct[2].append(0.0)
        new_ct[0].append(0.0), new_ct[1].append(0.5), new_ct[2].append(0.5)
        new_ct[0].append(1.0), new_ct[1].append(0.0), new_ct[2].append(0.0)
        new_ct[0].append(0.0), new_ct[1].append(0.0), new_ct[2].append(1.0)
      else:
        fsmpls = jarray.zeros(ct_len+6,'f')
        fsmpls[0] = -10000.0
        fsmpls[1] = 0.0
        for i in xrange(ct_len-1): fsmpls[i+2] = fsmpls[i+1] + vinc/(vmax - vmin)
        fsmpls[ct_len+1] = (self.vminmaxbig[2] - lval)/(rval - lval)
        fsmpls[ct_len+2] = (self.vminmaxbig[3] - lval)/(rval - lval)
        fsmpls[ct_len+3] = (self.vminmaxbig[4] - lval)/(rval - lval)
        fsmpls[ct_len+4] = (self.vminmaxbig[5] - lval)/(rval - lval)
        fsmpls[ct_len+5] = (self.vminmaxbig[6] - lval)/(rval - lval)
        clr_set = Gridded1DSet(RealType.Generic, [fsmpls], ct_len+6)
        new_ct  = [[], [], []]
                
        new_ct[0].append(ct[0][0]), new_ct[1].append(ct[1][0]), new_ct[2].append(ct[2][0])
        for i in range(ct_len): new_ct[0].append(ct[0][i]), new_ct[1].append(ct[1][i]), new_ct[2].append(ct[2][i])
        new_ct[0].append(0.3), new_ct[1].append(0.0), new_ct[2].append(0.3)    
        new_ct[0].append(0.5), new_ct[1].append(0.5), new_ct[2].append(0.0)    
        new_ct[0].append(0.0), new_ct[1].append(0.5), new_ct[2].append(0.5)    
        new_ct[0].append(1.0), new_ct[1].append(0.0), new_ct[2].append(0.0)    
        new_ct[0].append(0.0), new_ct[1].append(0.0), new_ct[2].append(1.0)    
        
      clr_fnc = FlatField(FunctionType(RealType.Generic, RealTupleType([makeType("cR"), makeType("cG"), makeType("cB")])), clr_set)
      clr_fnc.setSamples(new_ct)
      
      self.display.disableAction()
      cc.setFunction(clr_fnc)
      self.scalarMap.setRange(lval, rval)
      self.display.enableAction()
      