script/lib/iri_tweet/utils.py
changeset 254 2209e66bb50b
parent 253 e9335ee3cf71
child 255 500cd0405c7a
--- a/script/lib/iri_tweet/utils.py	Tue Aug 09 13:07:23 2011 +0200
+++ b/script/lib/iri_tweet/utils.py	Fri Aug 12 18:17:27 2011 +0200
@@ -1,6 +1,7 @@
 from models import (Tweet, User, Hashtag, EntityHashtag, EntityUser, Url, 
     EntityUrl, CONSUMER_KEY, CONSUMER_SECRET, APPLICATION_NAME, ACCESS_TOKEN_KEY, 
-    ACCESS_TOKEN_SECRET, adapt_date, adapt_json, TweetSource, TweetLog)
+    ACCESS_TOKEN_SECRET, adapt_date, adapt_json, TweetSource, TweetLog, MediaType, 
+    Media, EntityMedia, Entity, EntityType)
 from sqlalchemy.sql import select, or_ #@UnresolvedImport
 import anyjson #@UnresolvedImport
 import datetime
@@ -16,24 +17,40 @@
 
 CACHE_ACCESS_TOKEN = {}
 
-def get_oauth_token(token_file_path=None, application_name=APPLICATION_NAME, consumer_key=CONSUMER_KEY, consumer_secret=CONSUMER_SECRET):
+def get_oauth_token(token_file_path=None, check_access_token=True, application_name=APPLICATION_NAME, consumer_key=CONSUMER_KEY, consumer_secret=CONSUMER_SECRET):
     
     global CACHE_ACCESS_TOKEN
-    
-    if CACHE_ACCESS_TOKEN is not None and application_name in CACHE_ACCESS_TOKEN:
-        return CACHE_ACCESS_TOKEN[application_name]
-    
-    if token_file_path and os.path.exists(token_file_path):
-        logging.debug("reading token from file %s" % token_file_path) #@UndefinedVariable
-        CACHE_ACCESS_TOKEN[application_name] = twitter.oauth.read_token_file(token_file_path)
-        return CACHE_ACCESS_TOKEN[application_name]
-        #read access token info from path
-    
-    if 'ACCESS_TOKEN_KEY' in dict() and 'ACCESS_TOKEN_SECRET' in dict() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
+
+    if 'ACCESS_TOKEN_KEY' in globals() and 'ACCESS_TOKEN_SECRET' in globals() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
         return ACCESS_TOKEN_KEY,ACCESS_TOKEN_SECRET
     
-    CACHE_ACCESS_TOKEN[application_name] = twitter.oauth_dance.oauth_dance(application_name, consumer_key, consumer_secret, token_file_path)
-    return CACHE_ACCESS_TOKEN[application_name]
+    res = CACHE_ACCESS_TOKEN.get(application_name, None)
+    
+    if res is None and token_file_path and os.path.exists(token_file_path):
+        get_logger().debug("get_oauth_token : reading token from file %s" % token_file_path) #@UndefinedVariable
+        res = twitter.oauth.read_token_file(token_file_path)
+    
+    if res is not None and check_access_token:
+        get_logger().debug("get_oauth_token : Check oauth tokens") #@UndefinedVariable
+        t = twitter.Twitter(auth=twitter.OAuth(res[0], res[1], CONSUMER_KEY, CONSUMER_SECRET))
+        status = None
+        try:
+            status = t.account.rate_limit_status()
+        except Exception as e:
+            get_logger().debug("get_oauth_token : error getting rate limit status %s" % repr(e))
+            status = None
+        get_logger().debug("get_oauth_token : Check oauth tokens : status %s" % repr(status)) #@UndefinedVariable
+        if status is None or status['remaining_hits'] == 0:
+            get_logger().debug("get_oauth_token : Problem with status %s" % repr(status))
+            res = None
+
+    if res is None:
+        get_logger().debug("get_oauth_token : doing the oauth dance")
+        res = twitter.oauth_dance.oauth_dance(application_name, consumer_key, consumer_secret, token_file_path)
+    
+    CACHE_ACCESS_TOKEN[application_name] = res
+    
+    return res
 
 def parse_date(date_str):
     ts = email.utils.parsedate_tz(date_str) #@UndefinedVariable
@@ -54,6 +71,13 @@
         "user": {
             "created_at"  : adapt_date,
         },
