web/lib/django/core/management/commands/inspectdb.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/core/management/commands/inspectdb.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/core/management/commands/inspectdb.py	Tue May 25 02:43:45 2010 +0200
@@ -1,20 +1,31 @@
+import keyword
+from optparse import make_option
+
 from django.core.management.base import NoArgsCommand, CommandError
+from django.db import connections, DEFAULT_DB_ALIAS
 
 class Command(NoArgsCommand):
     help = "Introspects the database tables in the given database and outputs a Django model module."
 
+    option_list = NoArgsCommand.option_list + (
+        make_option('--database', action='store', dest='database',
+            default=DEFAULT_DB_ALIAS, help='Nominates a database to '
+                'introspect.  Defaults to using the "default" database.'),
+    )
+
     requires_model_validation = False
 
+    db_module = 'django.db'
+
     def handle_noargs(self, **options):
         try:
-            for line in self.handle_inspection():
+            for line in self.handle_inspection(options):
                 print line
         except NotImplementedError:
             raise CommandError("Database inspection isn't supported for the currently selected database backend.")
 
-    def handle_inspection(self):
-        from django.db import connection
-        import keyword
+    def handle_inspection(self, options):
+        connection = connections[options.get('database', DEFAULT_DB_ALIAS)]
 
         table2model = lambda table_name: table_name.title().replace('_', '').replace(' ', '').replace('-', '')
 
@@ -28,7 +39,7 @@
         yield "# Also note: You'll have to insert the output of 'django-admin.py sqlcustom [appname]'"
         yield "# into your database."
         yield ''
-        yield 'from django.db import models'
+        yield 'from %s import models' % self.db_module
         yield ''
         for table_name in connection.introspection.get_table_list(cursor):
             yield 'class %s(models.Model):' % table2model(table_name)
@@ -72,25 +83,11 @@
                     else:
                         extra_params['db_column'] = column_name
                 else:
-                    try:
-                        field_type = connection.introspection.get_field_type(row[1], row)
-                    except KeyError:
-                        field_type = 'TextField'
-                        comment_notes.append('This field type is a guess.')
-
-                    # This is a hook for DATA_TYPES_REVERSE to return a tuple of
-                    # (field_type, extra_params_dict).
-                    if type(field_type) is tuple:
-                        field_type, new_params = field_type
-                        extra_params.update(new_params)
-
-                    # Add max_length for all CharFields.
-                    if field_type == 'CharField' and row[3]:
-                        extra_params['max_length'] = row[3]
-
-                    if field_type == 'DecimalField':
-                        extra_params['max_digits'] = row[4]
-                        extra_params['decimal_places'] = row[5]
+                    # Calling `get_field_type` to get the field type string and any
+                    # additional paramters and notes.
+                    field_type, field_params, field_notes = self.get_field_type(connection, table_name, row)
+                    extra_params.update(field_params)
+                    comment_notes.extend(field_notes)
 
                     # Add primary_key and unique, if necessary.
                     if column_name in indexes:
@@ -122,6 +119,46 @@
                 if comment_notes:
                     field_desc += ' # ' + ' '.join(comment_notes)
                 yield '    %s' % field_desc
-            yield '    class Meta:'
-            yield '        db_table = %r' % table_name
-            yield ''
+            for meta_line in self.get_meta(table_name):
+                yield meta_line
+
+    def get_field_type(self, connection, table_name, row):
+        """
+        Given the database connection, the table name, and the cursor row
+        description, this routine will return the given field type name, as
+        well as any additional keyword parameters and notes for the field.
+        """
+        field_params = {}
+        field_notes = []
+
+        try:
+            field_type = connection.introspection.get_field_type(row[1], row)
+        except KeyError:
+            field_type = 'TextField'
+            field_notes.append('This field type is a guess.')
+
+        # This is a hook for DATA_TYPES_REVERSE to return a tuple of
+        # (field_type, field_params_dict).
+        if type(field_type) is tuple:
+            field_type, new_params = field_type
+            field_params.update(new_params)
+
+        # Add max_length for all CharFields.
+        if field_type == 'CharField' and row[3]:
+            field_params['max_length'] = row[3]
+
+        if field_type == 'DecimalField':
+            field_params['max_digits'] = row[4]
+            field_params['decimal_places'] = row[5]
+
+        return field_type, field_params, field_notes
+
+    def get_meta(self, table_name):
+        """
+        Return a sequence comprising the lines of code necessary
+        to construct the inner Meta class for the model corresponding
+        to the given database table name.
+        """
+        return ['    class Meta:',
+                '        db_table = %r' % table_name,
+                '']