tweetcast/server/tweetcast.py
changeset 311 13702105c5ee
parent 310 526d3e411736
child 312 f8336354d107
child 314 0f1e6ce19b6d
equal deleted inserted replaced
310:526d3e411736 311:13702105c5ee
     1 #!/usr/bin/env python
       
     2 # -*- coding: utf-8 -*-
       
     3 
       
     4 import anyjson
       
     5 from twisted.enterprise import adbapi
       
     6 from twisted.internet import reactor, task
       
     7 from twisted.internet.protocol import Protocol, Factory
       
     8 from autobahn.websocket import WebSocketServerFactory, WebSocketServerProtocol
       
     9 
       
    10 connectstring = "dbname='tweet_live' user='postgres' host='localhost' password='doiteshimashite'"
       
    11 columns = [ 'id', 'created_at', 'text', 'user_id', 'screen_name', 'profile_image_url' ]
       
    12 selectcommon = "SELECT tweet_tweet.id, tweet_tweet.created_at, text, user_id, screen_name, profile_image_url FROM tweet_tweet JOIN tweet_user ON tweet_tweet.user_id = tweet_user.id"
       
    13 
       
    14 dbpool = adbapi.ConnectionPool("psycopg2",connectstring)
       
    15 
       
    16 class TweetCast:
       
    17 	
       
    18 	def __init__(self):
       
    19 		self.lastid = 0L
       
    20 		self.serverfactory = TweetcastServerFactory(tweetcast=self)
       
    21 		dbpool.runQuery("SELECT MAX(tweet_tweet.id) FROM tweet_tweet").addCallback(self.callbackInit)
       
    22 	
       
    23 	def callbackInit(self, result):
       
    24 		self.lastid = result[0][0]
       
    25 		task.LoopingCall(self.scheduleTweets).start(1)
       
    26 
       
    27 	def scheduleTweets(self):
       
    28 		dbpool.runQuery("%s WHERE tweet_tweet.id > %d ORDER BY tweet_tweet.id ASC"%(selectcommon, self.lastid)).addCallback(self.callbackTweets)
       
    29 	
       
    30 	def callbackTweets(self, result):
       
    31 		if result:
       
    32 			self.lastid = result[len(result)-1][0]
       
    33 			data = [dict((columns[i], str(ligne[i])) for i in range(len(columns))) for ligne in result]
       
    34 		else:
       
    35 			data = None
       
    36 		self.serverfactory.broadcast(anyjson.serialize(data))
       
    37 
       
    38 class TweetcastServerProtocol(WebSocketServerProtocol):
       
    39 
       
    40 	def onOpen(self):
       
    41 		self.factory.register(self)
       
    42 		dbpool.runQuery("%s WHERE tweet_tweet.id <= %d ORDER BY tweet_tweet.id DESC LIMIT 100"%(selectcommon, self.factory.tweetcast.lastid)).addCallback(self.callbackOldTweets)
       
    43 
       
    44 	def callbackOldTweets(self, result):
       
    45 		if result:
       
    46 			data = [dict((columns[i], str(ligne[i])) for i in range(len(columns))) for ligne in result]
       
    47 			data.reverse()
       
    48 		else:
       
    49 			data = None
       
    50 		self.sendMessage(anyjson.serialize(data))
       
    51 		print "sending old tweets to new client"
       
    52 
       
    53 	def connectionLost(self, reason):
       
    54 		WebSocketServerProtocol.connectionLost(self, reason)
       
    55 		self.factory.unregister(self)
       
    56 
       
    57 	def onMessage(self, msg, binary):
       
    58 		print "Got message: " + msg
       
    59 
       
    60 class TweetcastServerFactory(WebSocketServerFactory):
       
    61  
       
    62    protocol = TweetcastServerProtocol
       
    63  
       
    64    def __init__(self, tweetcast=None):
       
    65       WebSocketServerFactory.__init__(self)
       
    66       self.clients = []
       
    67       self.tweetcast = tweetcast
       
    68  
       
    69    def register(self, client):
       
    70       if not client in self.clients:
       
    71          print "registered client " + client.peerstr
       
    72          self.clients.append(client)
       
    73  
       
    74    def unregister(self, client):
       
    75       if client in self.clients:
       
    76          print "unregistered client " + client.peerstr
       
    77          self.clients.remove(client)
       
    78  
       
    79    def broadcast(self, msg):
       
    80       print "broadcasting ids up to %d" % self.tweetcast.lastid
       
    81       for c in self.clients:
       
    82          print "send to " + c.peerstr
       
    83          c.sendMessage(msg)
       
    84 
       
    85 class FlashPolicySocketProtocol(Protocol):
       
    86     def dataReceived(self, data):
       
    87         self.transport.write("<?xml version=\"1.0\"?><!DOCTYPE cross-domain-policy SYSTEM \"/xml/dtds/cross-domain-policy.dtd\"><cross-domain-policy><site-control permitted-cross-domain-policies=\"master-only\"/><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\0") 
       
    88  
       
    89 class FlashPolicyFactory(Factory):
       
    90     def __init__(self):
       
    91         self.protocol = FlashPolicySocketProtocol;
       
    92 
       
    93 tc = TweetCast()
       
    94 reactor.listenTCP(843, FlashPolicyFactory())
       
    95 reactor.listenTCP(9000, tc.serverfactory)
       
    96 reactor.run()