script/lib/iri_tweet/utils.py
author Yves-Marie Haussonne <1218002+ymph@users.noreply.github.com>
Wed, 07 Dec 2011 19:28:46 +0100
changeset 409 f7ceddf99d6d
parent 289 a5eff8f2b81d
child 411 0471e6eb8a1b
permissions -rw-r--r--
improve the user management. try to complete user information whenever possible.

from models import (Tweet, User, Hashtag, EntityHashtag, EntityUser, Url, 
    EntityUrl, CONSUMER_KEY, CONSUMER_SECRET, APPLICATION_NAME, ACCESS_TOKEN_KEY, 
    ACCESS_TOKEN_SECRET, adapt_date, adapt_json, TweetSource, TweetLog, MediaType, 
    Media, EntityMedia, Entity, EntityType)
from sqlalchemy.sql import select, or_ #@UnresolvedImport
import Queue #@UnresolvedImport
import anyjson #@UnresolvedImport
import datetime
import email.utils
import logging
import os.path
import sys
import twitter.oauth #@UnresolvedImport
import twitter.oauth_dance #@UnresolvedImport
import twitter_text #@UnresolvedImport


CACHE_ACCESS_TOKEN = {}

def get_oauth_token(token_file_path=None, check_access_token=True, application_name=APPLICATION_NAME, consumer_key=CONSUMER_KEY, consumer_secret=CONSUMER_SECRET):
    
    global CACHE_ACCESS_TOKEN

    if 'ACCESS_TOKEN_KEY' in globals() and 'ACCESS_TOKEN_SECRET' in globals() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
        return ACCESS_TOKEN_KEY,ACCESS_TOKEN_SECRET
    
    res = CACHE_ACCESS_TOKEN.get(application_name, None)
    
    if res is None and token_file_path and os.path.exists(token_file_path):
        get_logger().debug("get_oauth_token : reading token from file %s" % token_file_path) #@UndefinedVariable
        res = twitter.oauth.read_token_file(token_file_path)
    
    if res is not None and check_access_token:
        get_logger().debug("get_oauth_token : Check oauth tokens") #@UndefinedVariable
        t = twitter.Twitter(auth=twitter.OAuth(res[0], res[1], CONSUMER_KEY, CONSUMER_SECRET))
        status = None
        try:
            status = t.account.rate_limit_status()
        except Exception as e:
            get_logger().debug("get_oauth_token : error getting rate limit status %s" % repr(e))
            status = None
        get_logger().debug("get_oauth_token : Check oauth tokens : status %s" % repr(status)) #@UndefinedVariable
        if status is None or status['remaining_hits'] == 0:
            get_logger().debug("get_oauth_token : Problem with status %s" % repr(status))
            res = None

    if res is None:
        get_logger().debug("get_oauth_token : doing the oauth dance")
        res = twitter.oauth_dance.oauth_dance(application_name, consumer_key, consumer_secret, token_file_path)
    
    CACHE_ACCESS_TOKEN[application_name] = res
    
    return res

def parse_date(date_str):
    ts = email.utils.parsedate_tz(date_str) #@UndefinedVariable
    return datetime.datetime(*ts[0:7])

def clean_keys(dict_val):
    return dict([(str(key),value) for key,value in dict_val.items()])

fields_adapter = {
    'stream': {
        "tweet": {
            "created_at"    : adapt_date,
            "coordinates"   : adapt_json,
            "place"         : adapt_json,
            "geo"           : adapt_json,
#            "original_json" : adapt_json,
        },
        "user": {
            "created_at"  : adapt_date,
        },

    },
                  
    'entities' : {
        "medias": {
            "sizes"  : adapt_json,
        },                  
    },
    'rest': {
        "tweet" : {
            "place"         : adapt_json,
            "geo"           : adapt_json,
            "created_at"    : adapt_date,
#            "original_json" : adapt_json,
        }, 
    },
}

#
# adapt fields, return a copy of the field_dict with adapted fields
#
def adapt_fields(fields_dict, adapter_mapping):
    def adapt_one_field(field, value):
        if field in adapter_mapping and adapter_mapping[field] is not None:
            return adapter_mapping[field](value)
        else:
            return value
    return dict([(str(k),adapt_one_field(k,v)) for k,v in fields_dict.items()])    


