diff -r b758351d191f -r cc9b7e14412b web/lib/django/test/testcases.py --- a/web/lib/django/test/testcases.py Wed May 19 17:43:59 2010 +0200 +++ b/web/lib/django/test/testcases.py Tue May 25 02:43:45 2010 +0200 @@ -7,13 +7,18 @@ from django.core import mail from django.core.management import call_command from django.core.urlresolvers import clear_url_caches -from django.db import transaction, connection +from django.db import transaction, connections, DEFAULT_DB_ALIAS from django.http import QueryDict from django.test import _doctest as doctest from django.test.client import Client from django.utils import simplejson from django.utils.encoding import smart_str +try: + all +except NameError: + from django.utils.itercompat import all + normalize_long_ints = lambda s: re.sub(r'(? 0, - ("Response didn't redirect as expected: Response code was %d" - " (expected %d)" % (response.status_code, status_code))) + msg_prefix + "Response didn't redirect as expected: Response" + " code was %d (expected %d)" % + (response.status_code, status_code)) self.assertEqual(response.redirect_chain[0][1], status_code, - ("Initial response didn't redirect as expected: Response code was %d" - " (expected %d)" % (response.redirect_chain[0][1], status_code))) + msg_prefix + "Initial response didn't redirect as expected:" + " Response code was %d (expected %d)" % + (response.redirect_chain[0][1], status_code)) url, status_code = response.redirect_chain[-1] self.assertEqual(response.status_code, target_status_code, - ("Response didn't redirect as expected: Final Response code was %d" - " (expected %d)" % (response.status_code, target_status_code))) + msg_prefix + "Response didn't redirect as expected: Final" + " Response code was %d (expected %d)" % + (response.status_code, target_status_code)) else: # Not a followed redirect self.assertEqual(response.status_code, status_code, - ("Response didn't redirect as expected: Response code was %d" - " (expected %d)" % (response.status_code, status_code))) + msg_prefix + "Response didn't redirect as expected: Response" + " code was %d (expected %d)" % + (response.status_code, status_code)) url = response['Location'] scheme, netloc, path, query, fragment = urlsplit(url) @@ -310,9 +331,9 @@ # Get the redirection page, using the same client that was used # to obtain the original response. self.assertEqual(redirect_response.status_code, target_status_code, - ("Couldn't retrieve redirection page '%s': response code was %d" - " (expected %d)") % - (path, redirect_response.status_code, target_status_code)) + msg_prefix + "Couldn't retrieve redirection page '%s':" + " response code was %d (expected %d)" % + (path, redirect_response.status_code, target_status_code)) e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) if not (e_scheme or e_netloc): @@ -320,10 +341,11 @@ e_query, e_fragment)) self.assertEqual(url, expected_url, - "Response redirected to '%s', expected '%s'" % (url, expected_url)) + msg_prefix + "Response redirected to '%s', expected '%s'" % + (url, expected_url)) - - def assertContains(self, response, text, count=None, status_code=200): + def assertContains(self, response, text, count=None, status_code=200, + msg_prefix=''): """ Asserts that a response indicates that a page was retrieved successfully, (i.e., the HTTP status code was as expected), and that @@ -331,42 +353,52 @@ If ``count`` is None, the count doesn't matter - the assertion is true if the text occurs at least once in the response. """ + if msg_prefix: + msg_prefix += ": " + self.assertEqual(response.status_code, status_code, - "Couldn't retrieve page: Response code was %d (expected %d)'" % - (response.status_code, status_code)) + msg_prefix + "Couldn't retrieve page: Response code was %d" + " (expected %d)" % (response.status_code, status_code)) text = smart_str(text, response._charset) real_count = response.content.count(text) if count is not None: self.assertEqual(real_count, count, - "Found %d instances of '%s' in response (expected %d)" % - (real_count, text, count)) + msg_prefix + "Found %d instances of '%s' in response" + " (expected %d)" % (real_count, text, count)) else: self.failUnless(real_count != 0, - "Couldn't find '%s' in response" % text) + msg_prefix + "Couldn't find '%s' in response" % text) - def assertNotContains(self, response, text, status_code=200): + def assertNotContains(self, response, text, status_code=200, + msg_prefix=''): """ Asserts that a response indicates that a page was retrieved successfully, (i.e., the HTTP status code was as expected), and that ``text`` doesn't occurs in the content of the response. """ + if msg_prefix: + msg_prefix += ": " + self.assertEqual(response.status_code, status_code, - "Couldn't retrieve page: Response code was %d (expected %d)'" % - (response.status_code, status_code)) + msg_prefix + "Couldn't retrieve page: Response code was %d" + " (expected %d)" % (response.status_code, status_code)) text = smart_str(text, response._charset) - self.assertEqual(response.content.count(text), - 0, "Response should not contain '%s'" % text) + self.assertEqual(response.content.count(text), 0, + msg_prefix + "Response should not contain '%s'" % text) - def assertFormError(self, response, form, field, errors): + def assertFormError(self, response, form, field, errors, msg_prefix=''): """ Asserts that a form used to render the response has a specific field error. """ + if msg_prefix: + msg_prefix += ": " + # Put context(s) into a list to simplify processing. contexts = to_list(response.context) if not contexts: - self.fail('Response did not use any contexts to render the' - ' response') + self.fail(msg_prefix + "Response did not use any contexts to " + "render the response") # Put error(s) into a list to simplify processing. errors = to_list(errors) @@ -382,50 +414,65 @@ if field in context[form].errors: field_errors = context[form].errors[field] self.failUnless(err in field_errors, - "The field '%s' on form '%s' in" - " context %d does not contain the" - " error '%s' (actual errors: %s)" % - (field, form, i, err, - repr(field_errors))) + msg_prefix + "The field '%s' on form '%s' in" + " context %d does not contain the error '%s'" + " (actual errors: %s)" % + (field, form, i, err, repr(field_errors))) elif field in context[form].fields: - self.fail("The field '%s' on form '%s' in context %d" - " contains no errors" % (field, form, i)) + self.fail(msg_prefix + "The field '%s' on form '%s'" + " in context %d contains no errors" % + (field, form, i)) else: - self.fail("The form '%s' in context %d does not" - " contain the field '%s'" % + self.fail(msg_prefix + "The form '%s' in context %d" + " does not contain the field '%s'" % (form, i, field)) else: non_field_errors = context[form].non_field_errors() self.failUnless(err in non_field_errors, - "The form '%s' in context %d does not contain the" - " non-field error '%s' (actual errors: %s)" % + msg_prefix + "The form '%s' in context %d does not" + " contain the non-field error '%s'" + " (actual errors: %s)" % (form, i, err, non_field_errors)) if not found_form: - self.fail("The form '%s' was not used to render the response" % - form) + self.fail(msg_prefix + "The form '%s' was not used to render the" + " response" % form) - def assertTemplateUsed(self, response, template_name): + def assertTemplateUsed(self, response, template_name, msg_prefix=''): """ Asserts that the template with the provided name was used in rendering the response. """ + if msg_prefix: + msg_prefix += ": " + template_names = [t.name for t in to_list(response.template)] if not template_names: - self.fail('No templates used to render the response') + self.fail(msg_prefix + "No templates used to render the response") self.failUnless(template_name in template_names, - (u"Template '%s' was not a template used to render the response." - u" Actual template(s) used: %s") % (template_name, - u', '.join(template_names))) + msg_prefix + "Template '%s' was not a template used to render" + " the response. Actual template(s) used: %s" % + (template_name, u', '.join(template_names))) - def assertTemplateNotUsed(self, response, template_name): + def assertTemplateNotUsed(self, response, template_name, msg_prefix=''): """ Asserts that the template with the provided name was NOT used in rendering the response. """ + if msg_prefix: + msg_prefix += ": " + template_names = [t.name for t in to_list(response.template)] self.failIf(template_name in template_names, - (u"Template '%s' was used unexpectedly in rendering the" - u" response") % template_name) + msg_prefix + "Template '%s' was used unexpectedly in rendering" + " the response" % template_name) + +def connections_support_transactions(): + """ + Returns True if all connections support transactions. This is messy + because 2.4 doesn't support any or all. + """ + return all(conn.settings_dict['SUPPORTS_TRANSACTIONS'] + for conn in connections.all()) class TestCase(TransactionTestCase): """ @@ -436,27 +483,47 @@ """ def _fixture_setup(self): - if not settings.DATABASE_SUPPORTS_TRANSACTIONS: + if not connections_support_transactions(): return super(TestCase, self)._fixture_setup() - transaction.enter_transaction_management() - transaction.managed(True) + # If the test case has a multi_db=True flag, setup all databases. + # Otherwise, just use default. + if getattr(self, 'multi_db', False): + databases = connections + else: + databases = [DEFAULT_DB_ALIAS] + + for db in databases: + transaction.enter_transaction_management(using=db) + transaction.managed(True, using=db) disable_transaction_methods() from django.contrib.sites.models import Site Site.objects.clear_cache() - if hasattr(self, 'fixtures'): - call_command('loaddata', *self.fixtures, **{ - 'verbosity': 0, - 'commit': False - }) + for db in databases: + if hasattr(self, 'fixtures'): + call_command('loaddata', *self.fixtures, **{ + 'verbosity': 0, + 'commit': False, + 'database': db + }) def _fixture_teardown(self): - if not settings.DATABASE_SUPPORTS_TRANSACTIONS: + if not connections_support_transactions(): return super(TestCase, self)._fixture_teardown() + # If the test case has a multi_db=True flag, teardown all databases. + # Otherwise, just teardown default. + if getattr(self, 'multi_db', False): + databases = connections + else: + databases = [DEFAULT_DB_ALIAS] + restore_transaction_methods() - transaction.rollback() - transaction.leave_transaction_management() - connection.close() + for db in databases: + transaction.rollback(using=db) + transaction.leave_transaction_management(using=db) + + for connection in connections.all(): + connection.close()