From 638d17b5051f80b36ffa641366440b8266eaac4c Mon Sep 17 00:00:00 2001 From: Christian Pointner Date: Mon, 14 Dec 2015 16:29:05 +0100 Subject: switched to mymysql go driver and added escape for log table name diff --git a/rddb.go b/rddb.go index 8b72f58..34019e8 100644 --- a/rddb.go +++ b/rddb.go @@ -27,13 +27,14 @@ package rhimport import ( "database/sql" "fmt" - _ "github.com/go-sql-driver/mysql" + "github.com/ziutek/mymysql/godrv" "regexp" "strings" ) var ( showMacroRe = regexp.MustCompile(`^LL 1 ([^ ]+) 0\!$`) + mysqlTableNameRe = regexp.MustCompile(`^[_0-9a-zA-Z-]+$`) ) const ( @@ -114,8 +115,10 @@ type RdDb struct { } func (self *RdDb) init(conf *Config) (err error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8", conf.db_user, conf.db_passwd, conf.db_host, conf.db_db) - if self.dbh, err = sql.Open("mysql", dsn); err != nil { + godrv.Register("SET CHARACTER SET utf8;") + + dsn := fmt.Sprintf("tcp:%s:3306*%s/%s/%s", conf.db_host, conf.db_db, conf.db_user, conf.db_passwd) + if self.dbh, err = sql.Open("mymysql", dsn); err != nil { return } @@ -196,12 +199,20 @@ func (self *RdDb) getGroupOfCart(cart uint) (result getGroupOfCartResult) { return } -func (self *RdDb) getLogTableName(log string) string { - return strings.Replace(log, " ", "_", -1) + "_LOG" // TODO: this should get escaped for mySQL but golang doesn't support it!!! +func (self *RdDb) getLogTableName(log string) (logtable string, err error) { + logtable = strings.Replace(log, " ", "_", -1) + "_LOG" + if !mysqlTableNameRe.MatchString(logtable) { + return "", fmt.Errorf("the log table name contains illegal charecters: %s", logtable) + } + return } func (self *RdDb) getShowCarts(log string, low_cart, high_cart int) (carts []uint, err error) { - q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", self.getLogTableName(log), low_cart, high_cart) + var logtable string + if logtable, err = self.getLogTableName(log); err != nil { + return + } + q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", logtable, low_cart, high_cart) var rows *sql.Rows if rows, err = self.dbh.Query(q); err != nil { return -- cgit v0.10.2