whatcanGOwrong

This commit is contained in:
2024-09-19 21:38:24 -04:00
commit d0ae4d841d
17908 changed files with 4096831 additions and 0 deletions
@@ -0,0 +1,45 @@
# Cassandra / ScyllaDB
* `Drop()` method will not work on Cassandra 2.X because it rely on
system_schema table which comes with 3.X
* Other methods should work properly but are **not tested**
* The Cassandra driver (gocql) does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param. There are two important caveats:
* This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
* The queries are not executed in any sort of transaction/batch, meaning you are responsible for fixing partial migrations.
**ScyllaDB**
* No additional configuration is required since it is a drop-in replacement for Cassandra.
* The `Drop()` method` works for ScyllaDB 5.1
## Usage
`cassandra://host:port/keyspace?param1=value&param2=value2`
| URL Query | Default value | Description |
|------------|-------------|-----------|
| `x-migrations-table` | schema_migrations | Name of the migrations table |
| `x-multi-statement` | false | Enable multiple statements to be ran in a single migration (See note above) |
| `port` | 9042 | The port to bind to |
| `consistency` | ALL | Migration consistency
| `protocol` | | Cassandra protocol version (3 or 4)
| `timeout` | 1 minute | Migration timeout
| `connect-timeout` | 600ms | Initial connection timeout to the cluster |
| `username` | nil | Username to use when authenticating. |
| `password` | nil | Password to use when authenticating. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
| `disable-host-lookup`| false | Disable initial host lookup. |
`timeout` is parsed using [time.ParseDuration(s string)](https://golang.org/pkg/time/#ParseDuration)
## Upgrading from v1
1. Write down the current migration version from schema_migrations
2. `DROP TABLE schema_migrations`
4. Download and install the latest migrate version.
5. Force the current migration version with `migrate force <current_version>`.
@@ -0,0 +1,352 @@
package cassandra
import (
"errors"
"fmt"
"io"
nurl "net/url"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/gocql/gocql"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
)
func init() {
db := new(Cassandra)
database.Register("cassandra", db)
}
var (
multiStmtDelimiter = []byte(";")
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
)
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = errors.New("no config")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrDatabaseDirty = errors.New("database is dirty")
ErrClosedSession = errors.New("session is closed")
)
type Config struct {
MigrationsTable string
KeyspaceName string
MultiStatementEnabled bool
MultiStatementMaxSize int
}
type Cassandra struct {
session *gocql.Session
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if len(config.KeyspaceName) == 0 {
return nil, ErrNoKeyspace
}
if session.Closed() {
return nil, ErrClosedSession
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if config.MultiStatementMaxSize <= 0 {
config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
}
c := &Cassandra{
session: session,
config: config,
}
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return c, nil
}
func (c *Cassandra) Open(url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// Check for missing mandatory attributes
if len(u.Path) == 0 {
return nil, ErrNoKeyspace
}
cluster := gocql.NewCluster(u.Host)
cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
cluster.Consistency = gocql.All
cluster.Timeout = 1 * time.Minute
if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
authenticator := gocql.PasswordAuthenticator{
Username: u.Query().Get("username"),
Password: u.Query().Get("password"),
}
cluster.Authenticator = authenticator
}
// Retrieve query string configuration
if len(u.Query().Get("consistency")) > 0 {
var consistency gocql.Consistency
consistency, err = parseConsistency(u.Query().Get("consistency"))
if err != nil {
return nil, err
}
cluster.Consistency = consistency
}
if len(u.Query().Get("protocol")) > 0 {
var protoversion int
protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
if err != nil {
return nil, err
}
cluster.ProtoVersion = protoversion
}
if len(u.Query().Get("timeout")) > 0 {
var timeout time.Duration
timeout, err = time.ParseDuration(u.Query().Get("timeout"))
if err != nil {
return nil, err
}
cluster.Timeout = timeout
}
if len(u.Query().Get("connect-timeout")) > 0 {
var connectTimeout time.Duration
connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout"))
if err != nil {
return nil, err
}
cluster.ConnectTimeout = connectTimeout
}
if len(u.Query().Get("sslmode")) > 0 {
if u.Query().Get("sslmode") != "disable" {
sslOpts := &gocql.SslOptions{}
if len(u.Query().Get("sslrootcert")) > 0 {
sslOpts.CaPath = u.Query().Get("sslrootcert")
}
if len(u.Query().Get("sslcert")) > 0 {
sslOpts.CertPath = u.Query().Get("sslcert")
}
if len(u.Query().Get("sslkey")) > 0 {
sslOpts.KeyPath = u.Query().Get("sslkey")
}
if u.Query().Get("sslmode") == "verify-full" {
sslOpts.EnableHostVerification = true
}
cluster.SslOpts = sslOpts
}
}
if len(u.Query().Get("disable-host-lookup")) > 0 {
if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
cluster.DisableInitialHostLookup = true
} else if err != nil {
return nil, err
}
}
session, err := cluster.CreateSession()
if err != nil {
return nil, err
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
}
return WithInstance(session, &Config{
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
MigrationsTable: u.Query().Get("x-migrations-table"),
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
})
}
func (c *Cassandra) Close() error {
c.session.Close()
return nil
}
func (c *Cassandra) Lock() error {
if !c.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}
func (c *Cassandra) Unlock() error {
if !c.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}
func (c *Cassandra) Run(migration io.Reader) error {
if c.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
tq := strings.TrimSpace(string(m))
if tq == "" {
return true
}
if e := c.session.Query(tq).Exec(); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
// run migration
if err := c.session.Query(string(migr)).Exec(); err != nil {
// TODO: cast to Cassandra error and get line number
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (c *Cassandra) SetVersion(version int, dirty bool) error {
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
iter := c.session.Query(squery).Iter()
var previous int
for iter.Scan(&previous) {
if err := c.session.Query(dquery, previous).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(dquery)}
}
}
if err := iter.Close(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(squery)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
return nil
}
// Return current keyspace version
func (c *Cassandra) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
case err == gocql.ErrNotFound:
return database.NilVersion, false, nil
case err != nil:
if _, ok := err.(*gocql.Error); ok {
return database.NilVersion, false, nil
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (c *Cassandra) Drop() error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
iter := c.session.Query(query).Iter()
var tableName string
for iter.Scan(&tableName) {
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
if err != nil {
return err
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Cassandra type.
func (c *Cassandra) ensureVersionTable() (err error) {
if err = c.Lock(); err != nil {
return err
}
defer func() {
if e := c.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
if err != nil {
return err
}
if _, _, err = c.Version(); err != nil {
return err
}
return nil
}
// ParseConsistency wraps gocql.ParseConsistency
// to return an error instead of a panicking.
func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
err, ok = r.(error)
if !ok {
err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
}
}
}()
consistency = gocql.ParseConsistency(consistencyStr)
return consistency, nil
}
@@ -0,0 +1,122 @@
package cassandra
import (
"context"
"fmt"
"github.com/golang-migrate/migrate/v4"
"strconv"
"testing"
)
import (
"github.com/dhui/dktest"
"github.com/gocql/gocql"
)
import (
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
var (
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady}
// Supported versions: http://cassandra.apache.org/download/
// Although Cassandra 2.x is supported by the Apache Foundation,
// the migrate db driver only supports Cassandra 3.x since it uses
// the system_schema keyspace.
// last ScyllaDB version tested is 5.1.11
specs = []dktesting.ContainerSpec{
{ImageName: "cassandra:3.0", Options: opts},
{ImageName: "cassandra:3.11", Options: opts},
{ImageName: "scylladb/scylla:5.1.11", Options: opts},
}
)
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
// Cassandra exposes 5 ports (7000, 7001, 7199, 9042 & 9160)
// We only need the port bound to 9042
ip, portStr, err := c.Port(9042)
if err != nil {
return false
}
port, err := strconv.Atoi(portStr)
if err != nil {
return false
}
cluster := gocql.NewCluster(ip)
cluster.Port = port
cluster.Consistency = gocql.All
p, err := cluster.CreateSession()
if err != nil {
return false
}
defer p.Close()
// Create keyspace for tests
if err = p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec(); err != nil {
return false
}
return true
}
func Test(t *testing.T) {
t.Run("test", test)
t.Run("testMigrate", testMigrate)
t.Cleanup(func() {
for _, spec := range specs {
t.Log("Cleaning up ", spec.ImageName)
if err := spec.Cleanup(); err != nil {
t.Error("Error removing ", spec.ImageName, "error:", err)
}
}
})
}
func test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT table_name from system_schema.tables"))
})
}
func testMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(9042)
if err != nil {
t.Fatal("Unable to get mapped port:", err)
}
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
p := &Cassandra{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
@@ -0,0 +1 @@
SELECT table_name from system_schema.tables
@@ -0,0 +1 @@
SELECT table_name from system_schema.tables
@@ -0,0 +1,26 @@
# ClickHouse
`clickhouse://host:port?username=user&password=password&database=clicks&x-multi-statement=true`
| URL Query | Description |
|------------|-------------|
| `x-migrations-table`| Name of the migrations table |
| `x-migrations-table-engine`| Engine to use for the migrations table, defaults to TinyLog |
| `x-cluster-name` | Name of cluster for creating `schema_migrations` table cluster wide |
| `database` | The name of the database to connect to |
| `username` | The user to sign in as |
| `password` | The user's password |
| `host` | The host to connect to. |
| `port` | The port to bind to. |
| `x-multi-statement` | false | Enable multiple statements to be ran in a single migration (See note below) |
## Notes
* The Clickhouse driver does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param. There are two important caveats:
* This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
* The queries are not executed in any sort of transaction/batch, meaning you are responsible for fixing partial migrations.
* Using the default TinyLog table engine for the schema_versions table prevents backing up the table if using the [clickhouse-backup](https://github.com/AlexAkulov/clickhouse-backup) tool. If backing up the database with make sure the migrations are run with `x-migrations-table-engine=MergeTree`.
* Clickhouse cluster mode is not officially supported, since it's not tested right now, but you can try enabling `schema_migrations` table replication by specifying a `x-cluster-name`:
* When `x-cluster-name` is specified, `x-migrations-table-engine` also should be specified. See the docs regarding [replicated table engines](https://clickhouse.tech/docs/en/engines/table-engines/mergetree-family/replication/#table_engines-replication).
* When `x-cluster-name` is specified, only the `schema_migrations` table is replicated across the cluster. You still need to write your migrations so that the application tables are replicated within the cluster.
* If you want to create database inside the migration, you should know, that table which will manage migrations `schema-migrations table` will be in `default` table, so you can't use `USE <database_name>` inside migration. In this case you may not specify the database in the connection string (example you can find [here](examples/migrations/003_create_database.up.sql))
@@ -0,0 +1,316 @@
package clickhouse
import (
"database/sql"
"fmt"
"io"
"net/url"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
)
var (
multiStmtDelimiter = []byte(";")
DefaultMigrationsTable = "schema_migrations"
DefaultMigrationsTableEngine = "TinyLog"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
ErrNilConfig = fmt.Errorf("no config")
)
type Config struct {
DatabaseName string
ClusterName string
MigrationsTable string
MigrationsTableEngine string
MultiStatementEnabled bool
MultiStatementMaxSize int
}
func init() {
database.Register("clickhouse", &ClickHouse{})
}
func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := conn.Ping(); err != nil {
return nil, err
}
ch := &ClickHouse{
conn: conn,
config: config,
}
if err := ch.init(); err != nil {
return nil, err
}
return ch, nil
}
type ClickHouse struct {
conn *sql.DB
config *Config
isLocked atomic.Bool
}
func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
purl, err := url.Parse(dsn)
if err != nil {
return nil, err
}
q := migrate.FilterCustomQuery(purl)
q.Scheme = "tcp"
conn, err := sql.Open("clickhouse", q.String())
if err != nil {
return nil, err
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
}
migrationsTableEngine := DefaultMigrationsTableEngine
if s := purl.Query().Get("x-migrations-table-engine"); len(s) > 0 {
migrationsTableEngine = s
}
ch = &ClickHouse{
conn: conn,
config: &Config{
MigrationsTable: purl.Query().Get("x-migrations-table"),
MigrationsTableEngine: migrationsTableEngine,
DatabaseName: purl.Query().Get("database"),
ClusterName: purl.Query().Get("x-cluster-name"),
MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
MultiStatementMaxSize: multiStatementMaxSize,
},
}
if err := ch.init(); err != nil {
return nil, err
}
return ch, nil
}
func (ch *ClickHouse) init() error {
if len(ch.config.DatabaseName) == 0 {
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
return err
}
}
if len(ch.config.MigrationsTable) == 0 {
ch.config.MigrationsTable = DefaultMigrationsTable
}
if ch.config.MultiStatementMaxSize <= 0 {
ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
}
if len(ch.config.MigrationsTableEngine) == 0 {
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
}
return ch.ensureVersionTable()
}
func (ch *ClickHouse) Run(r io.Reader) error {
if ch.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
tq := strings.TrimSpace(string(m))
if tq == "" {
return true
}
if _, e := ch.conn.Exec(string(m)); e != nil {
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
return false
}
return true
}); e != nil {
return e
}
return err
}
migration, err := io.ReadAll(r)
if err != nil {
return err
}
if _, err := ch.conn.Exec(string(migration)); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
}
return nil
}
func (ch *ClickHouse) Version() (int, bool, error) {
var (
version int
dirty uint8
query = "SELECT version, dirty FROM `" + ch.config.MigrationsTable + "` ORDER BY sequence DESC LIMIT 1"
)
if err := ch.conn.QueryRow(query).Scan(&version, &dirty); err != nil {
if err == sql.ErrNoRows {
return database.NilVersion, false, nil
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
}
return version, dirty == 1, nil
}
func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
var (
bool = func(v bool) uint8 {
if v {
return 1
}
return 0
}
tx, err = ch.conn.Begin()
)
if err != nil {
return err
}
query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return tx.Commit()
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the ClickHouse type.
func (ch *ClickHouse) ensureVersionTable() (err error) {
if err = ch.Lock(); err != nil {
return err
}
defer func() {
if e := ch.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
var (
table string
query = "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) + " LIKE '" + ch.config.MigrationsTable + "'"
)
// check if migration table exists
if err := ch.conn.QueryRow(query).Scan(&table); err != nil {
if err != sql.ErrNoRows {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
} else {
return nil
}
// if not, create the empty migration table
if len(ch.config.ClusterName) > 0 {
query = fmt.Sprintf(`
CREATE TABLE %s ON CLUSTER %s (
version Int64,
dirty UInt8,
sequence UInt64
) Engine=%s`, ch.config.MigrationsTable, ch.config.ClusterName, ch.config.MigrationsTableEngine)
} else {
query = fmt.Sprintf(`
CREATE TABLE %s (
version Int64,
dirty UInt8,
sequence UInt64
) Engine=%s`, ch.config.MigrationsTable, ch.config.MigrationsTableEngine)
}
if strings.HasSuffix(ch.config.MigrationsTableEngine, "Tree") {
query = fmt.Sprintf(`%s ORDER BY sequence`, query)
}
if _, err := ch.conn.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (ch *ClickHouse) Drop() (err error) {
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
tables, err := ch.conn.Query(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
for tables.Next() {
var table string
if err := tables.Scan(&table); err != nil {
return err
}
query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table)
if _, err := ch.conn.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (ch *ClickHouse) Lock() error {
if !ch.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}
func (ch *ClickHouse) Unlock() error {
if !ch.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
@@ -0,0 +1,224 @@
package clickhouse_test
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"log"
"testing"
_ "github.com/ClickHouse/clickhouse-go"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/clickhouse"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const defaultPort = 9000
var (
tableEngines = []string{"TinyLog", "MergeTree"}
opts = dktest.Options{
Env: map[string]string{"CLICKHOUSE_USER": "user", "CLICKHOUSE_PASSWORD": "password", "CLICKHOUSE_DB": "db"},
PortRequired: true, ReadyFunc: isReady,
}
specs = []dktesting.ContainerSpec{
{ImageName: "yandex/clickhouse-server:21.3", Options: opts},
}
)
func clickhouseConnectionString(host, port, engine string) string {
if engine != "" {
return fmt.Sprintf(
"clickhouse://%v:%v?username=user&password=password&database=db&x-multi-statement=true&x-migrations-table-engine=%v&debug=false",
host, port, engine)
}
return fmt.Sprintf(
"clickhouse://%v:%v?username=user&password=password&database=db&x-multi-statement=true&debug=false",
host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(defaultPort)
if err != nil {
return false
}
db, err := sql.Open("clickhouse", clickhouseConnectionString(ip, port, ""))
if err != nil {
log.Println("open error", err)
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn:
return false
default:
fmt.Println(err)
}
return false
}
return true
}
func TestCases(t *testing.T) {
for _, engine := range tableEngines {
t.Run("Test_"+engine, func(t *testing.T) { testSimple(t, engine) })
t.Run("Migrate_"+engine, func(t *testing.T) { testMigrate(t, engine) })
t.Run("Version_"+engine, func(t *testing.T) { testVersion(t, engine) })
t.Run("Drop_"+engine, func(t *testing.T) { testDrop(t, engine) })
}
t.Run("WithInstanceDefaultConfigValues", func(t *testing.T) { testSimpleWithInstanceDefaultConfigValues(t) })
}
func testSimple(t *testing.T, engine string) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := clickhouseConnectionString(ip, port, engine)
p := &clickhouse.ClickHouse{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func testSimpleWithInstanceDefaultConfigValues(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := clickhouseConnectionString(ip, port, "")
conn, err := sql.Open("clickhouse", addr)
if err != nil {
t.Fatal(err)
}
d, err := clickhouse.WithInstance(conn, &clickhouse.Config{})
if err != nil {
_ = conn.Close()
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func testMigrate(t *testing.T, engine string) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := clickhouseConnectionString(ip, port, engine)
p := &clickhouse.ClickHouse{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "db", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func testVersion(t *testing.T, engine string) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
expectedVersion := 1
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := clickhouseConnectionString(ip, port, engine)
p := &clickhouse.ClickHouse{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
err = d.SetVersion(expectedVersion, false)
if err != nil {
t.Fatal(err)
}
version, _, err := d.Version()
if err != nil {
t.Fatal(err)
}
if version != expectedVersion {
t.Fatal("Version mismatch")
}
})
}
func testDrop(t *testing.T, engine string) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := clickhouseConnectionString(ip, port, engine)
p := &clickhouse.ClickHouse{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
err = d.Drop()
if err != nil {
t.Fatal(err)
}
})
}
@@ -0,0 +1,3 @@
CREATE TABLE test_1 (
Date Date
) Engine=Memory;
@@ -0,0 +1,3 @@
CREATE TABLE test_2 (
Date Date
) Engine=Memory;
@@ -0,0 +1,10 @@
DROP TABLE IF EXISTS driver_ratings;
DROP TABLE IF EXISTS user_ratings;
DROP TABLE IF EXISTS orders;
DROP TABLE IF EXISTS driver_ratings_queue;
DROP TABLE IF EXISTS user_ratings_queue;
DROP TABLE IF EXISTS orders_queue;
DROP VIEW IF EXISTS user_ratings_queue_mv;
DROP VIEW IF EXISTS driver_ratings_queue_mv;
DROP VIEW IF EXISTS orders_queue_mv;
DROP DATABASE IF EXISTS analytics;
@@ -0,0 +1,81 @@
CREATE DATABASE IF NOT EXISTS analytics;
CREATE TABLE IF NOT EXISTS analytics.driver_ratings(
rate UInt8,
userID Int64,
driverID String,
orderID String,
inserted_time DateTime DEFAULT now()
) ENGINE = MergeTree
PARTITION BY driverID
ORDER BY (inserted_time);
CREATE TABLE analytics.driver_ratings_queue(
rate UInt8,
userID Int64,
driverID String,
orderID String
) ENGINE = Kafka
SETTINGS kafka_broker_list = 'broker:9092',
kafka_topic_list = 'driver-ratings',
kafka_group_name = 'rating_readers',
kafka_format = 'Avro',
kafka_max_block_size = 1048576;
CREATE MATERIALIZED VIEW analytics.driver_ratings_queue_mv TO analytics.driver_ratings AS
SELECT rate, userID, driverID, orderID
FROM analytics.driver_ratings_queue;
CREATE TABLE IF NOT EXISTS analytics.user_ratings(
rate UInt8,
userID Int64,
driverID String,
orderID String,
inserted_time DateTime DEFAULT now()
) ENGINE = MergeTree
PARTITION BY userID
ORDER BY (inserted_time);
CREATE TABLE analytics.user_ratings_queue(
rate UInt8,
userID Int64,
driverID String,
orderID String
) ENGINE = Kafka
SETTINGS kafka_broker_list = 'broker:9092',
kafka_topic_list = 'user-ratings',
kafka_group_name = 'rating_readers',
kafka_format = 'JSON',
kafka_max_block_size = 1048576;
CREATE MATERIALIZED VIEW analytics.user_ratings_queue_mv TO analytics.user_ratings AS
SELECT rate, userID, driverID, orderID
FROM analytics.user_ratings_queue;
CREATE TABLE IF NOT EXISTS analytics.orders(
from_place String,
to_place String,
userID Int64,
driverID String,
orderID String,
inserted_time DateTime DEFAULT now()
) ENGINE = MergeTree
PARTITION BY driverID
ORDER BY (inserted_time);
CREATE TABLE analytics.orders_queue(
from_place String,
to_place String,
userID Int64,
driverID String,
orderID String
) ENGINE = Kafka
SETTINGS kafka_broker_list = 'broker:9092',
kafka_topic_list = 'orders',
kafka_group_name = 'order_readers',
kafka_format = 'Avro',
kafka_max_block_size = 1048576;
CREATE MATERIALIZED VIEW analytics.orders_queue_mv TO orders AS
SELECT from_place, to_place, userID, driverID, orderID
FROM analytics.orders_queue;
@@ -0,0 +1,19 @@
# cockroachdb
`cockroachdb://user:password@host:port/dbname?query` (`cockroach://`, and `crdb-postgres://` work, too)
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock |
| `x-force-lock` | `ForceLock` | Force lock acquisition to fix faulty migrations which may not have released the schema lock (Boolean, default is `false`) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
@@ -0,0 +1,142 @@
# CockroachDB tutorial for beginners (insecure cluster)
## Create/configure database
First, let's start a local cluster - follow step 1. and 2. from [the docs](https://www.cockroachlabs.com/docs/stable/start-a-local-cluster.html#step-1-start-the-first-node).
Once you have it, create a database. Here I am going to create a database called `example`.
Our user here is `cockroach`. We are not going to use a password, since it's not supported for insecure cluster.
```
cockroach sql --insecure --host=localhost:26257
```
```
CREATE DATABASE example;
CREATE USER IF NOT EXISTS cockroach;
GRANT ALL ON DATABASE example TO cockroach;
```
When using Migrate CLI we need to pass to database URL. Let's export it to a variable for convenience:
```
export COCKROACHDB_URL='cockroachdb://cockroach:@localhost:26257/example?sslmode=disable'
```
`sslmode=disable` means that the connection with our database will not be encrypted. This is needed to connect to an insecure node.
**NOTE:** Do not use COCKROACH_URL as a variable name here, it's already in use for discrete parameters and you may run into connection problems. For more info check out [docs](https://www.cockroachlabs.com/docs/stable/connection-parameters.html#connect-using-discrete-parameters).
You can find further description of database URLs [here](README.md#database-urls).
## Create migrations
Let's create a table called `users`:
```
migrate create -ext sql -dir db/migrations -seq create_users_table
```
If there were no errors, we should have two files available under `db/migrations` folder:
- 000001_create_users_table.down.sql
- 000001_create_users_table.up.sql
Note the `sql` extension that we provided.
In the `.up.sql` file let's create the table:
```
CREATE TABLE IF NOT EXISTS example.users
(
user_id INT PRIMARY KEY,
username VARCHAR (50) UNIQUE NOT NULL,
password VARCHAR (50) NOT NULL,
email VARCHAR (300) UNIQUE NOT NULL
);
```
And in the `.down.sql` let's delete it:
```
DROP TABLE IF EXISTS example.users;
```
By adding `IF EXISTS/IF NOT EXISTS` we are making migrations idempotent - you can read more about idempotency in [getting started](/GETTING_STARTED.md#create-migrations)
## Run migrations
```
migrate -database ${COCKROACHDB_URL} -path db/migrations up
```
Let's check if the table was created properly by running `cockroach sql --insecure --host=localhost:26257 -e "show columns from example.users;"`.
The output you are supposed to see:
```
column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden
+-------------+--------------+-------------+----------------+-----------------------+----------------------------------------------+-----------+
user_id | INT8 | false | NULL | | {primary,users_username_key,users_email_key} | false
username | VARCHAR(50) | false | NULL | | {users_username_key} | false
password | VARCHAR(50) | false | NULL | | {} | false
email | VARCHAR(300) | false | NULL | | {users_email_key} | false
(4 rows)
```
Now let's check if running reverse migration also works:
```
migrate -database ${COCKROACHDB_URL} -path db/migrations down
```
Make sure to check if your database changed as expected in this case as well.
## Database transactions
To show database transactions usage, let's create another set of migrations by running:
```
migrate create -ext sql -dir db/migrations -seq add_mood_to_users
```
Again, it should create for us two migrations files:
- 000002_add_mood_to_users.down.sql
- 000002_add_mood_to_users.up.sql
In Cockroach, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands, similar to PostgreSQL.
In our example, we are going to add a column to our database that can only accept enumerable values or NULL.
Migration up:
```
BEGIN;
ALTER TABLE example.users ADD COLUMN mood STRING;
ALTER TABLE example.users ADD CONSTRAINT check_mood CHECK (mood IN ('happy', 'sad', 'neutral'));
COMMIT;
```
Migration down:
```
ALTER TABLE example.users DROP COLUMN mood;
```
Now we can run our new migration and check the database:
```
migrate -database ${COCKROACHDB_URL} -path db/migrations up
cockroach sql --insecure --host=localhost:26257 -e "show columns from example.users;"
```
Expected output:
```
column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden
+-------------+--------------+-------------+----------------+-----------------------+----------------------------------------------+-----------+
user_id | INT8 | false | NULL | | {primary,users_username_key,users_email_key} | false
username | VARCHAR(50) | false | NULL | | {users_username_key} | false
password | VARCHAR(50) | false | NULL | | {} | false
email | VARCHAR(300) | false | NULL | | {users_email_key} | false
mood | STRING | true | NULL | | {} | false
(5 rows)
```
## Optional: Run migrations within your Go app
Here is a very simple app running migrations for the above configuration:
```
import (
"log"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/cockroachdb"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
func main() {
m, err := migrate.New(
"file://db/migrations",
"cockroachdb://cockroach:@localhost:26257/example?sslmode=disable")
if err != nil {
log.Fatal(err)
}
if err := m.Up(); err != nil {
log.Fatal(err)
}
}
```
You can find details [here](README.md#use-in-your-go-project)
@@ -0,0 +1,365 @@
package cockroachdb
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"regexp"
"strconv"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"github.com/lib/pq"
"go.uber.org/atomic"
)
func init() {
db := CockroachDb{}
database.Register("cockroach", &db)
database.Register("cockroachdb", &db)
database.Register("crdb-postgres", &db)
}
var DefaultMigrationsTable = "schema_migrations"
var DefaultLockTable = "schema_lock"
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
)
type Config struct {
MigrationsTable string
LockTable string
ForceLock bool
DatabaseName string
}
type CockroachDb struct {
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if config.DatabaseName == "" {
query := `SELECT current_database()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if len(config.LockTable) == 0 {
config.LockTable = DefaultLockTable
}
px := &CockroachDb{
db: instance,
config: config,
}
// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
if err := px.ensureLockTable(); err != nil {
return nil, err
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (c *CockroachDb) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
db, err := sql.Open("postgres", connectString)
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
lockTable := purl.Query().Get("x-lock-table")
if len(lockTable) == 0 {
lockTable = DefaultLockTable
}
forceLockQuery := purl.Query().Get("x-force-lock")
forceLock, err := strconv.ParseBool(forceLockQuery)
if err != nil {
forceLock = false
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
LockTable: lockTable,
ForceLock: forceLock,
})
if err != nil {
return nil, err
}
return px, nil
}
func (c *CockroachDb) Close() error {
return c.db.Close()
}
// Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed
// See: https://github.com/cockroachdb/cockroach/issues/13546
func (c *CockroachDb) Lock() error {
return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
if err != nil {
return err
}
query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
rows, err := tx.Query(query, aid)
if err != nil {
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
}
defer func() {
if errClose := rows.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// If row exists at all, lock is present
locked := rows.Next()
if locked && !c.config.ForceLock {
return database.ErrLocked
}
query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
}
return nil
})
})
}
// Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed
// See: https://github.com/cockroachdb/cockroach/issues/13546
func (c *CockroachDb) Unlock() error {
return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
if err != nil {
return err
}
// In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until
// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
if _, err := c.db.Exec(query, aid); err != nil {
if e, ok := err.(*pq.Error); ok {
// 42P01 is "UndefinedTableError" in CockroachDB
// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
if e.Code == "42P01" {
// On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema
return nil
}
}
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
}
return nil
})
}
func (c *CockroachDb) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if _, err := c.db.Exec(query); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (c *CockroachDb) SetVersion(version int, dirty bool) error {
return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
return err
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
return err
}
}
return nil
})
}
func (c *CockroachDb) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.db.QueryRow(query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pq.Error); ok {
// 42P01 is "UndefinedTableError" in CockroachDB
// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
if e.Code == "42P01" {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (c *CockroachDb) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
tables, err := c.db.Query(query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
if _, err := c.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the CockroachDb type.
func (c *CockroachDb) ensureVersionTable() (err error) {
if err = c.Lock(); err != nil {
return err
}
defer func() {
if e := c.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// check if migration table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
// if not, create the empty migration table
query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
if _, err := c.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (c *CockroachDb) ensureLockTable() error {
// check if lock table exists
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
// if not, create the empty lock table
query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
if _, err := c.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
@@ -0,0 +1,174 @@
package cockroachdb
// error codes https://github.com/lib/pq/blob/master/error.go
import (
"context"
"database/sql"
"fmt"
"github.com/golang-migrate/migrate/v4"
"log"
"strings"
"testing"
)
import (
"github.com/dhui/dktest"
_ "github.com/lib/pq"
)
import (
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const defaultPort = 26257
var (
opts = dktest.Options{Cmd: []string{"start", "--insecure"}, PortRequired: true, ReadyFunc: isReady}
// Released versions: https://www.cockroachlabs.com/docs/releases/
specs = []dktesting.ContainerSpec{
{ImageName: "cockroachdb/cockroach:v1.0.7", Options: opts},
{ImageName: "cockroachdb/cockroach:v1.1.9", Options: opts},
{ImageName: "cockroachdb/cockroach:v2.0.7", Options: opts},
{ImageName: "cockroachdb/cockroach:v2.1.3", Options: opts},
}
)
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(defaultPort)
if err != nil {
log.Println("port error:", err)
return false
}
db, err := sql.Open("postgres", fmt.Sprintf("postgres://root@%v:%v?sslmode=disable", ip, port))
if err != nil {
log.Println("open error:", err)
return false
}
if err := db.PingContext(ctx); err != nil {
log.Println("ping error:", err)
return false
}
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
return true
}
func createDB(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
db, err := sql.Open("postgres", fmt.Sprintf("postgres://root@%v:%v?sslmode=disable", ip, port))
if err != nil {
t.Fatal(err)
}
if err = db.Ping(); err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
if _, err = db.Exec("CREATE DATABASE migrate"); err != nil {
t.Fatal(err)
}
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)
ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
c := &CockroachDb{}
d, err := c.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)
ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
c := &CockroachDb{}
d, err := c.Open(addr)
if err != nil {
t.Fatal(err)
}
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "migrate", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMultiStatement(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)
ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
c := &CockroachDb{}
d, err := c.Open(addr)
if err != nil {
t.Fatal(err)
}
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.(*CockroachDb).db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
createDB(t, ci)
ip, port, err := ci.Port(26257)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable&x-custom=foobar", ip, port)
c := &CockroachDb{}
_, err = c.Open(addr)
if err != nil {
t.Fatal(err)
}
})
}
@@ -0,0 +1,5 @@
CREATE TABLE users (
user_id INT UNIQUE,
name STRING(40),
email STRING(40)
);
@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX IF NOT EXISTS users_email_index ON users (email);
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,5 @@
CREATE TABLE books (
user_id INT,
name STRING(40),
author STRING(40)
);
@@ -0,0 +1,5 @@
CREATE TABLE movies (
user_id INT,
name STRING(40),
director STRING(40)
);
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,123 @@
// Package database provides the Driver interface.
// All database drivers must implement this interface, register themselves,
// optionally provide a `WithInstance` function and pass the tests
// in package database/testing.
package database
import (
"fmt"
"io"
"sync"
iurl "github.com/golang-migrate/migrate/v4/internal/url"
)
var (
ErrLocked = fmt.Errorf("can't acquire lock")
ErrNotLocked = fmt.Errorf("can't unlock, as not currently locked")
)
const NilVersion int = -1
var driversMu sync.RWMutex
var drivers = make(map[string]Driver)
// Driver is the interface every database driver must implement.
//
// How to implement a database driver?
// 1. Implement this interface.
// 2. Optionally, add a function named `WithInstance`.
// This function should accept an existing DB instance and a Config{} struct
// and return a driver instance.
// 3. Add a test that calls database/testing.go:Test()
// 4. Add own tests for Open(), WithInstance() (when provided) and Close().
// All other functions are tested by tests in database/testing.
// Saves you some time and makes sure all database drivers behave the same way.
// 5. Call Register in init().
// 6. Create a internal/cli/build_<driver-name>.go file
// 7. Add driver name in 'DATABASE' variable in Makefile
//
// Guidelines:
// - Don't try to correct user input. Don't assume things.
// When in doubt, return an error and explain the situation to the user.
// - All configuration input must come from the URL string in func Open()
// or the Config{} struct in WithInstance. Don't os.Getenv().
type Driver interface {
// Open returns a new driver instance configured with parameters
// coming from the URL string. Migrate will call this function
// only once per instance.
Open(url string) (Driver, error)
// Close closes the underlying database instance managed by the driver.
// Migrate will call this function only once per instance.
Close() error
// Lock should acquire a database lock so that only one migration process
// can run at a time. Migrate will call this function before Run is called.
// If the implementation can't provide this functionality, return nil.
// Return database.ErrLocked if database is already locked.
Lock() error
// Unlock should release the lock. Migrate will call this function after
// all migrations have been run.
Unlock() error
// Run applies a migration to the database. migration is guaranteed to be not nil.
Run(migration io.Reader) error
// SetVersion saves version and dirty state.
// Migrate will call this function before and after each call to Run.
// version must be >= -1. -1 means NilVersion.
SetVersion(version int, dirty bool) error
// Version returns the currently active version and if the database is dirty.
// When no migration has been applied, it must return version -1.
// Dirty means, a previous migration failed and user interaction is required.
Version() (version int, dirty bool, err error)
// Drop deletes everything in the database.
// Note that this is a breaking action, a new call to Open() is necessary to
// ensure subsequent calls work as expected.
Drop() error
}
// Open returns a new driver instance.
func Open(url string) (Driver, error) {
scheme, err := iurl.SchemeFromURL(url)
if err != nil {
return nil, err
}
driversMu.RLock()
d, ok := drivers[scheme]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
}
return d.Open(url)
}
// Register globally registers a driver.
func Register(name string, driver Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if driver == nil {
panic("Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("Register called twice for driver " + name)
}
drivers[name] = driver
}
// List lists the registered drivers
func List() []string {
driversMu.RLock()
defer driversMu.RUnlock()
names := make([]string, 0, len(drivers))
for n := range drivers {
names = append(names, n)
}
return names
}
@@ -0,0 +1,115 @@
package database
import (
"io"
"testing"
)
func ExampleDriver() {
// see database/stub for an example
// database/stub/stub.go has the driver implementation
// database/stub/stub_test.go runs database/testing/test.go:Test
}
// Using database/stub here is not possible as it
// results in an import cycle.
type mockDriver struct {
url string
}
func (m *mockDriver) Open(url string) (Driver, error) {
return &mockDriver{
url: url,
}, nil
}
func (m *mockDriver) Close() error {
return nil
}
func (m *mockDriver) Lock() error {
return nil
}
func (m *mockDriver) Unlock() error {
return nil
}
func (m *mockDriver) Run(migration io.Reader) error {
return nil
}
func (m *mockDriver) SetVersion(version int, dirty bool) error {
return nil
}
func (m *mockDriver) Version() (version int, dirty bool, err error) {
return 0, false, nil
}
func (m *mockDriver) Drop() error {
return nil
}
func TestRegisterTwice(t *testing.T) {
Register("mock", &mockDriver{})
var err interface{}
func() {
defer func() {
err = recover()
}()
Register("mock", &mockDriver{})
}()
if err == nil {
t.Fatal("expected a panic when calling Register twice")
}
}
func TestOpen(t *testing.T) {
// Make sure the driver is registered.
// But if the previous test already registered it just ignore the panic.
// If we don't do this it will be impossible to run this test standalone.
func() {
defer func() {
_ = recover()
}()
Register("mock", &mockDriver{})
}()
cases := []struct {
url string
err bool
}{
{
"mock://user:pass@tcp(host:1337)/db",
false,
},
{
"unknown://bla",
true,
},
}
for _, c := range cases {
t.Run(c.url, func(t *testing.T) {
d, err := Open(c.url)
if err == nil {
if c.err {
t.Fatal("expected an error for an unknown driver")
} else {
if md, ok := d.(*mockDriver); !ok {
t.Fatalf("expected *mockDriver got %T", d)
} else if md.url != c.url {
t.Fatalf("expected %q got %q", c.url, md.url)
}
}
} else if !c.err {
t.Fatalf("did not expect %q", err)
}
})
}
}
@@ -0,0 +1,27 @@
package database
import (
"fmt"
)
// Error should be used for errors involving queries ran against the database
type Error struct {
// Optional: the line number
Line uint
// Query is a query excerpt
Query []byte
// Err is a useful/helping error message for humans
Err string
// OrigErr is the underlying error
OrigErr error
}
func (e Error) Error() string {
if len(e.Err) == 0 {
return fmt.Sprintf("%v in line %v: %s", e.OrigErr, e.Line, e.Query)
}
return fmt.Sprintf("%v in line %v: %s (details: %v)", e.Err, e.Line, e.Query, e.OrigErr)
}
@@ -0,0 +1,12 @@
# firebird
`firebirdsql://user:password@servername[:port_number]/database_name_or_file[?params1=value1[&param2=value2]...]`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `auth_plugin_name` | | Authentication plugin name. Srp256/Srp/Legacy_Auth are available. (default is Srp) |
| `column_name_to_lower` | | Force column name to lower. (default is false) |
| `role` | | Role name |
| `tzname` | | Time Zone name. (For Firebird 4.0+) |
| `wire_crypt` | | Enable wire data encryption or not. For Firebird 3.0+ (default is true) |
@@ -0,0 +1,5 @@
CREATE TABLE users (
user_id integer unique,
name varchar(40),
email varchar(40)
);
@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX users_email_index ON users (email);
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,5 @@
CREATE TABLE books (
user_id integer,
name varchar(40),
author varchar(40)
);
@@ -0,0 +1,5 @@
CREATE TABLE movies (
user_id integer,
name varchar(40),
director varchar(40)
);
@@ -0,0 +1,259 @@
//go:build go1.9
// +build go1.9
package firebird
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
_ "github.com/nakagami/firebirdsql"
"go.uber.org/atomic"
)
func init() {
db := Firebird{}
database.Register("firebird", &db)
database.Register("firebirdsql", &db)
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
)
type Config struct {
DatabaseName string
MigrationsTable string
}
type Firebird struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
fb := &Firebird{
conn: conn,
db: instance,
config: config,
}
if err := fb.ensureVersionTable(); err != nil {
return nil, err
}
return fb, nil
}
func (f *Firebird) Open(dsn string) (database.Driver, error) {
purl, err := nurl.Parse(dsn)
if err != nil {
return nil, err
}
db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
px, err := WithInstance(db, &Config{
MigrationsTable: purl.Query().Get("x-migrations-table"),
DatabaseName: purl.Path,
})
if err != nil {
return nil, err
}
return px, nil
}
func (f *Firebird) Close() error {
connErr := f.conn.Close()
dbErr := f.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
func (f *Firebird) Lock() error {
if !f.isLocked.CAS(false, true) {
return database.ErrLocked
}
return nil
}
func (f *Firebird) Unlock() error {
if !f.isLocked.CAS(true, false) {
return database.ErrNotLocked
}
return nil
}
func (f *Firebird) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (f *Firebird) SetVersion(version int, dirty bool) error {
// Always re-write the schema version to prevent empty schema version
// for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
// TODO: parameterize this SQL statement
// https://firebirdsql.org/refdocs/langrefupd20-execblock.html
// VALUES (?, ?) doesn't work
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
DELETE FROM "%v";
INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
END;`,
f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (f *Firebird) Version() (version int, dirty bool, err error) {
var d int
query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, itob(d), nil
}
}
func (f *Firebird) Drop() (err error) {
// select all tables
query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
tables, err := f.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// delete one by one ...
for _, t := range tableNames {
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
execute statement 'drop table "%v"';
END;`,
t, t)
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
func (f *Firebird) ensureVersionTable() (err error) {
if err = f.Lock(); err != nil {
return err
}
defer func() {
if e := f.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
END;`,
f.config.MigrationsTable, f.config.MigrationsTable)
if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// btoi converts bool to int
func btoi(v bool) int {
if v {
return 1
}
return 0
}
// itob converts int to bool
func itob(v int) bool {
return v != 0
}
@@ -0,0 +1,226 @@
package firebird
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"fmt"
"log"
"github.com/golang-migrate/migrate/v4"
"io"
"strings"
"testing"
"github.com/dhui/dktest"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
_ "github.com/nakagami/firebirdsql"
)
const (
user = "test_user"
password = "123456"
dbName = "test.fdb"
)
var (
opts = dktest.Options{
PortRequired: true,
ReadyFunc: isReady,
Env: map[string]string{
"FIREBIRD_DATABASE": dbName,
"FIREBIRD_USER": user,
"FIREBIRD_PASSWORD": password,
},
}
specs = []dktesting.ContainerSpec{
{ImageName: "jacobalberty/firebird:2.5-ss", Options: opts},
{ImageName: "jacobalberty/firebird:3.0", Options: opts},
}
)
func fbConnectionString(host, port string) string {
//firebird://user:password@servername[:port_number]/database_name_or_file[?params1=value1[&param2=value2]...]
return fmt.Sprintf("firebird://%s:%s@%s:%s//firebird/data/%s", user, password, host, port, dbName)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("firebirdsql", fbConnectionString(ip, port))
if err != nil {
log.Println("open error:", err)
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := fbConnectionString(ip, port)
p := &Firebird{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT Count(*) FROM rdb$relations"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := fbConnectionString(ip, port)
p := &Firebird{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "firebirdsql", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := fbConnectionString(ip, port)
p := &Firebird{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed in line 0: CREATE TABLEE foo (foo varchar(40)); (details: Dynamic SQL Error
SQL error code = -104
Token unknown - line 1, column 8
TABLEE
)`
if err := d.Run(strings.NewReader("CREATE TABLEE foo (foo varchar(40));")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
msg := err.Error()
t.Fatalf("expected '%s' but got '%s'", wantErr, msg)
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := fbConnectionString(ip, port) + "?sslmode=disable&x-custom=foobar"
p := &Firebird{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
})
}
func Test_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := fbConnectionString(ip, port)
p := &Firebird{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT Count(*) FROM rdb$relations"))
ps := d.(*Firebird)
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
@@ -0,0 +1,24 @@
# MongoDB
* Driver work with mongo through [db.runCommands](https://docs.mongodb.com/manual/reference/command/)
* Migrations support json format. It contains array of commands for `db.runCommand`. Every command is executed in separate request to database
* All keys have to be in quotes `"`
* [Examples](./examples)
# Usage
`mongodb://user:password@host:port/dbname?query` (`mongodb+srv://` also works, but behaves a bit differently. See [docs](https://docs.mongodb.com/manual/reference/connection-string/#dns-seedlist-connection-format) for more information)
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-collection` | `MigrationsCollection` | Name of the migrations collection |
| `x-transaction-mode` | `TransactionMode` | If set to `true` wrap commands in [transaction](https://docs.mongodb.com/manual/core/transactions). Available only for replica set. Driver is using [strconv.ParseBool](https://golang.org/pkg/strconv/#ParseBool) for parsing|
| `x-advisory-locking` | `true` | Feature flag for advisory locking, if set to false, disable advisory locking |
| `x-advisory-lock-collection` | `migrate_advisory_lock` | The name of the collection to use for advisory locking.|
| `x-advisory-lock-timeout` | `15` | The max time in seconds that migrate will wait to acquire a lock before failing. |
| `x-advisory-lock-timeout-interval` | `10` | The max time in seconds between attempts to acquire the advisory lock, the lock is attempted to be acquired using an exponential backoff algorithm. |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `user` | | The user to sign in as. Can be omitted |
| `password` | | The user's password. Can be omitted |
| `host` | | The host to connect to |
| `port` | | The port to bind to |
@@ -0,0 +1,12 @@
[
{
"createUser": "deminem",
"pwd": "gogo",
"roles": [
{
"role": "readWrite",
"db": "testMigration"
}
]
}
]
@@ -0,0 +1,10 @@
[
{
"dropIndexes": "mycollection",
"index": "username_sort_by_asc_created"
},
{
"dropIndexes": "mycollection",
"index": "unique_email"
}
]
@@ -0,0 +1,21 @@
[{
"createIndexes": "mycollection",
"indexes": [
{
"key": {
"username": 1,
"created": -1
},
"name": "username_sort_by_asc_created",
"background": true
},
{
"key": {
"email": 1
},
"name": "unique_email",
"unique": true,
"background": true
}
]
}]
@@ -0,0 +1,16 @@
[
{
"update": "users",
"updates": [
{
"q": {},
"u": {
"$unset": {
"status": ""
}
},
"multi": true
}
]
}
]
@@ -0,0 +1,16 @@
[
{
"update": "users",
"updates": [
{
"q": {},
"u": {
"$set": {
"status": "active"
}
},
"multi": true
}
]
}
]
@@ -0,0 +1,14 @@
[
{
"update": "users",
"updates": [
{
"q": {},
"u": {
"fullname": ""
},
"multi": true
}
]
}
]
@@ -0,0 +1,23 @@
[
{
"aggregate": "users",
"pipeline": [
{
"$project": {
"_id": 1,
"firstname": 1,
"lastname": 1,
"username": 1,
"password": 1,
"email": 1,
"active": 1,
"fullname": { "$concat": ["$firstname", " ", "$lastname"] }
}
},
{
"$out": "users"
}
],
"cursor": {}
}
]
@@ -0,0 +1,404 @@
package mongodb
import (
"context"
"fmt"
"io"
"net/url"
"os"
"strconv"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
"go.uber.org/atomic"
)
func init() {
db := Mongo{}
database.Register("mongodb", &db)
database.Register("mongodb+srv", &db)
}
var DefaultMigrationsCollection = "schema_migrations"
const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released.
const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout.
const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true.
const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field.
const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for.
var (
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNilConfig = fmt.Errorf("no config")
ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified")
)
type Mongo struct {
client *mongo.Client
db *mongo.Database
config *Config
isLocked atomic.Bool
}
type Locking struct {
CollectionName string
Timeout int
Enabled bool
Interval int
}
type Config struct {
DatabaseName string
MigrationsCollection string
TransactionMode bool
Locking Locking
}
type versionInfo struct {
Version int `bson:"version"`
Dirty bool `bson:"dirty"`
}
type lockObj struct {
Key int `bson:"locking_key"`
Pid int `bson:"pid"`
Hostname string `bson:"hostname"`
CreatedAt time.Time `bson:"created_at"`
}
type findFilter struct {
Key int `bson:"locking_key"`
}
func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if len(config.DatabaseName) == 0 {
return nil, ErrNoDatabaseName
}
if len(config.MigrationsCollection) == 0 {
config.MigrationsCollection = DefaultMigrationsCollection
}
if len(config.Locking.CollectionName) == 0 {
config.Locking.CollectionName = DefaultLockingCollection
}
if config.Locking.Timeout <= 0 {
config.Locking.Timeout = DefaultLockTimeout
}
if config.Locking.Interval <= 0 {
config.Locking.Interval = DefaultLockTimeoutInterval
}
mc := &Mongo{
client: instance,
db: instance.Database(config.DatabaseName),
config: config,
}
if mc.config.Locking.Enabled {
if err := mc.ensureLockTable(); err != nil {
return nil, err
}
}
if err := mc.ensureVersionTable(); err != nil {
return nil, err
}
return mc, nil
}
func (m *Mongo) Open(dsn string) (database.Driver, error) {
// connstring is experimental package, but it used for parse connection string in mongo.Connect function
uri, err := connstring.Parse(dsn)
if err != nil {
return nil, err
}
if len(uri.Database) == 0 {
return nil, ErrNoDatabaseName
}
unknown := url.Values(uri.UnknownOptions)
migrationsCollection := unknown.Get("x-migrations-collection")
lockCollection := unknown.Get("x-advisory-lock-collection")
transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
if err != nil {
return nil, err
}
advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
if err != nil {
return nil, err
}
lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
if err != nil {
return nil, err
}
lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval")
// The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it
// and we will error out if both values are set.
lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval")
lockTimeout := lockTimeoutIntervalValue
if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" {
return nil, ErrLockTimeoutConfigConflict
} else if lockTimeoutIntervalValueFromTypo != "" {
lockTimeout = lockTimeoutIntervalValueFromTypo
}
maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval)
if err != nil {
return nil, err
}
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
if err != nil {
return nil, err
}
if err = client.Ping(context.TODO(), nil); err != nil {
return nil, err
}
mc, err := WithInstance(client, &Config{
DatabaseName: uri.Database,
MigrationsCollection: migrationsCollection,
TransactionMode: transactionMode,
Locking: Locking{
CollectionName: lockCollection,
Timeout: lockingTimout,
Enabled: advisoryLockingFlag,
Interval: maxLockCheckInterval,
},
})
if err != nil {
return nil, err
}
return mc, nil
}
// Parse the url param, convert it to boolean
// returns error if param invalid. returns defaultValue if param not present
func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
// if parameter passed, parse it (otherwise return default value)
if urlParam != "" {
result, err := strconv.ParseBool(urlParam)
if err != nil {
return false, err
}
return result, nil
}
// if no url Param passed, return default value
return defaultValue, nil
}
// Parse the url param, convert it to int
// returns error if param invalid. returns defaultValue if param not present
func parseInt(urlParam string, defaultValue int) (int, error) {
// if parameter passed, parse it (otherwise return default value)
if urlParam != "" {
result, err := strconv.Atoi(urlParam)
if err != nil {
return -1, err
}
return result, nil
}
// if no url Param passed, return default value
return defaultValue, nil
}
func (m *Mongo) SetVersion(version int, dirty bool) error {
migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
if err := migrationsCollection.Drop(context.TODO()); err != nil {
return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
}
_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
if err != nil {
return &database.Error{OrigErr: err, Err: "save version failed"}
}
return nil
}
func (m *Mongo) Version() (version int, dirty bool, err error) {
var versionInfo versionInfo
err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
switch {
case err == mongo.ErrNoDocuments:
return database.NilVersion, false, nil
case err != nil:
return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
default:
return versionInfo.Version, versionInfo.Dirty, nil
}
}
func (m *Mongo) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
var cmds []bson.D
err = bson.UnmarshalExtJSON(migr, true, &cmds)
if err != nil {
return fmt.Errorf("unmarshaling json error: %s", err)
}
if m.config.TransactionMode {
if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
return err
}
} else {
if err := m.executeCommands(context.TODO(), cmds); err != nil {
return err
}
}
return nil
}
func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
if err := sessionContext.StartTransaction(); err != nil {
return &database.Error{OrigErr: err, Err: "failed to start transaction"}
}
if err := m.executeCommands(sessionContext, cmds); err != nil {
// When command execution is failed, it's aborting transaction
// If you tried to call abortTransaction, it`s return error that transaction already aborted
return err
}
if err := sessionContext.CommitTransaction(sessionContext); err != nil {
return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
}
return nil
})
if err != nil {
return err
}
return nil
}
func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
for _, cmd := range cmds {
err := m.db.RunCommand(ctx, cmd).Err()
if err != nil {
return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
}
}
return nil
}
func (m *Mongo) Close() error {
return m.client.Disconnect(context.TODO())
}
func (m *Mongo) Drop() error {
return m.db.Drop(context.TODO())
}
func (m *Mongo) ensureLockTable() error {
indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
Options: indexOptions,
Keys: findFilter{Key: -1},
})
if err != nil {
return err
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the MongoDb type.
func (m *Mongo) ensureVersionTable() (err error) {
if err = m.Lock(); err != nil {
return err
}
defer func() {
if e := m.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
if err != nil {
return err
}
if _, _, err = m.Version(); err != nil {
return err
}
return nil
}
// Utilizes advisory locking on the config.LockingCollection collection
// This uses a unique index on the `locking_key` field.
func (m *Mongo) Lock() error {
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
if !m.config.Locking.Enabled {
return nil
}
pid := os.Getpid()
hostname, err := os.Hostname()
if err != nil {
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
}
newLockObj := lockObj{
Key: lockKeyUniqueValue,
Pid: pid,
Hostname: hostname,
CreatedAt: time.Now(),
}
operation := func() error {
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
defer cancelFunc()
return err
}
exponentialBackOff := backoff.NewExponentialBackOff()
duration := time.Duration(m.config.Locking.Timeout) * time.Second
exponentialBackOff.MaxElapsedTime = duration
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
err = backoff.Retry(operation, exponentialBackOff)
if err != nil {
return database.ErrLocked
}
return nil
})
}
func (m *Mongo) Unlock() error {
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
if !m.config.Locking.Enabled {
return nil
}
filter := findFilter{
Key: lockKeyUniqueValue,
}
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
defer cancel()
if err != nil {
return err
}
return nil
})
}
@@ -0,0 +1,430 @@
package mongodb
import (
"bytes"
"context"
"fmt"
"log"
"github.com/golang-migrate/migrate/v4"
"io"
"os"
"strconv"
"testing"
"time"
)
import (
"github.com/dhui/dktest"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
import (
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
var (
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady}
// Supported versions: https://www.mongodb.com/support-policy
specs = []dktesting.ContainerSpec{
{ImageName: "mongo:3.4", Options: opts},
{ImageName: "mongo:3.6", Options: opts},
{ImageName: "mongo:4.0", Options: opts},
{ImageName: "mongo:4.2", Options: opts},
}
)
func mongoConnectionString(host, port string) string {
// there is connect option for excluding serverConnection algorithm
// it's let avoid errors with mongo replica set connection in docker container
return fmt.Sprintf("mongodb://%s:%s/testMigration?connect=direct", host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoConnectionString(ip, port)))
if err != nil {
return false
}
defer func() {
if err := client.Disconnect(ctx); err != nil {
log.Println("close error:", err)
}
}()
if err = client.Ping(ctx, nil); err != nil {
switch err {
case io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func Test(t *testing.T) {
t.Run("test", test)
t.Run("testMigrate", testMigrate)
t.Run("testWithAuth", testWithAuth)
t.Run("testLockWorks", testLockWorks)
t.Cleanup(func() {
for _, spec := range specs {
t.Log("Cleaning up ", spec.ImageName)
if err := spec.Cleanup(); err != nil {
t.Error("Error removing ", spec.ImageName, "error:", err)
}
}
})
}
func test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.TestNilVersion(t, d)
dt.TestLockAndUnlock(t, d)
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
dt.TestSetVersion(t, d)
dt.TestDrop(t, d)
})
}
func testMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func testWithAuth(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
createUserCMD := []byte(`[{"createUser":"deminem","pwd":"gogo","roles":[{"role":"readWrite","db":"testMigration"}]}]`)
err = d.Run(bytes.NewReader(createUserCMD))
if err != nil {
t.Fatal(err)
}
testcases := []struct {
name string
connectUri string
isErrorExpected bool
}{
{"right auth data", "mongodb://deminem:gogo@%s:%v/testMigration", false},
{"wrong auth data", "mongodb://wrong:auth@%s:%v/testMigration", true},
}
for _, tcase := range testcases {
t.Run(tcase.name, func(t *testing.T) {
mc := &Mongo{}
d, err := mc.Open(fmt.Sprintf(tcase.connectUri, ip, port))
if err == nil {
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
}
switch {
case tcase.isErrorExpected && err == nil:
t.Fatalf("no error when expected")
case !tcase.isErrorExpected && err != nil:
t.Fatalf("unexpected error: %v", err)
}
})
}
})
}
func testLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := mongoConnectionString(ip, port)
p := &Mongo{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
mc := d.(*Mongo)
err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Unlock()
if err != nil {
t.Fatal(err)
}
err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Unlock()
if err != nil {
t.Fatal(err)
}
// enable locking,
//try to hit a lock conflict
mc.config.Locking.Enabled = true
mc.config.Locking.Timeout = 1
err = mc.Lock()
if err != nil {
t.Fatal(err)
}
err = mc.Lock()
if err == nil {
t.Fatal("should have failed, mongo should be locked already")
}
})
}
func TestTransaction(t *testing.T) {
transactionSpecs := []dktesting.ContainerSpec{
{ImageName: "mongo:4", Options: dktest.Options{PortRequired: true, ReadyFunc: isReady,
Cmd: []string{"mongod", "--bind_ip_all", "--replSet", "rs0"}}},
}
t.Cleanup(func() {
for _, spec := range transactionSpecs {
t.Log("Cleaning up ", spec.ImageName)
if err := spec.Cleanup(); err != nil {
t.Error("Error removing ", spec.ImageName, "error:", err)
}
}
})
dktesting.ParallelTest(t, transactionSpecs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(mongoConnectionString(ip, port)))
if err != nil {
t.Fatal(err)
}
err = client.Ping(context.TODO(), nil)
if err != nil {
t.Fatal(err)
}
//rs.initiate()
err = client.Database("admin").RunCommand(context.TODO(), bson.D{bson.E{Key: "replSetInitiate", Value: bson.D{}}}).Err()
if err != nil {
t.Fatal(err)
}
err = waitForReplicaInit(client)
if err != nil {
t.Fatal(err)
}
d, err := WithInstance(client, &Config{
DatabaseName: "testMigration",
})
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
//We have to create collection
//transactions don't support operations with creating new dbs, collections
//Unique index need for checking transaction aborting
insertCMD := []byte(`[
{"create":"hello"},
{"createIndexes": "hello",
"indexes": [{
"key": {
"wild": 1
},
"name": "unique_wild",
"unique": true,
"background": true
}]
}]`)
err = d.Run(bytes.NewReader(insertCMD))
if err != nil {
t.Fatal(err)
}
testcases := []struct {
name string
cmds []byte
documentsCount int64
isErrorExpected bool
}{
{
name: "success transaction",
cmds: []byte(`[{"insert":"hello","documents":[
{"wild":"world"},
{"wild":"west"},
{"wild":"natural"}
]
}]`),
documentsCount: 3,
isErrorExpected: false,
},
{
name: "failure transaction",
//transaction have to be failure - duplicate unique key wild:west
//none of the documents should be added
cmds: []byte(`[{"insert":"hello","documents":[{"wild":"flower"}]},
{"insert":"hello","documents":[
{"wild":"cat"},
{"wild":"west"}
]
}]`),
documentsCount: 3,
isErrorExpected: true,
},
}
for _, tcase := range testcases {
t.Run(tcase.name, func(t *testing.T) {
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(mongoConnectionString(ip, port)))
if err != nil {
t.Fatal(err)
}
err = client.Ping(context.TODO(), nil)
if err != nil {
t.Fatal(err)
}
d, err := WithInstance(client, &Config{
DatabaseName: "testMigration",
TransactionMode: true,
})
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
runErr := d.Run(bytes.NewReader(tcase.cmds))
if runErr != nil {
if !tcase.isErrorExpected {
t.Fatal(runErr)
}
}
documentsCount, err := client.Database("testMigration").Collection("hello").CountDocuments(context.TODO(), bson.M{})
if err != nil {
t.Fatal(err)
}
if tcase.documentsCount != documentsCount {
t.Fatalf("expected %d and actual %d documents count not equal. run migration error:%s", tcase.documentsCount, documentsCount, runErr)
}
})
}
})
}
type isMaster struct {
IsMaster bool `bson:"ismaster"`
}
func waitForReplicaInit(client *mongo.Client) error {
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
timeout, err := strconv.Atoi(os.Getenv("MIGRATE_TEST_MONGO_REPLICA_SET_INIT_TIMEOUT"))
if err != nil {
timeout = 30
}
timeoutTimer := time.NewTimer(time.Duration(timeout) * time.Second)
defer timeoutTimer.Stop()
for {
select {
case <-ticker.C:
var status isMaster
//Check that node is primary because
//during replica set initialization, the first node first becomes a secondary and then becomes the primary
//should consider that initialization is completed only after the node has become the primary
result := client.Database("admin").RunCommand(context.TODO(), bson.D{bson.E{Key: "isMaster", Value: 1}})
r, err := result.DecodeBytes()
if err != nil {
return err
}
err = bson.Unmarshal(r, &status)
if err != nil {
return err
}
if status.IsMaster {
return nil
}
case <-timeoutTimer.C:
return fmt.Errorf("replica init timeout")
}
}
}
@@ -0,0 +1,46 @@
// Package multistmt provides methods for parsing multi-statement database migrations
package multistmt
import (
"bufio"
"bytes"
"io"
)
// StartBufSize is the default starting size of the buffer used to scan and parse multi-statement migrations
var StartBufSize = 4096
// Handler handles a single migration parsed from a multi-statement migration.
// It's given the single migration to handle and returns whether or not further statements
// from the multi-statement migration should be parsed and handled.
type Handler func(migration []byte) bool
func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) {
return func(d []byte, atEOF bool) (int, []byte, error) {
// SplitFunc inspired by bufio.ScanLines() implementation
if atEOF {
if len(d) == 0 {
return 0, nil, nil
}
return len(d), d, nil
}
if i := bytes.Index(d, delimiter); i >= 0 {
return i + len(delimiter), d[:i+len(delimiter)], nil
}
return 0, nil, nil
}
}
// Parse parses the given multi-statement migration
func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
scanner := bufio.NewScanner(reader)
scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize)
scanner.Split(splitWithDelimiter(delimiter))
for scanner.Scan() {
cont := h(scanner.Bytes())
if !cont {
break
}
}
return scanner.Err()
}
@@ -0,0 +1,57 @@
package multistmt_test
import (
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/golang-migrate/migrate/v4/database/multistmt"
)
const maxMigrationSize = 1024
func TestParse(t *testing.T) {
testCases := []struct {
name string
multiStmt string
delimiter string
expected []string
expectedErr error
}{
{name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
expected: []string{"single statement, no delimiter"}, expectedErr: nil},
{name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
{name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
expected: []string{"statement one;", " statement two"}, expectedErr: nil},
{name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
stmts := make([]string, 0, len(tc.expected))
err := multistmt.Parse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
stmts = append(stmts, string(b))
return true
})
assert.Equal(t, tc.expectedErr, err)
assert.Equal(t, tc.expected, stmts)
})
}
}
func TestParseDiscontinue(t *testing.T) {
multiStmt := "statement one; statement two"
delimiter := ";"
expected := []string{"statement one;"}
stmts := make([]string, 0, len(expected))
err := multistmt.Parse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
stmts = append(stmts, string(b))
return false
})
assert.Nil(t, err)
assert.Equal(t, expected, stmts)
}
@@ -0,0 +1,56 @@
# MySQL
`mysql://user:password@tcp(host:port)/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-no-lock` | `NoLock` | Set to `true` to skip `GET_LOCK`/`RELEASE_LOCK` statements. Useful for [multi-master MySQL flavors](https://www.percona.com/doc/percona-xtradb-cluster/LATEST/features/pxc-strict-mode.html#explicit-table-locking). Only run migrations from one host when this is enabled. |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds, functionally similar to [Server-side SELECT statement timeouts](https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/) but enforced by the client. Available for all versions of MySQL, not just >=5.7. |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. |
| `port` | | The port to bind to. |
| `tls` | | TLS / SSL encrypted connection parameter; see [go-sql-driver](https://github.com/go-sql-driver/mysql#tls). Use any name (e.g. `migrate`) if you want to use a custom TLS config (`x-tls-` queries). |
| `x-tls-ca` | | The location of the CA (certificate authority) file. |
| `x-tls-cert` | | The location of the client certificate file. Must be used with `x-tls-key`. |
| `x-tls-key` | | The location of the private key file. Must be used with `x-tls-cert`. |
| `x-tls-insecure-skip-verify` | | Whether or not to use SSL (true\|false) |
## Use with existing client
If you use the MySQL driver with existing database client, you must create the client with parameter `multiStatements=true`:
```go
package main
import (
"database/sql"
_ "github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/mysql"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
func main() {
db, _ := sql.Open("mysql", "user:password@tcp(host:port)/dbname?multiStatements=true")
driver, _ := mysql.WithInstance(db, &mysql.Config{})
m, _ := migrate.NewWithDatabaseInstance(
"file:///migrations",
"mysql",
driver,
)
m.Steps(2)
}
```
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://dev.mysql.com/doc/refman/5.7/en/commit.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
@@ -0,0 +1,3 @@
CREATE TABLE IF NOT EXISTS test (
firstname VARCHAR(16)
);
@@ -0,0 +1,514 @@
//go:build go1.9
// +build go1.9
package mysql
import (
"context"
"crypto/tls"
"crypto/x509"
"database/sql"
"fmt"
"io"
nurl "net/url"
"os"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4/database"
"github.com/hashicorp/go-multierror"
)
var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
func init() {
database.Register("mysql", &Mysql{})
}
var DefaultMigrationsTable = "schema_migrations"
var (
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrAppendPEM = fmt.Errorf("failed to append PEM")
ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
)
type Config struct {
MigrationsTable string
DatabaseName string
NoLock bool
StatementTimeout time.Duration
}
type Mysql struct {
// mysql RELEASE_LOCK must be called from the same conn, so
// just do everything over a single conn anyway.
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
config *Config
}
// connection instance must have `multiStatements` set to true
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := conn.PingContext(ctx); err != nil {
return nil, err
}
mx := &Mysql{
conn: conn,
db: nil,
config: config,
}
if config.DatabaseName == "" {
query := `SELECT DATABASE()`
var databaseName sql.NullString
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName.String) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName.String
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if err := mx.ensureVersionTable(); err != nil {
return nil, err
}
return mx, nil
}
// instance must have `multiStatements` set to true
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
ctx := context.Background()
if err := instance.Ping(); err != nil {
return nil, err
}
conn, err := instance.Conn(ctx)
if err != nil {
return nil, err
}
mx, err := WithConnection(ctx, conn, config)
if err != nil {
return nil, err
}
mx.db = instance
return mx, nil
}
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
if c == nil {
return nil, ErrNilConfig
}
customQueryParams := map[string]string{}
for k, v := range c.Params {
if strings.HasPrefix(k, "x-") {
customQueryParams[k] = v
delete(c.Params, k)
}
}
return customQueryParams, nil
}
func urlToMySQLConfig(url string) (*mysql.Config, error) {
// Need to parse out custom TLS parameters and call
// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
// which consumes the registered tls.Config
// Fixes: https://github.com/golang-migrate/migrate/issues/411
//
// Can't use url.Parse() since it fails to parse MySQL DSNs
// mysql.ParseDSN() also searches for "?" to find query parameters:
// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
if idx := strings.LastIndex(url, "?"); idx > 0 {
rawParams := url[idx+1:]
parsedParams, err := nurl.ParseQuery(rawParams)
if err != nil {
return nil, err
}
ctls := parsedParams.Get("tls")
if len(ctls) > 0 {
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
rootCertPool := x509.NewCertPool()
pem, err := os.ReadFile(parsedParams.Get("x-tls-ca"))
if err != nil {
return nil, err
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, ErrAppendPEM
}
clientCert := make([]tls.Certificate, 0, 1)
if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
if ccert == "" || ckey == "" {
return nil, ErrTLSCertKeyConfig
}
certs, err := tls.LoadX509KeyPair(ccert, ckey)
if err != nil {
return nil, err
}
clientCert = append(clientCert, certs)
}
insecureSkipVerify := false
insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
if len(insecureSkipVerifyStr) > 0 {
x, err := strconv.ParseBool(insecureSkipVerifyStr)
if err != nil {
return nil, err
}
insecureSkipVerify = x
}
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
InsecureSkipVerify: insecureSkipVerify,
})
if err != nil {
return nil, err
}
}
}
}
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
if err != nil {
return nil, err
}
config.MultiStatements = true
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
// net/url.Parse() would automatically unescape it for us.
// See: https://play.golang.org/p/q9j1io-YICQ
user, err := nurl.QueryUnescape(config.User)
if err != nil {
return nil, err
}
config.User = user
password, err := nurl.QueryUnescape(config.Passwd)
if err != nil {
return nil, err
}
config.Passwd = password
return config, nil
}
func (m *Mysql) Open(url string) (database.Driver, error) {
config, err := urlToMySQLConfig(url)
if err != nil {
return nil, err
}
customParams, err := extractCustomQueryParams(config)
if err != nil {
return nil, err
}
noLockParam, noLock := customParams["x-no-lock"], false
if noLockParam != "" {
noLock, err = strconv.ParseBool(noLockParam)
if err != nil {
return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
}
}
statementTimeoutParam := customParams["x-statement-timeout"]
statementTimeout := 0
if statementTimeoutParam != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutParam)
if err != nil {
return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err)
}
}
db, err := sql.Open("mysql", config.FormatDSN())
if err != nil {
return nil, err
}
mx, err := WithInstance(db, &Config{
DatabaseName: config.DBName,
MigrationsTable: customParams["x-migrations-table"],
NoLock: noLock,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
})
if err != nil {
return nil, err
}
return mx, nil
}
func (m *Mysql) Close() error {
connErr := m.conn.Close()
var dbErr error
if m.db != nil {
dbErr = m.db.Close()
}
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
func (m *Mysql) Lock() error {
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
if m.config.NoLock {
return nil
}
aid, err := database.GenerateAdvisoryLockId(
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
if err != nil {
return err
}
query := "SELECT GET_LOCK(?, 10)"
var success bool
if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
if !success {
return database.ErrLocked
}
return nil
})
}
func (m *Mysql) Unlock() error {
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
if m.config.NoLock {
return nil
}
aid, err := database.GenerateAdvisoryLockId(
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
if err != nil {
return err
}
query := `SELECT RELEASE_LOCK(?)`
if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
// in which case isLocked should be true until the timeout expires -- synchronizing
// these states is likely not worth trying to do; reconsider the necessity of isLocked.
return nil
})
}
func (m *Mysql) Run(migration io.Reader) error {
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
ctx := context.Background()
if m.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout)
defer cancel()
}
query := string(migr[:])
if _, err := m.conn.ExecContext(ctx, query); err != nil {
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
}
func (m *Mysql) SetVersion(version int, dirty bool) error {
tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := "DELETE FROM `" + m.config.MigrationsTable + "`"
if _, err := tx.ExecContext(context.Background(), query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (m *Mysql) Version() (version int, dirty bool, err error) {
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*mysql.MySQLError); ok {
if e.Number == 0 {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (m *Mysql) Drop() (err error) {
// select all tables
query := `SHOW TABLES LIKE '%'`
tables, err := m.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// disable checking foreign key constraints until finished
query = `SET foreign_key_checks = 0`
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
// enable foreign key checks
_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
}()
// delete one by one ...
for _, t := range tableNames {
query = "DROP TABLE IF EXISTS `" + t + "`"
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Mysql type.
func (m *Mysql) ensureVersionTable() (err error) {
if err = m.Lock(); err != nil {
return err
}
defer func() {
if e := m.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// check if migration table exists
var result string
query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
if err != sql.ErrNoRows {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
} else {
return nil
}
// if not, create the empty migration table
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Returns the bool value of the input.
// The 2nd return value indicates if the input was a valid bool value
// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
func readBool(input string) (value bool, valid bool) {
switch input {
case "1", "true", "TRUE", "True":
return true, true
case "0", "false", "FALSE", "False":
return false, true
}
// Not a valid bool value
return
}
@@ -0,0 +1,420 @@
package mysql
import (
"context"
"crypto/ed25519"
"crypto/x509"
"database/sql"
sqldriver "database/sql/driver"
"encoding/pem"
"errors"
"fmt"
"log"
"math/big"
"math/rand"
"net/url"
"os"
"strconv"
"testing"
"github.com/dhui/dktest"
"github.com/go-sql-driver/mysql"
"github.com/golang-migrate/migrate/v4"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
"github.com/stretchr/testify/assert"
)
const defaultPort = 3306
var (
opts = dktest.Options{
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
PortRequired: true, ReadyFunc: isReady,
}
optsAnsiQuotes = dktest.Options{
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
PortRequired: true, ReadyFunc: isReady,
Cmd: []string{"--sql-mode=ANSI_QUOTES"},
}
// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
specs = []dktesting.ContainerSpec{
{ImageName: "mysql:5.5", Options: opts},
{ImageName: "mysql:5.6", Options: opts},
{ImageName: "mysql:5.7", Options: opts},
{ImageName: "mysql:8", Options: opts},
}
specsAnsiQuotes = []dktesting.ContainerSpec{
{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
{ImageName: "mysql:8", Options: optsAnsiQuotes},
}
)
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(defaultPort)
if err != nil {
return false
}
db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
return false
default:
fmt.Println(err)
}
return false
}
return true
}
func Test(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestMigrate(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestMigrateAnsiQuotes(t *testing.T) {
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
// check ensureVersionTable
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
// check again
if err := d.(*Mysql).ensureVersionTable(); err != nil {
t.Fatal(err)
}
})
}
func TestLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ms := d.(*Mysql)
err = ms.Lock()
if err != nil {
t.Fatal(err)
}
err = ms.Unlock()
if err != nil {
t.Fatal(err)
}
// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
err = ms.Lock()
if err != nil {
t.Fatal(err)
}
err = ms.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestNoLockParamValidation(t *testing.T) {
ip := "127.0.0.1"
port := 3306
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
_, err := p.Open(addr + "?x-no-lock=not-a-bool")
if !errors.Is(err, strconv.ErrSyntax) {
t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
}
}
func TestNoLockWorks(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(defaultPort)
if err != nil {
t.Fatal(err)
}
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
p := &Mysql{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
lock := d.(*Mysql)
p = &Mysql{}
d, err = p.Open(addr + "?x-no-lock=true")
if err != nil {
t.Fatal(err)
}
noLock := d.(*Mysql)
// Should be possible to take real lock and no-lock at the same time
if err = lock.Lock(); err != nil {
t.Fatal(err)
}
if err = noLock.Lock(); err != nil {
t.Fatal(err)
}
if err = lock.Unlock(); err != nil {
t.Fatal(err)
}
if err = noLock.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestExtractCustomQueryParams(t *testing.T) {
testcases := []struct {
name string
config *mysql.Config
expectedParams map[string]string
expectedCustomParams map[string]string
expectedErr error
}{
{name: "nil config", expectedErr: ErrNilConfig},
{
name: "no params",
config: mysql.NewConfig(),
expectedCustomParams: map[string]string{},
},
{
name: "no custom params",
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{},
},
{
name: "one param, one custom param",
config: &mysql.Config{
Params: map[string]string{"hello": "world", "x-foo": "bar"},
},
expectedParams: map[string]string{"hello": "world"},
expectedCustomParams: map[string]string{"x-foo": "bar"},
},
{
name: "multiple params, multiple custom params",
config: &mysql.Config{
Params: map[string]string{
"hello": "world",
"x-foo": "bar",
"dead": "beef",
"x-cat": "hat",
},
},
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
customParams, err := extractCustomQueryParams(tc.config)
if tc.config != nil {
assert.Equal(t, tc.expectedParams, tc.config.Params,
"Expected config params have custom params properly removed")
}
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
assert.Equal(t, tc.expectedCustomParams, customParams,
"Expected custom params to be properly extracted")
})
}
}
func createTmpCert(t *testing.T) string {
tmpCertFile, err := os.CreateTemp("", "migrate_test_cert")
if err != nil {
t.Fatal("Failed to create temp cert file:", err)
}
t.Cleanup(func() {
if err := os.Remove(tmpCertFile.Name()); err != nil {
t.Log("Failed to cleanup temp cert file:", err)
}
})
r := rand.New(rand.NewSource(0))
pub, priv, err := ed25519.GenerateKey(r)
if err != nil {
t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
}
tmpl := x509.Certificate{
SerialNumber: big.NewInt(0),
}
derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
if err != nil {
t.Fatal("Failed to generate temp cert file:", err)
}
if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
t.Fatal("Failed to encode ")
}
if err := tmpCertFile.Close(); err != nil {
t.Fatal("Failed to close temp cert file:", err)
}
return tmpCertFile.Name()
}
func TestURLToMySQLConfig(t *testing.T) {
tmpCertFilename := createTmpCert(t)
tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
testcases := []struct {
name string
urlStr string
expectedDSN string // empty string signifies that an error is expected
}{
{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded :",
urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "only user - with encoded @",
urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
// {name: "user/password - user with encoded :",
// urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
// expectedDSN: "username::password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - user with encoded @",
urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded :",
urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "user/password - password with encoded @",
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
{name: "custom tls",
urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
}
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
config, err := urlToMySQLConfig(tc.urlStr)
if err != nil {
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
}
dsn := config.FormatDSN()
if dsn != tc.expectedDSN {
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
}
})
}
}
@@ -0,0 +1,20 @@
# neo4j
The Neo4j driver (bolt) does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param.
This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
The queries **should** run in a single transaction, so partial migrations should not be a concern, but this is untested.
`neo4j://user:password@host:port/`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-multi-statement` | `MultiStatement` | Enable multiple statements to be ran in a single migration (See note above) |
| `user` | Contained within `AuthConfig` | The user to sign in as |
| `password` | Contained within `AuthConfig` | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 7687) |
| | `MigrationsLabel` | Name of the migrations node label |
## Supported versions
Only Neo4j v3.5+ is [supported](https://github.com/neo4j/neo4j-go-driver/issues/64#issuecomment-625133600)
@@ -0,0 +1,97 @@
## Create migrations
Let's create nodes called `Users`:
```
migrate create -ext cypher -dir db/migrations -seq create_user_nodes
```
If there were no errors, we should have two files available under `db/migrations` folder:
- 000001_create_user_nodes.down.cypher
- 000001_create_user_nodes.up.cypher
Note the `cypher` extension that we provided.
In the `.up.cypher` file let's create the table:
```
CREATE (u1:User {name: "Peter"})
CREATE (u2:User {name: "Paul"})
CREATE (u3:User {name: "Mary"})
```
And in the `.down.sql` let's delete it:
```
MATCH (u:User) WHERE u.name IN ["Peter", "Paul", "Mary"] DELETE u
```
Ideally your migrations should be idempotent. You can read more about idempotency in [getting started](GETTING_STARTED.md#create-migrations)
## Run migrations
```
migrate -database ${NEO4J_URL} -path db/migrations up
```
Let's check if the table was created properly by running `bin/cypher-shell -u neo4j -p password`, then `neo4j> MATCH (u:User)`
The output you are supposed to see:
```
+-----------------------------------------------------------------+
| u |
+-----------------------------------------------------------------+
| (:User {name: "Peter") |
| (:User {name: "Paul") |
| (:User {name: "Mary") |
+-----------------------------------------------------------------+
```
Great! Now let's check if running reverse migration also works:
```
migrate -database ${NEO4J_URL} -path db/migrations down
```
Make sure to check if your database changed as expected in this case as well.
## Database transactions
To show database transactions usage, let's create another set of migrations by running:
```
migrate create -ext cypher -dir db/migrations -seq add_mood_to_users
```
Again, it should create for us two migrations files:
- 000002_add_mood_to_users.down.cypher
- 000002_add_mood_to_users.up.cypher
In Neo4j, when we want our queries to be done in a transaction, we need to wrap it with `:BEGIN` and `:COMMIT` commands.
Migration up:
```
:BEGIN
MATCH (u:User)
SET u.mood = "Cheery"
:COMMIT
```
Migration down:
```
:BEGIN
MATCH (u:User)
SET u.mood = null
:COMMIT
```
## Optional: Run migrations within your Go app
Here is a very simple app running migrations for the above configuration:
```
import (
"log"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/neo4j"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
func main() {
m, err := migrate.New(
"file://db/migrations",
"neo4j://neo4j:password@localhost:7687/")
if err != nil {
log.Fatal(err)
}
if err := m.Up(); err != nil {
log.Fatal(err)
}
}
```
@@ -0,0 +1 @@
DROP CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
@@ -0,0 +1 @@
CREATE CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
@@ -0,0 +1,2 @@
CREATE (:Movie {name: "Footloose"})
CREATE (:Movie {name: "Ghost"})
@@ -0,0 +1,3 @@
CREATE (:Movie {name: "Hollow Man"});
CREATE (:Movie {name: "Mystic River"});
;;;
@@ -0,0 +1,303 @@
package neo4j
import (
"bytes"
"fmt"
"io"
neturl "net/url"
"strconv"
"sync/atomic"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
"github.com/neo4j/neo4j-go-driver/neo4j"
)
func init() {
db := Neo4j{}
database.Register("neo4j", &db)
}
const DefaultMigrationsLabel = "SchemaMigration"
var (
StatementSeparator = []byte(";")
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
)
var (
ErrNilConfig = fmt.Errorf("no config")
)
type Config struct {
MigrationsLabel string
MultiStatement bool
MultiStatementMaxSize int
}
type Neo4j struct {
driver neo4j.Driver
lock uint32
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
nDriver := &Neo4j{
driver: driver,
config: config,
}
if err := nDriver.ensureVersionConstraint(); err != nil {
return nil, err
}
return nDriver, nil
}
func (n *Neo4j) Open(url string) (database.Driver, error) {
uri, err := neturl.Parse(url)
if err != nil {
return nil, err
}
password, _ := uri.User.Password()
authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
uri.User = nil
uri.Scheme = "bolt"
msQuery := uri.Query().Get("x-multi-statement")
// Whether to turn on/off TLS encryption.
tlsEncrypted := uri.Query().Get("x-tls-encrypted")
multi := false
encrypted := false
if msQuery != "" {
multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
if err != nil {
return nil, err
}
}
if tlsEncrypted != "" {
encrypted, err = strconv.ParseBool(tlsEncrypted)
if err != nil {
return nil, err
}
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
}
uri.RawQuery = ""
driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
config.Encrypted = encrypted
})
if err != nil {
return nil, err
}
return WithInstance(driver, &Config{
MigrationsLabel: DefaultMigrationsLabel,
MultiStatement: multi,
MultiStatementMaxSize: multiStatementMaxSize,
})
}
func (n *Neo4j) Close() error {
return n.driver.Close()
}
// local locking in order to pass tests, Neo doesn't support database locking
func (n *Neo4j) Lock() error {
if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
return database.ErrLocked
}
return nil
}
func (n *Neo4j) Unlock() error {
if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
return database.ErrNotLocked
}
return nil
}
func (n *Neo4j) Run(migration io.Reader) (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
}
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
}
}()
if n.config.MultiStatement {
_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
var stmtRunErr error
if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
trimStmt := bytes.TrimSpace(stmt)
if len(trimStmt) == 0 {
return true
}
trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
if len(trimStmt) == 0 {
return true
}
result, err := transaction.Run(string(trimStmt), nil)
if _, err := neo4j.Collect(result, err); err != nil {
stmtRunErr = err
return false
}
return true
}); err != nil {
return nil, err
}
return nil, stmtRunErr
})
return err
}
body, err := io.ReadAll(migration)
if err != nil {
return err
}
_, err = neo4j.Collect(session.Run(string(body[:]), nil))
return err
}
func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
}
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
}
}()
query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
n.config.MigrationsLabel)
_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
if err != nil {
return err
}
return nil
}
type MigrationRecord struct {
Version int
Dirty bool
}
func (n *Neo4j) Version() (version int, dirty bool, err error) {
session, err := n.driver.Session(neo4j.AccessModeRead)
if err != nil {
return database.NilVersion, false, err
}
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
}
}()
query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
n.config.MigrationsLabel)
result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
result, err := transaction.Run(query, nil)
if err != nil {
return nil, err
}
if result.Next() {
record := result.Record()
mr := MigrationRecord{}
versionResult, ok := record.Get("version")
if !ok {
mr.Version = database.NilVersion
} else {
mr.Version = int(versionResult.(int64))
}
dirtyResult, ok := record.Get("dirty")
if ok {
mr.Dirty = dirtyResult.(bool)
}
return mr, nil
}
return nil, result.Err()
})
if err != nil {
return database.NilVersion, false, err
}
if result == nil {
return database.NilVersion, false, err
}
mr := result.(MigrationRecord)
return mr.Version, mr.Dirty, err
}
func (n *Neo4j) Drop() (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
}
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
}
}()
if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
return err
}
return nil
}
func (n *Neo4j) ensureVersionConstraint() (err error) {
session, err := n.driver.Session(neo4j.AccessModeWrite)
if err != nil {
return err
}
defer func() {
if cerr := session.Close(); cerr != nil {
err = multierror.Append(err, cerr)
}
}()
/**
Get constraint and check to avoid error duplicate
using db.labels() to support Neo4j 3 and 4.
Neo4J 3 doesn't support db.constraints() YIELD name
*/
res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
if err != nil {
return err
}
if len(res) == 1 {
return nil
}
query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
return err
}
return nil
}
@@ -0,0 +1,138 @@
package neo4j
import (
"bytes"
"context"
"fmt"
"log"
"testing"
"github.com/dhui/dktest"
"github.com/neo4j/neo4j-go-driver/neo4j"
"github.com/golang-migrate/migrate/v4"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
var (
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady,
Env: map[string]string{"NEO4J_AUTH": "neo4j/migratetest", "NEO4J_ACCEPT_LICENSE_AGREEMENT": "yes"}}
specs = []dktesting.ContainerSpec{
{ImageName: "neo4j:4.0", Options: opts},
{ImageName: "neo4j:4.0-enterprise", Options: opts},
{ImageName: "neo4j:3.5", Options: opts},
{ImageName: "neo4j:3.5-enterprise", Options: opts},
}
)
func neoConnectionString(host, port string) string {
return fmt.Sprintf("bolt://neo4j:migratetest@%s:%s", host, port)
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.Port(7687)
if err != nil {
return false
}
driver, err := neo4j.NewDriver(
neoConnectionString(ip, port),
neo4j.BasicAuth("neo4j", "migratetest", ""),
func(config *neo4j.Config) {
config.Encrypted = false
})
if err != nil {
return false
}
defer func() {
if err := driver.Close(); err != nil {
log.Println("close error:", err)
}
}()
session, err := driver.Session(neo4j.AccessModeRead)
if err != nil {
return false
}
result, err := session.Run("RETURN 1", nil)
if err != nil {
return false
} else if result.Err() != nil {
return false
}
return true
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(7687)
if err != nil {
t.Fatal(err)
}
n := &Neo4j{}
d, err := n.Open(neoConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("MATCH (a) RETURN a"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(7687)
if err != nil {
t.Fatal(err)
}
n := &Neo4j{}
neoUrl := neoConnectionString(ip, port) + "/?x-multi-statement=true"
d, err := n.Open(neoUrl)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "neo4j", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMalformed(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.Port(7687)
if err != nil {
t.Fatal(err)
}
n := &Neo4j{}
d, err := n.Open(neoConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
migration := bytes.NewReader([]byte("CREATE (a {qid: 1) RETURN a"))
if err := d.Run(migration); err == nil {
t.Fatal("expected failure for malformed migration")
}
})
}
@@ -0,0 +1,242 @@
package database_test
import (
"encoding/hex"
"net/url"
"strings"
"testing"
)
const reservedChars = "!#$%&'()*+,/:;=?@[]"
const reservedCharTestNamePrefix = "reserved char "
const baseUsername = "username"
const scheme = "database://"
// TestUserUnencodedReservedURLChars documents the behavior of using unencoded reserved characters in usernames with
// net/url Parse()
func TestUserUnencodedReservedURLChars(t *testing.T) {
urlSuffix := "password@localhost:12345/myDB?someParam=true"
urlSuffixAndSep := ":" + urlSuffix
testcases := []struct {
char string
parses bool
expectedUsername string // empty string means that the username failed to parse
encodedURL string
}{
{char: "!", parses: true, expectedUsername: baseUsername + "!",
encodedURL: scheme + baseUsername + "%21" + urlSuffixAndSep},
{char: "#", parses: true, expectedUsername: "",
encodedURL: scheme + baseUsername + "#" + urlSuffixAndSep},
{char: "$", parses: true, expectedUsername: baseUsername + "$",
encodedURL: scheme + baseUsername + "$" + urlSuffixAndSep},
{char: "%", parses: false},
{char: "&", parses: true, expectedUsername: baseUsername + "&",
encodedURL: scheme + baseUsername + "&" + urlSuffixAndSep},
{char: "'", parses: true, expectedUsername: "username'",
encodedURL: scheme + baseUsername + "%27" + urlSuffixAndSep},
{char: "(", parses: true, expectedUsername: "username(",
encodedURL: scheme + baseUsername + "%28" + urlSuffixAndSep},
{char: ")", parses: true, expectedUsername: "username)",
encodedURL: scheme + baseUsername + "%29" + urlSuffixAndSep},
{char: "*", parses: true, expectedUsername: "username*",
encodedURL: scheme + baseUsername + "%2A" + urlSuffixAndSep},
{char: "+", parses: true, expectedUsername: "username+",
encodedURL: scheme + baseUsername + "+" + urlSuffixAndSep},
{char: ",", parses: true, expectedUsername: "username,",
encodedURL: scheme + baseUsername + "," + urlSuffixAndSep},
{char: "/", parses: true, expectedUsername: "",
encodedURL: scheme + baseUsername + "/" + urlSuffixAndSep},
{char: ":", parses: true, expectedUsername: baseUsername,
encodedURL: scheme + baseUsername + ":%3A" + urlSuffix},
{char: ";", parses: true, expectedUsername: "username;",
encodedURL: scheme + baseUsername + ";" + urlSuffixAndSep},
{char: "=", parses: true, expectedUsername: "username=",
encodedURL: scheme + baseUsername + "=" + urlSuffixAndSep},
{char: "?", parses: true, expectedUsername: "",
encodedURL: scheme + baseUsername + "?" + urlSuffixAndSep},
{char: "@", parses: true, expectedUsername: "username@",
encodedURL: scheme + baseUsername + "%40" + urlSuffixAndSep},
{char: "[", parses: false},
{char: "]", parses: false},
}
testedChars := make([]string, 0, len(reservedChars))
for _, tc := range testcases {
testedChars = append(testedChars, tc.char)
t.Run(reservedCharTestNamePrefix+tc.char, func(t *testing.T) {
s := scheme + baseUsername + tc.char + urlSuffixAndSep
u, err := url.Parse(s)
if err == nil {
if !tc.parses {
t.Error("Unexpectedly parsed reserved character. url:", s)
return
}
var username string
if u.User != nil {
username = u.User.Username()
}
if username != tc.expectedUsername {
t.Error("Got unexpected username:", username, "!=", tc.expectedUsername)
}
if s := u.String(); s != tc.encodedURL {
t.Error("Got unexpected encoded URL:", s, "!=", tc.encodedURL)
}
} else {
if tc.parses {
t.Error("Failed to parse reserved character. url:", s)
}
}
})
}
t.Run("All reserved chars tested", func(t *testing.T) {
if s := strings.Join(testedChars, ""); s != reservedChars {
t.Error("Not all reserved URL characters were tested:", s, "!=", reservedChars)
}
})
}
func TestUserEncodedReservedURLChars(t *testing.T) {
urlSuffix := "password@localhost:12345/myDB?someParam=true"
urlSuffixAndSep := ":" + urlSuffix
for _, c := range reservedChars {
c := string(c)
t.Run(reservedCharTestNamePrefix+c, func(t *testing.T) {
encodedChar := "%" + hex.EncodeToString([]byte(c))
s := scheme + baseUsername + encodedChar + urlSuffixAndSep
expectedUsername := baseUsername + c
u, err := url.Parse(s)
if err != nil {
t.Fatal("Failed to parse url with encoded reserved character. url:", s)
}
if u.User == nil {
t.Fatal("Failed to parse userinfo with encoded reserve character. url:", s)
}
if username := u.User.Username(); username != expectedUsername {
t.Fatal("Got unexpected username:", username, "!=", expectedUsername)
}
})
}
}
// TestPasswordUnencodedReservedURLChars documents the behavior of using unencoded reserved characters in passwords
// with net/url Parse()
func TestPasswordUnencodedReservedURLChars(t *testing.T) {
username := baseUsername
schemeAndUsernameAndSep := scheme + username + ":"
basePassword := "password"
urlSuffixAndSep := "@localhost:12345/myDB?someParam=true"
testcases := []struct {
char string
parses bool
expectedUsername string // empty string means that the username failed to parse
expectedPassword string // empty string means that the password failed to parse
encodedURL string
}{
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
{char: "#", parses: false},
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
{char: "%", parses: false},
{char: "&", parses: true, expectedUsername: username, expectedPassword: basePassword + "&",
encodedURL: schemeAndUsernameAndSep + basePassword + "&" + urlSuffixAndSep},
{char: "'", parses: true, expectedUsername: username, expectedPassword: "password'",
encodedURL: schemeAndUsernameAndSep + basePassword + "%27" + urlSuffixAndSep},
{char: "(", parses: true, expectedUsername: username, expectedPassword: "password(",
encodedURL: schemeAndUsernameAndSep + basePassword + "%28" + urlSuffixAndSep},
{char: ")", parses: true, expectedUsername: username, expectedPassword: "password)",
encodedURL: schemeAndUsernameAndSep + basePassword + "%29" + urlSuffixAndSep},
{char: "*", parses: true, expectedUsername: username, expectedPassword: "password*",
encodedURL: schemeAndUsernameAndSep + basePassword + "%2A" + urlSuffixAndSep},
{char: "+", parses: true, expectedUsername: username, expectedPassword: "password+",
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
{char: "/", parses: false},
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
{char: "?", parses: false},
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
{char: "[", parses: false},
{char: "]", parses: false},
}
testedChars := make([]string, 0, len(reservedChars))
for _, tc := range testcases {
testedChars = append(testedChars, tc.char)
t.Run(reservedCharTestNamePrefix+tc.char, func(t *testing.T) {
s := schemeAndUsernameAndSep + basePassword + tc.char + urlSuffixAndSep
u, err := url.Parse(s)
if err == nil {
if !tc.parses {
t.Error("Unexpectedly parsed reserved character. url:", s)
return
}
var username, password string
if u.User != nil {
username = u.User.Username()
password, _ = u.User.Password()
}
if username != tc.expectedUsername {
t.Error("Got unexpected username:", username, "!=", tc.expectedUsername)
}
if password != tc.expectedPassword {
t.Error("Got unexpected password:", password, "!=", tc.expectedPassword)
}
if s := u.String(); s != tc.encodedURL {
t.Error("Got unexpected encoded URL:", s, "!=", tc.encodedURL)
}
} else {
if tc.parses {
t.Error("Failed to parse reserved character. url:", s)
}
}
})
}
t.Run("All reserved chars tested", func(t *testing.T) {
if s := strings.Join(testedChars, ""); s != reservedChars {
t.Error("Not all reserved URL characters were tested:", s, "!=", reservedChars)
}
})
}
func TestPasswordEncodedReservedURLChars(t *testing.T) {
username := baseUsername
schemeAndUsernameAndSep := scheme + username + ":"
basePassword := "password"
urlSuffixAndSep := "@localhost:12345/myDB?someParam=true"
for _, c := range reservedChars {
c := string(c)
t.Run(reservedCharTestNamePrefix+c, func(t *testing.T) {
encodedChar := "%" + hex.EncodeToString([]byte(c))
s := schemeAndUsernameAndSep + basePassword + encodedChar + urlSuffixAndSep
expectedPassword := basePassword + c
u, err := url.Parse(s)
if err != nil {
t.Fatal("Failed to parse url with encoded reserved character. url:", s)
}
if u.User == nil {
t.Fatal("Failed to parse userinfo with encoded reserve character. url:", s)
}
if n := u.User.Username(); n != username {
t.Fatal("Got unexpected username:", n, "!=", username)
}
if p, _ := u.User.Password(); p != expectedPassword {
t.Fatal("Got unexpected password:", p, "!=", expectedPassword)
}
})
}
}
@@ -0,0 +1,43 @@
# pgx
This package is for [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). A backend for the newer [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5) is [also available](v5).
`pgx://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `x-lock-strategy` | `LockStrategy` | Strategy used for locking during migration (default: advisory) |
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock (default: schema_lock) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
## Multi-statement mode
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
@@ -0,0 +1,5 @@
CREATE TABLE users (
user_id integer unique,
name varchar(40),
email varchar(40)
);
@@ -0,0 +1 @@
ALTER TABLE users DROP COLUMN IF EXISTS city;
@@ -0,0 +1,3 @@
ALTER TABLE users ADD COLUMN city varchar(100);
@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX CONCURRENTLY users_email_index ON users (email);
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,5 @@
CREATE TABLE books (
user_id integer,
name varchar(40),
author varchar(40)
);
@@ -0,0 +1,5 @@
CREATE TABLE movies (
user_id integer,
name varchar(40),
director varchar(40)
);
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,623 @@
//go:build go1.9
// +build go1.9
package pgx
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/lib/pq"
)
const (
LockStrategyAdvisory = "advisory"
LockStrategyTable = "table"
)
func init() {
db := Postgres{}
database.Register("pgx", &db)
database.Register("pgx4", &db)
}
var (
multiStmtDelimiter = []byte(";")
DefaultMigrationsTable = "schema_migrations"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
DefaultLockTable = "schema_lock"
DefaultLockStrategy = LockStrategyAdvisory
)
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
)
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
LockTable string
LockStrategy string
migrationsSchemaName string
migrationsTableName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
}
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(schemaName) == 0 {
return nil, ErrNoSchema
}
config.SchemaName = schemaName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if len(config.LockTable) == 0 {
config.LockTable = DefaultLockTable
}
if len(config.LockStrategy) == 0 {
config.LockStrategy = DefaultLockStrategy
}
config.migrationsSchemaName = config.SchemaName
config.migrationsTableName = config.MigrationsTable
if config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
config.migrationsTableName = result[len(result)-1][1]
if len(result) == 2 {
config.migrationsSchemaName = result[0][1]
} else if len(result) > 2 {
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
}
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
px := &Postgres{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureLockTable(); err != nil {
return nil, err
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// Driver is registered as pgx, but connection string must use postgres schema
// when making actual connection
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}
statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutString)
if err != nil {
return nil, err
}
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
if multiStatementMaxSize <= 0 {
multiStatementMaxSize = DefaultMultiStatementMaxSize
}
}
multiStatementEnabled := false
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
multiStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
}
}
lockStrategy := purl.Query().Get("x-lock-strategy")
lockTable := purl.Query().Get("x-lock-table")
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
LockStrategy: lockStrategy,
LockTable: lockTable,
})
if err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
func (p *Postgres) Lock() error {
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
switch p.config.LockStrategy {
case LockStrategyAdvisory:
return p.applyAdvisoryLock()
case LockStrategyTable:
return p.applyTableLock()
default:
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
}
})
}
func (p *Postgres) Unlock() error {
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
switch p.config.LockStrategy {
case LockStrategyAdvisory:
return p.releaseAdvisoryLock()
case LockStrategyTable:
return p.releaseTableLock()
default:
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
}
})
}
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) applyAdvisoryLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
}
func (p *Postgres) applyTableLock() error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
defer func() {
errRollback := tx.Rollback()
if errRollback != nil {
err = multierror.Append(err, errRollback)
}
}()
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(query, aid)
if err != nil {
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
}
defer func() {
if errClose := rows.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// If row exists at all, lock is present
locked := rows.Next()
if locked {
return database.ErrLocked
}
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
}
return tx.Commit()
}
func (p *Postgres) releaseAdvisoryLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (p *Postgres) releaseTableLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.db.Exec(query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
}
return nil
}
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
return p.runStatement(migr)
}
func (p *Postgres) runStatement(statement []byte) error {
ctx := context.Background()
if p.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
var col uint
var lineColOK bool
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pgconn.PgError); ok {
if e.SQLState() == pgerrcode.UndefinedTable {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
// do not drop lock table
if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable {
continue
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Postgres type.
func (p *Postgres) ensureVersionTable() (err error) {
if err = p.Lock(); err != nil {
return err
}
defer func() {
if e := p.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (p *Postgres) ensureLockTable() error {
if p.config.LockStrategy != LockStrategyTable {
return nil
}
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
@@ -0,0 +1,789 @@
package pgx
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:9.5", Options: opts},
{ImageName: "postgres:9.6", Options: opts},
{ImageName: "postgres:10", Options: opts},
{ImageName: "postgres:11", Options: opts},
{ImageName: "postgres:12", Options: opts},
{ImageName: "postgres:13", Options: opts},
{ImageName: "postgres:14", Options: opts},
{ImageName: "postgres:15", Options: opts},
}
)
func pgConnectionString(host, port string, options ...string) string {
options = append(options, "sslmode=disable")
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMigrateLockTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMultipleStatements(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-multi-statement=true")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure created index exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-custom=foobar")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
})
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create foobar schema
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.SetVersion(1, false); err != nil {
t.Fatal(err)
}
// re-connect using that schema
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("expected NilVersion")
}
// now update version and compare
if err := d2.SetVersion(2, false); err != nil {
t.Fatal(err)
}
version, _, err = d2.Version()
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected version 2")
}
// meanwhile, the public schema still has the other version
version, _, err = d.Version()
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected version 2")
}
})
}
func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create migrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// bad unquoted x-migrations-table parameter
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// too many quoted x-migrations-table parameters
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
// make sure migrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table migrate.schema_migrations to exist")
}
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
}
})
}
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
defer func() {
if d2 == nil {
return
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
var e *database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}
func TestCheckBeforeCreateTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
// revoke privileges
mustRun(t, d, []string{
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
version, _, err := d3.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
}
defer func() {
if err := d3.Close(); err != nil {
t.Fatal(err)
}
}()
})
}
func TestParallelSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create foo and bar schemas
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// re-connect using that schemas
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dfoo.Close(); err != nil {
t.Error(err)
}
}()
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dbar.Close(); err != nil {
t.Error(err)
}
}()
if err := dfoo.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Unlock(); err != nil {
t.Fatal(err)
}
if err := dfoo.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ps := d.(*Postgres)
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestWithInstance_Concurrent(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// The number of concurrent processes running WithInstance
const concurrency = 30
// We can instantiate a single database handle because it is
// actually a connection pool, and so, each of the below go
// routines will have a high probability of using a separate
// connection, which is something we want to exercise.
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
db.SetMaxIdleConns(concurrency)
db.SetMaxOpenConns(concurrency)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wg.Done()
_, err := WithInstance(db, &Config{})
if err != nil {
t.Errorf("process %d error: %s", i, err)
}
}(i)
}
})
}
func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
wantLine uint
wantCol uint
input string
wantOk bool
}{
{
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
},
{
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
},
{
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
},
{
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
},
{
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
},
{
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
},
{
17, 2, 8, "SELECT *\nFROM foo", true, // last character
},
{
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
},
}
for i, tc := range testcases {
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
run := func(crlf bool, nonASCII bool) {
var name string
if crlf {
name = "crlf"
} else {
name = "lf"
}
if nonASCII {
name += "-nonascii"
} else {
name += "-ascii"
}
t.Run(name, func(t *testing.T) {
input := tc.input
if crlf {
input = strings.Replace(input, "\n", "\r\n", -1)
}
if nonASCII {
input = strings.Replace(input, "FROM", "FRÖM", -1)
}
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
if tc.wantOk {
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
}
if gotOK != tc.wantOk {
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
}
if gotLine != tc.wantLine {
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
}
if gotCol != tc.wantCol {
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
}
})
}
run(false, false)
run(true, false)
run(false, true)
run(true, true)
})
}
}
@@ -0,0 +1,41 @@
# pgx
This package is for [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5). A backend for the older [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). is [also available](..).
`pgx5://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
## Multi-statement mode
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
@@ -0,0 +1,486 @@
//go:build go1.9
// +build go1.9
package pgx
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
_ "github.com/jackc/pgx/v5/stdlib"
)
func init() {
db := Postgres{}
database.Register("pgx5", &db)
}
var (
multiStmtDelimiter = []byte(";")
DefaultMigrationsTable = "schema_migrations"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
)
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
)
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
migrationsSchemaName string
migrationsTableName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
}
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(schemaName) == 0 {
return nil, ErrNoSchema
}
config.SchemaName = schemaName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
config.migrationsSchemaName = config.SchemaName
config.migrationsTableName = config.MigrationsTable
if config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
config.migrationsTableName = result[len(result)-1][1]
if len(result) == 2 {
config.migrationsSchemaName = result[0][1]
} else if len(result) > 2 {
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
}
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
px := &Postgres{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// Driver is registered as pgx, but connection string must use postgres schema
// when making actual connection
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"
db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}
statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutString)
if err != nil {
return nil, err
}
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
if multiStatementMaxSize <= 0 {
multiStatementMaxSize = DefaultMultiStatementMaxSize
}
}
multiStatementEnabled := false
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
multiStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
}
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
})
if err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) Lock() error {
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
})
}
func (p *Postgres) Unlock() error {
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
})
}
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
return p.runStatement(migr)
}
func (p *Postgres) runStatement(statement []byte) error {
ctx := context.Background()
if p.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
var col uint
var lineColOK bool
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pgconn.PgError); ok {
if e.SQLState() == pgerrcode.UndefinedTable {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Postgres type.
func (p *Postgres) ensureVersionTable() (err error) {
if err = p.Lock(); err != nil {
return err
}
defer func() {
if e := p.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
@@ -0,0 +1,764 @@
package pgx
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"
"github.com/golang-migrate/migrate/v4"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4/database"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:9.5", Options: opts},
{ImageName: "postgres:9.6", Options: opts},
{ImageName: "postgres:10", Options: opts},
{ImageName: "postgres:11", Options: opts},
{ImageName: "postgres:12", Options: opts},
{ImageName: "postgres:13", Options: opts},
{ImageName: "postgres:14", Options: opts},
{ImageName: "postgres:15", Options: opts},
}
)
func pgConnectionString(host, port string, options ...string) string {
options = append(options, "sslmode=disable")
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMultipleStatements(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-multi-statement=true")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure created index exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-custom=foobar")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
})
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create foobar schema
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.SetVersion(1, false); err != nil {
t.Fatal(err)
}
// re-connect using that schema
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("expected NilVersion")
}
// now update version and compare
if err := d2.SetVersion(2, false); err != nil {
t.Fatal(err)
}
version, _, err = d2.Version()
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected version 2")
}
// meanwhile, the public schema still has the other version
version, _, err = d.Version()
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected version 2")
}
})
}
func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create migrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// bad unquoted x-migrations-table parameter
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// too many quoted x-migrations-table parameters
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
// make sure migrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table migrate.schema_migrations to exist")
}
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
}
})
}
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
defer func() {
if d2 == nil {
return
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
var e *database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}
func TestCheckBeforeCreateTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
// revoke privileges
mustRun(t, d, []string{
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
version, _, err := d3.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
}
defer func() {
if err := d3.Close(); err != nil {
t.Fatal(err)
}
}()
})
}
func TestParallelSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create foo and bar schemas
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// re-connect using that schemas
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dfoo.Close(); err != nil {
t.Error(err)
}
}()
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dbar.Close(); err != nil {
t.Error(err)
}
}()
if err := dfoo.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Unlock(); err != nil {
t.Fatal(err)
}
if err := dfoo.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ps := d.(*Postgres)
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestWithInstance_Concurrent(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// The number of concurrent processes running WithInstance
const concurrency = 30
// We can instantiate a single database handle because it is
// actually a connection pool, and so, each of the below go
// routines will have a high probability of using a separate
// connection, which is something we want to exercise.
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
db.SetMaxIdleConns(concurrency)
db.SetMaxOpenConns(concurrency)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wg.Done()
_, err := WithInstance(db, &Config{})
if err != nil {
t.Errorf("process %d error: %s", i, err)
}
}(i)
}
})
}
func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
wantLine uint
wantCol uint
input string
wantOk bool
}{
{
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
},
{
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
},
{
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
},
{
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
},
{
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
},
{
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
},
{
17, 2, 8, "SELECT *\nFROM foo", true, // last character
},
{
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
},
}
for i, tc := range testcases {
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
run := func(crlf bool, nonASCII bool) {
var name string
if crlf {
name = "crlf"
} else {
name = "lf"
}
if nonASCII {
name += "-nonascii"
} else {
name += "-ascii"
}
t.Run(name, func(t *testing.T) {
input := tc.input
if crlf {
input = strings.Replace(input, "\n", "\r\n", -1)
}
if nonASCII {
input = strings.Replace(input, "FROM", "FRÖM", -1)
}
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
if tc.wantOk {
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
}
if gotOK != tc.wantOk {
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
}
if gotLine != tc.wantLine {
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
}
if gotCol != tc.wantCol {
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
}
})
}
run(false, false)
run(true, false)
run(false, true)
run(true, true)
})
}
}
@@ -0,0 +1,39 @@
# postgres
`postgres://user:password@host:port/dbname?query` (`postgresql://` works, too)
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
## Multi-statement mode
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
@@ -0,0 +1,167 @@
# PostgreSQL tutorial for beginners
## Create/configure database
For the purpose of this tutorial let's create PostgreSQL database called `example`.
Our user here is `postgres`, password `password`, and host is `localhost`.
```
psql -h localhost -U postgres -w -c "create database example;"
```
When using Migrate CLI we need to pass to database URL. Let's export it to a variable for convenience:
```
export POSTGRESQL_URL='postgres://postgres:password@localhost:5432/example?sslmode=disable'
```
`sslmode=disable` means that the connection with our database will not be encrypted. Enabling it is left as an exercise.
You can find further description of database URLs [here](README.md#database-urls).
## Create migrations
Let's create table called `users`:
```
migrate create -ext sql -dir db/migrations -seq create_users_table
```
If there were no errors, we should have two files available under `db/migrations` folder:
- 000001_create_users_table.down.sql
- 000001_create_users_table.up.sql
Note the `sql` extension that we provided.
In the `.up.sql` file let's create the table:
```sql
CREATE TABLE IF NOT EXISTS users(
user_id serial PRIMARY KEY,
username VARCHAR (50) UNIQUE NOT NULL,
password VARCHAR (50) NOT NULL,
email VARCHAR (300) UNIQUE NOT NULL
);
```
And in the `.down.sql` let's delete it:
```sql
DROP TABLE IF EXISTS users;
```
By adding `IF EXISTS/IF NOT EXISTS` we are making migrations idempotent - you can read more about idempotency in [getting started](../../GETTING_STARTED.md#create-migrations)
## Run migrations
```
migrate -database ${POSTGRESQL_URL} -path db/migrations up
```
Let's check if the table was created properly by running `psql example -c "\d users"`.
The output you are supposed to see:
```
Table "public.users"
Column | Type | Modifiers
----------+------------------------+---------------------------------------------------------
user_id | integer | not null default nextval('users_user_id_seq'::regclass)
username | character varying(50) | not null
password | character varying(50) | not null
email | character varying(300) | not null
Indexes:
"users_pkey" PRIMARY KEY, btree (user_id)
"users_email_key" UNIQUE CONSTRAINT, btree (email)
"users_username_key" UNIQUE CONSTRAINT, btree (username)
```
Great! Now let's check if running reverse migration also works:
```
migrate -database ${POSTGRESQL_URL} -path db/migrations down
```
Make sure to check if your database changed as expected in this case as well.
## Database transactions
To show database transactions usage, let's create another set of migrations by running:
```
migrate create -ext sql -dir db/migrations -seq add_mood_to_users
```
Again, it should create for us two migrations files:
- 000002_add_mood_to_users.down.sql
- 000002_add_mood_to_users.up.sql
In Postgres, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands.
In our example, we are going to add a column to our database that can only accept enumerable values or NULL.
Migration up:
```sql
BEGIN;
CREATE TYPE enum_mood AS ENUM (
'happy',
'sad',
'neutral'
);
ALTER TABLE users ADD COLUMN mood enum_mood;
COMMIT;
```
Migration down:
```sql
BEGIN;
ALTER TABLE users DROP COLUMN mood;
DROP TYPE enum_mood;
COMMIT;
```
Now we can run our new migration and check the database:
```
migrate -database ${POSTGRESQL_URL} -path db/migrations up
psql example -c "\d users"
```
Expected output:
```
Table "public.users"
Column | Type | Modifiers
----------+------------------------+---------------------------------------------------------
user_id | integer | not null default nextval('users_user_id_seq'::regclass)
username | character varying(50) | not null
password | character varying(50) | not null
email | character varying(300) | not null
mood | enum_mood |
Indexes:
"users_pkey" PRIMARY KEY, btree (user_id)
"users_email_key" UNIQUE CONSTRAINT, btree (email)
"users_username_key" UNIQUE CONSTRAINT, btree (username)
```
## Optional: Run migrations within your Go app
Here is a very simple app running migrations for the above configuration:
```go
import (
"log"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
func main() {
m, err := migrate.New(
"file://db/migrations",
"postgres://postgres:postgres@localhost:5432/example?sslmode=disable")
if err != nil {
log.Fatal(err)
}
if err := m.Up(); err != nil {
log.Fatal(err)
}
}
```
You can find details [here](README.md#use-in-your-go-project)
## Fix issue where migrations run twice
When the schema and role names are the same, you might run into issues if you create this schema using migrations.
This is caused by the fact that the [default `search_path`](https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH) is `"$user", public`.
In the first run (with an empty database) the migrate table is created in `public`.
When the migrations create the `$user` schema, the next run will store (a new) migrate table in this schema (due to order of schemas in `search_path`) and tries to apply all migrations again (most likely failing).
To solve this you need to change the default `search_path` by removing the `$user` component, so the migrate table is always stored in the (available) `public` schema.
This can be done using the [`search_path` query parameter in the URL](https://github.com/jexia/migrate/blob/fix-postgres-version-table/database/postgres/README.md#postgres).
For example to force the migrations table in the public schema you can use:
```
export POSTGRESQL_URL='postgres://postgres:password@localhost:5432/example?sslmode=disable&search_path=public'
```
Note that you need to explicitly add the schema names to the table names in your migrations when you to modify the tables of the non-public schema.
Alternatively you can add the non-public schema manually (before applying the migrations) if that is possible in your case and let the tool store the migrations table in this schema as well.

Some files were not shown because too many files have changed in this diff Show More