+
+    },
+                  
+    'entities' : {
+        "medias": {
+            "sizes"  : adapt_json,
+        },                  
     },
     'rest': {
         "tweet" : {
@@ -89,7 +113,12 @@
         new_args = [arg() if callable(arg) else arg for arg in self.args] if self.args is not None else []
         new_kwargs = dict([(k,v()) if callable(v) else (k,v) for k,v in self.kwargs.items()]) if self.kwargs is not None else {}
         
-        self.instance = self.klass(*new_args, **new_kwargs)
+        if self.instance is None:
+            self.instance = self.klass(*new_args, **new_kwargs)
+        else:
+            self.instance = self.klass(*new_args, **new_kwargs)
+            self.instance = session.merge(self.instance)
+
         session.add(self.instance)
         if self.must_flush:
             session.flush()
@@ -104,40 +133,45 @@
 
     def __init__(self):
         self.__bufferlist = []
+        self.__bufferdict = {}
+    
+    def __add_proxy_object(self, proxy):
+        proxy_list =  self.__bufferdict.get(proxy.klass, None)
+        if proxy_list is None:
+            proxy_list = []
+            self.__bufferdict[proxy.klass] = proxy_list
+        proxy_list.append(proxy)
+        self.__bufferlist.append(proxy)
         
     def persists(self, session):
         for object_proxy in self.__bufferlist:
             object_proxy.persists(session)
-            
-    def add_object(self, klass, args, kwargs, must_flush):
-        new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush)
-        self.__bufferlist.append(new_proxy)
+                
+    def add_object(self, klass, args, kwargs, must_flush, instance=None):
+        new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush, instance)
+        self.__add_proxy_object(new_proxy)
         return new_proxy 
     
     def get(self, klass, **kwargs):
-        for proxy in self.__bufferlist:
-            if proxy.kwargs is None or len(proxy.kwargs) == 0 or proxy.klass != klass:
-                continue
-            found = True
-            for k,v in kwargs.items():
-                if (k not in proxy.kwargs) or v != proxy.kwargs[k]:
-                    found = False
-                    break
-            if found:
-                return proxy
-        
+        if klass in self.__bufferdict:
+            for proxy in self.__bufferdict[klass]:
+                if proxy.kwargs is None or len(proxy.kwargs) == 0 or proxy.klass != klass:
+                    continue
+                found = True
+                for k,v in kwargs.items():
+                    if (k not in proxy.kwargs) or v != proxy.kwargs[k]:
+                        found = False
+                        break
+                if found:
+                    return proxy        
         return None
                 
-                    
-        
-
-
 class TwitterProcessorException(Exception):
     pass
 
 class TwitterProcessor(object):
     
-    def __init__(self, json_dict, json_txt, source_id, session, token_filename=None):
+    def __init__(self, json_dict, json_txt, source_id, session, access_token=None, token_filename=None):
 
         if json_dict is None and json_txt is None:
             raise TwitterProcessorException("No json")
@@ -158,11 +192,13 @@
         self.source_id = source_id
         self.session = session
         self.token_filename = token_filename
+        self.access_token = access_token
         self.obj_buffer = ObjectsBuffer()
+        
 
 
-    def __get_user(self, user_dict):
-        logger.debug("Get user : " + repr(user_dict)) #@UndefinedVariable
+    def __get_user(self, user_dict, do_merge, query_twitter = False):
+        get_logger().debug("Get user : " + repr(user_dict)) #@UndefinedVariable
     
         user_id = user_dict.get("id",None)    
         user_name = user_dict.get("screen_name", user_dict.get("name", None))
@@ -192,8 +228,12 @@
     
         user_created_at = user_dict.get("created_at", None)
         
-        if user_created_at is None:
-            acess_token_key, access_token_secret = get_oauth_token(self.token_filename)
+        if user_created_at is None and query_twitter:
+            
+            if self.access_token is not None:
+                acess_token_key, access_token_secret = self.access_token
+            else:
+                acess_token_key, access_token_secret = get_oauth_token(self.token_filename)
             t = twitter.Twitter(auth=twitter.OAuth(acess_token_key, access_token_secret, CONSUMER_KEY, CONSUMER_SECRET))
             try:
                 if user_id:
@@ -201,8 +241,8 @@
                 else:
                     user_dict = t.users.show(screen_name=user_name)            
             except Exception as e:
-                logger.info("get_user : TWITTER ERROR : " + repr(e)) #@UndefinedVariable
-                logger.info("get_user : TWITTER ERROR : " + str(e)) #@UndefinedVariable
+                get_logger().info("get_user : TWITTER ERROR : " + repr(e)) #@UndefinedVariable
+                get_logger().info("get_user : TWITTER ERROR : " + str(e)) #@UndefinedVariable
                 return None
     
         user_dict = adapt_fields(user_dict, fields_adapter["stream"]["user"])
@@ -210,45 +250,79 @@
             return None
         
         #TODO filter get, wrap in proxy
