script/lib/iri_tweet/utils.py
changeset 253 e9335ee3cf71
parent 250 6334869ab06d
parent 244 d4b7d6e2633f
child 254 2209e66bb50b
--- a/script/lib/iri_tweet/utils.py	Wed Jul 27 18:32:56 2011 +0200
+++ b/script/lib/iri_tweet/utils.py	Tue Aug 09 13:07:23 2011 +0200
@@ -1,6 +1,6 @@
-from models import Tweet, User, Hashtag, EntityHashtag, EntityUser, Url, \
-    EntityUrl, CONSUMER_KEY, CONSUMER_SECRET, APPLICATION_NAME, ACCESS_TOKEN_KEY, \
-    ACCESS_TOKEN_SECRET, adapt_date, adapt_json
+from models import (Tweet, User, Hashtag, EntityHashtag, EntityUser, Url, 
+    EntityUrl, CONSUMER_KEY, CONSUMER_SECRET, APPLICATION_NAME, ACCESS_TOKEN_KEY, 
+    ACCESS_TOKEN_SECRET, adapt_date, adapt_json, TweetSource, TweetLog)
 from sqlalchemy.sql import select, or_ #@UnresolvedImport
 import anyjson #@UnresolvedImport
 import datetime
@@ -77,13 +77,67 @@
     return dict([(str(k),adapt_one_field(k,v)) for k,v in fields_dict.items()])    
 
 
+class ObjectBufferProxy(object):
+    def __init__(self, klass, args, kwargs, must_flush, instance=None):
+        self.klass= klass
+        self.args = args
+        self.kwargs = kwargs
+        self.must_flush = must_flush
+        self.instance = instance
+        
+    def persists(self, session):
+        new_args = [arg() if callable(arg) else arg for arg in self.args] if self.args is not None else []
+        new_kwargs = dict([(k,v()) if callable(v) else (k,v) for k,v in self.kwargs.items()]) if self.kwargs is not None else {}
+        
+        self.instance = self.klass(*new_args, **new_kwargs)
+        session.add(self.instance)
+        if self.must_flush:
+            session.flush()
+            
+    def __getattr__(self, name):
+        return lambda : getattr(self.instance, name) if self.instance else None
+        
+        
+    
+
+class ObjectsBuffer(object):
+
+    def __init__(self):
+        self.__bufferlist = []
+        
+    def persists(self, session):
+        for object_proxy in self.__bufferlist:
+            object_proxy.persists(session)
+            
+    def add_object(self, klass, args, kwargs, must_flush):
+        new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush)
+        self.__bufferlist.append(new_proxy)
+        return new_proxy 
+    
+    def get(self, klass, **kwargs):
+        for proxy in self.__bufferlist:
+            if proxy.kwargs is None or len(proxy.kwargs) == 0 or proxy.klass != klass:
+                continue
+            found = True
+            for k,v in kwargs.items():
+                if (k not in proxy.kwargs) or v != proxy.kwargs[k]:
+                    found = False
+                    break
+            if found:
+                return proxy
+        
+        return None
+                
+                    
+        
+
 
 class TwitterProcessorException(Exception):
     pass
 
 class TwitterProcessor(object):
     
-    def __init__(self, json_dict, json_txt, session, token_filename=None):
+    def __init__(self, json_dict, json_txt, source_id, session, token_filename=None):
 
         if json_dict is None and json_txt is None:
             raise TwitterProcessorException("No json")
@@ -101,24 +155,39 @@
         if "id" not in self.json_dict:
             raise TwitterProcessorException("No id in json")
         
+        self.source_id = source_id
         self.session = session
         self.token_filename = token_filename
+        self.obj_buffer = ObjectsBuffer()
+
 
     def __get_user(self, user_dict):
-        logging.debug("Get user : " + repr(user_dict)) #@UndefinedVariable
+        logger.debug("Get user : " + repr(user_dict)) #@UndefinedVariable
     
         user_id = user_dict.get("id",None)    
         user_name = user_dict.get("screen_name", user_dict.get("name", None))
         
         if user_id is None and user_name is None:
             return None
-    
+
+        user = None
         if user_id:
-            user = self.session.query(User).filter(User.id == user_id).first()
+            user = self.obj_buffer.get(User, id=user_id)
         else:
-            user = self.session.query(User).filter(User.screen_name.ilike(user_name)).first()
+            user = self.obj_buffer.get(User, screen_name=user_name)
+            
+        if user is not None:
+            return user
+
+        #todo : add methpds to objectbuffer to get buffer user
+        user_obj = None
+        if user_id:
+            user_obj = self.session.query(User).filter(User.id == user_id).first()
+        else:
+            user_obj = self.session.query(User).filter(User.screen_name.ilike(user_name)).first()
     
-        if user is not None:
+        if user_obj is not None:
+            user = ObjectBufferProxy(User, None, None, False, user_obj)
             return user
     
         user_created_at = user_dict.get("created_at", None)
