# -*- coding: utf-8 -*-
'''
Created on Mar 22, 2012

@author: ymh

Module directly inspired by tweetstream

'''
import time
import requests
from requests.utils import stream_untransfer, stream_decode_response_unicode
import anyjson
import select

from . import  USER_AGENT, ConnectionError, AuthenticationError


def iter_content_non_blocking(req, max_chunk_size=4096, decode_unicode=False, timeout=-1):
    
    if req._content_consumed:
        raise RuntimeError(
            'The content for this response was already consumed'
        )
    
    req.raw._fp.fp._sock.setblocking(False)
    
    def generate():
        chunk_size = 1        
        while True:
            if timeout < 0:
                rlist,_,_ = select.select([req.raw._fp.fp._sock], [], [])
            else:
                rlist,_,_ = select.select([req.raw._fp.fp._sock], [], [], timeout)
                
            if not rlist:                 
                continue
            
            try:
                chunk = req.raw.read(chunk_size)            
                if not chunk:
                    break
                if len(chunk) >= chunk_size and chunk_size < max_chunk_size:
                    chunk_size = min(chunk_size*2, max_chunk_size)
                elif len(chunk) < chunk_size/2 and chunk_size < max_chunk_size:
                    chunk_size = max(chunk_size/2,1)
                yield chunk
            except requests.exceptions.SSLError as e:
                if e.errno == 2:
                    # Apparently this means there was nothing in the socket buf
                    pass
                else:
                    raise                
            
        req._content_consumed = True

    gen = stream_untransfer(generate(), req)

    if decode_unicode:
        gen = stream_decode_response_unicode(gen, req)

    return gen

    
    

class BaseStream(object):

    """A network connection to Twitters streaming API

    :param auth: tweepy auth object.
    :keyword catchup: Number of tweets from the past to get before switching to
      live stream.
    :keyword raw: If True, return each tweet's raw data direct from the socket,
      without UTF8 decoding or parsing, rather than a parsed object. The
      default is False.
    :keyword timeout: If non-None, set a timeout in seconds on the receiving
      socket. Certain types of network problems (e.g., disconnecting a VPN)
      can cause the connection to hang, leading to indefinite blocking that
      requires kill -9 to resolve. Setting a timeout leads to an orderly
      shutdown in these cases. The default is None (i.e., no timeout).
    :keyword url: Endpoint URL for the object. Note: you should not
      need to edit this. It's present to make testing easier.

    .. attribute:: connected

        True if the object is currently connected to the stream.

    .. attribute:: url

        The URL to which the object is connected

    .. attribute:: starttime

        The timestamp, in seconds since the epoch, the object connected to the
        streaming api.

    .. attribute:: count

        The number of tweets that have been returned by the object.

    .. attribute:: rate

        The rate at which tweets have been returned from the object as a
        float. see also :attr: `rate_period`.

    .. attribute:: rate_period

        The ammount of time to sample tweets to calculate tweet rate. By
        default 10 seconds. Changes to this attribute will not be reflected
        until the next time the rate is calculated. The rate of tweets vary
        with time of day etc. so it's usefull to set this to something
        sensible.

    .. attribute:: user_agent

        User agent string that will be included in the request. NOTE: This can
        not be changed after the connection has been made. This property must
        thus be set before accessing the iterator. The default is set in
        :attr: `USER_AGENT`.
    """

    def __init__(self, auth,
                 catchup=None, raw=False, timeout=-1, url=None, compressed=False, chunk_size=4096, logger=None):
        self._conn = None
        self._rate_ts = None
        self._rate_cnt = 0
        self._auth = auth
        self._catchup_count = catchup
        self.raw_mode = raw
        self.timeout = timeout
        self._compressed = compressed

        self.rate_period = 10  # in seconds
        self.connected = False
        self.starttime = None
        self.count = 0
        self.rate = 0
        self.user_agent = USER_AGENT
        self.chunk_size = chunk_size
        if url: self.url = url
        
        self.muststop = False
        self._logger = logger
        
        self._iter = self.__iter__()
         

    def __enter__(self):
        return self

    def __exit__(self, *params):
        self.close()
        return False

    def _init_conn(self):
        """Open the connection to the twitter server"""
        
        if self._logger : self._logger.debug("BaseStream Open the connection to the twitter server")
        
        headers = {'User-Agent': self.user_agent}
        
        if self._compressed:
            headers['Accept-Encoding'] = "deflate, gzip"

        postdata = self._get_post_data() or {}
        postdata['stall_warnings'] = 'true'
        if self._catchup_count:
            postdata["count"] = self._catchup_count
        
        if self._auth:
            self._auth.apply_auth(self.url, "POST", headers, postdata)

        if self._logger : self._logger.debug("BaseStream init connection url " + repr(self.url))
        if self._logger : self._logger.debug("BaseStream init connection headers " + repr(headers))
        if self._logger : self._logger.debug("BaseStream init connection data " + repr(postdata))
        
        self._resp = requests.post(self.url, headers=headers, data=postdata, prefetch=False)
        if self._logger : self._logger.debug("BaseStream init connection " + repr(self._resp))
        
        self._resp.raise_for_status()
        self.connected = True

        if not self._rate_ts:
            self._rate_ts = time.time()
        if not self.starttime:
            self.starttime = time.time()


    def _get_post_data(self):
        """Subclasses that need to add post data to the request can override
        this method and return post data. The data should be in the format
        returned by urllib.urlencode."""
        return None

    def testmuststop(self):
        if callable(self.muststop):
            return self.muststop()
        else:
            return self.muststop
    
    def _update_rate(self):
        rate_time = time.time() - self._rate_ts
        if not self._rate_ts or rate_time > self.rate_period:
            self.rate = self._rate_cnt / rate_time
            self._rate_cnt = 0
            self._rate_ts = time.time()
            
    def _iter_object(self):
        pending = None
        has_stopped = False

