web/lib/django/db/backends/oracle/base.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/db/backends/oracle/base.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/db/backends/oracle/base.py	Tue May 25 02:43:45 2010 +0200
@@ -4,13 +4,12 @@
 Requires cx_Oracle: http://cx-oracle.sourceforge.net/
 """
 
+
+import datetime
 import os
-import datetime
+import sys
 import time
-try:
-    from decimal import Decimal
-except ImportError:
-    from django.utils._decimal import Decimal
+from decimal import Decimal
 
 # Oracle takes client-side character set encoding from the environment.
 os.environ['NLS_LANG'] = '.UTF8'
@@ -24,9 +23,9 @@
     from django.core.exceptions import ImproperlyConfigured
     raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e)
 
+from django.db import utils
 from django.db.backends import *
 from django.db.backends.signals import connection_created
-from django.db.backends.oracle import query
 from django.db.backends.oracle.client import DatabaseClient
 from django.db.backends.oracle.creation import DatabaseCreation
 from django.db.backends.oracle.introspection import DatabaseIntrospection
@@ -47,13 +46,14 @@
 class DatabaseFeatures(BaseDatabaseFeatures):
     empty_fetchmany_value = ()
     needs_datetime_string_cast = False
-    uses_custom_query_class = True
     interprets_empty_strings_as_nulls = True
     uses_savepoints = True
     can_return_id_from_insert = True
+    allow_sliced_subqueries = False
 
 
 class DatabaseOperations(BaseDatabaseOperations):
+    compiler_module = "django.db.backends.oracle.compiler"
 
     def autoinc_sql(self, table, column):
         # To simulate auto-incrementing primary keys in Oracle, we have to
@@ -102,6 +102,54 @@
             sql = "TRUNC(%s, '%s')" % (field_name, lookup_type)
         return sql
 
+    def convert_values(self, value, field):
+        if isinstance(value, Database.LOB):
+            value = value.read()
+            if field and field.get_internal_type() == 'TextField':
+                value = force_unicode(value)
+
+        # Oracle stores empty strings as null. We need to undo this in
+        # order to adhere to the Django convention of using the empty
+        # string instead of null, but only if the field accepts the
+        # empty string.
+        if value is None and field and field.empty_strings_allowed:
+            value = u''
+        # Convert 1 or 0 to True or False
+        elif value in (1, 0) and field and field.get_internal_type() in ('BooleanField', 'NullBooleanField'):
+            value = bool(value)
+        # Force floats to the correct type
+        elif value is not None and field and field.get_internal_type() == 'FloatField':
+            value = float(value)
+        # Convert floats to decimals
+        elif value is not None and field and field.get_internal_type() == 'DecimalField':
+            value = util.typecast_decimal(field.format_number(value))
+        # cx_Oracle always returns datetime.datetime objects for
+        # DATE and TIMESTAMP columns, but Django wants to see a
+        # python datetime.date, .time, or .datetime.  We use the type
+        # of the Field to determine which to cast to, but it's not
+        # always available.
+        # As a workaround, we cast to date if all the time-related
+        # values are 0, or to time if the date is 1/1/1900.
+        # This could be cleaned a bit by adding a method to the Field
+        # classes to normalize values from the database (the to_python
+        # method is used for validation and isn't what we want here).
+        elif isinstance(value, Database.Timestamp):
+            # In Python 2.3, the cx_Oracle driver returns its own
+            # Timestamp object that we must convert to a datetime class.
+            if not isinstance(value, datetime.datetime):
+                value = datetime.datetime(value.year, value.month,
+                        value.day, value.hour, value.minute, value.second,
+                        value.fsecond)
+            if field and field.get_internal_type() == 'DateTimeField':
+                pass
+            elif field and field.get_internal_type() == 'DateField':
+                value = value.date()
+            elif field and field.get_internal_type() == 'TimeField' or (value.year == 1900 and value.month == value.day == 1):
+                value = value.time()
+            elif value.hour == value.minute == value.second == value.microsecond == 0:
+                value = value.date()
+        return value
+
     def datetime_cast_sql(self):
         return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')"
 
@@ -141,9 +189,6 @@
             return u''
         return force_unicode(value.read())
 
-    def query_class(self, DefaultQueryClass):
-        return query.query_class(DefaultQueryClass, Database)
-
     def quote_name(self, name):
         # SQL92 requires delimited (quoted) names to be case-sensitive.  When
         # not quoted, Oracle has case-insensitive behavior for identifiers, but
@@ -270,16 +315,16 @@
     operators = {
         'exact': '= %s',
         'iexact': '= UPPER(%s)',
-        'contains': "LIKEC %s ESCAPE '\\'",
-        'icontains': "LIKEC UPPER(%s) ESCAPE '\\'",
+        'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+        'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
         'gt': '> %s',
         'gte': '>= %s',
         'lt': '< %s',
         'lte': '<= %s',
-        'startswith': "LIKEC %s ESCAPE '\\'",
-        'endswith': "LIKEC %s ESCAPE '\\'",
-        'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'",
-        'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'",
+        'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+        'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+        'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
+        'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)",
     }
     oracle_version = None
 
@@ -291,29 +336,29 @@
         self.client = DatabaseClient(self)
         self.creation = DatabaseCreation(self)
         self.introspection = DatabaseIntrospection(self)
-        self.validation = BaseDatabaseValidation()
+        self.validation = BaseDatabaseValidation(self)
 
     def _valid_connection(self):
         return self.connection is not None
 
     def _connect_string(self):
         settings_dict = self.settings_dict
-        if len(settings_dict['DATABASE_HOST'].strip()) == 0:
-            settings_dict['DATABASE_HOST'] = 'localhost'
-        if len(settings_dict['DATABASE_PORT'].strip()) != 0:
-            dsn = Database.makedsn(settings_dict['DATABASE_HOST'],
-                                   int(settings_dict['DATABASE_PORT']),
-                                   settings_dict['DATABASE_NAME'])
+        if len(settings_dict['HOST'].strip()) == 0:
+            settings_dict['HOST'] = 'localhost'
+        if len(settings_dict['PORT'].strip()) != 0:
+            dsn = Database.makedsn(settings_dict['HOST'],
+                                   int(settings_dict['PORT']),
+                                   settings_dict['NAME'])
         else:
-            dsn = settings_dict['DATABASE_NAME']
-        return "%s/%s@%s" % (settings_dict['DATABASE_USER'],
-                             settings_dict['DATABASE_PASSWORD'], dsn)
+            dsn = settings_dict['NAME']
+        return "%s/%s@%s" % (settings_dict['USER'],
+                             settings_dict['PASSWORD'], dsn)
 
     def _cursor(self):
         cursor = None
         if not self._valid_connection():
             conn_string = convert_unicode(self._connect_string())
-            self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS'])
+            self.connection = Database.connect(conn_string, **self.settings_dict['OPTIONS'])
             cursor = FormatStylePlaceholderCursor(self.connection)
             # Set oracle date to ansi date format.  This only needs to execute
             # once when we create a new connection. We also set the Territory
@@ -375,6 +420,30 @@
             self.input_size = None
 
 
+class VariableWrapper(object):
+    """
+    An adapter class for cursor variables that prevents the wrapped object
+    from being converted into a string when used to instanciate an OracleParam.
+    This can be used generally for any other object that should be passed into
+    Cursor.execute as-is.
+    """
+
+    def __init__(self, var):
+        self.var = var
+
+    def bind_parameter(self, cursor):
+        return self.var
+
+    def __getattr__(self, key):
+        return getattr(self.var, key)
+
+    def __setattr__(self, key, value):
+        if key == 'var':
+            self.__dict__[key] = value
+        else:
+            setattr(self.var, key, value)
+
+
 class InsertIdVar(object):
     """
     A late-binding cursor variable that can be passed to Cursor.execute
