web/lib/django/test/testcases.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/test/testcases.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/test/testcases.py	Tue May 25 02:43:45 2010 +0200
@@ -7,13 +7,18 @@
 from django.core import mail
 from django.core.management import call_command
 from django.core.urlresolvers import clear_url_caches
-from django.db import transaction, connection
+from django.db import transaction, connections, DEFAULT_DB_ALIAS
 from django.http import QueryDict
 from django.test import _doctest as doctest
 from django.test.client import Client
 from django.utils import simplejson
 from django.utils.encoding import smart_str
 
+try:
+    all
+except NameError:
+    from django.utils.itercompat import all
+
 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
 
@@ -201,7 +206,8 @@
                                                           example, exc_info)
         # Rollback, in case of database errors. Otherwise they'd have
         # side effects on other tests.
-        transaction.rollback_unless_managed()
+        for conn in connections:
+            transaction.rollback_unless_managed(using=conn)
 
 class TransactionTestCase(unittest.TestCase):
     def _pre_setup(self):
@@ -219,11 +225,19 @@
         mail.outbox = []
 
     def _fixture_setup(self):
-        call_command('flush', verbosity=0, interactive=False)
-        if hasattr(self, 'fixtures'):
-            # We have to use this slightly awkward syntax due to the fact
-            # that we're using *args and **kwargs together.
-            call_command('loaddata', *self.fixtures, **{'verbosity': 0})
+        # If the test case has a multi_db=True flag, flush all databases.
+        # Otherwise, just flush default.
+        if getattr(self, 'multi_db', False):
+            databases = connections
+        else:
+            databases = [DEFAULT_DB_ALIAS]
+        for db in databases:
+            call_command('flush', verbosity=0, interactive=False, database=db)
+
+            if hasattr(self, 'fixtures'):
+                # We have to use this slightly awkward syntax due to the fact
+                # that we're using *args and **kwargs together.
+                call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db})
 
     def _urlconf_setup(self):
         if hasattr(self, 'urls'):
@@ -273,34 +287,41 @@
             clear_url_caches()
 
     def assertRedirects(self, response, expected_url, status_code=302,
-                        target_status_code=200, host=None):
+                        target_status_code=200, host=None, msg_prefix=''):
         """Asserts that a response redirected to a specific URL, and that the
         redirect URL can be loaded.
 
         Note that assertRedirects won't work for external links since it uses
         TestClient to do a request.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         if hasattr(response, 'redirect_chain'):
             # The request was a followed redirect
             self.failUnless(len(response.redirect_chain) > 0,
-                ("Response didn't redirect as expected: Response code was %d"
-                " (expected %d)" % (response.status_code, status_code)))
+                msg_prefix + "Response didn't redirect as expected: Response"
+                " code was %d (expected %d)" %
+                    (response.status_code, status_code))
 
             self.assertEqual(response.redirect_chain[0][1], status_code,
-                ("Initial response didn't redirect as expected: Response code was %d"
-                 " (expected %d)" % (response.redirect_chain[0][1], status_code)))
+                msg_prefix + "Initial response didn't redirect as expected:"
+                " Response code was %d (expected %d)" %
+                    (response.redirect_chain[0][1], status_code))
 
             url, status_code = response.redirect_chain[-1]
 
             self.assertEqual(response.status_code, target_status_code,
-                ("Response didn't redirect as expected: Final Response code was %d"
-                " (expected %d)" % (response.status_code, target_status_code)))
+                msg_prefix + "Response didn't redirect as expected: Final"
+                " Response code was %d (expected %d)" %
+                    (response.status_code, target_status_code))
 
         else:
             # Not a followed redirect
             self.assertEqual(response.status_code, status_code,
-                ("Response didn't redirect as expected: Response code was %d"
-                 " (expected %d)" % (response.status_code, status_code)))
+                msg_prefix + "Response didn't redirect as expected: Response"
+                " code was %d (expected %d)" %
+                    (response.status_code, status_code))
 
             url = response['Location']
             scheme, netloc, path, query, fragment = urlsplit(url)
@@ -310,9 +331,9 @@
             # Get the redirection page, using the same client that was used
             # to obtain the original response.
             self.assertEqual(redirect_response.status_code, target_status_code,
-                ("Couldn't retrieve redirection page '%s': response code was %d"
-                 " (expected %d)") %
-                     (path, redirect_response.status_code, target_status_code))
+                msg_prefix + "Couldn't retrieve redirection page '%s':"
+                " response code was %d (expected %d)" %
+                    (path, redirect_response.status_code, target_status_code))
 
         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
         if not (e_scheme or e_netloc):
@@ -320,10 +341,11 @@
                 e_query, e_fragment))
 
         self.assertEqual(url, expected_url,
-            "Response redirected to '%s', expected '%s'" % (url, expected_url))
+            msg_prefix + "Response redirected to '%s', expected '%s'" %
+                (url, expected_url))
 
-
-    def assertContains(self, response, text, count=None, status_code=200):
+    def assertContains(self, response, text, count=None, status_code=200,
+                       msg_prefix=''):
         """
         Asserts that a response indicates that a page was retrieved
         successfully, (i.e., the HTTP status code was as expected), and that
