web/ldt/test/client.py
author ymh <ymh.work@gmail.com>
Mon, 13 Dec 2010 23:55:19 +0100
changeset 22 83b28fc0d731
permissions -rw-r--r--
improve on ldt test framework start migration for text test

#from django.test.client import Client as DClient
from django.conf import settings
from django.core.urlresolvers import reverse
from django.http import HttpResponse, SimpleCookie
from django.test.client import encode_multipart, encode_file, Client, BOUNDARY, \
    MULTIPART_CONTENT, CONTENT_TYPE_RE
from django.utils.encoding import smart_str
from django.utils.http import urlencode
from ldt.utils import Property
from oauth2 import Request, Consumer, Token, SignatureMethod_HMAC_SHA1, \
    generate_nonce, SignatureMethod_PLAINTEXT
from oauth_provider.consts import OUT_OF_BAND
from urlparse import urlsplit, urlunsplit, urlparse, urlunparse, parse_qs
import httplib2
import logging
import re
try:
    from cStringIO import StringIO
except ImportError:
    from StringIO import StringIO


class WebClient(object):
    """
    A class that can act as a client for testing purposes.

    It allows the user to compose GET and POST requests, and
    obtain the response that the server gave to those requests.
    The server Response objects are annotated with the details
    of the contexts and templates that were rendered during the
    process of serving the request.

    Client objects are stateful - they will retain cookie (and
    thus session) details for the lifetime of the Client instance.

    This is not intended as a replacement for Twill/Selenium or
    the like - it is here to allow testing against the
    contexts and templates produced by a view, rather than the
    HTML rendered to the end-user.
    """
    def __init__(self, **defaults):
        self.handler = httplib2.Http()
        #self.defaults = defaults
        self.cookies = SimpleCookie()
        #self.exc_info = None
        #self.errors = StringIO()
        self.__baseurltuple = ()
        self.__login_url = None
    
    @Property
    def baseurl():
        
        def fget(self):
            return self.__baseurltuple
        
        def fset(self, value):            
            if isinstance(value, tuple):
                self.__baseurltuple = value
            else:
                self.__baseurltuple = urlsplit(unicode(value))
        
        return locals()

    @Property
    def login_url():
        
        def fget(self):
            return self.__login_url
        
        def fset(self, value):            
            self.__login_url = value
        
        return locals()
    

    def _mergeurl(self, urltuple):
        res = ["" for i in range(5)]
        for i in range(min(len(self.baseurl), len(urltuple))):
            res[i] = self.baseurl[i] or urltuple[i]
                
        return urlunsplit(res)
    
    def _process_response(self, response, content):        
        resp = HttpResponse(content=content, status=response.status, content_type=response['content-type'])
        if 'set-cookie' in response:
            self.cookies.load(response['set-cookie'])
            resp.cookies.load(response['set-cookie'])
        
        resp.client = self
        resp.raw_response = response
        for key,value in response.items():
            resp[key] = value 

        return resp

    def _handle_redirects(self, response):

        response.redirect_chain = []
        
        r = response.raw_response.previous
        while not r is None:
            response.redirect_chain.append((r['content-location'],r.status))
            r = r.previous
        
        return response


    def get(self, path, data={}, follow=False, **extra):
        """
        Requests a response from the server using GET.
        """
        parsed = list(urlsplit(path))
        parsed[3] = urlencode(data, doseq=True) or parsed[3]
                
        
        fullpath = self._mergeurl(parsed)
        self.handler.follow_redirects = follow
        
        headers = {}
        if len(self.cookies) > 0:
            headers['Cookie'] = self.cookies.output()
            
        if extra:
            headers.update(extra)
        
        response, content = self.handler.request(fullpath, method="GET", headers=headers)
        
        resp = self._process_response(response, content)
        
        if follow:
            resp = self._handle_redirects(resp)
        return resp


    def post(self, path, data={}, content_type="application/x-www-form-urlencoded",
             follow=False, **extra):
        """
        Requests a response from the server using POST.
        """
        if content_type == MULTIPART_CONTENT:
            post_data = encode_multipart(BOUNDARY, data)
        elif content_type  == "application/x-www-form-urlencoded":
            post_data = urlencode(data)            
        else:
            # Encode the content so that the byte representation is correct.
            match = CONTENT_TYPE_RE.match(content_type)
            if match:
                charset = match.group(1)
            else:
                charset = settings.DEFAULT_CHARSET
            post_data = smart_str(data, encoding=charset)

        parsed = list(urlsplit(path))
        fullpath = self._mergeurl(parsed)
        self.handler.follow_redirects = follow
        
        headers = {}
        headers['Content-type'] = content_type
        if len(self.cookies) > 0:
            headers['Cookie'] = self.cookies.output()
            
        if extra:
            headers.update(extra)

        response,content = self.handler.request(fullpath, method="POST", headers=headers, body=post_data)
        
        resp = self._process_response(response, content)
        
        if follow:
            resp = self._handle_redirects(response)
        return resp
    
    def login(self, **credentials):
        """
        Sets the Client to appear as if it has successfully logged into a site.

        Returns True if login is possible; False if the provided credentials
        are incorrect, or the user is inactive, or if the sessions framework is
        not available.
        """
        resp = self.post(path=self.login_url, data=credentials, follow=False, **{"X-Requested-With" : "XMLHttpRequest"})
        return resp.status_code == 302

