from visad.georef import MapProjection
from visad.data.hdfeos import PolarStereographic, LambertAzimuthalEqualArea
from visad import RealTupleType, Data, CoordinateSystem, Display, Gridded2DSet, CommonUnit
from util import Helper
from java.awt.geom import Rectangle2D
import jarray

false = 0
true  = 1

lon_idx = 1
lat_idx = 0

  
def getValidCorners(set, valid_range):
  samples = set.getSamples(false)
  lens    = set.getLengths()
  
  for j in xrange(lens[1]):
    a = None
    b = None
    c = None
    d = None
    
    for j in xrange(lens[1]):     
      idx = j*lens[0]
      if ((a == None) and ((samples[lon_idx][idx] >= valid_range[0][0]) and (samples[lon_idx][idx] <= valid_range[0][1])) and
                            ((samples[lat_idx][idx] >= valid_range[1][0]) and (samples[lat_idx][idx] <= valid_range[1][1]))):
        a = idx
        
      idx = (j+1)*lens[0] - 1
      if ((b == None) and ((samples[lon_idx][idx] >= valid_range[0][0]) and (samples[lon_idx][idx] <= valid_range[0][1])) and
                            ((samples[lat_idx][idx] >= valid_range[1][0]) and (samples[lat_idx][idx] <= valid_range[1][1]))): 
        b = idx
        
      idx = ((lens[1]-j)-1)*lens[0]
      if ((c == None) and ((samples[lon_idx][idx] >= valid_range[0][0]) and (samples[lon_idx][idx] <= valid_range[0][1])) and
                            ((samples[lat_idx][idx] >= valid_range[1][0]) and (samples[lat_idx][idx] <= valid_range[1][1]))): 
        c = idx
        
      idx = (lens[1]-j)*lens[0] - 1
      if ((d == None) and ((samples[lon_idx][idx] >= valid_range[0][0]) and (samples[lon_idx][idx] <= valid_range[0][1])) and
                            ((samples[lat_idx][idx] >= valid_range[1][0]) and (samples[lat_idx][idx] <= valid_range[1][1]))):
        d = idx     


    return [[samples[lon_idx][a], samples[lon_idx][b], samples[lon_idx][c], samples[lon_idx][d]],
            [samples[lat_idx][a], samples[lat_idx][b], samples[lat_idx][c], samples[lat_idx][d]]]
      


class LambertEA(MapProjection):
  def __init__(self, corners, earthRadius=6367470, false_easting=0, false_northing=0):
    MapProjection.__init__(self, RealTupleType.SpatialEarth2DTuple, None)
    
    
    minLon = min(corners[0])
    minLat = min(corners[1])
    maxLon = max(corners[0])
    maxLat = max(corners[1])
    
    londiff   = maxLon - minLon
    lonCenter = minLon + (maxLon - minLon)/2
    if londiff > 180:
      lonCenter += 180
    latCenter = minLat + (maxLat - minLat)/2
    
    
    self.cs = LambertAzimuthalEqualArea(self.getReference(),
                               earthRadius,
                               lonCenter*Data.DEGREES_TO_RADIANS,
                               latCenter*Data.DEGREES_TO_RADIANS, 0, 0)
                               
    cs_corners  = self.cs.fromReference(corners)
    
    min_x     = min(cs_corners[0])
    min_y     = min(cs_corners[1])
    max_x     = max(cs_corners[0])
    max_y     = max(cs_corners[1])
    self.rect = Rectangle2D.Float(min_x, min_y, (max_x - min_x), (max_y - min_y))
                               
  def getDefaultMapArea(self):
    return self.rect
    
  def fromReference(self, values):
    return self.cs.fromReference(values)
    
  def toReference(self, values):
    return self.cs.toReference(values)
    
    
class MapProjectionAdapter(CoordinateSystem):
  def __init__(self, mapProjection):
    self.mapProjection = mapProjection
    CoordinateSystem.__init__(self, Display.DisplaySpatialCartesianTuple, [CommonUnit.degree, CommonUnit.degree, None])
    
    bounds = self.mapProjection.getDefaultMapArea()
    
    self.scaleX  = bounds.getWidth()/2.0
    self.scaleY  = bounds.getHeight()/2.0
    self.offsetX = bounds.getX() + self.scaleX
    self.offsetY = bounds.getY() + self.scaleY
   
    
  def toReference(self, values):
    values_3D = []
    values_2D = self.mapProjection.fromReference([values[0], values[1]])
    for i in xrange(len(values_2D[0])):
      values_2D[0][i] = (values_2D[0][i]-self.offsetX)/self.scaleX
      values_2D[1][i] = (values_2D[1][i]-self.offsetY)/self.scaleY
    values_3D.append(values_2D[0])
    values_3D.append(values_2D[1])
    values_3D.append(values[2])
    return values_3D
      
  def fromReference(self, values):
    values_3D     = []
    values_2D     = []
    for i in xrange(len(values[0])):
      values[0][i] = values[0][i]*self.scaleX + self.offsetX
      values[1][i] = values[1][i]*self.scaleY + self.offsetY
    values_2D.append(values[0])
    values_2D.append(values[1])
    values_2D = self.mapProjection.toReference(values_2D)
    values_3D.append(values_2D[0])
    values_3D.append(values_2D[1])
    values_3D.append(values[2])
    return values_3D
