View on GitHub


2 days
Test Coverage

    Copyright 2014 Whole Foods Co-op

    This file is part of IT CORE.

    IT CORE is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    IT CORE is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    in the file license.txt along with IT CORE; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA


using System;
using System.Net;
using System.Net.Sockets;
using System.Collections.Generic;
using System.Threading;

namespace WebSockets 
      Simple server for handling WebSocket connections
      Currently push-to-clients only. Ignores any
      data sent by clients.
    public class WebSocketServer
        private TcpListener _tcp;
        private List<TcpClient> _clients;
        private Queue<string> _msg_queue;
        public static ManualResetEvent _connect_event;
        private int _verbose;

        public WebSocketServer(IPEndPoint ip)
            _tcp = new TcpListener(ip);
            _clients = new List<TcpClient>();
            _msg_queue = new Queue<string>();
            _connect_event = new ManualResetEvent(false);
            _verbose = 0;

        public void SetVerbose(int v)
            _verbose = v;

          Main loop
        public void Run()

            while (true) {
                _tcp.BeginAcceptTcpClient(new AsyncCallback(ConnectClient), _tcp);
                _connect_event.WaitOne(1000, false);
                /* test sending
                if (_clients.Count > 0) {
                    string ticks = Environment.TickCount.ToString();
                    Push("{\"ticks\":\"" + ticks + "\", \"server\":\"CORE\"}");

          Convert message to WebSocket data framing
        private byte[] MessageToFrames(string msg)
            byte[] payload = System.Text.Encoding.UTF8.GetBytes(msg);
            byte[] resp;
            if (payload.Length <= 125) {
                /* option 1
                   single byte length
                resp = new byte[payload.Length + 2];
                resp[0] = 0x81;
                resp[1] = (byte)payload.Length;
                for (int i=0; i<payload.Length; i++) {
                    resp[i+2] = payload[i];
            } else if (payload.Length > 125 && payload.Length <= (1 << 16)) {
                /* option 2
                   byte1 126 => 2 length bytes follow
                resp = new byte[payload.Length + 4];
                resp[0] = 0x81;
                resp[1] = 126;
                resp[2] = (byte)((payload.Length >> 8) & 0xff);
                resp[3] = (byte)(payload.Length & 0xff);
                for (int i=0; i<payload.Length; i++) {
                    resp[i+4] = payload[i];
            } else if (payload.Length > (1 << 16) && payload.Length <= (1 << 64)) {
                /* option 3
                   byte1 128 => 8 length bytes follow
                resp = new byte[payload.Length + 10];
                resp[0] = 0x81;
                resp[1] = 127;
                resp[2] = (byte)((payload.Length >> 56) & 0xff);
                resp[3] = (byte)((payload.Length >> 48) & 0xff);
                resp[4] = (byte)((payload.Length >> 40) & 0xff);
                resp[5] = (byte)((payload.Length >> 32) & 0xff);
                resp[6] = (byte)((payload.Length >> 24) & 0xff);
                resp[7] = (byte)((payload.Length >> 16) & 0xff);
                resp[8] = (byte)((payload.Length >> 8) & 0xff);
                resp[9] = (byte)(payload.Length & 0xff);;
                for (int i=0; i<payload.Length; i++) {
                    resp[i+10] = payload[i];
            } else {
                /* message is too long for one frame
                   create max size frame, recurse to get
                byte[] frame = new byte[payload.Length + 10];
                frame[0] = 0x1;
                frame[1] = 127;
                frame[2] = (byte)((payload.Length >> 56) & 0xff);
                frame[3] = (byte)((payload.Length >> 48) & 0xff);
                frame[4] = (byte)((payload.Length >> 40) & 0xff);
                frame[5] = (byte)((payload.Length >> 32) & 0xff);
                frame[6] = (byte)((payload.Length >> 24) & 0xff);
                frame[7] = (byte)((payload.Length >> 16) & 0xff);
                frame[8] = (byte)((payload.Length >> 8) & 0xff);
                frame[9] = (byte)(payload.Length & 0xff);;
                for (int i=0; i<(1<<64); i++) {
                    frame[i+10] = payload[i];

                byte[] next = new byte[(1<<64) - payload.Length];
                for (int i=0; i<next.Length; i++) {
                    next[i] = payload[(1<<64) + i];

                byte[] other_frames = MessageToFrames(System.Text.Encoding.UTF8.GetString(next));
                resp = new byte[frame.Length + other_frames.Length];
                for (int i=0; i<frame.Length; i++) {
                    resp[i] = frame[i];
                for (int i=0; i<other_frames.Length; i++) {
                    resp[i + frame.Length] = other_frames[i];

            return resp;

        private WsDataFrame FramesToMessage(byte[] frames)
            if (frames.Length < 2) {
                throw new WsProtocolException("Invalid frame: too short (header)");

            int opcode = frames[0] & 0xf;
            int last_fragment = frames[0] & 0x80;
            int masked = frames[1] & 0x80;
            long length = frames[1] & 0x7f;
            int data_starts = 6;

            if (masked == 0) {
                throw new WsFatalClientException("Client data not masked");

            byte[] mask_key = new byte[4];
            if (length <= 125) {
                if (frames.Length < 6) {
                    throw new WsProtocolException("Invalid frame: too short (mask)");
                Array.Copy(frames, 2, mask_key, 0, 4);
            } else if (length == 126) {
                if (frames.Length < 8) {
                    throw new WsProtocolException("Invalid frame: too short (mask)");
                length = ((frames[2] << 8) & 0xff00) + (frames[3] & 0xff);
                Array.Copy(frames, 4, mask_key, 0, 4);
                data_starts = 8;
            } else if (length == 127) {
                if (frames.Length < 14) {
                    throw new WsProtocolException("Invalid frame: too short (mask)");
                length = ((frames[2] & 0xff) << 56)
                    + ((frames[3] & 0xff) << 48)
                    + ((frames[4] & 0xff) << 40)
                    + ((frames[5] & 0xff) << 32)
                    + ((frames[6] & 0xff) << 24)
                    + ((frames[7] & 0xff) << 16)
                    + ((frames[8] & 0xff) <<  8)
                    + (frames[9] &0xff);
                Array.Copy(frames, 10, mask_key, 0, 4);
                data_starts = 14;

            if (frames.Length < (data_starts + length)) {
                throw new WsProtocolException("Invalid frame: too short (payload)");

            byte[] payload = new byte[length];
            for (int i=0; i < length; i++) {
                payload[i] = (byte)(frames[i+data_starts] ^ mask_key[i % 4]);

            if (last_fragment == 0) {
                byte[] remainder = new byte[frames.Length - (data_starts + length)];
                Array.Copy(frames, data_starts+length, remainder, 0, remainder.Length);
                WsDataFrame others = FramesToMessage(remainder);
                byte[] full_payload = new byte[payload.Length + others.payload.Length];
                Array.Copy(payload, 0, full_payload, 0, payload.Length);
                Array.Copy(others.payload, 0, full_payload, payload.Length, others.payload.Length);

                return new WsDataFrame(opcode, full_payload);
            } else {
                return new WsDataFrame(opcode, payload);

          Queue message msg and send
          queued messages to all
          connected clients. Message
          remains queued until a successful
        public void Push(string msg)

            NetworkStream s = null;
            while (_msg_queue.Count > 0 && _clients.Count > 0) {
                string next = _msg_queue.Peek();
                byte[] encoded = MessageToFrames(next);
                List<int> disconnected = new List<int>();
                for (int i = 0; i < _clients.Count; i++) {
                    try {
                        s = _clients[i].GetStream();
                        s.Write(encoded, 0, encoded.Length);
                    } catch (Exception ex) {
                        if (_verbose > 0) {
                        // disconnected?
                foreach (int i in disconnected) {
                    try {
                    } catch (Exception) { }
                if (_clients.Count > 0) {


          Callback when client connects
          Validates initial HTTP header from client
          and sends appropriate response.
          Adds client to list of connected clients.
        private void ConnectClient(IAsyncResult state)
            TcpClient client;
            try {
                TcpListener server = (TcpListener)state.AsyncState;
                client = server.EndAcceptTcpClient(state);
            } finally {

            try {
                NetworkStream stream = client.GetStream();
                byte[] buffer = new byte[256];
                string headers = "";
                int bytes_read;
                stream.ReadTimeout = 1000;
                // loop structure matters.
                // stream.DataAvailable is not reliable
                // until read has been initiated
                do {
                    bytes_read = stream.Read(buffer, 0, buffer.Length);
                    if (bytes_read == 0) {
                    headers += System.Text.Encoding.UTF8.GetString(buffer, 0, bytes_read);
                } while(stream.DataAvailable);

                if (_verbose > 0) {
                    System.Console.WriteLine("Handshake: " + headers);

                string[] lines = headers.Split('\n');
                string[] pair;
                string protocol = null;
                string key = null;
                string upgrade = null;
                string connection = null;
                string version = null;
                foreach (string line in lines) {
                    pair = line.Split(new char[]{':'}, 2);
                    if (pair.Length != 2) {
                    switch (pair[0].Trim()) {
                        case "Upgrade":
                            upgrade = pair[1].Trim();
                        case "Connection":
                            connection = pair[1].Trim();
                        case "Sec-WebSocket-Version":
                            version = pair[1].Trim();
                        case "Sec-WebSocket-Protocol":
                            protocol = pair[1].Trim();
                        case "Sec-WebSocket-Key":
                            key = pair[1].Trim();

                if (upgrade != "websocket") {
                    throw new WsFatalClientException("Invalid header \"Upgrade\"");
                } else if (!connection.Contains("Upgrade")) {
                    throw new WsFatalClientException("Invalid header \"Connection\"");
                } else if (key == null) {
                    throw new WsFatalClientException("Invalid header \"Sec-WebSocket-Key\"");

                key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // magic value
                byte[] hashed = System.Security.Cryptography.SHA1.Create().ComputeHash(

                string resp = "HTTP/1.1 101 Switching Protocols\r\n"
                    + "Connection: Upgrade\r\n"
                    + "Upgrade: websocket\r\n"
                    + "Sec-WebSocket-Accept: " + Convert.ToBase64String(hashed) + "\r\n";
                if (version != null && int.Parse(version) >= 13) {
                    resp += "Sec-WebSocket-Version: 13\r\n";
                if (protocol != null) {
                    resp += "Sec-WebSocket-Protocol: " + protocol + "\r\n";
                resp += "\r\n";

                if (_verbose > 0) {

                byte[] r = System.Text.Encoding.UTF8.GetBytes(resp);
                stream.Write(r, 0, r.Length);


            } catch (Exception ex) {
                // client initialization failed 
                if (_verbose > 0) {

          Do async read on client
          Added for debugging purposes
          Doesn't do much yet
        private void MonitorClient(TcpClient client)
            try {
                byte[] buffer = new byte[512];
                WsCallbackState state = new WsCallbackState(client, buffer);
                client.GetStream().BeginRead(buffer, 0, buffer.Length, new AsyncCallback(ClientDataCallback), state);
            } catch (Exception) {


          Async callback for data sent by client
        private void ClientDataCallback(IAsyncResult state)
            try {
                WsCallbackState cs = (WsCallbackState)state.AsyncState;
                NetworkStream stream = cs.client.GetStream();
                int bytes = stream.EndRead(state);
                if (bytes > 0) {
                    byte[] frames = new byte[bytes];
                    Array.Copy(cs.buffer, 0, frames, 0, bytes);
                    bool closed = false;
                    try {
                        // decode client message
                        WsDataFrame frame = FramesToMessage(frames);

                        if (frame.opcode == 0x8) { // close frame
                            int close_code = 0;
                            if (frame.payload.Length == 2) {
                                close_code = ((frame.payload[0] & 0xff) << 8) + (frame.payload[1] & 0xff);
                            byte[] close_msg = WsDataFrame.CloseFrame(close_code);
                            stream.Write(close_msg, 0, close_msg.Length);
                            closed = true;
                        } else if (frame.opcode == 0x9) { // ping frame
                            byte[] pong_msg = WsDataFrame.PongFrame(frame.payload);
                            stream.Write(pong_msg, 0, pong_msg.Length);
                    } catch (WsFatalClientException ex) {
                        // client did something wrong. kill connection
                        if (_verbose > 0) {
                        closed = true;
                    } catch (Exception ex) {
                        if (_verbose > 0) {
                    if (!closed) {
                } else { // zero-bytes read implies closed connection, I think
            } catch (Exception) {


        // testing stub
        public static void Main(string[] args)
            WebSocketServer ws = new WebSocketServer(new IPEndPoint(IPAddress.Any, 8888));

      State object for async delegate
      Needs access to buffer that contains
      actual bytes read but also needs
      access to the client to manage disconnects
      or additional reads.
    class WsCallbackState
        public TcpClient client;
        public byte[] buffer;

        public WsCallbackState(TcpClient c, byte[] b)
            client = c;
            buffer = b;

    class WsDataFrame
        public int opcode;
        public byte[] payload;

        public WsDataFrame(int o, byte[] p) 
            opcode = o;
            payload = p;

          Factory: get bytes for a close frame
        public static byte[] CloseFrame(int reason)
            if (reason >= 1000 && reason <= 1003) {
                return new byte[4]{ 
            } else {
                return new byte[2]{ 0x88, 0x0 };

          Factory: get bytes for a pong frame
        public static byte[] PongFrame(byte[] payload)
            if (payload.Length > 125) {
                // technically wrong. stupid client
                // sending GIGANTIC pings can decide
                // how to deal with it
                return new byte[2]{ 0x8a, 0x0 };
            byte[] frame = new byte[2 + payload.Length];
            frame[0] = 0x8a;
            frame[1] = (byte)payload.Length;
            Array.Copy(payload, 0, frame, 2, payload.Length);

            return frame;

    class WsException : Exception
        public WsException() { }
        public WsException(string message) : base(message) { }
        public WsException(string message, Exception inner) : base(message, inner) { }
    class WsProtocolException : WsException
        public WsProtocolException() { }
        public WsProtocolException(string message) : base(message) { }
        public WsProtocolException(string message, Exception inner) : base(message, inner) { }
    class WsFatalClientException : WsException
        public WsFatalClientException() { }
        public WsFatalClientException(string message) : base(message) { }
        public WsFatalClientException(string message, Exception inner) : base(message, inner) { }