web/lib/django/db/models/sql/where.py
changeset 38 77b6da96e6f1
equal deleted inserted replaced
37:8d941af65caf 38:77b6da96e6f1
       
     1 """
       
     2 Code to manage the creation and SQL rendering of 'where' constraints.
       
     3 """
       
     4 import datetime
       
     5 
       
     6 from django.utils import tree
       
     7 from django.db.models.fields import Field
       
     8 from django.db.models.query_utils import QueryWrapper
       
     9 from datastructures import EmptyResultSet, FullResultSet
       
    10 
       
    11 # Connection types
       
    12 AND = 'AND'
       
    13 OR = 'OR'
       
    14 
       
    15 class EmptyShortCircuit(Exception):
       
    16     """
       
    17     Internal exception used to indicate that a "matches nothing" node should be
       
    18     added to the where-clause.
       
    19     """
       
    20     pass
       
    21 
       
    22 class WhereNode(tree.Node):
       
    23     """
       
    24     Used to represent the SQL where-clause.
       
    25 
       
    26     The class is tied to the Query class that created it (in order to create
       
    27     the correct SQL).
       
    28 
       
    29     The children in this tree are usually either Q-like objects or lists of
       
    30     [table_alias, field_name, db_type, lookup_type, value_annotation,
       
    31     params]. However, a child could also be any class with as_sql() and
       
    32     relabel_aliases() methods.
       
    33     """
       
    34     default = AND
       
    35 
       
    36     def add(self, data, connector):
       
    37         """
       
    38         Add a node to the where-tree. If the data is a list or tuple, it is
       
    39         expected to be of the form (alias, col_name, field_obj, lookup_type,
       
    40         value), which is then slightly munged before being stored (to avoid
       
    41         storing any reference to field objects). Otherwise, the 'data' is
       
    42         stored unchanged and can be anything with an 'as_sql()' method.
       
    43         """
       
    44         if not isinstance(data, (list, tuple)):
       
    45             super(WhereNode, self).add(data, connector)
       
    46             return
       
    47 
       
    48         obj, lookup_type, value = data
       
    49         if hasattr(value, '__iter__') and hasattr(value, 'next'):
       
    50             # Consume any generators immediately, so that we can determine
       
    51             # emptiness and transform any non-empty values correctly.
       
    52             value = list(value)
       
    53 
       
    54         # The "annotation" parameter is used to pass auxilliary information
       
    55         # about the value(s) to the query construction. Specifically, datetime
       
    56         # and empty values need special handling. Other types could be used
       
    57         # here in the future (using Python types is suggested for consistency).
       
    58         if isinstance(value, datetime.datetime):
       
    59             annotation = datetime.datetime
       
    60         elif hasattr(value, 'value_annotation'):
       
    61             annotation = value.value_annotation
       
    62         else:
       
    63             annotation = bool(value)
       
    64 
       
    65         if hasattr(obj, "prepare"):
       
    66             value = obj.prepare(lookup_type, value)
       
    67             super(WhereNode, self).add((obj, lookup_type, annotation, value),
       
    68                 connector)
       
    69             return
       
    70 
       
    71         super(WhereNode, self).add((obj, lookup_type, annotation, value),
       
    72                 connector)
       
    73 
       
    74     def as_sql(self, qn, connection):
       
    75         """
       
    76         Returns the SQL version of the where clause and the value to be
       
    77         substituted in. Returns None, None if this node is empty.
       
    78 
       
    79         If 'node' is provided, that is the root of the SQL generation
       
    80         (generally not needed except by the internal implementation for
       
    81         recursion).
       
    82         """
       
    83         if not self.children:
       
    84             return None, []
       
    85         result = []
       
    86         result_params = []
       
    87         empty = True
       
    88         for child in self.children:
       
    89             try:
       
    90                 if hasattr(child, 'as_sql'):
       
    91                     sql, params = child.as_sql(qn=qn, connection=connection)
       
    92                 else:
       
    93                     # A leaf node in the tree.
       
    94                     sql, params = self.make_atom(child, qn, connection)
       
    95 
       
    96             except EmptyResultSet:
       
    97                 if self.connector == AND and not self.negated:
       
    98                     # We can bail out early in this particular case (only).
       
    99                     raise
       
   100                 elif self.negated:
       
   101                     empty = False
       
   102                 continue
       
   103             except FullResultSet:
       
   104                 if self.connector == OR:
       
   105                     if self.negated:
       
   106                         empty = True
       
   107                         break
       
   108                     # We match everything. No need for any constraints.
       
   109                     return '', []
       
   110                 if self.negated:
       
   111                     empty = True
       
   112                 continue
       
   113 
       
   114             empty = False
       
   115             if sql:
       
   116                 result.append(sql)
       
   117                 result_params.extend(params)
       
   118         if empty:
       
   119             raise EmptyResultSet
       
   120 
       
   121         conn = ' %s ' % self.connector
       
   122         sql_string = conn.join(result)
       
   123         if sql_string:
       
   124             if self.negated:
       
   125                 sql_string = 'NOT (%s)' % sql_string
       
   126             elif len(self.children) != 1:
       
   127                 sql_string = '(%s)' % sql_string
       
   128         return sql_string, result_params
       
   129 
       
   130     def make_atom(self, child, qn, connection):
       
   131         """
       
   132         Turn a tuple (table_alias, column_name, db_type, lookup_type,
       
   133         value_annot, params) into valid SQL.
       
   134 
       
   135         Returns the string for the SQL fragment and the parameters to use for
       
   136         it.
       
   137         """
       
   138         lvalue, lookup_type, value_annot, params_or_value = child
       
   139         if hasattr(lvalue, 'process'):
       
   140             try:
       
   141                 lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
       
   142             except EmptyShortCircuit:
       
   143                 raise EmptyResultSet
       
   144         else:
       
   145             params = Field().get_db_prep_lookup(lookup_type, params_or_value,
       
   146                 connection=connection, prepared=True)
       
   147         if isinstance(lvalue, tuple):
       
   148             # A direct database column lookup.
       
   149             field_sql = self.sql_for_columns(lvalue, qn, connection)
       
   150         else:
       
   151             # A smart object with an as_sql() method.
       
   152             field_sql = lvalue.as_sql(qn, connection)
       
   153 
       
   154         if value_annot is datetime.datetime:
       
   155             cast_sql = connection.ops.datetime_cast_sql()
       
   156         else:
       
   157             cast_sql = '%s'
       
   158 
       
   159         if hasattr(params, 'as_sql'):
       
   160             extra, params = params.as_sql(qn, connection)
       
   161             cast_sql = ''
       
   162         else:
       
   163             extra = ''
       
   164 
       
   165         if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
       
   166             and connection.features.interprets_empty_strings_as_nulls):
       
   167             lookup_type = 'isnull'
       
   168             value_annot = True
       
   169 
       
   170         if lookup_type in connection.operators:
       
   171             format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
       
   172             return (format % (field_sql,
       
   173                               connection.operators[lookup_type] % cast_sql,
       
   174                               extra), params)
       
   175 
       
   176         if lookup_type == 'in':
       
   177             if not value_annot:
       
   178                 raise EmptyResultSet
       
   179             if extra:
       
   180                 return ('%s IN %s' % (field_sql, extra), params)
       
   181             return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(params))),
       
   182                     params)
       
   183         elif lookup_type in ('range', 'year'):
       
   184             return ('%s BETWEEN %%s and %%s' % field_sql, params)
       
   185         elif lookup_type in ('month', 'day', 'week_day'):
       
   186             return ('%s = %%s' % connection.ops.date_extract_sql(lookup_type, field_sql),
       
   187                     params)
       
   188         elif lookup_type == 'isnull':
       
   189             return ('%s IS %sNULL' % (field_sql,
       
   190                 (not value_annot and 'NOT ' or '')), ())
       
   191         elif lookup_type == 'search':
       
   192             return (connection.ops.fulltext_search_sql(field_sql), params)
       
   193         elif lookup_type in ('regex', 'iregex'):
       
   194             return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
       
   195 
       
   196         raise TypeError('Invalid lookup_type: %r' % lookup_type)
       
   197 
       
   198     def sql_for_columns(self, data, qn, connection):
       
   199         """
       
   200         Returns the SQL fragment used for the left-hand side of a column
       
   201         constraint (for example, the "T1.foo" portion in the clause
       
   202         "WHERE ... T1.foo = 6").
       
   203         """
       
   204         table_alias, name, db_type = data
       
   205         if table_alias:
       
   206             lhs = '%s.%s' % (qn(table_alias), qn(name))
       
   207         else:
       
   208             lhs = qn(name)
       
   209         return connection.ops.field_cast_sql(db_type) % lhs
       
   210 
       
   211     def relabel_aliases(self, change_map, node=None):
       
   212         """
       
   213         Relabels the alias values of any children. 'change_map' is a dictionary
       
   214         mapping old (current) alias values to the new values.
       
   215         """
       
   216         if not node:
       
   217             node = self
       
   218         for pos, child in enumerate(node.children):
       
   219             if hasattr(child, 'relabel_aliases'):
       
   220                 child.relabel_aliases(change_map)
       
   221             elif isinstance(child, tree.Node):
       
   222                 self.relabel_aliases(change_map, child)
       
   223             elif isinstance(child, (list, tuple)):
       
   224                 if isinstance(child[0], (list, tuple)):
       
   225                     elt = list(child[0])
       
   226                     if elt[0] in change_map:
       
   227                         elt[0] = change_map[elt[0]]
       
   228                         node.children[pos] = (tuple(elt),) + child[1:]
       
   229                 else:
       
   230                     child[0].relabel_aliases(change_map)
       
   231 
       
   232                 # Check if the query value also requires relabelling
       
   233                 if hasattr(child[3], 'relabel_aliases'):
       
   234                     child[3].relabel_aliases(change_map)
       
   235 
       
   236 class EverythingNode(object):
       
   237     """
       
   238     A node that matches everything.
       
   239     """
       
   240 
       
   241     def as_sql(self, qn=None, connection=None):
       
   242         raise FullResultSet
       
   243 
       
   244     def relabel_aliases(self, change_map, node=None):
       
   245         return
       
   246 
       
   247 class NothingNode(object):
       
   248     """
       
   249     A node that matches nothing.
       
   250     """
       
   251     def as_sql(self, qn=None, connection=None):
       
   252         raise EmptyResultSet
       
   253 
       
   254     def relabel_aliases(self, change_map, node=None):
       
   255         return
       
   256 
       
   257 class ExtraWhere(object):
       
   258     def __init__(self, sqls, params):
       
   259         self.sqls = sqls
       
   260         self.params = params
       
   261 
       
   262     def as_sql(self, qn=None, connection=None):
       
   263         return " AND ".join(self.sqls), tuple(self.params or ())
       
   264 
       
   265 class Constraint(object):
       
   266     """
       
   267     An object that can be passed to WhereNode.add() and knows how to
       
   268     pre-process itself prior to including in the WhereNode.
       
   269     """
       
   270     def __init__(self, alias, col, field):
       
   271         self.alias, self.col, self.field = alias, col, field
       
   272 
       
   273     def __getstate__(self):
       
   274         """Save the state of the Constraint for pickling.
       
   275 
       
   276         Fields aren't necessarily pickleable, because they can have
       
   277         callable default values. So, instead of pickling the field
       
   278         store a reference so we can restore it manually
       
   279         """
       
   280         obj_dict = self.__dict__.copy()
       
   281         if self.field:
       
   282             obj_dict['model'] = self.field.model
       
   283             obj_dict['field_name'] = self.field.name
       
   284         del obj_dict['field']
       
   285         return obj_dict
       
   286 
       
   287     def __setstate__(self, data):
       
   288         """Restore the constraint """
       
   289         model = data.pop('model', None)
       
   290         field_name = data.pop('field_name', None)
       
   291         self.__dict__.update(data)
       
   292         if model is not None:
       
   293             self.field = model._meta.get_field(field_name)
       
   294         else:
       
   295             self.field = None
       
   296 
       
   297     def prepare(self, lookup_type, value):
       
   298         if self.field:
       
   299             return self.field.get_prep_lookup(lookup_type, value)
       
   300         return value
       
   301 
       
   302     def process(self, lookup_type, value, connection):
       
   303         """
       
   304         Returns a tuple of data suitable for inclusion in a WhereNode
       
   305         instance.
       
   306         """
       
   307         # Because of circular imports, we need to import this here.
       
   308         from django.db.models.base import ObjectDoesNotExist
       
   309         try:
       
   310             if self.field:
       
   311                 params = self.field.get_db_prep_lookup(lookup_type, value,
       
   312                     connection=connection, prepared=True)
       
   313                 db_type = self.field.db_type(connection=connection)
       
   314             else:
       
   315                 # This branch is used at times when we add a comparison to NULL
       
   316                 # (we don't really want to waste time looking up the associated
       
   317                 # field object at the calling location).
       
   318                 params = Field().get_db_prep_lookup(lookup_type, value,
       
   319                     connection=connection, prepared=True)
       
   320                 db_type = None
       
   321         except ObjectDoesNotExist:
       
   322             raise EmptyShortCircuit
       
   323 
       
   324         return (self.alias, self.col, db_type), params
       
   325 
       
   326     def relabel_aliases(self, change_map):
       
   327         if self.alias in change_map:
       
   328             self.alias = change_map[self.alias]