diff options
-rw-r--r-- | conf.go | 70 |
1 files changed, 53 insertions, 17 deletions
@@ -25,6 +25,9 @@ package rhimport import ( + "database/sql" + "fmt" + _ "github.com/go-sql-driver/mysql" "github.com/vaughan0/go-ini" ) @@ -40,17 +43,18 @@ type getPasswordRequest struct { } type Config struct { - configfile string - RDXportEndpoint string - db_host string - db_user string - db_passwd string - db_db string - // TODO: reference to sql connection - password_cache map[string]string - getPasswordChan chan getPasswordRequest - quit chan bool - done chan bool + configfile string + RDXportEndpoint string + db_host string + db_user string + db_passwd string + db_db string + dbh *sql.DB + password_cache map[string]string + getPasswordChan chan getPasswordRequest + dbGetPasswordStmt *sql.Stmt + quit chan bool + done chan bool } func get_ini_value(file ini.File, section string, key string, dflt string) string { @@ -74,9 +78,35 @@ func (self *Config) read_config_file() error { return nil } +func (self *Config) init_database() (err error) { + dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s", self.db_user, self.db_passwd, self.db_host, self.db_db) + if self.dbh, err = sql.Open("mysql", dsn); err != nil { + return + } + if self.dbGetPasswordStmt, err = self.dbh.Prepare("select PASSWORD from USERS where LOGIN_NAME = ?"); err != nil { + return + } + + return +} + func (self *Config) getPassword(username string, cached bool) (pwd string, err error) { - //TODO: actually fetch password from cache or DB - pwd = "12345" + + if cached { + pwd = self.password_cache[username] + } + + if pwd == "" { + err = self.dbGetPasswordStmt.QueryRow(username).Scan(&pwd) + if err != nil { + if err == sql.ErrNoRows { + err = fmt.Errorf("user '%s' not known by rivendell", username) + } + return + } + self.password_cache[username] = pwd + } + return } @@ -104,14 +134,18 @@ func (self *Config) Cleanup() { close(self.quit) close(self.done) close(self.getPasswordChan) - //TODO : close db connection + if self.dbh != nil { + self.dbh.Close() + } + if self.dbGetPasswordStmt != nil { + self.dbGetPasswordStmt.Close() + } } func NewConfig(configfile, rdxport_endpoint *string) (conf *Config, err error) { conf = new(Config) conf.configfile = *configfile - err = conf.read_config_file() - if err != nil { + if err = conf.read_config_file(); err != nil { return } conf.quit = make(chan bool) @@ -120,7 +154,9 @@ func NewConfig(configfile, rdxport_endpoint *string) (conf *Config, err error) { conf.password_cache = make(map[string]string) conf.getPasswordChan = make(chan getPasswordRequest) - //TODO : init db connection + if err = conf.init_database(); err != nil { + return + } go conf.dispatchRequests() return |