package db import ( "database/sql" "errors" "fmt" "genBrief/common/log" "os" "strings" "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/mysql" ) type DB struct { Db *sql.DB } var p1 *DB func Init() { connectionName := os.Getenv("CLOUDSQL_CONNECTION_NAME") if connectionName != "" { db, err := initProxyPool() if err != nil { log.Infof("init_sql_proxy %s", err.Error()) panic("init db error" + err.Error()) } p1 = &DB{Db: db} } else { db, err := initTCPConnectionPool() if err != nil { log.Infof("init_sql %s", err.Error()) panic("init db error" + err.Error()) } p1 = &DB{Db: db} } } func initTCPConnectionPool() (*sql.DB, error) { var ( dbTcpHost = os.Getenv("DB_HOST") // e.g. '127.0.0.1' ('172.17.0.1' if deployed to GAE Flex) dbPort = os.Getenv("DB_PORT") // e.g. '3306' dbName = os.Getenv("DB_NAME") // e.g. 'my-database' dbUser = os.Getenv("DB_USER") // e.g. 'my-db-user' dbPwd = os.Getenv("DB_PASS") // e.g. 'my-db-password' ) log.Infof("dbTcpHost:%v\n dbPort:%v\n dbName:%v\n dbuser:%v\n dbpwd:%v\n", dbTcpHost, dbPort, dbName, dbUser, dbPwd) //var dbURI string dbURI := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", dbUser, dbPwd, dbTcpHost, dbPort, dbName) dbPool, err := sql.Open("mysql", dbURI) if err != nil { return nil, fmt.Errorf("sql.Open: %v", err) } configureConnectionPool(dbPool) if err := dbPool.Ping(); err != nil { log.Warn("dbPing err:", err) } else { log.Warn("dbPing succ") } return dbPool, nil } func initProxyPool() (*sql.DB, error) { var ( instanceConnectionName = os.Getenv("CLOUDSQL_CONNECTION_NAME") dbName = os.Getenv("DB_NAME") dbUser = os.Getenv("DB_USER") dbPwd = os.Getenv("DB_PASS") ) log.Infof("connectionName:%v\n dbName:%v\n dbuser:%v\n dbpwd:%v\n", instanceConnectionName, dbName, dbUser, dbPwd) cfg := mysql.Cfg(instanceConnectionName, dbUser, dbPwd) cfg.DBName = dbName cfg.ParseTime = true dbPool, err := mysql.DialCfg(cfg) if err != nil { log.Warn(err) panic(err.Error()) } configureConnectionPool(dbPool) if err := dbPool.Ping(); err != nil { log.Warn("dbPing err:", err) } else { log.Warn("dbPing succ") } return dbPool, nil } func configureConnectionPool(dbPool *sql.DB) { dbPool.SetMaxIdleConns(50) dbPool.SetMaxOpenConns(50) dbPool.SetConnMaxLifetime(1800) } func P1() *sql.DB { return p1.Db } func Pool() *DB { return p1 } func NewPool(pstr string) *DB { mdb := &DB{} mdb.init(pstr) return mdb } func (p *DB) init(pstr string) { p1, err := sql.Open("mysql", pstr) if err != nil { panic("open mysql error " + err.Error()) } p.Db = p1 p1.SetMaxOpenConns(200) p1.SetMaxIdleConns(100) p1.Ping() } func (p *DB) Insert(tab string, data map[string]interface{}) (int, error) { var sqllist []string var seqs []string var values []interface{} for key, value := range data { sqllist = append(sqllist, key) seqs = append(seqs, "?") values = append(values, value) } newsql := "insert into " + tab + " (" + strings.Join(sqllist, ",") + ") values (" + strings.Join(seqs, ", ") + ")" res, err := p.Db.Exec(newsql, values...) if err != nil { log.Infof(" insert err=%v", err) return 0, err } id, err1 := res.LastInsertId() //log.Infof(" insert id=%v err=%v", id, err1) return int(id), err1 } func (p *DB) InsertIgnore(tab string, data map[string]interface{}) (int, error) { var sqllist []string var seqs []string var values []interface{} for key, value := range data { sqllist = append(sqllist, key) seqs = append(seqs, "?") values = append(values, value) } newsql := "insert ignore into " + tab + " (" + strings.Join(sqllist, ",") + ") values (" + strings.Join(seqs, ", ") + ")" res, err := p.Db.Exec(newsql, values...) if err != nil { log.Infof(" insert err=%v", err) return 0, err } id, err1 := res.LastInsertId() log.Infof(" insert id=%v err=%v", id, err1) return int(id), err1 } func (p *DB) Replace(tab string, data map[string]interface{}) (int64, error) { var sqllist []string var seqs []string var values []interface{} for key, value := range data { sqllist = append(sqllist, key) seqs = append(seqs, "?") values = append(values, value) } newsql := "replace into " + tab + " (" + strings.Join(sqllist, ",") + ") values (" + strings.Join(seqs, ", ") + ")" res, erro := p.Db.Exec(newsql, values...) if erro != nil { log.Infof(" replace err=%v", erro) return 0, erro } log.Infof("replace sql=%v %v %v", newsql, res, erro) id, err := res.LastInsertId() return id, err } func (p *DB) Excute(newsql string, args ...interface{}) (int, error) { res, erro := p.Db.Exec(newsql, args...) if erro != nil { return 0, erro } id, err := res.RowsAffected() return int(id), err } func (p *DB) ExcuteInsert(newsql string, args ...interface{}) (int, error) { res, erro := p.Db.Exec(newsql, args...) if erro != nil { return 0, erro } id, err := res.LastInsertId() return int(id), err } func (p *DB) Update(tab string, update map[string]interface{}, where map[string]interface{}) (int, error) { var setsql []string var values []interface{} var wheresql []string for key, value := range update { setsql = append(setsql, key+" = ?") values = append(values, value) } for key1, value1 := range where { if strings.Contains(key1, "?") { wheresql = append(wheresql, key1) values = append(values, value1) } else { wheresql = append(wheresql, key1+" = ?") values = append(values, value1) } } newsql := "update " + tab + " set " + strings.Join(setsql, ",") + " where " + strings.Join(wheresql, " and ") res, erro := p.Db.Exec(newsql, values...) if erro != nil { return 0, erro } id, err := res.RowsAffected() return int(id), err } func (p *DB) GetData(sqlString string, params ...interface{}) ([]map[string]interface{}, error) { // tableData := make([]map[string]interface{}, 0) rows, err := p.Db.Query(sqlString, params...) if err != nil { log.Infof("dberr = %v", err) return nil, err // return tableData, err } defer rows.Close() tableData, err := FetchData(rows) if err != nil || len(tableData) < 1 { return nil, errors.New("data null") } return tableData, nil } func (p *DB) GetRow(sqlString string, params ...interface{}) (map[string]interface{}, error) { // tableData := make([]map[string]interface{}, 0) rows, err := p.Db.Query(sqlString, params...) if err != nil { log.Infof("dberr = %v", err) return nil, err } defer rows.Close() tableData, err := FetchData(rows) if err != nil || len(tableData) < 1 { return nil, errors.New("data null") } return tableData[0], nil } func (p *DB) GetWithTotal(tab string, attr string, where map[string]interface{}, from int, to int) ([]map[string]interface{}, int) { var wheresqls []string var params []interface{} var sqlString string tableData := make([]map[string]interface{}, 0) var wheresql string for key1, value1 := range where { if strings.Contains(key1, "?") { wheresqls = append(wheresqls, key1) params = append(params, value1) } else { wheresqls = append(wheresqls, key1+" = ?") params = append(params, value1) } } wheresql = strings.Join(wheresqls, " and ") if wheresql != "" { wheresql = " where " + wheresql } sqltotal := "select count(1) from " + tab + wheresql total := 0 err0 := p.Db.QueryRow(sqltotal, params...).Scan(&total) if err0 != nil { log.Infof("dberr = %v %v", sqltotal, err0) return tableData, 0 } sqlString = "select " + attr + " from " + tab + wheresql + " limit ?,?" params = append(params, from, to) rows, err := p.Db.Query(sqlString, params...) if err != nil { log.Infof("dberr = %v %v", sqltotal, err) return tableData, 0 } defer rows.Close() tableData, _ = FetchData(rows) return tableData, total } func FetchData(rows *sql.Rows) (tableData []map[string]interface{}, err error) { columns, err := rows.Columns() if err != nil { return } tableData = make([]map[string]interface{}, 0) count := len(columns) values := make([]interface{}, count) valuePtrs := make([]interface{}, count) for rows.Next() { for i := 0; i < count; i++ { valuePtrs[i] = &values[i] } rows.Scan(valuePtrs...) entry := make(map[string]interface{}) for i, col := range columns { var v interface{} val := values[i] b, ok := val.([]byte) if ok { v = string(b) } else { v = val } entry[col] = v } tableData = append(tableData, entry) } return } func FetchOne(rows *sql.Rows) (tableData map[string]interface{}, err error) { defer rows.Close() columns, err := rows.Columns() data := map[string]interface{}{} if err != nil { return data, err } if !rows.Next() { return data, errors.New("data null") } count := len(columns) valuePtrs := make([]interface{}, count) values := make([]interface{}, count) for i := 0; i < count; i++ { valuePtrs[i] = &values[i] } if err = rows.Scan(valuePtrs...); err != nil { return data, err } for i, col := range columns { var v interface{} val := values[i] b, ok := val.([]byte) if ok { v = string(b) } else { v = val } data[col] = v } return data, nil }