web/lib/django_extensions/management/commands/dumpscript.py
changeset 3 526ebd3988b0
equal deleted inserted replaced
1:ebaad720f88b 3:526ebd3988b0
       
     1 #!/usr/bin/env python
       
     2 # -*- coding: UTF-8 -*-
       
     3 """
       
     4       Title: Dumpscript management command
       
     5     Project: Hardytools (queryset-refactor version)
       
     6      Author: Will Hardy (http://willhardy.com.au)
       
     7        Date: June 2008
       
     8       Usage: python manage.py dumpscript appname > scripts/scriptname.py
       
     9   $Revision: 217 $
       
    10 
       
    11 Description: 
       
    12     Generates a Python script that will repopulate the database using objects.
       
    13     The advantage of this approach is that it is easy to understand, and more
       
    14     flexible than directly populating the database, or using XML.
       
    15 
       
    16     * It also allows for new defaults to take effect and only transfers what is
       
    17       needed.
       
    18     * If a new database schema has a NEW ATTRIBUTE, it is simply not
       
    19       populated (using a default value will make the transition smooth :)
       
    20     * If a new database schema REMOVES AN ATTRIBUTE, it is simply ignored
       
    21       and the data moves across safely (I'm assuming we don't want this
       
    22       attribute anymore.
       
    23     * Problems may only occur if there is a new model and is now a required
       
    24       ForeignKey for an existing model. But this is easy to fix by editing the
       
    25       populate script :)
       
    26 
       
    27 Improvements:
       
    28     See TODOs and FIXMEs scattered throughout :-)
       
    29 
       
    30 """
       
    31 
       
    32 import sys
       
    33 from django.db import models
       
    34 from django.core.exceptions import ObjectDoesNotExist
       
    35 from django.core.management.base import BaseCommand
       
    36 from django.utils.encoding import smart_unicode, force_unicode
       
    37 from django.contrib.contenttypes.models import ContentType
       
    38 
       
    39 class Command(BaseCommand):
       
    40     help = 'Dumps the data as a customised python script.'
       
    41     args = '[appname ...]'
       
    42 
       
    43     def handle(self, *app_labels, **options):
       
    44 
       
    45         # Get the models we want to export
       
    46         models = get_models(app_labels)
       
    47 
       
    48         # A dictionary is created to keep track of all the processed objects,
       
    49         # so that foreign key references can be made using python variable names.
       
    50         # This variable "context" will be passed around like the town bicycle.
       
    51         context = {}
       
    52 
       
    53         # Create a dumpscript object and let it format itself as a string
       
    54         print Script(models=models, context=context)
       
    55 
       
    56 
       
    57 def get_models(app_labels):
       
    58     """ Gets a list of models for the given app labels, with some exceptions. 
       
    59         TODO: If a required model is referenced, it should also be included.
       
    60         Or at least discovered with a get_or_create() call.
       
    61     """
       
    62 
       
    63     from django.db.models import get_app, get_apps, get_model
       
    64     from django.db.models import get_models as get_all_models
       
    65 
       
    66     # These models are not to be output, e.g. because they can be generated automatically
       
    67     # TODO: This should be "appname.modelname" string
       
    68     from django.contrib.contenttypes.models import ContentType
       
    69     EXCLUDED_MODELS = (ContentType, )
       
    70 
       
    71     models = []
       
    72 
       
    73     # If no app labels are given, return all
       
    74     if not app_labels:
       
    75         for app in get_apps():
       
    76             models += [ m for m in get_all_models(app) if m not in EXCLUDED_MODELS ]
       
    77 
       
    78     # Get all relevant apps
       
    79     for app_label in app_labels:
       
    80         # If a specific model is mentioned, get only that model
       
    81         if "." in app_label:
       
    82             app_label, model_name = app_label.split(".", 1)
       
    83             models.append(get_model(app_label, model_name))
       
    84         # Get all models for a given app
       
    85         else:
       
    86             models += [ m for m in get_all_models(get_app(app_label)) if m not in EXCLUDED_MODELS ]
       
    87 
       
    88     return models
       
    89 
       
    90 
       
    91 
       
    92 class Code(object):
       
    93     """ A snippet of python script. 
       
    94         This keeps track of import statements and can be output to a string.
       
    95         In the future, other features such as custom indentation might be included
       
    96         in this class.
       
    97     """
       
    98 
       
    99     def __init__(self):
       
   100         self.imports = {}
       
   101         self.indent = -1 
       
   102 
       
   103     def __str__(self):
       
   104         """ Returns a string representation of this script. 
       
   105         """
       
   106         if self.imports:
       
   107             sys.stderr.write(repr(self.import_lines))
       
   108             return flatten_blocks([""] + self.import_lines + [""] + self.lines, num_indents=self.indent)
       
   109         else:
       
   110             return flatten_blocks(self.lines, num_indents=self.indent)
       
   111 
       
   112     def get_import_lines(self):
       
   113         """ Takes the stored imports and converts them to lines
       
   114         """
       
   115         if self.imports:
       
   116             return [ "from %s import %s" % (value, key) for key, value in self.imports.items() ]
       
   117         else:
       
   118             return []
       
   119     import_lines = property(get_import_lines)
       
   120 
       
   121 
       
   122 class ModelCode(Code):
       
   123     " Produces a python script that can recreate data for a given model class. "
       
   124 
       
   125     def __init__(self, model, context={}):
       
   126         self.model = model
       
   127         self.context = context
       
   128         self.instances = []
       
   129         self.indent = 0
       
   130 
       
   131     def get_imports(self):
       
   132         """ Returns a dictionary of import statements, with the variable being
       
   133             defined as the key. 
       
   134         """
       
   135         return { self.model.__name__: smart_unicode(self.model.__module__) }
       
   136     imports = property(get_imports)
       
   137 
       
   138     def get_lines(self):
       
   139         """ Returns a list of lists or strings, representing the code body. 
       
   140             Each list is a block, each string is a statement.
       
   141         """
       
   142         code = []
       
   143 
       
   144         for counter, item in enumerate(self.model.objects.all()):
       
   145             instance = InstanceCode(instance=item, id=counter+1, context=self.context)
       
   146             self.instances.append(instance)
       
   147             if instance.waiting_list:
       
   148                 code += instance.lines
       
   149  
       
   150         # After each instance has been processed, try again.
       
   151         # This allows self referencing fields to work.
       
   152         for instance in self.instances:
       
   153             if instance.waiting_list:
       
   154                 code += instance.lines
       
   155 
       
   156         return code
       
   157 
       
   158     lines = property(get_lines)
       
   159 
       
   160 
       
   161 class InstanceCode(Code):
       
   162     " Produces a python script that can recreate data for a given model instance. "
       
   163 
       
   164     def __init__(self, instance, id, context={}):
       
   165         """ We need the instance in question and an id """
       
   166 
       
   167         self.instance = instance
       
   168         self.model = self.instance.__class__
       
   169         self.context = context
       
   170         self.variable_name = "%s_%s" % (self.instance._meta.db_table, id)
       
   171         self.skip_me = None
       
   172         self.instantiated = False
       
   173 
       
   174         self.indent  = 0 
       
   175         self.imports = {}
       
   176 
       
   177         self.waiting_list = list(self.model._meta.fields)
       
   178 
       
   179         self.many_to_many_waiting_list = {} 
       
   180         for field in self.model._meta.many_to_many:
       
   181             self.many_to_many_waiting_list[field] = list(getattr(self.instance, field.name).all())
       
   182 
       
   183     def get_lines(self, force=False):
       
   184         """ Returns a list of lists or strings, representing the code body. 
       
   185             Each list is a block, each string is a statement.
       
   186             
       
   187             force (True or False): if an attribute object cannot be included, 
       
   188             it is usually skipped to be processed later. With 'force' set, there
       
   189             will be no waiting: a get_or_create() call is written instead.
       
   190         """
       
   191         code_lines = []
       
   192 
       
   193         # Don't return anything if this is an instance that should be skipped
       
   194         if self.skip():
       
   195             return []
       
   196 
       
   197         # Initialise our new object
       
   198         # e.g. model_name_35 = Model()
       
   199         code_lines += self.instantiate()
       
   200 
       
   201         # Add each field
       
   202         # e.g. model_name_35.field_one = 1034.91
       
   203         #      model_name_35.field_two = "text"
       
   204         code_lines += self.get_waiting_list()
       
   205 
       
   206         if force:
       
   207             # TODO: Check that M2M are not affected
       
   208             code_lines += self.get_waiting_list(force=force)
       
   209 
       
   210         # Print the save command for our new object
       
   211         # e.g. model_name_35.save()
       
   212         if code_lines:
       
   213             code_lines.append("%s.save()\n" % (self.variable_name))
       
   214 
       
   215         code_lines += self.get_many_to_many_lines(force=force)
       
   216 
       
   217         return code_lines
       
   218     lines = property(get_lines)
       
   219 
       
   220     def skip(self):
       
   221         """ Determine whether or not this object should be skipped.
       
   222             If this model is a parent of a single subclassed instance, skip it.
       
   223             The subclassed instance will create this parent instance for us.
       
   224 
       
   225             TODO: Allow the user to force its creation?
       
   226         """
       
   227 
       
   228         if self.skip_me is not None:
       
   229             return self.skip_me
       
   230 
       
   231         try:
       
   232             # Django trunk since r7722 uses CollectedObjects instead of dict
       
   233             from django.db.models.query import CollectedObjects
       
   234             sub_objects = CollectedObjects()
       
   235         except ImportError:
       
   236             # previous versions don't have CollectedObjects
       
   237             sub_objects = {}
       
   238         self.instance._collect_sub_objects(sub_objects)
       
   239         if reduce(lambda x, y: x+y, [self.model in so._meta.parents for so in sub_objects.keys()]) == 1:
       
   240             pk_name = self.instance._meta.pk.name
       
   241             key = '%s_%s' % (self.model.__name__, getattr(self.instance, pk_name))
       
   242             self.context[key] = None
       
   243             self.skip_me = True
       
   244         else:
       
   245             self.skip_me = False
       
   246 
       
   247         return self.skip_me
       
   248 
       
   249     def instantiate(self):
       
   250         " Write lines for instantiation "
       
   251         # e.g. model_name_35 = Model()
       
   252         code_lines = []
       
   253 
       
   254         if not self.instantiated:
       
   255             code_lines.append("%s = %s()" % (self.variable_name, self.model.__name__))
       
   256             self.instantiated = True
       
   257 
       
   258             # Store our variable name for future foreign key references
       
   259             pk_name = self.instance._meta.pk.name
       
   260             key = '%s_%s' % (self.model.__name__, getattr(self.instance, pk_name))
       
   261             self.context[key] = self.variable_name
       
   262 
       
   263         return code_lines
       
   264 
       
   265 
       
   266     def get_waiting_list(self, force=False):
       
   267         " Add lines for any waiting fields that can be completed now. "
       
   268 
       
   269         code_lines = []
       
   270 
       
   271         # Process normal fields
       
   272         for field in list(self.waiting_list):
       
   273             try:
       
   274                 # Find the value, add the line, remove from waiting list and move on
       
   275                 value = get_attribute_value(self.instance, field, self.context, force=force)
       
   276                 code_lines.append('%s.%s = %s' % (self.variable_name, field.name, value))
       
   277                 self.waiting_list.remove(field)
       
   278             except SkipValue, e:
       
   279                 # Remove from the waiting list and move on
       
   280                 self.waiting_list.remove(field)
       
   281                 continue
       
   282             except DoLater, e:
       
   283                 # Move on, maybe next time
       
   284                 continue
       
   285 
       
   286 
       
   287         return code_lines
       
   288 
       
   289 
       
   290     def get_many_to_many_lines(self, force=False):
       
   291         """ Generates lines that define many to many relations for this instance. """
       
   292 
       
   293         lines = []
       
   294 
       
   295         for field, rel_items in self.many_to_many_waiting_list.items():
       
   296             for rel_item in list(rel_items):
       
   297                 try:
       
   298                     pk_name = rel_item._meta.pk.name
       
   299                     key = '%s_%s' % (rel_item.__class__.__name__, getattr(rel_item, pk_name))
       
   300                     value = "%s" % self.context[key]
       
   301                     lines.append('%s.%s.add(%s)' % (self.variable_name, field.name, value))
       
   302                     self.many_to_many_waiting_list[field].remove(rel_item)
       
   303                 except KeyError:
       
   304                     if force:
       
   305                         value = "%s.objects.get(%s=%s)" % (rel_item._meta.object_name, pk_name, getattr(rel_item, pk_name))
       
   306                         lines.append('%s.%s.add(%s)' % (self.variable_name, field.name, value))
       
   307                         self.many_to_many_waiting_list[field].remove(rel_item)
       
   308 
       
   309         if lines:
       
   310             lines.append("")
       
   311 
       
   312         return lines
       
   313 
       
   314 
       
   315 class Script(Code):
       
   316     " Produces a complete python script that can recreate data for the given apps. "
       
   317 
       
   318     def __init__(self, models, context={}):
       
   319         self.models = models
       
   320         self.context = context
       
   321 
       
   322         self.indent = -1 
       
   323         self.imports = {}
       
   324 
       
   325     def get_lines(self):
       
   326         """ Returns a list of lists or strings, representing the code body. 
       
   327             Each list is a block, each string is a statement.
       
   328         """
       
   329         code = [ self.FILE_HEADER.strip() ]
       
   330 
       
   331         # Queue and process the required models
       
   332         for model_class in queue_models(self.models, context=self.context):
       
   333             sys.stderr.write('Processing model: %s\n' % model_class.model.__name__)
       
   334             code.append(model_class.import_lines)
       
   335             code.append("")
       
   336             code.append(model_class.lines)
       
   337 
       
   338         # Process left over foreign keys from cyclic models
       
   339         for model in self.models:
       
   340             sys.stderr.write('Re-processing model: %s\n' % model.model.__name__)
       
   341             for instance in model.instances:
       
   342                 if instance.waiting_list or instance.many_to_many_waiting_list:
       
   343                     code.append(instance.get_lines(force=True))
       
   344 
       
   345         return code
       
   346 
       
   347     lines = property(get_lines)
       
   348 
       
   349     # A user-friendly file header
       
   350     FILE_HEADER = """
       
   351 
       
   352 #!/usr/bin/env python
       
   353 # -*- coding: utf-8 -*-
       
   354 
       
   355 # This file has been automatically generated, changes may be lost if you
       
   356 # go and generate it again. It was generated with the following command:
       
   357 # %s
       
   358 
       
   359 import datetime
       
   360 from decimal import Decimal
       
   361 from django.contrib.contenttypes.models import ContentType
       
   362 
       
   363 def run():
       
   364 
       
   365 """ % " ".join(sys.argv)
       
   366 
       
   367 
       
   368 
       
   369 # HELPER FUNCTIONS
       
   370 #-------------------------------------------------------------------------------
       
   371 
       
   372 def flatten_blocks(lines, num_indents=-1):
       
   373     """ Takes a list (block) or string (statement) and flattens it into a string
       
   374         with indentation. 
       
   375     """
       
   376 
       
   377     # The standard indent is four spaces
       
   378     INDENTATION = " " * 4
       
   379 
       
   380     if not lines:
       
   381         return ""
       
   382 
       
   383     # If this is a string, add the indentation and finish here
       
   384     if isinstance(lines, basestring):
       
   385         return INDENTATION * num_indents + lines
       
   386 
       
   387     # If this is not a string, join the lines and recurse
       
   388     return "\n".join([ flatten_blocks(line, num_indents+1) for line in lines ])
       
   389 
       
   390 
       
   391 
       
   392 
       
   393 def get_attribute_value(item, field, context, force=False):
       
   394     """ Gets a string version of the given attribute's value, like repr() might. """
       
   395 
       
   396     # Find the value of the field, catching any database issues
       
   397     try:
       
   398         value = getattr(item, field.name)
       
   399     except ObjectDoesNotExist:
       
   400         raise SkipValue('Could not find object for %s.%s, ignoring.\n' % (item.__class__.__name__, field.name))
       
   401 
       
   402     # AutoField: We don't include the auto fields, they'll be automatically recreated
       
   403     if isinstance(field, models.AutoField):
       
   404         raise SkipValue()
       
   405 
       
   406     # Some databases (eg MySQL) might store boolean values as 0/1, this needs to be cast as a bool
       
   407     elif isinstance(field, models.BooleanField) and value is not None:
       
   408         return repr(bool(value))
       
   409 
       
   410     # Post file-storage-refactor, repr() on File/ImageFields no longer returns the path
       
   411     elif isinstance(field, models.FileField):
       
   412         return repr(force_unicode(value))
       
   413 
       
   414     # ForeignKey fields, link directly using our stored python variable name
       
   415     elif isinstance(field, models.ForeignKey) and value is not None:
       
   416 
       
   417         # Special case for contenttype foreign keys: no need to output any
       
   418         # content types in this script, as they can be generated again 
       
   419         # automatically.
       
   420         # NB: Not sure if "is" will always work
       
   421         if field.rel.to is ContentType:
       
   422             return 'ContentType.objects.get(app_label="%s", model="%s")' % (value.app_label, value.model)
       
   423 
       
   424         # Generate an identifier (key) for this foreign object
       
   425         pk_name = value._meta.pk.name
       
   426         key = '%s_%s' % (value.__class__.__name__, getattr(value, pk_name))
       
   427 
       
   428         if key in context:
       
   429             variable_name = context[key]
       
   430             # If the context value is set to None, this should be skipped.
       
   431             # This identifies models that have been skipped (inheritance)
       
   432             if variable_name is None:
       
   433                 raise SkipValue()
       
   434             # Return the variable name listed in the context 
       
   435             return "%s" % variable_name
       
   436         elif force:
       
   437             return "%s.objects.get(%s=%s)" % (value._meta.object_name, pk_name, getattr(value, pk_name))
       
   438         else:
       
   439             raise DoLater('(FK) %s.%s\n' % (item.__class__.__name__, field.name))
       
   440 
       
   441 
       
   442     # A normal field (e.g. a python built-in)
       
   443     else:
       
   444         return repr(value)
       
   445 
       
   446 def queue_models(models, context):
       
   447     """ Works an an appropriate ordering for the models.
       
   448         This isn't essential, but makes the script look nicer because 
       
   449         more instances can be defined on their first try.
       
   450     """
       
   451 
       
   452     # Max number of cycles allowed before we call it an infinite loop.
       
   453     MAX_CYCLES = 5
       
   454 
       
   455     model_queue = []
       
   456     number_remaining_models = len(models)
       
   457     allowed_cycles = MAX_CYCLES
       
   458 
       
   459     while number_remaining_models > 0:
       
   460         previous_number_remaining_models = number_remaining_models
       
   461 
       
   462         model = models.pop(0)
       
   463         
       
   464         # If the model is ready to be processed, add it to the list
       
   465         if check_dependencies(model, model_queue):
       
   466             model_class = ModelCode(model=model, context=context)
       
   467             model_queue.append(model_class)
       
   468 
       
   469         # Otherwise put the model back at the end of the list
       
   470         else:
       
   471             models.append(model)
       
   472 
       
   473         # Check for infinite loops. 
       
   474         # This means there is a cyclic foreign key structure
       
   475         # That cannot be resolved by re-ordering
       
   476         number_remaining_models = len(models)
       
   477         if number_remaining_models == previous_number_remaining_models:
       
   478             allowed_cycles -= 1
       
   479             if allowed_cycles <= 0:
       
   480                 # Add the remaining models, but do not remove them from the model list
       
   481                 missing_models = [ ModelCode(model=m, context=context) for m in models ]
       
   482                 model_queue += missing_models
       
   483                 # Replace the models with the model class objects 
       
   484                 # (sure, this is a little bit of hackery)
       
   485                 models[:] = missing_models
       
   486                 break
       
   487         else:
       
   488             allowed_cycles = MAX_CYCLES
       
   489 
       
   490     return model_queue
       
   491 
       
   492 
       
   493 def check_dependencies(model, model_queue):
       
   494     " Check that all the depenedencies for this model are already in the queue. "
       
   495 
       
   496     # A list of allowed links: existing fields, itself and the special case ContentType
       
   497     allowed_links = [ m.model.__name__ for m in model_queue ] + [model.__name__, 'ContentType']
       
   498 
       
   499     # For each ForeignKey or ManyToMany field, check that a link is possible
       
   500     for field in model._meta.fields + model._meta.many_to_many:
       
   501         if field.rel and field.rel.to.__name__ not in allowed_links:
       
   502             return False
       
   503 
       
   504     return True
       
   505 
       
   506 
       
   507 
       
   508 # EXCEPTIONS
       
   509 #-------------------------------------------------------------------------------
       
   510 
       
   511 class SkipValue(Exception):
       
   512     """ Value could not be parsed or should simply be skipped. """
       
   513 
       
   514 class DoLater(Exception):
       
   515     """ Value could not be parsed or should simply be skipped. """