whatcanGOwrong

This commit is contained in:
2024-09-19 21:38:24 -04:00
commit d0ae4d841d
17908 changed files with 4096831 additions and 0 deletions
@@ -0,0 +1,56 @@
# MySQL
`mysql://user:password@tcp(host:port)/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-no-lock` | `NoLock` | Set to `true` to skip `GET_LOCK`/`RELEASE_LOCK` statements. Useful for [multi-master MySQL flavors](https://www.percona.com/doc/percona-xtradb-cluster/LATEST/features/pxc-strict-mode.html#explicit-table-locking). Only run migrations from one host when this is enabled. |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds, functionally similar to [Server-side SELECT statement timeouts](https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/) but enforced by the client. Available for all versions of MySQL, not just >=5.7. |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. |
| `port` | | The port to bind to. |
| `tls` | | TLS / SSL encrypted connection parameter; see [go-sql-driver](https://github.com/go-sql-driver/mysql#tls). Use any name (e.g. `migrate`) if you want to use a custom TLS config (`x-tls-` queries). |
| `x-tls-ca` | | The location of the CA (certificate authority) file. |
| `x-tls-cert` | | The location of the client certificate file. Must be used with `x-tls-key`. |
| `x-tls-key` | | The location of the private key file. Must be used with `x-tls-cert`. |
| `x-tls-insecure-skip-verify` | | Whether or not to use SSL (true\|false) |
## Use with existing client
If you use the MySQL driver with existing database client, you must create the client with parameter `multiStatements=true`:
```go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
func main() {
db, _ := sql.Open("mysql", "user:password@tcp(host:port)/dbname?multiStatements=true")
driver, _ := mysql.WithInstance(db, &mysql.Config{})
m, _ := migrate.NewWithDatabaseInstance(
"file:///migrations",
"mysql",
driver,
)
m.Steps(2)
}
```
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://dev.mysql.com/doc/refman/5.7/en/commit.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
@@ -0,0 +1,3 @@
CREATE TABLE IF NOT EXISTS test (
firstname VARCHAR(16)
);
@@ -0,0 +1,514 @@
//go:build go1.9
// +build go1.9
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
nurl "net/url"
"os"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
)
var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
func init() {
database.Register("mysql", &Mysql{})
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrAppendPEM = fmt.Errorf("failed to append PEM")
ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
)
type Config struct {
MigrationsTable string
DatabaseName string
NoLock bool
StatementTimeout time.Duration
}
type Mysql struct {
// mysql RELEASE_LOCK must be called from the same conn, so
// just do everything over a single conn anyway.
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
config *Config
}
// connection instance must have `multiStatements` set to true
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := conn.PingContext(ctx); err != nil {
return nil, err
}
mx := &Mysql{
conn: conn,
db: nil,
config: config,
}
if config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName.String) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName.String
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}
return mx, nil
}
// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()
if err := instance.Ping(); err != nil {
return nil, err
}
conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}
mx, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}
mx.db = instance
return mx, nil
}
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
if c == nil {
return nil, ErrNilConfig
}
customQueryParams := map[string]string{}
for k, v := range c.Params {
if strings.HasPrefix(k, "x-") {
customQueryParams[k] = v
delete(c.Params, k)
}
}
return customQueryParams, nil
}
func urlToMySQLConfig(url string) (*mysql.Config, error) {
// Need to parse out custom TLS parameters and call
// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
// which consumes the registered tls.Config
// Fixes: https://github.com/golang-migrate/migrate/issues/411
//
// Can't use url.Parse() since it fails to parse MySQL DSNs
// mysql.ParseDSN() also searches for "?" to find query parameters:
// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
if idx := strings.LastIndex(url, "?"); idx > 0 {
rawParams := url[idx+1:]
parsedParams, err := nurl.ParseQuery(rawParams)
if err != nil {
return nil, err
}
ctls := parsedParams.Get("tls")
if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(parsedParams.Get("x-tls-ca"))
if err != nil {
return nil, err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, ErrAppendPEM
}
clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig
}
certs, err := tls.LoadX509KeyPair(ccert, ckey)
if err != nil {
return nil, err
}
clientCert = append(clientCert, certs)
}
insecureSkipVerify := false
insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
if len(insecureSkipVerifyStr) > 0 {
x, err := strconv.ParseBool(insecureSkipVerifyStr)
if err != nil {
return nil, err
}
insecureSkipVerify = x
}
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
InsecureSkipVerify: insecureSkipVerify,
})
if err != nil {
return nil, err
}
}
}
}
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
return nil, err
}
config.MultiStatements = true
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
// net/url.Parse() would automatically unescape it for us.
// See: https://play.golang.org/p/q9j1io-YICQ
user, err := nurl.QueryUnescape(config.User)
if err != nil {
return nil, err
}
config.User = user
password, err := nurl.QueryUnescape(config.Passwd)
if err != nil {
return nil, err
}
config.Passwd = password
return config, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
config, err := urlToMySQLConfig(url)
if err != nil {
return nil, err
}
customParams, err := extractCustomQueryParams(config)
if err != nil {
return nil, err
}
noLockParam, noLock := customParams["x-no-lock"], false
if noLockParam != "" {
noLock, err = strconv.ParseBool(noLockParam)
if err != nil {
return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
}
}
statementTimeoutParam := customParams["x-statement-timeout"]
statementTimeout := 0
if statementTimeoutParam != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutParam)
if err != nil {
return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err)
}
}
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
mx, err := WithInstance(db, &Config{
DatabaseName: config.DBName,
MigrationsTable: customParams["x-migrations-table"],
NoLock: noLock,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
})
if err != nil {
return nil, err
}
return mx, nil
}
func (m *Mysql) Close() error {
connErr := m.conn.Close()
var dbErr error
if m.db != nil {
dbErr = m.db.Close()
}
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
func (m *Mysql) Lock() error {
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
if m.config.NoLock {
return nil
}
aid, err := database.GenerateAdvisoryLockId(
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
if err != nil {
return err
}
query := "SELECT GET_LOCK(?, 10)"
var success bool
if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
if !success {
return database.ErrLocked
}
return nil
})
}
func (m *Mysql) Unlock() error {
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
if m.config.NoLock {
return nil
}
aid, err := database.GenerateAdvisoryLockId(
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
if err != nil {
return err
}
query := `SELECT RELEASE_LOCK(?)`
if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
// in which case isLocked should be true until the timeout expires -- synchronizing
// these states is likely not worth trying to do; reconsider the necessity of isLocked.
return nil
})
}
func (m *Mysql) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
ctx := context.Background()
if m.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout)
defer cancel()
}
query := string(migr[:])
if _, err := m.conn.ExecContext(ctx, query); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (m *Mysql) SetVersion(version int, dirty bool) error {
tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := "DELETE FROM `" + m.config.MigrationsTable + "`"
if _, err := tx.ExecContext(context.Background(), query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (m *Mysql) Version() (version int, dirty bool, err error) {
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*mysql.MySQLError); ok {
if e.Number == 0 {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (m *Mysql) Drop() (err error) {
// select all tables
query := `SHOW TABLES LIKE '%'`
tables, err := m.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// disable checking foreign key constraints until finished
query = `SET foreign_key_checks = 0`
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
// enable foreign key checks
_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
}()
// delete one by one ...
for _, t := range tableNames {
query = "DROP TABLE IF EXISTS `" + t + "`"
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Mysql type.
func (m *Mysql) ensureVersionTable() (err error) {
if err = m.Lock(); err != nil {
return err
}
defer func() {
if e := m.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// check if migration table exists
var result string
query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
if err != sql.ErrNoRows {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
} else {
return nil
}
// if not, create the empty migration table
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
@@ -0,0 +1,420 @@
package mysql
import (
"context"
"crypto/ed25519"
"crypto/x509"
"database/sql"
sqldriver "database/sql/driver"
"encoding/pem"
"errors"
"fmt"
"log"
"math/big"
"math/rand"
"net/url"
"os"
"strconv"
"testing"
"github.com/dhui/dktest"
"github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/stretchr/testify/assert"
)
const defaultPort = 3306
var (
opts = dktest.Options{
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
PortRequired: true, ReadyFunc: isReady,
}
optsAnsiQuotes = dktest.Options{
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
PortRequired: true, ReadyFunc: isReady,
Cmd: []string{"--sql-mode=ANSI_QUOTES"},
}
// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
specs = []dktesting.ContainerSpec{
{ImageName: "mysql:5.5", Options: opts},
{ImageName: "mysql:5.6", Options: opts},
{ImageName: "mysql:5.7", Options: opts},
{ImageName: "mysql:8", Options: opts},
}
specsAnsiQuotes = []dktesting.ContainerSpec{
{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
{ImageName: "mysql:8", Options: optsAnsiQuotes},
}
)
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(defaultPort)
if err != nil {
return false
}
db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
return false
default:
fmt.Println(err)
}
return false
}
return true
}
func Test(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestMigrate(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestMigrateAnsiQuotes(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ms := d.(*Mysql)
err = ms.Lock()
if err != nil {
t.Fatal(err)
}
err = ms.Unlock()
if err != nil {
t.Fatal(err)
}
// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
err = ms.Lock()
if err != nil {
t.Fatal(err)
}
err = ms.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestNoLockParamValidation(t *testing.T) {
ip := "127.0.0.1"
port := 3306
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
_, err := p.Open(addr + "?x-no-lock=not-a-bool")
if !errors.Is(err, strconv.ErrSyntax) {
t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
}
}
func TestNoLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
lock := d.(*Mysql)
p = &Mysql{}
d, err = p.Open(addr + "?x-no-lock=true")
if err != nil {
t.Fatal(err)
}
noLock := d.(*Mysql)
// Should be possible to take real lock and no-lock at the same time
if err = lock.Lock(); err != nil {
t.Fatal(err)
}
if err = noLock.Lock(); err != nil {
t.Fatal(err)
}
if err = lock.Unlock(); err != nil {
t.Fatal(err)
}
if err = noLock.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestExtractCustomQueryParams(t *testing.T) {
testcases := []struct {
name string
config *mysql.Config
expectedParams map[string]string
expectedCustomParams map[string]string
expectedErr error
}{
{name: "nil config", expectedErr: ErrNilConfig},
{
name: "no params",
config: mysql.NewConfig(),
expectedCustomParams: map[string]string{},
},
{
name: "no custom params",
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{},
},
{
name: "one param, one custom param",
config: &mysql.Config{
Params: map[string]string{"hello": "world", "x-foo": "bar"},
},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{"x-foo": "bar"},
},
{
name: "multiple params, multiple custom params",
config: &mysql.Config{
Params: map[string]string{
"hello": "world",
"x-foo": "bar",
"dead": "beef",
"x-cat": "hat",
},
},
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
customParams, err := extractCustomQueryParams(tc.config)
if tc.config != nil {
assert.Equal(t, tc.expectedParams, tc.config.Params,
"Expected config params have custom params properly removed")
}
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
assert.Equal(t, tc.expectedCustomParams, customParams,
"Expected custom params to be properly extracted")
})
}
}
func createTmpCert(t *testing.T) string {
tmpCertFile, err := os.CreateTemp("", "migrate_test_cert")
if err != nil {
t.Fatal("Failed to create temp cert file:", err)
}
t.Cleanup(func() {
if err := os.Remove(tmpCertFile.Name()); err != nil {
t.Log("Failed to cleanup temp cert file:", err)
}
})
r := rand.New(rand.NewSource(0))
pub, priv, err := ed25519.GenerateKey(r)
if err != nil {
t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
}
tmpl := x509.Certificate{
SerialNumber: big.NewInt(0),
}
derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
if err != nil {
t.Fatal("Failed to generate temp cert file:", err)
}
if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
t.Fatal("Failed to encode ")
}
if err := tmpCertFile.Close(); err != nil {
t.Fatal("Failed to close temp cert file:", err)
}
return tmpCertFile.Name()
}
func TestURLToMySQLConfig(t *testing.T) {
tmpCertFilename := createTmpCert(t)
tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
testcases := []struct {
name string
urlStr string
expectedDSN string // empty string signifies that an error is expected
}{
{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded :",
urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded @",
urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
// {name: "user/password - user with encoded :",
// urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
// expectedDSN: "username::password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - user with encoded @",
urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded :",
urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded @",
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "custom tls",
urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
config, err := urlToMySQLConfig(tc.urlStr)
if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
}
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
})
}
}