web/lib/django/contrib/gis/gdal/tests/test_ds.py
changeset 38 77b6da96e6f1
equal deleted inserted replaced
37:8d941af65caf 38:77b6da96e6f1
       
     1 import os, os.path, unittest
       
     2 from django.contrib.gis.gdal import DataSource, Envelope, OGRGeometry, OGRException, OGRIndexError
       
     3 from django.contrib.gis.gdal.field import OFTReal, OFTInteger, OFTString
       
     4 from django.contrib import gis
       
     5 
       
     6 # Path for SHP files
       
     7 data_path = os.path.join(os.path.dirname(gis.__file__), 'tests' + os.sep + 'data')
       
     8 def get_ds_file(name, ext):
       
     9     return os.sep.join([data_path, name, name + '.%s' % ext])
       
    10 
       
    11 # Test SHP data source object
       
    12 class TestDS:
       
    13     def __init__(self, name, **kwargs):
       
    14         ext = kwargs.pop('ext', 'shp')
       
    15         self.ds = get_ds_file(name, ext)
       
    16         for key, value in kwargs.items():
       
    17             setattr(self, key, value)
       
    18 
       
    19 # List of acceptable data sources.
       
    20 ds_list = (TestDS('test_point', nfeat=5, nfld=3, geom='POINT', gtype=1, driver='ESRI Shapefile',
       
    21                   fields={'dbl' : OFTReal, 'int' : OFTInteger, 'str' : OFTString,},
       
    22                   extent=(-1.35011,0.166623,-0.524093,0.824508), # Got extent from QGIS
       
    23                   srs_wkt='GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]',
       
    24                   field_values={'dbl' : [float(i) for i in range(1, 6)], 'int' : range(1, 6), 'str' : [str(i) for i in range(1, 6)]},
       
    25                   fids=range(5)),
       
    26            TestDS('test_vrt', ext='vrt', nfeat=3, nfld=3, geom='POINT', gtype='Point25D', driver='VRT',
       
    27                   fields={'POINT_X' : OFTString, 'POINT_Y' : OFTString, 'NUM' : OFTString}, # VRT uses CSV, which all types are OFTString.
       
    28                   extent=(1.0, 2.0, 100.0, 523.5), # Min/Max from CSV
       
    29                   field_values={'POINT_X' : ['1.0', '5.0', '100.0'], 'POINT_Y' : ['2.0', '23.0', '523.5'], 'NUM' : ['5', '17', '23']},
       
    30                   fids=range(1,4)),
       
    31            TestDS('test_poly', nfeat=3, nfld=3, geom='POLYGON', gtype=3, 
       
    32                   driver='ESRI Shapefile',
       
    33                   fields={'float' : OFTReal, 'int' : OFTInteger, 'str' : OFTString,},
       
    34                   extent=(-1.01513,-0.558245,0.161876,0.839637), # Got extent from QGIS
       
    35                   srs_wkt='GEOGCS["GCS_WGS_1984",DATUM["WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]'),
       
    36            )
       
    37 
       
    38 bad_ds = (TestDS('foo'),
       
    39           )
       
    40 
       
    41 class DataSourceTest(unittest.TestCase):
       
    42 
       
    43     def test01_valid_shp(self):
       
    44         "Testing valid SHP Data Source files."
       
    45 
       
    46         for source in ds_list:
       
    47             # Loading up the data source
       
    48             ds = DataSource(source.ds)
       
    49 
       
    50             # Making sure the layer count is what's expected (only 1 layer in a SHP file)
       
    51             self.assertEqual(1, len(ds))
       
    52 
       
    53             # Making sure GetName works
       
    54             self.assertEqual(source.ds, ds.name)
       
    55 
       
    56             # Making sure the driver name matches up
       
    57             self.assertEqual(source.driver, str(ds.driver))
       
    58 
       
    59             # Making sure indexing works
       
    60             try:
       
    61                 ds[len(ds)]
       
    62             except OGRIndexError:
       
    63                 pass
       
    64             else:
       
    65                 self.fail('Expected an IndexError!')
       
    66                         
       
    67     def test02_invalid_shp(self):
       
    68         "Testing invalid SHP files for the Data Source."
       
    69         for source in bad_ds:
       
    70             self.assertRaises(OGRException, DataSource, source.ds)
       
    71 
       
    72     def test03a_layers(self):
       
    73         "Testing Data Source Layers."
       
    74         print "\nBEGIN - expecting out of range feature id error; safe to ignore.\n"
       
    75         for source in ds_list:
       
    76             ds = DataSource(source.ds)
       
    77 
       
    78             # Incrementing through each layer, this tests DataSource.__iter__
       
    79             for layer in ds:                
       
    80                 # Making sure we get the number of features we expect
       
    81                 self.assertEqual(len(layer), source.nfeat)
       
    82 
       
    83                 # Making sure we get the number of fields we expect
       
    84                 self.assertEqual(source.nfld, layer.num_fields)
       
    85                 self.assertEqual(source.nfld, len(layer.fields))
       
    86 
       
    87                 # Testing the layer's extent (an Envelope), and it's properties
       
    88                 self.assertEqual(True, isinstance(layer.extent, Envelope))
       
    89                 self.assertAlmostEqual(source.extent[0], layer.extent.min_x, 5)
       
    90                 self.assertAlmostEqual(source.extent[1], layer.extent.min_y, 5)
       
    91                 self.assertAlmostEqual(source.extent[2], layer.extent.max_x, 5)
       
    92                 self.assertAlmostEqual(source.extent[3], layer.extent.max_y, 5)
       
    93 
       
    94                 # Now checking the field names.
       
    95                 flds = layer.fields
       
    96                 for f in flds: self.assertEqual(True, f in source.fields)
       
    97                 
       
    98                 # Negative FIDs are not allowed.
       
    99                 self.assertRaises(OGRIndexError, layer.__getitem__, -1)
       
   100                 self.assertRaises(OGRIndexError, layer.__getitem__, 50000)
       
   101 
       
   102                 if hasattr(source, 'field_values'):
       
   103                     fld_names = source.field_values.keys()
       
   104 
       
   105                     # Testing `Layer.get_fields` (which uses Layer.__iter__)
       
   106                     for fld_name in fld_names:
       
   107                         self.assertEqual(source.field_values[fld_name], layer.get_fields(fld_name))
       
   108 
       
   109                     # Testing `Layer.__getitem__`.
       
   110                     for i, fid in enumerate(source.fids):
       
   111                         feat = layer[fid]
       
   112                         self.assertEqual(fid, feat.fid)
       
   113                         # Maybe this should be in the test below, but we might as well test
       
   114                         # the feature values here while in this loop.
       
   115                         for fld_name in fld_names:
       
   116                             self.assertEqual(source.field_values[fld_name][i], feat.get(fld_name))
       
   117         print "\nEND - expecting out of range feature id error; safe to ignore."
       
   118                         
       
   119     def test03b_layer_slice(self):
       
   120         "Test indexing and slicing on Layers."
       
   121         # Using the first data-source because the same slice
       
   122         # can be used for both the layer and the control values.
       
   123         source = ds_list[0]
       
   124         ds = DataSource(source.ds)
       
   125 
       
   126         sl = slice(1, 3)
       
   127         feats = ds[0][sl]
       
   128 
       
   129         for fld_name in ds[0].fields:
       
   130             test_vals = [feat.get(fld_name) for feat in feats]
       
   131             control_vals = source.field_values[fld_name][sl]
       
   132             self.assertEqual(control_vals, test_vals)
       
   133 
       
   134     def test03c_layer_references(self):
       
   135         "Test to make sure Layer access is still available without the DataSource."
       
   136         source = ds_list[0]
       
   137 
       
   138         # See ticket #9448.
       
   139         def get_layer():
       
   140             # This DataSource object is not accessible outside this
       
   141             # scope.  However, a reference should still be kept alive
       
   142             # on the `Layer` returned.
       
   143             ds = DataSource(source.ds)
       
   144             return ds[0]
       
   145 
       
   146         # Making sure we can call OGR routines on the Layer returned.
       
   147         lyr = get_layer()
       
   148         self.assertEqual(source.nfeat, len(lyr))
       
   149         self.assertEqual(source.gtype, lyr.geom_type.num)        
       
   150 
       
   151     def test04_features(self):
       
   152         "Testing Data Source Features."
       
   153         for source in ds_list:
       
   154             ds = DataSource(source.ds)
       
   155 
       
   156             # Incrementing through each layer
       
   157             for layer in ds:
       
   158                 # Incrementing through each feature in the layer
       
   159                 for feat in layer:
       
   160                     # Making sure the number of fields, and the geometry type
       
   161                     # are what's expected.
       
   162                     self.assertEqual(source.nfld, len(list(feat)))
       
   163                     self.assertEqual(source.gtype, feat.geom_type)
       
   164 
       
   165                     # Making sure the fields match to an appropriate OFT type.
       
   166                     for k, v in source.fields.items():
       
   167                         # Making sure we get the proper OGR Field instance, using
       
   168                         # a string value index for the feature.
       
   169                         self.assertEqual(True, isinstance(feat[k], v))
       
   170 
       
   171                     # Testing Feature.__iter__
       
   172                     for fld in feat: self.assertEqual(True, fld.name in source.fields.keys())
       
   173                         
       
   174     def test05_geometries(self):
       
   175         "Testing Geometries from Data Source Features."
       
   176         for source in ds_list:
       
   177             ds = DataSource(source.ds)
       
   178 
       
   179             # Incrementing through each layer and feature.
       
   180             for layer in ds:
       
   181                 for feat in layer:
       
   182                     g = feat.geom
       
   183 
       
   184                     # Making sure we get the right Geometry name & type
       
   185                     self.assertEqual(source.geom, g.geom_name)
       
   186                     self.assertEqual(source.gtype, g.geom_type)
       
   187 
       
   188                     # Making sure the SpatialReference is as expected.
       
   189                     if hasattr(source, 'srs_wkt'):
       
   190                         self.assertEqual(source.srs_wkt, g.srs.wkt)
       
   191 
       
   192     def test06_spatial_filter(self):
       
   193         "Testing the Layer.spatial_filter property."
       
   194         ds = DataSource(get_ds_file('cities', 'shp'))
       
   195         lyr = ds[0]
       
   196 
       
   197         # When not set, it should be None.
       
   198         self.assertEqual(None, lyr.spatial_filter)
       
   199 
       
   200         # Must be set a/an OGRGeometry or 4-tuple.
       
   201         self.assertRaises(TypeError, lyr._set_spatial_filter, 'foo')
       
   202 
       
   203         # Setting the spatial filter with a tuple/list with the extent of
       
   204         # a buffer centering around Pueblo.
       
   205         self.assertRaises(ValueError, lyr._set_spatial_filter, range(5))
       
   206         filter_extent = (-105.609252, 37.255001, -103.609252, 39.255001)
       
   207         lyr.spatial_filter = (-105.609252, 37.255001, -103.609252, 39.255001)
       
   208         self.assertEqual(OGRGeometry.from_bbox(filter_extent), lyr.spatial_filter)
       
   209         feats = [feat for feat in lyr]
       
   210         self.assertEqual(1, len(feats))
       
   211         self.assertEqual('Pueblo', feats[0].get('Name'))
       
   212 
       
   213         # Setting the spatial filter with an OGRGeometry for buffer centering
       
   214         # around Houston.
       
   215         filter_geom = OGRGeometry('POLYGON((-96.363151 28.763374,-94.363151 28.763374,-94.363151 30.763374,-96.363151 30.763374,-96.363151 28.763374))')
       
   216         lyr.spatial_filter = filter_geom
       
   217         self.assertEqual(filter_geom, lyr.spatial_filter)
       
   218         feats = [feat for feat in lyr]
       
   219         self.assertEqual(1, len(feats))
       
   220         self.assertEqual('Houston', feats[0].get('Name'))
       
   221 
       
   222         # Clearing the spatial filter by setting it to None.  Now
       
   223         # should indicate that there are 3 features in the Layer.
       
   224         lyr.spatial_filter = None
       
   225         self.assertEqual(3, len(lyr))
       
   226         
       
   227 def suite():
       
   228     s = unittest.TestSuite()
       
   229     s.addTest(unittest.makeSuite(DataSourceTest))
       
   230     return s
       
   231 
       
   232 def run(verbosity=2):
       
   233     unittest.TextTestRunner(verbosity=verbosity).run(suite())