web/lib/django/db/models/sql/subqueries.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/db/models/sql/subqueries.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/db/models/sql/subqueries.py	Tue May 25 02:43:45 2010 +0200
@@ -3,6 +3,7 @@
 """
 
 from django.core.exceptions import FieldError
+from django.db import connections
 from django.db.models.sql.constants import *
 from django.db.models.sql.datastructures import Date
 from django.db.models.sql.expressions import SQLEvaluator
@@ -17,65 +18,17 @@
     Delete queries are done through this class, since they are more constrained
     than general queries.
     """
-    def as_sql(self):
-        """
-        Creates the SQL for this query. Returns the SQL string and list of
-        parameters.
-        """
-        assert len(self.tables) == 1, \
-                "Can only delete from one table at a time."
-        result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])]
-        where, params = self.where.as_sql()
-        result.append('WHERE %s' % where)
-        return ' '.join(result), tuple(params)
 
-    def do_query(self, table, where):
+    compiler = 'SQLDeleteCompiler'
+
+    def do_query(self, table, where, using):
         self.tables = [table]
         self.where = where
-        self.execute_sql(None)
+        self.get_compiler(using).execute_sql(None)
 
-    def delete_batch_related(self, pk_list):
-        """
-        Set up and execute delete queries for all the objects related to the
-        primary key values in pk_list. To delete the objects themselves, use
-        the delete_batch() method.
-
-        More than one physical query may be executed if there are a
-        lot of values in pk_list.
+    def delete_batch(self, pk_list, using):
         """
