web/lib/django/test/simple.py
changeset 29 cc9b7e14412b
parent 0 0d40e90630ef
--- a/web/lib/django/test/simple.py	Wed May 19 17:43:59 2010 +0200
+++ b/web/lib/django/test/simple.py	Tue May 25 02:43:45 2010 +0200
@@ -1,4 +1,7 @@
+import sys
+import signal
 import unittest
+
 from django.conf import settings
 from django.db.models import get_app, get_apps
 from django.test import _doctest as doctest
@@ -10,6 +13,54 @@
 
 doctestOutputChecker = OutputChecker()
 
+class DjangoTestRunner(unittest.TextTestRunner):
+
+    def __init__(self, verbosity=0, failfast=False, **kwargs):
+        super(DjangoTestRunner, self).__init__(verbosity=verbosity, **kwargs)
+        self.failfast = failfast
+        self._keyboard_interrupt_intercepted = False
+
+    def run(self, *args, **kwargs):
+        """
+        Runs the test suite after registering a custom signal handler
+        that triggers a graceful exit when Ctrl-C is pressed.
+        """
+        self._default_keyboard_interrupt_handler = signal.signal(signal.SIGINT,
+            self._keyboard_interrupt_handler)
+        try:
+            result = super(DjangoTestRunner, self).run(*args, **kwargs)
+        finally:
+            signal.signal(signal.SIGINT, self._default_keyboard_interrupt_handler)
+        return result
+
+    def _keyboard_interrupt_handler(self, signal_number, stack_frame):
+        """
+        Handles Ctrl-C by setting a flag that will stop the test run when
+        the currently running test completes.
+        """
+        self._keyboard_interrupt_intercepted = True
+        sys.stderr.write(" <Test run halted by Ctrl-C> ")
+        # Set the interrupt handler back to the default handler, so that
+        # another Ctrl-C press will trigger immediate exit.
+        signal.signal(signal.SIGINT, self._default_keyboard_interrupt_handler)
+
+    def _makeResult(self):
+        result = super(DjangoTestRunner, self)._makeResult()
+        failfast = self.failfast
+
+        def stoptest_override(func):
+            def stoptest(test):
+                # If we were set to failfast and the unit test failed,
+                # or if the user has typed Ctrl-C, report and quit
+                if (failfast and not result.wasSuccessful()) or \
+                    self._keyboard_interrupt_intercepted:
+                    result.stop()
+                func(test)
+            return stoptest
+
+        setattr(result, 'stopTest', stoptest_override(result.stopTest))
+        return result
+
 def get_tests(app_module):
     try:
         app_path = app_module.__name__.split('.')[:-1]
@@ -73,41 +124,66 @@
     return suite
 
 def build_test(label):
-    """Construct a test case a test with the specified label. Label should
-    be of the form model.TestClass or model.TestClass.test_method. Returns
-    an instantiated test or test suite corresponding to the label provided.
+    """Construct a test case with the specified label. Label should be of the
+    form model.TestClass or model.TestClass.test_method. Returns an
+    instantiated test or test suite corresponding to the label provided.
 
     """
     parts = label.split('.')
     if len(parts) < 2 or len(parts) > 3:
         raise ValueError("Test label '%s' should be of the form app.TestCase or app.TestCase.test_method" % label)
 
+    #
+    # First, look for TestCase instances with a name that matches
+    #
     app_module = get_app(parts[0])
+    test_module = get_tests(app_module)
     TestClass = getattr(app_module, parts[1], None)
 
     # Couldn't find the test class in models.py; look in tests.py
     if TestClass is None:
-        test_module = get_tests(app_module)
         if test_module:
             TestClass = getattr(test_module, parts[1], None)
 
-    if len(parts) == 2: # label is app.TestClass
+    try:
+        if issubclass(TestClass, unittest.TestCase):
+            if len(parts) == 2: # label is app.TestClass
+                try:
+                    return unittest.TestLoader().loadTestsFromTestCase(TestClass)
+                except TypeError:
+                    raise ValueError("Test label '%s' does not refer to a test class" % label)
+            else: # label is app.TestClass.test_method
+                return TestClass(parts[2])
+    except TypeError:
+        # TestClass isn't a TestClass - it must be a method or normal class
+        pass
+
+    #
+    # If there isn't a TestCase, look for a doctest that matches
+    #
+    tests = []
+    for module in app_module, test_module:
         try:
-            return unittest.TestLoader().loadTestsFromTestCase(TestClass)
-        except TypeError:
-            raise ValueError("Test label '%s' does not refer to a test class" % label)
-    else: # label is app.TestClass.test_method
-        if not TestClass:
-            raise ValueError("Test label '%s' does not refer to a test class" % label)
-        return TestClass(parts[2])
+            doctests = doctest.DocTestSuite(module,
+                                            checker=doctestOutputChecker,
+                                            runner=DocTestRunner)
+            # Now iterate over the suite, looking for doctests whose name
+            # matches the pattern that was given
+            for test in doctests:
+                if test._dt_test.name in (
+                        '%s.%s' % (module.__name__, '.'.join(parts[1:])),
+                        '%s.__test__.%s' % (module.__name__, '.'.join(parts[1:]))):
+                    tests.append(test)
+        except ValueError:
+            # No doctests found.
+            pass
 
-# Python 2.3 compatibility: TestSuites were made iterable in 2.4.
-# We need to iterate over them, so we add the missing method when
-# necessary.
-try:
-    getattr(unittest.TestSuite, '__iter__')
-except AttributeError:
-    setattr(unittest.TestSuite, '__iter__', lambda s: iter(s._tests))
+    # If no tests were found, then we were given a bad test label.
+    if not tests:
+        raise ValueError("Test label '%s' does not refer to a test" % label)
+
+    # Construct a suite out of the tests that matched.
+    return unittest.TestSuite(tests)
 
 def partition_suite(suite, classes, bins):
     """