#
#    def head(self, path, data={}, follow=False, **extra):
#        """
#        Request a response from the server using HEAD.
#        """
#        parsed = urlparse(path)
#        r = {
#            'CONTENT_TYPE':    'text/html; charset=utf-8',
#            'PATH_INFO':       urllib.unquote(parsed[2]),
#            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
#            'REQUEST_METHOD': 'HEAD',
#            'wsgi.input':      FakePayload('')
#        }
#        r.update(extra)
#
#        response = self.request(**r)
#        if follow:
#            response = self._handle_redirects(response)
#        return response
#
#    def options(self, path, data={}, follow=False, **extra):
#        """
#        Request a response from the server using OPTIONS.
#        """
#        parsed = urlparse(path)
#        
#        r = {
#            'PATH_INFO':       urllib.unquote(parsed[2]),
#            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
#            'REQUEST_METHOD': 'OPTIONS',
#            'wsgi.input':      FakePayload('')
#        }
#        r.update(extra)
#
#        response = self.request(**r)
#        if follow:
#            response = self._handle_redirects(response)
#        return response
#
#    def put(self, path, data={}, content_type=MULTIPART_CONTENT,
#            follow=False, **extra):
#        """
#        Send a resource to the server using PUT.
#        """
#        if content_type is MULTIPART_CONTENT:
#            post_data = encode_multipart(BOUNDARY, data)
#        else:
#            post_data = data
#
#        # Make `data` into a querystring only if it's not already a string. If
#        # it is a string, we'll assume that the caller has already encoded it.
#        query_string = None
#        if not isinstance(data, basestring):
#            query_string = urlencode(data, doseq=True)
#
#        parsed = urlparse(path)
#        r = {
#            'CONTENT_LENGTH': len(post_data),
#            'CONTENT_TYPE':   content_type,
#            'PATH_INFO':      urllib.unquote(parsed[2]),
#            'QUERY_STRING':   query_string or parsed[4],
#            'REQUEST_METHOD': 'PUT',
#            'wsgi.input':     FakePayload(post_data),
#        }
#        r.update(extra)
#
#        response = self.request(**r)
#        if follow:
#            response = self._handle_redirects(response)
#        return response
#
#    def delete(self, path, data={}, follow=False, **extra):
#        """
#        Send a DELETE request to the server.
#        """
#        parsed = urlparse(path)
#        r = {
#            'PATH_INFO':       urllib.unquote(parsed[2]),
#            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
#            'REQUEST_METHOD': 'DELETE',
#            'wsgi.input':      FakePayload('')
#        }
#        r.update(extra)
#
#        response = self.request(**r)
#        if follow:
#            response = self._handle_redirects(response)
#        return response
#
#    def login(self, **credentials):
#        """
#        Sets the Client to appear as if it has successfully logged into a site.
#
#        Returns True if login is possible; False if the provided credentials
#        are incorrect, or the user is inactive, or if the sessions framework is
#        not available.
#        """
#        user = authenticate(**credentials)
#        if user and user.is_active \
#                and 'django.contrib.sessions' in settings.INSTALLED_APPS:
#            engine = import_module(settings.SESSION_ENGINE)
#
#            # Create a fake request to store login details.
#            request = HttpRequest()
#            if self.session:
#                request.session = self.session
#            else:
#                request.session = engine.SessionStore()
#            login(request, user)
#
#            # Save the session values.
#            request.session.save()
#
#            # Set the cookie to represent the session.
#            session_cookie = settings.SESSION_COOKIE_NAME
#            self.cookies[session_cookie] = request.session.session_key
#            cookie_data = {
#                'max-age': None,
#                'path': '/',
#                'domain': settings.SESSION_COOKIE_DOMAIN,
#                'secure': settings.SESSION_COOKIE_SECURE or None,
#                'expires': None,
#            }
#            self.cookies[session_cookie].update(cookie_data)
#
#            return True
#        else:
#            return False
#
#    def logout(self):
#        """
#        Removes the authenticated user's cookies and session object.
#
#        Causes the authenticated user to be logged out.
#        """
#        session = import_module(settings.SESSION_ENGINE).SessionStore()
#        session_cookie = self.cookies.get(settings.SESSION_COOKIE_NAME)
#        if session_cookie:
#            session.delete(session_key=session_cookie.value)
#        self.cookies = SimpleCookie()

        
class OAuthPayload(object):
    
    def __init__(self, servername="testserver"):
        self._token = None
        self._servername = servername
        self._oauth_parameters = {
            'oauth_version': '1.0'
        }
        self._oauth_parameters_extra = {
            'oauth_callback': 'http://127.0.0.1/callback',
            'scope':'all'
        }
        self.errors = StringIO()
        
    def _get_signed_request(self, method, path, params):
        
        parameters = params.copy()
        parameters.update(self._oauth_parameters)
        oauth_request = Request.from_consumer_and_token(consumer=self._consumer, token=self._token, http_method=method, http_url=path, parameters=parameters)
        oauth_request.sign_request(SignatureMethod_HMAC_SHA1(), consumer=self._consumer, token=self._token)
        
        return oauth_request
    
        
    def set_consumer(self, key, secret):
        self._consumer = Consumer(key, secret)
        self._oauth_parameters['oauth_consumer_key'] = key
        
    def set_scope(self, value):
        self._oauth_parameters_extra['scope'] = value

    def inject_oauth_data(self, path, method, data):
        
        path_parsed = urlparse(path)
                
        if method=='GET' and (data is None or len(data) == 0):
            new_data = parse_qs(path_parsed[4])
        elif  data is None:
            new_data = {}
        else:
            new_data = data.copy()
            
        clean_path = ['']*6
        clean_path[0] = 'http'
        clean_path[1] = self._servername
        for i in range(0,4):
            clean_path[i] = path_parsed[i] or clean_path[i]
        path = urlunparse(clean_path)
        
        oauth_request = self._get_signed_request(method, path, new_data)
                
        new_data.update(oauth_request)
        
        return new_data
    
    def login(self, client, login_method, **credential):
        
        
        self._oauth_parameters.update(self._oauth_parameters_extra)
        #Obtaining a Request Token
        resp = client.get(reverse('oauth_request_token'), follow=True)
        if resp.status_code == 200:
            self._token = Token.from_string(resp.content)
        else:
            self.errors.write("oauth_request_token response status code fail : " + repr(resp))
            return False
                
        #Requesting User Authorization
        res = login_method(client, **credential)
        if not res:
            self.errors.write("login failed : " + repr(credential))
            return False

        resp = client.get(reverse('oauth_user_authorization'))
        if resp.status_code != 200:
            self.errors.write("oauth_user_authorization get response status code fail : " + repr(resp))
            return False
        
        #"X-Requested-With" : "XMLHttpRequest"
        resp = client.post(reverse('oauth_user_authorization'), {'authorize_access':1}, **{"X-Requested-With" : "XMLHttpRequest"})
        if resp.status_code != 302:
            self.errors.write("oauth_user_authorization post response status code fail : " + repr(resp))
            return False
        
        location_splitted = urlsplit(resp["Location"])
        location_query_dict = parse_qs(location_splitted[3])
        self._token.verifier = location_query_dict['oauth_verifier']
        
                
        #Obtaining an Access Token
        resp = client.get(reverse('oauth_access_token'))
        if resp.status_code == 200:            
            self._token = Token.from_string(resp.content)
            for key in self._oauth_parameters_extra.keys():
                if key in self._oauth_parameters:
                    del(self._oauth_parameters[key])
            return True
        else:
            self.errors.write("oauth_access_token get response status code fail : " + repr(resp))
            return False

    def logout(self):
        self._token = None