@@ -132,28 +201,27 @@
                 else:
                     user_dict = t.users.show(screen_name=user_name)            
             except Exception as e:
-                logging.info("get_user : TWITTER ERROR : " + repr(e)) #@UndefinedVariable
-                logging.info("get_user : TWITTER ERROR : " + str(e)) #@UndefinedVariable
+                logger.info("get_user : TWITTER ERROR : " + repr(e)) #@UndefinedVariable
+                logger.info("get_user : TWITTER ERROR : " + str(e)) #@UndefinedVariable
                 return None
     
         user_dict = adapt_fields(user_dict, fields_adapter["stream"]["user"])
         if "id" not in user_dict:
             return None
         
+        #TODO filter get, wrap in proxy
         user = self.session.query(User).filter(User.id == user_dict["id"]).first()
         
         if user is not None:
             return user
         
-        user = User(**user_dict)
+        user = self.obj_buffer.add_object(User, None, user_dict, True)
         
-        self.session.add(user)
-        self.session.flush()
-        
-        return user 
+        return user
+
 
     def __process_entity(self, ind, ind_type):
-        logging.debug("Process_entity : " + repr(ind) + " : " + repr(ind_type)) #@UndefinedVariable
+        logger.debug("Process_entity : " + repr(ind) + " : " + repr(ind_type)) #@UndefinedVariable
         
         ind = clean_keys(ind)
         
@@ -161,57 +229,53 @@
            "indice_start": ind["indices"][0],
            "indice_end"  : ind["indices"][1],
            "tweet_id"    : self.tweet.id,
-           "tweet"       : self.tweet
         }
     
         def process_hashtags():
             text = ind.get("text", ind.get("hashtag", None))
             if text is None:
-                return None 
-            hashtag = self.session.query(Hashtag).filter(Hashtag.text.ilike(text)).first()
+                return None
+            hashtag = self.obj_buffer.get(Hashtag, text=text)
+            if hashtag is None: 
+                hashtag_obj = self.session.query(Hashtag).filter(Hashtag.text.ilike(text)).first()
+                if hashtag_obj is not None:
+                    hashtag = ObjectBufferProxy(Hashtag, None, None, False, hashtag_obj)
+                    
             if hashtag is None:
                 ind["text"] = text
-                hashtag = Hashtag(**ind)
-                self.session.add(hashtag)
-                self.session.flush()
-            entity_dict['hashtag'] = hashtag
+                hashtag = self.obj_buffer.add_object(Hashtag, None, ind, True)
             entity_dict['hashtag_id'] = hashtag.id
-            entity = EntityHashtag(**entity_dict)
-            return entity
+            return EntityHashtag, entity_dict             
         
         def process_user_mentions():
             user_mention = self.__get_user(ind)
             if user_mention is None:
-                entity_dict['user'] = None
                 entity_dict['user_id'] = None
             else:
-                entity_dict['user'] = user_mention
                 entity_dict['user_id'] = user_mention.id
-            entity = EntityUser(**entity_dict)
-            return entity
+            return EntityUser, entity_dict
         
         def process_urls():
-            url = self.session.query(Url).filter(Url.url == ind["url"]).first()
+            url = self.obj_buffer.get(Url, url=ind["url"])
             if url is None:
-                url = Url(**ind)
-                self.session.add(url)
-                self.session.flush()
-            entity_dict['url'] = url
+                url_obj = self.session.query(Url).filter(Url.url == ind["url"]).first()
+                if url_obj is not None:
+                    url = ObjectBufferProxy(Url, None, None, False, url_obj)
+            if url is None:
+                url = self.obj_buffer.add_object(Url, None, ind, True)
             entity_dict['url_id'] = url.id
-            entity = EntityUrl(**entity_dict)
-            return entity
+            return EntityUrl, entity_dict
         
         #{'': lambda }
