src/hdalab/fields.py
changeset 353 91c44b3fd11f
child 359 46ad324f6fe4
equal deleted inserted replaced
352:205804d9f142 353:91c44b3fd11f
       
     1 # -*- coding: utf-8 -*-
       
     2 '''
       
     3 Created on Nov 14, 2014
       
     4 
       
     5 from https://gist.github.com/gsakkis/601977
       
     6 to correct https://code.djangoproject.com/ticket/10227
       
     7 
       
     8 @author: ymh
       
     9 '''
       
    10 from django.db import models
       
    11 from django.db.models import fields as django_fields
       
    12 from django.core.exceptions import ObjectDoesNotExist
       
    13 
       
    14 
       
    15 class OneToOneField(models.OneToOneField):    
       
    16     def __init__(self, to, **kwargs):
       
    17         self.related_default = kwargs.pop('related_default', None)
       
    18         super(OneToOneField, self).__init__(to, **kwargs)
       
    19 
       
    20     def contribute_to_related_class(self, cls, related):
       
    21         setattr(cls, related.get_accessor_name(),
       
    22                 SingleRelatedObjectDescriptor(related, self.related_default))
       
    23 
       
    24 
       
    25 class SingleRelatedObjectDescriptor(django_fields.related.SingleRelatedObjectDescriptor):
       
    26     def __init__(self, related, default):
       
    27         super(SingleRelatedObjectDescriptor, self).__init__(related)
       
    28         self.default = default
       
    29         
       
    30     def __get__(self, instance, instance_type=None):
       
    31         try:
       
    32             return super(SingleRelatedObjectDescriptor, self).__get__(instance,
       
    33                                                                       instance_type)
       
    34         except ObjectDoesNotExist:
       
    35             if self.default is None:
       
    36                 raise
       
    37             value = self.default(instance)
       
    38             setattr(instance, self.cache_name, value)
       
    39             if value is not None:
       
    40                 setattr(value, self.related.field.get_cache_name(), instance)
       
    41             return value