web/lib/django/test/testcases.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
equal deleted inserted replaced
28:b758351d191f 29:cc9b7e14412b
     5 
     5 
     6 from django.conf import settings
     6 from django.conf import settings
     7 from django.core import mail
     7 from django.core import mail
     8 from django.core.management import call_command
     8 from django.core.management import call_command
     9 from django.core.urlresolvers import clear_url_caches
     9 from django.core.urlresolvers import clear_url_caches
    10 from django.db import transaction, connection
    10 from django.db import transaction, connections, DEFAULT_DB_ALIAS
    11 from django.http import QueryDict
    11 from django.http import QueryDict
    12 from django.test import _doctest as doctest
    12 from django.test import _doctest as doctest
    13 from django.test.client import Client
    13 from django.test.client import Client
    14 from django.utils import simplejson
    14 from django.utils import simplejson
    15 from django.utils.encoding import smart_str
    15 from django.utils.encoding import smart_str
       
    16 
       
    17 try:
       
    18     all
       
    19 except NameError:
       
    20     from django.utils.itercompat import all
    16 
    21 
    17 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
    22 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
    18 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
    23 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
    19 
    24 
    20 def to_list(value):
    25 def to_list(value):
   199     def report_unexpected_exception(self, out, test, example, exc_info):
   204     def report_unexpected_exception(self, out, test, example, exc_info):
   200         doctest.DocTestRunner.report_unexpected_exception(self, out, test,
   205         doctest.DocTestRunner.report_unexpected_exception(self, out, test,
   201                                                           example, exc_info)
   206                                                           example, exc_info)
   202         # Rollback, in case of database errors. Otherwise they'd have
   207         # Rollback, in case of database errors. Otherwise they'd have
   203         # side effects on other tests.
   208         # side effects on other tests.
   204         transaction.rollback_unless_managed()
   209         for conn in connections:
       
   210             transaction.rollback_unless_managed(using=conn)
   205 
   211 
   206 class TransactionTestCase(unittest.TestCase):
   212 class TransactionTestCase(unittest.TestCase):
   207     def _pre_setup(self):
   213     def _pre_setup(self):
   208         """Performs any pre-test setup. This includes:
   214         """Performs any pre-test setup. This includes:
   209 
   215 
   217         self._fixture_setup()
   223         self._fixture_setup()
   218         self._urlconf_setup()
   224         self._urlconf_setup()
   219         mail.outbox = []
   225         mail.outbox = []
   220 
   226 
   221     def _fixture_setup(self):
   227     def _fixture_setup(self):
   222         call_command('flush', verbosity=0, interactive=False)
   228         # If the test case has a multi_db=True flag, flush all databases.
   223         if hasattr(self, 'fixtures'):
   229         # Otherwise, just flush default.
   224             # We have to use this slightly awkward syntax due to the fact
   230         if getattr(self, 'multi_db', False):
   225             # that we're using *args and **kwargs together.
   231             databases = connections
   226             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
   232         else:
       
   233             databases = [DEFAULT_DB_ALIAS]
       
   234         for db in databases:
       
   235             call_command('flush', verbosity=0, interactive=False, database=db)
       
   236 
       
   237             if hasattr(self, 'fixtures'):
       
   238                 # We have to use this slightly awkward syntax due to the fact
       
   239                 # that we're using *args and **kwargs together.
       
   240                 call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db})
   227 
   241 
   228     def _urlconf_setup(self):
   242     def _urlconf_setup(self):
   229         if hasattr(self, 'urls'):
   243         if hasattr(self, 'urls'):
   230             self._old_root_urlconf = settings.ROOT_URLCONF
   244             self._old_root_urlconf = settings.ROOT_URLCONF
   231             settings.ROOT_URLCONF = self.urls
   245             settings.ROOT_URLCONF = self.urls
   271         if hasattr(self, '_old_root_urlconf'):
   285         if hasattr(self, '_old_root_urlconf'):
   272             settings.ROOT_URLCONF = self._old_root_urlconf
   286             settings.ROOT_URLCONF = self._old_root_urlconf
   273             clear_url_caches()
   287             clear_url_caches()
   274 
   288 
   275     def assertRedirects(self, response, expected_url, status_code=302,
   289     def assertRedirects(self, response, expected_url, status_code=302,
   276                         target_status_code=200, host=None):
   290                         target_status_code=200, host=None, msg_prefix=''):
   277         """Asserts that a response redirected to a specific URL, and that the
   291         """Asserts that a response redirected to a specific URL, and that the
   278         redirect URL can be loaded.
   292         redirect URL can be loaded.
   279 
   293 
   280         Note that assertRedirects won't work for external links since it uses
   294         Note that assertRedirects won't work for external links since it uses
   281         TestClient to do a request.
   295         TestClient to do a request.
   282         """
   296         """
       
   297         if msg_prefix:
       
   298             msg_prefix += ": "
       
   299 
   283         if hasattr(response, 'redirect_chain'):
   300         if hasattr(response, 'redirect_chain'):
   284             # The request was a followed redirect
   301             # The request was a followed redirect
   285             self.failUnless(len(response.redirect_chain) > 0,
   302             self.failUnless(len(response.redirect_chain) > 0,
   286                 ("Response didn't redirect as expected: Response code was %d"
   303                 msg_prefix + "Response didn't redirect as expected: Response"
   287                 " (expected %d)" % (response.status_code, status_code)))
   304                 " code was %d (expected %d)" %
       
   305                     (response.status_code, status_code))
   288 
   306 
   289             self.assertEqual(response.redirect_chain[0][1], status_code,
   307             self.assertEqual(response.redirect_chain[0][1], status_code,
   290                 ("Initial response didn't redirect as expected: Response code was %d"
   308                 msg_prefix + "Initial response didn't redirect as expected:"
   291                  " (expected %d)" % (response.redirect_chain[0][1], status_code)))
   309                 " Response code was %d (expected %d)" %
       
   310                     (response.redirect_chain[0][1], status_code))
   292 
   311 
   293             url, status_code = response.redirect_chain[-1]
   312             url, status_code = response.redirect_chain[-1]
   294 
   313 
   295             self.assertEqual(response.status_code, target_status_code,
   314             self.assertEqual(response.status_code, target_status_code,
   296                 ("Response didn't redirect as expected: Final Response code was %d"
   315                 msg_prefix + "Response didn't redirect as expected: Final"
   297                 " (expected %d)" % (response.status_code, target_status_code)))
   316                 " Response code was %d (expected %d)" %
       
   317                     (response.status_code, target_status_code))
   298 
   318 
   299         else:
   319         else:
   300             # Not a followed redirect
   320             # Not a followed redirect
   301             self.assertEqual(response.status_code, status_code,
   321             self.assertEqual(response.status_code, status_code,
   302                 ("Response didn't redirect as expected: Response code was %d"
   322                 msg_prefix + "Response didn't redirect as expected: Response"
   303                  " (expected %d)" % (response.status_code, status_code)))
   323                 " code was %d (expected %d)" %
       
   324                     (response.status_code, status_code))
   304 
   325 
   305             url = response['Location']
   326             url = response['Location']
   306             scheme, netloc, path, query, fragment = urlsplit(url)
   327             scheme, netloc, path, query, fragment = urlsplit(url)
   307 
   328 
   308             redirect_response = response.client.get(path, QueryDict(query))
   329             redirect_response = response.client.get(path, QueryDict(query))
   309 
   330 
   310             # Get the redirection page, using the same client that was used
   331             # Get the redirection page, using the same client that was used
   311             # to obtain the original response.
   332             # to obtain the original response.
   312             self.assertEqual(redirect_response.status_code, target_status_code,
   333             self.assertEqual(redirect_response.status_code, target_status_code,
   313                 ("Couldn't retrieve redirection page '%s': response code was %d"
   334                 msg_prefix + "Couldn't retrieve redirection page '%s':"
   314                  " (expected %d)") %
   335                 " response code was %d (expected %d)" %
   315                      (path, redirect_response.status_code, target_status_code))
   336                     (path, redirect_response.status_code, target_status_code))
   316 
   337 
   317         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
   338         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
   318         if not (e_scheme or e_netloc):
   339         if not (e_scheme or e_netloc):
   319             expected_url = urlunsplit(('http', host or 'testserver', e_path,
   340             expected_url = urlunsplit(('http', host or 'testserver', e_path,
   320                 e_query, e_fragment))
   341                 e_query, e_fragment))
   321 
   342 
   322         self.assertEqual(url, expected_url,
   343         self.assertEqual(url, expected_url,
   323             "Response redirected to '%s', expected '%s'" % (url, expected_url))
   344             msg_prefix + "Response redirected to '%s', expected '%s'" %
   324 
   345                 (url, expected_url))
   325 
   346 
   326     def assertContains(self, response, text, count=None, status_code=200):
   347     def assertContains(self, response, text, count=None, status_code=200,
       
   348                        msg_prefix=''):
   327         """
   349         """
   328         Asserts that a response indicates that a page was retrieved
   350         Asserts that a response indicates that a page was retrieved
   329         successfully, (i.e., the HTTP status code was as expected), and that
   351         successfully, (i.e., the HTTP status code was as expected), and that
   330         ``text`` occurs ``count`` times in the content of the response.
   352         ``text`` occurs ``count`` times in the content of the response.
   331         If ``count`` is None, the count doesn't matter - the assertion is true
   353         If ``count`` is None, the count doesn't matter - the assertion is true
   332         if the text occurs at least once in the response.
   354         if the text occurs at least once in the response.
   333         """
   355         """
       
   356         if msg_prefix:
       
   357             msg_prefix += ": "
       
   358 
   334         self.assertEqual(response.status_code, status_code,
   359         self.assertEqual(response.status_code, status_code,
   335             "Couldn't retrieve page: Response code was %d (expected %d)'" %
   360             msg_prefix + "Couldn't retrieve page: Response code was %d"
   336                 (response.status_code, status_code))
   361             " (expected %d)" % (response.status_code, status_code))
   337         text = smart_str(text, response._charset)
   362         text = smart_str(text, response._charset)
   338         real_count = response.content.count(text)
   363         real_count = response.content.count(text)
   339         if count is not None:
   364         if count is not None:
   340             self.assertEqual(real_count, count,
   365             self.assertEqual(real_count, count,
   341                 "Found %d instances of '%s' in response (expected %d)" %
   366                 msg_prefix + "Found %d instances of '%s' in response"
   342                     (real_count, text, count))
   367                 " (expected %d)" % (real_count, text, count))
   343         else:
   368         else:
   344             self.failUnless(real_count != 0,
   369             self.failUnless(real_count != 0,
   345                             "Couldn't find '%s' in response" % text)
   370                 msg_prefix + "Couldn't find '%s' in response" % text)
   346 
   371 
   347     def assertNotContains(self, response, text, status_code=200):
   372     def assertNotContains(self, response, text, status_code=200,
       
   373                           msg_prefix=''):
   348         """
   374         """
   349         Asserts that a response indicates that a page was retrieved
   375         Asserts that a response indicates that a page was retrieved
   350         successfully, (i.e., the HTTP status code was as expected), and that
   376         successfully, (i.e., the HTTP status code was as expected), and that
   351         ``text`` doesn't occurs in the content of the response.
   377         ``text`` doesn't occurs in the content of the response.
   352         """
   378         """
       
   379         if msg_prefix:
       
   380             msg_prefix += ": "
       
   381 
   353         self.assertEqual(response.status_code, status_code,
   382         self.assertEqual(response.status_code, status_code,
   354             "Couldn't retrieve page: Response code was %d (expected %d)'" %
   383             msg_prefix + "Couldn't retrieve page: Response code was %d"
   355                 (response.status_code, status_code))
   384             " (expected %d)" % (response.status_code, status_code))
   356         text = smart_str(text, response._charset)
   385         text = smart_str(text, response._charset)
   357         self.assertEqual(response.content.count(text),
   386         self.assertEqual(response.content.count(text), 0,
   358              0, "Response should not contain '%s'" % text)
   387             msg_prefix + "Response should not contain '%s'" % text)
   359 
   388 
   360     def assertFormError(self, response, form, field, errors):
   389     def assertFormError(self, response, form, field, errors, msg_prefix=''):
   361         """
   390         """
   362         Asserts that a form used to render the response has a specific field
   391         Asserts that a form used to render the response has a specific field
   363         error.
   392         error.
   364         """
   393         """
       
   394         if msg_prefix:
       
   395             msg_prefix += ": "
       
   396 
   365         # Put context(s) into a list to simplify processing.
   397         # Put context(s) into a list to simplify processing.
   366         contexts = to_list(response.context)
   398         contexts = to_list(response.context)
   367         if not contexts:
   399         if not contexts:
   368             self.fail('Response did not use any contexts to render the'
   400             self.fail(msg_prefix + "Response did not use any contexts to "
   369                       ' response')
   401                       "render the response")
   370 
   402 
   371         # Put error(s) into a list to simplify processing.
   403         # Put error(s) into a list to simplify processing.
   372         errors = to_list(errors)
   404         errors = to_list(errors)
   373 
   405 
   374         # Search all contexts for the error.
   406         # Search all contexts for the error.
   380             for err in errors:
   412             for err in errors:
   381                 if field:
   413                 if field:
   382                     if field in context[form].errors:
   414                     if field in context[form].errors:
   383                         field_errors = context[form].errors[field]
   415                         field_errors = context[form].errors[field]
   384                         self.failUnless(err in field_errors,
   416                         self.failUnless(err in field_errors,
   385                                         "The field '%s' on form '%s' in"
   417                             msg_prefix + "The field '%s' on form '%s' in"
   386                                         " context %d does not contain the"
   418                             " context %d does not contain the error '%s'"
   387                                         " error '%s' (actual errors: %s)" %
   419                             " (actual errors: %s)" %
   388                                             (field, form, i, err,
   420                                 (field, form, i, err, repr(field_errors)))
   389                                              repr(field_errors)))
       
   390                     elif field in context[form].fields:
   421                     elif field in context[form].fields:
   391                         self.fail("The field '%s' on form '%s' in context %d"
   422                         self.fail(msg_prefix + "The field '%s' on form '%s'"
   392                                   " contains no errors" % (field, form, i))
   423                                   " in context %d contains no errors" %
       
   424                                       (field, form, i))
   393                     else:
   425                     else:
   394                         self.fail("The form '%s' in context %d does not"
   426                         self.fail(msg_prefix + "The form '%s' in context %d"
   395                                   " contain the field '%s'" %
   427                                   " does not contain the field '%s'" %
   396                                       (form, i, field))
   428                                       (form, i, field))
   397                 else:
   429                 else:
   398                     non_field_errors = context[form].non_field_errors()
   430                     non_field_errors = context[form].non_field_errors()
   399                     self.failUnless(err in non_field_errors,
   431                     self.failUnless(err in non_field_errors,
   400                         "The form '%s' in context %d does not contain the"
   432                         msg_prefix + "The form '%s' in context %d does not"
   401                         " non-field error '%s' (actual errors: %s)" %
   433                         " contain the non-field error '%s'"
       
   434                         " (actual errors: %s)" %
   402                             (form, i, err, non_field_errors))
   435                             (form, i, err, non_field_errors))
   403         if not found_form:
   436         if not found_form:
   404             self.fail("The form '%s' was not used to render the response" %
   437             self.fail(msg_prefix + "The form '%s' was not used to render the"
   405                           form)
   438                       " response" % form)
   406 
   439 
   407     def assertTemplateUsed(self, response, template_name):
   440     def assertTemplateUsed(self, response, template_name, msg_prefix=''):
   408         """
   441         """
   409         Asserts that the template with the provided name was used in rendering
   442         Asserts that the template with the provided name was used in rendering
   410         the response.
   443         the response.
   411         """
   444         """
       
   445         if msg_prefix:
       
   446             msg_prefix += ": "
       
   447 
   412         template_names = [t.name for t in to_list(response.template)]
   448         template_names = [t.name for t in to_list(response.template)]
   413         if not template_names:
   449         if not template_names:
   414             self.fail('No templates used to render the response')
   450             self.fail(msg_prefix + "No templates used to render the response")
   415         self.failUnless(template_name in template_names,
   451         self.failUnless(template_name in template_names,
   416             (u"Template '%s' was not a template used to render the response."
   452             msg_prefix + "Template '%s' was not a template used to render"
   417              u" Actual template(s) used: %s") % (template_name,
   453             " the response. Actual template(s) used: %s" %
   418                                                  u', '.join(template_names)))
   454                 (template_name, u', '.join(template_names)))
   419 
   455 
   420     def assertTemplateNotUsed(self, response, template_name):
   456     def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
   421         """
   457         """
   422         Asserts that the template with the provided name was NOT used in
   458         Asserts that the template with the provided name was NOT used in
   423         rendering the response.
   459         rendering the response.
   424         """
   460         """
       
   461         if msg_prefix:
       
   462             msg_prefix += ": "
       
   463 
   425         template_names = [t.name for t in to_list(response.template)]
   464         template_names = [t.name for t in to_list(response.template)]
   426         self.failIf(template_name in template_names,
   465         self.failIf(template_name in template_names,
   427             (u"Template '%s' was used unexpectedly in rendering the"
   466             msg_prefix + "Template '%s' was used unexpectedly in rendering"
   428              u" response") % template_name)
   467             " the response" % template_name)
       
   468 
       
   469 def connections_support_transactions():
       
   470     """
       
   471     Returns True if all connections support transactions.  This is messy
       
   472     because 2.4 doesn't support any or all.
       
   473     """
       
   474     return all(conn.settings_dict['SUPPORTS_TRANSACTIONS']
       
   475         for conn in connections.all())
   429 
   476 
   430 class TestCase(TransactionTestCase):
   477 class TestCase(TransactionTestCase):
   431     """
   478     """
   432     Does basically the same as TransactionTestCase, but surrounds every test
   479     Does basically the same as TransactionTestCase, but surrounds every test
   433     with a transaction, monkey-patches the real transaction management routines to
   480     with a transaction, monkey-patches the real transaction management routines to
   434     do nothing, and rollsback the test transaction at the end of the test. You have
   481     do nothing, and rollsback the test transaction at the end of the test. You have
   435     to use TransactionTestCase, if you need transaction management inside a test.
   482     to use TransactionTestCase, if you need transaction management inside a test.
   436     """
   483     """
   437 
   484 
   438     def _fixture_setup(self):
   485     def _fixture_setup(self):
   439         if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
   486         if not connections_support_transactions():
   440             return super(TestCase, self)._fixture_setup()
   487             return super(TestCase, self)._fixture_setup()
   441 
   488 
   442         transaction.enter_transaction_management()
   489         # If the test case has a multi_db=True flag, setup all databases.
   443         transaction.managed(True)
   490         # Otherwise, just use default.
       
   491         if getattr(self, 'multi_db', False):
       
   492             databases = connections
       
   493         else:
       
   494             databases = [DEFAULT_DB_ALIAS]
       
   495 
       
   496         for db in databases:
       
   497             transaction.enter_transaction_management(using=db)
       
   498             transaction.managed(True, using=db)
   444         disable_transaction_methods()
   499         disable_transaction_methods()
   445 
   500 
   446         from django.contrib.sites.models import Site
   501         from django.contrib.sites.models import Site
   447         Site.objects.clear_cache()
   502         Site.objects.clear_cache()
   448 
   503 
   449         if hasattr(self, 'fixtures'):
   504         for db in databases:
   450             call_command('loaddata', *self.fixtures, **{
   505             if hasattr(self, 'fixtures'):
   451                                                         'verbosity': 0,
   506                 call_command('loaddata', *self.fixtures, **{
   452                                                         'commit': False
   507                                                             'verbosity': 0,
   453                                                         })
   508                                                             'commit': False,
       
   509                                                             'database': db
       
   510                                                             })
   454 
   511 
   455     def _fixture_teardown(self):
   512     def _fixture_teardown(self):
   456         if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
   513         if not connections_support_transactions():
   457             return super(TestCase, self)._fixture_teardown()
   514             return super(TestCase, self)._fixture_teardown()
   458 
   515 
       
   516         # If the test case has a multi_db=True flag, teardown all databases.
       
   517         # Otherwise, just teardown default.
       
   518         if getattr(self, 'multi_db', False):
       
   519             databases = connections
       
   520         else:
       
   521             databases = [DEFAULT_DB_ALIAS]
       
   522 
   459         restore_transaction_methods()
   523         restore_transaction_methods()
   460         transaction.rollback()
   524         for db in databases:
   461         transaction.leave_transaction_management()
   525             transaction.rollback(using=db)
   462         connection.close()
   526             transaction.leave_transaction_management(using=db)
       
   527 
       
   528         for connection in connections.all():
       
   529             connection.close()