script/lib/iri_tweet/iri_tweet/utils.py
changeset 1507 1e7aa7dc444b
parent 1497 14a9bed2e3cd
--- a/script/lib/iri_tweet/iri_tweet/utils.py	Mon Jul 01 14:35:52 2019 +0200
+++ b/script/lib/iri_tweet/iri_tweet/utils.py	Tue Jul 02 17:41:28 2019 +0200
@@ -5,13 +5,14 @@
 import logging
 import math
 import os.path
-import Queue
+import queue
 import socket
 import sys
 
 import twitter.oauth
 import twitter.oauth_dance
 from sqlalchemy.sql import or_, select
+from sqlalchemy.orm import class_mapper
 
 from .models import (ACCESS_TOKEN_KEY, ACCESS_TOKEN_SECRET, APPLICATION_NAME,
                      EntityHashtag, Hashtag, Tweet, User, adapt_date,
@@ -20,18 +21,18 @@
 CACHE_ACCESS_TOKEN = {}
 
 def get_oauth_token(consumer_key, consumer_secret, token_file_path=None, check_access_token=True, application_name=APPLICATION_NAME):
-    
+
     global CACHE_ACCESS_TOKEN
 
     if 'ACCESS_TOKEN_KEY' in globals() and 'ACCESS_TOKEN_SECRET' in globals() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
         return ACCESS_TOKEN_KEY,ACCESS_TOKEN_SECRET
-    
+
     res = CACHE_ACCESS_TOKEN.get(application_name, None)
-    
+
     if res is None and token_file_path and os.path.exists(token_file_path):
         get_logger().debug("get_oauth_token : reading token from file %s" % token_file_path) #@UndefinedVariable
         res = twitter.oauth.read_token_file(token_file_path)
-    
+
     if res is not None and check_access_token:
         get_logger().debug("get_oauth_token : Check oauth tokens") #@UndefinedVariable
         t = twitter.Twitter(auth=twitter.OAuth(res[0], res[1], consumer_key, consumer_secret))
@@ -39,7 +40,7 @@
         try:
             status = t.application.rate_limit_status(resources="account")
         except Exception as e:
-            get_logger().debug("get_oauth_token : error getting rate limit status %s " % repr(e))            
+            get_logger().debug("get_oauth_token : error getting rate limit status %s " % repr(e))
             get_logger().debug("get_oauth_token : error getting rate limit status %s " % str(e))
             status = None
         get_logger().debug("get_oauth_token : Check oauth tokens : status %s" % repr(status)) #@UndefinedVariable
@@ -50,27 +51,27 @@
     if res is None:
         get_logger().debug("get_oauth_token : doing the oauth dance")
         res = twitter.oauth_dance(application_name, consumer_key, consumer_secret, token_file_path)
-        
-    
+
+
     CACHE_ACCESS_TOKEN[application_name] = res
-    
+
     get_logger().debug("get_oauth_token : done got %s" % repr(res))
     return res
 
 
 def get_oauth2_token(consumer_key, consumer_secret, token_file_path=None, check_access_token=True, application_name=APPLICATION_NAME):
-    
+
     global CACHE_ACCESS_TOKEN
 
     if 'ACCESS_TOKEN_KEY' in globals() and 'ACCESS_TOKEN_SECRET' in globals() and ACCESS_TOKEN_KEY and ACCESS_TOKEN_SECRET:
         return ACCESS_TOKEN_KEY,ACCESS_TOKEN_SECRET
-    
+
     res = CACHE_ACCESS_TOKEN.get(application_name, None)
-    
+
     if res is None and token_file_path and os.path.exists(token_file_path):
         get_logger().debug("get_oauth2_token : reading token from file %s" % token_file_path) #@UndefinedVariable
         res = twitter.oauth2.read_bearer_token_file(token_file_path)
-    
+
     if res is not None and check_access_token:
         get_logger().debug("get_oauth2_token : Check oauth tokens") #@UndefinedVariable
         t = twitter.Twitter(auth=twitter.OAuth2(consumer_key, consumer_secret, res))
@@ -78,7 +79,7 @@
         try:
             status = t.application.rate_limit_status()
         except Exception as e:
-            get_logger().debug("get_oauth2_token : error getting rate limit status %s " % repr(e))            
+            get_logger().debug("get_oauth2_token : error getting rate limit status %s " % repr(e))
             get_logger().debug("get_oauth2_token : error getting rate limit status %s " % str(e))
             status = None
         get_logger().debug("get_oauth2_token : Check oauth tokens : status %s" % repr(status)) #@UndefinedVariable
@@ -89,10 +90,10 @@
     if res is None:
         get_logger().debug("get_oauth2_token : doing the oauth dance")
         res = twitter.oauth2_dance(consumer_key, consumer_secret, token_file_path)
-        
-    
+
+
     CACHE_ACCESS_TOKEN[application_name] = res
-    
+
     get_logger().debug("get_oauth_token : done got %s" % repr(res))
     return res
 
@@ -102,7 +103,7 @@
     return datetime.datetime(*ts[0:7])
 
 def clean_keys(dict_val):
-    return dict([(str(key),value) for key,value in dict_val.iteritems()])
+    return dict([(str(key),value) for key,value in dict_val.items()])
 
 fields_adapter = {
     'stream': {
@@ -115,22 +116,29 @@
         },
         "user": {
             "created_at"  : adapt_date,
+            "derived" : adapt_json,
+            "withheld_in_countries" : adapt_json
         },
 
     },
-                  
+
     'entities' : {
         "medias": {
             "sizes"  : adapt_json,
-        },                  
+        },
     },
     'rest': {
+        "user": {
+            "created_at"  : adapt_date,
+            "derived" : adapt_json,
+            "withheld_in_countries" : adapt_json
+        },
         "tweet" : {
             "place"         : adapt_json,
             "geo"           : adapt_json,
             "created_at"    : adapt_date,
 #            "original_json" : adapt_json,
-        }, 
+        },
     },
 }
 
