web/lib/django/db/models/sql/where.py
changeset 0 0d40e90630ef
child 29 cc9b7e14412b
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/web/lib/django/db/models/sql/where.py	Wed Jan 20 00:34:04 2010 +0100
@@ -0,0 +1,281 @@
+"""
+Code to manage the creation and SQL rendering of 'where' constraints.
+"""
+import datetime
+
+from django.utils import tree
+from django.db import connection
+from django.db.models.fields import Field
+from django.db.models.query_utils import QueryWrapper
+from datastructures import EmptyResultSet, FullResultSet
+
+# Connection types
+AND = 'AND'
+OR = 'OR'
+
+class EmptyShortCircuit(Exception):
+    """
+    Internal exception used to indicate that a "matches nothing" node should be
+    added to the where-clause.
+    """
+    pass
+
+class WhereNode(tree.Node):
+    """
+    Used to represent the SQL where-clause.
+
+    The class is tied to the Query class that created it (in order to create
+    the correct SQL).
+
+    The children in this tree are usually either Q-like objects or lists of
+    [table_alias, field_name, db_type, lookup_type, value_annotation,
+    params]. However, a child could also be any class with as_sql() and
+    relabel_aliases() methods.
+    """
+    default = AND
+
+    def add(self, data, connector):
+        """
+        Add a node to the where-tree. If the data is a list or tuple, it is
+        expected to be of the form (alias, col_name, field_obj, lookup_type,
+        value), which is then slightly munged before being stored (to avoid
+        storing any reference to field objects). Otherwise, the 'data' is
+        stored unchanged and can be anything with an 'as_sql()' method.
+        """
+        if not isinstance(data, (list, tuple)):
+            super(WhereNode, self).add(data, connector)
+            return
+
+        obj, lookup_type, value = data
+        if hasattr(value, '__iter__') and hasattr(value, 'next'):
+            # Consume any generators immediately, so that we can determine
+            # emptiness and transform any non-empty values correctly.
+            value = list(value)
+        if hasattr(obj, "process"):
+            try:
+                obj, params = obj.process(lookup_type, value)
+            except (EmptyShortCircuit, EmptyResultSet):
+                # There are situations where we want to short-circuit any
+                # comparisons and make sure that nothing is returned. One
+                # example is when checking for a NULL pk value, or the
+                # equivalent.
+                super(WhereNode, self).add(NothingNode(), connector)
+                return
+        else:
+            params = Field().get_db_prep_lookup(lookup_type, value)
+
+        # The "annotation" parameter is used to pass auxilliary information
+        # about the value(s) to the query construction. Specifically, datetime
+        # and empty values need special handling. Other types could be used
+        # here in the future (using Python types is suggested for consistency).
+        if isinstance(value, datetime.datetime):
+            annotation = datetime.datetime
+        elif hasattr(value, 'value_annotation'):
+            annotation = value.value_annotation
+        else:
+            annotation = bool(value)
+
+        super(WhereNode, self).add((obj, lookup_type, annotation, params),
+                connector)
+
+    def as_sql(self, qn=None):
+        """
+        Returns the SQL version of the where clause and the value to be
+        substituted in. Returns None, None if this node is empty.
+
+        If 'node' is provided, that is the root of the SQL generation
+        (generally not needed except by the internal implementation for
+        recursion).
+        """
+        if not qn:
+            qn = connection.ops.quote_name
+        if not self.children:
+            return None, []
+        result = []
+        result_params = []
+        empty = True
+        for child in self.children:
+            try:
+                if hasattr(child, 'as_sql'):
+                    sql, params = child.as_sql(qn=qn)
+                else:
+                    # A leaf node in the tree.
+                    sql, params = self.make_atom(child, qn)
+
+            except EmptyResultSet:
+                if self.connector == AND and not self.negated:
+                    # We can bail out early in this particular case (only).
+                    raise
+                elif self.negated:
+                    empty = False
+                continue
+            except FullResultSet:
+                if self.connector == OR:
+                    if self.negated:
+                        empty = True
+                        break
+                    # We match everything. No need for any constraints.
+                    return '', []
+                if self.negated:
+                    empty = True
+                continue
+
+            empty = False
+            if sql:
+                result.append(sql)
+                result_params.extend(params)
+        if empty:
+            raise EmptyResultSet
+
+        conn = ' %s ' % self.connector
+        sql_string = conn.join(result)
+        if sql_string:
+            if self.negated:
+                sql_string = 'NOT (%s)' % sql_string
+            elif len(self.children) != 1:
+                sql_string = '(%s)' % sql_string
+        return sql_string, result_params
+
+    def make_atom(self, child, qn):
+        """
+        Turn a tuple (table_alias, column_name, db_type, lookup_type,
+        value_annot, params) into valid SQL.
+
+        Returns the string for the SQL fragment and the parameters to use for
+        it.
+        """
+        lvalue, lookup_type, value_annot, params = child
+        if isinstance(lvalue, tuple):
+            # A direct database column lookup.
+            field_sql = self.sql_for_columns(lvalue, qn)
+        else:
+            # A smart object with an as_sql() method.
+            field_sql = lvalue.as_sql(quote_func=qn)
+
+        if value_annot is datetime.datetime:
+            cast_sql = connection.ops.datetime_cast_sql()
+        else:
+            cast_sql = '%s'
+
+        if hasattr(params, 'as_sql'):
+            extra, params = params.as_sql(qn)
+            cast_sql = ''
+        else:
+            extra = ''
+
+        if lookup_type in connection.operators:
+            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
+            return (format % (field_sql,
+                              connection.operators[lookup_type] % cast_sql,
+                              extra), params)
+
+        if lookup_type == 'in':
+            if not value_annot:
+                raise EmptyResultSet
+            if extra:
+                return ('%s IN %s' % (field_sql, extra), params)
+            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))),
+                    params)
+        elif lookup_type in ('range', 'year'):
+            return ('%s BETWEEN %%s and %%s' % field_sql, params)
+        elif lookup_type in ('month', 'day', 'week_day'):
+            return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
+                    params)
+        elif lookup_type == 'isnull':
+            return ('%s IS %sNULL' % (field_sql,
+                (not value_annot and 'NOT ' or '')), ())
+        elif lookup_type == 'search':
+            return (connection.ops.fulltext_search_sql(field_sql), params)
+        elif lookup_type in ('regex', 'iregex'):
+            return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
+
+        raise TypeError('Invalid lookup_type: %r' % lookup_type)
+
+    def sql_for_columns(self, data, qn):
+        """
+        Returns the SQL fragment used for the left-hand side of a column
+        constraint (for example, the "T1.foo" portion in the clause
+        "WHERE ... T1.foo = 6").
+        """
+        table_alias, name, db_type = data
+        if table_alias:
+            lhs = '%s.%s' % (qn(table_alias), qn(name))
+        else:
+            lhs = qn(name)
+        return connection.ops.field_cast_sql(db_type) % lhs
+
+    def relabel_aliases(self, change_map, node=None):
+        """
+        Relabels the alias values of any children. 'change_map' is a dictionary
+        mapping old (current) alias values to the new values.
+        """
+        if not node:
+            node = self
+        for pos, child in enumerate(node.children):
+            if hasattr(child, 'relabel_aliases'):
+                child.relabel_aliases(change_map)
+            elif isinstance(child, tree.Node):
+                self.relabel_aliases(change_map, child)
+            else:
+                if isinstance(child[0], (list, tuple)):
+                    elt = list(child[0])
+                    if elt[0] in change_map:
+                        elt[0] = change_map[elt[0]]
+                        node.children[pos] = (tuple(elt),) + child[1:]
+                else:
+                    child[0].relabel_aliases(change_map)
+
+                # Check if the query value also requires relabelling
+                if hasattr(child[3], 'relabel_aliases'):
+                    child[3].relabel_aliases(change_map)
+
+class EverythingNode(object):
+    """
+    A node that matches everything.
+    """
+    def as_sql(self, qn=None):
+        raise FullResultSet
+
+    def relabel_aliases(self, change_map, node=None):
+        return
+
+class NothingNode(object):
+    """
+    A node that matches nothing.
+    """
+    def as_sql(self, qn=None):
+        raise EmptyResultSet
+
+    def relabel_aliases(self, change_map, node=None):
+        return
+
+class Constraint(object):
+    """
+    An object that can be passed to WhereNode.add() and knows how to
+    pre-process itself prior to including in the WhereNode.
+    """
+    def __init__(self, alias, col, field):
+        self.alias, self.col, self.field = alias, col, field
+
+    def process(self, lookup_type, value):
+        """
+        Returns a tuple of data suitable for inclusion in a WhereNode
+        instance.
+        """
+        # Because of circular imports, we need to import this here.
+        from django.db.models.base import ObjectDoesNotExist
+        try:
+            if self.field:
+                params = self.field.get_db_prep_lookup(lookup_type, value)
+                db_type = self.field.db_type()
+            else:
+                # This branch is used at times when we add a comparison to NULL
+                # (we don't really want to waste time looking up the associated
+                # field object at the calling location).
+                params = Field().get_db_prep_lookup(lookup_type, value)
+                db_type = None
+        except ObjectDoesNotExist:
+            raise EmptyShortCircuit
+
+        return (self.alias, self.col, db_type), params
+