web/lib/django/test/testcases.py
changeset 0 0d40e90630ef
child 29 cc9b7e14412b
equal deleted inserted replaced
-1:000000000000 0:0d40e90630ef
       
     1 import re
       
     2 import unittest
       
     3 from urlparse import urlsplit, urlunsplit
       
     4 from xml.dom.minidom import parseString, Node
       
     5 
       
     6 from django.conf import settings
       
     7 from django.core import mail
       
     8 from django.core.management import call_command
       
     9 from django.core.urlresolvers import clear_url_caches
       
    10 from django.db import transaction, connection
       
    11 from django.http import QueryDict
       
    12 from django.test import _doctest as doctest
       
    13 from django.test.client import Client
       
    14 from django.utils import simplejson
       
    15 from django.utils.encoding import smart_str
       
    16 
       
    17 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s)
       
    18 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s)
       
    19 
       
    20 def to_list(value):
       
    21     """
       
    22     Puts value into a list if it's not already one.
       
    23     Returns an empty list if value is None.
       
    24     """
       
    25     if value is None:
       
    26         value = []
       
    27     elif not isinstance(value, list):
       
    28         value = [value]
       
    29     return value
       
    30 
       
    31 real_commit = transaction.commit
       
    32 real_rollback = transaction.rollback
       
    33 real_enter_transaction_management = transaction.enter_transaction_management
       
    34 real_leave_transaction_management = transaction.leave_transaction_management
       
    35 real_savepoint_commit = transaction.savepoint_commit
       
    36 real_savepoint_rollback = transaction.savepoint_rollback
       
    37 real_managed = transaction.managed
       
    38 
       
    39 def nop(*args, **kwargs):
       
    40     return
       
    41 
       
    42 def disable_transaction_methods():
       
    43     transaction.commit = nop
       
    44     transaction.rollback = nop
       
    45     transaction.savepoint_commit = nop
       
    46     transaction.savepoint_rollback = nop
       
    47     transaction.enter_transaction_management = nop
       
    48     transaction.leave_transaction_management = nop
       
    49     transaction.managed = nop
       
    50 
       
    51 def restore_transaction_methods():
       
    52     transaction.commit = real_commit
       
    53     transaction.rollback = real_rollback
       
    54     transaction.savepoint_commit = real_savepoint_commit
       
    55     transaction.savepoint_rollback = real_savepoint_rollback
       
    56     transaction.enter_transaction_management = real_enter_transaction_management
       
    57     transaction.leave_transaction_management = real_leave_transaction_management
       
    58     transaction.managed = real_managed
       
    59 
       
    60 class OutputChecker(doctest.OutputChecker):
       
    61     def check_output(self, want, got, optionflags):
       
    62         "The entry method for doctest output checking. Defers to a sequence of child checkers"
       
    63         checks = (self.check_output_default,
       
    64                   self.check_output_numeric,
       
    65                   self.check_output_xml,
       
    66                   self.check_output_json)
       
    67         for check in checks:
       
    68             if check(want, got, optionflags):
       
    69                 return True
       
    70         return False
       
    71 
       
    72     def check_output_default(self, want, got, optionflags):
       
    73         "The default comparator provided by doctest - not perfect, but good for most purposes"
       
    74         return doctest.OutputChecker.check_output(self, want, got, optionflags)
       
    75 
       
    76     def check_output_numeric(self, want, got, optionflags):
       
    77         """Doctest does an exact string comparison of output, which means that
       
    78         some numerically equivalent values aren't equal. This check normalizes
       
    79          * long integers (22L) so that they equal normal integers. (22)
       
    80          * Decimals so that they are comparable, regardless of the change
       
    81            made to __repr__ in Python 2.6.
       
    82         """
       
    83         return doctest.OutputChecker.check_output(self,
       
    84             normalize_decimals(normalize_long_ints(want)),
       
    85             normalize_decimals(normalize_long_ints(got)),
       
    86             optionflags)
       
    87 
       
    88     def check_output_xml(self, want, got, optionsflags):
       
    89         """Tries to do a 'xml-comparision' of want and got.  Plain string
       
    90         comparision doesn't always work because, for example, attribute
       
    91         ordering should not be important.
       
    92 
       
    93         Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py
       
    94         """
       
    95         _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+')
       
    96         def norm_whitespace(v):
       
    97             return _norm_whitespace_re.sub(' ', v)
       
    98 
       
    99         def child_text(element):
       
   100             return ''.join([c.data for c in element.childNodes
       
   101                             if c.nodeType == Node.TEXT_NODE])
       
   102 
       
   103         def children(element):
       
   104             return [c for c in element.childNodes
       
   105                     if c.nodeType == Node.ELEMENT_NODE]
       
   106 
       
   107         def norm_child_text(element):
       
   108             return norm_whitespace(child_text(element))
       
   109 
       
   110         def attrs_dict(element):
       
   111             return dict(element.attributes.items())
       
   112 
       
   113         def check_element(want_element, got_element):
       
   114             if want_element.tagName != got_element.tagName:
       
   115                 return False
       
   116             if norm_child_text(want_element) != norm_child_text(got_element):
       
   117                 return False
       
   118             if attrs_dict(want_element) != attrs_dict(got_element):
       
   119                 return False
       
   120             want_children = children(want_element)
       
   121             got_children = children(got_element)
       
   122             if len(want_children) != len(got_children):
       
   123                 return False
       
   124             for want, got in zip(want_children, got_children):
       
   125                 if not check_element(want, got):
       
   126                     return False
       
   127             return True
       
   128 
       
   129         want, got = self._strip_quotes(want, got)
       
   130         want = want.replace('\\n','\n')
       
   131         got = got.replace('\\n','\n')
       
   132 
       
   133         # If the string is not a complete xml document, we may need to add a
       
   134         # root element. This allow us to compare fragments, like "<foo/><bar/>"
       
   135         if not want.startswith('<?xml'):
       
   136             wrapper = '<root>%s</root>'
       
   137             want = wrapper % want
       
   138             got = wrapper % got
       
   139 
       
   140         # Parse the want and got strings, and compare the parsings.
       
   141         try:
       
   142             want_root = parseString(want).firstChild
       
   143             got_root = parseString(got).firstChild
       
   144         except:
       
   145             return False
       
   146         return check_element(want_root, got_root)
       
   147 
       
   148     def check_output_json(self, want, got, optionsflags):
       
   149         "Tries to compare want and got as if they were JSON-encoded data"
       
   150         want, got = self._strip_quotes(want, got)
       
   151         try:
       
   152             want_json = simplejson.loads(want)
       
   153             got_json = simplejson.loads(got)
       
   154         except:
       
   155             return False
       
   156         return want_json == got_json
       
   157 
       
   158     def _strip_quotes(self, want, got):
       
   159         """
       
   160         Strip quotes of doctests output values:
       
   161 
       
   162         >>> o = OutputChecker()
       
   163         >>> o._strip_quotes("'foo'")
       
   164         "foo"
       
   165         >>> o._strip_quotes('"foo"')
       
   166         "foo"
       
   167         >>> o._strip_quotes("u'foo'")
       
   168         "foo"
       
   169         >>> o._strip_quotes('u"foo"')
       
   170         "foo"
       
   171         """
       
   172         def is_quoted_string(s):
       
   173             s = s.strip()
       
   174             return (len(s) >= 2
       
   175                     and s[0] == s[-1]
       
   176                     and s[0] in ('"', "'"))
       
   177 
       
   178         def is_quoted_unicode(s):
       
   179             s = s.strip()
       
   180             return (len(s) >= 3
       
   181                     and s[0] == 'u'
       
   182                     and s[1] == s[-1]
       
   183                     and s[1] in ('"', "'"))
       
   184 
       
   185         if is_quoted_string(want) and is_quoted_string(got):
       
   186             want = want.strip()[1:-1]
       
   187             got = got.strip()[1:-1]
       
   188         elif is_quoted_unicode(want) and is_quoted_unicode(got):
       
   189             want = want.strip()[2:-1]
       
   190             got = got.strip()[2:-1]
       
   191         return want, got
       
   192 
       
   193 
       
   194 class DocTestRunner(doctest.DocTestRunner):
       
   195     def __init__(self, *args, **kwargs):
       
   196         doctest.DocTestRunner.__init__(self, *args, **kwargs)
       
   197         self.optionflags = doctest.ELLIPSIS
       
   198 
       
   199     def report_unexpected_exception(self, out, test, example, exc_info):
       
   200         doctest.DocTestRunner.report_unexpected_exception(self, out, test,
       
   201                                                           example, exc_info)
       
   202         # Rollback, in case of database errors. Otherwise they'd have
       
   203         # side effects on other tests.
       
   204         transaction.rollback_unless_managed()
       
   205 
       
   206 class TransactionTestCase(unittest.TestCase):
       
   207     def _pre_setup(self):
       
   208         """Performs any pre-test setup. This includes:
       
   209 
       
   210             * Flushing the database.
       
   211             * If the Test Case class has a 'fixtures' member, installing the
       
   212               named fixtures.
       
   213             * If the Test Case class has a 'urls' member, replace the
       
   214               ROOT_URLCONF with it.
       
   215             * Clearing the mail test outbox.
       
   216         """
       
   217         self._fixture_setup()
       
   218         self._urlconf_setup()
       
   219         mail.outbox = []
       
   220 
       
   221     def _fixture_setup(self):
       
   222         call_command('flush', verbosity=0, interactive=False)
       
   223         if hasattr(self, 'fixtures'):
       
   224             # We have to use this slightly awkward syntax due to the fact
       
   225             # that we're using *args and **kwargs together.
       
   226             call_command('loaddata', *self.fixtures, **{'verbosity': 0})
       
   227 
       
   228     def _urlconf_setup(self):
       
   229         if hasattr(self, 'urls'):
       
   230             self._old_root_urlconf = settings.ROOT_URLCONF
       
   231             settings.ROOT_URLCONF = self.urls
       
   232             clear_url_caches()
       
   233 
       
   234     def __call__(self, result=None):
       
   235         """
       
   236         Wrapper around default __call__ method to perform common Django test
       
   237         set up. This means that user-defined Test Cases aren't required to
       
   238         include a call to super().setUp().
       
   239         """
       
   240         self.client = Client()
       
   241         try:
       
   242             self._pre_setup()
       
   243         except (KeyboardInterrupt, SystemExit):
       
   244             raise
       
   245         except Exception:
       
   246             import sys
       
   247             result.addError(self, sys.exc_info())
       
   248             return
       
   249         super(TransactionTestCase, self).__call__(result)
       
   250         try:
       
   251             self._post_teardown()
       
   252         except (KeyboardInterrupt, SystemExit):
       
   253             raise
       
   254         except Exception:
       
   255             import sys
       
   256             result.addError(self, sys.exc_info())
       
   257             return
       
   258 
       
   259     def _post_teardown(self):
       
   260         """ Performs any post-test things. This includes:
       
   261 
       
   262             * Putting back the original ROOT_URLCONF if it was changed.
       
   263         """
       
   264         self._fixture_teardown()
       
   265         self._urlconf_teardown()
       
   266 
       
   267     def _fixture_teardown(self):
       
   268         pass
       
   269 
       
   270     def _urlconf_teardown(self):
       
   271         if hasattr(self, '_old_root_urlconf'):
       
   272             settings.ROOT_URLCONF = self._old_root_urlconf
       
   273             clear_url_caches()
       
   274 
       
   275     def assertRedirects(self, response, expected_url, status_code=302,
       
   276                         target_status_code=200, host=None):
       
   277         """Asserts that a response redirected to a specific URL, and that the
       
   278         redirect URL can be loaded.
       
   279 
       
   280         Note that assertRedirects won't work for external links since it uses
       
   281         TestClient to do a request.
       
   282         """
       
   283         if hasattr(response, 'redirect_chain'):
       
   284             # The request was a followed redirect
       
   285             self.failUnless(len(response.redirect_chain) > 0,
       
   286                 ("Response didn't redirect as expected: Response code was %d"
       
   287                 " (expected %d)" % (response.status_code, status_code)))
       
   288 
       
   289             self.assertEqual(response.redirect_chain[0][1], status_code,
       
   290                 ("Initial response didn't redirect as expected: Response code was %d"
       
   291                  " (expected %d)" % (response.redirect_chain[0][1], status_code)))
       
   292 
       
   293             url, status_code = response.redirect_chain[-1]
       
   294 
       
   295             self.assertEqual(response.status_code, target_status_code,
       
   296                 ("Response didn't redirect as expected: Final Response code was %d"
       
   297                 " (expected %d)" % (response.status_code, target_status_code)))
       
   298 
       
   299         else:
       
   300             # Not a followed redirect
       
   301             self.assertEqual(response.status_code, status_code,
       
   302                 ("Response didn't redirect as expected: Response code was %d"
       
   303                  " (expected %d)" % (response.status_code, status_code)))
       
   304 
       
   305             url = response['Location']
       
   306             scheme, netloc, path, query, fragment = urlsplit(url)
       
   307 
       
   308             redirect_response = response.client.get(path, QueryDict(query))
       
   309 
       
   310             # Get the redirection page, using the same client that was used
       
   311             # to obtain the original response.
       
   312             self.assertEqual(redirect_response.status_code, target_status_code,
       
   313                 ("Couldn't retrieve redirection page '%s': response code was %d"
       
   314                  " (expected %d)") %
       
   315                      (path, redirect_response.status_code, target_status_code))
       
   316 
       
   317         e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url)
       
   318         if not (e_scheme or e_netloc):
       
   319             expected_url = urlunsplit(('http', host or 'testserver', e_path,
       
   320                 e_query, e_fragment))
       
   321 
       
   322         self.assertEqual(url, expected_url,
       
   323             "Response redirected to '%s', expected '%s'" % (url, expected_url))
       
   324 
       
   325 
       
   326     def assertContains(self, response, text, count=None, status_code=200):
       
   327         """
       
   328         Asserts that a response indicates that a page was retrieved
       
   329         successfully, (i.e., the HTTP status code was as expected), and that
       
   330         ``text`` occurs ``count`` times in the content of the response.
       
   331         If ``count`` is None, the count doesn't matter - the assertion is true
       
   332         if the text occurs at least once in the response.
       
   333         """
       
   334         self.assertEqual(response.status_code, status_code,
       
   335             "Couldn't retrieve page: Response code was %d (expected %d)'" %
       
   336                 (response.status_code, status_code))
       
   337         text = smart_str(text, response._charset)
       
   338         real_count = response.content.count(text)
       
   339         if count is not None:
       
   340             self.assertEqual(real_count, count,
       
   341                 "Found %d instances of '%s' in response (expected %d)" %
       
   342                     (real_count, text, count))
       
   343         else:
       
   344             self.failUnless(real_count != 0,
       
   345                             "Couldn't find '%s' in response" % text)
       
   346 
       
   347     def assertNotContains(self, response, text, status_code=200):
       
   348         """
       
   349         Asserts that a response indicates that a page was retrieved
       
   350         successfully, (i.e., the HTTP status code was as expected), and that
       
   351         ``text`` doesn't occurs in the content of the response.
       
   352         """
       
   353         self.assertEqual(response.status_code, status_code,
       
   354             "Couldn't retrieve page: Response code was %d (expected %d)'" %
       
   355                 (response.status_code, status_code))
       
   356         text = smart_str(text, response._charset)
       
   357         self.assertEqual(response.content.count(text),
       
   358              0, "Response should not contain '%s'" % text)
       
   359 
       
   360     def assertFormError(self, response, form, field, errors):
       
   361         """
       
   362         Asserts that a form used to render the response has a specific field
       
   363         error.
       
   364         """
       
   365         # Put context(s) into a list to simplify processing.
       
   366         contexts = to_list(response.context)
       
   367         if not contexts:
       
   368             self.fail('Response did not use any contexts to render the'
       
   369                       ' response')
       
   370 
       
   371         # Put error(s) into a list to simplify processing.
       
   372         errors = to_list(errors)
       
   373 
       
   374         # Search all contexts for the error.
       
   375         found_form = False
       
   376         for i,context in enumerate(contexts):
       
   377             if form not in context:
       
   378                 continue
       
   379             found_form = True
       
   380             for err in errors:
       
   381                 if field:
       
   382                     if field in context[form].errors:
       
   383                         field_errors = context[form].errors[field]
       
   384                         self.failUnless(err in field_errors,
       
   385                                         "The field '%s' on form '%s' in"
       
   386                                         " context %d does not contain the"
       
   387                                         " error '%s' (actual errors: %s)" %
       
   388                                             (field, form, i, err,
       
   389                                              repr(field_errors)))
       
   390                     elif field in context[form].fields:
       
   391                         self.fail("The field '%s' on form '%s' in context %d"
       
   392                                   " contains no errors" % (field, form, i))
       
   393                     else:
       
   394                         self.fail("The form '%s' in context %d does not"
       
   395                                   " contain the field '%s'" %
       
   396                                       (form, i, field))
       
   397                 else:
       
   398                     non_field_errors = context[form].non_field_errors()
       
   399                     self.failUnless(err in non_field_errors,
       
   400                         "The form '%s' in context %d does not contain the"
       
   401                         " non-field error '%s' (actual errors: %s)" %
       
   402                             (form, i, err, non_field_errors))
       
   403         if not found_form:
       
   404             self.fail("The form '%s' was not used to render the response" %
       
   405                           form)
       
   406 
       
   407     def assertTemplateUsed(self, response, template_name):
       
   408         """
       
   409         Asserts that the template with the provided name was used in rendering
       
   410         the response.
       
   411         """
       
   412         template_names = [t.name for t in to_list(response.template)]
       
   413         if not template_names:
       
   414             self.fail('No templates used to render the response')
       
   415         self.failUnless(template_name in template_names,
       
   416             (u"Template '%s' was not a template used to render the response."
       
   417              u" Actual template(s) used: %s") % (template_name,
       
   418                                                  u', '.join(template_names)))
       
   419 
       
   420     def assertTemplateNotUsed(self, response, template_name):
       
   421         """
       
   422         Asserts that the template with the provided name was NOT used in
       
   423         rendering the response.
       
   424         """
       
   425         template_names = [t.name for t in to_list(response.template)]
       
   426         self.failIf(template_name in template_names,
       
   427             (u"Template '%s' was used unexpectedly in rendering the"
       
   428              u" response") % template_name)
       
   429 
       
   430 class TestCase(TransactionTestCase):
       
   431     """
       
   432     Does basically the same as TransactionTestCase, but surrounds every test
       
   433     with a transaction, monkey-patches the real transaction management routines to
       
   434     do nothing, and rollsback the test transaction at the end of the test. You have
       
   435     to use TransactionTestCase, if you need transaction management inside a test.
       
   436     """
       
   437 
       
   438     def _fixture_setup(self):
       
   439         if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
       
   440             return super(TestCase, self)._fixture_setup()
       
   441 
       
   442         transaction.enter_transaction_management()
       
   443         transaction.managed(True)
       
   444         disable_transaction_methods()
       
   445 
       
   446         from django.contrib.sites.models import Site
       
   447         Site.objects.clear_cache()
       
   448 
       
   449         if hasattr(self, 'fixtures'):
       
   450             call_command('loaddata', *self.fixtures, **{
       
   451                                                         'verbosity': 0,
       
   452                                                         'commit': False
       
   453                                                         })
       
   454 
       
   455     def _fixture_teardown(self):
       
   456         if not settings.DATABASE_SUPPORTS_TRANSACTIONS:
       
   457             return super(TestCase, self)._fixture_teardown()
       
   458 
       
   459         restore_transaction_methods()
       
   460         transaction.rollback()
       
   461         transaction.leave_transaction_management()
       
   462         connection.close()