class ObjectBufferProxy(object):
    def __init__(self, klass, args, kwargs, must_flush, instance=None):
        self.klass= klass
        self.args = args
        self.kwargs = kwargs
        self.must_flush = must_flush
        self.instance = instance
        
    def persists(self, session):
        new_args = [arg() if callable(arg) else arg for arg in self.args] if self.args is not None else []
        new_kwargs = dict([(k,v()) if callable(v) else (k,v) for k,v in self.kwargs.items()]) if self.kwargs is not None else {}
        
        if self.instance is None:
            self.instance = self.klass(*new_args, **new_kwargs)
        else:
            self.instance = self.klass(*new_args, **new_kwargs)
            self.instance = session.merge(self.instance)

        session.add(self.instance)
        if self.must_flush:
            session.flush()
            
    def __getattr__(self, name):
        return lambda : getattr(self.instance, name) if self.instance else None
        
        
    

class ObjectsBuffer(object):

    def __init__(self):
        self.__bufferlist = []
        self.__bufferdict = {}
    
    def __add_proxy_object(self, proxy):
        proxy_list =  self.__bufferdict.get(proxy.klass, None)
        if proxy_list is None:
            proxy_list = []
            self.__bufferdict[proxy.klass] = proxy_list
        proxy_list.append(proxy)
        self.__bufferlist.append(proxy)
        
    def persists(self, session):
        for object_proxy in self.__bufferlist:
            object_proxy.persists(session)
                
    def add_object(self, klass, args, kwargs, must_flush, instance=None):
        new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush, instance)
        self.__add_proxy_object(new_proxy)
        return new_proxy 
    
    def get(self, klass, **kwargs):
        if klass in self.__bufferdict:
            for proxy in self.__bufferdict[klass]:
                if proxy.kwargs is None or len(proxy.kwargs) == 0 or proxy.klass != klass:
                    continue
                found = True
                for k,v in kwargs.items():
                    if (k not in proxy.kwargs) or v != proxy.kwargs[k]:
                        found = False
                        break
                if found:
                    return proxy        
        return None
                
class TwitterProcessorException(Exception):
    pass

