diff -r b758351d191f -r cc9b7e14412b web/lib/django/db/backends/postgresql_psycopg2/base.py --- a/web/lib/django/db/backends/postgresql_psycopg2/base.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/db/backends/postgresql_psycopg2/base.py Tue May 25 02:43:45 2010 +0200 @@ -4,7 +4,9 @@ Requires psycopg 2: http://initd.org/projects/psycopg2 """ -from django.conf import settings +import sys + +from django.db import utils from django.db.backends import * from django.db.backends.signals import connection_created from django.db.backends.postgresql.operations import DatabaseOperations as PostgresqlDatabaseOperations @@ -28,6 +30,40 @@ psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString) psycopg2.extensions.register_adapter(SafeUnicode, psycopg2.extensions.QuotedString) +class CursorWrapper(object): + """ + A thin wrapper around psycopg2's normal cursor class so that we can catch + particular exception instances and reraise them with the right types. + """ + + def __init__(self, cursor): + self.cursor = cursor + + def execute(self, query, args=None): + try: + return self.cursor.execute(query, args) + except Database.IntegrityError, e: + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + except Database.DatabaseError, e: + raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2] + + def executemany(self, query, args): + try: + return self.cursor.executemany(query, args) + except Database.IntegrityError, e: + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + except Database.DatabaseError, e: + raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2] + + def __getattr__(self, attr): + if attr in self.__dict__: + return self.__dict__[attr] + else: + return getattr(self.cursor, attr) + + def __iter__(self): + return iter(self.cursor) + class DatabaseFeatures(BaseDatabaseFeatures): needs_datetime_string_cast = False can_return_id_from_insert = False @@ -64,45 +100,48 @@ super(DatabaseWrapper, self).__init__(*args, **kwargs) self.features = DatabaseFeatures() - autocommit = self.settings_dict["DATABASE_OPTIONS"].get('autocommit', False) + autocommit = self.settings_dict["OPTIONS"].get('autocommit', False) self.features.uses_autocommit = autocommit self._set_isolation_level(int(not autocommit)) - self.ops = DatabaseOperations() + self.ops = DatabaseOperations(self) self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation() + self.validation = BaseDatabaseValidation(self) def _cursor(self): + new_connection = False set_tz = False settings_dict = self.settings_dict if self.connection is None: - set_tz = True - if settings_dict['DATABASE_NAME'] == '': + new_connection = True + set_tz = settings_dict.get('TIME_ZONE') + if settings_dict['NAME'] == '': from django.core.exceptions import ImproperlyConfigured - raise ImproperlyConfigured("You need to specify DATABASE_NAME in your Django settings file.") + raise ImproperlyConfigured("You need to specify NAME in your Django settings file.") conn_params = { - 'database': settings_dict['DATABASE_NAME'], + 'database': settings_dict['NAME'], } - conn_params.update(settings_dict['DATABASE_OPTIONS']) + conn_params.update(settings_dict['OPTIONS']) if 'autocommit' in conn_params: del conn_params['autocommit'] - if settings_dict['DATABASE_USER']: - conn_params['user'] = settings_dict['DATABASE_USER'] - if settings_dict['DATABASE_PASSWORD']: - conn_params['password'] = settings_dict['DATABASE_PASSWORD'] - if settings_dict['DATABASE_HOST']: - conn_params['host'] = settings_dict['DATABASE_HOST'] - if settings_dict['DATABASE_PORT']: - conn_params['port'] = settings_dict['DATABASE_PORT'] + if settings_dict['USER']: + conn_params['user'] = settings_dict['USER'] + if settings_dict['PASSWORD']: + conn_params['password'] = settings_dict['PASSWORD'] + if settings_dict['HOST']: + conn_params['host'] = settings_dict['HOST'] + if settings_dict['PORT']: + conn_params['port'] = settings_dict['PORT'] self.connection = Database.connect(**conn_params) self.connection.set_client_encoding('UTF8') self.connection.set_isolation_level(self.isolation_level) connection_created.send(sender=self.__class__) cursor = self.connection.cursor() cursor.tzinfo_factory = None - if set_tz: - cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']]) + if new_connection: + if set_tz: + cursor.execute("SET TIME ZONE %s", [settings_dict['TIME_ZONE']]) if not hasattr(self, '_version'): self.__class__._version = get_version(cursor) if self._version[0:2] < (8, 0): @@ -119,7 +158,7 @@ # versions that support it, but, right now, that's hard to # do without breaking other things (#10509). self.features.can_return_id_from_insert = True - return cursor + return CursorWrapper(cursor) def _enter_transaction_management(self, managed): """ @@ -150,4 +189,3 @@ finally: self.isolation_level = level self.features.uses_savepoints = bool(level) -