--- a/web/lib/django/db/models/query.py Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/db/models/query.py Tue May 25 02:43:45 2010 +0200
@@ -2,19 +2,15 @@
The main QuerySet implementation. This provides the public API for the ORM.
"""
-try:
- set
-except NameError:
- from sets import Set as set # Python 2.3 fallback
+from copy import deepcopy
+from itertools import izip
-from copy import deepcopy
-
-from django.db import connection, transaction, IntegrityError
+from django.db import connections, router, transaction, IntegrityError
from django.db.models.aggregates import Aggregate
from django.db.models.fields import DateField
-from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory
+from django.db.models.query_utils import Q, select_related_descend, CollectedObjects, CyclicDependency, deferred_class_factory, InvalidQuery
from django.db.models import signals, sql
-
+from django.utils.copycompat import deepcopy
# Used to control how many objects are worked with at once in some cases (e.g.
# when deleting objects).
@@ -31,12 +27,15 @@
"""
Represents a lazy database lookup for a set of objects.
"""
- def __init__(self, model=None, query=None):
+ def __init__(self, model=None, query=None, using=None):
self.model = model
- self.query = query or sql.Query(self.model, connection)
+ # EmptyQuerySet instantiates QuerySet with model as None
+ self._db = using
+ self.query = query or sql.Query(self.model)
self._result_cache = None
self._iter = None
self._sticky_filter = False
+ self._for_write = False
########################
# PYTHON MAGIC METHODS #
@@ -46,11 +45,12 @@
"""
Deep copy of a QuerySet doesn't populate the cache
"""
- obj_dict = deepcopy(self.__dict__, memo)
- obj_dict['_iter'] = None
-
obj = self.__class__()
- obj.__dict__.update(obj_dict)
+ for k,v in self.__dict__.items():
+ if k in ('_iter','_result_cache'):
+ obj.__dict__[k] = None
+ else:
+ obj.__dict__[k] = deepcopy(v, memo)
return obj
def __getstate__(self):
@@ -114,6 +114,36 @@
return False
return True
+ def __contains__(self, val):
+ # The 'in' operator works without this method, due to __iter__. This
+ # implementation exists only to shortcut the creation of Model
+ # instances, by bailing out early if we find a matching element.
+ pos = 0
+ if self._result_cache is not None:
+ if val in self._result_cache:
+ return True
+ elif self._iter is None:
+ # iterator is exhausted, so we have our answer
+ return False
+ # remember not to check these again:
+ pos = len(self._result_cache)
+ else:
+ # We need to start filling the result cache out. The following
+ # ensures that self._iter is not None and self._result_cache is not
+ # None
+ it = iter(self)
+
+ # Carry on, one result at a time.
+ while True:
+ if len(self._result_cache) <= pos:
+ self._fill_cache(num=1)
+ if self._iter is None:
+ # we ran out of items
+ return False
+ if self._result_cache[pos] == val:
+ return True
+ pos += 1
+
def __getitem__(self, k):
"""
Retrieves an item or slice from the set of results.
@@ -158,7 +188,7 @@
qs.query.set_limits(k, k + 1)
return list(qs)[0]
except self.model.DoesNotExist, e:
- raise IndexError, e.args
+ raise IndexError(e.args)
def __and__(self, other):
self._merge_sanity_check(other)
@@ -235,10 +265,11 @@
init_list.append(field.attname)
model_cls = deferred_class_factory(self.model, skip)
- for row in self.query.results_iter():
+ compiler = self.query.get_compiler(using=self.db)
+ for row in compiler.results_iter():
if fill_cache:
obj, _ = get_cached_row(self.model, row,
- index_start, max_depth,
+ index_start, using=self.db, max_depth=max_depth,
requested=requested, offset=len(aggregate_select),
only_load=only_load)
else:
@@ -250,6 +281,9 @@
# Omit aggregates in object creation.
obj = self.model(*row[index_start:aggregate_start])
+ # Store the source database of the object
+ obj._state.db = self.db
+
for i, k in enumerate(extra_select):
setattr(obj, k, row[i])
@@ -264,7 +298,7 @@
Returns a dictionary containing the calculations (aggregation)
over the current queryset
- If args is present the expression is passed as a kwarg ussing
+ If args is present the expression is passed as a kwarg using
the Aggregate object's default alias.
"""
for arg in args:
@@ -276,7 +310,7 @@
query.add_aggregate(aggregate_expr, self.model, alias,
is_summary=True)
- return query.get_aggregation()
+ return query.get_aggregation(using=self.db)
def count(self):
"""
@@ -289,7 +323,7 @@
if self._result_cache is not None and not self._iter:
return len(self._result_cache)
- return self.query.get_count()
+ return self.query.get_count(using=self.db)
def get(self, *args, **kwargs):
"""
@@ -297,6 +331,8 @@
keyword arguments.
"""
clone = self.filter(*args, **kwargs)
+ if self.query.can_filter():
+ clone = clone.order_by()
num = len(clone)
if num == 1:
return clone._result_cache[0]
@@ -312,7 +348,8 @@
and returning the created object.
"""
obj = self.model(**kwargs)
- obj.save(force_insert=True)
+ self._for_write = True
+ obj.save(force_insert=True, using=self.db)
return obj
def get_or_create(self, **kwargs):
@@ -325,18 +362,19 @@
'get_or_create() must be passed at least one keyword argument'
defaults = kwargs.pop('defaults', {})
try:
+ self._for_write = True
return self.get(**kwargs), False
except self.model.DoesNotExist:
try:
params = dict([(k, v) for k, v in kwargs.items() if '__' not in k])
params.update(defaults)
obj = self.model(**params)
- sid = transaction.savepoint()
- obj.save(force_insert=True)
- transaction.savepoint_commit(sid)
+ sid = transaction.savepoint(using=self.db)
+ obj.save(force_insert=True, using=self.db)
+ transaction.savepoint_commit(sid, using=self.db)
return obj, True
except IntegrityError, e:
- transaction.savepoint_rollback(sid)
+ transaction.savepoint_rollback(sid, using=self.db)
try:
return self.get(**kwargs), False
except self.model.DoesNotExist:
@@ -363,7 +401,7 @@
"""
assert self.query.can_filter(), \
"Cannot use 'limit' or 'offset' with in_bulk"
- assert isinstance(id_list, (tuple, list)), \
+ assert isinstance(id_list, (tuple, list, set, frozenset)), \
"in_bulk() must be provided with a list of IDs."
if not id_list:
return {}
@@ -380,6 +418,11 @@
del_query = self._clone()
+ # The delete is actually 2 queries - one to find related objects,
+ # and one to delete. Make sure that the discovery of related
+ # objects is performed on the same database as the deletion.
+ del_query._for_write = True
+
# Disable non-supported fields.
del_query.query.select_related = False
del_query.query.clear_ordering()
@@ -387,16 +430,19 @@
# Delete objects in chunks to prevent the list of related objects from
# becoming too long.
seen_objs = None
+ del_itr = iter(del_query)
while 1:
- # Collect all the objects to be deleted in this chunk, and all the
+ # Collect a chunk of objects to be deleted, and then all the
# objects that are related to the objects that are to be deleted.
+ # The chunking *isn't* done by slicing the del_query because we
+ # need to maintain the query cache on del_query (see #12328)
seen_objs = CollectedObjects(seen_objs)
- for object in del_query[:CHUNK_SIZE]:
- object._collect_sub_objects(seen_objs)
+ for i, obj in izip(xrange(CHUNK_SIZE), del_itr):
+ obj._collect_sub_objects(seen_objs)
if not seen_objs:
break
- delete_objects(seen_objs)
+ delete_objects(seen_objs, del_query.db)
# Clear the result cache, in case this QuerySet gets reused.
self._result_cache = None
@@ -409,22 +455,23 @@
"""
assert self.query.can_filter(), \
"Cannot update a query once a slice has been taken."
+ self._for_write = True
query = self.query.clone(sql.UpdateQuery)
query.add_update_values(kwargs)
- if not transaction.is_managed():
- transaction.enter_transaction_management()
+ if not transaction.is_managed(using=self.db):
+ transaction.enter_transaction_management(using=self.db)
forced_managed = True
else:
forced_managed = False
try:
- rows = query.execute_sql(None)
+ rows = query.get_compiler(self.db).execute_sql(None)
if forced_managed:
- transaction.commit()
+ transaction.commit(using=self.db)
else:
- transaction.commit_unless_managed()
+ transaction.commit_unless_managed(using=self.db)
finally:
if forced_managed:
- transaction.leave_transaction_management()
+ transaction.leave_transaction_management(using=self.db)
self._result_cache = None
return rows
update.alters_data = True
@@ -441,9 +488,14 @@
query = self.query.clone(sql.UpdateQuery)
query.add_update_fields(values)
self._result_cache = None
- return query.execute_sql(None)
+ return query.get_compiler(self.db).execute_sql(None)
_update.alters_data = True
+ def exists(self):
+ if self._result_cache is None:
+ return self.query.has_results(using=self.db)
+ return bool(self._result_cache)
+
##################################################
# PUBLIC METHODS THAT RETURN A QUERYSET SUBCLASS #
##################################################
@@ -648,6 +700,14 @@
clone.query.add_immediate_loading(fields)
return clone
+ def using(self, alias):
+ """
+ Selects which database this QuerySet should excecute it's query against.
+ """
+ clone = self._clone()
+ clone._db = alias
+ return clone
+
###################################
# PUBLIC INTROSPECTION ATTRIBUTES #
###################################
@@ -665,6 +725,13 @@
return False
ordered = property(ordered)
+ @property
+ def db(self):
+ "Return the database that will be used if this query is executed now"
+ if self._for_write:
+ return self._db or router.db_for_write(self.model)
+ return self._db or router.db_for_read(self.model)
+
###################
# PRIVATE METHODS #
###################
@@ -675,7 +742,8 @@
query = self.query.clone()
if self._sticky_filter:
query.filter_is_sticky = True
- c = klass(model=self.model, query=query)
+ c = klass(model=self.model, query=query, using=self._db)
+ c._for_write = self._for_write
c.__dict__.update(kwargs)
if setup and hasattr(c, '_setup_query'):
c._setup_query()
@@ -725,12 +793,17 @@
self.query.add_fields(field_names, False)
self.query.set_group_by()
- def _as_sql(self):
+ def _prepare(self):
+ return self
+
+ def _as_sql(self, connection):
"""
Returns the internal query's SQL and parameters (as a tuple).
"""
obj = self.values("pk")
- return obj.query.as_nested_sql()
+ if obj._db is None or connection == connections[obj._db]:
+ return obj.query.get_compiler(connection=connection).as_nested_sql()
+ raise ValueError("Can't do subqueries with queries on different DBs.")
# When used as part of a nested query, a queryset will never be an "always
# empty" result.
@@ -753,7 +826,7 @@
names = extra_names + field_names + aggregate_names
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield dict(zip(names, row))
def _setup_query(self):
@@ -836,7 +909,7 @@
super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
- def _as_sql(self):
+ def _as_sql(self, connection):
"""
For ValueQuerySet (and subclasses like ValuesListQuerySet), they can
only be used as nested queries if they're already set up to select only
@@ -848,15 +921,30 @@
(not self._fields and len(self.model._meta.fields) > 1)):
raise TypeError('Cannot use a multi-field %s as a filter value.'
% self.__class__.__name__)
- return self._clone().query.as_nested_sql()
+
+ obj = self._clone()
+ if obj._db is None or connection == connections[obj._db]:
+ return obj.query.get_compiler(connection=connection).as_nested_sql()
+ raise ValueError("Can't do subqueries with queries on different DBs.")
+
+ def _prepare(self):
+ """
+ Validates that we aren't trying to do a query like
+ value__in=qs.values('value1', 'value2'), which isn't valid.
+ """
+ if ((self._fields and len(self._fields) > 1) or
+ (not self._fields and len(self.model._meta.fields) > 1)):
+ raise TypeError('Cannot use a multi-field %s as a filter value.'
+ % self.__class__.__name__)
+ return self
class ValuesListQuerySet(ValuesQuerySet):
def iterator(self):
if self.flat and len(self._fields) == 1:
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield row[0]
elif not self.query.extra_select and not self.query.aggregate_select:
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
yield tuple(row)
else:
# When extra(select=...) or an annotation is involved, the extra
@@ -871,11 +959,12 @@
# If a field list has been specified, use it. Otherwise, use the
# full list of fields, including extras and aggregates.
if self._fields:
- fields = self._fields
+ fields = list(self._fields) + filter(lambda f: f not in self._fields,
+ aggregate_names)
else:
fields = names
- for row in self.query.results_iter():
+ for row in self.query.get_compiler(self.db).results_iter():
data = dict(zip(names, row))
yield tuple([data[f] for f in fields])
@@ -887,7 +976,7 @@
class DateQuerySet(QuerySet):
def iterator(self):
- return self.query.results_iter()
+ return self.query.get_compiler(self.db).results_iter()
def _setup_query(self):
"""
@@ -916,8 +1005,8 @@
class EmptyQuerySet(QuerySet):
- def __init__(self, model=None, query=None):
- super(EmptyQuerySet, self).__init__(model, query)
+ def __init__(self, model=None, query=None, using=None):
+ super(EmptyQuerySet, self).__init__(model, query, using)
self._result_cache = []
def __and__(self, other):
@@ -942,35 +1031,158 @@
# (it raises StopIteration immediately).
yield iter([]).next()
+ def all(self):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def filter(self, *args, **kwargs):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def exclude(self, *args, **kwargs):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def complex_filter(self, filter_obj):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def select_related(self, *fields, **kwargs):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def annotate(self, *args, **kwargs):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def order_by(self, *field_names):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def distinct(self, true_or_false=True):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def extra(self, select=None, where=None, params=None, tables=None,
+ order_by=None, select_params=None):
+ """
+ Always returns EmptyQuerySet.
+ """
+ assert self.query.can_filter(), \
+ "Cannot change a query once a slice has been taken"
+ return self
+
+ def reverse(self):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def defer(self, *fields):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def only(self, *fields):
+ """
+ Always returns EmptyQuerySet.
+ """
+ return self
+
+ def update(self, **kwargs):
+ """
+ Don't update anything.
+ """
+ return 0
+
# EmptyQuerySet is always an empty result in where-clauses (and similar
# situations).
value_annotation = False
-def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
- requested=None, offset=0, only_load=None):
+def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
+ requested=None, offset=0, only_load=None, local_only=False):
"""
Helper function that recursively returns an object with the specified
related attributes already populated.
+
+ This method may be called recursively to populate deep select_related()
+ clauses.
+
+ Arguments:
+ * klass - the class to retrieve (and instantiate)
+ * row - the row of data returned by the database cursor
+ * index_start - the index of the row at which data for this
+ object is known to start
+ * using - the database alias on which the query is being executed.
+ * max_depth - the maximum depth to which a select_related()
+ relationship should be explored.
+ * cur_depth - the current depth in the select_related() tree.
+ Used in recursive calls to determin if we should dig deeper.
+ * requested - A dictionary describing the select_related() tree
+ that is to be retrieved. keys are field names; values are
+ dictionaries describing the keys on that related object that
+ are themselves to be select_related().
+ * offset - the number of additional fields that are known to
+ exist in `row` for `klass`. This usually means the number of
+ annotated results on `klass`.
+ * only_load - if the query has had only() or defer() applied,
+ this is the list of field names that will be returned. If None,
+ the full field list for `klass` can be assumed.
+ * local_only - Only populate local fields. This is used when building
+ following reverse select-related relations
"""
if max_depth and requested is None and cur_depth > max_depth:
# We've recursed deeply enough; stop now.
return None
restricted = requested is not None
- load_fields = only_load and only_load.get(klass) or None
+ if only_load:
+ load_fields = only_load.get(klass)
+ # When we create the object, we will also be creating populating
+ # all the parent classes, so traverse the parent classes looking
+ # for fields that must be included on load.
+ for parent in klass._meta.get_parent_list():
+ fields = only_load.get(parent)
+ if fields:
+ load_fields.update(fields)
+ else:
+ load_fields = None
if load_fields:
# Handle deferred fields.
skip = set()
init_list = []
- pk_val = row[index_start + klass._meta.pk_index()]
- for field in klass._meta.fields:
+ # Build the list of fields that *haven't* been requested
+ for field, model in klass._meta.get_fields_with_model():
if field.name not in load_fields:
skip.add(field.name)
+ elif local_only and model is not None:
+ continue
else:
init_list.append(field.attname)
+ # Retrieve all the requested fields
field_count = len(init_list)
fields = row[index_start : index_start + field_count]
+ # If all the select_related columns are None, then the related
+ # object must be non-existent - set the relation to None.
+ # Otherwise, construct the related object.
if fields == (None,) * field_count:
obj = None
elif skip:
@@ -978,15 +1190,30 @@
obj = klass(**dict(zip(init_list, fields)))
else:
obj = klass(*fields)
+
else:
- field_count = len(klass._meta.fields)
+ # Load all fields on klass
+ if local_only:
+ field_names = [f.attname for f in klass._meta.local_fields]
+ else:
+ field_names = [f.attname for f in klass._meta.fields]
+ field_count = len(field_names)
fields = row[index_start : index_start + field_count]
+ # If all the select_related columns are None, then the related
+ # object must be non-existent - set the relation to None.
+ # Otherwise, construct the related object.
if fields == (None,) * field_count:
obj = None
else:
- obj = klass(*fields)
+ obj = klass(**dict(zip(field_names, fields)))
+
+ # If an object was retrieved, set the database state.
+ if obj:
+ obj._state.db = using
index_end = index_start + field_count + offset
+ # Iterate over each related object, populating any
+ # select_related() fields
for f in klass._meta.fields:
if not select_related_descend(f, restricted, requested):
continue
@@ -994,21 +1221,74 @@
next = requested[f.name]
else:
next = None
- cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
- cur_depth+1, next)
+ # Recursively retrieve the data for the related object
+ cached_row = get_cached_row(f.rel.to, row, index_end, using,
+ max_depth, cur_depth+1, next, only_load=only_load)
+ # If the recursive descent found an object, populate the
+ # descriptor caches relevant to the object
if cached_row:
rel_obj, index_end = cached_row
if obj is not None:
+ # If the base object exists, populate the
+ # descriptor cache
setattr(obj, f.get_cache_name(), rel_obj)
+ if f.unique and rel_obj is not None:
+ # If the field is unique, populate the
+ # reverse descriptor cache on the related object
+ setattr(rel_obj, f.related.get_cache_name(), obj)
+
+ # Now do the same, but for reverse related objects.
+ # Only handle the restricted case - i.e., don't do a depth
+ # descent into reverse relations unless explicitly requested
+ if restricted:
+ related_fields = [
+ (o.field, o.model)
+ for o in klass._meta.get_all_related_objects()
+ if o.field.unique
+ ]
+ for f, model in related_fields:
+ if not select_related_descend(f, restricted, requested, reverse=True):
+ continue
+ next = requested[f.related_query_name()]
+ # Recursively retrieve the data for the related object
+ cached_row = get_cached_row(model, row, index_end, using,
+ max_depth, cur_depth+1, next, only_load=only_load, local_only=True)
+ # If the recursive descent found an object, populate the
+ # descriptor caches relevant to the object
+ if cached_row:
+ rel_obj, index_end = cached_row
+ if obj is not None:
+ # If the field is unique, populate the
+ # reverse descriptor cache
+ setattr(obj, f.related.get_cache_name(), rel_obj)
+ if rel_obj is not None:
+ # If the related object exists, populate
+ # the descriptor cache.
+ setattr(rel_obj, f.get_cache_name(), obj)
+ # Now populate all the non-local field values
+ # on the related object
+ for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
+ if rel_model is not None:
+ setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
+ # populate the field cache for any related object
+ # that has already been retrieved
+ if rel_field.rel:
+ try:
+ cached_obj = getattr(obj, rel_field.get_cache_name())
+ setattr(rel_obj, rel_field.get_cache_name(), cached_obj)
+ except AttributeError:
+ # Related object hasn't been cached yet
+ pass
return obj, index_end
-def delete_objects(seen_objs):
+def delete_objects(seen_objs, using):
"""
Iterate through a list of seen classes, and remove any instances that are
referred to.
"""
- if not transaction.is_managed():
- transaction.enter_transaction_management()
+ connection = connections[using]
+ if not transaction.is_managed(using=using):
+ transaction.enter_transaction_management(using=using)
forced_managed = True
else:
forced_managed = False
@@ -1030,22 +1310,20 @@
# Pre-notify all instances to be deleted.
for pk_val, instance in items:
- signals.pre_delete.send(sender=cls, instance=instance)
+ if not cls._meta.auto_created:
+ signals.pre_delete.send(sender=cls, instance=instance)
pk_list = [pk for pk,instance in items]
- del_query = sql.DeleteQuery(cls, connection)
- del_query.delete_batch_related(pk_list)
- update_query = sql.UpdateQuery(cls, connection)
+ update_query = sql.UpdateQuery(cls)
for field, model in cls._meta.get_fields_with_model():
if (field.rel and field.null and field.rel.to in seen_objs and
filter(lambda f: f.column == field.rel.get_related_field().column,
field.rel.to._meta.fields)):
if model:
- sql.UpdateQuery(model, connection).clear_related(field,
- pk_list)
+ sql.UpdateQuery(model).clear_related(field, pk_list, using=using)
else:
- update_query.clear_related(field, pk_list)
+ update_query.clear_related(field, pk_list, using=using)
# Now delete the actual data.
for cls in ordered_classes:
@@ -1053,8 +1331,8 @@
items.reverse()
pk_list = [pk for pk,instance in items]
- del_query = sql.DeleteQuery(cls, connection)
- del_query.delete_batch(pk_list)
+ del_query = sql.DeleteQuery(cls)
+ del_query.delete_batch(pk_list, using=using)
# Last cleanup; set NULLs where there once was a reference to the
# object, NULL the primary key of the found objects, and perform
@@ -1064,24 +1342,138 @@
if field.rel and field.null and field.rel.to in seen_objs:
setattr(instance, field.attname, None)
- signals.post_delete.send(sender=cls, instance=instance)
+ if not cls._meta.auto_created:
+ signals.post_delete.send(sender=cls, instance=instance)
setattr(instance, cls._meta.pk.attname, None)
if forced_managed:
- transaction.commit()
+ transaction.commit(using=using)
else:
- transaction.commit_unless_managed()
+ transaction.commit_unless_managed(using=using)
finally:
if forced_managed:
- transaction.leave_transaction_management()
+ transaction.leave_transaction_management(using=using)
+
+class RawQuerySet(object):
+ """
+ Provides an iterator which converts the results of raw SQL queries into
+ annotated model instances.
+ """
+ def __init__(self, raw_query, model=None, query=None, params=None,
+ translations=None, using=None):
+ self.raw_query = raw_query
+ self.model = model
+ self._db = using
+ self.query = query or sql.RawQuery(sql=raw_query, using=self.db, params=params)
+ self.params = params or ()
+ self.translations = translations or {}
+
+ def __iter__(self):
+ for row in self.query:
+ yield self.transform_results(row)
+
+ def __repr__(self):
+ return "<RawQuerySet: %r>" % (self.raw_query % self.params)
+
+ def __getitem__(self, k):
+ return list(self)[k]
+
+ @property
+ def db(self):
+ "Return the database that will be used if this query is executed now"
+ return self._db or router.db_for_read(self.model)
+
+ def using(self, alias):
+ """
+ Selects which database this Raw QuerySet should excecute it's query against.
+ """
+ return RawQuerySet(self.raw_query, model=self.model,
+ query=self.query.clone(using=alias),
+ params=self.params, translations=self.translations,
+ using=alias)
+
+ @property
+ def columns(self):
+ """
+ A list of model field names in the order they'll appear in the
+ query results.
+ """
+ if not hasattr(self, '_columns'):
+ self._columns = self.query.get_columns()
+
+ # Adjust any column names which don't match field names
+ for (query_name, model_name) in self.translations.items():
+ try:
+ index = self._columns.index(query_name)
+ self._columns[index] = model_name
+ except ValueError:
+ # Ignore translations for non-existant column names
+ pass
+ return self._columns
-def insert_query(model, values, return_id=False, raw_values=False):
+ @property
+ def model_fields(self):
+ """
+ A dict mapping column names to model field names.
+ """
+ if not hasattr(self, '_model_fields'):
+ converter = connections[self.db].introspection.table_name_converter
+ self._model_fields = {}
+ for field in self.model._meta.fields:
+ name, column = field.get_attname_column()
+ self._model_fields[converter(column)] = field
+ return self._model_fields
+
+ def transform_results(self, values):
+ model_init_kwargs = {}
+ annotations = ()
+
+ # Perform database backend type resolution
+ connection = connections[self.db]
+ compiler = connection.ops.compiler('SQLCompiler')(self.query, connection, self.db)
+ if hasattr(compiler, 'resolve_columns'):
+ fields = [self.model_fields.get(c,None) for c in self.columns]
+ values = compiler.resolve_columns(values, fields)
+
+ # Associate fields to values
+ for pos, value in enumerate(values):
+ column = self.columns[pos]
+
+ # Separate properties from annotations
+ if column in self.model_fields.keys():
+ model_init_kwargs[self.model_fields[column].attname] = value
+ else:
+ annotations += (column, value),
+
+ # Construct model instance and apply annotations
+ skip = set()
+ for field in self.model._meta.fields:
+ if field.attname not in model_init_kwargs.keys():
+ skip.add(field.attname)
+
+ if skip:
+ if self.model._meta.pk.attname in skip:
+ raise InvalidQuery('Raw query must include the primary key')
+ model_cls = deferred_class_factory(self.model, skip)
+ else:
+ model_cls = self.model
+
+ instance = model_cls(**model_init_kwargs)
+
+ for field, value in annotations:
+ setattr(instance, field, value)
+
+ instance._state.db = self.query.using
+
+ return instance
+
+def insert_query(model, values, return_id=False, raw_values=False, using=None):
"""
Inserts a new record for the given model. This provides an interface to
the InsertQuery class and is how Model.save() is implemented. It is not
part of the public API.
"""
- query = sql.InsertQuery(model, connection)
+ query = sql.InsertQuery(model)
query.insert_values(values, raw_values)
- return query.execute_sql(return_id)
+ return query.get_compiler(using=using).execute_sql(return_id)