class TwitterProcessor(object):
    
    def __init__(self, json_dict, json_txt, source_id, session, access_token=None, token_filename=None):

        if json_dict is None and json_txt is None:
            raise TwitterProcessorException("No json")
        
        if json_dict is None:
            self.json_dict = anyjson.deserialize(json_txt)
        else:
            self.json_dict = json_dict
        
        if not json_txt:
            self.json_txt = anyjson.serialize(json_dict)
        else:
            self.json_txt = json_txt
        
        if "id" not in self.json_dict:
            raise TwitterProcessorException("No id in json")
        
        self.source_id = source_id
        self.session = session
        self.token_filename = token_filename
        self.access_token = access_token
        self.obj_buffer = ObjectsBuffer()
        


    def __get_user(self, user_dict, do_merge, query_twitter = False):
        get_logger().debug("Get user : " + repr(user_dict)) #@UndefinedVariable
        
        user_dict = adapt_fields(user_dict, fields_adapter["stream"]["user"])
    
        user_id = user_dict.get("id",None)    
        user_name = user_dict.get("screen_name", user_dict.get("name", None))
        
        if user_id is None and user_name is None:
            return None

        user = None
        if user_id:
            user = self.obj_buffer.get(User, id=user_id)
        else:
            user = self.obj_buffer.get(User, screen_name=user_name)

        #to do update user id needed            
        if user is not None:
            user_created_at = None
            if user.args is not None:
                user_created_at = user.args.get('created_at', None)
            if user_created_at is None and user_dict.get('created_at', None) is not None and do_merge:
                if user.args is None:
                    user.args = user_dict
                else:
                    user.args.update(user_dict)
            return user

        #todo : add methpds to objectbuffer to get buffer user
        user_obj = None
        if user_id:
            user_obj = self.session.query(User).filter(User.id == user_id).first()
        else:
            user_obj = self.session.query(User).filter(User.screen_name.ilike(user_name)).first()
    
        #todo update user if needed
        if user_obj is not None:            
            if user_obj.created_at is not None or user_dict.get('created_at', None) is None or not do_merge :
                user = ObjectBufferProxy(User, None, None, False, user_obj)
            else:
                user = self.obj_buffer.add_object(User, None, user_dict, True, user_obj)
            return user
    
        user_created_at = user_dict.get("created_at", None)
        
        if user_created_at is None and query_twitter:
            
            if self.access_token is not None:
                acess_token_key, access_token_secret = self.access_token
            else:
                acess_token_key, access_token_secret = get_oauth_token(self.token_filename)
            t = twitter.Twitter(auth=twitter.OAuth(acess_token_key, access_token_secret, CONSUMER_KEY, CONSUMER_SECRET))
            try:
                if user_id:
                    user_dict = t.users.show(user_id=user_id)
                else:
                    user_dict = t.users.show(screen_name=user_name)            
            except Exception as e:
                get_logger().info("get_user : TWITTER ERROR : " + repr(e)) #@UndefinedVariable
                get_logger().info("get_user : TWITTER ERROR : " + str(e)) #@UndefinedVariable
                return None
            
        if "id" not in user_dict:
            return None
        
        #TODO filter get, wrap in proxy
        user_obj = self.session.query(User).filter(User.id == user_dict["id"]).first()
        
        if user_obj is not None and not do_merge:
            return ObjectBufferProxy(User, None, None, False, user_obj)
        else:        
            return self.obj_buffer.add_object(User, None, user_dict, True)        

    def __get_or_create_object(self, klass, filter_by_kwargs, filter, creation_kwargs, must_flush, do_merge):
        
        obj_proxy = self.obj_buffer.get(klass, **filter_by_kwargs)
        if obj_proxy is None:
            query = self.session.query(klass)
            if filter is not None:
                query = query.filter(filter)
            else:
                query = query.filter_by(**filter_by_kwargs)
            obj_instance = query.first()
            if obj_instance is not None:
                if not do_merge:
                    obj_proxy = ObjectBufferProxy(klass, None, None, False, obj_instance)
                else:
                    obj_proxy = self.obj_buffer.add_object(klass, None, creation_kwargs, must_flush, obj_instance)
        if obj_proxy is None:
            obj_proxy = self.obj_buffer.add_object(klass, None, creation_kwargs, must_flush)
        return obj_proxy


    def __process_entity(self, ind, ind_type):
        get_logger().debug("Process_entity : " + repr(ind) + " : " + repr(ind_type)) #@UndefinedVariable
        
        ind = clean_keys(ind)
        
        entity_type = self.__get_or_create_object(EntityType, {'label':ind_type}, None, {'label':ind_type}, True, False)
        
        entity_dict = {
           "indice_start"   : ind["indices"][0],
           "indice_end"     : ind["indices"][1],
           "tweet_id"       : self.tweet.id,
           "entity_type_id" : entity_type.id,
           "source"         : adapt_json(ind)
        }

        def process_medias():
            
            media_id = ind.get('id', None)
            if media_id is None:
                return None, None
            
            type_str = ind.get("type", "photo")
            media_type = self.__get_or_create_object(MediaType, {'label': type_str}, None, {'label':type_str}, True, False)
            media_ind = adapt_fields(ind, fields_adapter["entities"]["medias"])
            if "type" in media_ind:
                del(media_ind["type"])
            media_ind['type_id'] = media_type.id            
            media = self.__get_or_create_object(Media, {'id':media_id}, None, media_ind, True, False)
            
            entity_dict['media_id'] = media.id
            return EntityMedia, entity_dict

        def process_hashtags():
            text = ind.get("text", ind.get("hashtag", None))
            if text is None:
                return None, None
            ind['text'] = text
            hashtag = self.__get_or_create_object(Hashtag, {'text':text}, Hashtag.text.ilike(text), ind, True, False)
            entity_dict['hashtag_id'] = hashtag.id
            return EntityHashtag, entity_dict             
        
        def process_user_mentions():
            user_mention = self.__get_user(ind, False, False)
            if user_mention is None:
                entity_dict['user_id'] = None
            else:
                entity_dict['user_id'] = user_mention.id
            return EntityUser, entity_dict
        
        def process_urls():
            url = self.__get_or_create_object(Url, {'url':ind["url"]}, None, ind, True, False)
            entity_dict['url_id'] = url.id
            return EntityUrl, entity_dict
                
        #{'': lambda }
        entity_klass, entity_dict =  { 
            'hashtags': process_hashtags,
            'user_mentions' : process_user_mentions,
            'urls' : process_urls,
            'media': process_medias,
            }.get(ind_type, lambda: (Entity, entity_dict))()
            
        get_logger().debug("Process_entity entity_dict: " + repr(entity_dict)) #@UndefinedVariable
        if entity_klass:
            self.obj_buffer.add_object(entity_klass, None, entity_dict, False)


    def __process_twitter_stream(self):
        
        tweet_nb = self.session.query(Tweet).filter(Tweet.id == self.json_dict["id"]).count()
        if tweet_nb > 0:
            return
        
        ts_copy = adapt_fields(self.json_dict, fields_adapter["stream"]["tweet"])
        
        # get or create user
        user = self.__get_user(self.json_dict["user"], True)
        if user is None:
            get_logger().warning("USER not found " + repr(self.json_dict["user"])) #@UndefinedVariable
            ts_copy["user_id"] = None
        else:
            ts_copy["user_id"] = user.id
            
        del(ts_copy['user'])
        ts_copy["tweet_source_id"] = self.source_id
        
        self.tweet = self.obj_buffer.add_object(Tweet, None, ts_copy, True)
            
        self.__process_entities()


    def __process_entities(self):
        if "entities" in self.json_dict:
            for ind_type, entity_list in self.json_dict["entities"].items():
                for ind in entity_list:
                    self.__process_entity(ind, ind_type)
        else:
            
            text = self.tweet.text
            extractor = twitter_text.Extractor(text)
            for ind in extractor.extract_hashtags_with_indices():
                self.__process_entity(ind, "hashtags")
            
            for ind in extractor.extract_urls_with_indices():
                self.__process_entity(ind, "urls")
            
            for ind in extractor.extract_mentioned_screen_names_with_indices():
                self.__process_entity(ind, "user_mentions")

    def __process_twitter_rest(self):
        tweet_nb = self.session.query(Tweet).filter(Tweet.id == self.json_dict["id"]).count()
        if tweet_nb > 0:
            return
        
        
        tweet_fields = {
            'created_at': self.json_dict["created_at"], 
            'favorited': False,
            'id': self.json_dict["id"],
            'id_str': self.json_dict["id_str"],
            #'in_reply_to_screen_name': ts["to_user"], 
            'in_reply_to_user_id': self.json_dict["to_user_id"],
            'in_reply_to_user_id_str': self.json_dict["to_user_id_str"],
            #'place': ts["place"],
            'source': self.json_dict["source"],
            'text': self.json_dict["text"],
            'truncated': False,            
            'tweet_source_id' : self.source_id,
        }
        
        #user
    
        user_fields = {
            'lang' : self.json_dict.get('iso_language_code',None),
            'profile_image_url' : self.json_dict["profile_image_url"],
            'screen_name' : self.json_dict["from_user"],                   
        }
        
        user = self.__get_user(user_fields, do_merge=False)
        if user is None:
            get_logger().warning("USER not found " + repr(user_fields)) #@UndefinedVariable
            tweet_fields["user_id"] = None
        else:
            tweet_fields["user_id"] = user.id
        
        tweet_fields = adapt_fields(tweet_fields, fields_adapter["rest"]["tweet"])
        self.tweet = self.obj_buffer.add_object(Tweet, None, tweet_fields, True)
                
        self.__process_entities()



    def process(self):
        
        if self.source_id is None:
            tweet_source = self.obj_buffer.add_object(TweetSource, None, {'original_json':self.json_txt}, True)
            self.source_id = tweet_source.id
        
        if "metadata" in self.json_dict:
            self.__process_twitter_rest()
        else:
            self.__process_twitter_stream()

        self.obj_buffer.add_object(TweetLog, None, {'tweet_source_id':self.source_id, 'status':TweetLog.TWEET_STATUS['OK']}, True)
        
        self.obj_buffer.persists(self.session)