-        user = self.session.query(User).filter(User.id == user_dict["id"]).first()
+        user_obj = self.session.query(User).filter(User.id == user_dict["id"]).first()
         
-        if user is not None:
-            return user
+        if user_obj is not None:
+            if not do_merge:
+                return ObjectBufferProxy(User, None, None, False, user_obj)
         
         user = self.obj_buffer.add_object(User, None, user_dict, True)
         
         return user
 
+    def __get_or_create_object(self, klass, filter_by_kwargs, filter, creation_kwargs, must_flush, do_merge):
+        
+        obj_proxy = self.obj_buffer.get(klass, **filter_by_kwargs)
+        if obj_proxy is None:
+            query = self.session.query(klass)
+            if filter is not None:
+                query = query.filter(filter)
+            else:
+                query = query.filter_by(**filter_by_kwargs)
+            obj_instance = query.first()
+            if obj_instance is not None:
+                if not do_merge:
+                    obj_proxy = ObjectBufferProxy(klass, None, None, False, obj_instance)
+                else:
+                    obj_proxy = self.obj_buffer.add_object(klass, None, creation_kwargs, must_flush, obj_instance)
+        if obj_proxy is None:
+            obj_proxy = self.obj_buffer.add_object(klass, None, creation_kwargs, must_flush)
+        return obj_proxy
+
 
     def __process_entity(self, ind, ind_type):
-        logger.debug("Process_entity : " + repr(ind) + " : " + repr(ind_type)) #@UndefinedVariable
+        get_logger().debug("Process_entity : " + repr(ind) + " : " + repr(ind_type)) #@UndefinedVariable
         
         ind = clean_keys(ind)
         
+        entity_type = self.__get_or_create_object(EntityType, {'label':ind_type}, None, {'label':ind_type}, True, False)
+        
         entity_dict = {
-           "indice_start": ind["indices"][0],
-           "indice_end"  : ind["indices"][1],
-           "tweet_id"    : self.tweet.id,
+           "indice_start"   : ind["indices"][0],
+           "indice_end"     : ind["indices"][1],
+           "tweet_id"       : self.tweet.id,
+           "entity_type_id" : entity_type.id,
+           "source"         : adapt_json(ind)
         }
-    
+
+        def process_medias():
+            
+            media_id = ind.get('id', None)
+            if media_id is None:
+                return None, None
+            
+            type_str = ind.get("type", "photo")
+            media_type = self.__get_or_create_object(MediaType, {'label': type_str}, None, {'label':type_str}, True, False)
+            media_ind = adapt_fields(ind, fields_adapter["entities"]["medias"])
+            if "type" in media_ind:
+                del(media_ind["type"])
+            media_ind['type_id'] = media_type.id            
+            media = self.__get_or_create_object(Media, {'id':media_id}, None, media_ind, True, False)
+            
+            entity_dict['media_id'] = media.id
+            return EntityMedia, entity_dict
+
         def process_hashtags():
             text = ind.get("text", ind.get("hashtag", None))
             if text is None:
-                return None
-            hashtag = self.obj_buffer.get(Hashtag, text=text)
-            if hashtag is None: 
-                hashtag_obj = self.session.query(Hashtag).filter(Hashtag.text.ilike(text)).first()
-                if hashtag_obj is not None:
-                    hashtag = ObjectBufferProxy(Hashtag, None, None, False, hashtag_obj)
-                    
-            if hashtag is None:
-                ind["text"] = text
-                hashtag = self.obj_buffer.add_object(Hashtag, None, ind, True)
+                return None, None
+            ind['text'] = text
+            hashtag = self.__get_or_create_object(Hashtag, {'text':text}, Hashtag.text.ilike(text), ind, True, False)
             entity_dict['hashtag_id'] = hashtag.id
             return EntityHashtag, entity_dict             
         
         def process_user_mentions():
-            user_mention = self.__get_user(ind)
+            user_mention = self.__get_user(ind, False, False)
             if user_mention is None:
                 entity_dict['user_id'] = None
             else:
@@ -256,24 +330,19 @@
             return EntityUser, entity_dict
         
         def process_urls():
-            url = self.obj_buffer.get(Url, url=ind["url"])
-            if url is None:
-                url_obj = self.session.query(Url).filter(Url.url == ind["url"]).first()
-                if url_obj is not None:
-                    url = ObjectBufferProxy(Url, None, None, False, url_obj)
-            if url is None:
-                url = self.obj_buffer.add_object(Url, None, ind, True)
+            url = self.__get_or_create_object(Url, {'url':ind["url"]}, None, ind, True, False)
             entity_dict['url_id'] = url.id
             return EntityUrl, entity_dict
