diff -r b758351d191f -r cc9b7e14412b web/lib/django/db/models/sql/subqueries.py --- a/web/lib/django/db/models/sql/subqueries.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/db/models/sql/subqueries.py Tue May 25 02:43:45 2010 +0200 @@ -3,6 +3,7 @@ """ from django.core.exceptions import FieldError +from django.db import connections from django.db.models.sql.constants import * from django.db.models.sql.datastructures import Date from django.db.models.sql.expressions import SQLEvaluator @@ -17,65 +18,17 @@ Delete queries are done through this class, since they are more constrained than general queries. """ - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - assert len(self.tables) == 1, \ - "Can only delete from one table at a time." - result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] - where, params = self.where.as_sql() - result.append('WHERE %s' % where) - return ' '.join(result), tuple(params) - def do_query(self, table, where): + compiler = 'SQLDeleteCompiler' + + def do_query(self, table, where, using): self.tables = [table] self.where = where - self.execute_sql(None) + self.get_compiler(using).execute_sql(None) - def delete_batch_related(self, pk_list): - """ - Set up and execute delete queries for all the objects related to the - primary key values in pk_list. To delete the objects themselves, use - the delete_batch() method. - - More than one physical query may be executed if there are a - lot of values in pk_list. + def delete_batch(self, pk_list, using): """ - from django.contrib.contenttypes import generic - cls = self.model - for related in cls._meta.get_all_related_many_to_many_objects(): - if not isinstance(related.field, generic.GenericRelation): - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() - where.add((Constraint(None, - related.field.m2m_reverse_name(), related.field), - 'in', - pk_list[offset : offset+GET_ITERATOR_CHUNK_SIZE]), - AND) - self.do_query(related.field.m2m_db_table(), where) - - for f in cls._meta.many_to_many: - w1 = self.where_class() - if isinstance(f, generic.GenericRelation): - from django.contrib.contenttypes.models import ContentType - field = f.rel.to._meta.get_field(f.content_type_field_name) - w1.add((Constraint(None, field.column, field), 'exact', - ContentType.objects.get_for_model(cls).id), AND) - for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE): - where = self.where_class() - where.add((Constraint(None, f.m2m_column_name(), f), 'in', - pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), - AND) - if w1: - where.add(w1, AND) - self.do_query(f.m2m_db_table(), where) - - def delete_batch(self, pk_list): - """ - Set up and execute delete queries for all the objects in pk_list. This - should be called after delete_batch_related(), if necessary. + Set up and execute delete queries for all the objects in pk_list. More than one physical query may be executed if there are a lot of values in pk_list. @@ -85,12 +38,15 @@ field = self.model._meta.pk where.add((Constraint(None, field.column, field), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.do_query(self.model._meta.db_table, where) + self.do_query(self.model._meta.db_table, where, using=using) class UpdateQuery(Query): """ Represents an "update" SQL query. """ + + compiler = 'SQLUpdateCompiler' + def __init__(self, *args, **kwargs): super(UpdateQuery, self).__init__(*args, **kwargs) self._setup_query() @@ -110,98 +66,8 @@ return super(UpdateQuery, self).clone(klass, related_updates=self.related_updates.copy(), **kwargs) - def execute_sql(self, result_type=None): - """ - Execute the specified update. Returns the number of rows affected by - the primary update query. The "primary update query" is the first - non-empty query that is executed. Row counts for any subsequent, - related queries are not available. - """ - cursor = super(UpdateQuery, self).execute_sql(result_type) - rows = cursor and cursor.rowcount or 0 - is_empty = cursor is None - del cursor - for query in self.get_related_updates(): - aux_rows = query.execute_sql(result_type) - if is_empty: - rows = aux_rows - is_empty = False - return rows - def as_sql(self): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - self.pre_sql_setup() - if not self.values: - return '', () - table = self.tables[0] - qn = self.quote_name_unless_alias - result = ['UPDATE %s' % qn(table)] - result.append('SET') - values, update_params = [], [] - for name, val, placeholder in self.values: - if hasattr(val, 'as_sql'): - sql, params = val.as_sql(qn) - values.append('%s = %s' % (qn(name), sql)) - update_params.extend(params) - elif val is not None: - values.append('%s = %s' % (qn(name), placeholder)) - update_params.append(val) - else: - values.append('%s = NULL' % qn(name)) - result.append(', '.join(values)) - where, params = self.where.as_sql() - if where: - result.append('WHERE %s' % where) - return ' '.join(result), tuple(update_params + params) - - def pre_sql_setup(self): - """ - If the update depends on results from other tables, we need to do some - munging of the "where" conditions to match the format required for - (portable) SQL updates. That is done here. - - Further, if we are going to be running multiple updates, we pull out - the id values to update at this point so that they don't change as a - result of the progressive updates. - """ - self.select_related = False - self.clear_ordering(True) - super(UpdateQuery, self).pre_sql_setup() - count = self.count_active_tables() - if not self.related_updates and count == 1: - return - - # We need to use a sub-select in the where clause to filter on things - # from other tables. - query = self.clone(klass=Query) - query.bump_prefix() - query.extra = {} - query.select = [] - query.add_fields([query.model._meta.pk.name]) - must_pre_select = count > 1 and not self.connection.features.update_can_self_select - - # Now we adjust the current query: reset the where clause and get rid - # of all the tables we don't need (since they're in the sub-select). - self.where = self.where_class() - if self.related_updates or must_pre_select: - # Either we're using the idents in multiple update queries (so - # don't want them to change), or the db backend doesn't support - # selecting from the updating table (e.g. MySQL). - idents = [] - for rows in query.execute_sql(MULTI): - idents.extend([r[0] for r in rows]) - self.add_filter(('pk__in', idents)) - self.related_ids = idents - else: - # The fast path. Filters and updates in one query. - self.add_filter(('pk__in', query)) - for alias in self.tables[1:]: - self.alias_refcount[alias] = 0 - - def clear_related(self, related_field, pk_list): + def clear_related(self, related_field, pk_list, using): """ Set up and execute an update query that clears related entries for the keys in pk_list. @@ -214,8 +80,8 @@ self.where.add((Constraint(None, f.column, f), 'in', pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND) - self.values = [(related_field.column, None, '%s')] - self.execute_sql(None) + self.values = [(related_field, None, None)] + self.get_compiler(using).execute_sql(None) def add_update_values(self, values): """ @@ -228,6 +94,9 @@ field, model, direct, m2m = self.model._meta.get_field_by_name(name) if not direct or m2m: raise FieldError('Cannot update model field %r (only non-relations and foreign keys permitted).' % field) + if model: + self.add_related_update(model, field, val) + continue values_seq.append((field, model, val)) return self.add_update_fields(values_seq) @@ -237,36 +106,18 @@ Used by add_update_values() as well as the "fast" update path when saving models. """ - from django.db.models.base import Model - for field, model, val in values_seq: - if hasattr(val, 'prepare_database_save'): - val = val.prepare_database_save(field) - else: - val = field.get_db_prep_save(val) + self.values.extend(values_seq) - # Getting the placeholder for the field. - if hasattr(field, 'get_placeholder'): - placeholder = field.get_placeholder(val) - else: - placeholder = '%s' - - if hasattr(val, 'evaluate'): - val = SQLEvaluator(val, self, allow_joins=False) - if model: - self.add_related_update(model, field.column, val, placeholder) - else: - self.values.append((field.column, val, placeholder)) - - def add_related_update(self, model, column, value, placeholder): + def add_related_update(self, model, field, value): """ Adds (name, value) to an update query for an ancestor model. Updates are coalesced so that we only run one update query per ancestor. """ try: - self.related_updates[model].append((column, value, placeholder)) + self.related_updates[model].append((field, None, value)) except KeyError: - self.related_updates[model] = [(column, value, placeholder)] + self.related_updates[model] = [(field, None, value)] def get_related_updates(self): """ @@ -278,53 +129,31 @@ return [] result = [] for model, values in self.related_updates.iteritems(): - query = UpdateQuery(model, self.connection) + query = UpdateQuery(model) query.values = values - if self.related_ids: + if self.related_ids is not None: query.add_filter(('pk__in', self.related_ids)) result.append(query) return result class InsertQuery(Query): + compiler = 'SQLInsertCompiler' + def __init__(self, *args, **kwargs): super(InsertQuery, self).__init__(*args, **kwargs) self.columns = [] self.values = [] self.params = () - self.return_id = False def clone(self, klass=None, **kwargs): - extras = {'columns': self.columns[:], 'values': self.values[:], - 'params': self.params, 'return_id': self.return_id} + extras = { + 'columns': self.columns[:], + 'values': self.values[:], + 'params': self.params + } extras.update(kwargs) return super(InsertQuery, self).clone(klass, **extras) - def as_sql(self): - # We don't need quote_name_unless_alias() here, since these are all - # going to be column names (so we can avoid the extra overhead). - qn = self.connection.ops.quote_name - opts = self.model._meta - result = ['INSERT INTO %s' % qn(opts.db_table)] - result.append('(%s)' % ', '.join([qn(c) for c in self.columns])) - result.append('VALUES (%s)' % ', '.join(self.values)) - params = self.params - if self.return_id and self.connection.features.can_return_id_from_insert: - col = "%s.%s" % (qn(opts.db_table), qn(opts.pk.column)) - r_fmt, r_params = self.connection.ops.return_insert_id() - result.append(r_fmt % col) - params = params + r_params - return ' '.join(result), params - - def execute_sql(self, return_id=False): - self.return_id = return_id - cursor = super(InsertQuery, self).execute_sql(None) - if not (return_id and cursor): - return - if self.connection.features.can_return_id_from_insert: - return self.connection.ops.fetch_returned_insert_id(cursor) - return self.connection.ops.last_insert_id(cursor, - self.model._meta.db_table, self.model._meta.pk.column) - def insert_values(self, insert_values, raw_values=False): """ Set up the insert query from the 'insert_values' dictionary. The @@ -337,17 +166,11 @@ """ placeholders, values = [], [] for field, val in insert_values: - if hasattr(field, 'get_placeholder'): - # Some fields (e.g. geo fields) need special munging before - # they can be inserted. - placeholders.append(field.get_placeholder(val)) - else: - placeholders.append('%s') - + placeholders.append((field, val)) self.columns.append(field.column) values.append(val) if raw_values: - self.values.extend(values) + self.values.extend([(None, v) for v in values]) else: self.params += tuple(values) self.values.extend(placeholders) @@ -358,44 +181,8 @@ date field. This requires some special handling when converting the results back to Python objects, so we put it in a separate class. """ - def __getstate__(self): - """ - Special DateQuery-specific pickle handling. - """ - for elt in self.select: - if isinstance(elt, Date): - # Eliminate a method reference that can't be pickled. The - # __setstate__ method restores this. - elt.date_sql_func = None - return super(DateQuery, self).__getstate__() - def __setstate__(self, obj_dict): - super(DateQuery, self).__setstate__(obj_dict) - for elt in self.select: - if isinstance(elt, Date): - self.date_sql_func = self.connection.ops.date_trunc_sql - - def results_iter(self): - """ - Returns an iterator over the results from executing this query. - """ - resolve_columns = hasattr(self, 'resolve_columns') - if resolve_columns: - from django.db.models.fields import DateTimeField - fields = [DateTimeField()] - else: - from django.db.backends.util import typecast_timestamp - needs_string_cast = self.connection.features.needs_datetime_string_cast - - offset = len(self.extra_select) - for rows in self.execute_sql(MULTI): - for row in rows: - date = row[offset] - if resolve_columns: - date = self.resolve_columns(row, fields)[offset] - elif needs_string_cast: - date = typecast_timestamp(str(date)) - yield date + compiler = 'SQLDateCompiler' def add_date_select(self, field, lookup_type, order='ASC'): """ @@ -404,12 +191,11 @@ result = self.setup_joins([field.name], self.get_meta(), self.get_initial_alias(), False) alias = result[3][-1] - select = Date((alias, field.column), lookup_type, - self.connection.ops.date_trunc_sql) + select = Date((alias, field.column), lookup_type) self.select = [select] self.select_fields = [None] self.select_related = False # See #7097. - self.extra = {} + self.set_extra_mask([]) self.distinct = True self.order_by = order == 'ASC' and [1] or [-1] @@ -418,20 +204,8 @@ An AggregateQuery takes another query as a parameter to the FROM clause and only selects the elements in the provided list. """ - def add_subquery(self, query): - self.subquery, self.sub_params = query.as_sql(with_col_aliases=True) + + compiler = 'SQLAggregateCompiler' - def as_sql(self, quote_func=None): - """ - Creates the SQL for this query. Returns the SQL string and list of - parameters. - """ - sql = ('SELECT %s FROM (%s) subquery' % ( - ', '.join([ - aggregate.as_sql() - for aggregate in self.aggregate_select.values() - ]), - self.subquery) - ) - params = self.sub_params - return (sql, params) + def add_subquery(self, query, using): + self.subquery, self.sub_params = query.get_compiler(using).as_sql(with_col_aliases=True)