whatcanGOwrong
This commit is contained in:
@@ -0,0 +1,56 @@
|
||||
# MySQL
|
||||
|
||||
`mysql://user:password@tcp(host:port)/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-no-lock` | `NoLock` | Set to `true` to skip `GET_LOCK`/`RELEASE_LOCK` statements. Useful for [multi-master MySQL flavors](https://www.percona.com/doc/percona-xtradb-cluster/LATEST/features/pxc-strict-mode.html#explicit-table-locking). Only run migrations from one host when this is enabled. |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds, functionally similar to [Server-side SELECT statement timeouts](https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/) but enforced by the client. Available for all versions of MySQL, not just >=5.7. |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. |
|
||||
| `port` | | The port to bind to. |
|
||||
| `tls` | | TLS / SSL encrypted connection parameter; see [go-sql-driver](https://github.com/go-sql-driver/mysql#tls). Use any name (e.g. `migrate`) if you want to use a custom TLS config (`x-tls-` queries). |
|
||||
| `x-tls-ca` | | The location of the CA (certificate authority) file. |
|
||||
| `x-tls-cert` | | The location of the client certificate file. Must be used with `x-tls-key`. |
|
||||
| `x-tls-key` | | The location of the private key file. Must be used with `x-tls-cert`. |
|
||||
| `x-tls-insecure-skip-verify` | | Whether or not to use SSL (true\|false) |
|
||||
|
||||
## Use with existing client
|
||||
|
||||
If you use the MySQL driver with existing database client, you must create the client with parameter `multiStatements=true`:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("mysql", "user:password@tcp(host:port)/dbname?multiStatements=true")
|
||||
driver, _ := mysql.WithInstance(db, &mysql.Config{})
|
||||
m, _ := migrate.NewWithDatabaseInstance(
|
||||
"file:///migrations",
|
||||
"mysql",
|
||||
driver,
|
||||
)
|
||||
|
||||
m.Steps(2)
|
||||
}
|
||||
```
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://dev.mysql.com/doc/refman/5.7/en/commit.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS test;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE TABLE IF NOT EXISTS test (
|
||||
firstname VARCHAR(16)
|
||||
);
|
||||
@@ -0,0 +1,514 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
|
||||
|
||||
func init() {
|
||||
database.Register("mysql", &Mysql{})
|
||||
}
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
|
||||
var (
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrAppendPEM = fmt.Errorf("failed to append PEM")
|
||||
ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
NoLock bool
|
||||
StatementTimeout time.Duration
|
||||
}
|
||||
|
||||
type Mysql struct {
|
||||
// mysql RELEASE_LOCK must be called from the same conn, so
|
||||
// just do everything over a single conn anyway.
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
config *Config
|
||||
}
|
||||
|
||||
// connection instance must have `multiStatements` set to true
|
||||
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx := &Mysql{
|
||||
conn: conn,
|
||||
db: nil,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT DATABASE()`
|
||||
var databaseName sql.NullString
|
||||
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName.String) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName.String
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if err := mx.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
// instance must have `multiStatements` set to true
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx, err := WithConnection(ctx, conn, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx.db = instance
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
|
||||
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
|
||||
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
|
||||
if c == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
customQueryParams := map[string]string{}
|
||||
|
||||
for k, v := range c.Params {
|
||||
if strings.HasPrefix(k, "x-") {
|
||||
customQueryParams[k] = v
|
||||
delete(c.Params, k)
|
||||
}
|
||||
}
|
||||
return customQueryParams, nil
|
||||
}
|
||||
|
||||
func urlToMySQLConfig(url string) (*mysql.Config, error) {
|
||||
// Need to parse out custom TLS parameters and call
|
||||
// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
|
||||
// which consumes the registered tls.Config
|
||||
// Fixes: https://github.com/golang-migrate/migrate/issues/411
|
||||
//
|
||||
// Can't use url.Parse() since it fails to parse MySQL DSNs
|
||||
// mysql.ParseDSN() also searches for "?" to find query parameters:
|
||||
// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
|
||||
if idx := strings.LastIndex(url, "?"); idx > 0 {
|
||||
rawParams := url[idx+1:]
|
||||
parsedParams, err := nurl.ParseQuery(rawParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctls := parsedParams.Get("tls")
|
||||
if len(ctls) > 0 {
|
||||
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
pem, err := os.ReadFile(parsedParams.Get("x-tls-ca"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return nil, ErrAppendPEM
|
||||
}
|
||||
|
||||
clientCert := make([]tls.Certificate, 0, 1)
|
||||
if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
|
||||
if ccert == "" || ckey == "" {
|
||||
return nil, ErrTLSCertKeyConfig
|
||||
}
|
||||
certs, err := tls.LoadX509KeyPair(ccert, ckey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientCert = append(clientCert, certs)
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
|
||||
if len(insecureSkipVerifyStr) > 0 {
|
||||
x, err := strconv.ParseBool(insecureSkipVerifyStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
insecureSkipVerify = x
|
||||
}
|
||||
|
||||
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
|
||||
RootCAs: rootCertPool,
|
||||
Certificates: clientCert,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.MultiStatements = true
|
||||
|
||||
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
|
||||
// net/url.Parse() would automatically unescape it for us.
|
||||
// See: https://play.golang.org/p/q9j1io-YICQ
|
||||
user, err := nurl.QueryUnescape(config.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.User = user
|
||||
|
||||
password, err := nurl.QueryUnescape(config.Passwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Passwd = password
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Open(url string) (database.Driver, error) {
|
||||
config, err := urlToMySQLConfig(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
customParams, err := extractCustomQueryParams(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
noLockParam, noLock := customParams["x-no-lock"], false
|
||||
if noLockParam != "" {
|
||||
noLock, err = strconv.ParseBool(noLockParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
statementTimeoutParam := customParams["x-statement-timeout"]
|
||||
statementTimeout := 0
|
||||
if statementTimeoutParam != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", config.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx, err := WithInstance(db, &Config{
|
||||
DatabaseName: config.DBName,
|
||||
MigrationsTable: customParams["x-migrations-table"],
|
||||
NoLock: noLock,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Close() error {
|
||||
connErr := m.conn.Close()
|
||||
var dbErr error
|
||||
if m.db != nil {
|
||||
dbErr = m.db.Close()
|
||||
}
|
||||
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Lock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
|
||||
if m.config.NoLock {
|
||||
return nil
|
||||
}
|
||||
aid, err := database.GenerateAdvisoryLockId(
|
||||
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "SELECT GET_LOCK(?, 10)"
|
||||
var success bool
|
||||
if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
|
||||
if !success {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mysql) Unlock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
if m.config.NoLock {
|
||||
return nil
|
||||
}
|
||||
|
||||
aid, err := database.GenerateAdvisoryLockId(
|
||||
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT RELEASE_LOCK(?)`
|
||||
if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
|
||||
// in which case isLocked should be true until the timeout expires -- synchronizing
|
||||
// these states is likely not worth trying to do; reconsider the necessity of isLocked.
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mysql) Run(migration io.Reader) error {
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if m.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
query := string(migr[:])
|
||||
if _, err := m.conn.ExecContext(ctx, query); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) SetVersion(version int, dirty bool) error {
|
||||
tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := "DELETE FROM `" + m.config.MigrationsTable + "`"
|
||||
if _, err := tx.ExecContext(context.Background(), query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
|
||||
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Version() (version int, dirty bool, err error) {
|
||||
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
|
||||
err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*mysql.MySQLError); ok {
|
||||
if e.Number == 0 {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mysql) Drop() (err error) {
|
||||
// select all tables
|
||||
query := `SHOW TABLES LIKE '%'`
|
||||
tables, err := m.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// disable checking foreign key constraints until finished
|
||||
query = `SET foreign_key_checks = 0`
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// enable foreign key checks
|
||||
_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
|
||||
}()
|
||||
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = "DROP TABLE IF EXISTS `" + t + "`"
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Mysql type.
|
||||
func (m *Mysql) ensureVersionTable() (err error) {
|
||||
if err = m.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := m.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// check if migration table exists
|
||||
var result string
|
||||
query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
|
||||
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, create the empty migration table
|
||||
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the bool value of the input.
|
||||
// The 2nd return value indicates if the input was a valid bool value
|
||||
// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
|
||||
func readBool(input string) (value bool, valid bool) {
|
||||
switch input {
|
||||
case "1", "true", "TRUE", "True":
|
||||
return true, true
|
||||
case "0", "false", "FALSE", "False":
|
||||
return false, true
|
||||
}
|
||||
|
||||
// Not a valid bool value
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,420 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const defaultPort = 3306
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
}
|
||||
optsAnsiQuotes = dktest.Options{
|
||||
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
Cmd: []string{"--sql-mode=ANSI_QUOTES"},
|
||||
}
|
||||
// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "mysql:5.5", Options: opts},
|
||||
{ImageName: "mysql:5.6", Options: opts},
|
||||
{ImageName: "mysql:5.7", Options: opts},
|
||||
{ImageName: "mysql:8", Options: opts},
|
||||
}
|
||||
specsAnsiQuotes = []dktesting.ContainerSpec{
|
||||
{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:8", Options: optsAnsiQuotes},
|
||||
}
|
||||
)
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
|
||||
return false
|
||||
default:
|
||||
fmt.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateAnsiQuotes(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLockWorks(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ms := d.(*Mysql)
|
||||
|
||||
err = ms.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ms.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
|
||||
err = ms.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ms.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoLockParamValidation(t *testing.T) {
|
||||
ip := "127.0.0.1"
|
||||
port := 3306
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
_, err := p.Open(addr + "?x-no-lock=not-a-bool")
|
||||
if !errors.Is(err, strconv.ErrSyntax) {
|
||||
t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoLockWorks(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lock := d.(*Mysql)
|
||||
|
||||
p = &Mysql{}
|
||||
d, err = p.Open(addr + "?x-no-lock=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
noLock := d.(*Mysql)
|
||||
|
||||
// Should be possible to take real lock and no-lock at the same time
|
||||
if err = lock.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = noLock.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = lock.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = noLock.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractCustomQueryParams(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
config *mysql.Config
|
||||
expectedParams map[string]string
|
||||
expectedCustomParams map[string]string
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "nil config", expectedErr: ErrNilConfig},
|
||||
{
|
||||
name: "no params",
|
||||
config: mysql.NewConfig(),
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "no custom params",
|
||||
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "one param, one custom param",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{"hello": "world", "x-foo": "bar"},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "multiple params, multiple custom params",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{
|
||||
"hello": "world",
|
||||
"x-foo": "bar",
|
||||
"dead": "beef",
|
||||
"x-cat": "hat",
|
||||
},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
customParams, err := extractCustomQueryParams(tc.config)
|
||||
if tc.config != nil {
|
||||
assert.Equal(t, tc.expectedParams, tc.config.Params,
|
||||
"Expected config params have custom params properly removed")
|
||||
}
|
||||
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
|
||||
assert.Equal(t, tc.expectedCustomParams, customParams,
|
||||
"Expected custom params to be properly extracted")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTmpCert(t *testing.T) string {
|
||||
tmpCertFile, err := os.CreateTemp("", "migrate_test_cert")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create temp cert file:", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := os.Remove(tmpCertFile.Name()); err != nil {
|
||||
t.Log("Failed to cleanup temp cert file:", err)
|
||||
}
|
||||
})
|
||||
|
||||
r := rand.New(rand.NewSource(0))
|
||||
pub, priv, err := ed25519.GenerateKey(r)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
|
||||
}
|
||||
tmpl := x509.Certificate{
|
||||
SerialNumber: big.NewInt(0),
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to generate temp cert file:", err)
|
||||
}
|
||||
if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
t.Fatal("Failed to encode ")
|
||||
}
|
||||
if err := tmpCertFile.Close(); err != nil {
|
||||
t.Fatal("Failed to close temp cert file:", err)
|
||||
}
|
||||
return tmpCertFile.Name()
|
||||
}
|
||||
|
||||
func TestURLToMySQLConfig(t *testing.T) {
|
||||
tmpCertFilename := createTmpCert(t)
|
||||
tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectedDSN string // empty string signifies that an error is expected
|
||||
}{
|
||||
{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user - with encoded :",
|
||||
urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user - with encoded @",
|
||||
urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
|
||||
// {name: "user/password - user with encoded :",
|
||||
// urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
// expectedDSN: "username::password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - user with encoded @",
|
||||
urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - password with encoded :",
|
||||
urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - password with encoded @",
|
||||
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "custom tls",
|
||||
urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
|
||||
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config, err := urlToMySQLConfig(tc.urlStr)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
|
||||
}
|
||||
dsn := config.FormatDSN()
|
||||
if dsn != tc.expectedDSN {
|
||||
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user