def set_logging(options, plogger=None, queue=None):
    
    logging_config = {
        "format" : '%(asctime)s %(levelname)s:%(name)s:%(message)s',
        "level" : max(logging.NOTSET, min(logging.CRITICAL, logging.WARNING - 10 * options.verbose + 10 * options.quiet)), #@UndefinedVariable
    }
    
    if options.logfile == "stdout":
        logging_config["stream"] = sys.stdout
    elif options.logfile == "stderr":
        logging_config["stream"] = sys.stderr
    else:
        logging_config["filename"] = options.logfile
            
    logger = plogger
    if logger is None:
        logger = get_logger() #@UndefinedVariable
    
    if len(logger.handlers) == 0:
        filename = logging_config.get("filename")
        if queue is not None:
            hdlr = QueueHandler(queue, True)
        elif filename:
            mode = logging_config.get("filemode", 'a')
            hdlr = logging.FileHandler(filename, mode) #@UndefinedVariable
        else:
            stream = logging_config.get("stream")
            hdlr = logging.StreamHandler(stream) #@UndefinedVariable
            
        fs = logging_config.get("format", logging.BASIC_FORMAT) #@UndefinedVariable
        dfs = logging_config.get("datefmt", None)
        fmt = logging.Formatter(fs, dfs) #@UndefinedVariable
        hdlr.setFormatter(fmt)
        logger.addHandler(hdlr)
        level = logging_config.get("level")
        if level is not None:
            logger.setLevel(level)
    
    options.debug = (options.verbose-options.quiet > 0)
    return logger

