diff -r b758351d191f -r cc9b7e14412b web/lib/django/db/models/fields/related.py --- a/web/lib/django/db/models/fields/related.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/db/models/fields/related.py Tue May 25 02:43:45 2010 +0200 @@ -1,20 +1,18 @@ -from django.db import connection, transaction +from django.conf import settings +from django.db import connection, router, transaction from django.db.backends import util from django.db.models import signals, get_model -from django.db.models.fields import AutoField, Field, IntegerField, PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist +from django.db.models.fields import (AutoField, Field, IntegerField, + PositiveIntegerField, PositiveSmallIntegerField, FieldDoesNotExist) from django.db.models.related import RelatedObject from django.db.models.query import QuerySet from django.db.models.query_utils import QueryWrapper from django.utils.encoding import smart_unicode -from django.utils.translation import ugettext_lazy, string_concat, ungettext, ugettext as _ +from django.utils.translation import ugettext_lazy as _, string_concat, ungettext, ugettext from django.utils.functional import curry from django.core import exceptions from django import forms -try: - set -except NameError: - from sets import Set as set # Python 2.3 fallback RECURSIVE_RELATIONSHIP_CONSTANT = 'self' @@ -58,6 +56,10 @@ # If we can't split, assume a model in current app app_label = cls._meta.app_label model_name = relation + except AttributeError: + # If it doesn't have a split it's actually a model class + app_label = relation._meta.app_label + model_name = relation._meta.object_name # Try to look up the related model, and if it's already loaded resolve the # string right away. If get_model returns None, it means that the related @@ -86,17 +88,20 @@ def contribute_to_class(self, cls, name): sup = super(RelatedField, self) - # Add an accessor to allow easy determination of the related query path for this field - self.related_query_name = curry(self._get_related_query_name, cls._meta) + # Store the opts for related_query_name() + self.opts = cls._meta if hasattr(sup, 'contribute_to_class'): sup.contribute_to_class(cls, name) if not cls._meta.abstract and self.rel.related_name: - self.rel.related_name = self.rel.related_name % {'class': cls.__name__.lower()} + self.rel.related_name = self.rel.related_name % { + 'class': cls.__name__.lower(), + 'app_label': cls._meta.app_label.lower(), + } other = self.rel.to - if isinstance(other, basestring): + if isinstance(other, basestring) or other._meta.pk is None: def resolve_related_class(field, model, cls): field.rel.to = model field.do_related_class(model, cls) @@ -116,31 +121,28 @@ if not cls._meta.abstract: self.contribute_to_related_class(other, self.related) - def get_db_prep_lookup(self, lookup_type, value): - # If we are doing a lookup on a Related Field, we must be - # comparing object instances. The value should be the PK of value, - # not value itself. - def pk_trace(value): - # Value may be a primary key, or an object held in a relation. - # If it is an object, then we need to get the primary key value for - # that object. In certain conditions (especially one-to-one relations), - # the primary key may itself be an object - so we need to keep drilling - # down until we hit a value that can be used for a comparison. - v, field = value, None - try: - while True: - v, field = getattr(v, v._meta.pk.name), v._meta.pk - except AttributeError: - pass + def get_prep_lookup(self, lookup_type, value): + if hasattr(value, 'prepare'): + return value.prepare() + if hasattr(value, '_prepare'): + return value._prepare() + # FIXME: lt and gt are explicitly allowed to make + # get_(next/prev)_by_date work; other lookups are not allowed since that + # gets messy pretty quick. This is a good candidate for some refactoring + # in the future. + if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']: + return self._pk_trace(value, 'get_prep_lookup', lookup_type) + if lookup_type in ('range', 'in'): + return [self._pk_trace(v, 'get_prep_lookup', lookup_type) for v in value] + elif lookup_type == 'isnull': + return [] + raise TypeError("Related Field has invalid lookup: %s" % lookup_type) - if field: - if lookup_type in ('range', 'in'): - v = [v] - v = field.get_db_prep_lookup(lookup_type, v) - if isinstance(v, list): - v = v[0] - return v - + def get_db_prep_lookup(self, lookup_type, value, connection, prepared=False): + if not prepared: + value = self.get_prep_lookup(lookup_type, value) + if hasattr(value, 'get_compiler'): + value = value.get_compiler(connection=connection) if hasattr(value, 'as_sql') or hasattr(value, '_as_sql'): # If the value has a relabel_aliases method, it will need to # be invoked before the final SQL is evaluated @@ -149,27 +151,59 @@ if hasattr(value, 'as_sql'): sql, params = value.as_sql() else: - sql, params = value._as_sql() + sql, params = value._as_sql(connection=connection) return QueryWrapper(('(%s)' % sql), params) - # FIXME: lt and gt are explicitally allowed to make + # FIXME: lt and gt are explicitly allowed to make # get_(next/prev)_by_date work; other lookups are not allowed since that # gets messy pretty quick. This is a good candidate for some refactoring # in the future. if lookup_type in ['exact', 'gt', 'lt', 'gte', 'lte']: - return [pk_trace(value)] + return [self._pk_trace(value, 'get_db_prep_lookup', lookup_type, + connection=connection, prepared=prepared)] if lookup_type in ('range', 'in'): - return [pk_trace(v) for v in value] + return [self._pk_trace(v, 'get_db_prep_lookup', lookup_type, + connection=connection, prepared=prepared) + for v in value] elif lookup_type == 'isnull': return [] - raise TypeError, "Related Field has invalid lookup: %s" % lookup_type + raise TypeError("Related Field has invalid lookup: %s" % lookup_type) - def _get_related_query_name(self, opts): + def _pk_trace(self, value, prep_func, lookup_type, **kwargs): + # Value may be a primary key, or an object held in a relation. + # If it is an object, then we need to get the primary key value for + # that object. In certain conditions (especially one-to-one relations), + # the primary key may itself be an object - so we need to keep drilling + # down until we hit a value that can be used for a comparison. + v = value + try: + while True: + v = getattr(v, v._meta.pk.name) + except AttributeError: + pass + except exceptions.ObjectDoesNotExist: + v = None + + field = self + while field.rel: + if hasattr(field.rel, 'field_name'): + field = field.rel.to._meta.get_field(field.rel.field_name) + else: + field = field.rel.to._meta.pk + + if lookup_type in ('range', 'in'): + v = [v] + v = getattr(field, prep_func)(lookup_type, v, **kwargs) + if isinstance(v, list): + v = v[0] + return v + + def related_query_name(self): # This method defines the name that can be used to identify this # related object in a table-spanning query. It uses the lower-cased # object_name by default, but this can be overridden with the # "related_name" option. - return self.rel.related_name or opts.object_name.lower() + return self.rel.related_name or self.opts.object_name.lower() class SingleRelatedObjectDescriptor(object): # This class provides the functionality that makes the related-object @@ -179,7 +213,7 @@ # SingleRelatedObjectDescriptor instance. def __init__(self, related): self.related = related - self.cache_name = '_%s_cache' % related.get_accessor_name() + self.cache_name = related.get_cache_name() def __get__(self, instance, instance_type=None): if instance is None: @@ -188,13 +222,14 @@ return getattr(instance, self.cache_name) except AttributeError: params = {'%s__pk' % self.related.field.name: instance._get_pk_val()} - rel_obj = self.related.model._base_manager.get(**params) + db = router.db_for_read(self.related.model, instance=instance) + rel_obj = self.related.model._base_manager.using(db).get(**params) setattr(instance, self.cache_name, rel_obj) return rel_obj def __set__(self, instance, value): if instance is None: - raise AttributeError, "%s must be accessed via instance" % self.related.opts.object_name + raise AttributeError("%s must be accessed via instance" % self.related.opts.object_name) # The similarity of the code below to the code in # ReverseSingleRelatedObjectDescriptor is annoying, but there's a bunch @@ -209,6 +244,15 @@ raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (value, instance._meta.object_name, self.related.get_accessor_name(), self.related.opts.object_name)) + elif value is not None: + if instance._state.db is None: + instance._state.db = router.db_for_write(instance.__class__, instance=value) + elif value._state.db is None: + value._state.db = router.db_for_write(value.__class__, instance=instance) + elif value._state.db is not None and instance._state.db is not None: + if not router.allow_relation(value, instance): + raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' % + (value, instance._state.db, value._state.db)) # Set the value of the related field to the value of the related object's related field setattr(value, self.related.field.attname, getattr(instance, self.related.field.rel.get_related_field().attname)) @@ -251,16 +295,17 @@ # If the related manager indicates that it should be used for # related fields, respect that. rel_mgr = self.field.rel.to._default_manager + db = router.db_for_read(self.field.rel.to, instance=instance) if getattr(rel_mgr, 'use_for_related_fields', False): - rel_obj = rel_mgr.get(**params) + rel_obj = rel_mgr.using(db).get(**params) else: - rel_obj = QuerySet(self.field.rel.to).get(**params) + rel_obj = QuerySet(self.field.rel.to).using(db).get(**params) setattr(instance, cache_name, rel_obj) return rel_obj def __set__(self, instance, value): if instance is None: - raise AttributeError, "%s must be accessed via instance" % self._field.name + raise AttributeError("%s must be accessed via instance" % self._field.name) # If null=True, we can assign null here, but otherwise the value needs # to be an instance of the related class. @@ -271,6 +316,15 @@ raise ValueError('Cannot assign "%r": "%s.%s" must be a "%s" instance.' % (value, instance._meta.object_name, self.field.name, self.field.rel.to._meta.object_name)) + elif value is not None: + if instance._state.db is None: + instance._state.db = router.db_for_write(instance.__class__, instance=value) + elif value._state.db is None: + value._state.db = router.db_for_write(value.__class__, instance=instance) + elif value._state.db is not None and instance._state.db is not None: + if not router.allow_relation(value, instance): + raise ValueError('Cannot assign "%r": instance is on database "%s", value is on database "%s"' % + (value, instance._state.db, value._state.db)) # If we're setting the value of a OneToOneField to None, we need to clear # out the cache on any old related object. Otherwise, deleting the @@ -289,7 +343,7 @@ # cache. This cache also might not exist if the related object # hasn't been accessed yet. if related: - cache_name = '_%s_cache' % self.field.related.get_accessor_name() + cache_name = self.field.related.get_cache_name() try: delattr(related, cache_name) except AttributeError: @@ -325,7 +379,7 @@ def __set__(self, instance, value): if instance is None: - raise AttributeError, "Manager must be accessed via instance" + raise AttributeError("Manager must be accessed via instance") manager = self.__get__(instance) # If the foreign key can support nulls, then completely clear the related set. @@ -352,26 +406,29 @@ class RelatedManager(superclass): def get_query_set(self): - return superclass.get_query_set(self).filter(**(self.core_filters)) + db = self._db or router.db_for_read(rel_model, instance=instance) + return superclass.get_query_set(self).using(db).filter(**(self.core_filters)) def add(self, *objs): for obj in objs: if not isinstance(obj, self.model): - raise TypeError, "'%s' instance expected" % self.model._meta.object_name + raise TypeError("'%s' instance expected" % self.model._meta.object_name) setattr(obj, rel_field.name, instance) obj.save() add.alters_data = True def create(self, **kwargs): kwargs.update({rel_field.name: instance}) - return super(RelatedManager, self).create(**kwargs) + db = router.db_for_write(rel_model, instance=instance) + return super(RelatedManager, self).using(db).create(**kwargs) create.alters_data = True def get_or_create(self, **kwargs): # Update kwargs with the related object that this # ForeignRelatedObjectsDescriptor knows about. kwargs.update({rel_field.name: instance}) - return super(RelatedManager, self).get_or_create(**kwargs) + db = router.db_for_write(rel_model, instance=instance) + return super(RelatedManager, self).using(db).get_or_create(**kwargs) get_or_create.alters_data = True # remove() and clear() are only provided if the ForeignKey can have a value of null. @@ -384,7 +441,7 @@ setattr(obj, rel_field.name, None) obj.save() else: - raise rel_field.rel.to.DoesNotExist, "%r is not related to %r." % (obj, instance) + raise rel_field.rel.to.DoesNotExist("%r is not related to %r." % (obj, instance)) remove.alters_data = True def clear(self): @@ -401,68 +458,74 @@ return manager -def create_many_related_manager(superclass, through=False): +def create_many_related_manager(superclass, rel=False): """Creates a manager that subclasses 'superclass' (which is a Manager) and adds behavior for many-to-many related objects.""" + through = rel.through class ManyRelatedManager(superclass): def __init__(self, model=None, core_filters=None, instance=None, symmetrical=None, - join_table=None, source_col_name=None, target_col_name=None): + join_table=None, source_field_name=None, target_field_name=None, + reverse=False): super(ManyRelatedManager, self).__init__() self.core_filters = core_filters self.model = model self.symmetrical = symmetrical self.instance = instance - self.join_table = join_table - self.source_col_name = source_col_name - self.target_col_name = target_col_name + self.source_field_name = source_field_name + self.target_field_name = target_field_name self.through = through - self._pk_val = self.instance._get_pk_val() + self._pk_val = self.instance.pk + self.reverse = reverse if self._pk_val is None: raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) def get_query_set(self): - return superclass.get_query_set(self)._next_is_sticky().filter(**(self.core_filters)) + db = self._db or router.db_for_read(self.instance.__class__, instance=self.instance) + return superclass.get_query_set(self).using(db)._next_is_sticky().filter(**(self.core_filters)) # If the ManyToMany relation has an intermediary model, # the add and remove methods do not exist. - if through is None: + if rel.through._meta.auto_created: def add(self, *objs): - self._add_items(self.source_col_name, self.target_col_name, *objs) + self._add_items(self.source_field_name, self.target_field_name, *objs) # If this is a symmetrical m2m relation to self, add the mirror entry in the m2m table if self.symmetrical: - self._add_items(self.target_col_name, self.source_col_name, *objs) + self._add_items(self.target_field_name, self.source_field_name, *objs) add.alters_data = True def remove(self, *objs): - self._remove_items(self.source_col_name, self.target_col_name, *objs) + self._remove_items(self.source_field_name, self.target_field_name, *objs) # If this is a symmetrical m2m relation to self, remove the mirror entry in the m2m table if self.symmetrical: - self._remove_items(self.target_col_name, self.source_col_name, *objs) + self._remove_items(self.target_field_name, self.source_field_name, *objs) remove.alters_data = True def clear(self): - self._clear_items(self.source_col_name) + self._clear_items(self.source_field_name) # If this is a symmetrical m2m relation to self, clear the mirror entry in the m2m table if self.symmetrical: - self._clear_items(self.target_col_name) + self._clear_items(self.target_field_name) clear.alters_data = True def create(self, **kwargs): # This check needs to be done here, since we can't later remove this # from the method lookup table, as we do with add and remove. - if through is not None: - raise AttributeError, "Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through - new_obj = super(ManyRelatedManager, self).create(**kwargs) + if not rel.through._meta.auto_created: + opts = through._meta + raise AttributeError("Cannot use create() on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)) + db = router.db_for_write(self.instance.__class__, instance=self.instance) + new_obj = super(ManyRelatedManager, self).using(db).create(**kwargs) self.add(new_obj) return new_obj create.alters_data = True def get_or_create(self, **kwargs): + db = router.db_for_write(self.instance.__class__, instance=self.instance) obj, created = \ - super(ManyRelatedManager, self).get_or_create(**kwargs) + super(ManyRelatedManager, self).using(db).get_or_create(**kwargs) # We only need to add() if created because if we got an object back # from get() then the relationship already exists. if created: @@ -470,41 +533,54 @@ return obj, created get_or_create.alters_data = True - def _add_items(self, source_col_name, target_col_name, *objs): + def _add_items(self, source_field_name, target_field_name, *objs): # join_table: name of the m2m link table - # source_col_name: the PK colname in join_table for the source object - # target_col_name: the PK colname in join_table for the target object + # source_field_name: the PK fieldname in join_table for the source object + # target_field_name: the PK fieldname in join_table for the target object # *objs - objects to add. Either object instances, or primary keys of object instances. # If there aren't any objects, there is nothing to do. + from django.db.models import Model if objs: - from django.db.models.base import Model - # Check that all the objects are of the right type new_ids = set() for obj in objs: if isinstance(obj, self.model): - new_ids.add(obj._get_pk_val()) + if not router.allow_relation(obj, self.instance): + raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % + (obj, self.instance._state.db, obj._state.db)) + new_ids.add(obj.pk) elif isinstance(obj, Model): - raise TypeError, "'%s' instance expected" % self.model._meta.object_name + raise TypeError("'%s' instance expected" % self.model._meta.object_name) else: new_ids.add(obj) - # Add the newly created or already existing objects to the join table. - # First find out which items are already added, to avoid adding them twice - cursor = connection.cursor() - cursor.execute("SELECT %s FROM %s WHERE %s = %%s AND %s IN (%s)" % \ - (target_col_name, self.join_table, source_col_name, - target_col_name, ",".join(['%s'] * len(new_ids))), - [self._pk_val] + list(new_ids)) - existing_ids = set([row[0] for row in cursor.fetchall()]) + db = router.db_for_write(self.through.__class__, instance=self.instance) + vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True) + vals = vals.filter(**{ + source_field_name: self._pk_val, + '%s__in' % target_field_name: new_ids, + }) + new_ids = new_ids - set(vals) + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are inserting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action='pre_add', + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=new_ids) # Add the ones that aren't there already - for obj_id in (new_ids - existing_ids): - cursor.execute("INSERT INTO %s (%s, %s) VALUES (%%s, %%s)" % \ - (self.join_table, source_col_name, target_col_name), - [self._pk_val, obj_id]) - transaction.commit_unless_managed() + for obj_id in new_ids: + self.through._default_manager.using(db).create(**{ + '%s_id' % source_field_name: self._pk_val, + '%s_id' % target_field_name: obj_id, + }) + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are inserting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action='post_add', + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=new_ids) - def _remove_items(self, source_col_name, target_col_name, *objs): + def _remove_items(self, source_field_name, target_field_name, *objs): # source_col_name: the PK colname in join_table for the source object # target_col_name: the PK colname in join_table for the target object # *objs - objects to remove @@ -515,24 +591,46 @@ old_ids = set() for obj in objs: if isinstance(obj, self.model): - old_ids.add(obj._get_pk_val()) + old_ids.add(obj.pk) else: old_ids.add(obj) + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are deleting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action="pre_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids) # Remove the specified objects from the join table - cursor = connection.cursor() - cursor.execute("DELETE FROM %s WHERE %s = %%s AND %s IN (%s)" % \ - (self.join_table, source_col_name, - target_col_name, ",".join(['%s'] * len(old_ids))), - [self._pk_val] + list(old_ids)) - transaction.commit_unless_managed() + db = router.db_for_write(self.through.__class__, instance=self.instance) + self.through._default_manager.using(db).filter(**{ + source_field_name: self._pk_val, + '%s__in' % target_field_name: old_ids + }).delete() + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are deleting the + # duplicate data row for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action="post_remove", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=old_ids) - def _clear_items(self, source_col_name): + def _clear_items(self, source_field_name): # source_col_name: the PK colname in join_table for the source object - cursor = connection.cursor() - cursor.execute("DELETE FROM %s WHERE %s = %%s" % \ - (self.join_table, source_col_name), - [self._pk_val]) - transaction.commit_unless_managed() + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are clearing the + # duplicate data rows for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action="pre_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None) + db = router.db_for_write(self.through.__class__, instance=self.instance) + self.through._default_manager.using(db).filter(**{ + source_field_name: self._pk_val + }).delete() + if self.reverse or source_field_name == self.source_field_name: + # Don't send the signal when we are clearing the + # duplicate data rows for symmetrical reverse entries. + signals.m2m_changed.send(sender=rel.through, action="post_clear", + instance=self.instance, reverse=self.reverse, + model=self.model, pk_set=None) return ManyRelatedManager @@ -554,33 +652,33 @@ # model's default manager. rel_model = self.related.model superclass = rel_model._default_manager.__class__ - RelatedManager = create_many_related_manager(superclass, self.related.field.rel.through) + RelatedManager = create_many_related_manager(superclass, self.related.field.rel) - qn = connection.ops.quote_name manager = RelatedManager( model=rel_model, core_filters={'%s__pk' % self.related.field.name: instance._get_pk_val()}, instance=instance, symmetrical=False, - join_table=qn(self.related.field.m2m_db_table()), - source_col_name=qn(self.related.field.m2m_reverse_name()), - target_col_name=qn(self.related.field.m2m_column_name()) + source_field_name=self.related.field.m2m_reverse_field_name(), + target_field_name=self.related.field.m2m_field_name(), + reverse=True ) return manager def __set__(self, instance, value): if instance is None: - raise AttributeError, "Manager must be accessed via instance" + raise AttributeError("Manager must be accessed via instance") - through = getattr(self.related.field.rel, 'through', None) - if through is not None: - raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through + if not self.related.field.rel.through._meta.auto_created: + opts = self.related.field.rel.through._meta + raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)) manager = self.__get__(instance) manager.clear() manager.add(*value) + class ReverseManyRelatedObjectsDescriptor(object): # This class provides the functionality that makes the related-object # managers available as attributes on a model class, for fields that have @@ -591,6 +689,13 @@ def __init__(self, m2m_field): self.field = m2m_field + def _through(self): + # through is provided so that you have easy access to the through + # model (Book.authors.through) for inlines, etc. This is done as + # a property to ensure that the fully resolved value is returned. + return self.field.rel.through + through = property(_through) + def __get__(self, instance, instance_type=None): if instance is None: return self @@ -599,28 +704,27 @@ # model's default manager. rel_model=self.field.rel.to superclass = rel_model._default_manager.__class__ - RelatedManager = create_many_related_manager(superclass, self.field.rel.through) + RelatedManager = create_many_related_manager(superclass, self.field.rel) - qn = connection.ops.quote_name manager = RelatedManager( model=rel_model, core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()}, instance=instance, - symmetrical=(self.field.rel.symmetrical and isinstance(instance, rel_model)), - join_table=qn(self.field.m2m_db_table()), - source_col_name=qn(self.field.m2m_column_name()), - target_col_name=qn(self.field.m2m_reverse_name()) + symmetrical=self.field.rel.symmetrical, + source_field_name=self.field.m2m_field_name(), + target_field_name=self.field.m2m_reverse_field_name(), + reverse=False ) return manager def __set__(self, instance, value): if instance is None: - raise AttributeError, "Manager must be accessed via instance" + raise AttributeError("Manager must be accessed via instance") - through = getattr(self.field.rel, 'through', None) - if through is not None: - raise AttributeError, "Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s's Manager instead." % through + if not self.field.rel.through._meta.auto_created: + opts = self.field.rel.through._meta + raise AttributeError("Cannot set values on a ManyToManyField which specifies an intermediary model. Use %s.%s's Manager instead." % (opts.app_label, opts.object_name)) manager = self.__get__(instance) manager.clear() @@ -642,6 +746,10 @@ self.multiple = True self.parent_link = parent_link + def is_hidden(self): + "Should the related object be hidden?" + return self.related_name and self.related_name[-1] == '+' + def get_related_field(self): """ Returns the Field in the 'to' object to which this relationship is @@ -673,6 +781,10 @@ self.multiple = True self.through = through + def is_hidden(self): + "Should the related object be hidden?" + return self.related_name and self.related_name[-1] == '+' + def get_related_field(self): """ Returns the field in the to' object to which this relationship is tied @@ -683,6 +795,10 @@ class ForeignKey(RelatedField, Field): empty_strings_allowed = False + default_error_messages = { + 'invalid': _('Model %(model)s with pk %(pk)r does not exist.') + } + description = _("Foreign Key (type determined by related field)") def __init__(self, to, to_field=None, rel_class=ManyToOneRel, **kwargs): try: to_name = to._meta.object_name.lower() @@ -690,7 +806,10 @@ assert isinstance(to, basestring), "%s(%r) is invalid. First parameter to ForeignKey must be either a model, a model name, or the string %r" % (self.__class__.__name__, to, RECURSIVE_RELATIONSHIP_CONSTANT) else: assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name) - to_field = to_field or to._meta.pk.name + # For backwards compatibility purposes, we need to *try* and set + # the to_field during FK construction. It won't be guaranteed to + # be correct until contribute_to_class is called. Refs #12190. + to_field = to_field or (to._meta.pk and to._meta.pk.name) kwargs['verbose_name'] = kwargs.get('verbose_name', None) kwargs['rel'] = rel_class(to, to_field, @@ -702,6 +821,19 @@ self.db_index = True + def validate(self, value, model_instance): + if self.rel.parent_link: + return + super(ForeignKey, self).validate(value, model_instance) + if value is None: + return + + qs = self.rel.to._default_manager.filter(**{self.rel.field_name:value}) + qs = qs.complex_filter(self.rel.limit_choices_to) + if not qs.exists(): + raise exceptions.ValidationError(self.error_messages['invalid'] % { + 'model': self.rel.to._meta.verbose_name, 'pk': value}) + def get_attname(self): return '%s_id' % self.name @@ -715,11 +847,12 @@ return getattr(field_default, self.rel.get_related_field().attname) return field_default - def get_db_prep_save(self, value): + def get_db_prep_save(self, value, connection): if value == '' or value == None: return None else: - return self.rel.get_related_field().get_db_prep_save(value) + return self.rel.get_related_field().get_db_prep_save(value, + connection=connection) def value_to_string(self, obj): if not obj: @@ -743,19 +876,24 @@ cls._meta.duplicate_targets[self.column] = (target, "o2m") def contribute_to_related_class(self, cls, related): - setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + # Internal FK's - i.e., those with a related name ending with '+' - + # don't get a related descriptor. + if not self.rel.is_hidden(): + setattr(cls, related.get_accessor_name(), ForeignRelatedObjectsDescriptor(related)) + if self.rel.field_name is None: + self.rel.field_name = cls._meta.pk.name def formfield(self, **kwargs): + db = kwargs.pop('using', None) defaults = { 'form_class': forms.ModelChoiceField, - 'queryset': self.rel.to._default_manager.complex_filter( - self.rel.limit_choices_to), + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to), 'to_field_name': self.rel.field_name, } defaults.update(kwargs) return super(ForeignKey, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # The database column type of a ForeignKey is the column type # of the field to which it points. An exception is if the ForeignKey # points to an AutoField/PositiveIntegerField/PositiveSmallIntegerField, @@ -767,8 +905,8 @@ (not connection.features.related_fields_match_type and isinstance(rel_field, (PositiveIntegerField, PositiveSmallIntegerField)))): - return IntegerField().db_type() - return rel_field.db_type() + return IntegerField().db_type(connection=connection) + return rel_field.db_type(connection=connection) class OneToOneField(ForeignKey): """ @@ -777,6 +915,7 @@ always returns the object pointed to (since there will only ever be one), rather than returning a list. """ + description = _("One-to-one relationship") def __init__(self, to, to_field=None, **kwargs): kwargs['unique'] = True super(OneToOneField, self).__init__(to, to_field, OneToOneRel, **kwargs) @@ -790,7 +929,55 @@ return None return super(OneToOneField, self).formfield(**kwargs) + def save_form_data(self, instance, data): + if isinstance(data, self.rel.to): + setattr(instance, self.name, data) + else: + setattr(instance, self.attname, data) + +def create_many_to_many_intermediary_model(field, klass): + from django.db import models + managed = True + if isinstance(field.rel.to, basestring) and field.rel.to != RECURSIVE_RELATIONSHIP_CONSTANT: + to_model = field.rel.to + to = to_model.split('.')[-1] + def set_managed(field, model, cls): + field.rel.through._meta.managed = model._meta.managed or cls._meta.managed + add_lazy_relation(klass, field, to_model, set_managed) + elif isinstance(field.rel.to, basestring): + to = klass._meta.object_name + to_model = klass + managed = klass._meta.managed + else: + to = field.rel.to._meta.object_name + to_model = field.rel.to + managed = klass._meta.managed or to_model._meta.managed + name = '%s_%s' % (klass._meta.object_name, field.name) + if field.rel.to == RECURSIVE_RELATIONSHIP_CONSTANT or to == klass._meta.object_name: + from_ = 'from_%s' % to.lower() + to = 'to_%s' % to.lower() + else: + from_ = klass._meta.object_name.lower() + to = to.lower() + meta = type('Meta', (object,), { + 'db_table': field._get_m2m_db_table(klass._meta), + 'managed': managed, + 'auto_created': klass, + 'app_label': klass._meta.app_label, + 'unique_together': (from_, to), + 'verbose_name': '%(from)s-%(to)s relationship' % {'from': from_, 'to': to}, + 'verbose_name_plural': '%(from)s-%(to)s relationships' % {'from': from_, 'to': to}, + }) + # Construct and return the new class. + return type(name, (models.Model,), { + 'Meta': meta, + '__module__': klass.__module__, + from_: models.ForeignKey(klass, related_name='%s+' % name), + to: models.ForeignKey(to_model, related_name='%s+' % name) + }) + class ManyToManyField(RelatedField, Field): + description = _("Many-to-many relationship") def __init__(self, to, **kwargs): try: assert not to._meta.abstract, "%s cannot define a relation with abstract class %s" % (self.__class__.__name__, to._meta.object_name) @@ -801,19 +988,16 @@ kwargs['rel'] = ManyToManyRel(to, related_name=kwargs.pop('related_name', None), limit_choices_to=kwargs.pop('limit_choices_to', None), - symmetrical=kwargs.pop('symmetrical', True), + symmetrical=kwargs.pop('symmetrical', to==RECURSIVE_RELATIONSHIP_CONSTANT), through=kwargs.pop('through', None)) self.db_table = kwargs.pop('db_table', None) if kwargs['rel'].through is not None: - self.creates_table = False assert self.db_table is None, "Cannot specify a db_table if an intermediary model is used." - else: - self.creates_table = True Field.__init__(self, **kwargs) - msg = ugettext_lazy('Hold down "Control", or "Command" on a Mac, to select more than one.') + msg = _('Hold down "Control", or "Command" on a Mac, to select more than one.') self.help_text = string_concat(self.help_text, ' ', msg) def get_choices_default(self): @@ -822,62 +1006,45 @@ def _get_m2m_db_table(self, opts): "Function that can be curried to provide the m2m table name for this relation" if self.rel.through is not None: - return self.rel.through_model._meta.db_table + return self.rel.through._meta.db_table elif self.db_table: return self.db_table else: return util.truncate_name('%s_%s' % (opts.db_table, self.name), connection.ops.max_name_length()) - def _get_m2m_column_name(self, related): - "Function that can be curried to provide the source column name for the m2m table" - try: - return self._m2m_column_name_cache - except: - if self.rel.through is not None: - for f in self.rel.through_model._meta.fields: - if hasattr(f,'rel') and f.rel and f.rel.to == related.model: - self._m2m_column_name_cache = f.column - break - # If this is an m2m relation to self, avoid the inevitable name clash - elif related.model == related.parent_model: - self._m2m_column_name_cache = 'from_' + related.model._meta.object_name.lower() + '_id' - else: - self._m2m_column_name_cache = related.model._meta.object_name.lower() + '_id' - - # Return the newly cached value - return self._m2m_column_name_cache + def _get_m2m_attr(self, related, attr): + "Function that can be curried to provide the source accessor or DB column name for the m2m table" + cache_attr = '_m2m_%s_cache' % attr + if hasattr(self, cache_attr): + return getattr(self, cache_attr) + for f in self.rel.through._meta.fields: + if hasattr(f,'rel') and f.rel and f.rel.to == related.model: + setattr(self, cache_attr, getattr(f, attr)) + return getattr(self, cache_attr) - def _get_m2m_reverse_name(self, related): - "Function that can be curried to provide the related column name for the m2m table" - try: - return self._m2m_reverse_name_cache - except: - if self.rel.through is not None: - found = False - for f in self.rel.through_model._meta.fields: - if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model: - if related.model == related.parent_model: - # If this is an m2m-intermediate to self, - # the first foreign key you find will be - # the source column. Keep searching for - # the second foreign key. - if found: - self._m2m_reverse_name_cache = f.column - break - else: - found = True - else: - self._m2m_reverse_name_cache = f.column - break - # If this is an m2m relation to self, avoid the inevitable name clash - elif related.model == related.parent_model: - self._m2m_reverse_name_cache = 'to_' + related.parent_model._meta.object_name.lower() + '_id' - else: - self._m2m_reverse_name_cache = related.parent_model._meta.object_name.lower() + '_id' - - # Return the newly cached value - return self._m2m_reverse_name_cache + def _get_m2m_reverse_attr(self, related, attr): + "Function that can be curried to provide the related accessor or DB column name for the m2m table" + cache_attr = '_m2m_reverse_%s_cache' % attr + if hasattr(self, cache_attr): + return getattr(self, cache_attr) + found = False + for f in self.rel.through._meta.fields: + if hasattr(f,'rel') and f.rel and f.rel.to == related.parent_model: + if related.model == related.parent_model: + # If this is an m2m-intermediate to self, + # the first foreign key you find will be + # the source column. Keep searching for + # the second foreign key. + if found: + setattr(self, cache_attr, getattr(f, attr)) + break + else: + found = True + else: + setattr(self, cache_attr, getattr(f, attr)) + break + return getattr(self, cache_attr) def isValidIDList(self, field_data, all_data): "Validates that the value is a valid list of foreign keys" @@ -919,10 +1086,17 @@ # specify *what* on my non-reversible relation?!"), so we set it up # automatically. The funky name reduces the chance of an accidental # clash. - if self.rel.symmetrical and self.rel.to == "self" and self.rel.related_name is None: + if self.rel.symmetrical and (self.rel.to == "self" or self.rel.to == cls._meta.object_name): self.rel.related_name = "%s_rel_+" % name super(ManyToManyField, self).contribute_to_class(cls, name) + + # The intermediate m2m model is not auto created if: + # 1) There is a manually specified intermediate, or + # 2) The class owning the m2m field is abstract. + if not self.rel.through and not cls._meta.abstract: + self.rel.through = create_many_to_many_intermediary_model(self, cls) + # Add the descriptor for the m2m relation setattr(cls, self.name, ReverseManyRelatedObjectsDescriptor(self)) @@ -933,11 +1107,8 @@ # work correctly. if isinstance(self.rel.through, basestring): def resolve_through_model(field, model, cls): - field.rel.through_model = model + field.rel.through = model add_lazy_relation(cls, self, self.rel.through, resolve_through_model) - elif self.rel.through: - self.rel.through_model = self.rel.through - self.rel.through = self.rel.through._meta.object_name if isinstance(self.rel.to, basestring): target = self.rel.to @@ -946,15 +1117,17 @@ cls._meta.duplicate_targets[self.column] = (target, "m2m") def contribute_to_related_class(self, cls, related): - # m2m relations to self do not have a ManyRelatedObjectsDescriptor, - # as it would be redundant - unless the field is non-symmetrical. - if related.model != related.parent_model or not self.rel.symmetrical: - # Add the descriptor for the m2m relation + # Internal M2Ms (i.e., those with a related name ending with '+') + # don't get a related descriptor. + if not self.rel.is_hidden(): setattr(cls, related.get_accessor_name(), ManyRelatedObjectsDescriptor(related)) # Set up the accessors for the column names on the m2m table - self.m2m_column_name = curry(self._get_m2m_column_name, related) - self.m2m_reverse_name = curry(self._get_m2m_reverse_name, related) + self.m2m_column_name = curry(self._get_m2m_attr, related, 'column') + self.m2m_reverse_name = curry(self._get_m2m_reverse_attr, related, 'column') + + self.m2m_field_name = curry(self._get_m2m_attr, related, 'name') + self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name') def set_attributes_from_rel(self): pass @@ -967,7 +1140,11 @@ setattr(instance, self.attname, data) def formfield(self, **kwargs): - defaults = {'form_class': forms.ModelMultipleChoiceField, 'queryset': self.rel.to._default_manager.complex_filter(self.rel.limit_choices_to)} + db = kwargs.pop('using', None) + defaults = { + 'form_class': forms.ModelMultipleChoiceField, + 'queryset': self.rel.to._default_manager.using(db).complex_filter(self.rel.limit_choices_to) + } defaults.update(kwargs) # If initial is passed in, it's a list of related objects, but the # MultipleChoiceField takes a list of IDs. @@ -978,7 +1155,7 @@ defaults['initial'] = [i._get_pk_val() for i in initial] return super(ManyToManyField, self).formfield(**defaults) - def db_type(self): + def db_type(self, connection): # A ManyToManyField is not represented by a single column, # so return None. return None