from models import (Tweet, User, Hashtag, EntityHashtag, APPLICATION_NAME, ACCESS_TOKEN_SECRET, adapt_date, adapt_json, 
    ACCESS_TOKEN_KEY)
from sqlalchemy.sql import select, or_
import Queue
import codecs
import datetime
import email.utils
import logging
import math
import os.path
import socket
import sys
import twitter.oauth
import twitter.oauth_dance


CACHE_ACCESS_TOKEN = {}

def get_oauth_token(consumer_key, consumer_secret, token_file_path=None, check_access_token=True, application_name=APPLICATION_NAME):
    
    global CACHE_ACCESS_TOKEN

    if 'ACCESS_TOKEN_KEY' in globals() and 'ACCESS_TOKEN_SECRET' in globals() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
        return ACCESS_TOKEN_KEY,ACCESS_TOKEN_SECRET
    
    res = CACHE_ACCESS_TOKEN.get(application_name, None)
    
    if res is None and token_file_path and os.path.exists(token_file_path):
        get_logger().debug("get_oauth_token : reading token from file %s" % token_file_path) #@UndefinedVariable
        res = twitter.oauth.read_token_file(token_file_path)
    
    if res is not None and check_access_token:
        get_logger().debug("get_oauth_token : Check oauth tokens") #@UndefinedVariable
        t = twitter.Twitter(auth=twitter.OAuth(res[0], res[1], consumer_key, consumer_secret))
        status = None
        try:
            status = t.application.rate_limit_status(resources="account")
        except Exception as e:
            get_logger().debug("get_oauth_token : error getting rate limit status %s " % repr(e))            
            get_logger().debug("get_oauth_token : error getting rate limit status %s " % str(e))
            status = None
        get_logger().debug("get_oauth_token : Check oauth tokens : status %s" % repr(status)) #@UndefinedVariable
        if status is None or status.get("resources",{}).get("account",{}).get('/account/verify_credentials',{}).get('remaining',0) == 0:
            get_logger().debug("get_oauth_token : Problem with status %s" % repr(status))
            res = None

    if res is None:
        get_logger().debug("get_oauth_token : doing the oauth dance")
        res = twitter.oauth_dance(application_name, consumer_key, consumer_secret, token_file_path)
        
    
    CACHE_ACCESS_TOKEN[application_name] = res
    
    get_logger().debug("get_oauth_token : done got %s" % repr(res))
    return res

def parse_date(date_str):
    ts = email.utils.parsedate_tz(date_str) #@UndefinedVariable
    return datetime.datetime(*ts[0:7])

def clean_keys(dict_val):
    return dict([(str(key),value) for key,value in dict_val.items()])

fields_adapter = {
    'stream': {
        "tweet": {
            "created_at"    : adapt_date,
            "coordinates"   : adapt_json,
            "place"         : adapt_json,
            "geo"           : adapt_json,
#            "original_json" : adapt_json,
        },
        "user": {
            "created_at"  : adapt_date,
        },

    },
                  
    'entities' : {
        "medias": {
            "sizes"  : adapt_json,
        },                  
    },
    'rest': {
        "tweet" : {
            "place"         : adapt_json,
            "geo"           : adapt_json,
            "created_at"    : adapt_date,
#            "original_json" : adapt_json,
        }, 
    },
}

#
# adapt fields, return a copy of the field_dict with adapted fields
#
def adapt_fields(fields_dict, adapter_mapping):
    def adapt_one_field(field, value):
        if field in adapter_mapping and adapter_mapping[field] is not None:
            return adapter_mapping[field](value)
        else:
            return value
    return dict([(str(k),adapt_one_field(k,v)) for k,v in fields_dict.items()])    


class ObjectBufferProxy(object):
    def __init__(self, klass, args, kwargs, must_flush, instance=None):
        self.klass= klass
        self.args = args
        self.kwargs = kwargs
        self.must_flush = must_flush
        self.instance = instance
        
    def persists(self, session):
        new_args = [arg() if callable(arg) else arg for arg in self.args] if self.args is not None else []
        new_kwargs = dict([(k,v()) if callable(v) else (k,v) for k,v in self.kwargs.items()]) if self.kwargs is not None else {}
        
        if self.instance is None:
            self.instance = self.klass(*new_args, **new_kwargs)
        else:
            self.instance = self.klass(*new_args, **new_kwargs)
            self.instance = session.merge(self.instance)

        session.add(self.instance)
        if self.must_flush:
            session.flush()
            
    def __getattr__(self, name):
        return lambda : getattr(self.instance, name) if self.instance else None
        
        
    

class ObjectsBuffer(object):

    def __init__(self):
        self.__bufferlist = []
        self.__bufferdict = {}
    
    def __add_proxy_object(self, proxy):
        proxy_list =  self.__bufferdict.get(proxy.klass, None)
        if proxy_list is None:
            proxy_list = []
            self.__bufferdict[proxy.klass] = proxy_list
        proxy_list.append(proxy)
        self.__bufferlist.append(proxy)
        
    def persists(self, session):
        for object_proxy in self.__bufferlist:
            object_proxy.persists(session)
                
    def add_object(self, klass, args, kwargs, must_flush, instance=None):
        new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush, instance)
        self.__add_proxy_object(new_proxy)
        return new_proxy 
    
    def get(self, klass, **kwargs):
        if klass in self.__bufferdict:
            for proxy in self.__bufferdict[klass]:
                if proxy.kwargs is None or len(proxy.kwargs) == 0 or proxy.klass != klass:
                    continue
                found = True
                for k,v in kwargs.items():
                    if (k not in proxy.kwargs) or v != proxy.kwargs[k]:
                        found = False
                        break
                if found:
                    return proxy        
        return None
                


