web/lib/django/contrib/gis/utils/layermapping.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/contrib/gis/utils/layermapping.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/contrib/gis/utils/layermapping.py	Tue May 25 02:43:45 2010 +0200
@@ -3,115 +3,15 @@
  The LayerMapping class provides a way to map the contents of OGR
  vector files (e.g. SHP files) to Geographic-enabled Django models.
 
- This grew out of my personal needs, specifically the code repetition
- that went into pulling geometries and fields out of an OGR layer,
- converting to another coordinate system (e.g. WGS84), and then inserting
- into a GeoDjango model.
-
- Please report any bugs encountered using this utility.
-
- Requirements:  OGR C Library (from GDAL) required.
-
- Usage: 
-  lm = LayerMapping(model, source_file, mapping) where,
-
-  model:
-   GeoDjango model (not an instance)
-
-  data:
-   OGR-supported data source file (e.g. a shapefile) or
-    gdal.DataSource instance
-
-  mapping:
-   A python dictionary, keys are strings corresponding
-   to the GeoDjango model field, and values correspond to
-   string field names for the OGR feature, or if the model field
-   is a geographic then it should correspond to the OGR
-   geometry type, e.g. 'POINT', 'LINESTRING', 'POLYGON'.
-
- Keyword Args:
-  layer:
-   The index of the layer to use from the Data Source (defaults to 0)
-
-  source_srs:
-   Use this to specify the source SRS manually (for example, 
-   some shapefiles don't come with a '.prj' file).  An integer SRID,
-   a string WKT, and SpatialReference objects are valid parameters.
-
-  encoding:
-   Specifies the encoding of the string in the OGR data source.
-   For example, 'latin-1', 'utf-8', and 'cp437' are all valid
-   encoding parameters.
-
-  transaction_mode:
-   May be 'commit_on_success' (default) or 'autocommit'.
-
-  transform:
-   Setting this to False will disable all coordinate transformations.  
-
-  unique:
-   Setting this to the name, or a tuple of names, from the given
-   model will create models unique only to the given name(s).
-   Geometries will from each feature will be added into the collection
-   associated with the unique model.  Forces transaction mode to
-   be 'autocommit'.
-
-Example:
-
- 1. You need a GDAL-supported data source, like a shapefile.
-
-  Assume we're using the test_poly SHP file:
-  >>> from django.contrib.gis.gdal import DataSource
-  >>> ds = DataSource('test_poly.shp')
-  >>> layer = ds[0]
-  >>> print layer.fields # Exploring the fields in the layer, we only want the 'str' field.
-  ['float', 'int', 'str']
-  >>> print len(layer) # getting the number of features in the layer (should be 3)
-  3
-  >>> print layer.geom_type # Should be 3 (a Polygon)
-  3
-  >>> print layer.srs # WGS84
-  GEOGCS["GCS_WGS_1984",
-      DATUM["WGS_1984",
-          SPHEROID["WGS_1984",6378137,298.257223563]],
-      PRIMEM["Greenwich",0],
-      UNIT["Degree",0.017453292519943295]]
-
- 2. Now we define our corresponding Django model (make sure to use syncdb):
-
-  from django.contrib.gis.db import models
-  class TestGeo(models.Model, models.GeoMixin):
-      name = models.CharField(maxlength=25) # corresponds to the 'str' field
-      poly = models.PolygonField(srid=4269) # we want our model in a different SRID
-      objects = models.GeoManager()
-      def __str__(self):
-          return 'Name: %s' % self.name
-
- 3. Use LayerMapping to extract all the features and place them in the database:
-
-  >>> from django.contrib.gis.utils import LayerMapping
-  >>> from geoapp.models import TestGeo
-  >>> mapping = {'name' : 'str', # The 'name' model field maps to the 'str' layer field.
-                 'poly' : 'POLYGON', # For geometry fields use OGC name.
-                 } # The mapping is a dictionary
-  >>> lm = LayerMapping(TestGeo, 'test_poly.shp', mapping) 
-  >>> lm.save(verbose=True) # Save the layermap, imports the data. 
-  Saved: Name: 1
-  Saved: Name: 2
-  Saved: Name: 3
-
- LayerMapping just transformed the three geometries from the SHP file from their
- source spatial reference system (WGS84) to the spatial reference system of
- the GeoDjango model (NAD83).  If no spatial reference system is defined for
- the layer, use the `source_srs` keyword with a SpatialReference object to
- specify one.
+ For more information, please consult the GeoDjango documentation:
+   http://geodjango.org/docs/layermapping.html
 """
 import sys
 from datetime import date, datetime
 from decimal import Decimal
 from django.core.exceptions import ObjectDoesNotExist
+from django.db import connections, DEFAULT_DB_ALIAS
 from django.contrib.gis.db.models import GeometryField
-from django.contrib.gis.db.backend import SpatialBackend
 from django.contrib.gis.gdal import CoordTransform, DataSource, \
     OGRException, OGRGeometry, OGRGeomType, SpatialReference
 from django.contrib.gis.gdal.field import \
@@ -128,11 +28,14 @@
 
 class LayerMapping(object):
     "A class that maps OGR Layers to GeoDjango Models."
-    
+
     # Acceptable 'base' types for a multi-geometry type.
     MULTI_TYPES = {1 : OGRGeomType('MultiPoint'),
                    2 : OGRGeomType('MultiLineString'),
                    3 : OGRGeomType('MultiPolygon'),
+                   OGRGeomType('Point25D').num : OGRGeomType('MultiPoint25D'),
+                   OGRGeomType('LineString25D').num : OGRGeomType('MultiLineString25D'),
+                   OGRGeomType('Polygon25D').num : OGRGeomType('MultiPolygon25D'),
                    }
 
     # Acceptable Django field types and corresponding acceptable OGR
@@ -161,10 +64,10 @@
                          'commit_on_success' : transaction.commit_on_success,
                          }
 
-    def __init__(self, model, data, mapping, layer=0, 
+    def __init__(self, model, data, mapping, layer=0,
                  source_srs=None, encoding=None,
-                 transaction_mode='commit_on_success', 
-                 transform=True, unique=None):
+                 transaction_mode='commit_on_success',
+                 transform=True, unique=None, using=DEFAULT_DB_ALIAS):
         """
         A LayerMapping object is initialized using the given Model (not an instance),
         a DataSource (or string path to an OGR-supported data file), and a mapping
@@ -178,20 +81,23 @@
             self.ds = data
         self.layer = self.ds[layer]
 
+        self.using = using
+        self.spatial_backend = connections[using].ops
+
         # Setting the mapping & model attributes.
         self.mapping = mapping
         self.model = model
- 
+
         # Checking the layer -- intitialization of the object will fail if
         # things don't check out before hand.
         self.check_layer()
 
-        # Getting the geometry column associated with the model (an 
+        # Getting the geometry column associated with the model (an
         # exception will be raised if there is no geometry column).
-        if SpatialBackend.mysql:
+        if self.spatial_backend.mysql:
             transform = False
         else:
-            self.geo_col = self.geometry_column()
+            self.geo_field = self.geometry_field()
 
         # Checking the source spatial reference system, and getting
         # the coordinate transformation object (unless the `transform`
@@ -219,14 +125,17 @@
         else:
             self.unique = None
 
-        # Setting the transaction decorator with the function in the 
+        # Setting the transaction decorator with the function in the
         # transaction modes dictionary.
         if transaction_mode in self.TRANSACTION_MODES:
             self.transaction_decorator = self.TRANSACTION_MODES[transaction_mode]
             self.transaction_mode = transaction_mode
         else:
             raise LayerMapError('Unrecognized transaction mode: %s' % transaction_mode)
-    
+
+        if using is None:
+            pass
+
     #### Checking routines used during initialization ####
     def check_fid_range(self, fid_range):
         "This checks the `fid_range` keyword."
@@ -282,30 +191,40 @@
                 if self.geom_field:
                     raise LayerMapError('LayerMapping does not support more than one GeometryField per model.')
 
+                # Getting the coordinate dimension of the geometry field.
+                coord_dim = model_field.dim
+
                 try:
-                    gtype = OGRGeomType(ogr_name)
+                    if coord_dim == 3:
+                        gtype = OGRGeomType(ogr_name + '25D')
+                    else:
+                        gtype = OGRGeomType(ogr_name)
                 except OGRException:
                     raise LayerMapError('Invalid mapping for GeometryField "%s".' % field_name)
 
                 # Making sure that the OGR Layer's Geometry is compatible.
                 ltype = self.layer.geom_type
-                if not (gtype == ltype or self.make_multi(ltype, model_field)):
-                    raise LayerMapError('Invalid mapping geometry; model has %s, feature has %s.' % (fld_name, gtype))
+                if not (ltype.name.startswith(gtype.name) or self.make_multi(ltype, model_field)):
+                    raise LayerMapError('Invalid mapping geometry; model has %s%s, '
+                                        'layer geometry type is %s.' %
+                                        (fld_name, (coord_dim == 3 and '(dim=3)') or '', ltype))
 
                 # Setting the `geom_field` attribute w/the name of the model field
-                # that is a Geometry.
+                # that is a Geometry.  Also setting the coordinate dimension
+                # attribute.
                 self.geom_field = field_name
+                self.coord_dim = coord_dim
                 fields_val = model_field
             elif isinstance(model_field, models.ForeignKey):
                 if isinstance(ogr_name, dict):
                     # Is every given related model mapping field in the Layer?
                     rel_model = model_field.rel.to
-                    for rel_name, ogr_field in ogr_name.items(): 
+                    for rel_name, ogr_field in ogr_name.items():
                         idx = check_ogr_fld(ogr_field)
                         try:
                             rel_field = rel_model._meta.get_field(rel_name)
                         except models.fields.FieldDoesNotExist:
-                            raise LayerMapError('ForeignKey mapping field "%s" not in %s fields.' % 
+                            raise LayerMapError('ForeignKey mapping field "%s" not in %s fields.' %
                                                 (rel_name, rel_model.__class__.__name__))
                     fields_val = rel_model
                 else:
@@ -321,25 +240,25 @@
 
                 # Can the OGR field type be mapped to the Django field type?
                 if not issubclass(ogr_field, self.FIELD_TYPES[model_field.__class__]):
-                    raise LayerMapError('OGR field "%s" (of type %s) cannot be mapped to Django %s.' % 
+                    raise LayerMapError('OGR field "%s" (of type %s) cannot be mapped to Django %s.' %
                                         (ogr_field, ogr_field.__name__, fld_name))
                 fields_val = model_field
-        
+
             self.fields[field_name] = fields_val
 
     def check_srs(self, source_srs):
         "Checks the compatibility of the given spatial reference object."
-        from django.contrib.gis.models import SpatialRefSys
+
         if isinstance(source_srs, SpatialReference):
             sr = source_srs
-        elif isinstance(source_srs, SpatialRefSys):
+        elif isinstance(source_srs, self.spatial_backend.spatial_ref_sys()):
             sr = source_srs.srs
         elif isinstance(source_srs, (int, basestring)):
             sr = SpatialReference(source_srs)
         else:
             # Otherwise just pulling the SpatialReference from the layer
             sr = self.layer.srs
-        
+
         if not sr:
             raise LayerMapError('No source reference system defined.')
         else:
@@ -349,7 +268,7 @@
         "Checks the `unique` keyword parameter -- may be a sequence or string."
         if isinstance(unique, (list, tuple)):
             # List of fields to determine uniqueness with
-            for attr in unique: 
+            for attr in unique:
                 if not attr in self.mapping: raise ValueError
         elif isinstance(unique, basestring):
             # Only a single field passed in.
@@ -370,7 +289,7 @@
         # dictionary mapping.
         for field_name, ogr_name in self.mapping.items():
             model_field = self.fields[field_name]
-            
+
             if isinstance(model_field, GeometryField):
                 # Verify OGR geometry.
                 val = self.verify_geom(feat.geom, model_field)
@@ -385,7 +304,7 @@
             # Setting the keyword arguments for the field name with the
             # value obtained above.
             kwargs[field_name] = val
-            
+
         return kwargs
 
     def unique_kwargs(self, kwargs):
@@ -403,11 +322,11 @@
     def verify_ogr_field(self, ogr_field, model_field):
         """
         Verifies if the OGR Field contents are acceptable to the Django
-        model field.  If they are, the verified value is returned, 
+        model field.  If they are, the verified value is returned,
         otherwise the proper exception is raised.
         """
-        if (isinstance(ogr_field, OFTString) and 
-            isinstance(model_field, (models.CharField, models.TextField))): 
+        if (isinstance(ogr_field, OFTString) and
+            isinstance(model_field, (models.CharField, models.TextField))):
             if self.encoding:
                 # The encoding for OGR data sources may be specified here
                 # (e.g., 'cp437' for Census Bureau boundary files).
@@ -432,14 +351,14 @@
             # Maximum amount of precision, or digits to the left of the decimal.
             max_prec = model_field.max_digits - model_field.decimal_places
 
-            # Getting the digits to the left of the decimal place for the 
+            # Getting the digits to the left of the decimal place for the
             # given decimal.
             if d_idx < 0:
                 n_prec = len(digits[:d_idx])
             else:
                 n_prec = len(digits) + d_idx
 
-            # If we have more than the maximum digits allowed, then throw an 
+            # If we have more than the maximum digits allowed, then throw an
             # InvalidDecimal exception.
             if n_prec > max_prec:
                 raise InvalidDecimal('A DecimalField with max_digits %d, decimal_places %d must round to an absolute value less than 10^%d.' %
@@ -462,7 +381,7 @@
         mapping.
         """
         # TODO: It is expensive to retrieve a model for every record --
-        #  explore if an efficient mechanism exists for caching related 
+        #  explore if an efficient mechanism exists for caching related
         #  ForeignKey models.
 
         # Constructing and verifying the related model keyword arguments.
@@ -475,13 +394,17 @@
             return rel_model.objects.get(**fk_kwargs)
         except ObjectDoesNotExist:
             raise MissingForeignKey('No ForeignKey %s model found with keyword arguments: %s' % (rel_model.__name__, fk_kwargs))
-            
+
     def verify_geom(self, geom, model_field):
         """
         Verifies the geometry -- will construct and return a GeometryCollection
         if necessary (for example if the model field is MultiPolygonField while
         the mapped shapefile only contains Polygons).
         """
+        # Downgrade a 3D geom to a 2D one, if necessary.
+        if self.coord_dim != geom.coord_dim:
+            geom.coord_dim = self.coord_dim
+
         if self.make_multi(geom.geom_type, model_field):
             # Constructing a multi-geometry type to contain the single geometry
             multi_type = self.MULTI_TYPES[geom.geom_type.num]
@@ -491,61 +414,51 @@
             g = geom
 
         # Transforming the geometry with our Coordinate Transformation object,
-        # but only if the class variable `transform` is set w/a CoordTransform 
+        # but only if the class variable `transform` is set w/a CoordTransform
         # object.
         if self.transform: g.transform(self.transform)
-        
+
         # Returning the WKT of the geometry.
         return g.wkt
 
     #### Other model methods ####
     def coord_transform(self):
         "Returns the coordinate transformation object."
-        from django.contrib.gis.models import SpatialRefSys
+        SpatialRefSys = self.spatial_backend.spatial_ref_sys()
         try:
             # Getting the target spatial reference system
-            target_srs = SpatialRefSys.objects.get(srid=self.geo_col.srid).srs
+            target_srs = SpatialRefSys.objects.get(srid=self.geo_field.srid).srs
 
             # Creating the CoordTransform object
             return CoordTransform(self.source_srs, target_srs)
         except Exception, msg:
             raise LayerMapError('Could not translate between the data source and model geometry: %s' % msg)
 
-    def geometry_column(self):
-        "Returns the GeometryColumn model associated with the geographic column."
-        from django.contrib.gis.models import GeometryColumns
-        # Getting the GeometryColumn object.
-        try:
-            db_table = self.model._meta.db_table
-            geo_col = self.geom_field
-            if SpatialBackend.oracle:
-                # Making upper case for Oracle.
-                db_table = db_table.upper()
-                geo_col = geo_col.upper()
-            gc_kwargs = {GeometryColumns.table_name_col() : db_table,
-                         GeometryColumns.geom_col_name() : geo_col,
-                         }
-            return GeometryColumns.objects.get(**gc_kwargs)
-        except Exception, msg:
-            raise LayerMapError('Geometry column does not exist for model. (did you run syncdb?):\n %s' % msg)
+    def geometry_field(self):
+        "Returns the GeometryField instance associated with the geographic column."
+        # Use the `get_field_by_name` on the model's options so that we
+        # get the correct field instance if there's model inheritance.
+        opts = self.model._meta
+        fld, model, direct, m2m = opts.get_field_by_name(self.geom_field)
+        return fld
 
     def make_multi(self, geom_type, model_field):
         """
-        Given the OGRGeomType for a geometry and its associated GeometryField, 
+        Given the OGRGeomType for a geometry and its associated GeometryField,
         determine whether the geometry should be turned into a GeometryCollection.
         """
-        return (geom_type.num in self.MULTI_TYPES and 
+        return (geom_type.num in self.MULTI_TYPES and
                 model_field.__class__.__name__ == 'Multi%s' % geom_type.django)
 
-    def save(self, verbose=False, fid_range=False, step=False, 
+    def save(self, verbose=False, fid_range=False, step=False,
              progress=False, silent=False, stream=sys.stdout, strict=False):
         """
         Saves the contents from the OGR DataSource Layer into the database
-        according to the mapping dictionary given at initialization. 
-        
+        according to the mapping dictionary given at initialization.
+
         Keyword Parameters:
          verbose:
-           If set, information will be printed subsequent to each model save 
+           If set, information will be printed subsequent to each model save
            executed on the database.
 
          fid_range:
@@ -555,32 +468,32 @@
            data source.
 
          step:
-           If set with an integer, transactions will occur at every step 
-           interval. For example, if step=1000, a commit would occur after 
+           If set with an integer, transactions will occur at every step
+           interval. For example, if step=1000, a commit would occur after
            the 1,000th feature, the 2,000th feature etc.
 
          progress:
-           When this keyword is set, status information will be printed giving 
-           the number of features processed and sucessfully saved.  By default, 
-           progress information will pe printed every 1000 features processed, 
-           however, this default may be overridden by setting this keyword with an 
+           When this keyword is set, status information will be printed giving
+           the number of features processed and sucessfully saved.  By default,
+           progress information will pe printed every 1000 features processed,
+           however, this default may be overridden by setting this keyword with an
            integer for the desired interval.
 
          stream:
-           Status information will be written to this file handle.  Defaults to 
+           Status information will be written to this file handle.  Defaults to
            using `sys.stdout`, but any object with a `write` method is supported.
 
          silent:
-           By default, non-fatal error notifications are printed to stdout, but 
+           By default, non-fatal error notifications are printed to stdout, but
            this keyword may be set to disable these notifications.
 
          strict:
-           Execution of the model mapping will cease upon the first error 
+           Execution of the model mapping will cease upon the first error
            encountered.  The default behavior is to attempt to continue.
         """
         # Getting the default Feature ID range.
         default_range = self.check_fid_range(fid_range)
-    
+
         # Setting the progress interval, if requested.
         if progress:
             if progress is True or not isinstance(progress, int):
@@ -588,7 +501,7 @@
             else:
                 progress_interval = progress
 
-        # Defining the 'real' save method, utilizing the transaction 
+        # Defining the 'real' save method, utilizing the transaction
         # decorator created during initialization.
         @self.transaction_decorator
         def _save(feat_range=default_range, num_feat=0, num_saved=0):
@@ -605,7 +518,7 @@
                 except LayerMapError, msg:
                     # Something borked the validation
                     if strict: raise
-                    elif not silent: 
+                    elif not silent:
                         stream.write('Ignoring Feature ID %s because: %s\n' % (feat.fid, msg))
                 else:
                     # Constructing the model using the keyword args
@@ -617,16 +530,16 @@
                             # Getting the keyword arguments and retrieving
                             # the unique model.
                             u_kwargs = self.unique_kwargs(kwargs)
-                            m = self.model.objects.get(**u_kwargs)
+                            m = self.model.objects.using(self.using).get(**u_kwargs)
                             is_update = True
-                                
-                            # Getting the geometry (in OGR form), creating 
-                            # one from the kwargs WKT, adding in additional 
-                            # geometries, and update the attribute with the 
+
+                            # Getting the geometry (in OGR form), creating
+                            # one from the kwargs WKT, adding in additional
+                            # geometries, and update the attribute with the
                             # just-updated geometry WKT.
                             geom = getattr(m, self.geom_field).ogr
                             new = OGRGeometry(kwargs[self.geom_field])
-                            for g in new: geom.add(g) 
+                            for g in new: geom.add(g)
                             setattr(m, self.geom_field, geom.wkt)
                         except ObjectDoesNotExist:
                             # No unique model exists yet, create.
@@ -636,7 +549,7 @@
 
                     try:
                         # Attempting to save.
-                        m.save()
+                        m.save(using=self.using)
                         num_saved += 1
                         if verbose: stream.write('%s: %s\n' % (is_update and 'Updated' or 'Saved', m))
                     except SystemExit:
@@ -646,7 +559,7 @@
                             # Rolling back the transaction so that other model saves
                             # will work.
                             transaction.rollback_unless_managed()
-                        if strict: 
+                        if strict:
                             # Bailing out if the `strict` keyword is set.
                             if not silent:
                                 stream.write('Failed to save the feature (id: %s) into the model with the keyword arguments:\n' % feat.fid)
@@ -658,15 +571,15 @@
                 # Printing progress information, if requested.
                 if progress and num_feat % progress_interval == 0:
                     stream.write('Processed %d features, saved %d ...\n' % (num_feat, num_saved))
-        
+
             # Only used for status output purposes -- incremental saving uses the
             # values returned here.
             return num_saved, num_feat
 
         nfeat = self.layer.num_feat
         if step and isinstance(step, int) and step < nfeat:
-            # Incremental saving is requested at the given interval (step) 
-            if default_range: 
+            # Incremental saving is requested at the given interval (step)
+            if default_range:
                 raise LayerMapError('The `step` keyword may not be used in conjunction with the `fid_range` keyword.')
             beg, num_feat, num_saved = (0, 0, 0)
             indices = range(step, nfeat, step)
@@ -677,7 +590,7 @@
                 # special (e.g, [100:] instead of [90:100]).
                 if i+1 == n_i: step_slice = slice(beg, None)
                 else: step_slice = slice(beg, end)
-            
+
                 try:
                     num_feat, num_saved = _save(step_slice, num_feat, num_saved)
                     beg = end