diff -r b758351d191f -r cc9b7e14412b web/lib/django/db/models/sql/where.py --- a/web/lib/django/db/models/sql/where.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/db/models/sql/where.py Tue May 25 02:43:45 2010 +0200 @@ -4,7 +4,6 @@ import datetime from django.utils import tree -from django.db import connection from django.db.models.fields import Field from django.db.models.query_utils import QueryWrapper from datastructures import EmptyResultSet, FullResultSet @@ -51,18 +50,6 @@ # Consume any generators immediately, so that we can determine # emptiness and transform any non-empty values correctly. value = list(value) - if hasattr(obj, "process"): - try: - obj, params = obj.process(lookup_type, value) - except (EmptyShortCircuit, EmptyResultSet): - # There are situations where we want to short-circuit any - # comparisons and make sure that nothing is returned. One - # example is when checking for a NULL pk value, or the - # equivalent. - super(WhereNode, self).add(NothingNode(), connector) - return - else: - params = Field().get_db_prep_lookup(lookup_type, value) # The "annotation" parameter is used to pass auxilliary information # about the value(s) to the query construction. Specifically, datetime @@ -75,10 +62,16 @@ else: annotation = bool(value) - super(WhereNode, self).add((obj, lookup_type, annotation, params), + if hasattr(obj, "prepare"): + value = obj.prepare(lookup_type, value) + super(WhereNode, self).add((obj, lookup_type, annotation, value), + connector) + return + + super(WhereNode, self).add((obj, lookup_type, annotation, value), connector) - def as_sql(self, qn=None): + def as_sql(self, qn, connection): """ Returns the SQL version of the where clause and the value to be substituted in. Returns None, None if this node is empty. @@ -87,8 +80,6 @@ (generally not needed except by the internal implementation for recursion). """ - if not qn: - qn = connection.ops.quote_name if not self.children: return None, [] result = [] @@ -97,10 +88,10 @@ for child in self.children: try: if hasattr(child, 'as_sql'): - sql, params = child.as_sql(qn=qn) + sql, params = child.as_sql(qn=qn, connection=connection) else: # A leaf node in the tree. - sql, params = self.make_atom(child, qn) + sql, params = self.make_atom(child, qn, connection) except EmptyResultSet: if self.connector == AND and not self.negated: @@ -136,7 +127,7 @@ sql_string = '(%s)' % sql_string return sql_string, result_params - def make_atom(self, child, qn): + def make_atom(self, child, qn, connection): """ Turn a tuple (table_alias, column_name, db_type, lookup_type, value_annot, params) into valid SQL. @@ -144,13 +135,21 @@ Returns the string for the SQL fragment and the parameters to use for it. """ - lvalue, lookup_type, value_annot, params = child + lvalue, lookup_type, value_annot, params_or_value = child + if hasattr(lvalue, 'process'): + try: + lvalue, params = lvalue.process(lookup_type, params_or_value, connection) + except EmptyShortCircuit: + raise EmptyResultSet + else: + params = Field().get_db_prep_lookup(lookup_type, params_or_value, + connection=connection, prepared=True) if isinstance(lvalue, tuple): # A direct database column lookup. - field_sql = self.sql_for_columns(lvalue, qn) + field_sql = self.sql_for_columns(lvalue, qn, connection) else: # A smart object with an as_sql() method. - field_sql = lvalue.as_sql(quote_func=qn) + field_sql = lvalue.as_sql(qn, connection) if value_annot is datetime.datetime: cast_sql = connection.ops.datetime_cast_sql() @@ -158,11 +157,16 @@ cast_sql = '%s' if hasattr(params, 'as_sql'): - extra, params = params.as_sql(qn) + extra, params = params.as_sql(qn, connection) cast_sql = '' else: extra = '' + if (len(params) == 1 and params[0] == '' and lookup_type == 'exact' + and connection.features.interprets_empty_strings_as_nulls): + lookup_type = 'isnull' + value_annot = True + if lookup_type in connection.operators: format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) return (format % (field_sql, @@ -191,7 +195,7 @@ raise TypeError('Invalid lookup_type: %r' % lookup_type) - def sql_for_columns(self, data, qn): + def sql_for_columns(self, data, qn, connection): """ Returns the SQL fragment used for the left-hand side of a column constraint (for example, the "T1.foo" portion in the clause @@ -216,7 +220,7 @@ child.relabel_aliases(change_map) elif isinstance(child, tree.Node): self.relabel_aliases(change_map, child) - else: + elif isinstance(child, (list, tuple)): if isinstance(child[0], (list, tuple)): elt = list(child[0]) if elt[0] in change_map: @@ -233,7 +237,8 @@ """ A node that matches everything. """ - def as_sql(self, qn=None): + + def as_sql(self, qn=None, connection=None): raise FullResultSet def relabel_aliases(self, change_map, node=None): @@ -243,12 +248,20 @@ """ A node that matches nothing. """ - def as_sql(self, qn=None): + def as_sql(self, qn=None, connection=None): raise EmptyResultSet def relabel_aliases(self, change_map, node=None): return +class ExtraWhere(object): + def __init__(self, sqls, params): + self.sqls = sqls + self.params = params + + def as_sql(self, qn=None, connection=None): + return " AND ".join(self.sqls), tuple(self.params or ()) + class Constraint(object): """ An object that can be passed to WhereNode.add() and knows how to @@ -257,7 +270,36 @@ def __init__(self, alias, col, field): self.alias, self.col, self.field = alias, col, field - def process(self, lookup_type, value): + def __getstate__(self): + """Save the state of the Constraint for pickling. + + Fields aren't necessarily pickleable, because they can have + callable default values. So, instead of pickling the field + store a reference so we can restore it manually + """ + obj_dict = self.__dict__.copy() + if self.field: + obj_dict['model'] = self.field.model + obj_dict['field_name'] = self.field.name + del obj_dict['field'] + return obj_dict + + def __setstate__(self, data): + """Restore the constraint """ + model = data.pop('model', None) + field_name = data.pop('field_name', None) + self.__dict__.update(data) + if model is not None: + self.field = model._meta.get_field(field_name) + else: + self.field = None + + def prepare(self, lookup_type, value): + if self.field: + return self.field.get_prep_lookup(lookup_type, value) + return value + + def process(self, lookup_type, value, connection): """ Returns a tuple of data suitable for inclusion in a WhereNode instance. @@ -266,16 +308,21 @@ from django.db.models.base import ObjectDoesNotExist try: if self.field: - params = self.field.get_db_prep_lookup(lookup_type, value) - db_type = self.field.db_type() + params = self.field.get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) + db_type = self.field.db_type(connection=connection) else: # This branch is used at times when we add a comparison to NULL # (we don't really want to waste time looking up the associated # field object at the calling location). - params = Field().get_db_prep_lookup(lookup_type, value) + params = Field().get_db_prep_lookup(lookup_type, value, + connection=connection, prepared=True) db_type = None except ObjectDoesNotExist: raise EmptyShortCircuit return (self.alias, self.col, db_type), params + def relabel_aliases(self, change_map): + if self.alias in change_map: + self.alias = change_map[self.alias]