def set_logging(options, plogger=None, queue=None):
    
    logging_config = {
        "format" : '%(asctime)s %(levelname)s:%(name)s:%(message)s',
        "level" : max(logging.NOTSET, min(logging.CRITICAL, logging.WARNING - 10 * options.verbose + 10 * options.quiet)), #@UndefinedVariable
    }
    
    if options.logfile == "stdout":
        logging_config["stream"] = sys.stdout
    elif options.logfile == "stderr":
        logging_config["stream"] = sys.stderr
    else:
        logging_config["filename"] = options.logfile
            
    logger = plogger
    if logger is None:
        logger = get_logger() #@UndefinedVariable
    
    if len(logger.handlers) == 0:
        filename = logging_config.get("filename")
        if queue is not None:
            hdlr = QueueHandler(queue, True)
        elif filename:
            mode = logging_config.get("filemode", 'a')
            hdlr = logging.FileHandler(filename, mode) #@UndefinedVariable
        else:
            stream = logging_config.get("stream")
            hdlr = logging.StreamHandler(stream) #@UndefinedVariable
            
        fs = logging_config.get("format", logging.BASIC_FORMAT) #@UndefinedVariable
        dfs = logging_config.get("datefmt", None)
        fmt = logging.Formatter(fs, dfs) #@UndefinedVariable
        hdlr.setFormatter(fmt)
        logger.addHandler(hdlr)
        level = logging_config.get("level")
        if level is not None:
            logger.setLevel(level)
    
    options.debug = (options.verbose-options.quiet > 0)
    return logger

def set_logging_options(parser):
    parser.add_argument("-l", "--log", dest="logfile",
                      help="log to file", metavar="LOG", default="stderr")
    parser.add_argument("-v", dest="verbose", action="count",
                      help="verbose", default=0)
    parser.add_argument("-q", dest="quiet", action="count",
                      help="quiet", default=0)

def get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
    
    query = query.join(EntityHashtag).join(Hashtag)
    
    if tweet_exclude_table is not None:
        query = query.filter(~Tweet.id.in_(select([tweet_exclude_table.c.id]))) #@UndefinedVariable
    
    if start_date:
        query = query.filter(Tweet.created_at >=  start_date)
    if end_date:
        query = query.filter(Tweet.created_at <=  end_date)

    if user_whitelist:
        query = query.join(User).filter(User.screen_name.in_(user_whitelist))

    
    if hashtags :
        def merge_hash(l,h):
            l.extend(h.split(","))
            return l
        htags = reduce(merge_hash, hashtags, [])
        
        query = query.filter(or_(*map(lambda h: Hashtag.text.contains(h), htags))) #@UndefinedVariable
    
    return query

    
    
def get_filter_query(session, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
    
    query = session.query(Tweet)
    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist) 
    return query.order_by(Tweet.created_at)
    

def get_user_query(session, start_date, end_date, hashtags, tweet_exclude_table):
    
    query = session.query(User).join(Tweet)
    
    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, None)    
    
    return query.distinct()

logger_name = "iri.tweet"

def get_logger():
    global logger_name
    return logging.getLogger(logger_name) #@UndefinedVariable


# Next two import lines for this demo only

class QueueHandler(logging.Handler): #@UndefinedVariable
    """
    This is a logging handler which sends events to a multiprocessing queue.    
    """

    def __init__(self, queue, ignore_full):
        """
        Initialise an instance, using the passed queue.
        """
        logging.Handler.__init__(self) #@UndefinedVariable
        self.queue = queue
        self.ignore_full = True
        
    def emit(self, record):
        """
        Emit a record.

        Writes the LogRecord to the queue.
        """
        try:
            ei = record.exc_info
            if ei:
                dummy = self.format(record) # just to get traceback text into record.exc_text
                record.exc_info = None  # not needed any more
            if not self.ignore_full or (not self.queue.full()):
                self.queue.put_nowait(record)
        except AssertionError:
            pass
        except Queue.Full:
            if self.ignore_full:
                pass
            else:
                raise
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)

def show_progress(current_line, total_line, label, width, writer=None):

    if writer is None:
        writer = sys.stdout
        if sys.stdout.encoding is not None:
            writer = codecs.getwriter(sys.stdout.encoding)(sys.stdout)

    percent = (float(current_line) / float(total_line)) * 100.0

    marks = math.floor(width * (percent / 100.0))
    spaces = math.floor(width - marks)

    loader = u'[' + (u'=' * int(marks)) + (u' ' * int(spaces)) + u']'

    s = u"%s %3d%% %*d/%d - %*s\r" % (loader, percent, len(str(total_line)), current_line, total_line, width, label[:width])

    writer.write(s) #takes the header into account
    if percent >= 100:
        writer.write("\n")
    writer.flush()
    
    return writer

def get_unused_port():
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(('localhost', 0))
    _, port = s.getsockname()
    s.close()
    return port

