import jarray
import ucar
from java.net import URL
# netcdfhelper -- generic class for helping read NetCDF files
# using the version 2 API.
class netcdfhelper:

  def __init__(self, fname):
    """
    fname is the NetCDF file name
    """
    try:
      self.file = ucar.nc2.NetcdfFile(fname)
    except:
      print "local file: ",fname," not found"
      print "trying to open as a remote DODS file..."
      try:
        self.file = ucar.nc2.dods.DODSNetcdfFile(fname)
      except:
        print "dods file: ",fname," not found"
        print "tyring to open as URL: ",fname
        url = URL(fname)
        self.file = ucar.nc2.NetcdfFile(url)
    

  def close(self):
    self.file.close()

  def getDimensions(self):
    """
    Return a dictionary of the dimenions for each variable
    """
    self.di = self.file.getDimensionIterator()
    self.dim = {}
    while (self.di.hasNext()):
      self.dit = self.di.next()
      self.dim[self.dit.getName()] = self.dit.getLength()
    return self.dim

  def getAttributes(self):
    """
    Return a list of the attributes
    """
    self.ai = self.file.getGlobalAttributeIterator()
    self.att = []
    while (self.ai.hasNext()):
      self.att.append(self.ai.next())
    return self.att

  def getAttribute(self,name):
    """
    Get a single, named attribute
    """
    return self.file.findGlobalAttribute(name)

  def getVariable(self, name):
    """
    Get a single, named variable
    """
    return self.file.findVariable(name)

  def getVariables(self):
    """
    Return a list of all the variables
    """
    self.vi = self.file.getVariableIterator()
    self.var = []
    while (self.vi.hasNext()):
      self.var.append(self.vi.next())
    return self.var

  def getFloat(self, variable, start):
    """
    Get the value of the variable at the index (may be more than
    """
    size = []
    for i in xrange(len(start)):
      size.append(1)
    array = variable.read(start, size)
    ja = array.copyTo1DJavaArray()
    return float(ja[0])


  def getValues(self, variable, start=None, size=None, stride=None):
    if (stride != None):
      ncdf_array = variable.read(start, size, stride)
      array = ncdf_array.copyTo1DJavaArray()
      return array
    if (start == None and size == None):
      ncdf_array = variable.read()
      array = ncdf_array.copyTo1DJavaArray()
      return array
    else:
      slice   = []
      section = []
      for i in xrange(len(start)):
        if (size[i] == 1):
          slice.append(i)
        else:
          section.append(i)

      mas = variable
      if len(slice) > 0:
        for i in xrange(len(slice)):
          mas = ucar.ma2.MultiArraySlice(mas, slice[(len(slice)-1)-i], start[slice[(len(slice)-1)-i]])

      mas = ucar.ma2.MultiArraySection(mas, [start[section[i]] for i in xrange(len(section))], [size[section[i]] for i in xrange(len(section))])
      ncdf_array = mas.read()
      array = ncdf_array.copyTo1DJavaArray()
      return array