-        
+                
         #{'': lambda }
         entity_klass, entity_dict =  { 
             'hashtags': process_hashtags,
             'user_mentions' : process_user_mentions,
-            'urls' : process_urls
-            }[ind_type]()
+            'urls' : process_urls,
+            'media': process_medias,
+            }.get(ind_type, lambda: (Entity, entity_dict))()
             
-        logger.debug("Process_entity entity_dict: " + repr(entity_dict)) #@UndefinedVariable
+        get_logger().debug("Process_entity entity_dict: " + repr(entity_dict)) #@UndefinedVariable
         if entity_klass:
             self.obj_buffer.add_object(entity_klass, None, entity_dict, False)
 
@@ -287,9 +356,9 @@
         ts_copy = adapt_fields(self.json_dict, fields_adapter["stream"]["tweet"])
         
         # get or create user
-        user = self.__get_user(self.json_dict["user"])
+        user = self.__get_user(self.json_dict["user"], True)
         if user is None:
-            logger.warning("USER not found " + repr(self.json_dict["user"])) #@UndefinedVariable
+            get_logger().warning("USER not found " + repr(self.json_dict["user"])) #@UndefinedVariable
             ts_copy["user_id"] = None
         else:
             ts_copy["user_id"] = user.id
@@ -299,25 +368,26 @@
         
         self.tweet = self.obj_buffer.add_object(Tweet, None, ts_copy, True)
             
-        # get entities
+        self.__process_entities()
+
+
+    def __process_entities(self):
         if "entities" in self.json_dict:
             for ind_type, entity_list in self.json_dict["entities"].items():
                 for ind in entity_list:
                     self.__process_entity(ind, ind_type)
         else:
-            extractor = twitter_text.Extractor(self.tweet.text)
-    
+            
+            text = self.tweet.text
+            extractor = twitter_text.Extractor(text)
             for ind in extractor.extract_hashtags_with_indices():
                 self.__process_entity(ind, "hashtags")
-    
+            
+            for ind in extractor.extract_urls_with_indices():
+                self.__process_entity(ind, "urls")
+            
             for ind in extractor.extract_mentioned_screen_names_with_indices():
                 self.__process_entity(ind, "user_mentions")
-    
-            for ind in extractor.extract_urls_with_indices():
-                self.__process_entity(ind, "urls")
-
-        self.session.flush()
-
 
     def __process_twitter_rest(self):
         tweet_nb = self.session.query(Tweet).filter(Tweet.id == self.json_dict["id"]).count()
@@ -348,28 +418,17 @@
             'screen_name' : self.json_dict["from_user"],                   
         }
         
-        user = self.__get_user(user_fields)
+        user = self.__get_user(user_fields, do_merge=False)
         if user is None:
-            logger.warning("USER not found " + repr(user_fields)) #@UndefinedVariable
+            get_logger().warning("USER not found " + repr(user_fields)) #@UndefinedVariable
             tweet_fields["user_id"] = None
         else:
             tweet_fields["user_id"] = user.id
         
         tweet_fields = adapt_fields(tweet_fields, fields_adapter["rest"]["tweet"])
         self.tweet = self.obj_buffer.add_object(Tweet, None, tweet_fields, True)
-        
-        text = self.tweet.text
-        
-        extractor = twitter_text.Extractor(text)
-        
-        for ind in extractor.extract_hashtags_with_indices():
-            self.__process_entity(ind, "hashtags")
-                    
-        for ind in extractor.extract_urls_with_indices():
-            self.__process_entity(ind, "urls")
-        
-        for ind in extractor.extract_mentioned_screen_names_with_indices():
-            self.__process_entity(ind, "user_mentions")
+                
+        self.__process_entities()
 
 
 
@@ -384,7 +443,7 @@
         else:
             self.__process_twitter_stream()
 
-        self.obj_buffer.add_object(TweetLog, None, {'tweet_source_id':self.source_id, 'status':TweetLog.TWEET_STATUS['OK']}, False)
+        self.obj_buffer.add_object(TweetLog, None, {'tweet_source_id':self.source_id, 'status':TweetLog.TWEET_STATUS['OK']}, True)
         
         self.obj_buffer.persists(self.session)
 
@@ -405,7 +464,7 @@
             
     logger = plogger
     if logger is None:
-        logger = logging.getLogger() #@UndefinedVariable
+        logger = get_logger() #@UndefinedVariable
     
     if len(logger.handlers) == 0:    
         filename = logging_config.get("filename")
@@ -477,4 +536,5 @@
     
     return query.distinct()
 
-logger = logging.getLogger() #@UndefinedVariable
+def get_logger():
+    return logging.getLogger("iri_tweet") #@UndefinedVariable