|
1 import re |
|
2 import unittest |
|
3 from urlparse import urlsplit, urlunsplit |
|
4 from xml.dom.minidom import parseString, Node |
|
5 |
|
6 from django.conf import settings |
|
7 from django.core import mail |
|
8 from django.core.management import call_command |
|
9 from django.core.urlresolvers import clear_url_caches |
|
10 from django.db import transaction, connections, DEFAULT_DB_ALIAS |
|
11 from django.http import QueryDict |
|
12 from django.test import _doctest as doctest |
|
13 from django.test.client import Client |
|
14 from django.utils import simplejson |
|
15 from django.utils.encoding import smart_str |
|
16 |
|
17 try: |
|
18 all |
|
19 except NameError: |
|
20 from django.utils.itercompat import all |
|
21 |
|
22 normalize_long_ints = lambda s: re.sub(r'(?<![\w])(\d+)L(?![\w])', '\\1', s) |
|
23 normalize_decimals = lambda s: re.sub(r"Decimal\('(\d+(\.\d*)?)'\)", lambda m: "Decimal(\"%s\")" % m.groups()[0], s) |
|
24 |
|
25 def to_list(value): |
|
26 """ |
|
27 Puts value into a list if it's not already one. |
|
28 Returns an empty list if value is None. |
|
29 """ |
|
30 if value is None: |
|
31 value = [] |
|
32 elif not isinstance(value, list): |
|
33 value = [value] |
|
34 return value |
|
35 |
|
36 real_commit = transaction.commit |
|
37 real_rollback = transaction.rollback |
|
38 real_enter_transaction_management = transaction.enter_transaction_management |
|
39 real_leave_transaction_management = transaction.leave_transaction_management |
|
40 real_savepoint_commit = transaction.savepoint_commit |
|
41 real_savepoint_rollback = transaction.savepoint_rollback |
|
42 real_managed = transaction.managed |
|
43 |
|
44 def nop(*args, **kwargs): |
|
45 return |
|
46 |
|
47 def disable_transaction_methods(): |
|
48 transaction.commit = nop |
|
49 transaction.rollback = nop |
|
50 transaction.savepoint_commit = nop |
|
51 transaction.savepoint_rollback = nop |
|
52 transaction.enter_transaction_management = nop |
|
53 transaction.leave_transaction_management = nop |
|
54 transaction.managed = nop |
|
55 |
|
56 def restore_transaction_methods(): |
|
57 transaction.commit = real_commit |
|
58 transaction.rollback = real_rollback |
|
59 transaction.savepoint_commit = real_savepoint_commit |
|
60 transaction.savepoint_rollback = real_savepoint_rollback |
|
61 transaction.enter_transaction_management = real_enter_transaction_management |
|
62 transaction.leave_transaction_management = real_leave_transaction_management |
|
63 transaction.managed = real_managed |
|
64 |
|
65 class OutputChecker(doctest.OutputChecker): |
|
66 def check_output(self, want, got, optionflags): |
|
67 "The entry method for doctest output checking. Defers to a sequence of child checkers" |
|
68 checks = (self.check_output_default, |
|
69 self.check_output_numeric, |
|
70 self.check_output_xml, |
|
71 self.check_output_json) |
|
72 for check in checks: |
|
73 if check(want, got, optionflags): |
|
74 return True |
|
75 return False |
|
76 |
|
77 def check_output_default(self, want, got, optionflags): |
|
78 "The default comparator provided by doctest - not perfect, but good for most purposes" |
|
79 return doctest.OutputChecker.check_output(self, want, got, optionflags) |
|
80 |
|
81 def check_output_numeric(self, want, got, optionflags): |
|
82 """Doctest does an exact string comparison of output, which means that |
|
83 some numerically equivalent values aren't equal. This check normalizes |
|
84 * long integers (22L) so that they equal normal integers. (22) |
|
85 * Decimals so that they are comparable, regardless of the change |
|
86 made to __repr__ in Python 2.6. |
|
87 """ |
|
88 return doctest.OutputChecker.check_output(self, |
|
89 normalize_decimals(normalize_long_ints(want)), |
|
90 normalize_decimals(normalize_long_ints(got)), |
|
91 optionflags) |
|
92 |
|
93 def check_output_xml(self, want, got, optionsflags): |
|
94 """Tries to do a 'xml-comparision' of want and got. Plain string |
|
95 comparision doesn't always work because, for example, attribute |
|
96 ordering should not be important. |
|
97 |
|
98 Based on http://codespeak.net/svn/lxml/trunk/src/lxml/doctestcompare.py |
|
99 """ |
|
100 _norm_whitespace_re = re.compile(r'[ \t\n][ \t\n]+') |
|
101 def norm_whitespace(v): |
|
102 return _norm_whitespace_re.sub(' ', v) |
|
103 |
|
104 def child_text(element): |
|
105 return ''.join([c.data for c in element.childNodes |
|
106 if c.nodeType == Node.TEXT_NODE]) |
|
107 |
|
108 def children(element): |
|
109 return [c for c in element.childNodes |
|
110 if c.nodeType == Node.ELEMENT_NODE] |
|
111 |
|
112 def norm_child_text(element): |
|
113 return norm_whitespace(child_text(element)) |
|
114 |
|
115 def attrs_dict(element): |
|
116 return dict(element.attributes.items()) |
|
117 |
|
118 def check_element(want_element, got_element): |
|
119 if want_element.tagName != got_element.tagName: |
|
120 return False |
|
121 if norm_child_text(want_element) != norm_child_text(got_element): |
|
122 return False |
|
123 if attrs_dict(want_element) != attrs_dict(got_element): |
|
124 return False |
|
125 want_children = children(want_element) |
|
126 got_children = children(got_element) |
|
127 if len(want_children) != len(got_children): |
|
128 return False |
|
129 for want, got in zip(want_children, got_children): |
|
130 if not check_element(want, got): |
|
131 return False |
|
132 return True |
|
133 |
|
134 want, got = self._strip_quotes(want, got) |
|
135 want = want.replace('\\n','\n') |
|
136 got = got.replace('\\n','\n') |
|
137 |
|
138 # If the string is not a complete xml document, we may need to add a |
|
139 # root element. This allow us to compare fragments, like "<foo/><bar/>" |
|
140 if not want.startswith('<?xml'): |
|
141 wrapper = '<root>%s</root>' |
|
142 want = wrapper % want |
|
143 got = wrapper % got |
|
144 |
|
145 # Parse the want and got strings, and compare the parsings. |
|
146 try: |
|
147 want_root = parseString(want).firstChild |
|
148 got_root = parseString(got).firstChild |
|
149 except: |
|
150 return False |
|
151 return check_element(want_root, got_root) |
|
152 |
|
153 def check_output_json(self, want, got, optionsflags): |
|
154 "Tries to compare want and got as if they were JSON-encoded data" |
|
155 want, got = self._strip_quotes(want, got) |
|
156 try: |
|
157 want_json = simplejson.loads(want) |
|
158 got_json = simplejson.loads(got) |
|
159 except: |
|
160 return False |
|
161 return want_json == got_json |
|
162 |
|
163 def _strip_quotes(self, want, got): |
|
164 """ |
|
165 Strip quotes of doctests output values: |
|
166 |
|
167 >>> o = OutputChecker() |
|
168 >>> o._strip_quotes("'foo'") |
|
169 "foo" |
|
170 >>> o._strip_quotes('"foo"') |
|
171 "foo" |
|
172 >>> o._strip_quotes("u'foo'") |
|
173 "foo" |
|
174 >>> o._strip_quotes('u"foo"') |
|
175 "foo" |
|
176 """ |
|
177 def is_quoted_string(s): |
|
178 s = s.strip() |
|
179 return (len(s) >= 2 |
|
180 and s[0] == s[-1] |
|
181 and s[0] in ('"', "'")) |
|
182 |
|
183 def is_quoted_unicode(s): |
|
184 s = s.strip() |
|
185 return (len(s) >= 3 |
|
186 and s[0] == 'u' |
|
187 and s[1] == s[-1] |
|
188 and s[1] in ('"', "'")) |
|
189 |
|
190 if is_quoted_string(want) and is_quoted_string(got): |
|
191 want = want.strip()[1:-1] |
|
192 got = got.strip()[1:-1] |
|
193 elif is_quoted_unicode(want) and is_quoted_unicode(got): |
|
194 want = want.strip()[2:-1] |
|
195 got = got.strip()[2:-1] |
|
196 return want, got |
|
197 |
|
198 |
|
199 class DocTestRunner(doctest.DocTestRunner): |
|
200 def __init__(self, *args, **kwargs): |
|
201 doctest.DocTestRunner.__init__(self, *args, **kwargs) |
|
202 self.optionflags = doctest.ELLIPSIS |
|
203 |
|
204 def report_unexpected_exception(self, out, test, example, exc_info): |
|
205 doctest.DocTestRunner.report_unexpected_exception(self, out, test, |
|
206 example, exc_info) |
|
207 # Rollback, in case of database errors. Otherwise they'd have |
|
208 # side effects on other tests. |
|
209 for conn in connections: |
|
210 transaction.rollback_unless_managed(using=conn) |
|
211 |
|
212 class TransactionTestCase(unittest.TestCase): |
|
213 def _pre_setup(self): |
|
214 """Performs any pre-test setup. This includes: |
|
215 |
|
216 * Flushing the database. |
|
217 * If the Test Case class has a 'fixtures' member, installing the |
|
218 named fixtures. |
|
219 * If the Test Case class has a 'urls' member, replace the |
|
220 ROOT_URLCONF with it. |
|
221 * Clearing the mail test outbox. |
|
222 """ |
|
223 self._fixture_setup() |
|
224 self._urlconf_setup() |
|
225 mail.outbox = [] |
|
226 |
|
227 def _fixture_setup(self): |
|
228 # If the test case has a multi_db=True flag, flush all databases. |
|
229 # Otherwise, just flush default. |
|
230 if getattr(self, 'multi_db', False): |
|
231 databases = connections |
|
232 else: |
|
233 databases = [DEFAULT_DB_ALIAS] |
|
234 for db in databases: |
|
235 call_command('flush', verbosity=0, interactive=False, database=db) |
|
236 |
|
237 if hasattr(self, 'fixtures'): |
|
238 # We have to use this slightly awkward syntax due to the fact |
|
239 # that we're using *args and **kwargs together. |
|
240 call_command('loaddata', *self.fixtures, **{'verbosity': 0, 'database': db}) |
|
241 |
|
242 def _urlconf_setup(self): |
|
243 if hasattr(self, 'urls'): |
|
244 self._old_root_urlconf = settings.ROOT_URLCONF |
|
245 settings.ROOT_URLCONF = self.urls |
|
246 clear_url_caches() |
|
247 |
|
248 def __call__(self, result=None): |
|
249 """ |
|
250 Wrapper around default __call__ method to perform common Django test |
|
251 set up. This means that user-defined Test Cases aren't required to |
|
252 include a call to super().setUp(). |
|
253 """ |
|
254 self.client = Client() |
|
255 try: |
|
256 self._pre_setup() |
|
257 except (KeyboardInterrupt, SystemExit): |
|
258 raise |
|
259 except Exception: |
|
260 import sys |
|
261 result.addError(self, sys.exc_info()) |
|
262 return |
|
263 super(TransactionTestCase, self).__call__(result) |
|
264 try: |
|
265 self._post_teardown() |
|
266 except (KeyboardInterrupt, SystemExit): |
|
267 raise |
|
268 except Exception: |
|
269 import sys |
|
270 result.addError(self, sys.exc_info()) |
|
271 return |
|
272 |
|
273 def _post_teardown(self): |
|
274 """ Performs any post-test things. This includes: |
|
275 |
|
276 * Putting back the original ROOT_URLCONF if it was changed. |
|
277 """ |
|
278 self._fixture_teardown() |
|
279 self._urlconf_teardown() |
|
280 |
|
281 def _fixture_teardown(self): |
|
282 pass |
|
283 |
|
284 def _urlconf_teardown(self): |
|
285 if hasattr(self, '_old_root_urlconf'): |
|
286 settings.ROOT_URLCONF = self._old_root_urlconf |
|
287 clear_url_caches() |
|
288 |
|
289 def assertRedirects(self, response, expected_url, status_code=302, |
|
290 target_status_code=200, host=None, msg_prefix=''): |
|
291 """Asserts that a response redirected to a specific URL, and that the |
|
292 redirect URL can be loaded. |
|
293 |
|
294 Note that assertRedirects won't work for external links since it uses |
|
295 TestClient to do a request. |
|
296 """ |
|
297 if msg_prefix: |
|
298 msg_prefix += ": " |
|
299 |
|
300 if hasattr(response, 'redirect_chain'): |
|
301 # The request was a followed redirect |
|
302 self.failUnless(len(response.redirect_chain) > 0, |
|
303 msg_prefix + "Response didn't redirect as expected: Response" |
|
304 " code was %d (expected %d)" % |
|
305 (response.status_code, status_code)) |
|
306 |
|
307 self.assertEqual(response.redirect_chain[0][1], status_code, |
|
308 msg_prefix + "Initial response didn't redirect as expected:" |
|
309 " Response code was %d (expected %d)" % |
|
310 (response.redirect_chain[0][1], status_code)) |
|
311 |
|
312 url, status_code = response.redirect_chain[-1] |
|
313 |
|
314 self.assertEqual(response.status_code, target_status_code, |
|
315 msg_prefix + "Response didn't redirect as expected: Final" |
|
316 " Response code was %d (expected %d)" % |
|
317 (response.status_code, target_status_code)) |
|
318 |
|
319 else: |
|
320 # Not a followed redirect |
|
321 self.assertEqual(response.status_code, status_code, |
|
322 msg_prefix + "Response didn't redirect as expected: Response" |
|
323 " code was %d (expected %d)" % |
|
324 (response.status_code, status_code)) |
|
325 |
|
326 url = response['Location'] |
|
327 scheme, netloc, path, query, fragment = urlsplit(url) |
|
328 |
|
329 redirect_response = response.client.get(path, QueryDict(query)) |
|
330 |
|
331 # Get the redirection page, using the same client that was used |
|
332 # to obtain the original response. |
|
333 self.assertEqual(redirect_response.status_code, target_status_code, |
|
334 msg_prefix + "Couldn't retrieve redirection page '%s':" |
|
335 " response code was %d (expected %d)" % |
|
336 (path, redirect_response.status_code, target_status_code)) |
|
337 |
|
338 e_scheme, e_netloc, e_path, e_query, e_fragment = urlsplit(expected_url) |
|
339 if not (e_scheme or e_netloc): |
|
340 expected_url = urlunsplit(('http', host or 'testserver', e_path, |
|
341 e_query, e_fragment)) |
|
342 |
|
343 self.assertEqual(url, expected_url, |
|
344 msg_prefix + "Response redirected to '%s', expected '%s'" % |
|
345 (url, expected_url)) |
|
346 |
|
347 def assertContains(self, response, text, count=None, status_code=200, |
|
348 msg_prefix=''): |
|
349 """ |
|
350 Asserts that a response indicates that a page was retrieved |
|
351 successfully, (i.e., the HTTP status code was as expected), and that |
|
352 ``text`` occurs ``count`` times in the content of the response. |
|
353 If ``count`` is None, the count doesn't matter - the assertion is true |
|
354 if the text occurs at least once in the response. |
|
355 """ |
|
356 if msg_prefix: |
|
357 msg_prefix += ": " |
|
358 |
|
359 self.assertEqual(response.status_code, status_code, |
|
360 msg_prefix + "Couldn't retrieve page: Response code was %d" |
|
361 " (expected %d)" % (response.status_code, status_code)) |
|
362 text = smart_str(text, response._charset) |
|
363 real_count = response.content.count(text) |
|
364 if count is not None: |
|
365 self.assertEqual(real_count, count, |
|
366 msg_prefix + "Found %d instances of '%s' in response" |
|
367 " (expected %d)" % (real_count, text, count)) |
|
368 else: |
|
369 self.failUnless(real_count != 0, |
|
370 msg_prefix + "Couldn't find '%s' in response" % text) |
|
371 |
|
372 def assertNotContains(self, response, text, status_code=200, |
|
373 msg_prefix=''): |
|
374 """ |
|
375 Asserts that a response indicates that a page was retrieved |
|
376 successfully, (i.e., the HTTP status code was as expected), and that |
|
377 ``text`` doesn't occurs in the content of the response. |
|
378 """ |
|
379 if msg_prefix: |
|
380 msg_prefix += ": " |
|
381 |
|
382 self.assertEqual(response.status_code, status_code, |
|
383 msg_prefix + "Couldn't retrieve page: Response code was %d" |
|
384 " (expected %d)" % (response.status_code, status_code)) |
|
385 text = smart_str(text, response._charset) |
|
386 self.assertEqual(response.content.count(text), 0, |
|
387 msg_prefix + "Response should not contain '%s'" % text) |
|
388 |
|
389 def assertFormError(self, response, form, field, errors, msg_prefix=''): |
|
390 """ |
|
391 Asserts that a form used to render the response has a specific field |
|
392 error. |
|
393 """ |
|
394 if msg_prefix: |
|
395 msg_prefix += ": " |
|
396 |
|
397 # Put context(s) into a list to simplify processing. |
|
398 contexts = to_list(response.context) |
|
399 if not contexts: |
|
400 self.fail(msg_prefix + "Response did not use any contexts to " |
|
401 "render the response") |
|
402 |
|
403 # Put error(s) into a list to simplify processing. |
|
404 errors = to_list(errors) |
|
405 |
|
406 # Search all contexts for the error. |
|
407 found_form = False |
|
408 for i,context in enumerate(contexts): |
|
409 if form not in context: |
|
410 continue |
|
411 found_form = True |
|
412 for err in errors: |
|
413 if field: |
|
414 if field in context[form].errors: |
|
415 field_errors = context[form].errors[field] |
|
416 self.failUnless(err in field_errors, |
|
417 msg_prefix + "The field '%s' on form '%s' in" |
|
418 " context %d does not contain the error '%s'" |
|
419 " (actual errors: %s)" % |
|
420 (field, form, i, err, repr(field_errors))) |
|
421 elif field in context[form].fields: |
|
422 self.fail(msg_prefix + "The field '%s' on form '%s'" |
|
423 " in context %d contains no errors" % |
|
424 (field, form, i)) |
|
425 else: |
|
426 self.fail(msg_prefix + "The form '%s' in context %d" |
|
427 " does not contain the field '%s'" % |
|
428 (form, i, field)) |
|
429 else: |
|
430 non_field_errors = context[form].non_field_errors() |
|
431 self.failUnless(err in non_field_errors, |
|
432 msg_prefix + "The form '%s' in context %d does not" |
|
433 " contain the non-field error '%s'" |
|
434 " (actual errors: %s)" % |
|
435 (form, i, err, non_field_errors)) |
|
436 if not found_form: |
|
437 self.fail(msg_prefix + "The form '%s' was not used to render the" |
|
438 " response" % form) |
|
439 |
|
440 def assertTemplateUsed(self, response, template_name, msg_prefix=''): |
|
441 """ |
|
442 Asserts that the template with the provided name was used in rendering |
|
443 the response. |
|
444 """ |
|
445 if msg_prefix: |
|
446 msg_prefix += ": " |
|
447 |
|
448 template_names = [t.name for t in to_list(response.template)] |
|
449 if not template_names: |
|
450 self.fail(msg_prefix + "No templates used to render the response") |
|
451 self.failUnless(template_name in template_names, |
|
452 msg_prefix + "Template '%s' was not a template used to render" |
|
453 " the response. Actual template(s) used: %s" % |
|
454 (template_name, u', '.join(template_names))) |
|
455 |
|
456 def assertTemplateNotUsed(self, response, template_name, msg_prefix=''): |
|
457 """ |
|
458 Asserts that the template with the provided name was NOT used in |
|
459 rendering the response. |
|
460 """ |
|
461 if msg_prefix: |
|
462 msg_prefix += ": " |
|
463 |
|
464 template_names = [t.name for t in to_list(response.template)] |
|
465 self.failIf(template_name in template_names, |
|
466 msg_prefix + "Template '%s' was used unexpectedly in rendering" |
|
467 " the response" % template_name) |
|
468 |
|
469 def connections_support_transactions(): |
|
470 """ |
|
471 Returns True if all connections support transactions. This is messy |
|
472 because 2.4 doesn't support any or all. |
|
473 """ |
|
474 return all(conn.settings_dict['SUPPORTS_TRANSACTIONS'] |
|
475 for conn in connections.all()) |
|
476 |
|
477 class TestCase(TransactionTestCase): |
|
478 """ |
|
479 Does basically the same as TransactionTestCase, but surrounds every test |
|
480 with a transaction, monkey-patches the real transaction management routines to |
|
481 do nothing, and rollsback the test transaction at the end of the test. You have |
|
482 to use TransactionTestCase, if you need transaction management inside a test. |
|
483 """ |
|
484 |
|
485 def _fixture_setup(self): |
|
486 if not connections_support_transactions(): |
|
487 return super(TestCase, self)._fixture_setup() |
|
488 |
|
489 # If the test case has a multi_db=True flag, setup all databases. |
|
490 # Otherwise, just use default. |
|
491 if getattr(self, 'multi_db', False): |
|
492 databases = connections |
|
493 else: |
|
494 databases = [DEFAULT_DB_ALIAS] |
|
495 |
|
496 for db in databases: |
|
497 transaction.enter_transaction_management(using=db) |
|
498 transaction.managed(True, using=db) |
|
499 disable_transaction_methods() |
|
500 |
|
501 from django.contrib.sites.models import Site |
|
502 Site.objects.clear_cache() |
|
503 |
|
504 for db in databases: |
|
505 if hasattr(self, 'fixtures'): |
|
506 call_command('loaddata', *self.fixtures, **{ |
|
507 'verbosity': 0, |
|
508 'commit': False, |
|
509 'database': db |
|
510 }) |
|
511 |
|
512 def _fixture_teardown(self): |
|
513 if not connections_support_transactions(): |
|
514 return super(TestCase, self)._fixture_teardown() |
|
515 |
|
516 # If the test case has a multi_db=True flag, teardown all databases. |
|
517 # Otherwise, just teardown default. |
|
518 if getattr(self, 'multi_db', False): |
|
519 databases = connections |
|
520 else: |
|
521 databases = [DEFAULT_DB_ALIAS] |
|
522 |
|
523 restore_transaction_methods() |
|
524 for db in databases: |
|
525 transaction.rollback(using=db) |
|
526 transaction.leave_transaction_management(using=db) |
|
527 |
|
528 for connection in connections.all(): |
|
529 connection.close() |