def set_logging_options(parser):
    parser.add_option("-l", "--log", dest="logfile",
                      help="log to file", metavar="LOG", default="stderr")
    parser.add_option("-v", dest="verbose", action="count",
                      help="verbose", metavar="VERBOSE", default=0)
    parser.add_option("-q", dest="quiet", action="count",
                      help="quiet", metavar="QUIET", default=0)

def get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
    
    query = query.join(EntityHashtag).join(Hashtag)
    
    if tweet_exclude_table is not None:
        query = query.filter(~Tweet.id.in_(select([tweet_exclude_table.c.id]))) #@UndefinedVariable
    
    if start_date:
        query = query.filter(Tweet.created_at >=  start_date)
    if end_date:
        query = query.filter(Tweet.created_at <=  end_date)

    if user_whitelist:
        query = query.join(User).filter(User.screen_name.in_(user_whitelist))

    
    if hashtags :
        def merge_hash(l,h):
            l.extend(h.split(","))
            return l
        htags = reduce(merge_hash, hashtags, [])
        
        query = query.filter(or_(*map(lambda h: Hashtag.text.contains(h), htags))) #@UndefinedVariable
    
    return query

    
    
def get_filter_query(session, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
    
    query = session.query(Tweet)
    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist) 
    return query.order_by(Tweet.created_at)
    

def get_user_query(session, start_date, end_date, hashtags, tweet_exclude_table):
    
    query = session.query(User).join(Tweet)
    
    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, None)    
    
    return query.distinct()

logger_name = "iri.tweet"

def get_logger():
    global logger_name
    return logging.getLogger(logger_name) #@UndefinedVariable


# Next two import lines for this demo only

class QueueHandler(logging.Handler): #@UndefinedVariable
    """
    This is a logging handler which sends events to a multiprocessing queue.    
    """

    def __init__(self, queue, ignore_full):
        """
        Initialise an instance, using the passed queue.
        """
        logging.Handler.__init__(self) #@UndefinedVariable
        self.queue = queue
        self.ignore_full = True
        
    def emit(self, record):
        """
        Emit a record.

        Writes the LogRecord to the queue.
        """
        try:
            ei = record.exc_info
            if ei:
                dummy = self.format(record) # just to get traceback text into record.exc_text
                record.exc_info = None  # not needed any more
            if not self.ignore_full or not self.queue.full():
                self.queue.put_nowait(record)
        except Queue.Full:
            if self.ignore_full:
                pass
            else:
                raise
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            self.handleError(record)