web/lib/django/db/models/sql/where.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
equal deleted inserted replaced
28:b758351d191f 29:cc9b7e14412b
     2 Code to manage the creation and SQL rendering of 'where' constraints.
     2 Code to manage the creation and SQL rendering of 'where' constraints.
     3 """
     3 """
     4 import datetime
     4 import datetime
     5 
     5 
     6 from django.utils import tree
     6 from django.utils import tree
     7 from django.db import connection
       
     8 from django.db.models.fields import Field
     7 from django.db.models.fields import Field
     9 from django.db.models.query_utils import QueryWrapper
     8 from django.db.models.query_utils import QueryWrapper
    10 from datastructures import EmptyResultSet, FullResultSet
     9 from datastructures import EmptyResultSet, FullResultSet
    11 
    10 
    12 # Connection types
    11 # Connection types
    49         obj, lookup_type, value = data
    48         obj, lookup_type, value = data
    50         if hasattr(value, '__iter__') and hasattr(value, 'next'):
    49         if hasattr(value, '__iter__') and hasattr(value, 'next'):
    51             # Consume any generators immediately, so that we can determine
    50             # Consume any generators immediately, so that we can determine
    52             # emptiness and transform any non-empty values correctly.
    51             # emptiness and transform any non-empty values correctly.
    53             value = list(value)
    52             value = list(value)
    54         if hasattr(obj, "process"):
       
    55             try:
       
    56                 obj, params = obj.process(lookup_type, value)
       
    57             except (EmptyShortCircuit, EmptyResultSet):
       
    58                 # There are situations where we want to short-circuit any
       
    59                 # comparisons and make sure that nothing is returned. One
       
    60                 # example is when checking for a NULL pk value, or the
       
    61                 # equivalent.
       
    62                 super(WhereNode, self).add(NothingNode(), connector)
       
    63                 return
       
    64         else:
       
    65             params = Field().get_db_prep_lookup(lookup_type, value)
       
    66 
    53 
    67         # The "annotation" parameter is used to pass auxilliary information
    54         # The "annotation" parameter is used to pass auxilliary information
    68         # about the value(s) to the query construction. Specifically, datetime
    55         # about the value(s) to the query construction. Specifically, datetime
    69         # and empty values need special handling. Other types could be used
    56         # and empty values need special handling. Other types could be used
    70         # here in the future (using Python types is suggested for consistency).
    57         # here in the future (using Python types is suggested for consistency).
    73         elif hasattr(value, 'value_annotation'):
    60         elif hasattr(value, 'value_annotation'):
    74             annotation = value.value_annotation
    61             annotation = value.value_annotation
    75         else:
    62         else:
    76             annotation = bool(value)
    63             annotation = bool(value)
    77 
    64 
    78         super(WhereNode, self).add((obj, lookup_type, annotation, params),
    65         if hasattr(obj, "prepare"):
       
    66             value = obj.prepare(lookup_type, value)
       
    67             super(WhereNode, self).add((obj, lookup_type, annotation, value),
    79                 connector)
    68                 connector)
    80 
    69             return
    81     def as_sql(self, qn=None):
    70 
       
    71         super(WhereNode, self).add((obj, lookup_type, annotation, value),
       
    72                 connector)
       
    73 
       
    74     def as_sql(self, qn, connection):
    82         """
    75         """
    83         Returns the SQL version of the where clause and the value to be
    76         Returns the SQL version of the where clause and the value to be
    84         substituted in. Returns None, None if this node is empty.
    77         substituted in. Returns None, None if this node is empty.
    85 
    78 
    86         If 'node' is provided, that is the root of the SQL generation
    79         If 'node' is provided, that is the root of the SQL generation
    87         (generally not needed except by the internal implementation for
    80         (generally not needed except by the internal implementation for
    88         recursion).
    81         recursion).
    89         """
    82         """
    90         if not qn:
       
    91             qn = connection.ops.quote_name
       
    92         if not self.children:
    83         if not self.children:
    93             return None, []
    84             return None, []
    94         result = []
    85         result = []
    95         result_params = []
    86         result_params = []
    96         empty = True
    87         empty = True
    97         for child in self.children:
    88         for child in self.children:
    98             try:
    89             try:
    99                 if hasattr(child, 'as_sql'):
    90                 if hasattr(child, 'as_sql'):
   100                     sql, params = child.as_sql(qn=qn)
    91                     sql, params = child.as_sql(qn=qn, connection=connection)
   101                 else:
    92                 else:
   102                     # A leaf node in the tree.
    93                     # A leaf node in the tree.
   103                     sql, params = self.make_atom(child, qn)
    94                     sql, params = self.make_atom(child, qn, connection)
   104 
    95 
   105             except EmptyResultSet:
    96             except EmptyResultSet:
   106                 if self.connector == AND and not self.negated:
    97                 if self.connector == AND and not self.negated:
   107                     # We can bail out early in this particular case (only).
    98                     # We can bail out early in this particular case (only).
   108                     raise
    99                     raise
   134                 sql_string = 'NOT (%s)' % sql_string
   125                 sql_string = 'NOT (%s)' % sql_string
   135             elif len(self.children) != 1:
   126             elif len(self.children) != 1:
   136                 sql_string = '(%s)' % sql_string
   127                 sql_string = '(%s)' % sql_string
   137         return sql_string, result_params
   128         return sql_string, result_params
   138 
   129 
   139     def make_atom(self, child, qn):
   130     def make_atom(self, child, qn, connection):
   140         """
   131         """
   141         Turn a tuple (table_alias, column_name, db_type, lookup_type,
   132         Turn a tuple (table_alias, column_name, db_type, lookup_type,
   142         value_annot, params) into valid SQL.
   133         value_annot, params) into valid SQL.
   143 
   134 
   144         Returns the string for the SQL fragment and the parameters to use for
   135         Returns the string for the SQL fragment and the parameters to use for
   145         it.
   136         it.
   146         """
   137         """
   147         lvalue, lookup_type, value_annot, params = child
   138         lvalue, lookup_type, value_annot, params_or_value = child
       
   139         if hasattr(lvalue, 'process'):
       
   140             try:
       
   141                 lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
       
   142             except EmptyShortCircuit:
       
   143                 raise EmptyResultSet
       
   144         else:
       
   145             params = Field().get_db_prep_lookup(lookup_type, params_or_value,
       
   146                 connection=connection, prepared=True)
   148         if isinstance(lvalue, tuple):
   147         if isinstance(lvalue, tuple):
   149             # A direct database column lookup.
   148             # A direct database column lookup.
   150             field_sql = self.sql_for_columns(lvalue, qn)
   149             field_sql = self.sql_for_columns(lvalue, qn, connection)
   151         else:
   150         else:
   152             # A smart object with an as_sql() method.
   151             # A smart object with an as_sql() method.
   153             field_sql = lvalue.as_sql(quote_func=qn)
   152             field_sql = lvalue.as_sql(qn, connection)
   154 
   153 
   155         if value_annot is datetime.datetime:
   154         if value_annot is datetime.datetime:
   156             cast_sql = connection.ops.datetime_cast_sql()
   155             cast_sql = connection.ops.datetime_cast_sql()
   157         else:
   156         else:
   158             cast_sql = '%s'
   157             cast_sql = '%s'
   159 
   158 
   160         if hasattr(params, 'as_sql'):
   159         if hasattr(params, 'as_sql'):
   161             extra, params = params.as_sql(qn)
   160             extra, params = params.as_sql(qn, connection)
   162             cast_sql = ''
   161             cast_sql = ''
   163         else:
   162         else:
   164             extra = ''
   163             extra = ''
       
   164 
       
   165         if (len(params) == 1 and params[0] == '' and lookup_type == 'exact'
       
   166             and connection.features.interprets_empty_strings_as_nulls):
       
   167             lookup_type = 'isnull'
       
   168             value_annot = True
   165 
   169 
   166         if lookup_type in connection.operators:
   170         if lookup_type in connection.operators:
   167             format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
   171             format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
   168             return (format % (field_sql,
   172             return (format % (field_sql,
   169                               connection.operators[lookup_type] % cast_sql,
   173                               connection.operators[lookup_type] % cast_sql,
   189         elif lookup_type in ('regex', 'iregex'):
   193         elif lookup_type in ('regex', 'iregex'):
   190             return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
   194             return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
   191 
   195 
   192         raise TypeError('Invalid lookup_type: %r' % lookup_type)
   196         raise TypeError('Invalid lookup_type: %r' % lookup_type)
   193 
   197 
   194     def sql_for_columns(self, data, qn):
   198     def sql_for_columns(self, data, qn, connection):
   195         """
   199         """
   196         Returns the SQL fragment used for the left-hand side of a column
   200         Returns the SQL fragment used for the left-hand side of a column
   197         constraint (for example, the "T1.foo" portion in the clause
   201         constraint (for example, the "T1.foo" portion in the clause
   198         "WHERE ... T1.foo = 6").
   202         "WHERE ... T1.foo = 6").
   199         """
   203         """
   214         for pos, child in enumerate(node.children):
   218         for pos, child in enumerate(node.children):
   215             if hasattr(child, 'relabel_aliases'):
   219             if hasattr(child, 'relabel_aliases'):
   216                 child.relabel_aliases(change_map)
   220                 child.relabel_aliases(change_map)
   217             elif isinstance(child, tree.Node):
   221             elif isinstance(child, tree.Node):
   218                 self.relabel_aliases(change_map, child)
   222                 self.relabel_aliases(change_map, child)
   219             else:
   223             elif isinstance(child, (list, tuple)):
   220                 if isinstance(child[0], (list, tuple)):
   224                 if isinstance(child[0], (list, tuple)):
   221                     elt = list(child[0])
   225                     elt = list(child[0])
   222                     if elt[0] in change_map:
   226                     if elt[0] in change_map:
   223                         elt[0] = change_map[elt[0]]
   227                         elt[0] = change_map[elt[0]]
   224                         node.children[pos] = (tuple(elt),) + child[1:]
   228                         node.children[pos] = (tuple(elt),) + child[1:]
   231 
   235 
   232 class EverythingNode(object):
   236 class EverythingNode(object):
   233     """
   237     """
   234     A node that matches everything.
   238     A node that matches everything.
   235     """
   239     """
   236     def as_sql(self, qn=None):
   240 
       
   241     def as_sql(self, qn=None, connection=None):
   237         raise FullResultSet
   242         raise FullResultSet
   238 
   243 
   239     def relabel_aliases(self, change_map, node=None):
   244     def relabel_aliases(self, change_map, node=None):
   240         return
   245         return
   241 
   246 
   242 class NothingNode(object):
   247 class NothingNode(object):
   243     """
   248     """
   244     A node that matches nothing.
   249     A node that matches nothing.
   245     """
   250     """
   246     def as_sql(self, qn=None):
   251     def as_sql(self, qn=None, connection=None):
   247         raise EmptyResultSet
   252         raise EmptyResultSet
   248 
   253 
   249     def relabel_aliases(self, change_map, node=None):
   254     def relabel_aliases(self, change_map, node=None):
   250         return
   255         return
   251 
   256 
       
   257 class ExtraWhere(object):
       
   258     def __init__(self, sqls, params):
       
   259         self.sqls = sqls
       
   260         self.params = params
       
   261 
       
   262     def as_sql(self, qn=None, connection=None):
       
   263         return " AND ".join(self.sqls), tuple(self.params or ())
       
   264 
   252 class Constraint(object):
   265 class Constraint(object):
   253     """
   266     """
   254     An object that can be passed to WhereNode.add() and knows how to
   267     An object that can be passed to WhereNode.add() and knows how to
   255     pre-process itself prior to including in the WhereNode.
   268     pre-process itself prior to including in the WhereNode.
   256     """
   269     """
   257     def __init__(self, alias, col, field):
   270     def __init__(self, alias, col, field):
   258         self.alias, self.col, self.field = alias, col, field
   271         self.alias, self.col, self.field = alias, col, field
   259 
   272 
   260     def process(self, lookup_type, value):
   273     def __getstate__(self):
       
   274         """Save the state of the Constraint for pickling.
       
   275 
       
   276         Fields aren't necessarily pickleable, because they can have
       
   277         callable default values. So, instead of pickling the field
       
   278         store a reference so we can restore it manually
       
   279         """
       
   280         obj_dict = self.__dict__.copy()
       
   281         if self.field:
       
   282             obj_dict['model'] = self.field.model
       
   283             obj_dict['field_name'] = self.field.name
       
   284         del obj_dict['field']
       
   285         return obj_dict
       
   286 
       
   287     def __setstate__(self, data):
       
   288         """Restore the constraint """
       
   289         model = data.pop('model', None)
       
   290         field_name = data.pop('field_name', None)
       
   291         self.__dict__.update(data)
       
   292         if model is not None:
       
   293             self.field = model._meta.get_field(field_name)
       
   294         else:
       
   295             self.field = None
       
   296 
       
   297     def prepare(self, lookup_type, value):
       
   298         if self.field:
       
   299             return self.field.get_prep_lookup(lookup_type, value)
       
   300         return value
       
   301 
       
   302     def process(self, lookup_type, value, connection):
   261         """
   303         """
   262         Returns a tuple of data suitable for inclusion in a WhereNode
   304         Returns a tuple of data suitable for inclusion in a WhereNode
   263         instance.
   305         instance.
   264         """
   306         """
   265         # Because of circular imports, we need to import this here.
   307         # Because of circular imports, we need to import this here.
   266         from django.db.models.base import ObjectDoesNotExist
   308         from django.db.models.base import ObjectDoesNotExist
   267         try:
   309         try:
   268             if self.field:
   310             if self.field:
   269                 params = self.field.get_db_prep_lookup(lookup_type, value)
   311                 params = self.field.get_db_prep_lookup(lookup_type, value,
   270                 db_type = self.field.db_type()
   312                     connection=connection, prepared=True)
       
   313                 db_type = self.field.db_type(connection=connection)
   271             else:
   314             else:
   272                 # This branch is used at times when we add a comparison to NULL
   315                 # This branch is used at times when we add a comparison to NULL
   273                 # (we don't really want to waste time looking up the associated
   316                 # (we don't really want to waste time looking up the associated
   274                 # field object at the calling location).
   317                 # field object at the calling location).
   275                 params = Field().get_db_prep_lookup(lookup_type, value)
   318                 params = Field().get_db_prep_lookup(lookup_type, value,
       
   319                     connection=connection, prepared=True)
   276                 db_type = None
   320                 db_type = None
   277         except ObjectDoesNotExist:
   321         except ObjectDoesNotExist:
   278             raise EmptyShortCircuit
   322             raise EmptyShortCircuit
   279 
   323 
   280         return (self.alias, self.col, db_type), params
   324         return (self.alias, self.col, db_type), params
   281 
   325 
       
   326     def relabel_aliases(self, change_map):
       
   327         if self.alias in change_map:
       
   328             self.alias = change_map[self.alias]