@@ -146,52 +222,105 @@
         bins[0].addTests(bins[i+1])
     return bins[0]
 
-def run_tests(test_labels, verbosity=1, interactive=True, extra_tests=[]):
-    """
-    Run the unit tests for all the test labels in the provided list.
-    Labels must be of the form:
-     - app.TestClass.test_method
-        Run a single specific test method
-     - app.TestClass
-        Run all the test methods in a given class
-     - app
-        Search for doctests and unittests in the named application.
+
+class DjangoTestSuiteRunner(object):
+    def __init__(self, verbosity=1, interactive=True, failfast=True, **kwargs):
+        self.verbosity = verbosity
+        self.interactive = interactive
+        self.failfast = failfast
+
+    def setup_test_environment(self, **kwargs):
+        setup_test_environment()
+        settings.DEBUG = False
+
+    def build_suite(self, test_labels, extra_tests=None, **kwargs):
+        suite = unittest.TestSuite()
+
+        if test_labels:
+            for label in test_labels:
+                if '.' in label:
+                    suite.addTest(build_test(label))
+                else:
+                    app = get_app(label)
+                    suite.addTest(build_suite(app))
+        else:
+            for app in get_apps():
+                suite.addTest(build_suite(app))
 
-    When looking for tests, the test runner will look in the models and
-    tests modules for the application.
+        if extra_tests:
+            for test in extra_tests:
+                suite.addTest(test)
 
-    A list of 'extra' tests may also be provided; these tests
-    will be added to the test suite.
+        return reorder_suite(suite, (TestCase,))
 
-    Returns the number of tests that failed.
-    """
-    setup_test_environment()
+    def setup_databases(self, **kwargs):
+        from django.db import connections
+        old_names = []
+        mirrors = []
+        for alias in connections:
+            connection = connections[alias]
+            # If the database is a test mirror, redirect it's connection
+            # instead of creating a test database.
+            if connection.settings_dict['TEST_MIRROR']:
+                mirrors.append((alias, connection))
+                mirror_alias = connection.settings_dict['TEST_MIRROR']
+                connections._connections[alias] = connections[mirror_alias]
+            else:
+                old_names.append((connection, connection.settings_dict['NAME']))
+                connection.creation.create_test_db(self.verbosity, autoclobber=not self.interactive)
+        return old_names, mirrors
 
-    settings.DEBUG = False
-    suite = unittest.TestSuite()
+    def run_suite(self, suite, **kwargs):
+        return DjangoTestRunner(verbosity=self.verbosity, failfast=self.failfast).run(suite)
 
-    if test_labels:
-        for label in test_labels:
-            if '.' in label:
-                suite.addTest(build_test(label))
-            else:
-                app = get_app(label)
-                suite.addTest(build_suite(app))
-    else:
-        for app in get_apps():
-            suite.addTest(build_suite(app))
+    def teardown_databases(self, old_config, **kwargs):
+        from django.db import connections
+        old_names, mirrors = old_config
+        # Point all the mirrors back to the originals
+        for alias, connection in mirrors:
+            connections._connections[alias] = connection
+        # Destroy all the non-mirror databases
+        for connection, old_name in old_names:
+            connection.creation.destroy_test_db(old_name, self.verbosity)
+
+    def teardown_test_environment(self, **kwargs):
+        teardown_test_environment()
+
+    def suite_result(self, suite, result, **kwargs):
+        return len(result.failures) + len(result.errors)
 
-    for test in extra_tests:
-        suite.addTest(test)
+    def run_tests(self, test_labels, extra_tests=None, **kwargs):
+        """
+        Run the unit tests for all the test labels in the provided list.
+        Labels must be of the form:
+         - app.TestClass.test_method
+            Run a single specific test method
+         - app.TestClass
+            Run all the test methods in a given class
+         - app
+            Search for doctests and unittests in the named application.
 
-    suite = reorder_suite(suite, (TestCase,))
+        When looking for tests, the test runner will look in the models and
+        tests modules for the application.
+
+        A list of 'extra' tests may also be provided; these tests
+        will be added to the test suite.
 
-    old_name = settings.DATABASE_NAME
-    from django.db import connection
-    connection.creation.create_test_db(verbosity, autoclobber=not interactive)
-    result = unittest.TextTestRunner(verbosity=verbosity).run(suite)
-    connection.creation.destroy_test_db(old_name, verbosity)
+        Returns the number of tests that failed.
+        """
+        self.setup_test_environment()
+        suite = self.build_suite(test_labels, extra_tests)
+        old_config = self.setup_databases()
+        result = self.run_suite(suite)
+        self.teardown_databases(old_config)
+        self.teardown_test_environment()
+        return self.suite_result(suite, result)
 
-    teardown_test_environment()
-
-    return len(result.failures) + len(result.errors)
+def run_tests(test_labels, verbosity=1, interactive=True, failfast=False, extra_tests=None):
+    import warnings
+    warnings.warn(
+        'The run_tests() test runner has been deprecated in favor of DjangoTestSuiteRunner.',
+        PendingDeprecationWarning
+    )
+    test_runner = DjangoTestSuiteRunner(verbosity=verbosity, interactive=interactive, failfast=failfast)
+    return test_runner.run_tests(test_labels, extra_tests=extra_tests)