@@ -331,42 +353,52 @@
         If ``count`` is None, the count doesn't matter - the assertion is true
         if the text occurs at least once in the response.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         self.assertEqual(response.status_code, status_code,
-            "Couldn't retrieve page: Response code was %d (expected %d)'" %
-                (response.status_code, status_code))
+            msg_prefix + "Couldn't retrieve page: Response code was %d"
+            " (expected %d)" % (response.status_code, status_code))
         text = smart_str(text, response._charset)
         real_count = response.content.count(text)
         if count is not None:
             self.assertEqual(real_count, count,
-                "Found %d instances of '%s' in response (expected %d)" %
-                    (real_count, text, count))
+                msg_prefix + "Found %d instances of '%s' in response"
+                " (expected %d)" % (real_count, text, count))
         else:
             self.failUnless(real_count != 0,
-                            "Couldn't find '%s' in response" % text)
+                msg_prefix + "Couldn't find '%s' in response" % text)
 
-    def assertNotContains(self, response, text, status_code=200):
+    def assertNotContains(self, response, text, status_code=200,
+                          msg_prefix=''):
         """
         Asserts that a response indicates that a page was retrieved
         successfully, (i.e., the HTTP status code was as expected), and that
         ``text`` doesn't occurs in the content of the response.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         self.assertEqual(response.status_code, status_code,
-            "Couldn't retrieve page: Response code was %d (expected %d)'" %
-                (response.status_code, status_code))
+            msg_prefix + "Couldn't retrieve page: Response code was %d"
+            " (expected %d)" % (response.status_code, status_code))
         text = smart_str(text, response._charset)
-        self.assertEqual(response.content.count(text),
-             0, "Response should not contain '%s'" % text)
+        self.assertEqual(response.content.count(text), 0,
+            msg_prefix + "Response should not contain '%s'" % text)
 
-    def assertFormError(self, response, form, field, errors):
+    def assertFormError(self, response, form, field, errors, msg_prefix=''):
         """
         Asserts that a form used to render the response has a specific field
         error.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         # Put context(s) into a list to simplify processing.
         contexts = to_list(response.context)
         if not contexts:
-            self.fail('Response did not use any contexts to render the'
-                      ' response')
+            self.fail(msg_prefix + "Response did not use any contexts to "
+                      "render the response")
 
         # Put error(s) into a list to simplify processing.
         errors = to_list(errors)
@@ -382,50 +414,65 @@
                     if field in context[form].errors:
                         field_errors = context[form].errors[field]
                         self.failUnless(err in field_errors,
-                                        "The field '%s' on form '%s' in"
-                                        " context %d does not contain the"
-                                        " error '%s' (actual errors: %s)" %
-                                            (field, form, i, err,
-                                             repr(field_errors)))
+                            msg_prefix + "The field '%s' on form '%s' in"
+                            " context %d does not contain the error '%s'"
+                            " (actual errors: %s)" %
+                                (field, form, i, err, repr(field_errors)))
                     elif field in context[form].fields:
-                        self.fail("The field '%s' on form '%s' in context %d"
-                                  " contains no errors" % (field, form, i))
+                        self.fail(msg_prefix + "The field '%s' on form '%s'"
+                                  " in context %d contains no errors" %
+                                      (field, form, i))
                     else:
