# -*- Mode: Python -*-

import base64
import struct
import coro
import os
import sys
import hashlib

import coro.http

W = coro.write_stderr

def do_mask (data, mask):
    n = len (data)
    r = bytearray (n)
    i = 0
    while i < len (data):
        r[i] = chr (ord (data[i]) ^ mask[i%4])
        i += 1
    return bytes (r)

class ws_packet:
    fin = 0
    opcode = 0
    mask = 0
    plen = 0
    masking = []
    payload = ''
    def __repr__ (self):
        return '<fin=%r opcode=%r mask=%r plen=%r masking=%r payload=%d bytes>' % (
            self.fin,
            self.opcode,
            self.mask,
            self.plen,
            self.masking,
            len (self.payload),
            )

    def unpack (self):
        if self.mask:
            return do_mask (self.payload, self.masking)
        else:
            return self.payload

class websocket_handler:

    magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

    def __init__ (self, path):
        self.path = path
                  
    def match (self, request):
        return request.path == self.path and request.method == 'get' and request['upgrade'] == 'websocket'

    def handle_request (self, request):
        rh = request.request_headers
        key = rh.get_one ('sec-websocket-key')
        d = hashlib.new ('sha1')
        d.update (key + self.magic)
        reply = base64.encodestring (d.digest()).strip()
        r = [
            'HTTP/1.1 101 Switching Protocols',
            'Upgrade: websocket',
            'Connection: Upgrade',
            'Sec-WebSocket-Accept: %s' % (reply,),
            ]
        if rh.has_key ('sec-websocket-protocol'):
            r.append (
                'Sec-WebSocket-Protocol: %s' % (
                    rh.get_one ('sec-websocket-protocol')
                    )
                )
        conn = request.client.conn
        conn.send ('\r\n'.join (r) + '\r\n\r\n')
        self.protocol (conn)
        
    def recv_exact (self, conn, size):
        left = size
        r = []
        while left:
            block = conn.recv (left)
            if not block:
                break
            else:
                r.append (block)
                left -= len (block)
        return ''.join (r)

    def protocol (self, conn):
        while 1:
            close_it = self.read_packet (conn)
            if close_it:
                break
        
    def read_packet (self, conn):
        head = self.recv_exact (conn, 2)
        if not head:
            return True
        head, = struct.unpack ('>H', head)
        p = ws_packet()
        p.fin    = (head & 0x8000) >> 15
        p.opcode = (head & 0x0f00) >> 8
        p.mask   = (head & 0x0080) >> 7
        plen     = (head & 0x007f) >> 0
        if plen < 126:
            pass
        elif plen == 126:
            plen, = struct.unpack ('>H', self.recv_exact (conn, 2))
        else: # plen == 127:
            plen, = struct.unpack ('>Q', self.recv_exact (conn, 8))
        p.plen = plen
        if p.mask:
            p.masking = struct.unpack ('>BBBB', self.recv_exact (conn, 4))
        else:
            p.masking = None
        p.payload = self.recv_exact (conn, plen)
        if p.opcode in (1, 2):
            return self.handle_packet (conn, p)
        elif p.opcode == 0:
            # continuation frame
            raise NotImplementedError
        elif p.opcode == 8:
            return True
        elif p.opcode == 9:
            # ping
            raise NotImplementedError
        else:
            raise UnknownOpcode (p)

    def handle_packet (self, conn, p):
        #W ('packet=%r\n' % (p,))
        d = p.unpack()
        #W ('  payload=%r\n' % (p.unpack(),))
        self.send_text (conn, "Howdy, WebSocket")
        return False

    def send_text (self, conn, data, final=True):
        head = 0
        if final:
            head |= 0x8000
        # opcode 1 = text
        head |= 1 << 8
        # len
        ld = len (data)
        if ld >= 126:
            raise NotImplementedError
        head |= ld
        # RFC6455: A server MUST NOT mask any frames that it sends to the client.
        p = [ struct.pack ('>H', head), data ]
        conn.writev (p)

if __name__ == '__main__':
    import coro.http
    import coro.backdoor
    fh = coro.http.handlers.favicon_handler()
    sh = coro.http.handlers.coro_status_handler()
    wh = websocket_handler ('/echo')
    handlers = [fh, sh, wh]
    server = coro.http.server (('0.0.0.0', 9001))
    for h in handlers:
        server.push_handler (h)
    coro.spawn (server.start)
    coro.spawn (coro.backdoor.serve, unix_path='/tmp/ws.bd')
    coro.event_loop (30.0)