METHOD_MAPPING = {
    'get'     : 'GET',
    'post'    : 'POST',
    'put'     : 'POST',
    'head'    : 'GET',
    'options' : 'GET',
    'delete'  : 'GET'
}

def _generate_request_wrapper(meth):
    def request_wrapper(inst, *args, **kwargs):
        path = args[0] if len(args) > 0 else kwargs.get('path','')
        data = args[1] if len(args) > 1 else kwargs.get('data',{})
        args = args[2:]
        if 'path' in kwargs:
            del(kwargs['path'])        
        if 'data' in kwargs:
            del(kwargs['data'])
        data = inst._oauth_data.inject_oauth_data(path, METHOD_MAPPING[meth.__name__], data)
        return meth(inst,path=path, data=data, *args, **kwargs)
    return request_wrapper

def _generate_login_wrapper(meth):
    def login_wrapper(inst, **credential):
        return inst._oauth_data.login(inst, meth, **credential)
    return login_wrapper
    
def _generate_logout_wrapper(meth):
    def logout_wrapper(inst):
        inst._oauth_data.logout()
        meth(inst)
    return logout_wrapper
        
class OAuthMetaclass(type):
    
    def __new__(cls, name, bases, attrs):
        newattrs = {}
        def set_consumer(inst, key, secret):
            inst._oauth_data.set_consumer(key,secret)
        newattrs['set_consumer'] = set_consumer
        def set_scope(inst, scope):
            inst._oauth_data.set_scope(scope)
        newattrs['set_scope'] = set_scope
        
        for attrname, attrvalue in attrs.iteritems():
            if attrname in ('get', 'post', 'head', 'options', 'put', 'delete'):                
                newattrs[attrname] = _generate_request_wrapper(attrvalue)
            elif attrname == 'login':
                newattrs[attrname] = _generate_login_wrapper(attrvalue)
            elif attrname == 'logout':
                newattrs[attrname] = _generate_logout_wrapper(attrvalue)
            else:
                newattrs[attrname] = attrvalue
                
        for klass in bases:
            for attrname, attrvalue in klass.__dict__.iteritems():
                if attrname in newattrs:
                    continue
                if attrname in ('get', 'post', 'head', 'options', 'put', 'delete'):
                    newattrs[attrname] = _generate_request_wrapper(attrvalue)
                elif attrname == 'login':
                    newattrs[attrname] = _generate_login_wrapper(attrvalue)
                elif attrname == 'logout':
                    newattrs[attrname] = _generate_logout_wrapper(attrvalue)
        
        init_method = newattrs.get("__init__", None)
        
        def new_init(inst, *args, **kwargs):
            inst._oauth_data = OAuthPayload(attrs.get('servername','testserver'))
            if init_method is not None:
                init_method(*args,**kwargs)
            else:
                super(inst.__class__,inst).__init__(*args,**kwargs)
        newattrs["__init__"] = new_init
                        
        return super(OAuthMetaclass, cls).__new__(cls, name, bases, newattrs)
    

            
