--- a/web/lib/django/test/testcases.py Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/test/testcases.py Tue May 25 02:43:45 2010 +0200
@@ -7,13 +7,18 @@
from django.core import mail
from django.core.management import call_command
from django.core.urlresolvers import clear_url_caches
-from django.db import transaction, connection
+from django.db import transaction, connections, DEFAULT_DB_ALIAS
from django.http import QueryDict
from django.test import _doctest as doctest
from django.test.client import Client
from django.utils import simplejson
from django.utils.encoding import smart_str
+try:
+ all
+except NameError:
+ from django.utils.itercompat import all
+
normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
@@ -201,7 +206,8 @@
example, exc_info)
# Rollback, in case of database errors. Otherwise they'd have
# side effects on other tests.
- transaction.rollback_unless_managed()
+ for conn in connections:
+ transaction.rollback_unless_managed(using=conn)
class TransactionTestCase(unittest.TestCase):
def _pre_setup(self):
@@ -219,11 +225,19 @@
mail.outbox = []
def _fixture_setup(self):
- call_command('flush', verbosity=0, interactive=False)
- if hasattr(self, 'fixtures'):
- # We have to use this slightly awkward syntax due to the fact
- # that we're using *args and **kwargs together.
- call_command('loaddata', *self.fixtures, **{'verbosity': 0})
+ # If the test case has a multi_db=True flag, flush all databases.
+ # Otherwise, just flush default.
+ if getattr(self, 'multi_db', False):
+ databases = connections
+ else:
+ databases = [DEFAULT_DB_ALIAS]
+ for db in databases:
+ call_command('flush', verbosity=0, interactive=False, database=db)
+
+ if hasattr(self, 'fixtures'):
+ # We have to use this slightly awkward syntax due to the fact
+ # that we're using *args and **kwargs together.
+ call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db})
def _urlconf_setup(self):
if hasattr(self, 'urls'):
@@ -273,34 +287,41 @@
clear_url_caches()
def assertRedirects(self, response, expected_url, status_code=302,
- target_status_code=200, host=None):
+ target_status_code=200, host=None, msg_prefix=''):
"""Asserts that a response redirected to a specific URL, and that the
redirect URL can be loaded.
Note that assertRedirects won't work for external links since it uses
TestClient to do a request.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
if hasattr(response, 'redirect_chain'):
# The request was a followed redirect
self.failUnless(len(response.redirect_chain) > 0,
- ("Response didn't redirect as expected: Response code was %d"
- " (expected %d)" % (response.status_code, status_code)))
+ msg_prefix + "Response didn't redirect as expected: Response"
+ " code was %d (expected %d)" %
+ (response.status_code, status_code))
self.assertEqual(response.redirect_chain[0][1], status_code,
- ("Initial response didn't redirect as expected: Response code was %d"
- " (expected %d)" % (response.redirect_chain[0][1], status_code)))
+ msg_prefix + "Initial response didn't redirect as expected:"
+ " Response code was %d (expected %d)" %
+ (response.redirect_chain[0][1], status_code))
url, status_code = response.redirect_chain[-1]
self.assertEqual(response.status_code, target_status_code,
- ("Response didn't redirect as expected: Final Response code was %d"
- " (expected %d)" % (response.status_code, target_status_code)))
+ msg_prefix + "Response didn't redirect as expected: Final"
+ " Response code was %d (expected %d)" %
+ (response.status_code, target_status_code))
else:
# Not a followed redirect
self.assertEqual(response.status_code, status_code,
- ("Response didn't redirect as expected: Response code was %d"
- " (expected %d)" % (response.status_code, status_code)))
+ msg_prefix + "Response didn't redirect as expected: Response"
+ " code was %d (expected %d)" %
+ (response.status_code, status_code))
url = response['Location']
scheme, netloc, path, query, fragment = urlsplit(url)
@@ -310,9 +331,9 @@
# Get the redirection page, using the same client that was used
# to obtain the original response.
self.assertEqual(redirect_response.status_code, target_status_code,
- ("Couldn't retrieve redirection page '%s': response code was %d"
- " (expected %d)") %
- (path, redirect_response.status_code, target_status_code))
+ msg_prefix + "Couldn't retrieve redirection page '%s':"
+ " response code was %d (expected %d)" %
+ (path, redirect_response.status_code, target_status_code))
e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
if not (e_scheme or e_netloc):
@@ -320,10 +341,11 @@
e_query, e_fragment))
self.assertEqual(url, expected_url,
- "Response redirected to '%s', expected '%s'" % (url, expected_url))
+ msg_prefix + "Response redirected to '%s', expected '%s'" %
+ (url, expected_url))
-
- def assertContains(self, response, text, count=None, status_code=200):
+ def assertContains(self, response, text, count=None, status_code=200,
+ msg_prefix=''):
"""
Asserts that a response indicates that a page was retrieved
successfully, (i.e., the HTTP status code was as expected), and that
@@ -331,42 +353,52 @@
If ``count`` is None, the count doesn't matter - the assertion is true
if the text occurs at least once in the response.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
self.assertEqual(response.status_code, status_code,
- "Couldn't retrieve page: Response code was %d (expected %d)'" %
- (response.status_code, status_code))
+ msg_prefix + "Couldn't retrieve page: Response code was %d"
+ " (expected %d)" % (response.status_code, status_code))
text = smart_str(text, response._charset)
real_count = response.content.count(text)
if count is not None:
self.assertEqual(real_count, count,
- "Found %d instances of '%s' in response (expected %d)" %
- (real_count, text, count))
+ msg_prefix + "Found %d instances of '%s' in response"
+ " (expected %d)" % (real_count, text, count))
else:
self.failUnless(real_count != 0,
- "Couldn't find '%s' in response" % text)
+ msg_prefix + "Couldn't find '%s' in response" % text)
- def assertNotContains(self, response, text, status_code=200):
+ def assertNotContains(self, response, text, status_code=200,
+ msg_prefix=''):
"""
Asserts that a response indicates that a page was retrieved
successfully, (i.e., the HTTP status code was as expected), and that
``text`` doesn't occurs in the content of the response.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
self.assertEqual(response.status_code, status_code,
- "Couldn't retrieve page: Response code was %d (expected %d)'" %
- (response.status_code, status_code))
+ msg_prefix + "Couldn't retrieve page: Response code was %d"
+ " (expected %d)" % (response.status_code, status_code))
text = smart_str(text, response._charset)
- self.assertEqual(response.content.count(text),
- 0, "Response should not contain '%s'" % text)
+ self.assertEqual(response.content.count(text), 0,
+ msg_prefix + "Response should not contain '%s'" % text)
- def assertFormError(self, response, form, field, errors):
+ def assertFormError(self, response, form, field, errors, msg_prefix=''):
"""
Asserts that a form used to render the response has a specific field
error.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
# Put context(s) into a list to simplify processing.
contexts = to_list(response.context)
if not contexts:
- self.fail('Response did not use any contexts to render the'
- ' response')
+ self.fail(msg_prefix + "Response did not use any contexts to "
+ "render the response")
# Put error(s) into a list to simplify processing.
errors = to_list(errors)
@@ -382,50 +414,65 @@
if field in context[form].errors:
field_errors = context[form].errors[field]
self.failUnless(err in field_errors,
- "The field '%s' on form '%s' in"
- " context %d does not contain the"
- " error '%s' (actual errors: %s)" %
- (field, form, i, err,
- repr(field_errors)))
+ msg_prefix + "The field '%s' on form '%s' in"
+ " context %d does not contain the error '%s'"
+ " (actual errors: %s)" %
+ (field, form, i, err, repr(field_errors)))
elif field in context[form].fields:
- self.fail("The field '%s' on form '%s' in context %d"
- " contains no errors" % (field, form, i))
+ self.fail(msg_prefix + "The field '%s' on form '%s'"
+ " in context %d contains no errors" %
+ (field, form, i))
else:
- self.fail("The form '%s' in context %d does not"
- " contain the field '%s'" %
+ self.fail(msg_prefix + "The form '%s' in context %d"
+ " does not contain the field '%s'" %
(form, i, field))
else:
non_field_errors = context[form].non_field_errors()
self.failUnless(err in non_field_errors,
- "The form '%s' in context %d does not contain the"
- " non-field error '%s' (actual errors: %s)" %
+ msg_prefix + "The form '%s' in context %d does not"
+ " contain the non-field error '%s'"
+ " (actual errors: %s)" %
(form, i, err, non_field_errors))
if not found_form:
- self.fail("The form '%s' was not used to render the response" %
- form)
+ self.fail(msg_prefix + "The form '%s' was not used to render the"
+ " response" % form)
- def assertTemplateUsed(self, response, template_name):
+ def assertTemplateUsed(self, response, template_name, msg_prefix=''):
"""
Asserts that the template with the provided name was used in rendering
the response.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
template_names = [t.name for t in to_list(response.template)]
if not template_names:
- self.fail('No templates used to render the response')
+ self.fail(msg_prefix + "No templates used to render the response")
self.failUnless(template_name in template_names,
- (u"Template '%s' was not a template used to render the response."
- u" Actual template(s) used: %s") % (template_name,
- u', '.join(template_names)))
+ msg_prefix + "Template '%s' was not a template used to render"
+ " the response. Actual template(s) used: %s" %
+ (template_name, u', '.join(template_names)))
- def assertTemplateNotUsed(self, response, template_name):
+ def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
"""
Asserts that the template with the provided name was NOT used in
rendering the response.
"""
+ if msg_prefix:
+ msg_prefix += ": "
+
template_names = [t.name for t in to_list(response.template)]
self.failIf(template_name in template_names,
- (u"Template '%s' was used unexpectedly in rendering the"
- u" response") % template_name)
+ msg_prefix + "Template '%s' was used unexpectedly in rendering"
+ " the response" % template_name)
+
+def connections_support_transactions():
+ """
+ Returns True if all connections support transactions. This is messy
+ because 2.4 doesn't support any or all.
+ """
+ return all(conn.settings_dict['SUPPORTS_TRANSACTIONS']
+ for conn in connections.all())
class TestCase(TransactionTestCase):
"""
@@ -436,27 +483,47 @@
"""
def _fixture_setup(self):
- if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
+ if not connections_support_transactions():
return super(TestCase, self)._fixture_setup()
- transaction.enter_transaction_management()
- transaction.managed(True)
+ # If the test case has a multi_db=True flag, setup all databases.
+ # Otherwise, just use default.
+ if getattr(self, 'multi_db', False):
+ databases = connections
+ else:
+ databases = [DEFAULT_DB_ALIAS]
+
+ for db in databases:
+ transaction.enter_transaction_management(using=db)
+ transaction.managed(True, using=db)
disable_transaction_methods()
from django.contrib.sites.models import Site
Site.objects.clear_cache()
- if hasattr(self, 'fixtures'):
- call_command('loaddata', *self.fixtures, **{
- 'verbosity': 0,
- 'commit': False
- })
+ for db in databases:
+ if hasattr(self, 'fixtures'):
+ call_command('loaddata', *self.fixtures, **{
+ 'verbosity': 0,
+ 'commit': False,
+ 'database': db
+ })
def _fixture_teardown(self):
- if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
+ if not connections_support_transactions():
return super(TestCase, self)._fixture_teardown()
+ # If the test case has a multi_db=True flag, teardown all databases.
+ # Otherwise, just teardown default.
+ if getattr(self, 'multi_db', False):
+ databases = connections
+ else:
+ databases = [DEFAULT_DB_ALIAS]
+
restore_transaction_methods()
- transaction.rollback()
- transaction.leave_transaction_management()
- connection.close()
+ for db in databases:
+ transaction.rollback(using=db)
+ transaction.leave_transaction_management(using=db)
+
+ for connection in connections.all():
+ connection.close()