// // rhctl // // Copyright (C) 2009-2016 Christian Pointner // // This file is part of rhctl. // // rhctl is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // any later version. // // rhctl is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with rhctl. If not, see . // package main import ( "encoding/json" "fmt" "io" "io/ioutil" "net/http" "strings" "github.com/gorilla/websocket" ) type webSocketRequestData struct { Command string `json:"COMMAND"` Args []string `json:"ARGS"` } type webSocketResponseBaseData struct { ResponseCode int `json:"RESPONSE_CODE"` Type string `json:"TYPE"` ErrorString string `json:"ERROR_STRING"` } type webSocketResponseStateData struct { webSocketResponseBaseData State interface{} `json:"STATE"` } type webSocketResponseUpdateData struct { webSocketResponseBaseData Update interface{} `json:"UPDATE"` } func sendWebSocketResponse(ws *websocket.Conn, rd interface{}) { if err := ws.WriteJSON(rd); err != nil { rhdl.Println("Web(socket) client", ws.RemoteAddr(), "write error:", err) } } func sendWebSocketErrorResponse(ws *websocket.Conn, code int, errStr string) { rd := &webSocketResponseBaseData{} rd.ResponseCode = code rd.Type = "error" rd.ErrorString = errStr sendWebSocketResponse(ws, rd) } func sendWebSocketAckResponse(ws *websocket.Conn) { rd := &webSocketResponseBaseData{} rd.ResponseCode = http.StatusOK rd.Type = "ack" rd.ErrorString = "OK" sendWebSocketResponse(ws, rd) } func sendWebSocketStateResponse(ws *websocket.Conn, state State) { rd := &webSocketResponseStateData{} rd.ResponseCode = http.StatusOK rd.Type = "state" rd.ErrorString = "OK" rd.State = state sendWebSocketResponse(ws, rd) } func sendWebSocketServerStateResponse(ws *websocket.Conn, state ServerState) { rd := &webSocketResponseStateData{} rd.ResponseCode = http.StatusOK rd.Type = "server-state" rd.ErrorString = "OK" rd.State = state sendWebSocketResponse(ws, rd) } func sendWebSocketSwitchStateResponse(ws *websocket.Conn, state SwitchState) { rd := &webSocketResponseStateData{} rd.ResponseCode = http.StatusOK rd.Type = "switch-state" rd.ErrorString = "OK" rd.State = state sendWebSocketResponse(ws, rd) } func sendWebSocketSwitchUpdateResponse(ws *websocket.Conn, update SwitchUpdate) { rd := &webSocketResponseUpdateData{} rd.ResponseCode = http.StatusOK rd.Type = "switch-update" rd.ErrorString = "OK" rd.Update = update sendWebSocketResponse(ws, rd) } func sendWebSocketUpdateData(ws *websocket.Conn, data interface{}) { switch data.(type) { case State: sendWebSocketStateResponse(ws, data.(State)) case ServerState: sendWebSocketServerStateResponse(ws, data.(ServerState)) case SwitchState: sendWebSocketSwitchStateResponse(ws, data.(SwitchState)) case SwitchUpdate: sendWebSocketSwitchUpdateResponse(ws, data.(SwitchUpdate)) default: sendWebSocketErrorResponse(ws, http.StatusInternalServerError, "got invalid data update") } } func webSocketSessionHandler(reqchan <-chan webSocketRequestData, ws *websocket.Conn, ctrl *SwitchControl) { defer ws.Close() updateC := ctrl.Updates.Sub() defer ctrl.Updates.Unsub(updateC) for { select { case reqdata, ok := <-reqchan: if !ok { return } switch reqdata.Command { case "state": resp := make(chan interface{}) ctrl.Commands <- &Command{Type: CmdState, Response: resp} result := <-resp switch result.(type) { case State: sendWebSocketStateResponse(ws, result.(State)) case error: sendWebSocketErrorResponse(ws, http.StatusInternalServerError, result.(error).Error()) default: sendWebSocketErrorResponse(ws, http.StatusInternalServerError, fmt.Sprintf("invalid response of type %T: %+v", result, result)) } case "subscribe": if len(reqdata.Args) == 1 { switch strings.ToLower(reqdata.Args[0]) { case "state": ctrl.Updates.AddSub(updateC, "state") sendWebSocketAckResponse(ws) case "server": ctrl.Updates.AddSub(updateC, "server:state") sendWebSocketAckResponse(ws) case "switch": ctrl.Updates.AddSub(updateC, "switch:state") sendWebSocketAckResponse(ws) case "audio": fallthrough case "gpi": fallthrough case "oc": fallthrough case "relay": fallthrough case "silence": ctrl.Updates.AddSub(updateC, "switch:"+reqdata.Args[0]) sendWebSocketAckResponse(ws) default: sendWebSocketErrorResponse(ws, http.StatusInternalServerError, fmt.Sprintf("unknown message type '%s'", reqdata.Args[0])) } } else { sendWebSocketErrorResponse(ws, http.StatusInternalServerError, "subscribe takes exactly one argument") } default: sendWebSocketErrorResponse(ws, http.StatusBadRequest, fmt.Sprintf("unknown command '%s'", reqdata.Command)) } case update := <-updateC: sendWebSocketUpdateData(ws, update) } } } func webSocketHandler(ctrl *SwitchControl, w http.ResponseWriter, r *http.Request) { ws, err := websocket.Upgrade(w, r, nil, 64*1024, 64*1024) if _, ok := err.(websocket.HandshakeError); ok { http.Error(w, "Not a websocket handshake", 400) return } else if err != nil { rhdl.Println("Web(socket) client", ws.RemoteAddr(), "error:", err) return } rhdl.Println("Web(socket) client", ws.RemoteAddr(), "connected") reqchan := make(chan webSocketRequestData) go webSocketSessionHandler(reqchan, ws, ctrl) defer close(reqchan) for { t, r, err := ws.NextReader() if err != nil { rhdl.Println("Web(socket) Client", ws.RemoteAddr(), "disconnected:", err) return } switch t { case websocket.TextMessage: var reqdata webSocketRequestData if err := json.NewDecoder(r).Decode(&reqdata); err != nil { if err == io.EOF { err = io.ErrUnexpectedEOF } rhdl.Println("Web(socket) client", ws.RemoteAddr(), "request error:", err) sendWebSocketErrorResponse(ws, http.StatusBadRequest, err.Error()) return } // rhdl.Printf("Web(socket) client %s got: %+v", ws.RemoteAddr(), reqdata) reqchan <- reqdata case websocket.BinaryMessage: sendWebSocketErrorResponse(ws, http.StatusBadRequest, "binary messages are not allowed") io.Copy(ioutil.Discard, r) // consume all the data } } }