@@ -383,7 +452,7 @@
     """
 
     def bind_parameter(self, cursor):
-        param = cursor.var(Database.NUMBER)
+        param = cursor.cursor.var(Database.NUMBER)
         cursor._insert_id_var = param
         return param
 
@@ -436,11 +505,13 @@
         self._guess_input_sizes([params])
         try:
             return self.cursor.execute(query, self._param_generator(params))
-        except DatabaseError, e:
+        except Database.IntegrityError, e:
+            raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+        except Database.DatabaseError, e:
             # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
-            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
-                e = IntegrityError(e.args[0])
-            raise e
+            if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError):
+                raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+            raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2]
 
     def executemany(self, query, params=None):
         try:
@@ -460,67 +531,35 @@
         try:
             return self.cursor.executemany(query,
                                 [self._param_generator(p) for p in formatted])
-        except DatabaseError, e:
+        except Database.IntegrityError, e:
+            raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+        except Database.DatabaseError, e:
             # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400.
-            if e.args[0].code == 1400 and not isinstance(e, IntegrityError):
-                e = IntegrityError(e.args[0])
-            raise e
+            if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError):
+                raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2]
+            raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2]
 
     def fetchone(self):
         row = self.cursor.fetchone()
         if row is None:
             return row
-        return self._rowfactory(row)
+        return _rowfactory(row, self.cursor)
 
     def fetchmany(self, size=None):
         if size is None:
             size = self.arraysize
-        return tuple([self._rowfactory(r)
+        return tuple([_rowfactory(r, self.cursor)
                       for r in self.cursor.fetchmany(size)])
 
     def fetchall(self):
-        return tuple([self._rowfactory(r)
+        return tuple([_rowfactory(r, self.cursor)
                       for r in self.cursor.fetchall()])
 
-    def _rowfactory(self, row):
-        # Cast numeric values as the appropriate Python type based upon the
-        # cursor description, and convert strings to unicode.
-        casted = []
-        for value, desc in zip(row, self.cursor.description):
-            if value is not None and desc[1] is Database.NUMBER:
-                precision, scale = desc[4:6]
-                if scale == -127:
-                    if precision == 0:
-                        # NUMBER column: decimal-precision floating point
-                        # This will normally be an integer from a sequence,
-                        # but it could be a decimal value.
-                        if '.' in value:
-                            value = Decimal(value)
-                        else:
-                            value = int(value)
-                    else:
-                        # FLOAT column: binary-precision floating point.
-                        # This comes from FloatField columns.
-                        value = float(value)
-                elif precision > 0:
-                    # NUMBER(p,s) column: decimal-precision fixed point.
-                    # This comes from IntField and DecimalField columns.
-                    if scale == 0:
-                        value = int(value)
-                    else:
-                        value = Decimal(value)
-                elif '.' in value:
-                    # No type information. This normally comes from a
-                    # mathematical expression in the SELECT list. Guess int
-                    # or Decimal based on whether it has a decimal point.
-                    value = Decimal(value)
-                else:
-                    value = int(value)
-            elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
-                             Database.LONG_STRING):
-                value = to_unicode(value)
-            casted.append(value)
-        return tuple(casted)
+    def var(self, *args):
+        return VariableWrapper(self.cursor.var(*args))
+
+    def arrayvar(self, *args):
+        return VariableWrapper(self.cursor.arrayvar(*args))
 
     def __getattr__(self, attr):
         if attr in self.__dict__:
@@ -529,7 +568,63 @@
             return getattr(self.cursor, attr)
 
     def __iter__(self):
-        return iter(self.cursor)
+        return CursorIterator(self.cursor)
+
+
+class CursorIterator(object):
+
+    """Cursor iterator wrapper that invokes our custom row factory."""
+
+    def __init__(self, cursor):
+        self.cursor = cursor
+        self.iter = iter(cursor)
+
+    def __iter__(self):
+        return self
+
+    def next(self):
+        return _rowfactory(self.iter.next(), self.cursor)
+
+
+def _rowfactory(row, cursor):
+    # Cast numeric values as the appropriate Python type based upon the
+    # cursor description, and convert strings to unicode.
+    casted = []
+    for value, desc in zip(row, cursor.description):
+        if value is not None and desc[1] is Database.NUMBER:
+            precision, scale = desc[4:6]
+            if scale == -127:
+                if precision == 0:
+                    # NUMBER column: decimal-precision floating point
+                    # This will normally be an integer from a sequence,
+                    # but it could be a decimal value.
+                    if '.' in value:
+                        value = Decimal(value)
+                    else:
+                        value = int(value)
+                else:
+                    # FLOAT column: binary-precision floating point.
+                    # This comes from FloatField columns.
+                    value = float(value)
+            elif precision > 0:
+                # NUMBER(p,s) column: decimal-precision fixed point.
+                # This comes from IntField and DecimalField columns.
+                if scale == 0:
+                    value = int(value)
+                else:
+                    value = Decimal(value)
+            elif '.' in value:
+                # No type information. This normally comes from a
+                # mathematical expression in the SELECT list. Guess int
+                # or Decimal based on whether it has a decimal point.
+                value = Decimal(value)
+            else:
+                value = int(value)
+        elif desc[1] in (Database.STRING, Database.FIXED_CHAR,
+                         Database.LONG_STRING):
+            value = to_unicode(value)
+        casted.append(value)
+    return tuple(casted)
 
 
 def to_unicode(s):