module websocket import net import net.urllib import encoding.base64 import eventbus import sync import net.websocket.logger const ( l = logger.new("ws") ) pub struct Client { retry int eb &eventbus.EventBus is_ssl bool lock sync.Mutex = sync.new_mutex() write_lock sync.Mutex = sync.new_mutex() //subprotocol_len int //cwebsocket_subprotocol *subprotocol; //cwebsocket_subprotocol *subprotocols[]; mut: state State socket net.Socket flags []Flag sslctx &C.SSL_CTX ssl &C.SSL fragments []Fragment pub mut: uri string subscriber &eventbus.Subscriber } struct Fragment { data voidptr len u64 code OPCode } pub struct Message { pub: opcode OPCode payload voidptr payload_len int } pub enum OPCode { continuation = 0x00 text_frame = 0x01 binary_frame = 0x02 close = 0x08 ping = 0x09 pong = 0x0A } enum State { connecting = 0 connected open closing closed } struct Uri { mut: hostname string port string resource string querystring string } enum Flag { has_accept has_connection has_upgrade } struct Frame { mut: fin bool rsv1 bool rsv2 bool rsv3 bool opcode OPCode mask bool payload_len u64 masking_key [4]byte } pub fn new(uri string) &Client { eb := eventbus.new() ws := &Client{ uri: uri state: .closed eb: eb, subscriber: eb.subscriber is_ssl: uri.starts_with("wss") ssl: C.NULL sslctx: C.NULL } return ws } fn C.sscanf() int fn (ws &Client) parse_uri() &Uri { u := urllib.parse(ws.uri) or {panic(err)} v := u.request_uri().split("?") querystring := if v.len > 1 {"?" + v[1]} else {""} return &Uri { hostname: u.hostname() port: u.port() resource: v[0] querystring: querystring } } pub fn (mut ws Client) connect() int { match ws.state { .connected { l.f("connect: websocket already connected") } .connecting { l.f("connect: websocket already connecting") } .open { l.f("connect: websocket already open") } else { // do nothing } } ws.lock.lock() ws.state = .connecting ws.lock.unlock() uri := ws.parse_uri() nonce := get_nonce() seckey := base64.encode(nonce) ai_family := C.AF_INET ai_socktype := C.SOCK_STREAM l.d("handshake header:") handshake := "GET ${uri.resource}${uri.querystring} HTTP/1.1\r\nHost: ${uri.hostname}:${uri.port}\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: ${seckey}\r\nSec-WebSocket-Version: 13\r\n\r\n" l.d(handshake) socket := net.new_socket(ai_family, ai_socktype, 0) or { l.f(err) return -1 } ws.socket = socket ws.socket.connect(uri.hostname, uri.port.int()) or { l.f(err) return -1 } optval := 1 ws.socket.setsockopt(C.SOL_SOCKET, C.SO_KEEPALIVE, &optval) or { l.f(err) return -1 } if ws.is_ssl { ws.connect_ssl() } ws.lock.lock() ws.state = .connected ws.lock.unlock() res := ws.write_to_server(handshake.str, handshake.len) if res <= 0 { l.f("Handshake failed.") } ws.read_handshake(seckey) ws.lock.lock() ws.state = .open ws.lock.unlock() ws.send_open_event() unsafe { handshake.free() nonce.free() free(uri) } return 0 } pub fn (mut ws Client) close(code int, message string){ if ws.state != .closed && ws.socket.sockfd > 1 { ws.lock.lock() ws.state = .closing ws.lock.unlock() mut code32 := 0 if code > 0 { _code := C.htons(code) message_len := message.len + 2 mut close_frame := [`0`].repeat(message_len) close_frame[0] = _code & 0xFF close_frame[1] = (_code >> 8) code32 = (close_frame[0] << 8) + close_frame[1] for i in 0..message.len { close_frame[i+2] = message[i] } ws.send_control_frame(.close, "CLOSE", close_frame) } else { ws.send_control_frame(.close, "CLOSE", []) } if ws.ssl != C.NULL { C.SSL_shutdown(ws.ssl) C.SSL_free(ws.ssl) if ws.sslctx != C.NULL { C.SSL_CTX_free(ws.sslctx) } } else { if C.shutdown(ws.socket.sockfd, C.SHUT_WR) == -1 { l.e("Unabled to shutdown websocket.") } mut buf := [`0`] for ws.read_from_server(buf.data, 1) > 0 { buf[0] = `\0` } unsafe { buf.free() } if C.close(ws.socket.sockfd) == -1 { //ws.send_close_event()(websocket, 1011, strerror(errno)); } } ws.fragments = [] ws.send_close_event() ws.lock.lock() ws.state = .closed ws.lock.unlock() unsafe { } //TODO impl autoreconnect } } pub fn (mut ws Client) write(payload byteptr, payload_len int, code OPCode) int { if ws.state != .open { ws.send_error_event("WebSocket closed. Cannot write.") goto free_data return -1 } header_len := 6 + if payload_len > 125 {2} else {0} + if payload_len > 0xffff {6} else {0} masking_key := create_masking_key() mut header := [`0`].repeat(header_len) mut bytes_written := -1 header[0] = (int(code) | 0x80) if payload_len <= 125 { header[1] = (payload_len | 0x80) header[2] = masking_key[0] header[3] = masking_key[1] header[4] = masking_key[2] header[5] = masking_key[3] } else if payload_len > 125 && payload_len <= 0xffff { len16 := C.htons(payload_len) header[1] = (126 | 0x80) C.memcpy(header.data+2, &len16, 2) header[4] = masking_key[0] header[5] = masking_key[1] header[6] = masking_key[2] header[7] = masking_key[3] } else if payload_len > 0xffff && payload_len <= 0xffffffffffffffff { // 65535 && 18446744073709551615 len64 := htonl64(u64(payload_len)) header[1] = (127 | 0x80) C.memcpy(header.data+2, len64, 8) header[10] = masking_key[0] header[11] = masking_key[1] header[12] = masking_key[2] header[13] = masking_key[3] } else { l.c("write: frame too large") ws.close(1009, "frame too large") goto free_data return -1 } frame_len := header_len + payload_len mut frame_buf := [`0`].repeat(frame_len) C.memcpy(frame_buf.data, header.data, header_len) C.memcpy(&frame_buf.data[header_len], payload, payload_len) for i in 0..payload_len { frame_buf[header_len+i] ^= masking_key[i % 4] & 0xff } bytes_written = ws.write_to_server(frame_buf.data, frame_len) if bytes_written == -1 { err := string(byteptr(C.strerror(C.errno))) l.e("write: there was an error writing data: ${err}") ws.send_error_event("Error writing data") goto free_data return -1 } l.d("write: ${bytes_written} bytes written.") free_data: unsafe { free(payload) frame_buf.free() header.free() masking_key.free() } return bytes_written } pub fn (mut ws Client) listen() { l.i("Starting listener...") for ws.state == .open { ws.read() } l.i("Listener stopped as websocket was closed.") } pub fn (mut ws Client) read() int { mut bytes_read := u64(0) initial_buffer := u64(256) mut header_len := 2 header_len_offset := 2 extended_payload16_end_byte := 4 extended_payload64_end_byte := 10 mut payload_len := u64(0) mut data := C.calloc(initial_buffer, 1)//[`0`].repeat(int(max_buffer)) mut frame := Frame{} mut frame_size := u64(header_len) for bytes_read < frame_size && ws.state == .open { byt := ws.read_from_server(data + int(bytes_read), 1) match byt { 0 { error := "server closed the connection." l.e("read: ${error}") ws.send_error_event(error) ws.close(1006, error) goto free_data return -1 } -1 { err := string(byteptr(C.strerror(C.errno))) l.e("read: error reading frame. ${err}") ws.send_error_event("error reading frame") goto free_data return -1 } else { bytes_read++ } } if bytes_read == u64(header_len_offset) { frame.fin = (data[0] & 0x80) == 0x80 frame.rsv1 = (data[0] & 0x40) == 0x40 frame.rsv2 = (data[0] & 0x20) == 0x20 frame.rsv3 = (data[0] & 0x10) == 0x10 frame.opcode = OPCode(int(data[0] & 0x7F)) frame.mask = (data[1] & 0x80) == 0x80 frame.payload_len = u64(data[1] & 0x7F) //masking key if frame.mask { frame.masking_key[0] = data[2] frame.masking_key[1] = data[3] frame.masking_key[2] = data[4] frame.masking_key[3] = data[5] } payload_len = frame.payload_len frame_size = u64(header_len) + payload_len } if frame.payload_len == u64(126) && bytes_read == u64(extended_payload16_end_byte) { header_len += 2 mut extended_payload_len := 0 extended_payload_len |= data[2] << 8 extended_payload_len |= data[3] << 0 //masking key if frame.mask { frame.masking_key[0] = data[4] frame.masking_key[1] = data[5] frame.masking_key[2] = data[6] frame.masking_key[3] = data[7] } payload_len = u64(extended_payload_len) frame_size = u64(header_len) + payload_len if frame_size > initial_buffer { l.d("reallocating: ${frame_size}") data = C.realloc(data, frame_size) } } else if frame.payload_len == u64(127) && bytes_read == u64(extended_payload64_end_byte) { header_len += 8 //TODO Not sure... mut extended_payload_len := u64(0) extended_payload_len |= u64(data[2]) << 56 extended_payload_len |= u64(data[3]) << 48 extended_payload_len |= u64(data[4]) << 40 extended_payload_len |= u64(data[5]) << 32 extended_payload_len |= u64(data[6]) << 24 extended_payload_len |= u64(data[7]) << 16 extended_payload_len |= u64(data[8]) << 8 extended_payload_len |= u64(data[9]) << 0 //masking key if frame.mask { frame.masking_key[0] = data[10] frame.masking_key[1] = data[11] frame.masking_key[2] = data[12] frame.masking_key[3] = data[13] } payload_len = extended_payload_len frame_size = u64(header_len) + payload_len if frame_size > initial_buffer { l.d("reallocating: ${frame_size}") data = C.realloc(data, frame_size) } } } // unmask the payload if frame.mask { for i in 0..payload_len { data[header_len+i] ^= frame.masking_key[i % 4] & 0xff } } if ws.fragments.len > 0 && frame.opcode in [.text_frame, .binary_frame] { ws.close(0, "") goto free_data return -1 } else if frame.opcode in [.text_frame, .binary_frame] { data_node: l.d("read: recieved text_frame or binary_frame") mut payload := malloc(sizeof(byte) * int(payload_len) + 1) if payload == C.NULL { l.f("out of memory") } C.memcpy(payload, &data[header_len], payload_len) if frame.fin { if ws.fragments.len > 0 { //join fragments ws.fragments << Fragment{ data: payload len: payload_len } mut frags := []Fragment{} mut size := u64(0) for f in ws.fragments { if f.len > 0 { frags << f size += f.len } } mut pl := malloc(sizeof(byte) * int(size)) if pl == C.NULL { l.f("out of memory") } mut by := 0 for i, f in frags { C.memcpy(pl + by, f.data, f.len) by += int(f.len) unsafe {free(f.data)} } payload = pl frame.opcode = ws.fragments[0].code payload_len = size //clear the fragments unsafe { ws.fragments.free() } ws.fragments = [] } payload[payload_len] = `\0` if frame.opcode == .text_frame && payload_len > 0 { if !utf8_validate(payload, int(payload_len)) { l.e("malformed utf8 payload") ws.send_error_event("Recieved malformed utf8.") ws.close(1007, "malformed utf8 payload") goto free_data return -1 } } message := Message { opcode: frame.opcode payload: payload payload_len: int(payload_len) } ws.send_message_event(message) } else { //fragment start. ws.fragments << Fragment{ data: payload len: payload_len code: frame.opcode } } unsafe { free(data) } return int(bytes_read) } else if frame.opcode == .continuation { l.d("read: continuation") if ws.fragments.len <= 0 { l.e("Nothing to continue.") ws.close(1002, "nothing to continue") goto free_data return -1 } goto data_node return 0 } else if frame.opcode == .ping { l.d("read: ping") if !frame.fin { ws.close(1002, "control message must not be fragmented") goto free_data return -1 } if frame.payload_len > 125 { ws.close(1002, "control frames must not exceed 125 bytes") goto free_data return -1 } mut payload := []byte{} if payload_len > 0 { payload = [`0`].repeat(int(payload_len)) C.memcpy(payload.data, &data[header_len], payload_len) } unsafe { free(data) } return ws.send_control_frame(.pong, "PONG", payload) } else if frame.opcode == .pong { if !frame.fin { ws.close(1002, "control message must not be fragmented") goto free_data return -1 } unsafe { free(data) } //got pong return 0 } else if frame.opcode == .close { l.d("read: close") if frame.payload_len > 125 { ws.close(1002, "control frames must not exceed 125 bytes") goto free_data return -1 } mut code := 0 mut reason := "" if payload_len > 2 { code = (int(data[header_len]) << 8) + int(data[header_len+1]) header_len += 2 payload_len -= 2 reason = string(&data[header_len]) l.i("Closing with reason: ${reason} & code: ${code}") if reason.len > 1 && !utf8_validate(reason.str, reason.len) { l.e("malformed utf8 payload") ws.send_error_event("Recieved malformed utf8.") ws.close(1007, "malformed utf8 payload") goto free_data return -1 } } unsafe{ free(data) } ws.close(code, reason) return 0 } l.e("read: Recieved unsupported opcode: ${frame.opcode} fin: ${frame.fin} uri: ${ws.uri}") ws.send_error_event("Recieved unsupported opcode: ${frame.opcode}") ws.close(1002, "Unsupported opcode") free_data: unsafe { free(data) } return -1 } fn (mut ws Client) send_control_frame(code OPCode, frame_typ string, payload []byte) int { if ws.socket.sockfd <= 0 { l.e("No socket opened.") goto free_data return -1 } header_len := 6 frame_len := header_len + payload.len mut control_frame := [`0`].repeat(frame_len) masking_key := create_masking_key() control_frame[0] = (int(code) | 0x80) control_frame[1] = (payload.len | 0x80) control_frame[2] = masking_key[0] control_frame[3] = masking_key[1] control_frame[4] = masking_key[2] control_frame[5] = masking_key[3] if code == .close { close_code := 1000 if payload.len > 2 { mut parsed_payload := [`0`].repeat(payload.len + 1) C.memcpy(parsed_payload.data, &payload[0], payload.len) parsed_payload[payload.len] = `\0` for i in 0..payload.len { control_frame[6+i] = (parsed_payload[i] ^ masking_key[i % 4]) & 0xff } unsafe { parsed_payload.free() } } } else { for i in 0..payload.len { control_frame[header_len + i] = (payload[i] ^ masking_key[i % 4]) & 0xff } } mut bytes_written := -1 bytes_written = ws.write_to_server(control_frame.data, frame_len) free_data: unsafe { control_frame.free() payload.free() masking_key.free() } match bytes_written { 0 { l.d("send_control_frame: remote host closed the connection.") return 0 } -1 { l.c("send_control_frame: error sending ${frame_typ} control frame.") return -1 } else { l.d("send_control_frame: wrote ${bytes_written} byte ${frame_typ} frame.") return bytes_written } } }