-        from django.contrib.contenttypes import generic
-        cls = self.model
-        for related in cls._meta.get_all_related_many_to_many_objects():
-            if not isinstance(related.field, generic.GenericRelation):
-                for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
-                    where = self.where_class()
-                    where.add((Constraint(None,
-                            related.field.m2m_reverse_name(), related.field),
-                            'in',
-                            pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]),
-                            AND)
-                    self.do_query(related.field.m2m_db_table(), where)
-
-        for f in cls._meta.many_to_many:
-            w1 = self.where_class()
-            if isinstance(f, generic.GenericRelation):
-                from django.contrib.contenttypes.models import ContentType
-                field = f.rel.to._meta.get_field(f.content_type_field_name)
-                w1.add((Constraint(None, field.column, field), 'exact',
-                        ContentType.objects.get_for_model(cls).id), AND)
-            for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
-                where = self.where_class()
-                where.add((Constraint(None, f.m2m_column_name(), f), 'in',
-                        pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
-                        AND)
-                if w1:
-                    where.add(w1, AND)
-                self.do_query(f.m2m_db_table(), where)
-
-    def delete_batch(self, pk_list):
-        """
-        Set up and execute delete queries for all the objects in pk_list. This
-        should be called after delete_batch_related(), if necessary.
+        Set up and execute delete queries for all the objects in pk_list.
 
         More than one physical query may be executed if there are a
         lot of values in pk_list.
@@ -85,12 +38,15 @@
             field = self.model._meta.pk
             where.add((Constraint(None, field.column, field), 'in',
                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
-            self.do_query(self.model._meta.db_table, where)
+            self.do_query(self.model._meta.db_table, where, using=using)
 
 class UpdateQuery(Query):
     """
     Represents an "update" SQL query.
     """
+
+    compiler = 'SQLUpdateCompiler'
+
     def __init__(self, *args, **kwargs):
         super(UpdateQuery, self).__init__(*args, **kwargs)
         self._setup_query()
@@ -110,98 +66,8 @@
         return super(UpdateQuery, self).clone(klass,
                 related_updates=self.related_updates.copy(), **kwargs)
 
-    def execute_sql(self, result_type=None):
-        """
-        Execute the specified update. Returns the number of rows affected by
-        the primary update query. The "primary update query" is the first
-        non-empty query that is executed. Row counts for any subsequent,
-        related queries are not available.
-        """
-        cursor = super(UpdateQuery, self).execute_sql(result_type)
-        rows = cursor and cursor.rowcount or 0
-        is_empty = cursor is None
-        del cursor
-        for query in self.get_related_updates():
-            aux_rows = query.execute_sql(result_type)
-            if is_empty:
-                rows = aux_rows
-                is_empty = False
-        return rows
 
-    def as_sql(self):
-        """
-        Creates the SQL for this query. Returns the SQL string and list of
-        parameters.
-        """
-        self.pre_sql_setup()
-        if not self.values:
-            return '', ()
-        table = self.tables[0]
-        qn = self.quote_name_unless_alias
-        result = ['UPDATE %s' % qn(table)]
-        result.append('SET')
-        values, update_params = [], []
-        for name, val, placeholder in self.values:
-            if hasattr(val, 'as_sql'):
-                sql, params = val.as_sql(qn)
-                values.append('%s = %s' % (qn(name), sql))
-                update_params.extend(params)
-            elif val is not None:
-                values.append('%s = %s' % (qn(name), placeholder))
-                update_params.append(val)
-            else:
-                values.append('%s = NULL' % qn(name))
-        result.append(', '.join(values))
-        where, params = self.where.as_sql()
-        if where:
-            result.append('WHERE %s' % where)
-        return ' '.join(result), tuple(update_params + params)
-
-    def pre_sql_setup(self):
-        """
-        If the update depends on results from other tables, we need to do some
-        munging of the "where" conditions to match the format required for
-        (portable) SQL updates. That is done here.
-
-        Further, if we are going to be running multiple updates, we pull out
-        the id values to update at this point so that they don't change as a
-        result of the progressive updates.
-        """
-        self.select_related = False
-        self.clear_ordering(True)
-        super(UpdateQuery, self).pre_sql_setup()
-        count = self.count_active_tables()
-        if not self.related_updates and count == 1:
-            return
-
-        # We need to use a sub-select in the where clause to filter on things
-        # from other tables.
-        query = self.clone(klass=Query)
-        query.bump_prefix()
-        query.extra = {}
-        query.select = []
-        query.add_fields([query.model._meta.pk.name])
-        must_pre_select = count > 1 and not self.connection.features.update_can_self_select
-
-        # Now we adjust the current query: reset the where clause and get rid
-        # of all the tables we don't need (since they're in the sub-select).
-        self.where = self.where_class()
-        if self.related_updates or must_pre_select:
-            # Either we're using the idents in multiple update queries (so
-            # don't want them to change), or the db backend doesn't support
-            # selecting from the updating table (e.g. MySQL).
-            idents = []
-            for rows in query.execute_sql(MULTI):
-                idents.extend([r[0] for r in rows])
-            self.add_filter(('pk__in', idents))
-            self.related_ids = idents
-        else:
-            # The fast path. Filters and updates in one query.
-            self.add_filter(('pk__in', query))
-        for alias in self.tables[1:]:
-            self.alias_refcount[alias] = 0
-
-    def clear_related(self, related_field, pk_list):
+    def clear_related(self, related_field, pk_list, using):
         """
         Set up and execute an update query that clears related entries for the
         keys in pk_list.
@@ -214,8 +80,8 @@
             self.where.add((Constraint(None, f.column, f), 'in',
                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
                     AND)
-            self.values = [(related_field.column, None, '%s')]
-            self.execute_sql(None)
+            self.values = [(related_field, None, None)]
+            self.get_compiler(using).execute_sql(None)
 
     def add_update_values(self, values):
         """
@@ -228,6 +94,9 @@
             field, model, direct, m2m = self.model._meta.get_field_by_name(name)
             if not direct or m2m:
                 raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field)
+            if model:
+                self.add_related_update(model, field, val)
+                continue
             values_seq.append((field, model, val))
         return self.add_update_fields(values_seq)
 
@@ -237,36 +106,18 @@
         Used by add_update_values() as well as the "fast" update path when
         saving models.
         """
-        from django.db.models.base import Model
-        for field, model, val in values_seq:
-            if hasattr(val, 'prepare_database_save'):
-                val = val.prepare_database_save(field)
-            else:
-                val = field.get_db_prep_save(val)
+        self.values.extend(values_seq)
 
-            # Getting the placeholder for the field.
-            if hasattr(field, 'get_placeholder'):
-                placeholder = field.get_placeholder(val)
-            else:
-                placeholder = '%s'
-
-            if hasattr(val, 'evaluate'):
-                val = SQLEvaluator(val, self, allow_joins=False)
-            if model:
-                self.add_related_update(model, field.column, val, placeholder)
-            else:
-                self.values.append((field.column, val, placeholder))
-
-    def add_related_update(self, model, column, value, placeholder):
+    def add_related_update(self, model, field, value):
         """
         Adds (name, value) to an update query for an ancestor model.
 
         Updates are coalesced so that we only run one update query per ancestor.
         """
         try:
-            self.related_updates[model].append((column, value, placeholder))
+            self.related_updates[model].append((field, None, value))
         except KeyError:
-            self.related_updates[model] = [(column, value, placeholder)]
+            self.related_updates[model] = [(field, None, value)]
 
     def get_related_updates(self):
         """
@@ -278,53 +129,31 @@
             return []
         result = []
         for model, values in self.related_updates.iteritems():
-            query = UpdateQuery(model, self.connection)
+            query = UpdateQuery(model)
             query.values = values
-            if self.related_ids:
+            if self.related_ids is not None:
                 query.add_filter(('pk__in', self.related_ids))
             result.append(query)
         return result
 
 class InsertQuery(Query):
+    compiler = 'SQLInsertCompiler'
+
     def __init__(self, *args, **kwargs):
         super(InsertQuery, self).__init__(*args, **kwargs)
         self.columns = []
         self.values = []
         self.params = ()
-        self.return_id = False
 
     def clone(self, klass=None, **kwargs):
-        extras = {'columns': self.columns[:], 'values': self.values[:],
-                  'params': self.params, 'return_id': self.return_id}
+        extras = {
+            'columns': self.columns[:],
+            'values': self.values[:],
+            'params': self.params
+        }
         extras.update(kwargs)
         return super(InsertQuery, self).clone(klass, **extras)
 
-    def as_sql(self):
-        # We don't need quote_name_unless_alias() here, since these are all
-        # going to be column names (so we can avoid the extra overhead).
-        qn = self.connection.ops.quote_name
-        opts = self.model._meta
-        result = ['INSERT INTO %s' % qn(opts.db_table)]
-        result.append('(%s)' % ', '.join([qn(c) for c in self.columns]))
-        result.append('VALUES (%s)' % ', '.join(self.values))
-        params = self.params
-        if self.return_id and self.connection.features.can_return_id_from_insert:
-            col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column))
-            r_fmt, r_params = self.connection.ops.return_insert_id()
-            result.append(r_fmt % col)
-            params = params + r_params
-        return ' '.join(result), params
-
-    def execute_sql(self, return_id=False):
-        self.return_id = return_id
-        cursor = super(InsertQuery, self).execute_sql(None)
-        if not (return_id and cursor):
-            return
-        if self.connection.features.can_return_id_from_insert:
-            return self.connection.ops.fetch_returned_insert_id(cursor)
-        return self.connection.ops.last_insert_id(cursor,
-                self.model._meta.db_table, self.model._meta.pk.column)
-
     def insert_values(self, insert_values, raw_values=False):
         """
         Set up the insert query from the 'insert_values' dictionary. The
@@ -337,17 +166,11 @@
         """
         placeholders, values = [], []
         for field, val in insert_values:
-            if hasattr(field, 'get_placeholder'):
-                # Some fields (e.g. geo fields) need special munging before
-                # they can be inserted.
-                placeholders.append(field.get_placeholder(val))
-            else:
-                placeholders.append('%s')
-
+            placeholders.append((field, val))
             self.columns.append(field.column)
             values.append(val)
         if raw_values:
-            self.values.extend(values)
+            self.values.extend([(None, v) for v in values])
         else:
             self.params += tuple(values)
             self.values.extend(placeholders)
@@ -358,44 +181,8 @@
     date field. This requires some special handling when converting the results
     back to Python objects, so we put it in a separate class.
     """
-    def __getstate__(self):
-        """
-        Special DateQuery-specific pickle handling.
-        """
-        for elt in self.select:
-            if isinstance(elt, Date):
-                # Eliminate a method reference that can't be pickled. The
-                # __setstate__ method restores this.
-                elt.date_sql_func = None
-        return super(DateQuery, self).__getstate__()
 
-    def __setstate__(self, obj_dict):
-        super(DateQuery, self).__setstate__(obj_dict)
-        for elt in self.select:
-            if isinstance(elt, Date):
-                self.date_sql_func = self.connection.ops.date_trunc_sql
-
-    def results_iter(self):
-        """
-        Returns an iterator over the results from executing this query.
-        """
-        resolve_columns = hasattr(self, 'resolve_columns')
-        if resolve_columns:
-            from django.db.models.fields import DateTimeField
-            fields = [DateTimeField()]
-        else:
-            from django.db.backends.util import typecast_timestamp
-            needs_string_cast = self.connection.features.needs_datetime_string_cast
-
-        offset = len(self.extra_select)
-        for rows in self.execute_sql(MULTI):
-            for row in rows:
-                date = row[offset]
-                if resolve_columns:
-                    date = self.resolve_columns(row, fields)[offset]
-                elif needs_string_cast:
-                    date = typecast_timestamp(str(date))
-                yield date
+    compiler = 'SQLDateCompiler'
 
     def add_date_select(self, field, lookup_type, order='ASC'):
         """
@@ -404,12 +191,11 @@
         result = self.setup_joins([field.name], self.get_meta(),
                 self.get_initial_alias(), False)
         alias = result[3][-1]
-        select = Date((alias, field.column), lookup_type,
-                self.connection.ops.date_trunc_sql)
+        select = Date((alias, field.column), lookup_type)
         self.select = [select]
         self.select_fields = [None]
         self.select_related = False # See #7097.
-        self.extra = {}
+        self.set_extra_mask([])
         self.distinct = True
         self.order_by = order == 'ASC' and [1] or [-1]
 
@@ -418,20 +204,8 @@
     An AggregateQuery takes another query as a parameter to the FROM
     clause and only selects the elements in the provided list.
     """
-    def add_subquery(self, query):
-        self.subquery, self.sub_params = query.as_sql(with_col_aliases=True)
+
+    compiler = 'SQLAggregateCompiler'
 
-    def as_sql(self, quote_func=None):
-        """
-        Creates the SQL for this query. Returns the SQL string and list of
-        parameters.
-        """
-        sql = ('SELECT %s FROM (%s) subquery' % (
-            ', '.join([
-                aggregate.as_sql()
-                for aggregate in self.aggregate_select.values()
-            ]),
-            self.subquery)
-        )
-        params = self.sub_params
-        return (sql, params)
+    def add_subquery(self, query, using):
+        self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)