class OAuthClient(Client):
    __metaclass__ = OAuthMetaclass
#    def __init__(self, **default):
#        super(OAuthClient,self).__init__(**default)
#        self.__token = None
#        self.oauth_parameters = {
#            'oauth_version': '1.0',
#            'oauth_callback': 'http://127.0.0.1/callback',
#            'scope':'all'
#        }
#        
#    def __get_signed_request(self, method, path):
#        
#        oauth_request = Request.from_consumer_and_token(consumer=self.__consumer, token=self.__token, http_method=method, http_url=path, parameters=self.oauth_parameters)
#        oauth_request.sign_request(SignatureMethod_HMAC_SHA1(), consumer=self.__consumer, token=self.__token)
#        
#        return oauth_request
#    
#        
#    def set_consumer(self, key, secret):
#        self.__consumer = Consumer(key, secret)
#        self.oauth_parameters['oauth_consumer_key'] = key
#        
#    def set_scope(self, value):
#        self.oauth_parameters['scope'] = value
#
#    def __inject_oauth_data(self, path, method, data):
#        
#        path_parsed = urlparse(path)
#        
#        if method=='GET' and len(data) == 0:
#            data= parse_qs(path_parsed[4])
#            
#        clean_path = ['']*6
#        clean_path[0] = 'http'
#        clean_path[1] = 'testserver'
#        for i in range(0,4):
#            clean_path[i] = path_parsed[i] or clean_path[i]
#        path = urlunparse(clean_path)
#        
#        oauth_request = self.__get_signed_request(method, path)
#                
#        data.update(oauth_request)
#
#
#    def get(self, path, data={}, follow=False, **extra):
#        
#        self.__inject_oauth_data(path, 'GET', data)
#        return super(OAuthClient, self).get(path, data, follow, **extra)
#
#    
#    def post(self, path, data={}, content_type=MULTIPART_CONTENT,
#             follow=False, **extra):
#        self.__inject_oauth_data(path, 'POST', data)
#        return super(OAuthClient,self).post(path, data, content_type, follow, **extra)    
#    
#    def head(self, path, data={}, follow=False, **extra):
#        self.__inject_oauth_data(path, 'GET', data)
#        return super(OAuthClient, self).head(path, data, follow, **extra)
#    
#    def options(self, path, data={}, follow=False, **extra):
#        self.__inject_oauth_data(path, 'GET', data)
#        return options(OAuthClient, self).options(path, data, follow, **extra)
#    
#    def put(self, path, data={}, content_type=MULTIPART_CONTENT,
#            follow=False, **extra):
#        self.__inject_oauth_data(path, 'POST', data)
#        return super(OAuthClient,self).put(path, data, content_type, follow, **extra)
#    
#    def delete(self, path, data={}, follow=False, **extra):
#        self.__inject_oauth_data(path, 'GET', data)
#        return super(OAuthClient, self).delete(path, data, follow, **extra)
#    
#    ### TODO: better document errors
#    def login(self, **credential):
#        
#        #Obtaining a Request Token
#        resp = self.get(reverse('oauth_request_token'), follow=True)
#        if resp.status_code == 200:
#            self.__token = Token.from_string(resp.content)
#        else:
#            self.errors.write("oauth_request_token response status code fail : " + repr(resp))
#            return False
#                
#        #Requesting User Authorization
#        res = super(OAuthClient, self).login(**credential)
#        if not res:
#            self.errors.write("login failed : " + repr(credential))
#            return False
#
#        resp = self.get(reverse('oauth_user_authorization'))
#        if resp.status_code != 200:
#            self.errors.write("oauth_user_authorization get response status code fail : " + repr(resp))
#            return False
#        
#        resp = self.post(reverse('oauth_user_authorization'), {'authorize_access':1})
#        if resp.status_code != 302:
#            self.errors.write("oauth_user_authorization post response status code fail : " + repr(resp))
#            return False
#        
#        location_splitted = urlsplit(resp["Location"])
#        location_query_dict = parse_qs(location_splitted[3])
#        self.__token.verifier = location_query_dict['oauth_verifier']
#        
#                
#        #Obtaining an Access Token
#        resp = self.get(reverse('oauth_access_token'))
#        if resp.status_code == 200:            
#            self.__token = Token.from_string(resp.content)
#            return True
#        else:
#            self.errors.write("oauth_access_token get response status code fail : " + repr(resp))
#            return False
#        
#    def logout(self):
#        super(OAuthClient,self).logout()
#        self._token = None

class OAuthWebClient(WebClient):
    __metaclass__ = OAuthMetaclass
    servername = '127.0.0.1:8000'