diff options
-rw-r--r-- | rhimport/conf.go | 51 | ||||
-rw-r--r-- | rhimport/core.go | 192 | ||||
-rw-r--r-- | rhimport/fetcher.go | 294 | ||||
-rw-r--r-- | rhimport/importer.go | 528 | ||||
-rw-r--r-- | rhimport/rdxport_responses.go | 150 | ||||
-rw-r--r-- | rhimport/session.go | 328 | ||||
-rw-r--r-- | rhimport/session_store.go | 310 |
7 files changed, 1853 insertions, 0 deletions
diff --git a/rhimport/conf.go b/rhimport/conf.go new file mode 100644 index 0000000..b7e8f88 --- /dev/null +++ b/rhimport/conf.go @@ -0,0 +1,51 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +type ImportParamDefaults struct { + Channels uint + NormalizationLevel int + AutotrimLevel int + UseMetaData bool +} + +type Config struct { + RDXportEndpoint string + TempDir string + LocalFetchDir string + ImportParamDefaults +} + +func NewConfig(rdxportEndpoint, tempDir, localFetchDir string) (conf *Config) { + conf = new(Config) + conf.RDXportEndpoint = rdxportEndpoint + conf.TempDir = tempDir + conf.LocalFetchDir = localFetchDir + conf.ImportParamDefaults.Channels = 2 + conf.ImportParamDefaults.NormalizationLevel = -12 + conf.ImportParamDefaults.AutotrimLevel = 0 + conf.ImportParamDefaults.UseMetaData = true + return +} diff --git a/rhimport/core.go b/rhimport/core.go new file mode 100644 index 0000000..98acd35 --- /dev/null +++ b/rhimport/core.go @@ -0,0 +1,192 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "fmt" + "github.com/andelf/go-curl" + "helsinki.at/rhrd-go/rddb" + "io/ioutil" + "log" + "os" +) + +const ( + CART_MAX = 999999 + CUT_MAX = 999 +) + +var ( + bool2str = map[bool]string{false: "0", true: "1"} + rhl = log.New(os.Stderr, "[rhimport]\t", log.LstdFlags) + rhdl = log.New(ioutil.Discard, "[rhimport-dbg]\t", log.LstdFlags) +) + +func init() { + if _, exists := os.LookupEnv("RHIMPORT_DEBUG"); exists { + rhdl.SetOutput(os.Stderr) + } + curl.GlobalInit(curl.GLOBAL_ALL) + fetcherInit() +} + +type ProgressCB func(step int, stepName string, progress float64, userdata interface{}) bool +type DoneCB func(Result, interface{}) bool + +type Result struct { + ResponseCode int + ErrorString string + Cart uint + Cut uint +} + +type Context struct { + conf *Config + db *rddb.DBChan + UserName string + Password string + Trusted bool + ShowId uint + ClearShowCarts bool + GroupName string + Cart uint + ClearCart bool + Cut uint + Channels uint + NormalizationLevel int + AutotrimLevel int + UseMetaData bool + SourceUri string + SourceFile string + DeleteSourceFile bool + DeleteSourceDir bool + ProgressCallBack ProgressCB + ProgressCallBackData interface{} + Cancel <-chan bool +} + +func NewContext(conf *Config, db *rddb.DBChan) *Context { + ctx := new(Context) + ctx.conf = conf + ctx.db = db + ctx.UserName = "" + ctx.Password = "" + ctx.Trusted = false + ctx.ShowId = 0 + ctx.ClearShowCarts = false + ctx.GroupName = "" + ctx.Cart = 0 + ctx.ClearCart = false + ctx.Cut = 0 + ctx.Channels = conf.ImportParamDefaults.Channels + ctx.NormalizationLevel = conf.ImportParamDefaults.NormalizationLevel + ctx.AutotrimLevel = conf.ImportParamDefaults.AutotrimLevel + ctx.UseMetaData = conf.ImportParamDefaults.UseMetaData + ctx.SourceFile = "" + ctx.DeleteSourceFile = false + ctx.DeleteSourceDir = false + ctx.ProgressCallBack = nil + ctx.Cancel = nil + + return ctx +} + +func (ctx *Context) SanityCheck() error { + if ctx.UserName == "" { + return fmt.Errorf("empty Username is not allowed") + } + if ctx.Password == "" && !ctx.Trusted { + return fmt.Errorf("empty Password on untrusted control interface is not allowed") + } + if ctx.ShowId != 0 { + if ctx.ShowId != 0 && ctx.ShowId > CART_MAX { + return fmt.Errorf("ShowId %d is outside of allowed range (0 < show-id < %d)", ctx.ShowId, CART_MAX) + } + if ctx.Cart != 0 && ctx.Cart > CART_MAX { + return fmt.Errorf("Cart %d is outside of allowed range (0 < cart < %d)", ctx.Cart, CART_MAX) + } + return nil + } + if ctx.GroupName != "" { + ismusic, err := ctx.checkMusicGroup() + if err != nil { + return err + } + if !ismusic { + return fmt.Errorf("supplied GroupName '%s' is not a music pool", ctx.GroupName) + } + if ctx.Cart != 0 || ctx.Cut != 0 { + return fmt.Errorf("Cart and Cut must not be supplied when importing into a music group") + } + return nil + } + if ctx.Cart == 0 { + return fmt.Errorf("either ShowId, PoolName or CartNumber must be supplied") + } + if ctx.Cart > CART_MAX { + return fmt.Errorf("Cart %d is outside of allowed range (0 < cart < %d)", ctx.Cart, CART_MAX) + } + if ctx.Cut != 0 && ctx.Cut > CUT_MAX { + return fmt.Errorf("Cut %d is outside of allowed range (0 < cart < %d)", ctx.Cut, CUT_MAX) + } + if ctx.Channels != 1 && ctx.Channels != 2 { + return fmt.Errorf("channles must be 1 or 2") + } + return nil +} + +func (ctx *Context) getPassword(cached bool) (err error) { + ctx.Password, err = ctx.db.GetPassword(ctx.UserName, cached) + return +} + +func (ctx *Context) CheckPassword() (bool, error) { + return ctx.db.CheckPassword(ctx.UserName, ctx.Password) +} + +func (ctx *Context) getGroupOfCart() (err error) { + ctx.GroupName, err = ctx.db.GetGroupOfCart(ctx.Cart) + return +} + +func (ctx *Context) getShowInfo() (carts []uint, err error) { + ctx.GroupName, ctx.NormalizationLevel, ctx.AutotrimLevel, carts, err = ctx.db.GetShowInfo(ctx.ShowId) + ctx.Channels = 2 + ctx.UseMetaData = true + return +} + +func (ctx *Context) checkMusicGroup() (bool, error) { + return ctx.db.CheckMusicGroup(ctx.GroupName) +} + +func (ctx *Context) getMusicInfo() (err error) { + ctx.NormalizationLevel, ctx.AutotrimLevel, err = ctx.db.GetMusicInfo(ctx.GroupName) + ctx.Channels = 2 + ctx.UseMetaData = true + ctx.Cart = 0 + ctx.Cut = 0 + return +} diff --git a/rhimport/fetcher.go b/rhimport/fetcher.go new file mode 100644 index 0000000..2d99be4 --- /dev/null +++ b/rhimport/fetcher.go @@ -0,0 +1,294 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "fmt" + "github.com/andelf/go-curl" + "io/ioutil" + "mime" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" +) + +type FetcherCurlCBData struct { + basepath string + filename string + remotename string + *os.File +} + +func (self *FetcherCurlCBData) Cleanup() { + if self.File != nil { + self.File.Close() + } +} + +func curlHeaderCallback(ptr []byte, userdata interface{}) bool { + hdr := fmt.Sprintf("%s", ptr) + data := userdata.(*FetcherCurlCBData) + + if strings.HasPrefix(hdr, "Content-Disposition:") { + if mediatype, params, err := mime.ParseMediaType(strings.TrimPrefix(hdr, "Content-Disposition:")); err == nil { + if mediatype == "attachment" { + data.filename = data.basepath + "/" + params["filename"] + } + } + } + return true +} + +func curlWriteCallback(ptr []byte, userdata interface{}) bool { + data := userdata.(*FetcherCurlCBData) + if data.File == nil { + if data.filename == "" { + data.filename = data.basepath + "/" + data.remotename + } + fp, err := os.OpenFile(data.filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0600) + if err != nil { + rhl.Printf("Unable to create file %s: %s", data.filename, err) + return false + } + data.File = fp + } + if _, err := data.File.Write(ptr); err != nil { + rhl.Printf("Unable to write file %s: %s", data.filename, err) + return false + } + return true +} + +func fetchFileCurl(ctx *Context, res *Result, uri *url.URL) (err error) { + rhl.Printf("curl-based fetcher called for '%s'", ctx.SourceUri) + + easy := curl.EasyInit() + if easy != nil { + defer easy.Cleanup() + + easy.Setopt(curl.OPT_FOLLOWLOCATION, true) + easy.Setopt(curl.OPT_URL, ctx.SourceUri) + + cbdata := &FetcherCurlCBData{remotename: path.Base(uri.Path)} + defer cbdata.Cleanup() + if cbdata.basepath, err = ioutil.TempDir(ctx.conf.TempDir, "rhimportd-"); err != nil { + return + } + + easy.Setopt(curl.OPT_HEADERFUNCTION, curlHeaderCallback) + easy.Setopt(curl.OPT_HEADERDATA, cbdata) + + easy.Setopt(curl.OPT_WRITEFUNCTION, curlWriteCallback) + easy.Setopt(curl.OPT_WRITEDATA, cbdata) + + easy.Setopt(curl.OPT_NOPROGRESS, false) + easy.Setopt(curl.OPT_PROGRESSFUNCTION, func(dltotal, dlnow, ultotal, ulnow float64, userdata interface{}) bool { + if ctx.Cancel != nil && len(ctx.Cancel) > 0 { + rhl.Printf("downloading '%s' got canceled", ctx.SourceUri) + res.ResponseCode = http.StatusNoContent + res.ErrorString = "canceled" + return false + } + + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(1, "downloading", dlnow/dltotal, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + return true + }) + easy.Setopt(curl.OPT_PROGRESSDATA, ctx) + + if err = easy.Perform(); err != nil { + if cbdata.File != nil { + rhdl.Printf("Removing stale file: %s", cbdata.filename) + os.Remove(cbdata.filename) + os.Remove(path.Dir(cbdata.filename)) + } + if res.ResponseCode == http.StatusNoContent { + err = nil + } else { + err = fmt.Errorf("curl-fetcher('%s'): %s", ctx.SourceUri, err) + } + return + } + + ctx.SourceFile = cbdata.filename + ctx.DeleteSourceFile = true + ctx.DeleteSourceDir = true + } else { + err = fmt.Errorf("Error initializing libcurl") + } + + return +} + +func fetchFileLocal(ctx *Context, res *Result, uri *url.URL) (err error) { + rhl.Printf("Local fetcher called for '%s'", ctx.SourceUri) + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(1, "fetching", 0.0, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + + ctx.SourceFile = filepath.Join(ctx.conf.LocalFetchDir, path.Clean("/"+uri.Path)) + var src *os.File + if src, err = os.Open(ctx.SourceFile); err != nil { + res.ResponseCode = http.StatusBadRequest + res.ErrorString = fmt.Sprintf("local-file open(): %s", err) + return nil + } + if info, err := src.Stat(); err != nil { + res.ResponseCode = http.StatusBadRequest + res.ErrorString = fmt.Sprintf("local-file stat(): %s", err) + return nil + } else { + if info.IsDir() { + res.ResponseCode = http.StatusBadRequest + res.ErrorString = fmt.Sprintf("'%s' is a directory", ctx.SourceFile) + return nil + } + } + src.Close() + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(1, "fetching", 1.0, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + ctx.DeleteSourceFile = false + ctx.DeleteSourceDir = false + return +} + +func fetchFileFake(ctx *Context, res *Result, uri *url.URL) error { + rhdl.Printf("Fake fetcher for '%s'", ctx.SourceUri) + + if duration, err := strconv.ParseUint(uri.Host, 10, 32); err != nil { + err = nil + res.ResponseCode = http.StatusBadRequest + res.ErrorString = "invalid duration (must be a positive integer)" + } else { + for i := uint(0); i < uint(duration); i++ { + if ctx.Cancel != nil && len(ctx.Cancel) > 0 { + rhl.Printf("faking got canceled") + res.ResponseCode = http.StatusNoContent + res.ErrorString = "canceled" + return nil + } + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(1, "faking", float64(i)/float64(duration), ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + time.Sleep(100 * time.Millisecond) + } + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(1, "faking", 1.0, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + ctx.SourceFile = "/nonexistend/fake.mp3" + ctx.DeleteSourceFile = false + ctx.DeleteSourceDir = false + } + return nil +} + +type FetchFunc func(*Context, *Result, *url.URL) (err error) + +// TODO: implement fetchers for: +// archiv:// +// public:// +// home:// ????? +var ( + fetchers = map[string]FetchFunc{ + "local": fetchFileLocal, + "fake": fetchFileFake, + } + curlProtos = map[string]bool{ + "http": false, "https": false, + "ftp": false, "ftps": false, + } +) + +func fetcherInit() { + info := curl.VersionInfo(curl.VERSION_FIRST) + protos := info.Protocols + for _, proto := range protos { + if _, ok := curlProtos[proto]; ok { + rhdl.Printf("curl: enabling protocol %s", proto) + fetchers[proto] = fetchFileCurl + curlProtos[proto] = true + } else { + rhdl.Printf("curl: ignoring protocol %s", proto) + } + } + for proto, enabled := range curlProtos { + if !enabled { + rhl.Printf("curl: protocol %s is disabled because the installed library version doesn't support it!", proto) + } + } +} + +func checkPassword(ctx *Context, result *Result) (err error) { + ok := false + if ok, err = ctx.CheckPassword(); err != nil { + return + } + if !ok { + result.ResponseCode = http.StatusUnauthorized + result.ErrorString = "invalid username and/or password" + } + return +} + +func FetchFile(ctx *Context) (res *Result, err error) { + res = &Result{ResponseCode: http.StatusOK} + + var uri *url.URL + if uri, err = url.Parse(ctx.SourceUri); err != nil { + res.ResponseCode = http.StatusBadRequest + res.ErrorString = fmt.Sprintf("parsing uri: %s", err) + return res, nil + } + + if !ctx.Trusted { + if err = checkPassword(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + } + + if fetcher, ok := fetchers[uri.Scheme]; ok { + err = fetcher(ctx, res, uri) + } else { + err = fmt.Errorf("No fetcher for uri scheme '%s' found.", uri.Scheme) + } + return +} diff --git a/rhimport/importer.go b/rhimport/importer.go new file mode 100644 index 0000000..0702db4 --- /dev/null +++ b/rhimport/importer.go @@ -0,0 +1,528 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "bufio" + "bytes" + "fmt" + "github.com/andelf/go-curl" + "mime/multipart" + "net/http" + "os" + "path" +) + +func (self *Result) fromRDWebResult(rdres *RDWebResult) { + self.ResponseCode = rdres.ResponseCode + self.ErrorString = rdres.ErrorString + if rdres.AudioConvertError != 0 { + self.ErrorString += fmt.Sprintf(", Audio Convert Error: %d", rdres.AudioConvertError) + } +} + +func addCart(ctx *Context, res *Result) (err error) { + rhdl.Printf("importer: addCart() called for cart: %d", ctx.Cart) + + if ctx.GroupName == "" { + if err = ctx.getGroupOfCart(); err != nil { + return + } + } + + var b bytes.Buffer + w := multipart.NewWriter(&b) + + if err = w.WriteField("COMMAND", "12"); err != nil { + return + } + if err = w.WriteField("LOGIN_NAME", ctx.UserName); err != nil { + return + } + if err = w.WriteField("PASSWORD", ctx.Password); err != nil { + return + } + if err = w.WriteField("GROUP_NAME", ctx.GroupName); err != nil { + return + } + if err = w.WriteField("TYPE", "audio"); err != nil { + return + } + if ctx.Cart != 0 { + if err = w.WriteField("CART_NUMBER", fmt.Sprintf("%d", ctx.Cart)); err != nil { + return + } + } + w.Close() + + var resp *http.Response + if resp, err = sendPostRequest(ctx.conf.RDXportEndpoint, &b, w.FormDataContentType()); err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var rdres *RDWebResult + if rdres, err = NewRDWebResultFromXML(resp.Body); err != nil { + return + } + res.fromRDWebResult(rdres) + res.Cart = ctx.Cart + return + } + var cartadd *RDCartAdd + if cartadd, err = NewRDCartAddFromXML(resp.Body); err != nil { + return + } + res.ResponseCode = resp.StatusCode + res.ErrorString = "OK" + res.Cart = cartadd.Carts[0].Number + ctx.Cart = res.Cart + return +} + +func addCut(ctx *Context, res *Result) (err error) { + rhdl.Printf("importer: addCut() called for cart/cut: %d/%d", ctx.Cart, ctx.Cut) + var b bytes.Buffer + w := multipart.NewWriter(&b) + + if err = w.WriteField("COMMAND", "10"); err != nil { + return + } + if err = w.WriteField("LOGIN_NAME", ctx.UserName); err != nil { + return + } + if err = w.WriteField("PASSWORD", ctx.Password); err != nil { + return + } + if err = w.WriteField("CART_NUMBER", fmt.Sprintf("%d", ctx.Cart)); err != nil { + return + } + w.Close() + + var resp *http.Response + if resp, err = sendPostRequest(ctx.conf.RDXportEndpoint, &b, w.FormDataContentType()); err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var rdres *RDWebResult + if rdres, err = NewRDWebResultFromXML(resp.Body); err != nil { + return + } + res.fromRDWebResult(rdres) + res.Cart = ctx.Cart + res.Cut = ctx.Cut + return + } + var cutadd *RDCutAdd + if cutadd, err = NewRDCutAddFromXML(resp.Body); err != nil { + return + } + res.ResponseCode = resp.StatusCode + res.ErrorString = "OK" + res.Cart = ctx.Cart + res.Cut = cutadd.Cuts[0].Number + ctx.Cut = cutadd.Cuts[0].Number + return +} + +func removeCart(ctx *Context, res *Result) (err error) { + rhdl.Printf("importer: removeCart() called for cart: %d", ctx.Cart) + var b bytes.Buffer + w := multipart.NewWriter(&b) + + if err = w.WriteField("COMMAND", "13"); err != nil { + return + } + if err = w.WriteField("LOGIN_NAME", ctx.UserName); err != nil { + return + } + if err = w.WriteField("PASSWORD", ctx.Password); err != nil { + return + } + if err = w.WriteField("CART_NUMBER", fmt.Sprintf("%d", ctx.Cart)); err != nil { + return + } + w.Close() + + var resp *http.Response + if resp, err = sendPostRequest(ctx.conf.RDXportEndpoint, &b, w.FormDataContentType()); err != nil { + return + } + defer resp.Body.Close() + + var rdres *RDWebResult + if rdres, err = NewRDWebResultFromXML(resp.Body); err != nil { + return + } + res.fromRDWebResult(rdres) + res.Cart = ctx.Cart + return +} + +func removeCut(ctx *Context, res *Result) (err error) { + rhdl.Printf("importer: removeCut() called for cart/cut: %d/%d", ctx.Cart, ctx.Cut) + var b bytes.Buffer + w := multipart.NewWriter(&b) + + if err = w.WriteField("COMMAND", "11"); err != nil { + return + } + if err = w.WriteField("LOGIN_NAME", ctx.UserName); err != nil { + return + } + if err = w.WriteField("PASSWORD", ctx.Password); err != nil { + return + } + if err = w.WriteField("CART_NUMBER", fmt.Sprintf("%d", ctx.Cart)); err != nil { + return + } + if err = w.WriteField("CUT_NUMBER", fmt.Sprintf("%d", ctx.Cut)); err != nil { + return + } + w.Close() + + var resp *http.Response + if resp, err = sendPostRequest(ctx.conf.RDXportEndpoint, &b, w.FormDataContentType()); err != nil { + return + } + defer resp.Body.Close() + + var rdres *RDWebResult + if rdres, err = NewRDWebResultFromXML(resp.Body); err != nil { + return + } + res.fromRDWebResult(rdres) + res.Cart = ctx.Cart + res.Cut = ctx.Cut + return +} + +func sendPostRequest(url string, b *bytes.Buffer, contenttype string) (resp *http.Response, err error) { + var req *http.Request + if req, err = http.NewRequest("POST", url, b); err != nil { + return + } + if contenttype != "" { + req.Header.Set("Content-Type", contenttype) + } + + client := &http.Client{} + if resp, err = client.Do(req); err != nil { + return + } + return +} + +func importAudioCreateRequest(ctx *Context, easy *curl.CURL) (form *curl.Form, err error) { + form = curl.NewForm() + + if err = form.Add("COMMAND", "2"); err != nil { + return + } + if err = form.Add("LOGIN_NAME", ctx.UserName); err != nil { + return + } + if err = form.Add("PASSWORD", ctx.Password); err != nil { + return + } + if err = form.Add("CART_NUMBER", fmt.Sprintf("%d", ctx.Cart)); err != nil { + return + } + if err = form.Add("CUT_NUMBER", fmt.Sprintf("%d", ctx.Cut)); err != nil { + return + } + if err = form.Add("CHANNELS", fmt.Sprintf("%d", ctx.Channels)); err != nil { + return + } + if err = form.Add("NORMALIZATION_LEVEL", fmt.Sprintf("%d", ctx.NormalizationLevel)); err != nil { + return + } + if err = form.Add("AUTOTRIM_LEVEL", fmt.Sprintf("%d", ctx.AutotrimLevel)); err != nil { + return + } + if err = form.Add("USE_METADATA", bool2str[ctx.UseMetaData]); err != nil { + return + } + if err = form.AddFile("FILENAME", ctx.SourceFile); err != nil { + return + } + + return +} + +func importAudio(ctx *Context, res *Result) (err error) { + rhdl.Printf("importer: importAudio() called for cart/cut: %d/%d", ctx.Cart, ctx.Cut) + easy := curl.EasyInit() + + if easy != nil { + defer easy.Cleanup() + + easy.Setopt(curl.OPT_URL, ctx.conf.RDXportEndpoint) + easy.Setopt(curl.OPT_POST, true) + + var form *curl.Form + if form, err = importAudioCreateRequest(ctx, easy); err != nil { + return + } + easy.Setopt(curl.OPT_HTTPPOST, form) + easy.Setopt(curl.OPT_HTTPHEADER, []string{"Expect:"}) + + var resbody bytes.Buffer + easy.Setopt(curl.OPT_WRITEFUNCTION, func(ptr []byte, userdata interface{}) bool { + b := userdata.(*bytes.Buffer) + b.Write(ptr) + return true + }) + easy.Setopt(curl.OPT_WRITEDATA, &resbody) + + easy.Setopt(curl.OPT_NOPROGRESS, false) + easy.Setopt(curl.OPT_PROGRESSFUNCTION, func(dltotal, dlnow, ultotal, ulnow float64, userdata interface{}) bool { + if ctx.Cancel != nil && len(ctx.Cancel) > 0 { + res.ResponseCode = http.StatusNoContent + res.ErrorString = "canceled" + return false + } + + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(2, "importing", ulnow/ultotal, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + return true + }) + easy.Setopt(curl.OPT_PROGRESSDATA, ctx) + + if err = easy.Perform(); err != nil { + if res.ResponseCode == http.StatusNoContent { + rhl.Printf("import to cart/cat %d/%d got canceled", ctx.Cart, ctx.Cut) + res.Cart = ctx.Cart + res.Cut = ctx.Cut + err = nil + } else { + err = fmt.Errorf("importer: %s", err) + } + return + } + + var rdres *RDWebResult + if rdres, err = NewRDWebResultFromXML(bufio.NewReader(&resbody)); err != nil { + return + } + res.fromRDWebResult(rdres) + res.Cart = ctx.Cart + res.Cut = ctx.Cut + } else { + err = fmt.Errorf("Error initializing libcurl") + } + + return +} + +func addCartCut(ctx *Context, res *Result) (err error) { + if err = addCart(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + if err = addCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return removeCart(ctx, &Result{ResponseCode: http.StatusOK}) + } + return +} + +func removeAddCartCut(ctx *Context, res *Result) (err error) { + if err = removeCart(ctx, res); err != nil || (res.ResponseCode != http.StatusOK && res.ResponseCode != http.StatusNotFound) { + return + } + return addCartCut(ctx, res) +} + +func isCartMemberOfShow(ctx *Context, res *Result, carts []uint) (found bool) { + if ctx.Cart == 0 { + return true + } + for _, cart := range carts { + if cart == ctx.Cart { + return true + } + } + res.ResponseCode = http.StatusBadRequest + res.ErrorString = fmt.Sprintf("Requested cart %d is not a member of show: %d", ctx.Cart, ctx.ShowId) + res.Cart = ctx.Cart + return false +} + +func clearShowCarts(ctx *Context, res *Result, carts []uint) (err error) { + if ctx.ClearShowCarts { + origCart := ctx.Cart + for _, cart := range carts { + ctx.Cart = cart + if err = removeCart(ctx, res); err != nil || (res.ResponseCode != http.StatusOK && res.ResponseCode != http.StatusNotFound) { + return + } + } + ctx.Cart = origCart + } + return +} + +func addShowCartCut(ctx *Context, res *Result, carts []uint) (err error) { + if err = addCart(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + for _, cart := range carts { + if cart == ctx.Cart { + if err = addCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return removeCart(ctx, &Result{ResponseCode: http.StatusOK}) + } + return + } + } + if err = removeCart(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + res.ResponseCode = http.StatusForbidden + res.ErrorString = fmt.Sprintf("Show %d has no free carts left", ctx.ShowId) + return +} + +func cleanupFiles(ctx *Context) { + if ctx.DeleteSourceFile { + rhdl.Printf("importer: removing file: %s", ctx.SourceFile) + if err := os.Remove(ctx.SourceFile); err != nil { + rhl.Printf("importer: error removing source file: %s", err) + return + } + if ctx.DeleteSourceDir { + dir := path.Dir(ctx.SourceFile) + rhdl.Printf("importer: also removing directory: %s", dir) + if err := os.Remove(dir); err != nil { + rhl.Printf("importer: error removing source directory: %s", err) + } + } + } + return +} + +func ImportFile(ctx *Context) (res *Result, err error) { + defer cleanupFiles(ctx) + + rhl.Printf("importer: ImportFile called with: show-id: %d, pool-name: '%s', cart/cut: %d/%d", ctx.ShowId, ctx.GroupName, ctx.Cart, ctx.Cut) + + if ctx.ProgressCallBack != nil { + if keep := ctx.ProgressCallBack(2, "importing", 0.0, ctx.ProgressCallBackData); !keep { + ctx.ProgressCallBack = nil + } + } + + // TODO: on trusted interfaces we should call getPassword again with cached=false after 401's + if ctx.Trusted { + if err = ctx.getPassword(true); err != nil { + return + } + } + + rmCartOnErr := false + rmCutOnErr := false + res = &Result{ResponseCode: http.StatusOK} + if ctx.ShowId != 0 { // Import to a show + var showCarts []uint + if showCarts, err = ctx.getShowInfo(); err != nil { + return + } + if !isCartMemberOfShow(ctx, res, showCarts) { + return + } + if err = clearShowCarts(ctx, res, showCarts); err != nil || (res.ResponseCode != http.StatusOK && res.ResponseCode != http.StatusNotFound) { + return + } + if ctx.ClearCart && !ctx.ClearShowCarts { + if err = removeAddCartCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + } else { + if err = addShowCartCut(ctx, res, showCarts); err != nil || res.ResponseCode != http.StatusOK { + return + } + } + rmCartOnErr = true + } else if ctx.GroupName != "" { // Import to music pool + if err = ctx.getMusicInfo(); err != nil { + return + } + if err = addCartCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + rmCartOnErr = true + } else if ctx.Cart != 0 && ctx.Cut == 0 { // Import to Cart + if ctx.ClearCart { + if err = removeAddCartCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + rmCartOnErr = true + } else { + if err = addCut(ctx, res); err != nil { + return + } + if res.ResponseCode != http.StatusOK { + if err = addCartCut(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + return + } + rmCartOnErr = true + } else { + rmCutOnErr = true + } + } + } + + if ctx.Cart != 0 && ctx.Cut != 0 { // Import to specific Cut within Cart + if err = importAudio(ctx, res); err != nil || res.ResponseCode != http.StatusOK { + if err != nil { + rhl.Printf("Fileimport has failed (Cart/Cut %d/%d): %s", ctx.Cart, ctx.Cut, err) + } else { + rhl.Printf("Fileimport has failed (Cart/Cut %d/%d): %s", res.Cart, res.Cut, res.ErrorString) + } + // Try to clean up after failed import + rmres := Result{ResponseCode: http.StatusOK} + if rmCartOnErr { + if rerr := removeCart(ctx, &rmres); rerr != nil { + return + } + } else if rmCutOnErr { + if rerr := removeCut(ctx, &rmres); rerr != nil { + return + } + } + } else { + rhl.Printf("File got succesfully imported into Cart/Cut %d/%d", res.Cart, res.Cut) + } + } else { + res.ResponseCode = http.StatusBadRequest + res.ErrorString = "importer: The request doesn't contain enough information to be processed" + } + + return +} diff --git a/rhimport/rdxport_responses.go b/rhimport/rdxport_responses.go new file mode 100644 index 0000000..2871408 --- /dev/null +++ b/rhimport/rdxport_responses.go @@ -0,0 +1,150 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "encoding/xml" + "fmt" + "io" +) + +type RDWebResult struct { + ResponseCode int `xml:"ResponseCode"` + ErrorString string `xml:"ErrorString"` + AudioConvertError int `xml:"AudioConvertError"` +} + +func NewRDWebResultFromXML(data io.Reader) (res *RDWebResult, err error) { + decoder := xml.NewDecoder(data) + res = &RDWebResult{} + if xmlerr := decoder.Decode(res); xmlerr != nil { + err = fmt.Errorf("Error parsing XML response: %s", xmlerr) + return + } + return +} + +type RDCartAdd struct { + Carts []RDCart `xml:"cart"` +} + +type RDCart struct { + Number uint `xml:"number"` + Type string `xml:"type"` + GroupName string `xml:"groupName"` + Title string `xml:"title"` + Artist string `xml:"artist"` + Album string `xml:"album"` + Year string `xml:"year"` + Label string `xml:"label"` + Client string `xml:"client"` + Agency string `xml:"agency"` + Publisher string `xml:"publisher"` + Composer string `xml:"composer"` + UserDefined string `xml:"userDefined"` + UsageCode uint `xml:"usageCode"` + ForcedLength string `xml:"forcedLength"` + AverageLength string `xml:"averageLength"` + LengthDeviation string `xml:"lengthDeviation"` + AverageSegueLength string `xml:"averageSegueLenth"` + AverageHookLength string `xml:"averageHookLength"` + CutQuantity uint `xml:"cutQuantity"` + LastCutPlayed uint `xml:"lastCutPlayed"` + Validity uint `xml:"validity"` + EnforceLength bool `xml:"enforceLength"` + Asynchronous bool `xml:"asyncronous"` + Owner string `xml:"owner"` + MetadataDatetime string `xml:"metadataDatetime"` +} + +func NewRDCartAddFromXML(data io.Reader) (cart *RDCartAdd, err error) { + decoder := xml.NewDecoder(data) + cart = &RDCartAdd{} + if xmlerr := decoder.Decode(cart); xmlerr != nil { + err = fmt.Errorf("Error parsing XML response: %s", xmlerr) + return + } + return +} + +type RDCutAdd struct { + Cuts []RDCut `xml:"cut"` +} + +type RDCut struct { + Name string `xml:"cutName"` + CartNumber uint `xml:"cartNumber"` + Number uint `xml:"cutNumber"` + EverGreen bool `xml:"evergreen"` + Description string `xml:"description"` + OutCue string `xml:"outcue"` + ISRC string `xml:"isrc"` + ISCI string `xml:"isci"` + Length int `xml:"length"` + OriginDateTime string `xml:"originDatetime"` + StartDateTime string `xml:"startDatetime"` + EndDateTime string `xml:"endDatetime"` + Sunday bool `xml:"sun"` + Monday bool `xml:"mon"` + Tuesday bool `xml:"tue"` + Wednesday bool `xml:"wed"` + Thursday bool `xml:"thu"` + Friday bool `xml:"fri"` + Saturday bool `xml:"sat"` + StartDaypart string `xml:"startDaypart"` + EndDayPart string `xml:"endDaypart"` + OriginName string `xml:"originName"` + Weight int `xml:"weight"` + LastPlayDateTime string `xml:"lastPlayDatetime"` + PlayCounter uint `xml:"playCounter"` + LocalCounter uint `xml:"localCounter"` + Validiy uint `xml:"validity"` + CondingFormat int `xml:"codingFormat"` + SampleRate int `xml:"sampleRate"` + BitRate int `xml:"bitRate"` + Channels int `xml:"channels"` + PlayGain int `xml:"playGain"` + StartPoint int `xml:"startPoint"` + EndPoint int `xml:"endPoint"` + FadeUpPoint int `xml:"fadeupPoint"` + FadeDownPoint int `xml:"fadedownPoint"` + SegueStartPoint int `xml:"segueStartPoint"` + SegueEndPoint int `xml:"segueEndPoint"` + SegueGain int `xml:"segueGain"` + HookStartPoint int `xml:"hookStartPoint"` + HookEndPoint int `xml:"hookEndPoint"` + TalkStartPoint int `xml:"talkStartPoint"` + TalkEndPoint int `xml:"talkEndPoint"` +} + +func NewRDCutAddFromXML(data io.Reader) (cut *RDCutAdd, err error) { + decoder := xml.NewDecoder(data) + cut = &RDCutAdd{} + if xmlerr := decoder.Decode(cut); xmlerr != nil { + err = fmt.Errorf("Error parsing XML response: %s", xmlerr) + return + } + return +} diff --git a/rhimport/session.go b/rhimport/session.go new file mode 100644 index 0000000..66705ec --- /dev/null +++ b/rhimport/session.go @@ -0,0 +1,328 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "fmt" + "net/http" + "time" +) + +const ( + SESSION_NEW = iota + SESSION_RUNNING + SESSION_CANCELED + SESSION_DONE + SESSION_TIMEOUT +) + +type Session struct { + ctx Context + state int + removeFunc func() + done chan bool + quit chan bool + timer *time.Timer + cancelIntChan chan bool + progressIntChan chan ProgressData + doneIntChan chan Result + runChan chan time.Duration + cancelChan chan bool + addProgressChan chan sessionAddProgressHandlerRequest + addDoneChan chan sessionAddDoneHandlerRequest + progressCBs []*SessionProgressCB + doneCBs []*SessionDoneCB +} + +type SessionProgressCB struct { + cb ProgressCB + userdata interface{} +} + +type SessionDoneCB struct { + cb DoneCB + userdata interface{} +} + +type ProgressData struct { + Step int + StepName string + Progress float64 +} + +type sessionAddProgressHandlerResponse struct { + err error +} + +type sessionAddProgressHandlerRequest struct { + userdata interface{} + callback ProgressCB + response chan<- sessionAddProgressHandlerResponse +} + +type sessionAddDoneHandlerResponse struct { + err error +} + +type sessionAddDoneHandlerRequest struct { + userdata interface{} + callback DoneCB + response chan<- sessionAddDoneHandlerResponse +} + +func sessionProgressCallback(step int, stepName string, progress float64, userdata interface{}) bool { + out := userdata.(chan<- ProgressData) + out <- ProgressData{step, stepName, progress} + return true +} + +func sessionRun(ctx Context, done chan<- Result) { + if err := ctx.SanityCheck(); err != nil { + done <- Result{http.StatusBadRequest, err.Error(), 0, 0} + return + } + + if res, err := FetchFile(&ctx); err != nil { + done <- Result{http.StatusInternalServerError, err.Error(), 0, 0} + return + } else if res.ResponseCode != http.StatusOK { + done <- *res + return + } + + if res, err := ImportFile(&ctx); err != nil { + done <- Result{http.StatusInternalServerError, err.Error(), 0, 0} + return + } else { + done <- *res + return + } +} + +func (self *Session) run(timeout time.Duration) { + self.ctx.ProgressCallBack = sessionProgressCallback + self.ctx.ProgressCallBackData = (chan<- ProgressData)(self.progressIntChan) + self.ctx.Cancel = self.cancelIntChan + go sessionRun(self.ctx, self.doneIntChan) + self.state = SESSION_RUNNING + self.timer.Reset(timeout) + return +} + +func (self *Session) cancel() { + rhdl.Println("Session: canceling running import") + select { + case self.cancelIntChan <- true: + default: // session got canceled already?? + } + self.state = SESSION_CANCELED +} + +func (self *Session) addProgressHandler(userdata interface{}, cb ProgressCB) (resp sessionAddProgressHandlerResponse) { + if self.state != SESSION_NEW && self.state != SESSION_RUNNING { + resp.err = fmt.Errorf("session is already done/canceled") + } + self.progressCBs = append(self.progressCBs, &SessionProgressCB{cb, userdata}) + return +} + +func (self *Session) addDoneHandler(userdata interface{}, cb DoneCB) (resp sessionAddDoneHandlerResponse) { + if self.state != SESSION_NEW && self.state != SESSION_RUNNING { + resp.err = fmt.Errorf("session is already done/canceled") + } + self.doneCBs = append(self.doneCBs, &SessionDoneCB{cb, userdata}) + return +} + +func (self *Session) callProgressHandler(p *ProgressData) { + for _, cb := range self.progressCBs { + if cb.cb != nil { + if keep := cb.cb(p.Step, p.StepName, p.Progress, cb.userdata); !keep { + cb.cb = nil + } + } + } +} + +func (self *Session) callDoneHandler(r *Result) { + for _, cb := range self.doneCBs { + if cb.cb != nil { + if keep := cb.cb(*r, cb.userdata); !keep { + cb.cb = nil + } + } + } +} + +func (self *Session) dispatchRequests() { + defer func() { self.done <- true }() + for { + select { + case <-self.quit: + if self.state == SESSION_RUNNING { + self.cancel() + } + return + case <-self.timer.C: + if self.state == SESSION_RUNNING { + self.cancel() + } + self.state = SESSION_TIMEOUT + r := &Result{500, "session timed out", 0, 0} + self.callDoneHandler(r) + if self.removeFunc != nil { + self.removeFunc() + } + case t := <-self.runChan: + if self.state == SESSION_NEW { + self.run(t) + } + case <-self.cancelChan: + if self.state == SESSION_RUNNING { + self.cancel() + } + case req := <-self.addProgressChan: + req.response <- self.addProgressHandler(req.userdata, req.callback) + case req := <-self.addDoneChan: + req.response <- self.addDoneHandler(req.userdata, req.callback) + case p := <-self.progressIntChan: + self.callProgressHandler(&p) + case r := <-self.doneIntChan: + if self.state != SESSION_TIMEOUT { + self.timer.Stop() + self.state = SESSION_DONE + self.callDoneHandler(&r) + if self.removeFunc != nil { + self.removeFunc() + } + } + } + } +} + +// ********************************************************* +// Public Interface + +type SessionChan struct { + runChan chan<- time.Duration + cancelChan chan<- bool + addProgressChan chan<- sessionAddProgressHandlerRequest + addDoneChan chan<- sessionAddDoneHandlerRequest +} + +func (self *SessionChan) Run(timeout time.Duration) { + select { + case self.runChan <- timeout: + default: // command is already pending or session is about to be closed/removed + } +} + +func (self *SessionChan) Cancel() { + select { + case self.cancelChan <- true: + default: // cancel is already pending or session is about to be closed/removed + } +} + +func (self *SessionChan) AddProgressHandler(userdata interface{}, cb ProgressCB) error { + resCh := make(chan sessionAddProgressHandlerResponse) + req := sessionAddProgressHandlerRequest{} + req.userdata = userdata + req.callback = cb + req.response = resCh + select { + case self.addProgressChan <- req: + default: + return fmt.Errorf("session is about to be closed/removed") + } + + res := <-resCh + return res.err +} + +func (self *SessionChan) AddDoneHandler(userdata interface{}, cb DoneCB) error { + resCh := make(chan sessionAddDoneHandlerResponse) + req := sessionAddDoneHandlerRequest{} + req.userdata = userdata + req.callback = cb + req.response = resCh + select { + case self.addDoneChan <- req: + default: + return fmt.Errorf("session is about to be closed/removed") + } + + res := <-resCh + return res.err +} + +// ********************************************************* +// Semi-Public Interface (only used by sessionStore) + +func (self *Session) getInterface() *SessionChan { + ch := &SessionChan{} + ch.runChan = self.runChan + ch.cancelChan = self.cancelChan + ch.addProgressChan = self.addProgressChan + ch.addDoneChan = self.addDoneChan + return ch +} + +func (self *Session) cleanup() { + self.quit <- true + rhdl.Printf("waiting for session to close") + <-self.done + close(self.quit) + close(self.done) + self.timer.Stop() + // don't close the channels we give out because this might lead to a panic if + // somebody wites to an already removed session + // close(self.cancelIntChan) + // close(self.progressIntChan) + // close(self.doneIntChan) + // close(self.runChan) + // close(self.cancelChan) + // close(self.addProgressChan) + // close(self.addDoneChan) + rhdl.Printf("session is now cleaned up") +} + +func newSession(ctx *Context, removeFunc func()) (session *Session) { + session = new(Session) + session.state = SESSION_NEW + session.removeFunc = removeFunc + session.ctx = *ctx + session.done = make(chan bool) + session.timer = time.NewTimer(10 * time.Second) + session.cancelIntChan = make(chan bool, 1) + session.progressIntChan = make(chan ProgressData, 10) + session.doneIntChan = make(chan Result, 1) + session.runChan = make(chan time.Duration, 1) + session.cancelChan = make(chan bool, 1) + session.addProgressChan = make(chan sessionAddProgressHandlerRequest, 10) + session.addDoneChan = make(chan sessionAddDoneHandlerRequest, 10) + go session.dispatchRequests() + return +} diff --git a/rhimport/session_store.go b/rhimport/session_store.go new file mode 100644 index 0000000..e47366e --- /dev/null +++ b/rhimport/session_store.go @@ -0,0 +1,310 @@ +// +// rhimportd +// +// The Radio Helsinki Rivendell Import Daemon +// +// +// Copyright (C) 2015-2016 Christian Pointner <equinox@helsinki.at> +// +// This file is part of rhimportd. +// +// rhimportd is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// any later version. +// +// rhimportd is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with rhimportd. If not, see <http://www.gnu.org/licenses/>. +// + +package rhimport + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "helsinki.at/rhrd-go/rddb" + "net/http" +) + +type newSessionResponse struct { + id string + session *SessionChan + responsecode int + errorstring string +} + +type newSessionRequest struct { + ctx *Context + refId string + response chan newSessionResponse +} + +type getSessionResponse struct { + session *SessionChan + refId string + responsecode int + errorstring string +} + +type getSessionRequest struct { + user string + id string + refId string + response chan getSessionResponse +} + +type listSessionsResponse struct { + sessions map[string]string + responsecode int + errorstring string +} + +type listSessionsRequest struct { + user string + password string + trusted bool + response chan listSessionsResponse +} + +type removeSessionResponse struct { + responsecode int + errorstring string +} + +type removeSessionRequest struct { + user string + id string + response chan removeSessionResponse +} + +type SessionStoreElement struct { + s *Session + refId string +} + +type SessionStore struct { + store map[string]map[string]*SessionStoreElement + conf *Config + db *rddb.DBChan + quit chan bool + done chan bool + newChan chan newSessionRequest + getChan chan getSessionRequest + listChan chan listSessionsRequest + removeChan chan removeSessionRequest +} + +func generateSessionId() (string, error) { + var b [32]byte + if _, err := rand.Read(b[:]); err != nil { + return "", err + } + return base64.RawStdEncoding.EncodeToString(b[:]), nil +} + +func (self *SessionStore) new(ctx *Context, refId string) (resp newSessionResponse) { + resp.responsecode = http.StatusOK + resp.errorstring = "OK" + if !ctx.Trusted { + if ok, err := self.db.CheckPassword(ctx.UserName, ctx.Password); err != nil { + resp.responsecode = http.StatusInternalServerError + resp.errorstring = err.Error() + return + } else if !ok { + resp.responsecode = http.StatusUnauthorized + resp.errorstring = "invalid username and/or password" + return + } + } + if id, err := generateSessionId(); err != nil { + resp.responsecode = http.StatusInternalServerError + resp.errorstring = err.Error() + } else { + resp.id = id + if _, exists := self.store[ctx.UserName]; !exists { + self.store[ctx.UserName] = make(map[string]*SessionStoreElement) + } + ctx.conf = self.conf + ctx.db = self.db + s := &SessionStoreElement{newSession(ctx, func() { self.GetInterface().Remove(ctx.UserName, resp.id) }), refId} + self.store[ctx.UserName][resp.id] = s + resp.session = self.store[ctx.UserName][resp.id].s.getInterface() + rhdl.Printf("SessionStore: created session for '%s' -> %s", ctx.UserName, resp.id) + } + return +} + +func (self *SessionStore) get(user, id string) (resp getSessionResponse) { + resp.responsecode = http.StatusOK + resp.errorstring = "OK" + if session, exists := self.store[user][id]; exists { + resp.session = session.s.getInterface() + resp.refId = session.refId + } else { + resp.responsecode = http.StatusNotFound + resp.errorstring = fmt.Sprintf("SessionStore: session '%s/%s' not found", user, id) + } + return +} + +func (self *SessionStore) list(user, password string, trusted bool) (resp listSessionsResponse) { + resp.responsecode = http.StatusOK + resp.errorstring = "OK" + if !trusted { + if ok, err := self.db.CheckPassword(user, password); err != nil { + resp.responsecode = http.StatusInternalServerError + resp.errorstring = err.Error() + return + } else if !ok { + resp.responsecode = http.StatusUnauthorized + resp.errorstring = "invalid username and/or password" + return + } + } + resp.sessions = make(map[string]string) + if sessions, exists := self.store[user]; exists { + for id, e := range sessions { + resp.sessions[id] = e.refId + } + } + return +} + +func (self *SessionStore) remove(user, id string) (resp removeSessionResponse) { + resp.responsecode = http.StatusOK + resp.errorstring = "OK" + if session, exists := self.store[user][id]; exists { + go session.s.cleanup() // cleanup could take a while -> don't block all the other stuff + delete(self.store[user], id) + rhdl.Printf("SessionStore: removed session '%s/%s'", user, id) + if userstore, exists := self.store[user]; exists { + if len(userstore) == 0 { + delete(self.store, user) + rhdl.Printf("SessionStore: removed user '%s'", user) + } + } + } else { + resp.responsecode = http.StatusNotFound + resp.errorstring = fmt.Sprintf("SessionStore: session '%s/%s' not found", user, id) + } + return +} + +func (self *SessionStore) dispatchRequests() { + defer func() { self.done <- true }() + for { + select { + case <-self.quit: + return + case req := <-self.newChan: + req.response <- self.new(req.ctx, req.refId) + case req := <-self.getChan: + req.response <- self.get(req.user, req.id) + case req := <-self.listChan: + req.response <- self.list(req.user, req.password, req.trusted) + case req := <-self.removeChan: + req.response <- self.remove(req.user, req.id) + } + } +} + +// ********************************************************* +// Public Interface + +type SessionStoreChan struct { + newChan chan<- newSessionRequest + getChan chan<- getSessionRequest + listChan chan listSessionsRequest + removeChan chan<- removeSessionRequest +} + +func (self *SessionStoreChan) New(ctx *Context, refId string) (string, *SessionChan, int, string) { + resCh := make(chan newSessionResponse) + req := newSessionRequest{} + req.ctx = ctx + req.refId = refId + req.response = resCh + self.newChan <- req + + res := <-resCh + return res.id, res.session, res.responsecode, res.errorstring +} + +func (self *SessionStoreChan) Get(user, id string) (*SessionChan, string, int, string) { + resCh := make(chan getSessionResponse) + req := getSessionRequest{} + req.user = user + req.id = id + req.response = resCh + self.getChan <- req + + res := <-resCh + return res.session, res.refId, res.responsecode, res.errorstring +} + +func (self *SessionStoreChan) List(user, password string, trusted bool) (map[string]string, int, string) { + resCh := make(chan listSessionsResponse) + req := listSessionsRequest{} + req.user = user + req.password = password + req.trusted = trusted + req.response = resCh + self.listChan <- req + + res := <-resCh + return res.sessions, res.responsecode, res.errorstring +} + +func (self *SessionStoreChan) Remove(user, id string) (int, string) { + resCh := make(chan removeSessionResponse) + req := removeSessionRequest{} + req.user = user + req.id = id + req.response = resCh + self.removeChan <- req + + res := <-resCh + return res.responsecode, res.errorstring +} + +func (self *SessionStore) GetInterface() *SessionStoreChan { + ch := &SessionStoreChan{} + ch.newChan = self.newChan + ch.getChan = self.getChan + ch.listChan = self.listChan + ch.removeChan = self.removeChan + return ch +} + +func (self *SessionStore) Cleanup() { + self.quit <- true + <-self.done + close(self.quit) + close(self.done) + close(self.newChan) + close(self.getChan) + close(self.listChan) + close(self.removeChan) +} + +func NewSessionStore(conf *Config, db *rddb.DBChan) (store *SessionStore, err error) { + store = new(SessionStore) + store.conf = conf + store.db = db + store.quit = make(chan bool) + store.done = make(chan bool) + store.store = make(map[string]map[string]*SessionStoreElement) + store.newChan = make(chan newSessionRequest, 10) + store.getChan = make(chan getSessionRequest, 10) + store.listChan = make(chan listSessionsRequest, 10) + store.removeChan = make(chan removeSessionRequest, 10) + + go store.dispatchRequests() + return +} |