@@ -143,41 +151,44 @@
             return adapter_mapping[field](value)
         else:
             return value
-    return dict([(str(k),adapt_one_field(k,v)) for k,v in fields_dict.iteritems()])    
+    return dict([(str(k),adapt_one_field(k,v)) for k,v in fields_dict.items()])
 
 
 class ObjectBufferProxy(object):
     def __init__(self, klass, args, kwargs, must_flush, instance=None):
         self.klass= klass
+        self.mapper = class_mapper(klass)
         self.args = args
         self.kwargs = kwargs
         self.must_flush = must_flush
         self.instance = instance
-        
+
     def persists(self, session):
         new_args = [arg() if callable(arg) else arg for arg in self.args] if self.args is not None else []
         new_kwargs = dict([(k,v()) if callable(v) else (k,v) for k,v in self.kwargs.items()]) if self.kwargs is not None else {}
-        
-        self.instance = self.klass(*new_args, **new_kwargs)
+
+        self.instance = self.klass(*new_args, **{
+            k: v for k, v in new_kwargs.items() if k in self.mapper.attrs.keys()
+        })
         if self.instance is not None:
             self.instance = session.merge(self.instance)
 
         session.add(self.instance)
         if self.must_flush:
             session.flush()
-            
+
     def __getattr__(self, name):
         return lambda : getattr(self.instance, name) if self.instance else None
-        
-        
-    
+
+
+
 
 class ObjectsBuffer(object):
 
     def __init__(self):
         self.__bufferlist = []
         self.__bufferdict = {}
-    
+
     def __add_proxy_object(self, proxy):
         proxy_list =  self.__bufferdict.get(proxy.klass, None)
         if proxy_list is None:
@@ -185,16 +196,16 @@
             self.__bufferdict[proxy.klass] = proxy_list
         proxy_list.append(proxy)
         self.__bufferlist.append(proxy)
-        
+
     def persists(self, session):
         for object_proxy in self.__bufferlist:
             object_proxy.persists(session)
-                
+
     def add_object(self, klass, args, kwargs, must_flush, instance=None):
         new_proxy = ObjectBufferProxy(klass, args, kwargs, must_flush, instance)
         self.__add_proxy_object(new_proxy)
