web/lib/django/test/testcases.py
changeset 38 77b6da96e6f1
equal deleted inserted replaced
37:8d941af65caf 38:77b6da96e6f1
       
     1 import re
       
     2 import unittest
       
     3 from urlparse import urlsplit, urlunsplit
       
     4 from xml.dom.minidom import parseString, Node
       
     5 
       
     6 from django.conf import settings
       
     7 from django.core import mail
       
     8 from django.core.management import call_command
       
     9 from django.core.urlresolvers import clear_url_caches
       
    10 from django.db import transaction, connections, DEFAULT_DB_ALIAS
       
    11 from django.http import QueryDict
       
    12 from django.test import _doctest as doctest
       
    13 from django.test.client import Client
       
    14 from django.utils import simplejson
       
    15 from django.utils.encoding import smart_str
       
    16 
       
    17 try:
       
    18     all
       
    19 except NameError:
       
    20     from django.utils.itercompat import all
       
    21 
       
    22 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
       
    23 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
       
    24 
       
    25 def to_list(value):
       
    26     """
       
    27     Puts value into a list if it's not already one.
       
    28     Returns an empty list if value is None.
       
    29     """
       
    30     if value is None:
       
    31         value = []
       
    32     elif not isinstance(value, list):
       
    33         value = [value]
       
    34     return value
       
    35 
       
    36 real_commit = transaction.commit
       
    37 real_rollback = transaction.rollback
       
    38 real_enter_transaction_management = transaction.enter_transaction_management
       
    39 real_leave_transaction_management = transaction.leave_transaction_management
       
    40 real_savepoint_commit = transaction.savepoint_commit
       
    41 real_savepoint_rollback = transaction.savepoint_rollback
       
    42 real_managed = transaction.managed
       
    43 
       
    44 def nop(*args, **kwargs):
       
    45     return
       
    46 
       
    47 def disable_transaction_methods():
       
    48     transaction.commit = nop
       
    49     transaction.rollback = nop
       
    50     transaction.savepoint_commit = nop
       
    51     transaction.savepoint_rollback = nop
       
    52     transaction.enter_transaction_management = nop
       
    53     transaction.leave_transaction_management = nop
       
    54     transaction.managed = nop
       
    55 
       
    56 def restore_transaction_methods():
       
    57     transaction.commit = real_commit
       
    58     transaction.rollback = real_rollback
       
    59     transaction.savepoint_commit = real_savepoint_commit
       
    60     transaction.savepoint_rollback = real_savepoint_rollback
       
    61     transaction.enter_transaction_management = real_enter_transaction_management
       
    62     transaction.leave_transaction_management = real_leave_transaction_management
       
    63     transaction.managed = real_managed
       
    64 
       
    65 class OutputChecker(doctest.OutputChecker):
       
    66     def check_output(self, want, got, optionflags):
       
    67         "The entry method for doctest output checking. Defers to a sequence of child checkers"
       
    68         checks = (self.check_output_default,
       
    69                   self.check_output_numeric,
       
    70                   self.check_output_xml,
       
    71                   self.check_output_json)
       
    72         for check in checks:
       
    73             if check(want, got, optionflags):
       
    74                 return True
       
    75         return False
       
    76 
       
    77     def check_output_default(self, want, got, optionflags):
       
    78         "The default comparator provided by doctest - not perfect, but good for most purposes"
       
    79         return doctest.OutputChecker.check_output(self, want, got, optionflags)
       
    80 
       
    81     def check_output_numeric(self, want, got, optionflags):
       
    82         """Doctest does an exact string comparison of output, which means that
       
    83         some numerically equivalent values aren't equal. This check normalizes
       
    84          * long integers (22L) so that they equal normal integers. (22)
       
    85          * Decimals so that they are comparable, regardless of the change
       
    86            made to __repr__ in Python 2.6.
       
    87         """
       
    88         return doctest.OutputChecker.check_output(self,
       
    89             normalize_decimals(normalize_long_ints(want)),
       
    90             normalize_decimals(normalize_long_ints(got)),
       
    91             optionflags)
       
    92 
       
    93     def check_output_xml(self, want, got, optionsflags):
       
    94         """Tries to do a 'xml-comparision' of want and got.  Plain string
       
    95         comparision doesn't always work because, for example, attribute
       
    96         ordering should not be important.
       
    97 
       
    98         Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
       
    99         """
       
   100         _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
       
   101         def norm_whitespace(v):
       
   102             return _norm_whitespace_re.sub(' ', v)
       
   103 
       
   104         def child_text(element):
       
   105             return ''.join([c.data for c in element.childNodes
       
   106                             if c.nodeType == Node.TEXT_NODE])
       
   107 
       
   108         def children(element):
       
   109             return [c for c in element.childNodes
       
   110                     if c.nodeType == Node.ELEMENT_NODE]
       
   111 
       
   112         def norm_child_text(element):
       
   113             return norm_whitespace(child_text(element))
       
   114 
       
   115         def attrs_dict(element):
       
   116             return dict(element.attributes.items())
       
   117 
       
   118         def check_element(want_element, got_element):
       
   119             if want_element.tagName != got_element.tagName:
       
   120                 return False
       
   121             if norm_child_text(want_element) != norm_child_text(got_element):
       
   122                 return False
       
   123             if attrs_dict(want_element) != attrs_dict(got_element):
       
   124                 return False
       
   125             want_children = children(want_element)
       
   126             got_children = children(got_element)
       
   127             if len(want_children) != len(got_children):
       
   128                 return False
       
   129             for want, got in zip(want_children, got_children):
       
   130                 if not check_element(want, got):
       
   131                     return False
       
   132             return True
       
   133 
       
   134         want, got = self._strip_quotes(want, got)
       
   135         want = want.replace('\\n','\n')
       
   136         got = got.replace('\\n','\n')
       
   137 
       
   138         # If the string is not a complete xml document, we may need to add a
       
   139         # root element. This allow us to compare fragments, like "<foo/><bar/>"
       
   140         if not want.startswith('<?xml'):
       
   141             wrapper = '<root>%s</root>'
       
   142             want = wrapper % want
       
   143             got = wrapper % got
       
   144 
       
   145         # Parse the want and got strings, and compare the parsings.
       
   146         try:
       
   147             want_root = parseString(want).firstChild
       
   148             got_root = parseString(got).firstChild
       
   149         except:
       
   150             return False
       
   151         return check_element(want_root, got_root)
       
   152 
       
   153     def check_output_json(self, want, got, optionsflags):
       
   154         "Tries to compare want and got as if they were JSON-encoded data"
       
   155         want, got = self._strip_quotes(want, got)
       
   156         try:
       
   157             want_json = simplejson.loads(want)
       
   158             got_json = simplejson.loads(got)
       
   159         except:
       
   160             return False
       
   161         return want_json == got_json
       
   162 
       
   163     def _strip_quotes(self, want, got):
       
   164         """
       
   165         Strip quotes of doctests output values:
       
   166 
       
   167         >>> o = OutputChecker()
       
   168         >>> o._strip_quotes("'foo'")
       
   169         "foo"
       
   170         >>> o._strip_quotes('"foo"')
       
   171         "foo"
       
   172         >>> o._strip_quotes("u'foo'")
       
   173         "foo"
       
   174         >>> o._strip_quotes('u"foo"')
       
   175         "foo"
       
   176         """
       
   177         def is_quoted_string(s):
       
   178             s = s.strip()
       
   179             return (len(s) >= 2
       
   180                     and s[0] == s[-1]
       
   181                     and s[0] in ('"', "'"))
       
   182 
       
   183         def is_quoted_unicode(s):
       
   184             s = s.strip()
       
   185             return (len(s) >= 3
       
   186                     and s[0] == 'u'
       
   187                     and s[1] == s[-1]
       
   188                     and s[1] in ('"', "'"))
       
   189 
       
   190         if is_quoted_string(want) and is_quoted_string(got):
       
   191             want = want.strip()[1:-1]
       
   192             got = got.strip()[1:-1]
       
   193         elif is_quoted_unicode(want) and is_quoted_unicode(got):
       
   194             want = want.strip()[2:-1]
       
   195             got = got.strip()[2:-1]
       
   196         return want, got
       
   197 
       
   198 
       
   199 class DocTestRunner(doctest.DocTestRunner):
       
   200     def __init__(self, *args, **kwargs):
       
   201         doctest.DocTestRunner.__init__(self, *args, **kwargs)
       
   202         self.optionflags = doctest.ELLIPSIS
       
   203 
       
   204     def report_unexpected_exception(self, out, test, example, exc_info):
       
   205         doctest.DocTestRunner.report_unexpected_exception(self, out, test,
       
   206                                                           example, exc_info)
       
   207         # Rollback, in case of database errors. Otherwise they'd have
       
   208         # side effects on other tests.
       
   209         for conn in connections:
       
   210             transaction.rollback_unless_managed(using=conn)
       
   211 
       
   212 class TransactionTestCase(unittest.TestCase):
       
   213     def _pre_setup(self):
       
   214         """Performs any pre-test setup. This includes:
       
   215 
       
   216             * Flushing the database.
       
   217             * If the Test Case class has a 'fixtures' member, installing the
       
   218               named fixtures.
       
   219             * If the Test Case class has a 'urls' member, replace the
       
   220               ROOT_URLCONF with it.
       
   221             * Clearing the mail test outbox.
       
   222         """
       
   223         self._fixture_setup()
       
   224         self._urlconf_setup()
       
   225         mail.outbox = []
       
   226 
       
   227     def _fixture_setup(self):
       
   228         # If the test case has a multi_db=True flag, flush all databases.
       
   229         # Otherwise, just flush default.
       
   230         if getattr(self, 'multi_db', False):
       
   231             databases = connections
       
   232         else:
       
   233             databases = [DEFAULT_DB_ALIAS]
       
   234         for db in databases:
       
   235             call_command('flush', verbosity=0, interactive=False, database=db)
       
   236 
       
   237             if hasattr(self, 'fixtures'):
       
   238                 # We have to use this slightly awkward syntax due to the fact
       
   239                 # that we're using *args and **kwargs together.
       
   240                 call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db})
       
   241 
       
   242     def _urlconf_setup(self):
       
   243         if hasattr(self, 'urls'):
       
   244             self._old_root_urlconf = settings.ROOT_URLCONF
       
   245             settings.ROOT_URLCONF = self.urls
       
   246             clear_url_caches()
       
   247 
       
   248     def __call__(self, result=None):
       
   249         """
       
   250         Wrapper around default __call__ method to perform common Django test
       
   251         set up. This means that user-defined Test Cases aren't required to
       
   252         include a call to super().setUp().
       
   253         """
       
   254         self.client = Client()
       
   255         try:
       
   256             self._pre_setup()
       
   257         except (KeyboardInterrupt, SystemExit):
       
   258             raise
       
   259         except Exception:
       
   260             import sys
       
   261             result.addError(self, sys.exc_info())
       
   262             return
       
   263         super(TransactionTestCase, self).__call__(result)
       
   264         try:
       
   265             self._post_teardown()
       
   266         except (KeyboardInterrupt, SystemExit):
       
   267             raise
       
   268         except Exception:
       
   269             import sys
       
   270             result.addError(self, sys.exc_info())
       
   271             return
       
   272 
       
   273     def _post_teardown(self):
       
   274         """ Performs any post-test things. This includes:
       
   275 
       
   276             * Putting back the original ROOT_URLCONF if it was changed.
       
   277         """
       
   278         self._fixture_teardown()
       
   279         self._urlconf_teardown()
       
   280 
       
   281     def _fixture_teardown(self):
       
   282         pass
       
   283 
       
   284     def _urlconf_teardown(self):
       
   285         if hasattr(self, '_old_root_urlconf'):
       
   286             settings.ROOT_URLCONF = self._old_root_urlconf
       
   287             clear_url_caches()
       
   288 
       
   289     def assertRedirects(self, response, expected_url, status_code=302,
       
   290                         target_status_code=200, host=None, msg_prefix=''):
       
   291         """Asserts that a response redirected to a specific URL, and that the
       
   292         redirect URL can be loaded.
       
   293 
       
   294         Note that assertRedirects won't work for external links since it uses
       
   295         TestClient to do a request.
       
   296         """
       
   297         if msg_prefix:
       
   298             msg_prefix += ": "
       
   299 
       
   300         if hasattr(response, 'redirect_chain'):
       
   301             # The request was a followed redirect
       
   302             self.failUnless(len(response.redirect_chain) > 0,
       
   303                 msg_prefix + "Response didn't redirect as expected: Response"
       
   304                 " code was %d (expected %d)" %
       
   305                     (response.status_code, status_code))
       
   306 
       
   307             self.assertEqual(response.redirect_chain[0][1], status_code,
       
   308                 msg_prefix + "Initial response didn't redirect as expected:"
       
   309                 " Response code was %d (expected %d)" %
       
   310                     (response.redirect_chain[0][1], status_code))
       
   311 
       
   312             url, status_code = response.redirect_chain[-1]
       
   313 
       
   314             self.assertEqual(response.status_code, target_status_code,
       
   315                 msg_prefix + "Response didn't redirect as expected: Final"
       
   316                 " Response code was %d (expected %d)" %
       
   317                     (response.status_code, target_status_code))
       
   318 
       
   319         else:
       
   320             # Not a followed redirect
       
   321             self.assertEqual(response.status_code, status_code,
       
   322                 msg_prefix + "Response didn't redirect as expected: Response"
       
   323                 " code was %d (expected %d)" %
       
   324                     (response.status_code, status_code))
       
   325 
       
   326             url = response['Location']
       
   327             scheme, netloc, path, query, fragment = urlsplit(url)
       
   328 
       
   329             redirect_response = response.client.get(path, QueryDict(query))
       
   330 
       
   331             # Get the redirection page, using the same client that was used
       
   332             # to obtain the original response.
       
   333             self.assertEqual(redirect_response.status_code, target_status_code,
       
   334                 msg_prefix + "Couldn't retrieve redirection page '%s':"
       
   335                 " response code was %d (expected %d)" %
       
   336                     (path, redirect_response.status_code, target_status_code))
       
   337 
       
   338         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
       
   339         if not (e_scheme or e_netloc):
       
   340             expected_url = urlunsplit(('http', host or 'testserver', e_path,
       
   341                 e_query, e_fragment))
       
   342 
       
   343         self.assertEqual(url, expected_url,
       
   344             msg_prefix + "Response redirected to '%s', expected '%s'" %
       
   345                 (url, expected_url))
       
   346 
       
   347     def assertContains(self, response, text, count=None, status_code=200,
       
   348                        msg_prefix=''):
       
   349         """
       
   350         Asserts that a response indicates that a page was retrieved
       
   351         successfully, (i.e., the HTTP status code was as expected), and that
       
   352         ``text`` occurs ``count`` times in the content of the response.
       
   353         If ``count`` is None, the count doesn't matter - the assertion is true
       
   354         if the text occurs at least once in the response.
       
   355         """
       
   356         if msg_prefix:
       
   357             msg_prefix += ": "
       
   358 
       
   359         self.assertEqual(response.status_code, status_code,
       
   360             msg_prefix + "Couldn't retrieve page: Response code was %d"
       
   361             " (expected %d)" % (response.status_code, status_code))
       
   362         text = smart_str(text, response._charset)
       
   363         real_count = response.content.count(text)
       
   364         if count is not None:
       
   365             self.assertEqual(real_count, count,
       
   366                 msg_prefix + "Found %d instances of '%s' in response"
       
   367                 " (expected %d)" % (real_count, text, count))
       
   368         else:
       
   369             self.failUnless(real_count != 0,
       
   370                 msg_prefix + "Couldn't find '%s' in response" % text)
       
   371 
       
   372     def assertNotContains(self, response, text, status_code=200,
       
   373                           msg_prefix=''):
       
   374         """
       
   375         Asserts that a response indicates that a page was retrieved
       
   376         successfully, (i.e., the HTTP status code was as expected), and that
       
   377         ``text`` doesn't occurs in the content of the response.
       
   378         """
       
   379         if msg_prefix:
       
   380             msg_prefix += ": "
       
   381 
       
   382         self.assertEqual(response.status_code, status_code,
       
   383             msg_prefix + "Couldn't retrieve page: Response code was %d"
       
   384             " (expected %d)" % (response.status_code, status_code))
       
   385         text = smart_str(text, response._charset)
       
   386         self.assertEqual(response.content.count(text), 0,
       
   387             msg_prefix + "Response should not contain '%s'" % text)
       
   388 
       
   389     def assertFormError(self, response, form, field, errors, msg_prefix=''):
       
   390         """
       
   391         Asserts that a form used to render the response has a specific field
       
   392         error.
       
   393         """
       
   394         if msg_prefix:
       
   395             msg_prefix += ": "
       
   396 
       
   397         # Put context(s) into a list to simplify processing.
       
   398         contexts = to_list(response.context)
       
   399         if not contexts:
       
   400             self.fail(msg_prefix + "Response did not use any contexts to "
       
   401                       "render the response")
       
   402 
       
   403         # Put error(s) into a list to simplify processing.
       
   404         errors = to_list(errors)
       
   405 
       
   406         # Search all contexts for the error.
       
   407         found_form = False
       
   408         for i,context in enumerate(contexts):
       
   409             if form not in context:
       
   410                 continue
       
   411             found_form = True
       
   412             for err in errors:
       
   413                 if field:
       
   414                     if field in context[form].errors:
       
   415                         field_errors = context[form].errors[field]
       
   416                         self.failUnless(err in field_errors,
       
   417                             msg_prefix + "The field '%s' on form '%s' in"
       
   418                             " context %d does not contain the error '%s'"
       
   419                             " (actual errors: %s)" %
       
   420                                 (field, form, i, err, repr(field_errors)))
       
   421                     elif field in context[form].fields:
       
   422                         self.fail(msg_prefix + "The field '%s' on form '%s'"
       
   423                                   " in context %d contains no errors" %
       
   424                                       (field, form, i))
       
   425                     else:
       
   426                         self.fail(msg_prefix + "The form '%s' in context %d"
       
   427                                   " does not contain the field '%s'" %
       
   428                                       (form, i, field))
       
   429                 else:
       
   430                     non_field_errors = context[form].non_field_errors()
       
   431                     self.failUnless(err in non_field_errors,
       
   432                         msg_prefix + "The form '%s' in context %d does not"
       
   433                         " contain the non-field error '%s'"
       
   434                         " (actual errors: %s)" %
       
   435                             (form, i, err, non_field_errors))
       
   436         if not found_form:
       
   437             self.fail(msg_prefix + "The form '%s' was not used to render the"
       
   438                       " response" % form)
       
   439 
       
   440     def assertTemplateUsed(self, response, template_name, msg_prefix=''):
       
   441         """
       
   442         Asserts that the template with the provided name was used in rendering
       
   443         the response.
       
   444         """
       
   445         if msg_prefix:
       
   446             msg_prefix += ": "
       
   447 
       
   448         template_names = [t.name for t in to_list(response.template)]
       
   449         if not template_names:
       
   450             self.fail(msg_prefix + "No templates used to render the response")
       
   451         self.failUnless(template_name in template_names,
       
   452             msg_prefix + "Template '%s' was not a template used to render"
       
   453             " the response. Actual template(s) used: %s" %
       
   454                 (template_name, u', '.join(template_names)))
       
   455 
       
   456     def assertTemplateNotUsed(self, response, template_name, msg_prefix=''):
       
   457         """
       
   458         Asserts that the template with the provided name was NOT used in
       
   459         rendering the response.
       
   460         """
       
   461         if msg_prefix:
       
   462             msg_prefix += ": "
       
   463 
       
   464         template_names = [t.name for t in to_list(response.template)]
       
   465         self.failIf(template_name in template_names,
       
   466             msg_prefix + "Template '%s' was used unexpectedly in rendering"
       
   467             " the response" % template_name)
       
   468 
       
   469 def connections_support_transactions():
       
   470     """
       
   471     Returns True if all connections support transactions.  This is messy
       
   472     because 2.4 doesn't support any or all.
       
   473     """
       
   474     return all(conn.settings_dict['SUPPORTS_TRANSACTIONS']
       
   475         for conn in connections.all())
       
   476 
       
   477 class TestCase(TransactionTestCase):
       
   478     """
       
   479     Does basically the same as TransactionTestCase, but surrounds every test
       
   480     with a transaction, monkey-patches the real transaction management routines to
       
   481     do nothing, and rollsback the test transaction at the end of the test. You have
       
   482     to use TransactionTestCase, if you need transaction management inside a test.
       
   483     """
       
   484 
       
   485     def _fixture_setup(self):
       
   486         if not connections_support_transactions():
       
   487             return super(TestCase, self)._fixture_setup()
       
   488 
       
   489         # If the test case has a multi_db=True flag, setup all databases.
       
   490         # Otherwise, just use default.
       
   491         if getattr(self, 'multi_db', False):
       
   492             databases = connections
       
   493         else:
       
   494             databases = [DEFAULT_DB_ALIAS]
       
   495 
       
   496         for db in databases:
       
   497             transaction.enter_transaction_management(using=db)
       
   498             transaction.managed(True, using=db)
       
   499         disable_transaction_methods()
       
   500 
       
   501         from django.contrib.sites.models import Site
       
   502         Site.objects.clear_cache()
       
   503 
       
   504         for db in databases:
       
   505             if hasattr(self, 'fixtures'):
       
   506                 call_command('loaddata', *self.fixtures, **{
       
   507                                                             'verbosity': 0,
       
   508                                                             'commit': False,
       
   509                                                             'database': db
       
   510                                                             })
       
   511 
       
   512     def _fixture_teardown(self):
       
   513         if not connections_support_transactions():
       
   514             return super(TestCase, self)._fixture_teardown()
       
   515 
       
   516         # If the test case has a multi_db=True flag, teardown all databases.
       
   517         # Otherwise, just teardown default.
       
   518         if getattr(self, 'multi_db', False):
       
   519             databases = connections
       
   520         else:
       
   521             databases = [DEFAULT_DB_ALIAS]
       
   522 
       
   523         restore_transaction_methods()
       
   524         for db in databases:
       
   525             transaction.rollback(using=db)
       
   526             transaction.leave_transaction_management(using=db)
       
   527 
       
   528         for connection in connections.all():
       
   529             connection.close()