web/lib/django/db/backends/postgresql_psycopg2/base.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/db/backends/postgresql_psycopg2/base.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/db/backends/postgresql_psycopg2/base.py	Tue May 25 02:43:45 2010 +0200
@@ -4,7 +4,9 @@
 Requires psycopg 2: http://initd.org/projects/psycopg2
 """
 
-from django.conf import settings
+import sys
+
+from django.db import utils
 from django.db.backends import *
 from django.db.backends.signals import connection_created
 from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations
@@ -28,6 +30,40 @@
 psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
 psycopg2.extensions.register_adapter(SafeUnicode, psycopg2.extensions.QuotedString)
 
+class CursorWrapper(object):
+    """
+    A thin wrapper around psycopg2's normal cursor class so that we can catch
+    particular exception instances and reraise them with the right types.
+    """
+
+    def __init__(self, cursor):
+        self.cursor = cursor
+
+    def execute(self, query, args=None):
+        try:
+            return self.cursor.execute(query, args)
+        except Database.IntegrityError, e:
+            raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+        except Database.DatabaseError, e:
+            raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2]
+
+    def executemany(self, query, args):
+        try:
+            return self.cursor.executemany(query, args)
+        except Database.IntegrityError, e:
+            raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+        except Database.DatabaseError, e:
+            raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2]
+
+    def __getattr__(self, attr):
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        else:
+            return getattr(self.cursor, attr)
+
+    def __iter__(self):
+        return iter(self.cursor)
+
 class DatabaseFeatures(BaseDatabaseFeatures):
     needs_datetime_string_cast = False
     can_return_id_from_insert = False
@@ -64,45 +100,48 @@
         super(DatabaseWrapper, self).__init__(*args, **kwargs)
 
         self.features = DatabaseFeatures()
-        autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False)
+        autocommit = self.settings_dict["OPTIONS"].get('autocommit', False)
         self.features.uses_autocommit = autocommit
         self._set_isolation_level(int(not autocommit))
-        self.ops = DatabaseOperations()
+        self.ops = DatabaseOperations(self)
         self.client = DatabaseClient(self)
         self.creation = DatabaseCreation(self)
         self.introspection = DatabaseIntrospection(self)
-        self.validation = BaseDatabaseValidation()
+        self.validation = BaseDatabaseValidation(self)
 
     def _cursor(self):
+        new_connection = False
         set_tz = False
         settings_dict = self.settings_dict
         if self.connection is None:
-            set_tz = True
-            if settings_dict['DATABASE_NAME'] == '':
+            new_connection = True
+            set_tz = settings_dict.get('TIME_ZONE')
+            if settings_dict['NAME'] == '':
                 from django.core.exceptions import ImproperlyConfigured
-                raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.")
+                raise ImproperlyConfigured("You need to specify NAME in your Django settings file.")
             conn_params = {
-                'database': settings_dict['DATABASE_NAME'],
+                'database': settings_dict['NAME'],
             }
-            conn_params.update(settings_dict['DATABASE_OPTIONS'])
+            conn_params.update(settings_dict['OPTIONS'])
             if 'autocommit' in conn_params:
                 del conn_params['autocommit']
-            if settings_dict['DATABASE_USER']:
-                conn_params['user'] = settings_dict['DATABASE_USER']
-            if settings_dict['DATABASE_PASSWORD']:
-                conn_params['password'] = settings_dict['DATABASE_PASSWORD']
-            if settings_dict['DATABASE_HOST']:
-                conn_params['host'] = settings_dict['DATABASE_HOST']
-            if settings_dict['DATABASE_PORT']:
-                conn_params['port'] = settings_dict['DATABASE_PORT']
+            if settings_dict['USER']:
+                conn_params['user'] = settings_dict['USER']
+            if settings_dict['PASSWORD']:
+                conn_params['password'] = settings_dict['PASSWORD']
+            if settings_dict['HOST']:
+                conn_params['host'] = settings_dict['HOST']
+            if settings_dict['PORT']:
+                conn_params['port'] = settings_dict['PORT']
             self.connection = Database.connect(**conn_params)
             self.connection.set_client_encoding('UTF8')
             self.connection.set_isolation_level(self.isolation_level)
             connection_created.send(sender=self.__class__)
         cursor = self.connection.cursor()
         cursor.tzinfo_factory = None
-        if set_tz:
-            cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']])
+        if new_connection:
+            if set_tz:
+                cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']])
             if not hasattr(self, '_version'):
                 self.__class__._version = get_version(cursor)
             if self._version[0:2] < (8, 0):
@@ -119,7 +158,7 @@
                     # versions that support it, but, right now, that's hard to
                     # do without breaking other things (#10509).
                     self.features.can_return_id_from_insert = True
-        return cursor
+        return CursorWrapper(cursor)
 
     def _enter_transaction_management(self, managed):
         """
@@ -150,4 +189,3 @@
         finally:
             self.isolation_level = level
             self.features.uses_savepoints = bool(level)
-