summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorChristian Pointner <equinox@helsinki.at>2015-12-14 15:29:05 (GMT)
committerChristian Pointner <equinox@helsinki.at>2015-12-14 15:29:05 (GMT)
commit764c1939b3cf995fd98e4dc27b5709d3b494702b (patch)
tree93bfa22ca5f6aecfad703c05ff8f45cd716ec017 /src
parent97b905e99b95b6611d8d87bda90b56e47f2a5040 (diff)
switched to mymysql go driver and added escape for log table name
Diffstat (limited to 'src')
-rw-r--r--src/helsinki.at/rhimport/rddb.go23
1 files changed, 17 insertions, 6 deletions
diff --git a/src/helsinki.at/rhimport/rddb.go b/src/helsinki.at/rhimport/rddb.go
index 8b72f58..34019e8 100644
--- a/src/helsinki.at/rhimport/rddb.go
+++ b/src/helsinki.at/rhimport/rddb.go
@@ -27,13 +27,14 @@ package rhimport
import (
"database/sql"
"fmt"
- _ "github.com/go-sql-driver/mysql"
+ "github.com/ziutek/mymysql/godrv"
"regexp"
"strings"
)
var (
showMacroRe = regexp.MustCompile(`^LL 1 ([^ ]+) 0\!$`)
+ mysqlTableNameRe = regexp.MustCompile(`^[_0-9a-zA-Z-]+$`)
)
const (
@@ -114,8 +115,10 @@ type RdDb struct {
}
func (self *RdDb) init(conf *Config) (err error) {
- dsn := fmt.Sprintf("%s:%s@tcp(%s:3306)/%s?charset=utf8", conf.db_user, conf.db_passwd, conf.db_host, conf.db_db)
- if self.dbh, err = sql.Open("mysql", dsn); err != nil {
+ godrv.Register("SET CHARACTER SET utf8;")
+
+ dsn := fmt.Sprintf("tcp:%s:3306*%s/%s/%s", conf.db_host, conf.db_db, conf.db_user, conf.db_passwd)
+ if self.dbh, err = sql.Open("mymysql", dsn); err != nil {
return
}
@@ -196,12 +199,20 @@ func (self *RdDb) getGroupOfCart(cart uint) (result getGroupOfCartResult) {
return
}
-func (self *RdDb) getLogTableName(log string) string {
- return strings.Replace(log, " ", "_", -1) + "_LOG" // TODO: this should get escaped for mySQL but golang doesn't support it!!!
+func (self *RdDb) getLogTableName(log string) (logtable string, err error) {
+ logtable = strings.Replace(log, " ", "_", -1) + "_LOG"
+ if !mysqlTableNameRe.MatchString(logtable) {
+ return "", fmt.Errorf("the log table name contains illegal charecters: %s", logtable)
+ }
+ return
}
func (self *RdDb) getShowCarts(log string, low_cart, high_cart int) (carts []uint, err error) {
- q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", self.getLogTableName(log), low_cart, high_cart)
+ var logtable string
+ if logtable, err = self.getLogTableName(log); err != nil {
+ return
+ }
+ q := fmt.Sprintf("select CART_NUMBER from %s where CART_NUMBER >= %d and CART_NUMBER <= %d order by COUNT;", logtable, low_cart, high_cart)
var rows *sql.Rows
if rows, err = self.dbh.Query(q); err != nil {
return