diff -r b758351d191f -r cc9b7e14412b web/lib/django/db/backends/oracle/base.py --- a/web/lib/django/db/backends/oracle/base.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/db/backends/oracle/base.py Tue May 25 02:43:45 2010 +0200 @@ -4,13 +4,12 @@ Requires cx_Oracle: http://cx-oracle.sourceforge.net/ """ + +import datetime import os -import datetime +import sys import time -try: - from decimal import Decimal -except ImportError: - from django.utils._decimal import Decimal +from decimal import Decimal # Oracle takes client-side character set encoding from the environment. os.environ['NLS_LANG'] = '.UTF8' @@ -24,9 +23,9 @@ from django.core.exceptions import ImproperlyConfigured raise ImproperlyConfigured("Error loading cx_Oracle module: %s" % e) +from django.db import utils from django.db.backends import * from django.db.backends.signals import connection_created -from django.db.backends.oracle import query from django.db.backends.oracle.client import DatabaseClient from django.db.backends.oracle.creation import DatabaseCreation from django.db.backends.oracle.introspection import DatabaseIntrospection @@ -47,13 +46,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): empty_fetchmany_value = () needs_datetime_string_cast = False - uses_custom_query_class = True interprets_empty_strings_as_nulls = True uses_savepoints = True can_return_id_from_insert = True + allow_sliced_subqueries = False class DatabaseOperations(BaseDatabaseOperations): + compiler_module = "django.db.backends.oracle.compiler" def autoinc_sql(self, table, column): # To simulate auto-incrementing primary keys in Oracle, we have to @@ -102,6 +102,54 @@ sql = "TRUNC(%s, '%s')" % (field_name, lookup_type) return sql + def convert_values(self, value, field): + if isinstance(value, Database.LOB): + value = value.read() + if field and field.get_internal_type() == 'TextField': + value = force_unicode(value) + + # Oracle stores empty strings as null. We need to undo this in + # order to adhere to the Django convention of using the empty + # string instead of null, but only if the field accepts the + # empty string. + if value is None and field and field.empty_strings_allowed: + value = u'' + # Convert 1 or 0 to True or False + elif value in (1, 0) and field and field.get_internal_type() in ('BooleanField', 'NullBooleanField'): + value = bool(value) + # Force floats to the correct type + elif value is not None and field and field.get_internal_type() == 'FloatField': + value = float(value) + # Convert floats to decimals + elif value is not None and field and field.get_internal_type() == 'DecimalField': + value = util.typecast_decimal(field.format_number(value)) + # cx_Oracle always returns datetime.datetime objects for + # DATE and TIMESTAMP columns, but Django wants to see a + # python datetime.date, .time, or .datetime. We use the type + # of the Field to determine which to cast to, but it's not + # always available. + # As a workaround, we cast to date if all the time-related + # values are 0, or to time if the date is 1/1/1900. + # This could be cleaned a bit by adding a method to the Field + # classes to normalize values from the database (the to_python + # method is used for validation and isn't what we want here). + elif isinstance(value, Database.Timestamp): + # In Python 2.3, the cx_Oracle driver returns its own + # Timestamp object that we must convert to a datetime class. + if not isinstance(value, datetime.datetime): + value = datetime.datetime(value.year, value.month, + value.day, value.hour, value.minute, value.second, + value.fsecond) + if field and field.get_internal_type() == 'DateTimeField': + pass + elif field and field.get_internal_type() == 'DateField': + value = value.date() + elif field and field.get_internal_type() == 'TimeField' or (value.year == 1900 and value.month == value.day == 1): + value = value.time() + elif value.hour == value.minute == value.second == value.microsecond == 0: + value = value.date() + return value + def datetime_cast_sql(self): return "TO_TIMESTAMP(%s, 'YYYY-MM-DD HH24:MI:SS.FF')" @@ -141,9 +189,6 @@ return u'' return force_unicode(value.read()) - def query_class(self, DefaultQueryClass): - return query.query_class(DefaultQueryClass, Database) - def quote_name(self, name): # SQL92 requires delimited (quoted) names to be case-sensitive. When # not quoted, Oracle has case-insensitive behavior for identifiers, but @@ -270,16 +315,16 @@ operators = { 'exact': '= %s', 'iexact': '= UPPER(%s)', - 'contains': "LIKEC %s ESCAPE '\\'", - 'icontains': "LIKEC UPPER(%s) ESCAPE '\\'", + 'contains': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + 'icontains': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", 'gt': '> %s', 'gte': '>= %s', 'lt': '< %s', 'lte': '<= %s', - 'startswith': "LIKEC %s ESCAPE '\\'", - 'endswith': "LIKEC %s ESCAPE '\\'", - 'istartswith': "LIKEC UPPER(%s) ESCAPE '\\'", - 'iendswith': "LIKEC UPPER(%s) ESCAPE '\\'", + 'startswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + 'endswith': "LIKE TRANSLATE(%s USING NCHAR_CS) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + 'istartswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", + 'iendswith': "LIKE UPPER(TRANSLATE(%s USING NCHAR_CS)) ESCAPE TRANSLATE('\\' USING NCHAR_CS)", } oracle_version = None @@ -291,29 +336,29 @@ self.client = DatabaseClient(self) self.creation = DatabaseCreation(self) self.introspection = DatabaseIntrospection(self) - self.validation = BaseDatabaseValidation() + self.validation = BaseDatabaseValidation(self) def _valid_connection(self): return self.connection is not None def _connect_string(self): settings_dict = self.settings_dict - if len(settings_dict['DATABASE_HOST'].strip()) == 0: - settings_dict['DATABASE_HOST'] = 'localhost' - if len(settings_dict['DATABASE_PORT'].strip()) != 0: - dsn = Database.makedsn(settings_dict['DATABASE_HOST'], - int(settings_dict['DATABASE_PORT']), - settings_dict['DATABASE_NAME']) + if len(settings_dict['HOST'].strip()) == 0: + settings_dict['HOST'] = 'localhost' + if len(settings_dict['PORT'].strip()) != 0: + dsn = Database.makedsn(settings_dict['HOST'], + int(settings_dict['PORT']), + settings_dict['NAME']) else: - dsn = settings_dict['DATABASE_NAME'] - return "%s/%s@%s" % (settings_dict['DATABASE_USER'], - settings_dict['DATABASE_PASSWORD'], dsn) + dsn = settings_dict['NAME'] + return "%s/%s@%s" % (settings_dict['USER'], + settings_dict['PASSWORD'], dsn) def _cursor(self): cursor = None if not self._valid_connection(): conn_string = convert_unicode(self._connect_string()) - self.connection = Database.connect(conn_string, **self.settings_dict['DATABASE_OPTIONS']) + self.connection = Database.connect(conn_string, **self.settings_dict['OPTIONS']) cursor = FormatStylePlaceholderCursor(self.connection) # Set oracle date to ansi date format. This only needs to execute # once when we create a new connection. We also set the Territory @@ -375,6 +420,30 @@ self.input_size = None +class VariableWrapper(object): + """ + An adapter class for cursor variables that prevents the wrapped object + from being converted into a string when used to instanciate an OracleParam. + This can be used generally for any other object that should be passed into + Cursor.execute as-is. + """ + + def __init__(self, var): + self.var = var + + def bind_parameter(self, cursor): + return self.var + + def __getattr__(self, key): + return getattr(self.var, key) + + def __setattr__(self, key, value): + if key == 'var': + self.__dict__[key] = value + else: + setattr(self.var, key, value) + + class InsertIdVar(object): """ A late-binding cursor variable that can be passed to Cursor.execute @@ -383,7 +452,7 @@ """ def bind_parameter(self, cursor): - param = cursor.var(Database.NUMBER) + param = cursor.cursor.var(Database.NUMBER) cursor._insert_id_var = param return param @@ -436,11 +505,13 @@ self._guess_input_sizes([params]) try: return self.cursor.execute(query, self._param_generator(params)) - except DatabaseError, e: + except Database.IntegrityError, e: + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + except Database.DatabaseError, e: # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400. - if e.args[0].code == 1400 and not isinstance(e, IntegrityError): - e = IntegrityError(e.args[0]) - raise e + if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError): + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2] def executemany(self, query, params=None): try: @@ -460,67 +531,35 @@ try: return self.cursor.executemany(query, [self._param_generator(p) for p in formatted]) - except DatabaseError, e: + except Database.IntegrityError, e: + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + except Database.DatabaseError, e: # cx_Oracle <= 4.4.0 wrongly raises a DatabaseError for ORA-01400. - if e.args[0].code == 1400 and not isinstance(e, IntegrityError): - e = IntegrityError(e.args[0]) - raise e + if hasattr(e.args[0], 'code') and e.args[0].code == 1400 and not isinstance(e, IntegrityError): + raise utils.IntegrityError, utils.IntegrityError(*tuple(e)), sys.exc_info()[2] + raise utils.DatabaseError, utils.DatabaseError(*tuple(e)), sys.exc_info()[2] def fetchone(self): row = self.cursor.fetchone() if row is None: return row - return self._rowfactory(row) + return _rowfactory(row, self.cursor) def fetchmany(self, size=None): if size is None: size = self.arraysize - return tuple([self._rowfactory(r) + return tuple([_rowfactory(r, self.cursor) for r in self.cursor.fetchmany(size)]) def fetchall(self): - return tuple([self._rowfactory(r) + return tuple([_rowfactory(r, self.cursor) for r in self.cursor.fetchall()]) - def _rowfactory(self, row): - # Cast numeric values as the appropriate Python type based upon the - # cursor description, and convert strings to unicode. - casted = [] - for value, desc in zip(row, self.cursor.description): - if value is not None and desc[1] is Database.NUMBER: - precision, scale = desc[4:6] - if scale == -127: - if precision == 0: - # NUMBER column: decimal-precision floating point - # This will normally be an integer from a sequence, - # but it could be a decimal value. - if '.' in value: - value = Decimal(value) - else: - value = int(value) - else: - # FLOAT column: binary-precision floating point. - # This comes from FloatField columns. - value = float(value) - elif precision > 0: - # NUMBER(p,s) column: decimal-precision fixed point. - # This comes from IntField and DecimalField columns. - if scale == 0: - value = int(value) - else: - value = Decimal(value) - elif '.' in value: - # No type information. This normally comes from a - # mathematical expression in the SELECT list. Guess int - # or Decimal based on whether it has a decimal point. - value = Decimal(value) - else: - value = int(value) - elif desc[1] in (Database.STRING, Database.FIXED_CHAR, - Database.LONG_STRING): - value = to_unicode(value) - casted.append(value) - return tuple(casted) + def var(self, *args): + return VariableWrapper(self.cursor.var(*args)) + + def arrayvar(self, *args): + return VariableWrapper(self.cursor.arrayvar(*args)) def __getattr__(self, attr): if attr in self.__dict__: @@ -529,7 +568,63 @@ return getattr(self.cursor, attr) def __iter__(self): - return iter(self.cursor) + return CursorIterator(self.cursor) + + +class CursorIterator(object): + + """Cursor iterator wrapper that invokes our custom row factory.""" + + def __init__(self, cursor): + self.cursor = cursor + self.iter = iter(cursor) + + def __iter__(self): + return self + + def next(self): + return _rowfactory(self.iter.next(), self.cursor) + + +def _rowfactory(row, cursor): + # Cast numeric values as the appropriate Python type based upon the + # cursor description, and convert strings to unicode. + casted = [] + for value, desc in zip(row, cursor.description): + if value is not None and desc[1] is Database.NUMBER: + precision, scale = desc[4:6] + if scale == -127: + if precision == 0: + # NUMBER column: decimal-precision floating point + # This will normally be an integer from a sequence, + # but it could be a decimal value. + if '.' in value: + value = Decimal(value) + else: + value = int(value) + else: + # FLOAT column: binary-precision floating point. + # This comes from FloatField columns. + value = float(value) + elif precision > 0: + # NUMBER(p,s) column: decimal-precision fixed point. + # This comes from IntField and DecimalField columns. + if scale == 0: + value = int(value) + else: + value = Decimal(value) + elif '.' in value: + # No type information. This normally comes from a + # mathematical expression in the SELECT list. Guess int + # or Decimal based on whether it has a decimal point. + value = Decimal(value) + else: + value = int(value) + elif desc[1] in (Database.STRING, Database.FIXED_CHAR, + Database.LONG_STRING): + value = to_unicode(value) + casted.append(value) + return tuple(casted) def to_unicode(s):