diff options
Diffstat (limited to 'src/rhrdtime')
-rw-r--r-- | src/rhrdtime/rhrdtime.go | 51 |
1 files changed, 46 insertions, 5 deletions
diff --git a/src/rhrdtime/rhrdtime.go b/src/rhrdtime/rhrdtime.go index 4bdf680..d4d24b3 100644 --- a/src/rhrdtime/rhrdtime.go +++ b/src/rhrdtime/rhrdtime.go @@ -28,10 +28,13 @@ import ( "encoding/json" "flag" "fmt" + "net" "net/http" - _ "net/http/pprof" + "os" + "sync" "time" + "github.com/coreos/go-systemd/activation" "github.com/gorilla/websocket" ) @@ -63,9 +66,11 @@ type ntpMessage struct { func handleNTPClient(w http.ResponseWriter, r *http.Request) { ws, err := websocket.Upgrade(w, r, nil, 1024, 1024) if _, ok := err.(websocket.HandshakeError); ok { + fmt.Println("Invalid request from", r.RemoteAddr) http.Error(w, "Not a websocket handshake", 400) return } else if err != nil { + fmt.Println("Invalid request from", r.RemoteAddr) fmt.Println(err) return } @@ -82,7 +87,7 @@ func handleNTPClient(w http.ResponseWriter, r *http.Request) { var n ntpMessage if err := json.Unmarshal(msg, &n); err != nil { - fmt.Println("Ignoring malformed (", err, ") NTP message from:", ws.RemoteAddr()) + fmt.Println("Ignoring malformed (", err, ") NTP message from", ws.RemoteAddr()) continue } @@ -106,7 +111,7 @@ func handleNTPClient(w http.ResponseWriter, r *http.Request) { } func main() { - addr_s := flag.String("addr", ":3000", "addr:port to listen on, default: ':3000'") + addr := flag.String("addr", "", "addr:port to listen on, default: use systemd socket activation or ':3000' when no sockets are passed") help := flag.Bool("help", false, "show usage") flag.Parse() @@ -115,7 +120,43 @@ func main() { return } - http.HandleFunc("/ntp", handleNTPClient) + var listeners []net.Listener + if *addr == "" { + lns, err := activation.Listeners(true) + if err != nil { + fmt.Println("error while getting socket from systemd:", err) + os.Exit(1) + } + listeners = lns + } - http.ListenAndServe(*addr_s, nil) + if len(listeners) == 0 { + if *addr == "" { + *addr = ":3000" + } + ln, err := net.Listen("tcp", *addr) + if err != nil { + fmt.Println("error listening on socket:", err) + os.Exit(1) + } + listeners = append(listeners, ln) + } + + for _, ln := range listeners { + fmt.Println("listening on:", ln.Addr().String()) + } + + var wg sync.WaitGroup + for _, l := range listeners { + ln := l + wg.Add(1) + go func() { + defer wg.Done() + mux := http.NewServeMux() + mux.HandleFunc("/ntp", handleNTPClient) + server := &http.Server{Handler: mux, ReadTimeout: 12 * time.Hour, WriteTimeout: 12 * time.Hour} + server.Serve(ln) + }() + } + wg.Wait() } |