1 #!/usr/bin/env python |
|
2 # -*- coding: utf-8 -*- |
|
3 |
|
4 import anyjson |
|
5 from twisted.enterprise import adbapi |
|
6 from twisted.internet import reactor, task |
|
7 from twisted.internet.protocol import Protocol, Factory |
|
8 from autobahn.websocket import WebSocketServerFactory, WebSocketServerProtocol |
|
9 |
|
10 connectstring = "dbname='tweet_live' user='postgres' host='localhost' password='doiteshimashite'" |
|
11 columns = [ 'id', 'created_at', 'text', 'user_id', 'screen_name', 'profile_image_url' ] |
|
12 selectcommon = "SELECT tweet_tweet.id, tweet_tweet.created_at, text, user_id, screen_name, profile_image_url FROM tweet_tweet JOIN tweet_user ON tweet_tweet.user_id = tweet_user.id" |
|
13 |
|
14 dbpool = adbapi.ConnectionPool("psycopg2",connectstring) |
|
15 |
|
16 class TweetCast: |
|
17 |
|
18 def __init__(self): |
|
19 self.lastid = 0L |
|
20 self.serverfactory = TweetcastServerFactory(tweetcast=self) |
|
21 dbpool.runQuery("SELECT MAX(tweet_tweet.id) FROM tweet_tweet").addCallback(self.callbackInit) |
|
22 |
|
23 def callbackInit(self, result): |
|
24 self.lastid = result[0][0] |
|
25 task.LoopingCall(self.scheduleTweets).start(1) |
|
26 |
|
27 def scheduleTweets(self): |
|
28 dbpool.runQuery("%s WHERE tweet_tweet.id > %d ORDER BY tweet_tweet.id ASC"%(selectcommon, self.lastid)).addCallback(self.callbackTweets) |
|
29 |
|
30 def callbackTweets(self, result): |
|
31 if result: |
|
32 self.lastid = result[len(result)-1][0] |
|
33 data = [dict((columns[i], str(ligne[i])) for i in range(len(columns))) for ligne in result] |
|
34 else: |
|
35 data = None |
|
36 self.serverfactory.broadcast(anyjson.serialize(data)) |
|
37 |
|
38 class TweetcastServerProtocol(WebSocketServerProtocol): |
|
39 |
|
40 def onOpen(self): |
|
41 self.factory.register(self) |
|
42 dbpool.runQuery("%s WHERE tweet_tweet.id <= %d ORDER BY tweet_tweet.id DESC LIMIT 100"%(selectcommon, self.factory.tweetcast.lastid)).addCallback(self.callbackOldTweets) |
|
43 |
|
44 def callbackOldTweets(self, result): |
|
45 if result: |
|
46 data = [dict((columns[i], str(ligne[i])) for i in range(len(columns))) for ligne in result] |
|
47 data.reverse() |
|
48 else: |
|
49 data = None |
|
50 self.sendMessage(anyjson.serialize(data)) |
|
51 print "sending old tweets to new client" |
|
52 |
|
53 def connectionLost(self, reason): |
|
54 WebSocketServerProtocol.connectionLost(self, reason) |
|
55 self.factory.unregister(self) |
|
56 |
|
57 def onMessage(self, msg, binary): |
|
58 print "Got message: " + msg |
|
59 |
|
60 class TweetcastServerFactory(WebSocketServerFactory): |
|
61 |
|
62 protocol = TweetcastServerProtocol |
|
63 |
|
64 def __init__(self, tweetcast=None): |
|
65 WebSocketServerFactory.__init__(self) |
|
66 self.clients = [] |
|
67 self.tweetcast = tweetcast |
|
68 |
|
69 def register(self, client): |
|
70 if not client in self.clients: |
|
71 print "registered client " + client.peerstr |
|
72 self.clients.append(client) |
|
73 |
|
74 def unregister(self, client): |
|
75 if client in self.clients: |
|
76 print "unregistered client " + client.peerstr |
|
77 self.clients.remove(client) |
|
78 |
|
79 def broadcast(self, msg): |
|
80 print "broadcasting ids up to %d" % self.tweetcast.lastid |
|
81 for c in self.clients: |
|
82 print "send to " + c.peerstr |
|
83 c.sendMessage(msg) |
|
84 |
|
85 class FlashPolicySocketProtocol(Protocol): |
|
86 def dataReceived(self, data): |
|
87 self.transport.write("<?xml version=\"1.0\"?><!DOCTYPE cross-domain-policy SYSTEM \"/xml/dtds/cross-domain-policy.dtd\"><cross-domain-policy><site-control permitted-cross-domain-policies=\"master-only\"/><allow-access-from domain=\"*\" to-ports=\"*\" /></cross-domain-policy>\0") |
|
88 |
|
89 class FlashPolicyFactory(Factory): |
|
90 def __init__(self): |
|
91 self.protocol = FlashPolicySocketProtocol; |
|
92 |
|
93 tc = TweetCast() |
|
94 reactor.listenTCP(843, FlashPolicyFactory()) |
|
95 reactor.listenTCP(9000, tc.serverfactory) |
|
96 reactor.run() |
|