web/lib/django/test/client.py
changeset 0 0d40e90630ef
child 29 cc9b7e14412b
equal deleted inserted replaced
-1:000000000000 0:0d40e90630ef
       
     1 import urllib
       
     2 from urlparse import urlparse, urlunparse, urlsplit
       
     3 import sys
       
     4 import os
       
     5 import re
       
     6 try:
       
     7     from cStringIO import StringIO
       
     8 except ImportError:
       
     9     from StringIO import StringIO
       
    10 
       
    11 from django.conf import settings
       
    12 from django.contrib.auth import authenticate, login
       
    13 from django.core.handlers.base import BaseHandler
       
    14 from django.core.handlers.wsgi import WSGIRequest
       
    15 from django.core.signals import got_request_exception
       
    16 from django.http import SimpleCookie, HttpRequest, QueryDict
       
    17 from django.template import TemplateDoesNotExist
       
    18 from django.test import signals
       
    19 from django.utils.functional import curry
       
    20 from django.utils.encoding import smart_str
       
    21 from django.utils.http import urlencode
       
    22 from django.utils.importlib import import_module
       
    23 from django.utils.itercompat import is_iterable
       
    24 from django.db import transaction, close_connection
       
    25 from django.test.utils import ContextList
       
    26 
       
    27 BOUNDARY = 'BoUnDaRyStRiNg'
       
    28 MULTIPART_CONTENT = 'multipart/form-data; boundary=%s' % BOUNDARY
       
    29 CONTENT_TYPE_RE = re.compile('.*; charset=([\w\d-]+);?')
       
    30 
       
    31 class FakePayload(object):
       
    32     """
       
    33     A wrapper around StringIO that restricts what can be read since data from
       
    34     the network can't be seeked and cannot be read outside of its content
       
    35     length. This makes sure that views can't do anything under the test client
       
    36     that wouldn't work in Real Life.
       
    37     """
       
    38     def __init__(self, content):
       
    39         self.__content = StringIO(content)
       
    40         self.__len = len(content)
       
    41 
       
    42     def read(self, num_bytes=None):
       
    43         if num_bytes is None:
       
    44             num_bytes = self.__len or 1
       
    45         assert self.__len >= num_bytes, "Cannot read more than the available bytes from the HTTP incoming data."
       
    46         content = self.__content.read(num_bytes)
       
    47         self.__len -= num_bytes
       
    48         return content
       
    49 
       
    50 
       
    51 class ClientHandler(BaseHandler):
       
    52     """
       
    53     A HTTP Handler that can be used for testing purposes.
       
    54     Uses the WSGI interface to compose requests, but returns
       
    55     the raw HttpResponse object
       
    56     """
       
    57     def __call__(self, environ):
       
    58         from django.conf import settings
       
    59         from django.core import signals
       
    60 
       
    61         # Set up middleware if needed. We couldn't do this earlier, because
       
    62         # settings weren't available.
       
    63         if self._request_middleware is None:
       
    64             self.load_middleware()
       
    65 
       
    66         signals.request_started.send(sender=self.__class__)
       
    67         try:
       
    68             request = WSGIRequest(environ)
       
    69             response = self.get_response(request)
       
    70 
       
    71             # Apply response middleware.
       
    72             for middleware_method in self._response_middleware:
       
    73                 response = middleware_method(request, response)
       
    74             response = self.apply_response_fixes(request, response)
       
    75         finally:
       
    76             signals.request_finished.disconnect(close_connection)
       
    77             signals.request_finished.send(sender=self.__class__)
       
    78             signals.request_finished.connect(close_connection)
       
    79 
       
    80         return response
       
    81 
       
    82 def store_rendered_templates(store, signal, sender, template, context, **kwargs):
       
    83     """
       
    84     Stores templates and contexts that are rendered.
       
    85     """
       
    86     store.setdefault('template', []).append(template)
       
    87     store.setdefault('context', ContextList()).append(context)
       
    88 
       
    89 def encode_multipart(boundary, data):
       
    90     """
       
    91     Encodes multipart POST data from a dictionary of form values.
       
    92 
       
    93     The key will be used as the form data name; the value will be transmitted
       
    94     as content. If the value is a file, the contents of the file will be sent
       
    95     as an application/octet-stream; otherwise, str(value) will be sent.
       
    96     """
       
    97     lines = []
       
    98     to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
       
    99 
       
   100     # Not by any means perfect, but good enough for our purposes.
       
   101     is_file = lambda thing: hasattr(thing, "read") and callable(thing.read)
       
   102 
       
   103     # Each bit of the multipart form data could be either a form value or a
       
   104     # file, or a *list* of form values and/or files. Remember that HTTP field
       
   105     # names can be duplicated!
       
   106     for (key, value) in data.items():
       
   107         if is_file(value):
       
   108             lines.extend(encode_file(boundary, key, value))
       
   109         elif not isinstance(value, basestring) and is_iterable(value):
       
   110             for item in value:
       
   111                 if is_file(item):
       
   112                     lines.extend(encode_file(boundary, key, item))
       
   113                 else:
       
   114                     lines.extend([
       
   115                         '--' + boundary,
       
   116                         'Content-Disposition: form-data; name="%s"' % to_str(key),
       
   117                         '',
       
   118                         to_str(item)
       
   119                     ])
       
   120         else:
       
   121             lines.extend([
       
   122                 '--' + boundary,
       
   123                 'Content-Disposition: form-data; name="%s"' % to_str(key),
       
   124                 '',
       
   125                 to_str(value)
       
   126             ])
       
   127 
       
   128     lines.extend([
       
   129         '--' + boundary + '--',
       
   130         '',
       
   131     ])
       
   132     return '\r\n'.join(lines)
       
   133 
       
   134 def encode_file(boundary, key, file):
       
   135     to_str = lambda s: smart_str(s, settings.DEFAULT_CHARSET)
       
   136     return [
       
   137         '--' + boundary,
       
   138         'Content-Disposition: form-data; name="%s"; filename="%s"' \
       
   139             % (to_str(key), to_str(os.path.basename(file.name))),
       
   140         'Content-Type: application/octet-stream',
       
   141         '',
       
   142         file.read()
       
   143     ]
       
   144 
       
   145 class Client(object):
       
   146     """
       
   147     A class that can act as a client for testing purposes.
       
   148 
       
   149     It allows the user to compose GET and POST requests, and
       
   150     obtain the response that the server gave to those requests.
       
   151     The server Response objects are annotated with the details
       
   152     of the contexts and templates that were rendered during the
       
   153     process of serving the request.
       
   154 
       
   155     Client objects are stateful - they will retain cookie (and
       
   156     thus session) details for the lifetime of the Client instance.
       
   157 
       
   158     This is not intended as a replacement for Twill/Selenium or
       
   159     the like - it is here to allow testing against the
       
   160     contexts and templates produced by a view, rather than the
       
   161     HTML rendered to the end-user.
       
   162     """
       
   163     def __init__(self, **defaults):
       
   164         self.handler = ClientHandler()
       
   165         self.defaults = defaults
       
   166         self.cookies = SimpleCookie()
       
   167         self.exc_info = None
       
   168         self.errors = StringIO()
       
   169 
       
   170     def store_exc_info(self, **kwargs):
       
   171         """
       
   172         Stores exceptions when they are generated by a view.
       
   173         """
       
   174         self.exc_info = sys.exc_info()
       
   175 
       
   176     def _session(self):
       
   177         """
       
   178         Obtains the current session variables.
       
   179         """
       
   180         if 'django.contrib.sessions' in settings.INSTALLED_APPS:
       
   181             engine = import_module(settings.SESSION_ENGINE)
       
   182             cookie = self.cookies.get(settings.SESSION_COOKIE_NAME, None)
       
   183             if cookie:
       
   184                 return engine.SessionStore(cookie.value)
       
   185         return {}
       
   186     session = property(_session)
       
   187 
       
   188     def request(self, **request):
       
   189         """
       
   190         The master request method. Composes the environment dictionary
       
   191         and passes to the handler, returning the result of the handler.
       
   192         Assumes defaults for the query environment, which can be overridden
       
   193         using the arguments to the request.
       
   194         """
       
   195         environ = {
       
   196             'HTTP_COOKIE':      self.cookies,
       
   197             'PATH_INFO':         '/',
       
   198             'QUERY_STRING':      '',
       
   199             'REMOTE_ADDR':       '127.0.0.1',
       
   200             'REQUEST_METHOD':    'GET',
       
   201             'SCRIPT_NAME':       '',
       
   202             'SERVER_NAME':       'testserver',
       
   203             'SERVER_PORT':       '80',
       
   204             'SERVER_PROTOCOL':   'HTTP/1.1',
       
   205             'wsgi.version':      (1,0),
       
   206             'wsgi.url_scheme':   'http',
       
   207             'wsgi.errors':       self.errors,
       
   208             'wsgi.multiprocess': True,
       
   209             'wsgi.multithread':  False,
       
   210             'wsgi.run_once':     False,
       
   211         }
       
   212         environ.update(self.defaults)
       
   213         environ.update(request)
       
   214 
       
   215         # Curry a data dictionary into an instance of the template renderer
       
   216         # callback function.
       
   217         data = {}
       
   218         on_template_render = curry(store_rendered_templates, data)
       
   219         signals.template_rendered.connect(on_template_render)
       
   220 
       
   221         # Capture exceptions created by the handler.
       
   222         got_request_exception.connect(self.store_exc_info)
       
   223 
       
   224         try:
       
   225             response = self.handler(environ)
       
   226         except TemplateDoesNotExist, e:
       
   227             # If the view raises an exception, Django will attempt to show
       
   228             # the 500.html template. If that template is not available,
       
   229             # we should ignore the error in favor of re-raising the
       
   230             # underlying exception that caused the 500 error. Any other
       
   231             # template found to be missing during view error handling
       
   232             # should be reported as-is.
       
   233             if e.args != ('500.html',):
       
   234                 raise
       
   235 
       
   236         # Look for a signalled exception, clear the current context
       
   237         # exception data, then re-raise the signalled exception.
       
   238         # Also make sure that the signalled exception is cleared from
       
   239         # the local cache!
       
   240         if self.exc_info:
       
   241             exc_info = self.exc_info
       
   242             self.exc_info = None
       
   243             raise exc_info[1], None, exc_info[2]
       
   244 
       
   245         # Save the client and request that stimulated the response.
       
   246         response.client = self
       
   247         response.request = request
       
   248 
       
   249         # Add any rendered template detail to the response.
       
   250         # If there was only one template rendered (the most likely case),
       
   251         # flatten the list to a single element.
       
   252         for detail in ('template', 'context'):
       
   253             if data.get(detail):
       
   254                 if len(data[detail]) == 1:
       
   255                     setattr(response, detail, data[detail][0]);
       
   256                 else:
       
   257                     setattr(response, detail, data[detail])
       
   258             else:
       
   259                 setattr(response, detail, None)
       
   260 
       
   261         # Update persistent cookie data.
       
   262         if response.cookies:
       
   263             self.cookies.update(response.cookies)
       
   264 
       
   265         return response
       
   266 
       
   267     def get(self, path, data={}, follow=False, **extra):
       
   268         """
       
   269         Requests a response from the server using GET.
       
   270         """
       
   271         parsed = urlparse(path)
       
   272         r = {
       
   273             'CONTENT_TYPE':    'text/html; charset=utf-8',
       
   274             'PATH_INFO':       urllib.unquote(parsed[2]),
       
   275             'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
       
   276             'REQUEST_METHOD': 'GET',
       
   277             'wsgi.input':      FakePayload('')
       
   278         }
       
   279         r.update(extra)
       
   280 
       
   281         response = self.request(**r)
       
   282         if follow:
       
   283             response = self._handle_redirects(response)
       
   284         return response
       
   285 
       
   286     def post(self, path, data={}, content_type=MULTIPART_CONTENT,
       
   287              follow=False, **extra):
       
   288         """
       
   289         Requests a response from the server using POST.
       
   290         """
       
   291         if content_type is MULTIPART_CONTENT:
       
   292             post_data = encode_multipart(BOUNDARY, data)
       
   293         else:
       
   294             # Encode the content so that the byte representation is correct.
       
   295             match = CONTENT_TYPE_RE.match(content_type)
       
   296             if match:
       
   297                 charset = match.group(1)
       
   298             else:
       
   299                 charset = settings.DEFAULT_CHARSET
       
   300             post_data = smart_str(data, encoding=charset)
       
   301 
       
   302         parsed = urlparse(path)
       
   303         r = {
       
   304             'CONTENT_LENGTH': len(post_data),
       
   305             'CONTENT_TYPE':   content_type,
       
   306             'PATH_INFO':      urllib.unquote(parsed[2]),
       
   307             'QUERY_STRING':   parsed[4],
       
   308             'REQUEST_METHOD': 'POST',
       
   309             'wsgi.input':     FakePayload(post_data),
       
   310         }
       
   311         r.update(extra)
       
   312 
       
   313         response = self.request(**r)
       
   314         if follow:
       
   315             response = self._handle_redirects(response)
       
   316         return response
       
   317 
       
   318     def head(self, path, data={}, follow=False, **extra):
       
   319         """
       
   320         Request a response from the server using HEAD.
       
   321         """
       
   322         parsed = urlparse(path)
       
   323         r = {
       
   324             'CONTENT_TYPE':    'text/html; charset=utf-8',
       
   325             'PATH_INFO':       urllib.unquote(parsed[2]),
       
   326             'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
       
   327             'REQUEST_METHOD': 'HEAD',
       
   328             'wsgi.input':      FakePayload('')
       
   329         }
       
   330         r.update(extra)
       
   331 
       
   332         response = self.request(**r)
       
   333         if follow:
       
   334             response = self._handle_redirects(response)
       
   335         return response
       
   336 
       
   337     def options(self, path, data={}, follow=False, **extra):
       
   338         """
       
   339         Request a response from the server using OPTIONS.
       
   340         """
       
   341         parsed = urlparse(path)
       
   342         r = {
       
   343             'PATH_INFO':       urllib.unquote(parsed[2]),
       
   344             'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
       
   345             'REQUEST_METHOD': 'OPTIONS',
       
   346             'wsgi.input':      FakePayload('')
       
   347         }
       
   348         r.update(extra)
       
   349 
       
   350         response = self.request(**r)
       
   351         if follow:
       
   352             response = self._handle_redirects(response)
       
   353         return response
       
   354 
       
   355     def put(self, path, data={}, content_type=MULTIPART_CONTENT,
       
   356             follow=False, **extra):
       
   357         """
       
   358         Send a resource to the server using PUT.
       
   359         """
       
   360         if content_type is MULTIPART_CONTENT:
       
   361             post_data = encode_multipart(BOUNDARY, data)
       
   362         else:
       
   363             post_data = data
       
   364 
       
   365         parsed = urlparse(path)
       
   366         r = {
       
   367             'CONTENT_LENGTH': len(post_data),
       
   368             'CONTENT_TYPE':   content_type,
       
   369             'PATH_INFO':      urllib.unquote(parsed[2]),
       
   370             'QUERY_STRING':   urlencode(data, doseq=True) or parsed[4],
       
   371             'REQUEST_METHOD': 'PUT',
       
   372             'wsgi.input':     FakePayload(post_data),
       
   373         }
       
   374         r.update(extra)
       
   375 
       
   376         response = self.request(**r)
       
   377         if follow:
       
   378             response = self._handle_redirects(response)
       
   379         return response
       
   380 
       
   381     def delete(self, path, data={}, follow=False, **extra):
       
   382         """
       
   383         Send a DELETE request to the server.
       
   384         """
       
   385         parsed = urlparse(path)
       
   386         r = {
       
   387             'PATH_INFO':       urllib.unquote(parsed[2]),
       
   388             'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
       
   389             'REQUEST_METHOD': 'DELETE',
       
   390             'wsgi.input':      FakePayload('')
       
   391         }
       
   392         r.update(extra)
       
   393 
       
   394         response = self.request(**r)
       
   395         if follow:
       
   396             response = self._handle_redirects(response)
       
   397         return response
       
   398 
       
   399     def login(self, **credentials):
       
   400         """
       
   401         Sets the Client to appear as if it has successfully logged into a site.
       
   402 
       
   403         Returns True if login is possible; False if the provided credentials
       
   404         are incorrect, or the user is inactive, or if the sessions framework is
       
   405         not available.
       
   406         """
       
   407         user = authenticate(**credentials)
       
   408         if user and user.is_active \
       
   409                 and 'django.contrib.sessions' in settings.INSTALLED_APPS:
       
   410             engine = import_module(settings.SESSION_ENGINE)
       
   411 
       
   412             # Create a fake request to store login details.
       
   413             request = HttpRequest()
       
   414             if self.session:
       
   415                 request.session = self.session
       
   416             else:
       
   417                 request.session = engine.SessionStore()
       
   418             login(request, user)
       
   419 
       
   420             # Set the cookie to represent the session.
       
   421             session_cookie = settings.SESSION_COOKIE_NAME
       
   422             self.cookies[session_cookie] = request.session.session_key
       
   423             cookie_data = {
       
   424                 'max-age': None,
       
   425                 'path': '/',
       
   426                 'domain': settings.SESSION_COOKIE_DOMAIN,
       
   427                 'secure': settings.SESSION_COOKIE_SECURE or None,
       
   428                 'expires': None,
       
   429             }
       
   430             self.cookies[session_cookie].update(cookie_data)
       
   431 
       
   432             # Save the session values.
       
   433             request.session.save()
       
   434 
       
   435             return True
       
   436         else:
       
   437             return False
       
   438 
       
   439     def logout(self):
       
   440         """
       
   441         Removes the authenticated user's cookies and session object.
       
   442 
       
   443         Causes the authenticated user to be logged out.
       
   444         """
       
   445         session = import_module(settings.SESSION_ENGINE).SessionStore()
       
   446         session_cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
       
   447         if session_cookie:
       
   448             session.delete(session_key=session_cookie.value)
       
   449         self.cookies = SimpleCookie()
       
   450 
       
   451     def _handle_redirects(self, response):
       
   452         "Follows any redirects by requesting responses from the server using GET."
       
   453 
       
   454         response.redirect_chain = []
       
   455         while response.status_code in (301, 302, 303, 307):
       
   456             url = response['Location']
       
   457             scheme, netloc, path, query, fragment = urlsplit(url)
       
   458 
       
   459             redirect_chain = response.redirect_chain
       
   460             redirect_chain.append((url, response.status_code))
       
   461 
       
   462             # The test client doesn't handle external links,
       
   463             # but since the situation is simulated in test_client,
       
   464             # we fake things here by ignoring the netloc portion of the
       
   465             # redirected URL.
       
   466             response = self.get(path, QueryDict(query), follow=False)
       
   467             response.redirect_chain = redirect_chain
       
   468 
       
   469             # Prevent loops
       
   470             if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
       
   471                 break
       
   472         return response
       
   473