web/hdabo/fields.py
author ymh <ymh.work@gmail.com>
Tue, 12 Jul 2011 18:24:56 +0200
changeset 97 0d0b23401d14
parent 21 20d3375b6d28
permissions -rw-r--r--
version 00.05

# -*- coding: utf-8 -*-
from django.db import models, router
from django.db.models import signals
from django.db.models.fields.related import (create_many_related_manager, 
    ManyToManyField, ReverseManyRelatedObjectsDescriptor)
from hdabo.forms import SortedMultipleChoiceField
from hdabo.utils import OrderedSet


SORT_VALUE_FIELD_NAME = 'sort_value'


def create_sorted_many_related_manager(superclass, rel):
    RelatedManager = create_many_related_manager(superclass, rel)

    class SortedRelatedManager(RelatedManager):
        def get_query_set(self):
            # We use ``extra`` method here because we have no other access to
            # the extra sorting field of the intermediary model. The fields
            # are hidden for joins because we set ``auto_created`` on the
            # intermediary's meta options.
            return super(SortedRelatedManager, self).\
                get_query_set().\
                order_by('%s.%s' % (
                    rel.through._meta.db_table,
                    rel.through._sort_field_name,))

        def _add_items(self, source_field_name, target_field_name, *objs):
            # join_table: name of the m2m link table
            # source_field_name: the PK fieldname in join_table for the source object
            # target_field_name: the PK fieldname in join_table for the target object
            # *objs - objects to add. Either object instances, or primary keys of object instances.

            # If there aren't any objects, there is nothing to do.
            from django.db.models import Model
            if objs:
                count = self.through._default_manager.count
                new_ids = OrderedSet()
                for obj in objs:
                    if isinstance(obj, self.model):
                        if not router.allow_relation(obj, self.instance):
                            raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % 
                                               (obj, self.instance._state.db, obj._state.db))
                        new_ids.add(obj.pk)
                    elif isinstance(obj, Model):
                        raise TypeError("'%s' instance expected" % self.model._meta.object_name)
                    else:
                        new_ids.add(obj)
                db = router.db_for_write(self.through, instance=self.instance)
                vals = self.through._default_manager.using(db).values_list(target_field_name, flat=True)
                vals = vals.filter(**{
                    source_field_name: self._pk_val,
                    '%s__in' % target_field_name: new_ids,
                })
                new_ids = new_ids - OrderedSet(vals)

                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='pre_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids, using=db)
                # Add the ones that aren't there already
                count = self.through._default_manager.count
                for obj_id in new_ids:
                    self.through._default_manager.using(db).create(**{
                        '%s_id' % source_field_name: self._pk_val,
                        '%s_id' % target_field_name: obj_id,
                        self.through._sort_field_name: count(),
                    })
                if self.reverse or source_field_name == self.source_field_name:
                    # Don't send the signal when we are inserting the
                    # duplicate data row for symmetrical reverse entries.
                    signals.m2m_changed.send(sender=rel.through, action='post_add',
                        instance=self.instance, reverse=self.reverse,
                        model=self.model, pk_set=new_ids, using=db)


    return SortedRelatedManager


class ReverseSortedManyRelatedObjectsDescriptor(ReverseManyRelatedObjectsDescriptor):
    def __get__(self, instance, instance_type=None):
        if instance is None:
            return self

        # Dynamically create a class that subclasses the related
        # model's default manager.
        rel_model = self.field.rel.to
        superclass = rel_model._default_manager.__class__
        RelatedManager = create_sorted_many_related_manager(superclass, self.field.rel)

        manager = RelatedManager(
            model=rel_model,
            core_filters={'%s__pk' % self.field.related_query_name(): instance._get_pk_val()},
            instance=instance,
            symmetrical=(self.field.rel.symmetrical and isinstance(instance, rel_model)),
            source_field_name=self.field.m2m_field_name(),
            target_field_name=self.field.m2m_reverse_field_name(),
            reverse=False
        )

        return manager

    def __set__(self, instance, value):
        if instance is None:
            raise AttributeError, "Manager must be accessed via instance"

        manager = self.__get__(instance)
        manager.clear()
        manager.add(*value)


class SortedManyToManyField(ManyToManyField):
    '''
    Providing a many to many relation that remembers the order of related
    objects.

    Accept a boolean ``sorted`` attribute which specifies if relation is
    ordered or not. Default is set to ``True``. If ``sorted`` is set to
    ``False`` the field will behave exactly like django's ``ManyToManyField``.
    '''
    def __init__(self, to, sorted=True, **kwargs):
        self.sorted = sorted
        if self.sorted:
            # This is very hacky and should be removed if a better solution is
            # found.
            kwargs.setdefault('through', True)
        super(SortedManyToManyField, self).__init__(to, **kwargs)
        self.help_text = kwargs.get('help_text', None)

    def create_intermediary_model(self, cls, field_name):
        '''
        Create intermediary model that stores the relation's data.
        '''
        module = ''

        # make sure rel.to is a model class and not a string
        if isinstance(self.rel.to, basestring):
            bits = self.rel.to.split('.')
            if len(bits) == 1:
                bits = cls._meta.app_label.lower(), bits[0]
            self.rel.to = models.get_model(*bits)

        model_name = '%s_%s_%s' % (
            cls._meta.app_label,
            cls._meta.object_name,
            field_name)
        from_ = '%s.%s' % (
            cls._meta.app_label,
            cls._meta.object_name)

        def default_sort_value():
            model = models.get_model(cls._meta.app_label, model_name)
            return model._default_manager.count()

        # Using from and to model's name as field names for relations. This is
        # also django default behaviour for m2m intermediary tables.
        fields = {
            cls._meta.object_name.lower():
                models.ForeignKey(from_),
            # using to model's name as field name for the other relation
            self.rel.to._meta.object_name.lower():
                models.ForeignKey(self.rel.to),
            SORT_VALUE_FIELD_NAME:
                models.IntegerField(default=default_sort_value),
        }

        class Meta:
            db_table = '%s_%s_%s' % (
                cls._meta.app_label.lower(),
                cls._meta.object_name.lower(),
                field_name.lower())
            app_label = cls._meta.app_label
            ordering = (SORT_VALUE_FIELD_NAME,)
            auto_created = cls

        attrs = {
            '__module__': module,
            'Meta': Meta,
            '_sort_field_name': SORT_VALUE_FIELD_NAME,
            '__unicode__': lambda s: 'pk=%d' % s.pk,
        }

        # Add in any fields that were provided
        if fields:
            attrs.update(fields)

        # Create the class, which automatically triggers ModelBase processing
        model = type(model_name, (models.Model,), attrs)

        return model

    def contribute_to_class(self, cls, name):
        if self.sorted:
            self.rel.through = self.create_intermediary_model(cls, name)
            super(SortedManyToManyField, self).contribute_to_class(cls, name)
            # overwrite default descriptor with reverse and sorted one
            setattr(cls, self.name, ReverseSortedManyRelatedObjectsDescriptor(self))
        else:
            super(SortedManyToManyField, self).contribute_to_class(cls, name)

    def formfield(self, **kwargs):
        defaults = {}
        if self.sorted:
            defaults['form_class'] = SortedMultipleChoiceField
        defaults.update(kwargs)
        return super(SortedManyToManyField, self).formfield(**defaults)