diff options
author | Christian Pointner <equinox@helsinki.at> | 2015-12-14 15:29:05 (GMT) |
---|---|---|
committer | Christian Pointner <equinox@helsinki.at> | 2015-12-14 15:29:05 (GMT) |
commit | 638d17b5051f80b36ffa641366440b8266eaac4c (patch) | |
tree | b0352d43032b27ef2342cb499c241a9823f45224 | |
parent | 090f1621563a2b5069dd54ba7e67b7c07cc79f13 (diff) |
switched to mymysql go driver and added escape for log table name
-rw-r--r-- | rddb.go | 23 |
1 files changed, 17 insertions, 6 deletions
@@ -27,13 +27,14 @@ package rhimport import ( "database/sql" "fmt" - _ "github.com/go-sql-driver/mysql" + "github.com/ziutek/mymysql/godrv" "regexp" "strings" ) var ( showMacroRe = regexp.MustCompile(`^LL 1 ([^ ]+) 0\!$`) + mysqlTableNameRe = regexp.MustCompile(`^[_0-9a-zA-Z-]+$`) ) const ( @@ -114,8 +115,10 @@ type RdDb struct { } func (self *RdDb) init(conf *Config) (err error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8", conf.db_user, conf.db_passwd, conf.db_host, conf.db_db) - if self.dbh, err = sql.Open("mysql", dsn); err != nil { + godrv.Register("SET CHARACTER SET utf8;") + + dsn := fmt.Sprintf("tcp:%s:3306*%s/%s/%s", conf.db_host, conf.db_db, conf.db_user, conf.db_passwd) + if self.dbh, err = sql.Open("mymysql", dsn); err != nil { return } @@ -196,12 +199,20 @@ func (self *RdDb) getGroupOfCart(cart uint) (result getGroupOfCartResult) { return } -func (self *RdDb) getLogTableName(log string) string { - return strings.Replace(log, " ", "_", -1) + "_LOG" // TODO: this should get escaped for mySQL but golang doesn't support it!!! +func (self *RdDb) getLogTableName(log string) (logtable string, err error) { + logtable = strings.Replace(log, " ", "_", -1) + "_LOG" + if !mysqlTableNameRe.MatchString(logtable) { + return "", fmt.Errorf("the log table name contains illegal charecters: %s", logtable) + } + return } func (self *RdDb) getShowCarts(log string, low_cart, high_cart int) (carts []uint, err error) { - q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", self.getLogTableName(log), low_cart, high_cart) + var logtable string + if logtable, err = self.getLogTableName(log); err != nil { + return + } + q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", logtable, low_cart, high_cart) var rows *sql.Rows if rows, err = self.dbh.Query(q); err != nil { return |