-        return new_proxy 
-    
+        return new_proxy
+
     def get(self, klass, **kwargs):
         if klass in self.__bufferdict:
             for proxy in self.__bufferdict[klass]:
@@ -208,27 +219,27 @@
                 if found:
                     return proxy
         return None
-                
+
 
 
 def set_logging(options, plogger=None, queue=None):
-    
+
     logging_config = {
         "format" : '%(asctime)s %(levelname)s:%(name)s:%(message)s',
         "level" : max(logging.NOTSET, min(logging.CRITICAL, logging.WARNING - 10 * options.verbose + 10 * options.quiet)), #@UndefinedVariable
     }
-    
+
     if options.logfile == "stdout":
         logging_config["stream"] = sys.stdout
     elif options.logfile == "stderr":
         logging_config["stream"] = sys.stderr
     else:
         logging_config["filename"] = options.logfile
-            
+
     logger = plogger
     if logger is None:
         logger = get_logger() #@UndefinedVariable
-    
+
     if len(logger.handlers) == 0:
         filename = logging_config.get("filename")
         if queue is not None:
@@ -239,7 +250,7 @@
         else:
             stream = logging_config.get("stream")
             hdlr = logging.StreamHandler(stream) #@UndefinedVariable
-            
+
         fs = logging_config.get("format", logging.BASIC_FORMAT) #@UndefinedVariable
         dfs = logging_config.get("datefmt", None)
         fmt = logging.Formatter(fs, dfs) #@UndefinedVariable
@@ -248,7 +259,7 @@
         level = logging_config.get("level")
         if level is not None:
             logger.setLevel(level)
-    
+
     options.debug = (options.verbose-options.quiet > 0)
     return logger
 
@@ -261,12 +272,12 @@
                       help="quiet", default=0)
 
 def get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
-    
+
     query = query.join(EntityHashtag).join(Hashtag)
-    
+
     if tweet_exclude_table is not None:
         query = query.filter(~Tweet.id.in_(select([tweet_exclude_table.c.id]))) #@UndefinedVariable
-    
+
     if start_date:
         query = query.filter(Tweet.created_at >=  start_date)
     if end_date:
@@ -275,32 +286,32 @@
     if user_whitelist:
         query = query.join(User).filter(User.screen_name.in_(user_whitelist))
 
-    
+
     if hashtags :
         def merge_hash(l,h):
             l.extend(h.split(","))
             return l
         htags = functools.reduce(merge_hash, hashtags, [])
-        
+
         query = query.filter(or_(*map(lambda h: Hashtag.text.contains(h), htags))) #@UndefinedVariable
-    
+
     return query
 
-    
-    
+
+
 def get_filter_query(session, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist):
-    
+
     query = session.query(Tweet)
-    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist) 
+    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, user_whitelist)
     return query.order_by(Tweet.created_at)
-    
+
 
 def get_user_query(session, start_date, end_date, hashtags, tweet_exclude_table):
-    
+
     query = session.query(User).join(Tweet)
-    
-    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, None)    
-    
+
+    query = get_base_query(session, query, start_date, end_date, hashtags, tweet_exclude_table, None)
+
     return query.distinct()
 
 logger_name = "iri.tweet"
@@ -312,7 +323,7 @@
 
 class QueueHandler(logging.Handler):
     """
-    This is a logging handler which sends events to a multiprocessing queue.    
+    This is a logging handler which sends events to a multiprocessing queue.
     """
 
     def __init__(self, queue, ignore_full):
@@ -322,7 +333,7 @@
         logging.Handler.__init__(self) #@UndefinedVariable
         self.queue = queue
         self.ignore_full = True
-        
+
     def emit(self, record):
         """
         Emit a record.
@@ -338,7 +349,7 @@
                 self.queue.put_nowait(record)
         except AssertionError:
             pass
-        except Queue.Full:
+        except queue.Full:
             if self.ignore_full:
                 pass
             else:
@@ -366,7 +377,7 @@
     if percent >= 100:
         writer.write("\n")
     writer.flush()
-    
+
     return writer
 
 def get_unused_port():