web/lib/django/db/models/sql/where.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/db/models/sql/where.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/db/models/sql/where.py	Tue May 25 02:43:45 2010 +0200
@@ -4,7 +4,6 @@
 import datetime
 
 from django.utils import tree
-from django.db import connection
 from django.db.models.fields import Field
 from django.db.models.query_utils import QueryWrapper
 from datastructures import EmptyResultSet, FullResultSet
@@ -51,18 +50,6 @@
             # Consume any generators immediately, so that we can determine
             # emptiness and transform any non-empty values correctly.
             value = list(value)
-        if hasattr(obj, "process"):
-            try:
-                obj, params = obj.process(lookup_type, value)
-            except (EmptyShortCircuit, EmptyResultSet):
-                # There are situations where we want to short-circuit any
-                # comparisons and make sure that nothing is returned. One
-                # example is when checking for a NULL pk value, or the
-                # equivalent.
-                super(WhereNode, self).add(NothingNode(), connector)
-                return
-        else:
-            params = Field().get_db_prep_lookup(lookup_type, value)
 
         # The "annotation" parameter is used to pass auxilliary information
         # about the value(s) to the query construction. Specifically, datetime
@@ -75,10 +62,16 @@
         else:
             annotation = bool(value)
 
-        super(WhereNode, self).add((obj, lookup_type, annotation, params),
+        if hasattr(obj, "prepare"):
+            value = obj.prepare(lookup_type, value)
+            super(WhereNode, self).add((obj, lookup_type, annotation, value),
+                connector)
+            return
+
+        super(WhereNode, self).add((obj, lookup_type, annotation, value),
                 connector)
 
-    def as_sql(self, qn=None):
+    def as_sql(self, qn, connection):
         """
         Returns the SQL version of the where clause and the value to be
         substituted in. Returns None, None if this node is empty.
@@ -87,8 +80,6 @@
         (generally not needed except by the internal implementation for
         recursion).
         """
-        if not qn:
-            qn = connection.ops.quote_name
         if not self.children:
             return None, []
         result = []
@@ -97,10 +88,10 @@
         for child in self.children:
             try:
                 if hasattr(child, 'as_sql'):
-                    sql, params = child.as_sql(qn=qn)
+                    sql, params = child.as_sql(qn=qn, connection=connection)
                 else:
                     # A leaf node in the tree.
-                    sql, params = self.make_atom(child, qn)
+                    sql, params = self.make_atom(child, qn, connection)
 
             except EmptyResultSet:
                 if self.connector == AND and not self.negated:
@@ -136,7 +127,7 @@
                 sql_string = '(%s)' % sql_string
         return sql_string, result_params
 
-    def make_atom(self, child, qn):
+    def make_atom(self, child, qn, connection):
         """
         Turn a tuple (table_alias, column_name, db_type, lookup_type,
         value_annot, params) into valid SQL.
@@ -144,13 +135,21 @@
         Returns the string for the SQL fragment and the parameters to use for
         it.
         """
-        lvalue, lookup_type, value_annot, params = child
+        lvalue, lookup_type, value_annot, params_or_value = child
+        if hasattr(lvalue, 'process'):
+            try:
+                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
+            except EmptyShortCircuit:
+                raise EmptyResultSet
+        else:
+            params = Field().get_db_prep_lookup(lookup_type, params_or_value,
+                connection=connection, prepared=True)
         if isinstance(lvalue, tuple):
             # A direct database column lookup.
-            field_sql = self.sql_for_columns(lvalue, qn)
+            field_sql = self.sql_for_columns(lvalue, qn, connection)
         else:
             # A smart object with an as_sql() method.
-            field_sql = lvalue.as_sql(quote_func=qn)
+            field_sql = lvalue.as_sql(qn, connection)
 
         if value_annot is datetime.datetime:
             cast_sql = connection.ops.datetime_cast_sql()
@@ -158,11 +157,16 @@
             cast_sql = '%s'
 
         if hasattr(params, 'as_sql'):
-            extra, params = params.as_sql(qn)
+            extra, params = params.as_sql(qn, connection)
             cast_sql = ''
         else:
             extra = ''
 
+        if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
+            and connection.features.interprets_empty_strings_as_nulls):
+            lookup_type = 'isnull'
+            value_annot = True
+
         if lookup_type in connection.operators:
             format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
             return (format % (field_sql,
@@ -191,7 +195,7 @@
 
         raise TypeError('Invalid lookup_type: %r' % lookup_type)
 
-    def sql_for_columns(self, data, qn):
+    def sql_for_columns(self, data, qn, connection):
         """
         Returns the SQL fragment used for the left-hand side of a column
         constraint (for example, the "T1.foo" portion in the clause
@@ -216,7 +220,7 @@
                 child.relabel_aliases(change_map)
             elif isinstance(child, tree.Node):
                 self.relabel_aliases(change_map, child)
-            else:
+            elif isinstance(child, (list, tuple)):
                 if isinstance(child[0], (list, tuple)):
                     elt = list(child[0])
                     if elt[0] in change_map:
@@ -233,7 +237,8 @@
     """
     A node that matches everything.
     """
-    def as_sql(self, qn=None):
+
+    def as_sql(self, qn=None, connection=None):
         raise FullResultSet
 
     def relabel_aliases(self, change_map, node=None):
@@ -243,12 +248,20 @@
     """
     A node that matches nothing.
     """
-    def as_sql(self, qn=None):
+    def as_sql(self, qn=None, connection=None):
         raise EmptyResultSet
 
     def relabel_aliases(self, change_map, node=None):
         return
 
+class ExtraWhere(object):
+    def __init__(self, sqls, params):
+        self.sqls = sqls
+        self.params = params
+
+    def as_sql(self, qn=None, connection=None):
+        return " AND ".join(self.sqls), tuple(self.params or ())
+
 class Constraint(object):
     """
     An object that can be passed to WhereNode.add() and knows how to
@@ -257,7 +270,36 @@
     def __init__(self, alias, col, field):
         self.alias, self.col, self.field = alias, col, field
 
-    def process(self, lookup_type, value):
+    def __getstate__(self):
+        """Save the state of the Constraint for pickling.
+
+        Fields aren't necessarily pickleable, because they can have
+        callable default values. So, instead of pickling the field
+        store a reference so we can restore it manually
+        """
+        obj_dict = self.__dict__.copy()
+        if self.field:
+            obj_dict['model'] = self.field.model
+            obj_dict['field_name'] = self.field.name
+        del obj_dict['field']
+        return obj_dict
+
+    def __setstate__(self, data):
+        """Restore the constraint """
+        model = data.pop('model', None)
+        field_name = data.pop('field_name', None)
+        self.__dict__.update(data)
+        if model is not None:
+            self.field = model._meta.get_field(field_name)
+        else:
+            self.field = None
+
+    def prepare(self, lookup_type, value):
+        if self.field:
+            return self.field.get_prep_lookup(lookup_type, value)
+        return value
+
+    def process(self, lookup_type, value, connection):
         """
         Returns a tuple of data suitable for inclusion in a WhereNode
         instance.
@@ -266,16 +308,21 @@
         from django.db.models.base import ObjectDoesNotExist
         try:
             if self.field:
-                params = self.field.get_db_prep_lookup(lookup_type, value)
-                db_type = self.field.db_type()
+                params = self.field.get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
+                db_type = self.field.db_type(connection=connection)
             else:
                 # This branch is used at times when we add a comparison to NULL
                 # (we don't really want to waste time looking up the associated
                 # field object at the calling location).
-                params = Field().get_db_prep_lookup(lookup_type, value)
+                params = Field().get_db_prep_lookup(lookup_type, value,
+                    connection=connection, prepared=True)
                 db_type = None
         except ObjectDoesNotExist:
             raise EmptyShortCircuit
 
         return (self.alias, self.col, db_type), params
 
+    def relabel_aliases(self, change_map):
+        if self.alias in change_map:
+            self.alias = change_map[self.alias]