-        entity =  { 
+        entity_klass, entity_dict =  { 
             'hashtags': process_hashtags,
             'user_mentions' : process_user_mentions,
             'urls' : process_urls
             }[ind_type]()
             
-        logging.debug("Process_entity entity_dict: " + repr(entity_dict)) #@UndefinedVariable
-        if entity:
-            self.session.add(entity)
-            self.session.flush()
+        logger.debug("Process_entity entity_dict: " + repr(entity_dict)) #@UndefinedVariable
+        if entity_klass:
+            self.obj_buffer.add_object(entity_klass, None, entity_dict, False)
 
 
     def __process_twitter_stream(self):
@@ -225,16 +289,15 @@
         # get or create user
         user = self.__get_user(self.json_dict["user"])
         if user is None:
-            logging.warning("USER not found " + repr(self.json_dict["user"])) #@UndefinedVariable
-            ts_copy["user"] = None
+            logger.warning("USER not found " + repr(self.json_dict["user"])) #@UndefinedVariable
             ts_copy["user_id"] = None
         else:
-            ts_copy["user"] = user
-            ts_copy["user_id"] = ts_copy["user"].id
-        ts_copy["original_json"] = self.json_txt
+            ts_copy["user_id"] = user.id
+            
+        del(ts_copy['user'])
+        ts_copy["tweet_source_id"] = self.source_id
         
-        self.tweet = Tweet(**ts_copy)
-        self.session.add(self.tweet)
+        self.tweet = self.obj_buffer.add_object(Tweet, None, ts_copy, True)
             
         # get entities
         if "entities" in self.json_dict:
@@ -260,7 +323,8 @@
         tweet_nb = self.session.query(Tweet).filter(Tweet.id == self.json_dict["id"]).count()
         if tweet_nb > 0:
             return
-            
+        
+        
         tweet_fields = {
             'created_at': self.json_dict["created_at"], 
             'favorited': False,
@@ -272,8 +336,8 @@
             #'place': ts["place"],
             'source': self.json_dict["source"],
             'text': self.json_dict["text"],
-            'truncated': False,
-            'original_json' : self.json_txt,
+            'truncated': False,            
+            'tweet_source_id' : self.source_id,
         }
         
         #user
@@ -286,16 +350,13 @@
         
         user = self.__get_user(user_fields)
         if user is None:
-            logging.warning("USER not found " + repr(user_fields)) #@UndefinedVariable
-            tweet_fields["user"] = None
+            logger.warning("USER not found " + repr(user_fields)) #@UndefinedVariable
             tweet_fields["user_id"] = None
         else:
-            tweet_fields["user"] = user
             tweet_fields["user_id"] = user.id
         
         tweet_fields = adapt_fields(tweet_fields, fields_adapter["rest"]["tweet"])
-        self.tweet = Tweet(**tweet_fields)
-        self.session.add(self.tweet)
+        self.tweet = self.obj_buffer.add_object(Tweet, None, tweet_fields, True)
         
         text = self.tweet.text
         
@@ -303,26 +364,37 @@
         
         for ind in extractor.extract_hashtags_with_indices():
             self.__process_entity(ind, "hashtags")
-            
-        for ind in extractor.extract_mentioned_screen_names_with_indices():
-            self.__process_entity(ind, "user_mentions")
-        
+                    
         for ind in extractor.extract_urls_with_indices():
             self.__process_entity(ind, "urls")
         
-        self.session.flush()
+        for ind in extractor.extract_mentioned_screen_names_with_indices():
+            self.__process_entity(ind, "user_mentions")
+
 
 
     def process(self):
+        
+        if self.source_id is None:
+            tweet_source = self.obj_buffer.add_object(TweetSource, None, {'original_json':self.json_txt}, True)
+            self.source_id = tweet_source.id
+        
         if "metadata" in self.json_dict:
             self.__process_twitter_rest()
         else:
             self.__process_twitter_stream()
+
+        self.obj_buffer.add_object(TweetLog, None, {'tweet_source_id':self.source_id, 'status':TweetLog.TWEET_STATUS['OK']}, False)
         
+        self.obj_buffer.persists(self.session)
+
 
-def set_logging(options):
+def set_logging(options, plogger=None):
     
-    logging_config = {}
+    logging_config = {
+        "format" : '%(asctime)s %(levelname)s:%(name)s:%(message)s',
+        "level" : max(logging.NOTSET, min(logging.CRITICAL, logging.WARNING - 10 * options.verbose + 10 * options.quiet)), #@UndefinedVariable
+    }
     
     if options.logfile == "stdout":
         logging_config["stream"] = sys.stdout
@@ -330,9 +402,27 @@
         logging_config["stream"] = sys.stderr
     else:
         logging_config["filename"] = options.logfile
-        
-    logging_config["level"] = max(logging.NOTSET, min(logging.CRITICAL, logging.WARNING - 10 * options.verbose + 10 * options.quiet)) #@UndefinedVariable
-    logging.basicConfig(**logging_config) #@UndefinedVariable
+            
+    logger = plogger
+    if logger is None:
+        logger = logging.getLogger() #@UndefinedVariable
+    
+    if len(logger.handlers) == 0:    
+        filename = logging_config.get("filename")
+        if filename:
+            mode = logging_config.get("filemode", 'a')
+            hdlr = logging.FileHandler(filename, mode) #@UndefinedVariable
+        else:
+            stream = logging_config.get("stream")
+            hdlr = logging.StreamHandler(stream) #@UndefinedVariable
+        fs = logging_config.get("format", logging.BASIC_FORMAT) #@UndefinedVariable
+        dfs = logging_config.get("datefmt", None)
+        fmt = logging.Formatter(fs, dfs) #@UndefinedVariable
+        hdlr.setFormatter(fmt)
+        logger.addHandler(hdlr)
+        level = logging_config.get("level")
+        if level is not None:
+            logger.setLevel(level)
     
     options.debug = (options.verbose-options.quiet > 0)
 
@@ -387,4 +477,4 @@
     
     return query.distinct()
 
-    
+logger = logging.getLogger() #@UndefinedVariable