diff options
author | Christian Pointner <equinox@helsinki.at> | 2015-12-14 15:29:05 (GMT) |
---|---|---|
committer | Christian Pointner <equinox@helsinki.at> | 2015-12-14 15:29:05 (GMT) |
commit | 764c1939b3cf995fd98e4dc27b5709d3b494702b (patch) | |
tree | 93bfa22ca5f6aecfad703c05ff8f45cd716ec017 /src | |
parent | 97b905e99b95b6611d8d87bda90b56e47f2a5040 (diff) |
switched to mymysql go driver and added escape for log table name
Diffstat (limited to 'src')
-rw-r--r-- | src/helsinki.at/rhimport/rddb.go | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/src/helsinki.at/rhimport/rddb.go b/src/helsinki.at/rhimport/rddb.go index 8b72f58..34019e8 100644 --- a/src/helsinki.at/rhimport/rddb.go +++ b/src/helsinki.at/rhimport/rddb.go @@ -27,13 +27,14 @@ package rhimport import ( "database/sql" "fmt" - _ "github.com/go-sql-driver/mysql" + "github.com/ziutek/mymysql/godrv" "regexp" "strings" ) var ( showMacroRe = regexp.MustCompile(`^LL 1 ([^ ]+) 0\!$`) + mysqlTableNameRe = regexp.MustCompile(`^[_0-9a-zA-Z-]+$`) ) const ( @@ -114,8 +115,10 @@ type RdDb struct { } func (self *RdDb) init(conf *Config) (err error) { - dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8", conf.db_user, conf.db_passwd, conf.db_host, conf.db_db) - if self.dbh, err = sql.Open("mysql", dsn); err != nil { + godrv.Register("SET CHARACTER SET utf8;") + + dsn := fmt.Sprintf("tcp:%s:3306*%s/%s/%s", conf.db_host, conf.db_db, conf.db_user, conf.db_passwd) + if self.dbh, err = sql.Open("mymysql", dsn); err != nil { return } @@ -196,12 +199,20 @@ func (self *RdDb) getGroupOfCart(cart uint) (result getGroupOfCartResult) { return } -func (self *RdDb) getLogTableName(log string) string { - return strings.Replace(log, " ", "_", -1) + "_LOG" // TODO: this should get escaped for mySQL but golang doesn't support it!!! +func (self *RdDb) getLogTableName(log string) (logtable string, err error) { + logtable = strings.Replace(log, " ", "_", -1) + "_LOG" + if !mysqlTableNameRe.MatchString(logtable) { + return "", fmt.Errorf("the log table name contains illegal charecters: %s", logtable) + } + return } func (self *RdDb) getShowCarts(log string, low_cart, high_cart int) (carts []uint, err error) { - q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", self.getLogTableName(log), low_cart, high_cart) + var logtable string + if logtable, err = self.getLogTableName(log); err != nil { + return + } + q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", logtable, low_cart, high_cart) var rows *sql.Rows if rows, err = self.dbh.Query(q); err != nil { return |