#        for chunk in iter_content_non_blocking(self._resp,
#            max_chunk_size=self.chunk_size,
#            decode_unicode=False,
#            timeout=self.timeout):
        for chunk in self._resp.iter_content(
            chunk_size=self.chunk_size,
            decode_unicode=False):

            if self.testmuststop():
                has_stopped = True
                break

            if pending is not None:
                chunk = pending + chunk
            lines = chunk.split('\r')

            if chunk and lines[-1] and lines[-1][-1] == chunk[-1]:
                pending = lines.pop()
            else:
                pending = None

            for line in lines:
                yield line.strip('\n')

            if self.testmuststop():
                has_stopped = True
                break

        if pending is not None:
            yield pending
        if has_stopped:
            raise StopIteration()
            
    def __iter__(self):
        
        if self._logger : self._logger.debug("BaseStream __iter__")
        if not self.connected:
            if self._logger : self._logger.debug("BaseStream __iter__ not connected, connecting")
            self._init_conn()

        if self._logger : self._logger.debug("BaseStream __iter__ connected")
        for line in self._iter_object():

            if not line:
                continue
                            
            if (self.raw_mode):
                tweet = line
            else:
                line = line.decode("utf8")
                try:
                    tweet = anyjson.deserialize(line)
                except ValueError:
                    self.close()
                    raise ConnectionError("Got invalid data from twitter", details=line)
            if 'text' in tweet:
                self.count += 1
                self._rate_cnt += 1
            self._update_rate()
            yield tweet


    def next(self):
        """Return the next available tweet. This call is blocking!"""
        return self._iter.next()


    def close(self):
        """
        Close the connection to the streaming server.
        """
        self.connected = False


class FilterStream(BaseStream):
    url = "https://stream.twitter.com/1.1/statuses/filter.json"

    def __init__(self, auth, follow=None, locations=None,
                 track=None, catchup=None, url=None, raw=False, timeout=None, compressed=False, chunk_size=4096, logger=None):
        self._follow = follow
        self._locations = locations
        self._track = track
        # remove follow, locations, track
        BaseStream.__init__(self, auth, url=url, raw=raw, catchup=catchup, timeout=timeout, compressed=compressed, chunk_size=chunk_size, logger=logger)

    def _get_post_data(self):
        postdata = {}
        if self._follow: postdata["follow"] = ",".join([str(e) for e in self._follow])
        if self._locations: postdata["locations"] = ",".join(self._locations)
        if self._track: postdata["track"] = ",".join(self._track)
        return postdata


class SafeStreamWrapper(object):
    
    def __init__(self, base_stream, logger=None, error_cb=None, max_reconnects=-1, initial_tcp_wait=250, initial_http_wait=5000, max_wait=240000):
        self._stream = base_stream
        self._logger = logger
        self._error_cb = error_cb
        self._max_reconnects = max_reconnects
        self._initial_tcp_wait = initial_tcp_wait
        self._initial_http_wait = initial_http_wait
        self._max_wait = max_wait
        self._retry_wait = 0
        self._retry_nb = 0

    def __post_process_error(self,e):
        # Note: error_cb is not called on the last error since we
        # raise a ConnectionError instead
        if  callable(self._error_cb):
            self._error_cb(e)
        if self._logger: self._logger.info("stream sleeping for %d ms " % self._retry_wait)
        time.sleep(float(self._retry_wait)/1000.0)
        
        
    def __process_tcp_error(self,e):
        if self._logger: self._logger.debug("connection error :" + str(e))
        self._reconnects += 1
        if self._max_reconnects >= 0 and self._reconnects > self._max_reconnects:
            raise ConnectionError("Too many retries")
        if self._retry_wait < self._max_wait:
            self._retry_wait += self._initial_tcp_wait
            if self._retry_wait > self._max_wait:
                self._retry_wait = self._max_wait
        
        self.__post_process_error(e)

        
    def __process_http_error(self,e):
        if self._logger: self._logger.debug("http error on %s : %s" % (e.response.url,e.message))
        if self._retry_wait < self._max_wait:
            self._retry_wait = 2*self._retry_wait if self._retry_wait > 0 else self._initial_http_wait
            if self._retry_wait > self._max_wait:
                self._retry_wait = self._max_wait
        
        self.__post_process_error(e)
        
    def __iter__(self):
        while not self._stream.testmuststop():
            self._retry_nb += 1
            try:
                if self._logger: self._logger.debug("inner loop")
                for tweet in self._stream:
                    if self._logger: self._logger.debug("tweet : " + repr(tweet))
                    self._reconnects = 0
                    self._retry_wait = 0
                    if "warning" in tweet:
                        if self._logger: self._logger.warning("Tweet warning received : %s" % repr(tweet))
                        continue
                    if not tweet.strip():
                        if self._logger: self._logger.debug("Empty Tweet received : PING")
                        continue
                    yield tweet
            except (ConnectionError, requests.exceptions.ConnectionError, requests.exceptions.Timeout, requests.exceptions.RequestException) as e:
                self.__process_tcp_error(e)
            except requests.exceptions.HTTPError as e:
                if e.response.status_code == 401 and self._retry_nb <= 1:
                    raise AuthenticationError("Error connecting to %s : %s" % (e.response.url,e.message))
                if e.response.status_code > 200:
                    self.__process_http_error(e)
                else:
                    self.__process_tcp_error(e)

        
    
    