-                        self.fail("The form '%s' in context %d does not"
-                                  " contain the field '%s'" %
+                        self.fail(msg_prefix + "The form '%s' in context %d"
+                                  " does not contain the field '%s'" %
                                       (form, i, field))
                 else:
                     non_field_errors = context[form].non_field_errors()
                     self.failUnless(err in non_field_errors,
-                        "The form '%s' in context %d does not contain the"
-                        " non-field error '%s' (actual errors: %s)" %
+                        msg_prefix + "The form '%s' in context %d does not"
+                        " contain the non-field error '%s'"
+                        " (actual errors: %s)" %
                             (form, i, err, non_field_errors))
         if not found_form:
-            self.fail("The form '%s' was not used to render the response" %
-                          form)
+            self.fail(msg_prefix + "The form '%s' was not used to render the"
+                      " response" % form)
 
-    def assertTemplateUsed(self, response, template_name):
+    def assertTemplateUsed(self, response, template_name, msg_prefix=''):
         """
         Asserts that the template with the provided name was used in rendering
         the response.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         template_names = [t.name for t in to_list(response.template)]
         if not template_names:
-            self.fail('No templates used to render the response')
+            self.fail(msg_prefix + "No templates used to render the response")
         self.failUnless(template_name in template_names,
-            (u"Template '%s' was not a template used to render the response."
-             u" Actual template(s) used: %s") % (template_name,
-                                                 u', '.join(template_names)))
+            msg_prefix + "Template '%s' was not a template used to render"
+            " the response. Actual template(s) used: %s" %
+                (template_name, u', '.join(template_names)))
 
-    def assertTemplateNotUsed(self, response, template_name):
+    def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
         """
         Asserts that the template with the provided name was NOT used in
         rendering the response.
         """
+        if msg_prefix:
+            msg_prefix += ": "
+
         template_names = [t.name for t in to_list(response.template)]
         self.failIf(template_name in template_names,
-            (u"Template '%s' was used unexpectedly in rendering the"
-             u" response") % template_name)
+            msg_prefix + "Template '%s' was used unexpectedly in rendering"
+            " the response" % template_name)
+
+def connections_support_transactions():
+    """
+    Returns True if all connections support transactions.  This is messy
+    because 2.4 doesn't support any or all.
+    """
+    return all(conn.settings_dict['SUPPORTS_TRANSACTIONS']
+        for conn in connections.all())
 
 class TestCase(TransactionTestCase):
     """
@@ -436,27 +483,47 @@
     """
 
     def _fixture_setup(self):
-        if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
+        if not connections_support_transactions():
             return super(TestCase, self)._fixture_setup()
 
-        transaction.enter_transaction_management()
-        transaction.managed(True)
+        # If the test case has a multi_db=True flag, setup all databases.
+        # Otherwise, just use default.
+        if getattr(self, 'multi_db', False):
+            databases = connections
+        else:
+            databases = [DEFAULT_DB_ALIAS]
+
+        for db in databases:
+            transaction.enter_transaction_management(using=db)
+            transaction.managed(True, using=db)
         disable_transaction_methods()
 
         from django.contrib.sites.models import Site
         Site.objects.clear_cache()
 
-        if hasattr(self, 'fixtures'):
-            call_command('loaddata', *self.fixtures, **{
-                                                        'verbosity': 0,
-                                                        'commit': False
-                                                        })
+        for db in databases:
+            if hasattr(self, 'fixtures'):
+                call_command('loaddata', *self.fixtures, **{
+                                                            'verbosity': 0,
+                                                            'commit': False,
+                                                            'database': db
+                                                            })
 
     def _fixture_teardown(self):
-        if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
+        if not connections_support_transactions():
             return super(TestCase, self)._fixture_teardown()
 
+        # If the test case has a multi_db=True flag, teardown all databases.
+        # Otherwise, just teardown default.
+        if getattr(self, 'multi_db', False):
+            databases = connections
+        else:
+            databases = [DEFAULT_DB_ALIAS]
+
         restore_transaction_methods()
-        transaction.rollback()
-        transaction.leave_transaction_management()
-        connection.close()
+        for db in databases:
+            transaction.rollback(using=db)
+            transaction.leave_transaction_management(using=db)
+
+        for connection in connections.all():
+            connection.close()