whatcanGOwrong
This commit is contained in:
@@ -0,0 +1,45 @@
|
||||
# Cassandra / ScyllaDB
|
||||
|
||||
* `Drop()` method will not work on Cassandra 2.X because it rely on
|
||||
system_schema table which comes with 3.X
|
||||
* Other methods should work properly but are **not tested**
|
||||
* The Cassandra driver (gocql) does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param. There are two important caveats:
|
||||
* This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
|
||||
* The queries are not executed in any sort of transaction/batch, meaning you are responsible for fixing partial migrations.
|
||||
|
||||
**ScyllaDB**
|
||||
|
||||
* No additional configuration is required since it is a drop-in replacement for Cassandra.
|
||||
* The `Drop()` method` works for ScyllaDB 5.1
|
||||
|
||||
|
||||
## Usage
|
||||
`cassandra://host:port/keyspace?param1=value¶m2=value2`
|
||||
|
||||
|
||||
| URL Query | Default value | Description |
|
||||
|------------|-------------|-----------|
|
||||
| `x-migrations-table` | schema_migrations | Name of the migrations table |
|
||||
| `x-multi-statement` | false | Enable multiple statements to be ran in a single migration (See note above) |
|
||||
| `port` | 9042 | The port to bind to |
|
||||
| `consistency` | ALL | Migration consistency
|
||||
| `protocol` | | Cassandra protocol version (3 or 4)
|
||||
| `timeout` | 1 minute | Migration timeout
|
||||
| `connect-timeout` | 600ms | Initial connection timeout to the cluster |
|
||||
| `username` | nil | Username to use when authenticating. |
|
||||
| `password` | nil | Password to use when authenticating. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
| `disable-host-lookup`| false | Disable initial host lookup. |
|
||||
|
||||
`timeout` is parsed using [time.ParseDuration(s string)](https://golang.org/pkg/time/#ParseDuration)
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
2. `DROP TABLE schema_migrations`
|
||||
4. Download and install the latest migrate version.
|
||||
5. Force the current migration version with `migrate force <current_version>`.
|
||||
+352
@@ -0,0 +1,352 @@
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/gocql/gocql"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := new(Cassandra)
|
||||
database.Register("cassandra", db)
|
||||
}
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
|
||||
var (
|
||||
ErrNilConfig = errors.New("no config")
|
||||
ErrNoKeyspace = errors.New("no keyspace provided")
|
||||
ErrDatabaseDirty = errors.New("database is dirty")
|
||||
ErrClosedSession = errors.New("session is closed")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
KeyspaceName string
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Cassandra struct {
|
||||
session *gocql.Session
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
} else if len(config.KeyspaceName) == 0 {
|
||||
return nil, ErrNoKeyspace
|
||||
}
|
||||
|
||||
if session.Closed() {
|
||||
return nil, ErrClosedSession
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if config.MultiStatementMaxSize <= 0 {
|
||||
config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
|
||||
c := &Cassandra{
|
||||
session: session,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Open(url string) (database.Driver, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check for missing mandatory attributes
|
||||
if len(u.Path) == 0 {
|
||||
return nil, ErrNoKeyspace
|
||||
}
|
||||
|
||||
cluster := gocql.NewCluster(u.Host)
|
||||
cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
|
||||
cluster.Consistency = gocql.All
|
||||
cluster.Timeout = 1 * time.Minute
|
||||
|
||||
if len(u.Query().Get("username")) > 0 && len(u.Query().Get("password")) > 0 {
|
||||
authenticator := gocql.PasswordAuthenticator{
|
||||
Username: u.Query().Get("username"),
|
||||
Password: u.Query().Get("password"),
|
||||
}
|
||||
cluster.Authenticator = authenticator
|
||||
}
|
||||
|
||||
// Retrieve query string configuration
|
||||
if len(u.Query().Get("consistency")) > 0 {
|
||||
var consistency gocql.Consistency
|
||||
consistency, err = parseConsistency(u.Query().Get("consistency"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cluster.Consistency = consistency
|
||||
}
|
||||
if len(u.Query().Get("protocol")) > 0 {
|
||||
var protoversion int
|
||||
protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cluster.ProtoVersion = protoversion
|
||||
}
|
||||
if len(u.Query().Get("timeout")) > 0 {
|
||||
var timeout time.Duration
|
||||
timeout, err = time.ParseDuration(u.Query().Get("timeout"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cluster.Timeout = timeout
|
||||
}
|
||||
if len(u.Query().Get("connect-timeout")) > 0 {
|
||||
var connectTimeout time.Duration
|
||||
connectTimeout, err = time.ParseDuration(u.Query().Get("connect-timeout"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cluster.ConnectTimeout = connectTimeout
|
||||
}
|
||||
|
||||
if len(u.Query().Get("sslmode")) > 0 {
|
||||
if u.Query().Get("sslmode") != "disable" {
|
||||
sslOpts := &gocql.SslOptions{}
|
||||
|
||||
if len(u.Query().Get("sslrootcert")) > 0 {
|
||||
sslOpts.CaPath = u.Query().Get("sslrootcert")
|
||||
}
|
||||
if len(u.Query().Get("sslcert")) > 0 {
|
||||
sslOpts.CertPath = u.Query().Get("sslcert")
|
||||
}
|
||||
if len(u.Query().Get("sslkey")) > 0 {
|
||||
sslOpts.KeyPath = u.Query().Get("sslkey")
|
||||
}
|
||||
|
||||
if u.Query().Get("sslmode") == "verify-full" {
|
||||
sslOpts.EnableHostVerification = true
|
||||
}
|
||||
|
||||
cluster.SslOpts = sslOpts
|
||||
}
|
||||
}
|
||||
|
||||
if len(u.Query().Get("disable-host-lookup")) > 0 {
|
||||
if flag, err := strconv.ParseBool(u.Query().Get("disable-host-lookup")); err != nil && flag {
|
||||
cluster.DisableInitialHostLookup = true
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
session, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := u.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return WithInstance(session, &Config{
|
||||
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
|
||||
MigrationsTable: u.Query().Get("x-migrations-table"),
|
||||
MultiStatementEnabled: u.Query().Get("x-multi-statement") == "true",
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Cassandra) Close() error {
|
||||
c.session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Lock() error {
|
||||
if !c.isLocked.CAS(false, true) {
|
||||
return database.ErrLocked
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Unlock() error {
|
||||
if !c.isLocked.CAS(true, false) {
|
||||
return database.ErrNotLocked
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Run(migration io.Reader) error {
|
||||
if c.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(migration, multiStmtDelimiter, c.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
tq := strings.TrimSpace(string(m))
|
||||
if tq == "" {
|
||||
return true
|
||||
}
|
||||
if e := c.session.Query(tq).Exec(); e != nil {
|
||||
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// run migration
|
||||
if err := c.session.Query(string(migr)).Exec(); err != nil {
|
||||
// TODO: cast to Cassandra error and get line number
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
// DELETE instead of TRUNCATE because AWS Keyspaces does not support it
|
||||
// see: https://docs.aws.amazon.com/keyspaces/latest/devguide/cassandra-apis.html
|
||||
squery := `SELECT version FROM "` + c.config.MigrationsTable + `"`
|
||||
dquery := `DELETE FROM "` + c.config.MigrationsTable + `" WHERE version = ?`
|
||||
iter := c.session.Query(squery).Iter()
|
||||
var previous int
|
||||
for iter.Scan(&previous) {
|
||||
if err := c.session.Query(dquery, previous).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(dquery)}
|
||||
}
|
||||
}
|
||||
if err := iter.Close(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(squery)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query := `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
||||
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return current keyspace version
|
||||
func (c *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
|
||||
err = c.session.Query(query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == gocql.ErrNotFound:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if _, ok := err.(*gocql.Error); ok {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cassandra) Drop() error {
|
||||
// select all tables in current schema
|
||||
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
|
||||
iter := c.session.Query(query).Iter()
|
||||
var tableName string
|
||||
for iter.Scan(&tableName) {
|
||||
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Cassandra type.
|
||||
func (c *Cassandra) ensureVersionTable() (err error) {
|
||||
if err = c.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := c.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
err = c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err = c.Version(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ParseConsistency wraps gocql.ParseConsistency
|
||||
// to return an error instead of a panicking.
|
||||
func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
var ok bool
|
||||
err, ok = r.(error)
|
||||
if !ok {
|
||||
err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
consistency = gocql.ParseConsistency(consistencyStr)
|
||||
|
||||
return consistency, nil
|
||||
}
|
||||
+122
@@ -0,0 +1,122 @@
|
||||
package cassandra
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
import (
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/gocql/gocql"
|
||||
)
|
||||
|
||||
import (
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: http://cassandra.apache.org/download/
|
||||
// Although Cassandra 2.x is supported by the Apache Foundation,
|
||||
// the migrate db driver only supports Cassandra 3.x since it uses
|
||||
// the system_schema keyspace.
|
||||
// last ScyllaDB version tested is 5.1.11
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "cassandra:3.0", Options: opts},
|
||||
{ImageName: "cassandra:3.11", Options: opts},
|
||||
{ImageName: "scylladb/scylla:5.1.11", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
// Cassandra exposes 5 ports (7000, 7001, 7199, 9042 & 9160)
|
||||
// We only need the port bound to 9042
|
||||
ip, portStr, err := c.Port(9042)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
cluster := gocql.NewCluster(ip)
|
||||
cluster.Port = port
|
||||
cluster.Consistency = gocql.All
|
||||
p, err := cluster.CreateSession()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer p.Close()
|
||||
// Create keyspace for tests
|
||||
if err = p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec(); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
t.Run("test", test)
|
||||
t.Run("testMigrate", testMigrate)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, spec := range specs {
|
||||
t.Log("Cleaning up ", spec.ImageName)
|
||||
if err := spec.Cleanup(); err != nil {
|
||||
t.Error("Error removing ", spec.ImageName, "error:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(9042)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to get mapped port:", err)
|
||||
}
|
||||
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
|
||||
p := &Cassandra{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT table_name from system_schema.tables"))
|
||||
})
|
||||
}
|
||||
|
||||
func testMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(9042)
|
||||
if err != nil {
|
||||
t.Fatal("Unable to get mapped port:", err)
|
||||
}
|
||||
addr := fmt.Sprintf("cassandra://%v:%v/testks", ip, port)
|
||||
p := &Cassandra{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "testks", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
+1
@@ -0,0 +1 @@
|
||||
SELECT table_name from system_schema.tables
|
||||
+1
@@ -0,0 +1 @@
|
||||
SELECT table_name from system_schema.tables
|
||||
@@ -0,0 +1,26 @@
|
||||
# ClickHouse
|
||||
|
||||
`clickhouse://host:port?username=user&password=password&database=clicks&x-multi-statement=true`
|
||||
|
||||
| URL Query | Description |
|
||||
|------------|-------------|
|
||||
| `x-migrations-table`| Name of the migrations table |
|
||||
| `x-migrations-table-engine`| Engine to use for the migrations table, defaults to TinyLog |
|
||||
| `x-cluster-name` | Name of cluster for creating `schema_migrations` table cluster wide |
|
||||
| `database` | The name of the database to connect to |
|
||||
| `username` | The user to sign in as |
|
||||
| `password` | The user's password |
|
||||
| `host` | The host to connect to. |
|
||||
| `port` | The port to bind to. |
|
||||
| `x-multi-statement` | false | Enable multiple statements to be ran in a single migration (See note below) |
|
||||
|
||||
## Notes
|
||||
|
||||
* The Clickhouse driver does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param. There are two important caveats:
|
||||
* This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
|
||||
* The queries are not executed in any sort of transaction/batch, meaning you are responsible for fixing partial migrations.
|
||||
* Using the default TinyLog table engine for the schema_versions table prevents backing up the table if using the [clickhouse-backup](https://github.com/AlexAkulov/clickhouse-backup) tool. If backing up the database with make sure the migrations are run with `x-migrations-table-engine=MergeTree`.
|
||||
* Clickhouse cluster mode is not officially supported, since it's not tested right now, but you can try enabling `schema_migrations` table replication by specifying a `x-cluster-name`:
|
||||
* When `x-cluster-name` is specified, `x-migrations-table-engine` also should be specified. See the docs regarding [replicated table engines](https://clickhouse.tech/docs/en/engines/table-engines/mergetree-family/replication/#table_engines-replication).
|
||||
* When `x-cluster-name` is specified, only the `schema_migrations` table is replicated across the cluster. You still need to write your migrations so that the application tables are replicated within the cluster.
|
||||
* If you want to create database inside the migration, you should know, that table which will manage migrations `schema-migrations table` will be in `default` table, so you can't use `USE <database_name>` inside migration. In this case you may not specify the database in the connection string (example you can find [here](examples/migrations/003_create_database.up.sql))
|
||||
+316
@@ -0,0 +1,316 @@
|
||||
package clickhouse
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMigrationsTable = "schema_migrations"
|
||||
DefaultMigrationsTableEngine = "TinyLog"
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DatabaseName string
|
||||
ClusterName string
|
||||
MigrationsTable string
|
||||
MigrationsTableEngine string
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
func init() {
|
||||
database.Register("clickhouse", &ClickHouse{})
|
||||
}
|
||||
|
||||
func WithInstance(conn *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := conn.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch := &ClickHouse{
|
||||
conn: conn,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := ch.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
type ClickHouse struct {
|
||||
conn *sql.DB
|
||||
config *Config
|
||||
isLocked atomic.Bool
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) Open(dsn string) (database.Driver, error) {
|
||||
purl, err := url.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
q := migrate.FilterCustomQuery(purl)
|
||||
q.Scheme = "tcp"
|
||||
conn, err := sql.Open("clickhouse", q.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
migrationsTableEngine := DefaultMigrationsTableEngine
|
||||
if s := purl.Query().Get("x-migrations-table-engine"); len(s) > 0 {
|
||||
migrationsTableEngine = s
|
||||
}
|
||||
|
||||
ch = &ClickHouse{
|
||||
conn: conn,
|
||||
config: &Config{
|
||||
MigrationsTable: purl.Query().Get("x-migrations-table"),
|
||||
MigrationsTableEngine: migrationsTableEngine,
|
||||
DatabaseName: purl.Query().Get("database"),
|
||||
ClusterName: purl.Query().Get("x-cluster-name"),
|
||||
MultiStatementEnabled: purl.Query().Get("x-multi-statement") == "true",
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
},
|
||||
}
|
||||
|
||||
if err := ch.init(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) init() error {
|
||||
if len(ch.config.DatabaseName) == 0 {
|
||||
if err := ch.conn.QueryRow("SELECT currentDatabase()").Scan(&ch.config.DatabaseName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(ch.config.MigrationsTable) == 0 {
|
||||
ch.config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if ch.config.MultiStatementMaxSize <= 0 {
|
||||
ch.config.MultiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
|
||||
if len(ch.config.MigrationsTableEngine) == 0 {
|
||||
ch.config.MigrationsTableEngine = DefaultMigrationsTableEngine
|
||||
}
|
||||
|
||||
return ch.ensureVersionTable()
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) Run(r io.Reader) error {
|
||||
if ch.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(r, multiStmtDelimiter, ch.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
tq := strings.TrimSpace(string(m))
|
||||
if tq == "" {
|
||||
return true
|
||||
}
|
||||
if _, e := ch.conn.Exec(string(m)); e != nil {
|
||||
err = database.Error{OrigErr: e, Err: "migration failed", Query: m}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
migration, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ch.conn.Exec(string(migration)); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migration}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (ch *ClickHouse) Version() (int, bool, error) {
|
||||
var (
|
||||
version int
|
||||
dirty uint8
|
||||
query = "SELECT version, dirty FROM `" + ch.config.MigrationsTable + "` ORDER BY sequence DESC LIMIT 1"
|
||||
)
|
||||
if err := ch.conn.QueryRow(query).Scan(&version, &dirty); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return version, dirty == 1, nil
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) SetVersion(version int, dirty bool) error {
|
||||
var (
|
||||
bool = func(v bool) uint8 {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
tx, err = ch.conn.Begin()
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "INSERT INTO " + ch.config.MigrationsTable + " (version, dirty, sequence) VALUES (?, ?, ?)"
|
||||
if _, err := tx.Exec(query, version, bool(dirty), time.Now().UnixNano()); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the ClickHouse type.
|
||||
func (ch *ClickHouse) ensureVersionTable() (err error) {
|
||||
if err = ch.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := ch.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var (
|
||||
table string
|
||||
query = "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName) + " LIKE '" + ch.config.MigrationsTable + "'"
|
||||
)
|
||||
// check if migration table exists
|
||||
if err := ch.conn.QueryRow(query).Scan(&table); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, create the empty migration table
|
||||
if len(ch.config.ClusterName) > 0 {
|
||||
query = fmt.Sprintf(`
|
||||
CREATE TABLE %s ON CLUSTER %s (
|
||||
version Int64,
|
||||
dirty UInt8,
|
||||
sequence UInt64
|
||||
) Engine=%s`, ch.config.MigrationsTable, ch.config.ClusterName, ch.config.MigrationsTableEngine)
|
||||
} else {
|
||||
query = fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
version Int64,
|
||||
dirty UInt8,
|
||||
sequence UInt64
|
||||
) Engine=%s`, ch.config.MigrationsTable, ch.config.MigrationsTableEngine)
|
||||
}
|
||||
|
||||
if strings.HasSuffix(ch.config.MigrationsTableEngine, "Tree") {
|
||||
query = fmt.Sprintf(`%s ORDER BY sequence`, query)
|
||||
}
|
||||
|
||||
if _, err := ch.conn.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) Drop() (err error) {
|
||||
query := "SHOW TABLES FROM " + quoteIdentifier(ch.config.DatabaseName)
|
||||
tables, err := ch.conn.Query(query)
|
||||
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
for tables.Next() {
|
||||
var table string
|
||||
if err := tables.Scan(&table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query = "DROP TABLE IF EXISTS " + quoteIdentifier(ch.config.DatabaseName) + "." + quoteIdentifier(table)
|
||||
|
||||
if _, err := ch.conn.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ch *ClickHouse) Lock() error {
|
||||
if !ch.isLocked.CAS(false, true) {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (ch *ClickHouse) Unlock() error {
|
||||
if !ch.isLocked.CAS(true, false) {
|
||||
return database.ErrNotLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
func (ch *ClickHouse) Close() error { return ch.conn.Close() }
|
||||
|
||||
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
|
||||
func quoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
||||
}
|
||||
+224
@@ -0,0 +1,224 @@
|
||||
package clickhouse_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
_ "github.com/ClickHouse/clickhouse-go"
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/clickhouse"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const defaultPort = 9000
|
||||
|
||||
var (
|
||||
tableEngines = []string{"TinyLog", "MergeTree"}
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"CLICKHOUSE_USER": "user", "CLICKHOUSE_PASSWORD": "password", "CLICKHOUSE_DB": "db"},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
}
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "yandex/clickhouse-server:21.3", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func clickhouseConnectionString(host, port, engine string) string {
|
||||
if engine != "" {
|
||||
return fmt.Sprintf(
|
||||
"clickhouse://%v:%v?username=user&password=password&database=db&x-multi-statement=true&x-migrations-table-engine=%v&debug=false",
|
||||
host, port, engine)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"clickhouse://%v:%v?username=user&password=password&database=db&x-multi-statement=true&debug=false",
|
||||
host, port)
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("clickhouse", clickhouseConnectionString(ip, port, ""))
|
||||
|
||||
if err != nil {
|
||||
log.Println("open error", err)
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn:
|
||||
return false
|
||||
default:
|
||||
fmt.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestCases(t *testing.T) {
|
||||
for _, engine := range tableEngines {
|
||||
t.Run("Test_"+engine, func(t *testing.T) { testSimple(t, engine) })
|
||||
t.Run("Migrate_"+engine, func(t *testing.T) { testMigrate(t, engine) })
|
||||
t.Run("Version_"+engine, func(t *testing.T) { testVersion(t, engine) })
|
||||
t.Run("Drop_"+engine, func(t *testing.T) { testDrop(t, engine) })
|
||||
}
|
||||
t.Run("WithInstanceDefaultConfigValues", func(t *testing.T) { testSimpleWithInstanceDefaultConfigValues(t) })
|
||||
}
|
||||
|
||||
func testSimple(t *testing.T, engine string) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := clickhouseConnectionString(ip, port, engine)
|
||||
p := &clickhouse.ClickHouse{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func testSimpleWithInstanceDefaultConfigValues(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := clickhouseConnectionString(ip, port, "")
|
||||
conn, err := sql.Open("clickhouse", addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d, err := clickhouse.WithInstance(conn, &clickhouse.Config{})
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func testMigrate(t *testing.T, engine string) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := clickhouseConnectionString(ip, port, engine)
|
||||
p := &clickhouse.ClickHouse{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "db", d)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func testVersion(t *testing.T, engine string) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
expectedVersion := 1
|
||||
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := clickhouseConnectionString(ip, port, engine)
|
||||
p := &clickhouse.ClickHouse{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = d.SetVersion(expectedVersion, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _, err := d.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if version != expectedVersion {
|
||||
t.Fatal("Version mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testDrop(t *testing.T, engine string) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := clickhouseConnectionString(ip, port, engine)
|
||||
p := &clickhouse.ClickHouse{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
err = d.Drop()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS test_1;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE TABLE test_1 (
|
||||
Date Date
|
||||
) Engine=Memory;
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS test_2;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE TABLE test_2 (
|
||||
Date Date
|
||||
) Engine=Memory;
|
||||
+10
@@ -0,0 +1,10 @@
|
||||
DROP TABLE IF EXISTS driver_ratings;
|
||||
DROP TABLE IF EXISTS user_ratings;
|
||||
DROP TABLE IF EXISTS orders;
|
||||
DROP TABLE IF EXISTS driver_ratings_queue;
|
||||
DROP TABLE IF EXISTS user_ratings_queue;
|
||||
DROP TABLE IF EXISTS orders_queue;
|
||||
DROP VIEW IF EXISTS user_ratings_queue_mv;
|
||||
DROP VIEW IF EXISTS driver_ratings_queue_mv;
|
||||
DROP VIEW IF EXISTS orders_queue_mv;
|
||||
DROP DATABASE IF EXISTS analytics;
|
||||
+81
@@ -0,0 +1,81 @@
|
||||
CREATE DATABASE IF NOT EXISTS analytics;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analytics.driver_ratings(
|
||||
rate UInt8,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String,
|
||||
inserted_time DateTime DEFAULT now()
|
||||
) ENGINE = MergeTree
|
||||
PARTITION BY driverID
|
||||
ORDER BY (inserted_time);
|
||||
|
||||
CREATE TABLE analytics.driver_ratings_queue(
|
||||
rate UInt8,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String
|
||||
) ENGINE = Kafka
|
||||
SETTINGS kafka_broker_list = 'broker:9092',
|
||||
kafka_topic_list = 'driver-ratings',
|
||||
kafka_group_name = 'rating_readers',
|
||||
kafka_format = 'Avro',
|
||||
kafka_max_block_size = 1048576;
|
||||
|
||||
CREATE MATERIALIZED VIEW analytics.driver_ratings_queue_mv TO analytics.driver_ratings AS
|
||||
SELECT rate, userID, driverID, orderID
|
||||
FROM analytics.driver_ratings_queue;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analytics.user_ratings(
|
||||
rate UInt8,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String,
|
||||
inserted_time DateTime DEFAULT now()
|
||||
) ENGINE = MergeTree
|
||||
PARTITION BY userID
|
||||
ORDER BY (inserted_time);
|
||||
|
||||
CREATE TABLE analytics.user_ratings_queue(
|
||||
rate UInt8,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String
|
||||
) ENGINE = Kafka
|
||||
SETTINGS kafka_broker_list = 'broker:9092',
|
||||
kafka_topic_list = 'user-ratings',
|
||||
kafka_group_name = 'rating_readers',
|
||||
kafka_format = 'JSON',
|
||||
kafka_max_block_size = 1048576;
|
||||
|
||||
CREATE MATERIALIZED VIEW analytics.user_ratings_queue_mv TO analytics.user_ratings AS
|
||||
SELECT rate, userID, driverID, orderID
|
||||
FROM analytics.user_ratings_queue;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analytics.orders(
|
||||
from_place String,
|
||||
to_place String,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String,
|
||||
inserted_time DateTime DEFAULT now()
|
||||
) ENGINE = MergeTree
|
||||
PARTITION BY driverID
|
||||
ORDER BY (inserted_time);
|
||||
|
||||
CREATE TABLE analytics.orders_queue(
|
||||
from_place String,
|
||||
to_place String,
|
||||
userID Int64,
|
||||
driverID String,
|
||||
orderID String
|
||||
) ENGINE = Kafka
|
||||
SETTINGS kafka_broker_list = 'broker:9092',
|
||||
kafka_topic_list = 'orders',
|
||||
kafka_group_name = 'order_readers',
|
||||
kafka_format = 'Avro',
|
||||
kafka_max_block_size = 1048576;
|
||||
|
||||
CREATE MATERIALIZED VIEW analytics.orders_queue_mv TO orders AS
|
||||
SELECT from_place, to_place, userID, driverID, orderID
|
||||
FROM analytics.orders_queue;
|
||||
+19
@@ -0,0 +1,19 @@
|
||||
# cockroachdb
|
||||
|
||||
`cockroachdb://user:password@host:port/dbname?query` (`cockroach://`, and `crdb-postgres://` work, too)
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock |
|
||||
| `x-force-lock` | `ForceLock` | Force lock acquisition to fix faulty migrations which may not have released the schema lock (Boolean, default is `false`) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
+142
@@ -0,0 +1,142 @@
|
||||
# CockroachDB tutorial for beginners (insecure cluster)
|
||||
|
||||
## Create/configure database
|
||||
|
||||
First, let's start a local cluster - follow step 1. and 2. from [the docs](https://www.cockroachlabs.com/docs/stable/start-a-local-cluster.html#step-1-start-the-first-node).
|
||||
|
||||
Once you have it, create a database. Here I am going to create a database called `example`.
|
||||
Our user here is `cockroach`. We are not going to use a password, since it's not supported for insecure cluster.
|
||||
```
|
||||
cockroach sql --insecure --host=localhost:26257
|
||||
```
|
||||
```
|
||||
CREATE DATABASE example;
|
||||
CREATE USER IF NOT EXISTS cockroach;
|
||||
GRANT ALL ON DATABASE example TO cockroach;
|
||||
```
|
||||
|
||||
When using Migrate CLI we need to pass to database URL. Let's export it to a variable for convenience:
|
||||
```
|
||||
export COCKROACHDB_URL='cockroachdb://cockroach:@localhost:26257/example?sslmode=disable'
|
||||
```
|
||||
`sslmode=disable` means that the connection with our database will not be encrypted. This is needed to connect to an insecure node.
|
||||
|
||||
**NOTE:** Do not use COCKROACH_URL as a variable name here, it's already in use for discrete parameters and you may run into connection problems. For more info check out [docs](https://www.cockroachlabs.com/docs/stable/connection-parameters.html#connect-using-discrete-parameters).
|
||||
|
||||
You can find further description of database URLs [here](README.md#database-urls).
|
||||
|
||||
## Create migrations
|
||||
Let's create a table called `users`:
|
||||
```
|
||||
migrate create -ext sql -dir db/migrations -seq create_users_table
|
||||
```
|
||||
If there were no errors, we should have two files available under `db/migrations` folder:
|
||||
- 000001_create_users_table.down.sql
|
||||
- 000001_create_users_table.up.sql
|
||||
|
||||
Note the `sql` extension that we provided.
|
||||
|
||||
In the `.up.sql` file let's create the table:
|
||||
```
|
||||
CREATE TABLE IF NOT EXISTS example.users
|
||||
(
|
||||
user_id INT PRIMARY KEY,
|
||||
username VARCHAR (50) UNIQUE NOT NULL,
|
||||
password VARCHAR (50) NOT NULL,
|
||||
email VARCHAR (300) UNIQUE NOT NULL
|
||||
);
|
||||
```
|
||||
And in the `.down.sql` let's delete it:
|
||||
```
|
||||
DROP TABLE IF EXISTS example.users;
|
||||
```
|
||||
By adding `IF EXISTS/IF NOT EXISTS` we are making migrations idempotent - you can read more about idempotency in [getting started](/GETTING_STARTED.md#create-migrations)
|
||||
|
||||
## Run migrations
|
||||
```
|
||||
migrate -database ${COCKROACHDB_URL} -path db/migrations up
|
||||
```
|
||||
Let's check if the table was created properly by running `cockroach sql --insecure --host=localhost:26257 -e "show columns from example.users;"`.
|
||||
The output you are supposed to see:
|
||||
```
|
||||
column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden
|
||||
+-------------+--------------+-------------+----------------+-----------------------+----------------------------------------------+-----------+
|
||||
user_id | INT8 | false | NULL | | {primary,users_username_key,users_email_key} | false
|
||||
username | VARCHAR(50) | false | NULL | | {users_username_key} | false
|
||||
password | VARCHAR(50) | false | NULL | | {} | false
|
||||
email | VARCHAR(300) | false | NULL | | {users_email_key} | false
|
||||
(4 rows)
|
||||
```
|
||||
Now let's check if running reverse migration also works:
|
||||
```
|
||||
migrate -database ${COCKROACHDB_URL} -path db/migrations down
|
||||
```
|
||||
Make sure to check if your database changed as expected in this case as well.
|
||||
|
||||
## Database transactions
|
||||
|
||||
To show database transactions usage, let's create another set of migrations by running:
|
||||
```
|
||||
migrate create -ext sql -dir db/migrations -seq add_mood_to_users
|
||||
```
|
||||
Again, it should create for us two migrations files:
|
||||
- 000002_add_mood_to_users.down.sql
|
||||
- 000002_add_mood_to_users.up.sql
|
||||
|
||||
In Cockroach, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands, similar to PostgreSQL.
|
||||
In our example, we are going to add a column to our database that can only accept enumerable values or NULL.
|
||||
Migration up:
|
||||
```
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE example.users ADD COLUMN mood STRING;
|
||||
ALTER TABLE example.users ADD CONSTRAINT check_mood CHECK (mood IN ('happy', 'sad', 'neutral'));
|
||||
|
||||
COMMIT;
|
||||
```
|
||||
Migration down:
|
||||
```
|
||||
ALTER TABLE example.users DROP COLUMN mood;
|
||||
```
|
||||
|
||||
Now we can run our new migration and check the database:
|
||||
```
|
||||
migrate -database ${COCKROACHDB_URL} -path db/migrations up
|
||||
cockroach sql --insecure --host=localhost:26257 -e "show columns from example.users;"
|
||||
```
|
||||
Expected output:
|
||||
```
|
||||
column_name | data_type | is_nullable | column_default | generation_expression | indices | is_hidden
|
||||
+-------------+--------------+-------------+----------------+-----------------------+----------------------------------------------+-----------+
|
||||
user_id | INT8 | false | NULL | | {primary,users_username_key,users_email_key} | false
|
||||
username | VARCHAR(50) | false | NULL | | {users_username_key} | false
|
||||
password | VARCHAR(50) | false | NULL | | {} | false
|
||||
email | VARCHAR(300) | false | NULL | | {users_email_key} | false
|
||||
mood | STRING | true | NULL | | {} | false
|
||||
(5 rows)
|
||||
```
|
||||
|
||||
## Optional: Run migrations within your Go app
|
||||
Here is a very simple app running migrations for the above configuration:
|
||||
```
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/cockroachdb"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
func main() {
|
||||
m, err := migrate.New(
|
||||
"file://db/migrations",
|
||||
"cockroachdb://cockroach:@localhost:26257/example?sslmode=disable")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := m.Up(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
You can find details [here](README.md#use-in-your-go-project)
|
||||
+365
@@ -0,0 +1,365 @@
|
||||
package cockroachdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/lib/pq"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := CockroachDb{}
|
||||
database.Register("cockroach", &db)
|
||||
database.Register("cockroachdb", &db)
|
||||
database.Register("crdb-postgres", &db)
|
||||
}
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
var DefaultLockTable = "schema_lock"
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
LockTable string
|
||||
ForceLock bool
|
||||
DatabaseName string
|
||||
}
|
||||
|
||||
type CockroachDb struct {
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT current_database()`
|
||||
var databaseName string
|
||||
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if len(config.LockTable) == 0 {
|
||||
config.LockTable = DefaultLockTable
|
||||
}
|
||||
|
||||
px := &CockroachDb{
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
// ensureVersionTable is a locking operation, so we need to ensureLockTable before we ensureVersionTable.
|
||||
if err := px.ensureLockTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (c *CockroachDb) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// As Cockroach uses the postgres protocol, and 'postgres' is already a registered database, we need to replace the
|
||||
// connect prefix, with the actual protocol, so that the library can differentiate between the implementations
|
||||
re := regexp.MustCompile("^(cockroach(db)?|crdb-postgres)")
|
||||
connectString := re.ReplaceAllString(migrate.FilterCustomQuery(purl).String(), "postgres")
|
||||
|
||||
db, err := sql.Open("postgres", connectString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
if len(migrationsTable) == 0 {
|
||||
migrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
lockTable := purl.Query().Get("x-lock-table")
|
||||
if len(lockTable) == 0 {
|
||||
lockTable = DefaultLockTable
|
||||
}
|
||||
|
||||
forceLockQuery := purl.Query().Get("x-force-lock")
|
||||
forceLock, err := strconv.ParseBool(forceLockQuery)
|
||||
if err != nil {
|
||||
forceLock = false
|
||||
}
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
LockTable: lockTable,
|
||||
ForceLock: forceLock,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (c *CockroachDb) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
// Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed
|
||||
// See: https://github.com/cockroachdb/cockroach/issues/13546
|
||||
func (c *CockroachDb) Lock() error {
|
||||
return database.CasRestoreOnErr(&c.isLocked, false, true, database.ErrLocked, func() (err error) {
|
||||
return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) (err error) {
|
||||
aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + c.config.LockTable + " WHERE lock_id = $1"
|
||||
rows, err := tx.Query(query, aid)
|
||||
if err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := rows.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// If row exists at all, lock is present
|
||||
locked := rows.Next()
|
||||
if locked && !c.config.ForceLock {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
query = "INSERT INTO " + c.config.LockTable + " (lock_id) VALUES ($1)"
|
||||
if _, err := tx.Exec(query, aid); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// Locking is done manually with a separate lock table. Implementing advisory locks in CRDB is being discussed
|
||||
// See: https://github.com/cockroachdb/cockroach/issues/13546
|
||||
func (c *CockroachDb) Unlock() error {
|
||||
return database.CasRestoreOnErr(&c.isLocked, true, false, database.ErrNotLocked, func() (err error) {
|
||||
aid, err := database.GenerateAdvisoryLockId(c.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// In the event of an implementation (non-migration) error, it is possible for the lock to not be released. Until
|
||||
// a better locking mechanism is added, a manual purging of the lock table may be required in such circumstances
|
||||
query := "DELETE FROM " + c.config.LockTable + " WHERE lock_id = $1"
|
||||
if _, err := c.db.Exec(query, aid); err != nil {
|
||||
if e, ok := err.(*pq.Error); ok {
|
||||
// 42P01 is "UndefinedTableError" in CockroachDB
|
||||
// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
|
||||
if e.Code == "42P01" {
|
||||
// On drops, the lock table is fully removed; This is fine, and is a valid "unlocked" state for the schema
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CockroachDb) Run(migration io.Reader) error {
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run migration
|
||||
query := string(migr[:])
|
||||
if _, err := c.db.Exec(query); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CockroachDb) SetVersion(version int, dirty bool) error {
|
||||
return crdb.ExecuteTx(context.Background(), c.db, nil, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(`DELETE FROM "` + c.config.MigrationsTable + `"`); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
if _, err := tx.Exec(`INSERT INTO "`+c.config.MigrationsTable+`" (version, dirty) VALUES ($1, $2)`, version, dirty); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *CockroachDb) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
|
||||
err = c.db.QueryRow(query).Scan(&version, &dirty)
|
||||
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pq.Error); ok {
|
||||
// 42P01 is "UndefinedTableError" in CockroachDB
|
||||
// https://github.com/cockroachdb/cockroach/blob/master/pkg/sql/pgwire/pgerror/codes.go
|
||||
if e.Code == "42P01" {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CockroachDb) Drop() (err error) {
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())`
|
||||
tables, err := c.db.Query(query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + t + ` CASCADE`
|
||||
if _, err := c.db.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the CockroachDb type.
|
||||
func (c *CockroachDb) ensureVersionTable() (err error) {
|
||||
if err = c.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := c.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// check if migration table exists
|
||||
var count int
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
|
||||
if err := c.db.QueryRow(query, c.config.MigrationsTable).Scan(&count); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, create the empty migration table
|
||||
query = `CREATE TABLE "` + c.config.MigrationsTable + `" (version INT NOT NULL PRIMARY KEY, dirty BOOL NOT NULL)`
|
||||
if _, err := c.db.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CockroachDb) ensureLockTable() error {
|
||||
// check if lock table exists
|
||||
var count int
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
|
||||
if err := c.db.QueryRow(query, c.config.LockTable).Scan(&count); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, create the empty lock table
|
||||
query = `CREATE TABLE "` + c.config.LockTable + `" (lock_id INT NOT NULL PRIMARY KEY)`
|
||||
if _, err := c.db.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
+174
@@ -0,0 +1,174 @@
|
||||
package cockroachdb
|
||||
|
||||
// error codes https://github.com/lib/pq/blob/master/error.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"log"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
import (
|
||||
"github.com/dhui/dktest"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
import (
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const defaultPort = 26257
|
||||
|
||||
var (
|
||||
opts = dktest.Options{Cmd: []string{"start", "--insecure"}, PortRequired: true, ReadyFunc: isReady}
|
||||
// Released versions: https://www.cockroachlabs.com/docs/releases/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "cockroachdb/cockroach:v1.0.7", Options: opts},
|
||||
{ImageName: "cockroachdb/cockroach:v1.1.9", Options: opts},
|
||||
{ImageName: "cockroachdb/cockroach:v2.0.7", Options: opts},
|
||||
{ImageName: "cockroachdb/cockroach:v2.1.3", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
log.Println("port error:", err)
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("postgres://root@%v:%v?sslmode=disable", ip, port))
|
||||
if err != nil {
|
||||
log.Println("open error:", err)
|
||||
return false
|
||||
}
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
log.Println("ping error:", err)
|
||||
return false
|
||||
}
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func createDB(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("postgres://root@%v:%v?sslmode=disable", ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = db.Ping(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err = db.Exec("CREATE DATABASE migrate"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
|
||||
createDB(t, ci)
|
||||
|
||||
ip, port, err := ci.Port(26257)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
|
||||
c := &CockroachDb{}
|
||||
d, err := c.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
|
||||
createDB(t, ci)
|
||||
|
||||
ip, port, err := ci.Port(26257)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
|
||||
c := &CockroachDb{}
|
||||
d, err := c.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "migrate", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiStatement(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
|
||||
createDB(t, ci)
|
||||
|
||||
ip, port, err := ci.Port(26257)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable", ip, port)
|
||||
c := &CockroachDb{}
|
||||
d, err := c.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.(*CockroachDb).db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, ci dktest.ContainerInfo) {
|
||||
createDB(t, ci)
|
||||
|
||||
ip, port, err := ci.Port(26257)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("cockroach://root@%v:%v/migrate?sslmode=disable&x-custom=foobar", ip, port)
|
||||
c := &CockroachDb{}
|
||||
_, err = c.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS users;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE users (
|
||||
user_id INT UNIQUE,
|
||||
name STRING(40),
|
||||
email STRING(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS city;
|
||||
+1
@@ -0,0 +1 @@
|
||||
ALTER TABLE users ADD COLUMN city TEXT;
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS users_email_index;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS users_email_index ON users (email);
|
||||
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS books;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE books (
|
||||
user_id INT,
|
||||
name STRING(40),
|
||||
author STRING(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS movies;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE movies (
|
||||
user_id INT,
|
||||
name STRING(40),
|
||||
director STRING(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
@@ -0,0 +1,123 @@
|
||||
// Package database provides the Driver interface.
|
||||
// All database drivers must implement this interface, register themselves,
|
||||
// optionally provide a `WithInstance` function and pass the tests
|
||||
// in package database/testing.
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
|
||||
iurl "github.com/golang-migrate/migrate/v4/internal/url"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLocked = fmt.Errorf("can't acquire lock")
|
||||
ErrNotLocked = fmt.Errorf("can't unlock, as not currently locked")
|
||||
)
|
||||
|
||||
const NilVersion int = -1
|
||||
|
||||
var driversMu sync.RWMutex
|
||||
var drivers = make(map[string]Driver)
|
||||
|
||||
// Driver is the interface every database driver must implement.
|
||||
//
|
||||
// How to implement a database driver?
|
||||
// 1. Implement this interface.
|
||||
// 2. Optionally, add a function named `WithInstance`.
|
||||
// This function should accept an existing DB instance and a Config{} struct
|
||||
// and return a driver instance.
|
||||
// 3. Add a test that calls database/testing.go:Test()
|
||||
// 4. Add own tests for Open(), WithInstance() (when provided) and Close().
|
||||
// All other functions are tested by tests in database/testing.
|
||||
// Saves you some time and makes sure all database drivers behave the same way.
|
||||
// 5. Call Register in init().
|
||||
// 6. Create a internal/cli/build_<driver-name>.go file
|
||||
// 7. Add driver name in 'DATABASE' variable in Makefile
|
||||
//
|
||||
// Guidelines:
|
||||
// - Don't try to correct user input. Don't assume things.
|
||||
// When in doubt, return an error and explain the situation to the user.
|
||||
// - All configuration input must come from the URL string in func Open()
|
||||
// or the Config{} struct in WithInstance. Don't os.Getenv().
|
||||
type Driver interface {
|
||||
// Open returns a new driver instance configured with parameters
|
||||
// coming from the URL string. Migrate will call this function
|
||||
// only once per instance.
|
||||
Open(url string) (Driver, error)
|
||||
|
||||
// Close closes the underlying database instance managed by the driver.
|
||||
// Migrate will call this function only once per instance.
|
||||
Close() error
|
||||
|
||||
// Lock should acquire a database lock so that only one migration process
|
||||
// can run at a time. Migrate will call this function before Run is called.
|
||||
// If the implementation can't provide this functionality, return nil.
|
||||
// Return database.ErrLocked if database is already locked.
|
||||
Lock() error
|
||||
|
||||
// Unlock should release the lock. Migrate will call this function after
|
||||
// all migrations have been run.
|
||||
Unlock() error
|
||||
|
||||
// Run applies a migration to the database. migration is guaranteed to be not nil.
|
||||
Run(migration io.Reader) error
|
||||
|
||||
// SetVersion saves version and dirty state.
|
||||
// Migrate will call this function before and after each call to Run.
|
||||
// version must be >= -1. -1 means NilVersion.
|
||||
SetVersion(version int, dirty bool) error
|
||||
|
||||
// Version returns the currently active version and if the database is dirty.
|
||||
// When no migration has been applied, it must return version -1.
|
||||
// Dirty means, a previous migration failed and user interaction is required.
|
||||
Version() (version int, dirty bool, err error)
|
||||
|
||||
// Drop deletes everything in the database.
|
||||
// Note that this is a breaking action, a new call to Open() is necessary to
|
||||
// ensure subsequent calls work as expected.
|
||||
Drop() error
|
||||
}
|
||||
|
||||
// Open returns a new driver instance.
|
||||
func Open(url string) (Driver, error) {
|
||||
scheme, err := iurl.SchemeFromURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
driversMu.RLock()
|
||||
d, ok := drivers[scheme]
|
||||
driversMu.RUnlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database driver: unknown driver %v (forgotten import?)", scheme)
|
||||
}
|
||||
|
||||
return d.Open(url)
|
||||
}
|
||||
|
||||
// Register globally registers a driver.
|
||||
func Register(name string, driver Driver) {
|
||||
driversMu.Lock()
|
||||
defer driversMu.Unlock()
|
||||
if driver == nil {
|
||||
panic("Register driver is nil")
|
||||
}
|
||||
if _, dup := drivers[name]; dup {
|
||||
panic("Register called twice for driver " + name)
|
||||
}
|
||||
drivers[name] = driver
|
||||
}
|
||||
|
||||
// List lists the registered drivers
|
||||
func List() []string {
|
||||
driversMu.RLock()
|
||||
defer driversMu.RUnlock()
|
||||
names := make([]string, 0, len(drivers))
|
||||
for n := range drivers {
|
||||
names = append(names, n)
|
||||
}
|
||||
return names
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ExampleDriver() {
|
||||
// see database/stub for an example
|
||||
|
||||
// database/stub/stub.go has the driver implementation
|
||||
// database/stub/stub_test.go runs database/testing/test.go:Test
|
||||
}
|
||||
|
||||
// Using database/stub here is not possible as it
|
||||
// results in an import cycle.
|
||||
type mockDriver struct {
|
||||
url string
|
||||
}
|
||||
|
||||
func (m *mockDriver) Open(url string) (Driver, error) {
|
||||
return &mockDriver{
|
||||
url: url,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Lock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Unlock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Run(migration io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) SetVersion(version int, dirty bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Version() (version int, dirty bool, err error) {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func (m *mockDriver) Drop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestRegisterTwice(t *testing.T) {
|
||||
Register("mock", &mockDriver{})
|
||||
|
||||
var err interface{}
|
||||
func() {
|
||||
defer func() {
|
||||
err = recover()
|
||||
}()
|
||||
Register("mock", &mockDriver{})
|
||||
}()
|
||||
|
||||
if err == nil {
|
||||
t.Fatal("expected a panic when calling Register twice")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpen(t *testing.T) {
|
||||
// Make sure the driver is registered.
|
||||
// But if the previous test already registered it just ignore the panic.
|
||||
// If we don't do this it will be impossible to run this test standalone.
|
||||
func() {
|
||||
defer func() {
|
||||
_ = recover()
|
||||
}()
|
||||
Register("mock", &mockDriver{})
|
||||
}()
|
||||
|
||||
cases := []struct {
|
||||
url string
|
||||
err bool
|
||||
}{
|
||||
{
|
||||
"mock://user:pass@tcp(host:1337)/db",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"unknown://bla",
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, c := range cases {
|
||||
t.Run(c.url, func(t *testing.T) {
|
||||
d, err := Open(c.url)
|
||||
|
||||
if err == nil {
|
||||
if c.err {
|
||||
t.Fatal("expected an error for an unknown driver")
|
||||
} else {
|
||||
if md, ok := d.(*mockDriver); !ok {
|
||||
t.Fatalf("expected *mockDriver got %T", d)
|
||||
} else if md.url != c.url {
|
||||
t.Fatalf("expected %q got %q", c.url, md.url)
|
||||
}
|
||||
}
|
||||
} else if !c.err {
|
||||
t.Fatalf("did not expect %q", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Error should be used for errors involving queries ran against the database
|
||||
type Error struct {
|
||||
// Optional: the line number
|
||||
Line uint
|
||||
|
||||
// Query is a query excerpt
|
||||
Query []byte
|
||||
|
||||
// Err is a useful/helping error message for humans
|
||||
Err string
|
||||
|
||||
// OrigErr is the underlying error
|
||||
OrigErr error
|
||||
}
|
||||
|
||||
func (e Error) Error() string {
|
||||
if len(e.Err) == 0 {
|
||||
return fmt.Sprintf("%v in line %v: %s", e.OrigErr, e.Line, e.Query)
|
||||
}
|
||||
return fmt.Sprintf("%v in line %v: %s (details: %v)", e.Err, e.Line, e.Query, e.OrigErr)
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
# firebird
|
||||
|
||||
`firebirdsql://user:password@servername[:port_number]/database_name_or_file[?params1=value1[¶m2=value2]...]`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `auth_plugin_name` | | Authentication plugin name. Srp256/Srp/Legacy_Auth are available. (default is Srp) |
|
||||
| `column_name_to_lower` | | Force column name to lower. (default is false) |
|
||||
| `role` | | Role name |
|
||||
| `tzname` | | Time Zone name. (For Firebird 4.0+) |
|
||||
| `wire_crypt` | | Enable wire data encryption or not. For Firebird 3.0+ (default is true) |
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE users;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE users (
|
||||
user_id integer unique,
|
||||
name varchar(40),
|
||||
email varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
ALTER TABLE users DROP city;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE users ADD city varchar(100);
|
||||
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP INDEX users_email_index;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX users_email_index ON users (email);
|
||||
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE books;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE books (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
author varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE movies;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE movies (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
director varchar(40)
|
||||
);
|
||||
+259
@@ -0,0 +1,259 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package firebird
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
_ "github.com/nakagami/firebirdsql"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Firebird{}
|
||||
database.Register("firebird", &db)
|
||||
database.Register("firebirdsql", &db)
|
||||
}
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
DatabaseName string
|
||||
MigrationsTable string
|
||||
}
|
||||
|
||||
type Firebird struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fb := &Firebird{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := fb.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fb, nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Open(dsn string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db, err := sql.Open("firebirdsql", migrate.FilterCustomQuery(purl).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
MigrationsTable: purl.Query().Get("x-migrations-table"),
|
||||
DatabaseName: purl.Path,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Close() error {
|
||||
connErr := f.conn.Close()
|
||||
dbErr := f.db.Close()
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Lock() error {
|
||||
if !f.isLocked.CAS(false, true) {
|
||||
return database.ErrLocked
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Unlock() error {
|
||||
if !f.isLocked.CAS(true, false) {
|
||||
return database.ErrNotLocked
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Run(migration io.Reader) error {
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run migration
|
||||
query := string(migr[:])
|
||||
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firebird) SetVersion(version int, dirty bool) error {
|
||||
// Always re-write the schema version to prevent empty schema version
|
||||
// for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
|
||||
// TODO: parameterize this SQL statement
|
||||
// https://firebirdsql.org/refdocs/langrefupd20-execblock.html
|
||||
// VALUES (?, ?) doesn't work
|
||||
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
||||
DELETE FROM "%v";
|
||||
INSERT INTO "%v" (version, dirty) VALUES (%v, %v);
|
||||
END;`,
|
||||
f.config.MigrationsTable, f.config.MigrationsTable, version, btoi(dirty))
|
||||
|
||||
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firebird) Version() (version int, dirty bool, err error) {
|
||||
var d int
|
||||
query := fmt.Sprintf(`SELECT FIRST 1 version, dirty FROM "%v"`, f.config.MigrationsTable)
|
||||
err = f.conn.QueryRowContext(context.Background(), query).Scan(&version, &d)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
case err != nil:
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, itob(d), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Firebird) Drop() (err error) {
|
||||
// select all tables
|
||||
query := `SELECT rdb$relation_name FROM rdb$relations WHERE rdb$view_blr IS NULL AND (rdb$system_flag IS NULL OR rdb$system_flag = 0);`
|
||||
tables, err := f.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
||||
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
|
||||
execute statement 'drop table "%v"';
|
||||
END;`,
|
||||
t, t)
|
||||
|
||||
if _, err := f.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
func (f *Firebird) ensureVersionTable() (err error) {
|
||||
if err = f.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := f.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
query := fmt.Sprintf(`EXECUTE BLOCK AS BEGIN
|
||||
if (not exists(select 1 from rdb$relations where rdb$relation_name = '%v')) then
|
||||
execute statement 'create table "%v" (version bigint not null primary key, dirty smallint not null)';
|
||||
END;`,
|
||||
f.config.MigrationsTable, f.config.MigrationsTable)
|
||||
|
||||
if _, err = f.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// btoi converts bool to int
|
||||
func btoi(v bool) int {
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// itob converts int to bool
|
||||
func itob(v int) bool {
|
||||
return v != 0
|
||||
}
|
||||
+226
@@ -0,0 +1,226 @@
|
||||
package firebird
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
|
||||
_ "github.com/nakagami/firebirdsql"
|
||||
)
|
||||
|
||||
const (
|
||||
user = "test_user"
|
||||
password = "123456"
|
||||
dbName = "test.fdb"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
PortRequired: true,
|
||||
ReadyFunc: isReady,
|
||||
Env: map[string]string{
|
||||
"FIREBIRD_DATABASE": dbName,
|
||||
"FIREBIRD_USER": user,
|
||||
"FIREBIRD_PASSWORD": password,
|
||||
},
|
||||
}
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "jacobalberty/firebird:2.5-ss", Options: opts},
|
||||
{ImageName: "jacobalberty/firebird:3.0", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func fbConnectionString(host, port string) string {
|
||||
//firebird://user:password@servername[:port_number]/database_name_or_file[?params1=value1[¶m2=value2]...]
|
||||
return fmt.Sprintf("firebird://%s:%s@%s:%s//firebird/data/%s", user, password, host, port, dbName)
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("firebirdsql", fbConnectionString(ip, port))
|
||||
if err != nil {
|
||||
log.Println("open error:", err)
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fbConnectionString(ip, port)
|
||||
p := &Firebird{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT Count(*) FROM rdb$relations"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fbConnectionString(ip, port)
|
||||
p := &Firebird{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "firebirdsql", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fbConnectionString(ip, port)
|
||||
p := &Firebird{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wantErr := `migration failed in line 0: CREATE TABLEE foo (foo varchar(40)); (details: Dynamic SQL Error
|
||||
SQL error code = -104
|
||||
Token unknown - line 1, column 8
|
||||
TABLEE
|
||||
)`
|
||||
|
||||
if err := d.Run(strings.NewReader("CREATE TABLEE foo (foo varchar(40));")); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
msg := err.Error()
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, msg)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fbConnectionString(ip, port) + "?sslmode=disable&x-custom=foobar"
|
||||
p := &Firebird{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func Test_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fbConnectionString(ip, port)
|
||||
p := &Firebird{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dt.Test(t, d, []byte("SELECT Count(*) FROM rdb$relations"))
|
||||
|
||||
ps := d.(*Firebird)
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
# MongoDB
|
||||
|
||||
* Driver work with mongo through [db.runCommands](https://docs.mongodb.com/manual/reference/command/)
|
||||
* Migrations support json format. It contains array of commands for `db.runCommand`. Every command is executed in separate request to database
|
||||
* All keys have to be in quotes `"`
|
||||
* [Examples](./examples)
|
||||
|
||||
# Usage
|
||||
|
||||
`mongodb://user:password@host:port/dbname?query` (`mongodb+srv://` also works, but behaves a bit differently. See [docs](https://docs.mongodb.com/manual/reference/connection-string/#dns-seedlist-connection-format) for more information)
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-collection` | `MigrationsCollection` | Name of the migrations collection |
|
||||
| `x-transaction-mode` | `TransactionMode` | If set to `true` wrap commands in [transaction](https://docs.mongodb.com/manual/core/transactions). Available only for replica set. Driver is using [strconv.ParseBool](https://golang.org/pkg/strconv/#ParseBool) for parsing|
|
||||
| `x-advisory-locking` | `true` | Feature flag for advisory locking, if set to false, disable advisory locking |
|
||||
| `x-advisory-lock-collection` | `migrate_advisory_lock` | The name of the collection to use for advisory locking.|
|
||||
| `x-advisory-lock-timeout` | `15` | The max time in seconds that migrate will wait to acquire a lock before failing. |
|
||||
| `x-advisory-lock-timeout-interval` | `10` | The max time in seconds between attempts to acquire the advisory lock, the lock is attempted to be acquired using an exponential backoff algorithm. |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `user` | | The user to sign in as. Can be omitted |
|
||||
| `password` | | The user's password. Can be omitted |
|
||||
| `host` | | The host to connect to |
|
||||
| `port` | | The port to bind to |
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
[
|
||||
{
|
||||
"dropUser": "deminem"
|
||||
}
|
||||
]
|
||||
+12
@@ -0,0 +1,12 @@
|
||||
[
|
||||
{
|
||||
"createUser": "deminem",
|
||||
"pwd": "gogo",
|
||||
"roles": [
|
||||
{
|
||||
"role": "readWrite",
|
||||
"db": "testMigration"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+10
@@ -0,0 +1,10 @@
|
||||
[
|
||||
{
|
||||
"dropIndexes": "mycollection",
|
||||
"index": "username_sort_by_asc_created"
|
||||
},
|
||||
{
|
||||
"dropIndexes": "mycollection",
|
||||
"index": "unique_email"
|
||||
}
|
||||
]
|
||||
+21
@@ -0,0 +1,21 @@
|
||||
[{
|
||||
"createIndexes": "mycollection",
|
||||
"indexes": [
|
||||
{
|
||||
"key": {
|
||||
"username": 1,
|
||||
"created": -1
|
||||
},
|
||||
"name": "username_sort_by_asc_created",
|
||||
"background": true
|
||||
},
|
||||
{
|
||||
"key": {
|
||||
"email": 1
|
||||
},
|
||||
"name": "unique_email",
|
||||
"unique": true,
|
||||
"background": true
|
||||
}
|
||||
]
|
||||
}]
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"update": "users",
|
||||
"updates": [
|
||||
{
|
||||
"q": {},
|
||||
"u": {
|
||||
"$unset": {
|
||||
"status": ""
|
||||
}
|
||||
},
|
||||
"multi": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+16
@@ -0,0 +1,16 @@
|
||||
[
|
||||
{
|
||||
"update": "users",
|
||||
"updates": [
|
||||
{
|
||||
"q": {},
|
||||
"u": {
|
||||
"$set": {
|
||||
"status": "active"
|
||||
}
|
||||
},
|
||||
"multi": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+14
@@ -0,0 +1,14 @@
|
||||
[
|
||||
{
|
||||
"update": "users",
|
||||
"updates": [
|
||||
{
|
||||
"q": {},
|
||||
"u": {
|
||||
"fullname": ""
|
||||
},
|
||||
"multi": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
[
|
||||
{
|
||||
"aggregate": "users",
|
||||
"pipeline": [
|
||||
{
|
||||
"$project": {
|
||||
"_id": 1,
|
||||
"firstname": 1,
|
||||
"lastname": 1,
|
||||
"username": 1,
|
||||
"password": 1,
|
||||
"email": 1,
|
||||
"active": 1,
|
||||
"fullname": { "$concat": ["$firstname", " ", "$lastname"] }
|
||||
}
|
||||
},
|
||||
{
|
||||
"$out": "users"
|
||||
}
|
||||
],
|
||||
"cursor": {}
|
||||
}
|
||||
]
|
||||
@@ -0,0 +1,404 @@
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Mongo{}
|
||||
database.Register("mongodb", &db)
|
||||
database.Register("mongodb+srv", &db)
|
||||
}
|
||||
|
||||
var DefaultMigrationsCollection = "schema_migrations"
|
||||
|
||||
const DefaultLockingCollection = "migrate_advisory_lock" // the collection to use for advisory locking by default.
|
||||
const lockKeyUniqueValue = 0 // the unique value to lock on. If multiple clients try to insert the same key, it will fail (locked).
|
||||
const DefaultLockTimeout = 15 // the default maximum time to wait for a lock to be released.
|
||||
const DefaultLockTimeoutInterval = 10 // the default maximum intervals time for the locking timout.
|
||||
const DefaultAdvisoryLockingFlag = true // the default value for the advisory locking feature flag. Default is true.
|
||||
const LockIndexName = "lock_unique_key" // the name of the index which adds unique constraint to the locking_key field.
|
||||
const contextWaitTimeout = 5 * time.Second // how long to wait for the request to mongo to block/wait for.
|
||||
|
||||
var (
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrLockTimeoutConfigConflict = fmt.Errorf("both x-advisory-lock-timeout-interval and x-advisory-lock-timout-interval were specified")
|
||||
)
|
||||
|
||||
type Mongo struct {
|
||||
client *mongo.Client
|
||||
db *mongo.Database
|
||||
config *Config
|
||||
isLocked atomic.Bool
|
||||
}
|
||||
|
||||
type Locking struct {
|
||||
CollectionName string
|
||||
Timeout int
|
||||
Enabled bool
|
||||
Interval int
|
||||
}
|
||||
type Config struct {
|
||||
DatabaseName string
|
||||
MigrationsCollection string
|
||||
TransactionMode bool
|
||||
Locking Locking
|
||||
}
|
||||
type versionInfo struct {
|
||||
Version int `bson:"version"`
|
||||
Dirty bool `bson:"dirty"`
|
||||
}
|
||||
|
||||
type lockObj struct {
|
||||
Key int `bson:"locking_key"`
|
||||
Pid int `bson:"pid"`
|
||||
Hostname string `bson:"hostname"`
|
||||
CreatedAt time.Time `bson:"created_at"`
|
||||
}
|
||||
type findFilter struct {
|
||||
Key int `bson:"locking_key"`
|
||||
}
|
||||
|
||||
func WithInstance(instance *mongo.Client, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
if len(config.DatabaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
if len(config.MigrationsCollection) == 0 {
|
||||
config.MigrationsCollection = DefaultMigrationsCollection
|
||||
}
|
||||
if len(config.Locking.CollectionName) == 0 {
|
||||
config.Locking.CollectionName = DefaultLockingCollection
|
||||
}
|
||||
if config.Locking.Timeout <= 0 {
|
||||
config.Locking.Timeout = DefaultLockTimeout
|
||||
}
|
||||
if config.Locking.Interval <= 0 {
|
||||
config.Locking.Interval = DefaultLockTimeoutInterval
|
||||
}
|
||||
|
||||
mc := &Mongo{
|
||||
client: instance,
|
||||
db: instance.Database(config.DatabaseName),
|
||||
config: config,
|
||||
}
|
||||
|
||||
if mc.config.Locking.Enabled {
|
||||
if err := mc.ensureLockTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := mc.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
func (m *Mongo) Open(dsn string) (database.Driver, error) {
|
||||
// connstring is experimental package, but it used for parse connection string in mongo.Connect function
|
||||
uri, err := connstring.Parse(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(uri.Database) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
unknown := url.Values(uri.UnknownOptions)
|
||||
|
||||
migrationsCollection := unknown.Get("x-migrations-collection")
|
||||
lockCollection := unknown.Get("x-advisory-lock-collection")
|
||||
transactionMode, err := parseBoolean(unknown.Get("x-transaction-mode"), false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
advisoryLockingFlag, err := parseBoolean(unknown.Get("x-advisory-locking"), DefaultAdvisoryLockingFlag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lockingTimout, err := parseInt(unknown.Get("x-advisory-lock-timeout"), DefaultLockTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lockTimeoutIntervalValue := unknown.Get("x-advisory-lock-timeout-interval")
|
||||
// The initial release had a typo for this argument but for backwards compatibility sake, we will keep supporting it
|
||||
// and we will error out if both values are set.
|
||||
lockTimeoutIntervalValueFromTypo := unknown.Get("x-advisory-lock-timout-interval")
|
||||
|
||||
lockTimeout := lockTimeoutIntervalValue
|
||||
|
||||
if lockTimeoutIntervalValue != "" && lockTimeoutIntervalValueFromTypo != "" {
|
||||
return nil, ErrLockTimeoutConfigConflict
|
||||
} else if lockTimeoutIntervalValueFromTypo != "" {
|
||||
lockTimeout = lockTimeoutIntervalValueFromTypo
|
||||
}
|
||||
|
||||
maxLockCheckInterval, err := parseInt(lockTimeout, DefaultLockTimeoutInterval)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(dsn))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = client.Ping(context.TODO(), nil); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc, err := WithInstance(client, &Config{
|
||||
DatabaseName: uri.Database,
|
||||
MigrationsCollection: migrationsCollection,
|
||||
TransactionMode: transactionMode,
|
||||
Locking: Locking{
|
||||
CollectionName: lockCollection,
|
||||
Timeout: lockingTimout,
|
||||
Enabled: advisoryLockingFlag,
|
||||
Interval: maxLockCheckInterval,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
// Parse the url param, convert it to boolean
|
||||
// returns error if param invalid. returns defaultValue if param not present
|
||||
func parseBoolean(urlParam string, defaultValue bool) (bool, error) {
|
||||
|
||||
// if parameter passed, parse it (otherwise return default value)
|
||||
if urlParam != "" {
|
||||
result, err := strconv.ParseBool(urlParam)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// if no url Param passed, return default value
|
||||
return defaultValue, nil
|
||||
}
|
||||
|
||||
// Parse the url param, convert it to int
|
||||
// returns error if param invalid. returns defaultValue if param not present
|
||||
func parseInt(urlParam string, defaultValue int) (int, error) {
|
||||
|
||||
// if parameter passed, parse it (otherwise return default value)
|
||||
if urlParam != "" {
|
||||
result, err := strconv.Atoi(urlParam)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// if no url Param passed, return default value
|
||||
return defaultValue, nil
|
||||
}
|
||||
func (m *Mongo) SetVersion(version int, dirty bool) error {
|
||||
migrationsCollection := m.db.Collection(m.config.MigrationsCollection)
|
||||
if err := migrationsCollection.Drop(context.TODO()); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "drop migrations collection failed"}
|
||||
}
|
||||
_, err := migrationsCollection.InsertOne(context.TODO(), bson.M{"version": version, "dirty": dirty})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "save version failed"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mongo) Version() (version int, dirty bool, err error) {
|
||||
var versionInfo versionInfo
|
||||
err = m.db.Collection(m.config.MigrationsCollection).FindOne(context.TODO(), bson.M{}).Decode(&versionInfo)
|
||||
switch {
|
||||
case err == mongo.ErrNoDocuments:
|
||||
return database.NilVersion, false, nil
|
||||
case err != nil:
|
||||
return 0, false, &database.Error{OrigErr: err, Err: "failed to get migration version"}
|
||||
default:
|
||||
return versionInfo.Version, versionInfo.Dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mongo) Run(migration io.Reader) error {
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var cmds []bson.D
|
||||
err = bson.UnmarshalExtJSON(migr, true, &cmds)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshaling json error: %s", err)
|
||||
}
|
||||
if m.config.TransactionMode {
|
||||
if err := m.executeCommandsWithTransaction(context.TODO(), cmds); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := m.executeCommands(context.TODO(), cmds); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mongo) executeCommandsWithTransaction(ctx context.Context, cmds []bson.D) error {
|
||||
err := m.db.Client().UseSession(ctx, func(sessionContext mongo.SessionContext) error {
|
||||
if err := sessionContext.StartTransaction(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "failed to start transaction"}
|
||||
}
|
||||
if err := m.executeCommands(sessionContext, cmds); err != nil {
|
||||
// When command execution is failed, it's aborting transaction
|
||||
// If you tried to call abortTransaction, it`s return error that transaction already aborted
|
||||
return err
|
||||
}
|
||||
if err := sessionContext.CommitTransaction(sessionContext); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "failed to commit transaction"}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mongo) executeCommands(ctx context.Context, cmds []bson.D) error {
|
||||
for _, cmd := range cmds {
|
||||
err := m.db.RunCommand(ctx, cmd).Err()
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: fmt.Sprintf("failed to execute command:%v", cmd)}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mongo) Close() error {
|
||||
return m.client.Disconnect(context.TODO())
|
||||
}
|
||||
|
||||
func (m *Mongo) Drop() error {
|
||||
return m.db.Drop(context.TODO())
|
||||
}
|
||||
|
||||
func (m *Mongo) ensureLockTable() error {
|
||||
indexes := m.db.Collection(m.config.Locking.CollectionName).Indexes()
|
||||
|
||||
indexOptions := options.Index().SetUnique(true).SetName(LockIndexName)
|
||||
_, err := indexes.CreateOne(context.TODO(), mongo.IndexModel{
|
||||
Options: indexOptions,
|
||||
Keys: findFilter{Key: -1},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the MongoDb type.
|
||||
func (m *Mongo) ensureVersionTable() (err error) {
|
||||
if err = m.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := m.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err = m.Version(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Utilizes advisory locking on the config.LockingCollection collection
|
||||
// This uses a unique index on the `locking_key` field.
|
||||
func (m *Mongo) Lock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
|
||||
if !m.config.Locking.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
pid := os.Getpid()
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = fmt.Sprintf("Could not determine hostname. Error: %s", err.Error())
|
||||
}
|
||||
|
||||
newLockObj := lockObj{
|
||||
Key: lockKeyUniqueValue,
|
||||
Pid: pid,
|
||||
Hostname: hostname,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
operation := func() error {
|
||||
timeout, cancelFunc := context.WithTimeout(context.Background(), contextWaitTimeout)
|
||||
_, err := m.db.Collection(m.config.Locking.CollectionName).InsertOne(timeout, newLockObj)
|
||||
defer cancelFunc()
|
||||
return err
|
||||
}
|
||||
exponentialBackOff := backoff.NewExponentialBackOff()
|
||||
duration := time.Duration(m.config.Locking.Timeout) * time.Second
|
||||
exponentialBackOff.MaxElapsedTime = duration
|
||||
exponentialBackOff.MaxInterval = time.Duration(m.config.Locking.Interval) * time.Second
|
||||
|
||||
err = backoff.Retry(operation, exponentialBackOff)
|
||||
if err != nil {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mongo) Unlock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
if !m.config.Locking.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
filter := findFilter{
|
||||
Key: lockKeyUniqueValue,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), contextWaitTimeout)
|
||||
_, err := m.db.Collection(m.config.Locking.CollectionName).DeleteMany(ctx, filter)
|
||||
defer cancel()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
+430
@@ -0,0 +1,430 @@
|
||||
package mongodb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
import (
|
||||
"github.com/dhui/dktest"
|
||||
"go.mongodb.org/mongo-driver/bson"
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
import (
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: https://www.mongodb.com/support-policy
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "mongo:3.4", Options: opts},
|
||||
{ImageName: "mongo:3.6", Options: opts},
|
||||
{ImageName: "mongo:4.0", Options: opts},
|
||||
{ImageName: "mongo:4.2", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func mongoConnectionString(host, port string) string {
|
||||
// there is connect option for excluding serverConnection algorithm
|
||||
// it's let avoid errors with mongo replica set connection in docker container
|
||||
return fmt.Sprintf("mongodb://%s:%s/testMigration?connect=direct", host, port)
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoConnectionString(ip, port)))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := client.Disconnect(ctx); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err = client.Ping(ctx, nil); err != nil {
|
||||
switch err {
|
||||
case io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
t.Run("test", test)
|
||||
t.Run("testMigrate", testMigrate)
|
||||
t.Run("testWithAuth", testWithAuth)
|
||||
t.Run("testLockWorks", testLockWorks)
|
||||
|
||||
t.Cleanup(func() {
|
||||
for _, spec := range specs {
|
||||
t.Log("Cleaning up ", spec.ImageName)
|
||||
if err := spec.Cleanup(); err != nil {
|
||||
t.Error("Error removing ", spec.ImageName, "error:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := mongoConnectionString(ip, port)
|
||||
p := &Mongo{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.TestNilVersion(t, d)
|
||||
dt.TestLockAndUnlock(t, d)
|
||||
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
|
||||
dt.TestSetVersion(t, d)
|
||||
dt.TestDrop(t, d)
|
||||
})
|
||||
}
|
||||
|
||||
func testMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := mongoConnectionString(ip, port)
|
||||
p := &Mongo{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func testWithAuth(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := mongoConnectionString(ip, port)
|
||||
p := &Mongo{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
createUserCMD := []byte(`[{"createUser":"deminem","pwd":"gogo","roles":[{"role":"readWrite","db":"testMigration"}]}]`)
|
||||
err = d.Run(bytes.NewReader(createUserCMD))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testcases := []struct {
|
||||
name string
|
||||
connectUri string
|
||||
isErrorExpected bool
|
||||
}{
|
||||
{"right auth data", "mongodb://deminem:gogo@%s:%v/testMigration", false},
|
||||
{"wrong auth data", "mongodb://wrong:auth@%s:%v/testMigration", true},
|
||||
}
|
||||
|
||||
for _, tcase := range testcases {
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
mc := &Mongo{}
|
||||
d, err := mc.Open(fmt.Sprintf(tcase.connectUri, ip, port))
|
||||
if err == nil {
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
switch {
|
||||
case tcase.isErrorExpected && err == nil:
|
||||
t.Fatalf("no error when expected")
|
||||
case !tcase.isErrorExpected && err != nil:
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func testLockWorks(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := mongoConnectionString(ip, port)
|
||||
p := &Mongo{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dt.TestRun(t, d, bytes.NewReader([]byte(`[{"insert":"hello","documents":[{"wild":"world"}]}]`)))
|
||||
|
||||
mc := d.(*Mongo)
|
||||
|
||||
err = mc.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = mc.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// enable locking,
|
||||
//try to hit a lock conflict
|
||||
mc.config.Locking.Enabled = true
|
||||
mc.config.Locking.Timeout = 1
|
||||
err = mc.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = mc.Lock()
|
||||
if err == nil {
|
||||
t.Fatal("should have failed, mongo should be locked already")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransaction(t *testing.T) {
|
||||
transactionSpecs := []dktesting.ContainerSpec{
|
||||
{ImageName: "mongo:4", Options: dktest.Options{PortRequired: true, ReadyFunc: isReady,
|
||||
Cmd: []string{"mongod", "--bind_ip_all", "--replSet", "rs0"}}},
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
for _, spec := range transactionSpecs {
|
||||
t.Log("Cleaning up ", spec.ImageName)
|
||||
if err := spec.Cleanup(); err != nil {
|
||||
t.Error("Error removing ", spec.ImageName, "error:", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
dktesting.ParallelTest(t, transactionSpecs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(mongoConnectionString(ip, port)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = client.Ping(context.TODO(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
//rs.initiate()
|
||||
err = client.Database("admin").RunCommand(context.TODO(), bson.D{bson.E{Key: "replSetInitiate", Value: bson.D{}}}).Err()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = waitForReplicaInit(client)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d, err := WithInstance(client, &Config{
|
||||
DatabaseName: "testMigration",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
//We have to create collection
|
||||
//transactions don't support operations with creating new dbs, collections
|
||||
//Unique index need for checking transaction aborting
|
||||
insertCMD := []byte(`[
|
||||
{"create":"hello"},
|
||||
{"createIndexes": "hello",
|
||||
"indexes": [{
|
||||
"key": {
|
||||
"wild": 1
|
||||
},
|
||||
"name": "unique_wild",
|
||||
"unique": true,
|
||||
"background": true
|
||||
}]
|
||||
}]`)
|
||||
err = d.Run(bytes.NewReader(insertCMD))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
testcases := []struct {
|
||||
name string
|
||||
cmds []byte
|
||||
documentsCount int64
|
||||
isErrorExpected bool
|
||||
}{
|
||||
{
|
||||
name: "success transaction",
|
||||
cmds: []byte(`[{"insert":"hello","documents":[
|
||||
{"wild":"world"},
|
||||
{"wild":"west"},
|
||||
{"wild":"natural"}
|
||||
]
|
||||
}]`),
|
||||
documentsCount: 3,
|
||||
isErrorExpected: false,
|
||||
},
|
||||
{
|
||||
name: "failure transaction",
|
||||
//transaction have to be failure - duplicate unique key wild:west
|
||||
//none of the documents should be added
|
||||
cmds: []byte(`[{"insert":"hello","documents":[{"wild":"flower"}]},
|
||||
{"insert":"hello","documents":[
|
||||
{"wild":"cat"},
|
||||
{"wild":"west"}
|
||||
]
|
||||
}]`),
|
||||
documentsCount: 3,
|
||||
isErrorExpected: true,
|
||||
},
|
||||
}
|
||||
for _, tcase := range testcases {
|
||||
t.Run(tcase.name, func(t *testing.T) {
|
||||
client, err := mongo.Connect(context.TODO(), options.Client().ApplyURI(mongoConnectionString(ip, port)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = client.Ping(context.TODO(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
d, err := WithInstance(client, &Config{
|
||||
DatabaseName: "testMigration",
|
||||
TransactionMode: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
runErr := d.Run(bytes.NewReader(tcase.cmds))
|
||||
if runErr != nil {
|
||||
if !tcase.isErrorExpected {
|
||||
t.Fatal(runErr)
|
||||
}
|
||||
}
|
||||
documentsCount, err := client.Database("testMigration").Collection("hello").CountDocuments(context.TODO(), bson.M{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tcase.documentsCount != documentsCount {
|
||||
t.Fatalf("expected %d and actual %d documents count not equal. run migration error:%s", tcase.documentsCount, documentsCount, runErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type isMaster struct {
|
||||
IsMaster bool `bson:"ismaster"`
|
||||
}
|
||||
|
||||
func waitForReplicaInit(client *mongo.Client) error {
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
timeout, err := strconv.Atoi(os.Getenv("MIGRATE_TEST_MONGO_REPLICA_SET_INIT_TIMEOUT"))
|
||||
if err != nil {
|
||||
timeout = 30
|
||||
}
|
||||
timeoutTimer := time.NewTimer(time.Duration(timeout) * time.Second)
|
||||
defer timeoutTimer.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
var status isMaster
|
||||
//Check that node is primary because
|
||||
//during replica set initialization, the first node first becomes a secondary and then becomes the primary
|
||||
//should consider that initialization is completed only after the node has become the primary
|
||||
result := client.Database("admin").RunCommand(context.TODO(), bson.D{bson.E{Key: "isMaster", Value: 1}})
|
||||
r, err := result.DecodeBytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bson.Unmarshal(r, &status)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if status.IsMaster {
|
||||
return nil
|
||||
}
|
||||
case <-timeoutTimer.C:
|
||||
return fmt.Errorf("replica init timeout")
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Package multistmt provides methods for parsing multi-statement database migrations
|
||||
package multistmt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
)
|
||||
|
||||
// StartBufSize is the default starting size of the buffer used to scan and parse multi-statement migrations
|
||||
var StartBufSize = 4096
|
||||
|
||||
// Handler handles a single migration parsed from a multi-statement migration.
|
||||
// It's given the single migration to handle and returns whether or not further statements
|
||||
// from the multi-statement migration should be parsed and handled.
|
||||
type Handler func(migration []byte) bool
|
||||
|
||||
func splitWithDelimiter(delimiter []byte) func(d []byte, atEOF bool) (int, []byte, error) {
|
||||
return func(d []byte, atEOF bool) (int, []byte, error) {
|
||||
// SplitFunc inspired by bufio.ScanLines() implementation
|
||||
if atEOF {
|
||||
if len(d) == 0 {
|
||||
return 0, nil, nil
|
||||
}
|
||||
return len(d), d, nil
|
||||
}
|
||||
if i := bytes.Index(d, delimiter); i >= 0 {
|
||||
return i + len(delimiter), d[:i+len(delimiter)], nil
|
||||
}
|
||||
return 0, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Parse parses the given multi-statement migration
|
||||
func Parse(reader io.Reader, delimiter []byte, maxMigrationSize int, h Handler) error {
|
||||
scanner := bufio.NewScanner(reader)
|
||||
scanner.Buffer(make([]byte, 0, StartBufSize), maxMigrationSize)
|
||||
scanner.Split(splitWithDelimiter(delimiter))
|
||||
for scanner.Scan() {
|
||||
cont := h(scanner.Bytes())
|
||||
if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
return scanner.Err()
|
||||
}
|
||||
+57
@@ -0,0 +1,57 @@
|
||||
package multistmt_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
)
|
||||
|
||||
const maxMigrationSize = 1024
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
multiStmt string
|
||||
delimiter string
|
||||
expected []string
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "single statement, no delimiter", multiStmt: "single statement, no delimiter", delimiter: ";",
|
||||
expected: []string{"single statement, no delimiter"}, expectedErr: nil},
|
||||
{name: "single statement, one delimiter", multiStmt: "single statement, one delimiter;", delimiter: ";",
|
||||
expected: []string{"single statement, one delimiter;"}, expectedErr: nil},
|
||||
{name: "two statements, no trailing delimiter", multiStmt: "statement one; statement two", delimiter: ";",
|
||||
expected: []string{"statement one;", " statement two"}, expectedErr: nil},
|
||||
{name: "two statements, with trailing delimiter", multiStmt: "statement one; statement two;", delimiter: ";",
|
||||
expected: []string{"statement one;", " statement two;"}, expectedErr: nil},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
stmts := make([]string, 0, len(tc.expected))
|
||||
err := multistmt.Parse(strings.NewReader(tc.multiStmt), []byte(tc.delimiter), maxMigrationSize, func(b []byte) bool {
|
||||
stmts = append(stmts, string(b))
|
||||
return true
|
||||
})
|
||||
assert.Equal(t, tc.expectedErr, err)
|
||||
assert.Equal(t, tc.expected, stmts)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDiscontinue(t *testing.T) {
|
||||
multiStmt := "statement one; statement two"
|
||||
delimiter := ";"
|
||||
expected := []string{"statement one;"}
|
||||
|
||||
stmts := make([]string, 0, len(expected))
|
||||
err := multistmt.Parse(strings.NewReader(multiStmt), []byte(delimiter), maxMigrationSize, func(b []byte) bool {
|
||||
stmts = append(stmts, string(b))
|
||||
return false
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, expected, stmts)
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
# MySQL
|
||||
|
||||
`mysql://user:password@tcp(host:port)/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-no-lock` | `NoLock` | Set to `true` to skip `GET_LOCK`/`RELEASE_LOCK` statements. Useful for [multi-master MySQL flavors](https://www.percona.com/doc/percona-xtradb-cluster/LATEST/features/pxc-strict-mode.html#explicit-table-locking). Only run migrations from one host when this is enabled. |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds, functionally similar to [Server-side SELECT statement timeouts](https://dev.mysql.com/blog-archive/server-side-select-statement-timeouts/) but enforced by the client. Available for all versions of MySQL, not just >=5.7. |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. |
|
||||
| `port` | | The port to bind to. |
|
||||
| `tls` | | TLS / SSL encrypted connection parameter; see [go-sql-driver](https://github.com/go-sql-driver/mysql#tls). Use any name (e.g. `migrate`) if you want to use a custom TLS config (`x-tls-` queries). |
|
||||
| `x-tls-ca` | | The location of the CA (certificate authority) file. |
|
||||
| `x-tls-cert` | | The location of the client certificate file. Must be used with `x-tls-key`. |
|
||||
| `x-tls-key` | | The location of the private key file. Must be used with `x-tls-cert`. |
|
||||
| `x-tls-insecure-skip-verify` | | Whether or not to use SSL (true\|false) |
|
||||
|
||||
## Use with existing client
|
||||
|
||||
If you use the MySQL driver with existing database client, you must create the client with parameter `multiStatements=true`:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/mysql"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
func main() {
|
||||
db, _ := sql.Open("mysql", "user:password@tcp(host:port)/dbname?multiStatements=true")
|
||||
driver, _ := mysql.WithInstance(db, &mysql.Config{})
|
||||
m, _ := migrate.NewWithDatabaseInstance(
|
||||
"file:///migrations",
|
||||
"mysql",
|
||||
driver,
|
||||
)
|
||||
|
||||
m.Steps(2)
|
||||
}
|
||||
```
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://dev.mysql.com/doc/refman/5.7/en/commit.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS test;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE TABLE IF NOT EXISTS test (
|
||||
firstname VARCHAR(16)
|
||||
);
|
||||
@@ -0,0 +1,514 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
)
|
||||
|
||||
var _ database.Driver = (*Mysql)(nil) // explicit compile time type check
|
||||
|
||||
func init() {
|
||||
database.Register("mysql", &Mysql{})
|
||||
}
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
|
||||
var (
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrAppendPEM = fmt.Errorf("failed to append PEM")
|
||||
ErrTLSCertKeyConfig = fmt.Errorf("To use TLS client authentication, both x-tls-cert and x-tls-key must not be empty")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
NoLock bool
|
||||
StatementTimeout time.Duration
|
||||
}
|
||||
|
||||
type Mysql struct {
|
||||
// mysql RELEASE_LOCK must be called from the same conn, so
|
||||
// just do everything over a single conn anyway.
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
config *Config
|
||||
}
|
||||
|
||||
// connection instance must have `multiStatements` set to true
|
||||
func WithConnection(ctx context.Context, conn *sql.Conn, config *Config) (*Mysql, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := conn.PingContext(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx := &Mysql{
|
||||
conn: conn,
|
||||
db: nil,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT DATABASE()`
|
||||
var databaseName sql.NullString
|
||||
if err := conn.QueryRowContext(ctx, query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName.String) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName.String
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if err := mx.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
// instance must have `multiStatements` set to true
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx, err := WithConnection(ctx, conn, config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx.db = instance
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
// extractCustomQueryParams extracts the custom query params (ones that start with "x-") from
|
||||
// mysql.Config.Params (connection parameters) as to not interfere with connecting to MySQL
|
||||
func extractCustomQueryParams(c *mysql.Config) (map[string]string, error) {
|
||||
if c == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
customQueryParams := map[string]string{}
|
||||
|
||||
for k, v := range c.Params {
|
||||
if strings.HasPrefix(k, "x-") {
|
||||
customQueryParams[k] = v
|
||||
delete(c.Params, k)
|
||||
}
|
||||
}
|
||||
return customQueryParams, nil
|
||||
}
|
||||
|
||||
func urlToMySQLConfig(url string) (*mysql.Config, error) {
|
||||
// Need to parse out custom TLS parameters and call
|
||||
// mysql.RegisterTLSConfig() before mysql.ParseDSN() is called
|
||||
// which consumes the registered tls.Config
|
||||
// Fixes: https://github.com/golang-migrate/migrate/issues/411
|
||||
//
|
||||
// Can't use url.Parse() since it fails to parse MySQL DSNs
|
||||
// mysql.ParseDSN() also searches for "?" to find query parameters:
|
||||
// https://github.com/go-sql-driver/mysql/blob/46351a8/dsn.go#L344
|
||||
if idx := strings.LastIndex(url, "?"); idx > 0 {
|
||||
rawParams := url[idx+1:]
|
||||
parsedParams, err := nurl.ParseQuery(rawParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctls := parsedParams.Get("tls")
|
||||
if len(ctls) > 0 {
|
||||
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
|
||||
rootCertPool := x509.NewCertPool()
|
||||
pem, err := os.ReadFile(parsedParams.Get("x-tls-ca"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||
return nil, ErrAppendPEM
|
||||
}
|
||||
|
||||
clientCert := make([]tls.Certificate, 0, 1)
|
||||
if ccert, ckey := parsedParams.Get("x-tls-cert"), parsedParams.Get("x-tls-key"); ccert != "" || ckey != "" {
|
||||
if ccert == "" || ckey == "" {
|
||||
return nil, ErrTLSCertKeyConfig
|
||||
}
|
||||
certs, err := tls.LoadX509KeyPair(ccert, ckey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientCert = append(clientCert, certs)
|
||||
}
|
||||
|
||||
insecureSkipVerify := false
|
||||
insecureSkipVerifyStr := parsedParams.Get("x-tls-insecure-skip-verify")
|
||||
if len(insecureSkipVerifyStr) > 0 {
|
||||
x, err := strconv.ParseBool(insecureSkipVerifyStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
insecureSkipVerify = x
|
||||
}
|
||||
|
||||
err = mysql.RegisterTLSConfig(ctls, &tls.Config{
|
||||
RootCAs: rootCertPool,
|
||||
Certificates: clientCert,
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config, err := mysql.ParseDSN(strings.TrimPrefix(url, "mysql://"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.MultiStatements = true
|
||||
|
||||
// Keep backwards compatibility from when we used net/url.Parse() to parse the DSN.
|
||||
// net/url.Parse() would automatically unescape it for us.
|
||||
// See: https://play.golang.org/p/q9j1io-YICQ
|
||||
user, err := nurl.QueryUnescape(config.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.User = user
|
||||
|
||||
password, err := nurl.QueryUnescape(config.Passwd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.Passwd = password
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Open(url string) (database.Driver, error) {
|
||||
config, err := urlToMySQLConfig(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
customParams, err := extractCustomQueryParams(config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
noLockParam, noLock := customParams["x-no-lock"], false
|
||||
if noLockParam != "" {
|
||||
noLock, err = strconv.ParseBool(noLockParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse x-no-lock as bool: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
statementTimeoutParam := customParams["x-statement-timeout"]
|
||||
statementTimeout := 0
|
||||
if statementTimeoutParam != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutParam)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse x-statement-timeout as float: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", config.FormatDSN())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mx, err := WithInstance(db, &Config{
|
||||
DatabaseName: config.DBName,
|
||||
MigrationsTable: customParams["x-migrations-table"],
|
||||
NoLock: noLock,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mx, nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Close() error {
|
||||
connErr := m.conn.Close()
|
||||
var dbErr error
|
||||
if m.db != nil {
|
||||
dbErr = m.db.Close()
|
||||
}
|
||||
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Lock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, false, true, database.ErrLocked, func() error {
|
||||
if m.config.NoLock {
|
||||
return nil
|
||||
}
|
||||
aid, err := database.GenerateAdvisoryLockId(
|
||||
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "SELECT GET_LOCK(?, 10)"
|
||||
var success bool
|
||||
if err := m.conn.QueryRowContext(context.Background(), query, aid).Scan(&success); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
|
||||
if !success {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mysql) Unlock() error {
|
||||
return database.CasRestoreOnErr(&m.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
if m.config.NoLock {
|
||||
return nil
|
||||
}
|
||||
|
||||
aid, err := database.GenerateAdvisoryLockId(
|
||||
fmt.Sprintf("%s:%s", m.config.DatabaseName, m.config.MigrationsTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT RELEASE_LOCK(?)`
|
||||
if _, err := m.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// NOTE: RELEASE_LOCK could return NULL or (or 0 if the code is changed),
|
||||
// in which case isLocked should be true until the timeout expires -- synchronizing
|
||||
// these states is likely not worth trying to do; reconsider the necessity of isLocked.
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Mysql) Run(migration io.Reader) error {
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if m.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, m.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
query := string(migr[:])
|
||||
if _, err := m.conn.ExecContext(ctx, query); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) SetVersion(version int, dirty bool) error {
|
||||
tx, err := m.conn.BeginTx(context.Background(), &sql.TxOptions{Isolation: sql.LevelSerializable})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := "DELETE FROM `" + m.config.MigrationsTable + "`"
|
||||
if _, err := tx.ExecContext(context.Background(), query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query := "INSERT INTO `" + m.config.MigrationsTable + "` (version, dirty) VALUES (?, ?)"
|
||||
if _, err := tx.ExecContext(context.Background(), query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Mysql) Version() (version int, dirty bool, err error) {
|
||||
query := "SELECT version, dirty FROM `" + m.config.MigrationsTable + "` LIMIT 1"
|
||||
err = m.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*mysql.MySQLError); ok {
|
||||
if e.Number == 0 {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Mysql) Drop() (err error) {
|
||||
// select all tables
|
||||
query := `SHOW TABLES LIKE '%'`
|
||||
tables, err := m.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// disable checking foreign key constraints until finished
|
||||
query = `SET foreign_key_checks = 0`
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// enable foreign key checks
|
||||
_, _ = m.conn.ExecContext(context.Background(), `SET foreign_key_checks = 1`)
|
||||
}()
|
||||
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = "DROP TABLE IF EXISTS `" + t + "`"
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Mysql type.
|
||||
func (m *Mysql) ensureVersionTable() (err error) {
|
||||
if err = m.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := m.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// check if migration table exists
|
||||
var result string
|
||||
query := `SHOW TABLES LIKE '` + m.config.MigrationsTable + `'`
|
||||
if err := m.conn.QueryRowContext(context.Background(), query).Scan(&result); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
// if not, create the empty migration table
|
||||
query = "CREATE TABLE `" + m.config.MigrationsTable + "` (version bigint not null primary key, dirty boolean not null)"
|
||||
if _, err := m.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Returns the bool value of the input.
|
||||
// The 2nd return value indicates if the input was a valid bool value
|
||||
// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
|
||||
func readBool(input string) (value bool, valid bool) {
|
||||
switch input {
|
||||
case "1", "true", "TRUE", "True":
|
||||
return true, true
|
||||
case "0", "false", "FALSE", "False":
|
||||
return false, true
|
||||
}
|
||||
|
||||
// Not a valid bool value
|
||||
return
|
||||
}
|
||||
@@ -0,0 +1,420 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const defaultPort = 3306
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
}
|
||||
optsAnsiQuotes = dktest.Options{
|
||||
Env: map[string]string{"MYSQL_ROOT_PASSWORD": "root", "MYSQL_DATABASE": "public"},
|
||||
PortRequired: true, ReadyFunc: isReady,
|
||||
Cmd: []string{"--sql-mode=ANSI_QUOTES"},
|
||||
}
|
||||
// Supported versions: https://www.mysql.com/support/supportedplatforms/database.html
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "mysql:5.5", Options: opts},
|
||||
{ImageName: "mysql:5.6", Options: opts},
|
||||
{ImageName: "mysql:5.7", Options: opts},
|
||||
{ImageName: "mysql:8", Options: opts},
|
||||
}
|
||||
specsAnsiQuotes = []dktesting.ContainerSpec{
|
||||
{ImageName: "mysql:5.5", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:5.6", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:5.7", Options: optsAnsiQuotes},
|
||||
{ImageName: "mysql:8", Options: optsAnsiQuotes},
|
||||
}
|
||||
)
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("root:root@tcp(%v:%v)/public", ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, mysql.ErrInvalidConn:
|
||||
return false
|
||||
default:
|
||||
fmt.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateAnsiQuotes(t *testing.T) {
|
||||
// mysql.SetLogger(mysql.Logger(log.New(io.Discard, "", log.Ltime)))
|
||||
|
||||
dktesting.ParallelTest(t, specsAnsiQuotes, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "public", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
|
||||
// check ensureVersionTable
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check again
|
||||
if err := d.(*Mysql).ensureVersionTable(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLockWorks(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ms := d.(*Mysql)
|
||||
|
||||
err = ms.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ms.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure the 2nd lock works (RELEASE_LOCK is very finicky)
|
||||
err = ms.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err = ms.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoLockParamValidation(t *testing.T) {
|
||||
ip := "127.0.0.1"
|
||||
port := 3306
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
_, err := p.Open(addr + "?x-no-lock=not-a-bool")
|
||||
if !errors.Is(err, strconv.ErrSyntax) {
|
||||
t.Fatal("Expected syntax error when passing a non-bool as x-no-lock parameter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoLockWorks(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(defaultPort)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := fmt.Sprintf("mysql://root:root@tcp(%v:%v)/public", ip, port)
|
||||
p := &Mysql{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
lock := d.(*Mysql)
|
||||
|
||||
p = &Mysql{}
|
||||
d, err = p.Open(addr + "?x-no-lock=true")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
noLock := d.(*Mysql)
|
||||
|
||||
// Should be possible to take real lock and no-lock at the same time
|
||||
if err = lock.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = noLock.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = lock.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = noLock.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractCustomQueryParams(t *testing.T) {
|
||||
testcases := []struct {
|
||||
name string
|
||||
config *mysql.Config
|
||||
expectedParams map[string]string
|
||||
expectedCustomParams map[string]string
|
||||
expectedErr error
|
||||
}{
|
||||
{name: "nil config", expectedErr: ErrNilConfig},
|
||||
{
|
||||
name: "no params",
|
||||
config: mysql.NewConfig(),
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "no custom params",
|
||||
config: &mysql.Config{Params: map[string]string{"hello": "world"}},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "one param, one custom param",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{"hello": "world", "x-foo": "bar"},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar"},
|
||||
},
|
||||
{
|
||||
name: "multiple params, multiple custom params",
|
||||
config: &mysql.Config{
|
||||
Params: map[string]string{
|
||||
"hello": "world",
|
||||
"x-foo": "bar",
|
||||
"dead": "beef",
|
||||
"x-cat": "hat",
|
||||
},
|
||||
},
|
||||
expectedParams: map[string]string{"hello": "world", "dead": "beef"},
|
||||
expectedCustomParams: map[string]string{"x-foo": "bar", "x-cat": "hat"},
|
||||
},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
customParams, err := extractCustomQueryParams(tc.config)
|
||||
if tc.config != nil {
|
||||
assert.Equal(t, tc.expectedParams, tc.config.Params,
|
||||
"Expected config params have custom params properly removed")
|
||||
}
|
||||
assert.Equal(t, tc.expectedErr, err, "Expected errors to match")
|
||||
assert.Equal(t, tc.expectedCustomParams, customParams,
|
||||
"Expected custom params to be properly extracted")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createTmpCert(t *testing.T) string {
|
||||
tmpCertFile, err := os.CreateTemp("", "migrate_test_cert")
|
||||
if err != nil {
|
||||
t.Fatal("Failed to create temp cert file:", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
if err := os.Remove(tmpCertFile.Name()); err != nil {
|
||||
t.Log("Failed to cleanup temp cert file:", err)
|
||||
}
|
||||
})
|
||||
|
||||
r := rand.New(rand.NewSource(0))
|
||||
pub, priv, err := ed25519.GenerateKey(r)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to generate ed25519 key for temp cert file:", err)
|
||||
}
|
||||
tmpl := x509.Certificate{
|
||||
SerialNumber: big.NewInt(0),
|
||||
}
|
||||
derBytes, err := x509.CreateCertificate(r, &tmpl, &tmpl, pub, priv)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to generate temp cert file:", err)
|
||||
}
|
||||
if err := pem.Encode(tmpCertFile, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
|
||||
t.Fatal("Failed to encode ")
|
||||
}
|
||||
if err := tmpCertFile.Close(); err != nil {
|
||||
t.Fatal("Failed to close temp cert file:", err)
|
||||
}
|
||||
return tmpCertFile.Name()
|
||||
}
|
||||
|
||||
func TestURLToMySQLConfig(t *testing.T) {
|
||||
tmpCertFilename := createTmpCert(t)
|
||||
tmpCertFilenameEscaped := url.PathEscape(tmpCertFilename)
|
||||
|
||||
testcases := []struct {
|
||||
name string
|
||||
urlStr string
|
||||
expectedDSN string // empty string signifies that an error is expected
|
||||
}{
|
||||
{name: "no user/password", urlStr: "mysql://tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user", urlStr: "mysql://username@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user - with encoded :",
|
||||
urlStr: "mysql://username%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "only user - with encoded @",
|
||||
urlStr: "mysql://username%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password", urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
// Not supported yet: https://github.com/go-sql-driver/mysql/issues/591
|
||||
// {name: "user/password - user with encoded :",
|
||||
// urlStr: "mysql://username%3A:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
// expectedDSN: "username::password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - user with encoded @",
|
||||
urlStr: "mysql://username%40:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username@:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - password with encoded :",
|
||||
urlStr: "mysql://username:password%3A@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password:@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "user/password - password with encoded @",
|
||||
urlStr: "mysql://username:password%40@tcp(127.0.0.1:3306)/myDB?multiStatements=true",
|
||||
expectedDSN: "username:password@@tcp(127.0.0.1:3306)/myDB?multiStatements=true"},
|
||||
{name: "custom tls",
|
||||
urlStr: "mysql://username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped,
|
||||
expectedDSN: "username:password@tcp(127.0.0.1:3306)/myDB?multiStatements=true&tls=custom&x-tls-ca=" + tmpCertFilenameEscaped},
|
||||
}
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
config, err := urlToMySQLConfig(tc.urlStr)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse url string:", tc.urlStr, "error:", err)
|
||||
}
|
||||
dsn := config.FormatDSN()
|
||||
if dsn != tc.expectedDSN {
|
||||
t.Error("Got unexpected DSN:", dsn, "!=", tc.expectedDSN)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
# neo4j
|
||||
The Neo4j driver (bolt) does not natively support executing multiple statements in a single query. To allow for multiple statements in a single migration, you can use the `x-multi-statement` param.
|
||||
This mode splits the migration text into separately-executed statements by a semi-colon `;`. Thus `x-multi-statement` cannot be used when a statement in the migration contains a string with a semi-colon.
|
||||
The queries **should** run in a single transaction, so partial migrations should not be a concern, but this is untested.
|
||||
|
||||
|
||||
`neo4j://user:password@host:port/`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-multi-statement` | `MultiStatement` | Enable multiple statements to be ran in a single migration (See note above) |
|
||||
| `user` | Contained within `AuthConfig` | The user to sign in as |
|
||||
| `password` | Contained within `AuthConfig` | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 7687) |
|
||||
| | `MigrationsLabel` | Name of the migrations node label |
|
||||
|
||||
## Supported versions
|
||||
|
||||
Only Neo4j v3.5+ is [supported](https://github.com/neo4j/neo4j-go-driver/issues/64#issuecomment-625133600)
|
||||
@@ -0,0 +1,97 @@
|
||||
## Create migrations
|
||||
Let's create nodes called `Users`:
|
||||
```
|
||||
migrate create -ext cypher -dir db/migrations -seq create_user_nodes
|
||||
```
|
||||
If there were no errors, we should have two files available under `db/migrations` folder:
|
||||
- 000001_create_user_nodes.down.cypher
|
||||
- 000001_create_user_nodes.up.cypher
|
||||
|
||||
Note the `cypher` extension that we provided.
|
||||
|
||||
In the `.up.cypher` file let's create the table:
|
||||
```
|
||||
CREATE (u1:User {name: "Peter"})
|
||||
CREATE (u2:User {name: "Paul"})
|
||||
CREATE (u3:User {name: "Mary"})
|
||||
```
|
||||
And in the `.down.sql` let's delete it:
|
||||
```
|
||||
MATCH (u:User) WHERE u.name IN ["Peter", "Paul", "Mary"] DELETE u
|
||||
```
|
||||
Ideally your migrations should be idempotent. You can read more about idempotency in [getting started](GETTING_STARTED.md#create-migrations)
|
||||
|
||||
## Run migrations
|
||||
```
|
||||
migrate -database ${NEO4J_URL} -path db/migrations up
|
||||
```
|
||||
Let's check if the table was created properly by running `bin/cypher-shell -u neo4j -p password`, then `neo4j> MATCH (u:User)`
|
||||
The output you are supposed to see:
|
||||
```
|
||||
+-----------------------------------------------------------------+
|
||||
| u |
|
||||
+-----------------------------------------------------------------+
|
||||
| (:User {name: "Peter") |
|
||||
| (:User {name: "Paul") |
|
||||
| (:User {name: "Mary") |
|
||||
+-----------------------------------------------------------------+
|
||||
```
|
||||
Great! Now let's check if running reverse migration also works:
|
||||
```
|
||||
migrate -database ${NEO4J_URL} -path db/migrations down
|
||||
```
|
||||
Make sure to check if your database changed as expected in this case as well.
|
||||
|
||||
## Database transactions
|
||||
|
||||
To show database transactions usage, let's create another set of migrations by running:
|
||||
```
|
||||
migrate create -ext cypher -dir db/migrations -seq add_mood_to_users
|
||||
```
|
||||
Again, it should create for us two migrations files:
|
||||
- 000002_add_mood_to_users.down.cypher
|
||||
- 000002_add_mood_to_users.up.cypher
|
||||
|
||||
In Neo4j, when we want our queries to be done in a transaction, we need to wrap it with `:BEGIN` and `:COMMIT` commands.
|
||||
Migration up:
|
||||
```
|
||||
:BEGIN
|
||||
|
||||
MATCH (u:User)
|
||||
SET u.mood = "Cheery"
|
||||
|
||||
:COMMIT
|
||||
```
|
||||
Migration down:
|
||||
```
|
||||
:BEGIN
|
||||
|
||||
MATCH (u:User)
|
||||
SET u.mood = null
|
||||
|
||||
:COMMIT
|
||||
```
|
||||
|
||||
## Optional: Run migrations within your Go app
|
||||
Here is a very simple app running migrations for the above configuration:
|
||||
```
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/neo4j"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
func main() {
|
||||
m, err := migrate.New(
|
||||
"file://db/migrations",
|
||||
"neo4j://neo4j:password@localhost:7687/")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := m.Up(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
|
||||
+1
@@ -0,0 +1 @@
|
||||
CREATE CONSTRAINT ON (m:Movie) ASSERT m.Name IS UNIQUE
|
||||
+2
@@ -0,0 +1,2 @@
|
||||
MATCH (m:Movie)
|
||||
DELETE m
|
||||
+2
@@ -0,0 +1,2 @@
|
||||
CREATE (:Movie {name: "Footloose"})
|
||||
CREATE (:Movie {name: "Ghost"})
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE (:Movie {name: "Hollow Man"});
|
||||
CREATE (:Movie {name: "Mystic River"});
|
||||
;;;
|
||||
@@ -0,0 +1,303 @@
|
||||
package neo4j
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
neturl "net/url"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/neo4j/neo4j-go-driver/neo4j"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Neo4j{}
|
||||
database.Register("neo4j", &db)
|
||||
}
|
||||
|
||||
const DefaultMigrationsLabel = "SchemaMigration"
|
||||
|
||||
var (
|
||||
StatementSeparator = []byte(";")
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsLabel string
|
||||
MultiStatement bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Neo4j struct {
|
||||
driver neo4j.Driver
|
||||
lock uint32
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(driver neo4j.Driver, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
nDriver := &Neo4j{
|
||||
driver: driver,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := nDriver.ensureVersionConstraint(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nDriver, nil
|
||||
}
|
||||
|
||||
func (n *Neo4j) Open(url string) (database.Driver, error) {
|
||||
uri, err := neturl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
password, _ := uri.User.Password()
|
||||
authToken := neo4j.BasicAuth(uri.User.Username(), password, "")
|
||||
uri.User = nil
|
||||
uri.Scheme = "bolt"
|
||||
msQuery := uri.Query().Get("x-multi-statement")
|
||||
|
||||
// Whether to turn on/off TLS encryption.
|
||||
tlsEncrypted := uri.Query().Get("x-tls-encrypted")
|
||||
multi := false
|
||||
encrypted := false
|
||||
if msQuery != "" {
|
||||
multi, err = strconv.ParseBool(uri.Query().Get("x-multi-statement"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if tlsEncrypted != "" {
|
||||
encrypted, err = strconv.ParseBool(tlsEncrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := uri.Query().Get("x-multi-statement-max-size"); s != "" {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
uri.RawQuery = ""
|
||||
|
||||
driver, err := neo4j.NewDriver(uri.String(), authToken, func(config *neo4j.Config) {
|
||||
config.Encrypted = encrypted
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return WithInstance(driver, &Config{
|
||||
MigrationsLabel: DefaultMigrationsLabel,
|
||||
MultiStatement: multi,
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
})
|
||||
}
|
||||
|
||||
func (n *Neo4j) Close() error {
|
||||
return n.driver.Close()
|
||||
}
|
||||
|
||||
// local locking in order to pass tests, Neo doesn't support database locking
|
||||
func (n *Neo4j) Lock() error {
|
||||
if !atomic.CompareAndSwapUint32(&n.lock, 0, 1) {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Neo4j) Unlock() error {
|
||||
if !atomic.CompareAndSwapUint32(&n.lock, 1, 0) {
|
||||
return database.ErrNotLocked
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Neo4j) Run(migration io.Reader) (err error) {
|
||||
session, err := n.driver.Session(neo4j.AccessModeWrite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := session.Close(); cerr != nil {
|
||||
err = multierror.Append(err, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
if n.config.MultiStatement {
|
||||
_, err = session.WriteTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
|
||||
var stmtRunErr error
|
||||
if err := multistmt.Parse(migration, StatementSeparator, n.config.MultiStatementMaxSize, func(stmt []byte) bool {
|
||||
trimStmt := bytes.TrimSpace(stmt)
|
||||
if len(trimStmt) == 0 {
|
||||
return true
|
||||
}
|
||||
trimStmt = bytes.TrimSuffix(trimStmt, StatementSeparator)
|
||||
if len(trimStmt) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
result, err := transaction.Run(string(trimStmt), nil)
|
||||
if _, err := neo4j.Collect(result, err); err != nil {
|
||||
stmtRunErr = err
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, stmtRunErr
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = neo4j.Collect(session.Run(string(body[:]), nil))
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *Neo4j) SetVersion(version int, dirty bool) (err error) {
|
||||
session, err := n.driver.Session(neo4j.AccessModeWrite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := session.Close(); cerr != nil {
|
||||
err = multierror.Append(err, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
query := fmt.Sprintf("MERGE (sm:%s {version: $version}) SET sm.dirty = $dirty, sm.ts = datetime()",
|
||||
n.config.MigrationsLabel)
|
||||
_, err = neo4j.Collect(session.Run(query, map[string]interface{}{"version": version, "dirty": dirty}))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type MigrationRecord struct {
|
||||
Version int
|
||||
Dirty bool
|
||||
}
|
||||
|
||||
func (n *Neo4j) Version() (version int, dirty bool, err error) {
|
||||
session, err := n.driver.Session(neo4j.AccessModeRead)
|
||||
if err != nil {
|
||||
return database.NilVersion, false, err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := session.Close(); cerr != nil {
|
||||
err = multierror.Append(err, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
query := fmt.Sprintf(`MATCH (sm:%s) RETURN sm.version AS version, sm.dirty AS dirty
|
||||
ORDER BY COALESCE(sm.ts, datetime({year: 0})) DESC, sm.version DESC LIMIT 1`,
|
||||
n.config.MigrationsLabel)
|
||||
result, err := session.ReadTransaction(func(transaction neo4j.Transaction) (interface{}, error) {
|
||||
result, err := transaction.Run(query, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result.Next() {
|
||||
record := result.Record()
|
||||
mr := MigrationRecord{}
|
||||
versionResult, ok := record.Get("version")
|
||||
if !ok {
|
||||
mr.Version = database.NilVersion
|
||||
} else {
|
||||
mr.Version = int(versionResult.(int64))
|
||||
}
|
||||
|
||||
dirtyResult, ok := record.Get("dirty")
|
||||
if ok {
|
||||
mr.Dirty = dirtyResult.(bool)
|
||||
}
|
||||
|
||||
return mr, nil
|
||||
}
|
||||
return nil, result.Err()
|
||||
})
|
||||
if err != nil {
|
||||
return database.NilVersion, false, err
|
||||
}
|
||||
if result == nil {
|
||||
return database.NilVersion, false, err
|
||||
}
|
||||
mr := result.(MigrationRecord)
|
||||
return mr.Version, mr.Dirty, err
|
||||
}
|
||||
|
||||
func (n *Neo4j) Drop() (err error) {
|
||||
session, err := n.driver.Session(neo4j.AccessModeWrite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := session.Close(); cerr != nil {
|
||||
err = multierror.Append(err, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := neo4j.Collect(session.Run("MATCH (n) DETACH DELETE n", nil)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *Neo4j) ensureVersionConstraint() (err error) {
|
||||
session, err := n.driver.Session(neo4j.AccessModeWrite)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := session.Close(); cerr != nil {
|
||||
err = multierror.Append(err, cerr)
|
||||
}
|
||||
}()
|
||||
|
||||
/**
|
||||
Get constraint and check to avoid error duplicate
|
||||
using db.labels() to support Neo4j 3 and 4.
|
||||
Neo4J 3 doesn't support db.constraints() YIELD name
|
||||
*/
|
||||
res, err := neo4j.Collect(session.Run(fmt.Sprintf("CALL db.labels() YIELD label WHERE label=\"%s\" RETURN label", n.config.MigrationsLabel), nil))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(res) == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("CREATE CONSTRAINT ON (a:%s) ASSERT a.version IS UNIQUE", n.config.MigrationsLabel)
|
||||
if _, err := neo4j.Collect(session.Run(query, nil)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package neo4j
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/neo4j/neo4j-go-driver/neo4j"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{PortRequired: true, ReadyFunc: isReady,
|
||||
Env: map[string]string{"NEO4J_AUTH": "neo4j/migratetest", "NEO4J_ACCEPT_LICENSE_AGREEMENT": "yes"}}
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "neo4j:4.0", Options: opts},
|
||||
{ImageName: "neo4j:4.0-enterprise", Options: opts},
|
||||
{ImageName: "neo4j:3.5", Options: opts},
|
||||
{ImageName: "neo4j:3.5-enterprise", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func neoConnectionString(host, port string) string {
|
||||
return fmt.Sprintf("bolt://neo4j:migratetest@%s:%s", host, port)
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.Port(7687)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
driver, err := neo4j.NewDriver(
|
||||
neoConnectionString(ip, port),
|
||||
neo4j.BasicAuth("neo4j", "migratetest", ""),
|
||||
func(config *neo4j.Config) {
|
||||
config.Encrypted = false
|
||||
})
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := driver.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
session, err := driver.Session(neo4j.AccessModeRead)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
result, err := session.Run("RETURN 1", nil)
|
||||
if err != nil {
|
||||
return false
|
||||
} else if result.Err() != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(7687)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n := &Neo4j{}
|
||||
d, err := n.Open(neoConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("MATCH (a) RETURN a"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(7687)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n := &Neo4j{}
|
||||
neoUrl := neoConnectionString(ip, port) + "/?x-multi-statement=true"
|
||||
d, err := n.Open(neoUrl)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "neo4j", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMalformed(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.Port(7687)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
n := &Neo4j{}
|
||||
d, err := n.Open(neoConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
migration := bytes.NewReader([]byte("CREATE (a {qid: 1) RETURN a"))
|
||||
if err := d.Run(migration); err == nil {
|
||||
t.Fatal("expected failure for malformed migration")
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
package database_test
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const reservedChars = "!#$%&'()*+,/:;=?@[]"
|
||||
const reservedCharTestNamePrefix = "reserved char "
|
||||
|
||||
const baseUsername = "username"
|
||||
|
||||
const scheme = "database://"
|
||||
|
||||
// TestUserUnencodedReservedURLChars documents the behavior of using unencoded reserved characters in usernames with
|
||||
// net/url Parse()
|
||||
func TestUserUnencodedReservedURLChars(t *testing.T) {
|
||||
urlSuffix := "password@localhost:12345/myDB?someParam=true"
|
||||
urlSuffixAndSep := ":" + urlSuffix
|
||||
|
||||
testcases := []struct {
|
||||
char string
|
||||
parses bool
|
||||
expectedUsername string // empty string means that the username failed to parse
|
||||
encodedURL string
|
||||
}{
|
||||
{char: "!", parses: true, expectedUsername: baseUsername + "!",
|
||||
encodedURL: scheme + baseUsername + "%21" + urlSuffixAndSep},
|
||||
{char: "#", parses: true, expectedUsername: "",
|
||||
encodedURL: scheme + baseUsername + "#" + urlSuffixAndSep},
|
||||
{char: "$", parses: true, expectedUsername: baseUsername + "$",
|
||||
encodedURL: scheme + baseUsername + "$" + urlSuffixAndSep},
|
||||
{char: "%", parses: false},
|
||||
{char: "&", parses: true, expectedUsername: baseUsername + "&",
|
||||
encodedURL: scheme + baseUsername + "&" + urlSuffixAndSep},
|
||||
{char: "'", parses: true, expectedUsername: "username'",
|
||||
encodedURL: scheme + baseUsername + "%27" + urlSuffixAndSep},
|
||||
{char: "(", parses: true, expectedUsername: "username(",
|
||||
encodedURL: scheme + baseUsername + "%28" + urlSuffixAndSep},
|
||||
{char: ")", parses: true, expectedUsername: "username)",
|
||||
encodedURL: scheme + baseUsername + "%29" + urlSuffixAndSep},
|
||||
{char: "*", parses: true, expectedUsername: "username*",
|
||||
encodedURL: scheme + baseUsername + "%2A" + urlSuffixAndSep},
|
||||
{char: "+", parses: true, expectedUsername: "username+",
|
||||
encodedURL: scheme + baseUsername + "+" + urlSuffixAndSep},
|
||||
{char: ",", parses: true, expectedUsername: "username,",
|
||||
encodedURL: scheme + baseUsername + "," + urlSuffixAndSep},
|
||||
{char: "/", parses: true, expectedUsername: "",
|
||||
encodedURL: scheme + baseUsername + "/" + urlSuffixAndSep},
|
||||
{char: ":", parses: true, expectedUsername: baseUsername,
|
||||
encodedURL: scheme + baseUsername + ":%3A" + urlSuffix},
|
||||
{char: ";", parses: true, expectedUsername: "username;",
|
||||
encodedURL: scheme + baseUsername + ";" + urlSuffixAndSep},
|
||||
{char: "=", parses: true, expectedUsername: "username=",
|
||||
encodedURL: scheme + baseUsername + "=" + urlSuffixAndSep},
|
||||
{char: "?", parses: true, expectedUsername: "",
|
||||
encodedURL: scheme + baseUsername + "?" + urlSuffixAndSep},
|
||||
{char: "@", parses: true, expectedUsername: "username@",
|
||||
encodedURL: scheme + baseUsername + "%40" + urlSuffixAndSep},
|
||||
{char: "[", parses: false},
|
||||
{char: "]", parses: false},
|
||||
}
|
||||
|
||||
testedChars := make([]string, 0, len(reservedChars))
|
||||
for _, tc := range testcases {
|
||||
testedChars = append(testedChars, tc.char)
|
||||
t.Run(reservedCharTestNamePrefix+tc.char, func(t *testing.T) {
|
||||
s := scheme + baseUsername + tc.char + urlSuffixAndSep
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
if !tc.parses {
|
||||
t.Error("Unexpectedly parsed reserved character. url:", s)
|
||||
return
|
||||
}
|
||||
var username string
|
||||
if u.User != nil {
|
||||
username = u.User.Username()
|
||||
}
|
||||
if username != tc.expectedUsername {
|
||||
t.Error("Got unexpected username:", username, "!=", tc.expectedUsername)
|
||||
}
|
||||
if s := u.String(); s != tc.encodedURL {
|
||||
t.Error("Got unexpected encoded URL:", s, "!=", tc.encodedURL)
|
||||
}
|
||||
} else {
|
||||
if tc.parses {
|
||||
t.Error("Failed to parse reserved character. url:", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("All reserved chars tested", func(t *testing.T) {
|
||||
if s := strings.Join(testedChars, ""); s != reservedChars {
|
||||
t.Error("Not all reserved URL characters were tested:", s, "!=", reservedChars)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserEncodedReservedURLChars(t *testing.T) {
|
||||
urlSuffix := "password@localhost:12345/myDB?someParam=true"
|
||||
urlSuffixAndSep := ":" + urlSuffix
|
||||
|
||||
for _, c := range reservedChars {
|
||||
c := string(c)
|
||||
t.Run(reservedCharTestNamePrefix+c, func(t *testing.T) {
|
||||
encodedChar := "%" + hex.EncodeToString([]byte(c))
|
||||
s := scheme + baseUsername + encodedChar + urlSuffixAndSep
|
||||
expectedUsername := baseUsername + c
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse url with encoded reserved character. url:", s)
|
||||
}
|
||||
if u.User == nil {
|
||||
t.Fatal("Failed to parse userinfo with encoded reserve character. url:", s)
|
||||
}
|
||||
if username := u.User.Username(); username != expectedUsername {
|
||||
t.Fatal("Got unexpected username:", username, "!=", expectedUsername)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPasswordUnencodedReservedURLChars documents the behavior of using unencoded reserved characters in passwords
|
||||
// with net/url Parse()
|
||||
func TestPasswordUnencodedReservedURLChars(t *testing.T) {
|
||||
username := baseUsername
|
||||
schemeAndUsernameAndSep := scheme + username + ":"
|
||||
basePassword := "password"
|
||||
urlSuffixAndSep := "@localhost:12345/myDB?someParam=true"
|
||||
|
||||
testcases := []struct {
|
||||
char string
|
||||
parses bool
|
||||
expectedUsername string // empty string means that the username failed to parse
|
||||
expectedPassword string // empty string means that the password failed to parse
|
||||
encodedURL string
|
||||
}{
|
||||
{char: "!", parses: true, expectedUsername: username, expectedPassword: basePassword + "!",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%21" + urlSuffixAndSep},
|
||||
{char: "#", parses: false},
|
||||
{char: "$", parses: true, expectedUsername: username, expectedPassword: basePassword + "$",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "$" + urlSuffixAndSep},
|
||||
{char: "%", parses: false},
|
||||
{char: "&", parses: true, expectedUsername: username, expectedPassword: basePassword + "&",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "&" + urlSuffixAndSep},
|
||||
{char: "'", parses: true, expectedUsername: username, expectedPassword: "password'",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%27" + urlSuffixAndSep},
|
||||
{char: "(", parses: true, expectedUsername: username, expectedPassword: "password(",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%28" + urlSuffixAndSep},
|
||||
{char: ")", parses: true, expectedUsername: username, expectedPassword: "password)",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%29" + urlSuffixAndSep},
|
||||
{char: "*", parses: true, expectedUsername: username, expectedPassword: "password*",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%2A" + urlSuffixAndSep},
|
||||
{char: "+", parses: true, expectedUsername: username, expectedPassword: "password+",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "+" + urlSuffixAndSep},
|
||||
{char: ",", parses: true, expectedUsername: username, expectedPassword: "password,",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "," + urlSuffixAndSep},
|
||||
{char: "/", parses: false},
|
||||
{char: ":", parses: true, expectedUsername: username, expectedPassword: "password:",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%3A" + urlSuffixAndSep},
|
||||
{char: ";", parses: true, expectedUsername: username, expectedPassword: "password;",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + ";" + urlSuffixAndSep},
|
||||
{char: "=", parses: true, expectedUsername: username, expectedPassword: "password=",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "=" + urlSuffixAndSep},
|
||||
{char: "?", parses: false},
|
||||
{char: "@", parses: true, expectedUsername: username, expectedPassword: "password@",
|
||||
encodedURL: schemeAndUsernameAndSep + basePassword + "%40" + urlSuffixAndSep},
|
||||
{char: "[", parses: false},
|
||||
{char: "]", parses: false},
|
||||
}
|
||||
|
||||
testedChars := make([]string, 0, len(reservedChars))
|
||||
for _, tc := range testcases {
|
||||
testedChars = append(testedChars, tc.char)
|
||||
t.Run(reservedCharTestNamePrefix+tc.char, func(t *testing.T) {
|
||||
s := schemeAndUsernameAndSep + basePassword + tc.char + urlSuffixAndSep
|
||||
u, err := url.Parse(s)
|
||||
if err == nil {
|
||||
if !tc.parses {
|
||||
t.Error("Unexpectedly parsed reserved character. url:", s)
|
||||
return
|
||||
}
|
||||
var username, password string
|
||||
if u.User != nil {
|
||||
username = u.User.Username()
|
||||
password, _ = u.User.Password()
|
||||
}
|
||||
if username != tc.expectedUsername {
|
||||
t.Error("Got unexpected username:", username, "!=", tc.expectedUsername)
|
||||
}
|
||||
if password != tc.expectedPassword {
|
||||
t.Error("Got unexpected password:", password, "!=", tc.expectedPassword)
|
||||
}
|
||||
if s := u.String(); s != tc.encodedURL {
|
||||
t.Error("Got unexpected encoded URL:", s, "!=", tc.encodedURL)
|
||||
}
|
||||
} else {
|
||||
if tc.parses {
|
||||
t.Error("Failed to parse reserved character. url:", s)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("All reserved chars tested", func(t *testing.T) {
|
||||
if s := strings.Join(testedChars, ""); s != reservedChars {
|
||||
t.Error("Not all reserved URL characters were tested:", s, "!=", reservedChars)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPasswordEncodedReservedURLChars(t *testing.T) {
|
||||
username := baseUsername
|
||||
schemeAndUsernameAndSep := scheme + username + ":"
|
||||
basePassword := "password"
|
||||
urlSuffixAndSep := "@localhost:12345/myDB?someParam=true"
|
||||
|
||||
for _, c := range reservedChars {
|
||||
c := string(c)
|
||||
t.Run(reservedCharTestNamePrefix+c, func(t *testing.T) {
|
||||
encodedChar := "%" + hex.EncodeToString([]byte(c))
|
||||
s := schemeAndUsernameAndSep + basePassword + encodedChar + urlSuffixAndSep
|
||||
expectedPassword := basePassword + c
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
t.Fatal("Failed to parse url with encoded reserved character. url:", s)
|
||||
}
|
||||
if u.User == nil {
|
||||
t.Fatal("Failed to parse userinfo with encoded reserve character. url:", s)
|
||||
}
|
||||
if n := u.User.Username(); n != username {
|
||||
t.Fatal("Got unexpected username:", n, "!=", username)
|
||||
}
|
||||
if p, _ := u.User.Password(); p != expectedPassword {
|
||||
t.Fatal("Got unexpected password:", p, "!=", expectedPassword)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
# pgx
|
||||
|
||||
This package is for [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). A backend for the newer [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5) is [also available](v5).
|
||||
|
||||
`pgx://user:password@host:port/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
|
||||
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
|
||||
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
|
||||
| `x-lock-strategy` | `LockStrategy` | Strategy used for locking during migration (default: advisory) |
|
||||
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock (default: schema_lock) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
|
||||
## Multi-statement mode
|
||||
|
||||
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
|
||||
behavior is not desirable because some statements can be only run outside of transaction (e.g.
|
||||
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
|
||||
you have to put such statements in a separate migration files.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS users;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE users (
|
||||
user_id integer unique,
|
||||
name varchar(40),
|
||||
email varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS city;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE users ADD COLUMN city varchar(100);
|
||||
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS users_email_index;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX CONCURRENTLY users_email_index ON users (email);
|
||||
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS books;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE books (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
author varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS movies;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE movies (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
director varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
@@ -0,0 +1,623 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgerrcode"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
LockStrategyAdvisory = "advisory"
|
||||
LockStrategyTable = "table"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Postgres{}
|
||||
database.Register("pgx", &db)
|
||||
database.Register("pgx4", &db)
|
||||
}
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMigrationsTable = "schema_migrations"
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
DefaultLockTable = "schema_lock"
|
||||
DefaultLockStrategy = LockStrategyAdvisory
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrNoSchema = fmt.Errorf("no schema")
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
LockTable string
|
||||
LockStrategy string
|
||||
migrationsSchemaName string
|
||||
migrationsTableName string
|
||||
StatementTimeout time.Duration
|
||||
MigrationsTableQuoted bool
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT CURRENT_DATABASE()`
|
||||
var databaseName string
|
||||
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName
|
||||
}
|
||||
|
||||
if config.SchemaName == "" {
|
||||
query := `SELECT CURRENT_SCHEMA()`
|
||||
var schemaName string
|
||||
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(schemaName) == 0 {
|
||||
return nil, ErrNoSchema
|
||||
}
|
||||
|
||||
config.SchemaName = schemaName
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if len(config.LockTable) == 0 {
|
||||
config.LockTable = DefaultLockTable
|
||||
}
|
||||
|
||||
if len(config.LockStrategy) == 0 {
|
||||
config.LockStrategy = DefaultLockStrategy
|
||||
}
|
||||
|
||||
config.migrationsSchemaName = config.SchemaName
|
||||
config.migrationsTableName = config.MigrationsTable
|
||||
if config.MigrationsTableQuoted {
|
||||
re := regexp.MustCompile(`"(.*?)"`)
|
||||
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
|
||||
config.migrationsTableName = result[len(result)-1][1]
|
||||
if len(result) == 2 {
|
||||
config.migrationsSchemaName = result[0][1]
|
||||
} else if len(result) > 2 {
|
||||
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
px := &Postgres{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := px.ensureLockTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Driver is registered as pgx, but connection string must use postgres schema
|
||||
// when making actual connection
|
||||
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
|
||||
purl.Scheme = "postgres"
|
||||
|
||||
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
migrationsTableQuoted := false
|
||||
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
|
||||
migrationsTableQuoted, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
|
||||
}
|
||||
}
|
||||
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
|
||||
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
|
||||
}
|
||||
|
||||
statementTimeoutString := purl.Query().Get("x-statement-timeout")
|
||||
statementTimeout := 0
|
||||
if statementTimeoutString != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if multiStatementMaxSize <= 0 {
|
||||
multiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementEnabled := false
|
||||
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
|
||||
multiStatementEnabled, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
lockStrategy := purl.Query().Get("x-lock-strategy")
|
||||
lockTable := purl.Query().Get("x-lock-table")
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
MigrationsTableQuoted: migrationsTableQuoted,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
MultiStatementEnabled: multiStatementEnabled,
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
LockStrategy: lockStrategy,
|
||||
LockTable: lockTable,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Close() error {
|
||||
connErr := p.conn.Close()
|
||||
dbErr := p.db.Close()
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Lock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
|
||||
switch p.config.LockStrategy {
|
||||
case LockStrategyAdvisory:
|
||||
return p.applyAdvisoryLock()
|
||||
case LockStrategyTable:
|
||||
return p.applyTableLock()
|
||||
default:
|
||||
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Unlock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
switch p.config.LockStrategy {
|
||||
case LockStrategyAdvisory:
|
||||
return p.releaseAdvisoryLock()
|
||||
case LockStrategyTable:
|
||||
return p.releaseTableLock()
|
||||
default:
|
||||
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) applyAdvisoryLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This will wait indefinitely until the lock can be acquired.
|
||||
query := `SELECT pg_advisory_lock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) applyTableLock() error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
defer func() {
|
||||
errRollback := tx.Rollback()
|
||||
if errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
}()
|
||||
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
|
||||
rows, err := tx.Query(query, aid)
|
||||
if err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if errClose := rows.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// If row exists at all, lock is present
|
||||
locked := rows.Next()
|
||||
if locked {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
|
||||
if _, err := tx.Exec(query, aid); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *Postgres) releaseAdvisoryLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT pg_advisory_unlock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) releaseTableLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
|
||||
if _, err := p.db.Exec(query, aid); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Run(migration io.Reader) error {
|
||||
if p.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
if err = p.runStatement(m); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.runStatement(migr)
|
||||
}
|
||||
|
||||
func (p *Postgres) runStatement(statement []byte) error {
|
||||
ctx := context.Background()
|
||||
if p.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
query := string(statement)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
|
||||
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
||||
// replace crlf with lf
|
||||
s = strings.Replace(s, "\r\n", "\n", -1)
|
||||
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
||||
runes := []rune(s)
|
||||
if pos > len(runes) {
|
||||
return 0, 0, false
|
||||
}
|
||||
sel := runes[:pos]
|
||||
line = uint(runesCount(sel, newLine) + 1)
|
||||
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
||||
return line, col, true
|
||||
}
|
||||
|
||||
const newLine = '\n'
|
||||
|
||||
func runesCount(input []rune, target rune) int {
|
||||
var count int
|
||||
for _, r := range input {
|
||||
if r == target {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runesLastIndex(input []rune, target rune) int {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
if input[i] == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *Postgres) SetVersion(version int, dirty bool) error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
|
||||
if _, err := tx.Exec(query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
|
||||
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pgconn.PgError); ok {
|
||||
if e.SQLState() == pgerrcode.UndefinedTable {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Postgres) Drop() (err error) {
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
||||
tables, err := p.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// do not drop lock table
|
||||
if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Postgres type.
|
||||
func (p *Postgres) ensureVersionTable() (err error) {
|
||||
if err = p.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := p.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
|
||||
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
|
||||
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
|
||||
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
|
||||
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
|
||||
var count int
|
||||
err = row.Scan(&count)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
|
||||
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) ensureLockTable() error {
|
||||
if p.config.LockStrategy != LockStrategyTable {
|
||||
return nil
|
||||
}
|
||||
|
||||
var count int
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
|
||||
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
|
||||
if _, err := p.db.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
|
||||
func quoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
||||
}
|
||||
@@ -0,0 +1,789 @@
|
||||
package pgx
|
||||
|
||||
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const (
|
||||
pgPassword = "postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
|
||||
PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: https://www.postgresql.org/support/versioning/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "postgres:9.5", Options: opts},
|
||||
{ImageName: "postgres:9.6", Options: opts},
|
||||
{ImageName: "postgres:10", Options: opts},
|
||||
{ImageName: "postgres:11", Options: opts},
|
||||
{ImageName: "postgres:12", Options: opts},
|
||||
{ImageName: "postgres:13", Options: opts},
|
||||
{ImageName: "postgres:14", Options: opts},
|
||||
{ImageName: "postgres:15", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func pgConnectionString(host, port string, options ...string) string {
|
||||
options = append(options, "sslmode=disable")
|
||||
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func mustRun(t *testing.T, d database.Driver, statements []string) {
|
||||
for _, statement := range statements {
|
||||
if err := d.Run(strings.NewReader(statement)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateLockTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatements(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-multi-statement=true")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure created index exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
|
||||
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-custom=foobar")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foobar schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.SetVersion(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
version, _, err := d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("expected NilVersion")
|
||||
}
|
||||
|
||||
// now update version and compare
|
||||
if err := d2.SetVersion(2, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _, err = d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
|
||||
// meanwhile, the public schema still has the other version
|
||||
version, _, err = d.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationTableOption(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, _ := p.Open(addr)
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create migrate schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// bad unquoted x-migrations-table parameter
|
||||
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// too many quoted x-migrations-table parameters
|
||||
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// good quoted x-migrations-table parameter
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure migrate.schema_migrations table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table migrate.schema_migrations to exist")
|
||||
}
|
||||
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
defer func() {
|
||||
if d2 == nil {
|
||||
return
|
||||
}
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var e *database.Error
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
|
||||
// re-connect using that x-migrations-table and x-migrations-table-quoted
|
||||
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckBeforeCreateTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// revoke privileges
|
||||
mustRun(t, d, []string{
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _, err := d3.Version()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d3.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestParallelSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foo and bar schemas
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schemas
|
||||
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dfoo.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dbar.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := dfoo.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dfoo.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgres_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ps := d.(*Postgres)
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance_Concurrent(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The number of concurrent processes running WithInstance
|
||||
const concurrency = 30
|
||||
|
||||
// We can instantiate a single database handle because it is
|
||||
// actually a connection pool, and so, each of the below go
|
||||
// routines will have a high probability of using a separate
|
||||
// connection, which is something we want to exercise.
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetMaxIdleConns(concurrency)
|
||||
db.SetMaxOpenConns(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_, err := WithInstance(db, &Config{})
|
||||
if err != nil {
|
||||
t.Errorf("process %d error: %s", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos int
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
|
||||
},
|
||||
{
|
||||
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
|
||||
},
|
||||
{
|
||||
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
|
||||
},
|
||||
{
|
||||
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
|
||||
},
|
||||
{
|
||||
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
|
||||
},
|
||||
{
|
||||
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
|
||||
},
|
||||
{
|
||||
17, 2, 8, "SELECT *\nFROM foo", true, // last character
|
||||
},
|
||||
{
|
||||
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
run := func(crlf bool, nonASCII bool) {
|
||||
var name string
|
||||
if crlf {
|
||||
name = "crlf"
|
||||
} else {
|
||||
name = "lf"
|
||||
}
|
||||
if nonASCII {
|
||||
name += "-nonascii"
|
||||
} else {
|
||||
name += "-ascii"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
input := tc.input
|
||||
if crlf {
|
||||
input = strings.Replace(input, "\n", "\r\n", -1)
|
||||
}
|
||||
if nonASCII {
|
||||
input = strings.Replace(input, "FROM", "FRÖM", -1)
|
||||
}
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
|
||||
|
||||
if tc.wantOk {
|
||||
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
|
||||
}
|
||||
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
run(false, false)
|
||||
run(true, false)
|
||||
run(false, true)
|
||||
run(true, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
# pgx
|
||||
|
||||
This package is for [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5). A backend for the older [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). is [also available](..).
|
||||
|
||||
`pgx5://user:password@host:port/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
|
||||
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
|
||||
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
|
||||
## Multi-statement mode
|
||||
|
||||
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
|
||||
behavior is not desirable because some statements can be only run outside of transaction (e.g.
|
||||
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
|
||||
you have to put such statements in a separate migration files.
|
||||
@@ -0,0 +1,486 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Postgres{}
|
||||
database.Register("pgx5", &db)
|
||||
}
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMigrationsTable = "schema_migrations"
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrNoSchema = fmt.Errorf("no schema")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
migrationsSchemaName string
|
||||
migrationsTableName string
|
||||
StatementTimeout time.Duration
|
||||
MigrationsTableQuoted bool
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT CURRENT_DATABASE()`
|
||||
var databaseName string
|
||||
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName
|
||||
}
|
||||
|
||||
if config.SchemaName == "" {
|
||||
query := `SELECT CURRENT_SCHEMA()`
|
||||
var schemaName string
|
||||
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(schemaName) == 0 {
|
||||
return nil, ErrNoSchema
|
||||
}
|
||||
|
||||
config.SchemaName = schemaName
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
config.migrationsSchemaName = config.SchemaName
|
||||
config.migrationsTableName = config.MigrationsTable
|
||||
if config.MigrationsTableQuoted {
|
||||
re := regexp.MustCompile(`"(.*?)"`)
|
||||
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
|
||||
config.migrationsTableName = result[len(result)-1][1]
|
||||
if len(result) == 2 {
|
||||
config.migrationsSchemaName = result[0][1]
|
||||
} else if len(result) > 2 {
|
||||
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
px := &Postgres{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Driver is registered as pgx, but connection string must use postgres schema
|
||||
// when making actual connection
|
||||
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
|
||||
purl.Scheme = "postgres"
|
||||
|
||||
db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
migrationsTableQuoted := false
|
||||
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
|
||||
migrationsTableQuoted, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
|
||||
}
|
||||
}
|
||||
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
|
||||
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
|
||||
}
|
||||
|
||||
statementTimeoutString := purl.Query().Get("x-statement-timeout")
|
||||
statementTimeout := 0
|
||||
if statementTimeoutString != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if multiStatementMaxSize <= 0 {
|
||||
multiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementEnabled := false
|
||||
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
|
||||
multiStatementEnabled, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
MigrationsTableQuoted: migrationsTableQuoted,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
MultiStatementEnabled: multiStatementEnabled,
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Close() error {
|
||||
connErr := p.conn.Close()
|
||||
dbErr := p.db.Close()
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) Lock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This will wait indefinitely until the lock can be acquired.
|
||||
query := `SELECT pg_advisory_lock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Unlock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT pg_advisory_unlock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Run(migration io.Reader) error {
|
||||
if p.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
if err = p.runStatement(m); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.runStatement(migr)
|
||||
}
|
||||
|
||||
func (p *Postgres) runStatement(statement []byte) error {
|
||||
ctx := context.Background()
|
||||
if p.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
query := string(statement)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
|
||||
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
||||
// replace crlf with lf
|
||||
s = strings.Replace(s, "\r\n", "\n", -1)
|
||||
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
||||
runes := []rune(s)
|
||||
if pos > len(runes) {
|
||||
return 0, 0, false
|
||||
}
|
||||
sel := runes[:pos]
|
||||
line = uint(runesCount(sel, newLine) + 1)
|
||||
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
||||
return line, col, true
|
||||
}
|
||||
|
||||
const newLine = '\n'
|
||||
|
||||
func runesCount(input []rune, target rune) int {
|
||||
var count int
|
||||
for _, r := range input {
|
||||
if r == target {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runesLastIndex(input []rune, target rune) int {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
if input[i] == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *Postgres) SetVersion(version int, dirty bool) error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
|
||||
if _, err := tx.Exec(query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
|
||||
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pgconn.PgError); ok {
|
||||
if e.SQLState() == pgerrcode.UndefinedTable {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Postgres) Drop() (err error) {
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
||||
tables, err := p.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Postgres type.
|
||||
func (p *Postgres) ensureVersionTable() (err error) {
|
||||
if err = p.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := p.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
|
||||
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
|
||||
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
|
||||
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
|
||||
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
|
||||
var count int
|
||||
err = row.Scan(&count)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
|
||||
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
|
||||
func quoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
||||
}
|
||||
@@ -0,0 +1,764 @@
|
||||
package pgx
|
||||
|
||||
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const (
|
||||
pgPassword = "postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
|
||||
PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: https://www.postgresql.org/support/versioning/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "postgres:9.5", Options: opts},
|
||||
{ImageName: "postgres:9.6", Options: opts},
|
||||
{ImageName: "postgres:10", Options: opts},
|
||||
{ImageName: "postgres:11", Options: opts},
|
||||
{ImageName: "postgres:12", Options: opts},
|
||||
{ImageName: "postgres:13", Options: opts},
|
||||
{ImageName: "postgres:14", Options: opts},
|
||||
{ImageName: "postgres:15", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func pgConnectionString(host, port string, options ...string) string {
|
||||
options = append(options, "sslmode=disable")
|
||||
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func mustRun(t *testing.T, d database.Driver, statements []string) {
|
||||
for _, statement := range statements {
|
||||
if err := d.Run(strings.NewReader(statement)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatements(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-multi-statement=true")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure created index exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
|
||||
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-custom=foobar")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foobar schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.SetVersion(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
version, _, err := d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("expected NilVersion")
|
||||
}
|
||||
|
||||
// now update version and compare
|
||||
if err := d2.SetVersion(2, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _, err = d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
|
||||
// meanwhile, the public schema still has the other version
|
||||
version, _, err = d.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationTableOption(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, _ := p.Open(addr)
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create migrate schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// bad unquoted x-migrations-table parameter
|
||||
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// too many quoted x-migrations-table parameters
|
||||
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// good quoted x-migrations-table parameter
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure migrate.schema_migrations table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table migrate.schema_migrations to exist")
|
||||
}
|
||||
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
defer func() {
|
||||
if d2 == nil {
|
||||
return
|
||||
}
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var e *database.Error
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
|
||||
// re-connect using that x-migrations-table and x-migrations-table-quoted
|
||||
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckBeforeCreateTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// revoke privileges
|
||||
mustRun(t, d, []string{
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _, err := d3.Version()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d3.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestParallelSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foo and bar schemas
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schemas
|
||||
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dfoo.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dbar.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := dfoo.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dfoo.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgres_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ps := d.(*Postgres)
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance_Concurrent(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The number of concurrent processes running WithInstance
|
||||
const concurrency = 30
|
||||
|
||||
// We can instantiate a single database handle because it is
|
||||
// actually a connection pool, and so, each of the below go
|
||||
// routines will have a high probability of using a separate
|
||||
// connection, which is something we want to exercise.
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetMaxIdleConns(concurrency)
|
||||
db.SetMaxOpenConns(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_, err := WithInstance(db, &Config{})
|
||||
if err != nil {
|
||||
t.Errorf("process %d error: %s", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos int
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
|
||||
},
|
||||
{
|
||||
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
|
||||
},
|
||||
{
|
||||
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
|
||||
},
|
||||
{
|
||||
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
|
||||
},
|
||||
{
|
||||
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
|
||||
},
|
||||
{
|
||||
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
|
||||
},
|
||||
{
|
||||
17, 2, 8, "SELECT *\nFROM foo", true, // last character
|
||||
},
|
||||
{
|
||||
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
run := func(crlf bool, nonASCII bool) {
|
||||
var name string
|
||||
if crlf {
|
||||
name = "crlf"
|
||||
} else {
|
||||
name = "lf"
|
||||
}
|
||||
if nonASCII {
|
||||
name += "-nonascii"
|
||||
} else {
|
||||
name += "-ascii"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
input := tc.input
|
||||
if crlf {
|
||||
input = strings.Replace(input, "\n", "\r\n", -1)
|
||||
}
|
||||
if nonASCII {
|
||||
input = strings.Replace(input, "FROM", "FRÖM", -1)
|
||||
}
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
|
||||
|
||||
if tc.wantOk {
|
||||
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
|
||||
}
|
||||
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
run(false, false)
|
||||
run(true, false)
|
||||
run(false, true)
|
||||
run(true, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
# postgres
|
||||
|
||||
`postgres://user:password@host:port/dbname?query` (`postgresql://` works, too)
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
|
||||
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
|
||||
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
|
||||
## Multi-statement mode
|
||||
|
||||
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
|
||||
behavior is not desirable because some statements can be only run outside of transaction (e.g.
|
||||
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
|
||||
you have to put such statements in a separate migration files.
|
||||
+167
@@ -0,0 +1,167 @@
|
||||
# PostgreSQL tutorial for beginners
|
||||
|
||||
## Create/configure database
|
||||
|
||||
For the purpose of this tutorial let's create PostgreSQL database called `example`.
|
||||
Our user here is `postgres`, password `password`, and host is `localhost`.
|
||||
```
|
||||
psql -h localhost -U postgres -w -c "create database example;"
|
||||
```
|
||||
When using Migrate CLI we need to pass to database URL. Let's export it to a variable for convenience:
|
||||
```
|
||||
export POSTGRESQL_URL='postgres://postgres:password@localhost:5432/example?sslmode=disable'
|
||||
```
|
||||
`sslmode=disable` means that the connection with our database will not be encrypted. Enabling it is left as an exercise.
|
||||
|
||||
You can find further description of database URLs [here](README.md#database-urls).
|
||||
|
||||
## Create migrations
|
||||
Let's create table called `users`:
|
||||
```
|
||||
migrate create -ext sql -dir db/migrations -seq create_users_table
|
||||
```
|
||||
If there were no errors, we should have two files available under `db/migrations` folder:
|
||||
- 000001_create_users_table.down.sql
|
||||
- 000001_create_users_table.up.sql
|
||||
|
||||
Note the `sql` extension that we provided.
|
||||
|
||||
In the `.up.sql` file let's create the table:
|
||||
```sql
|
||||
CREATE TABLE IF NOT EXISTS users(
|
||||
user_id serial PRIMARY KEY,
|
||||
username VARCHAR (50) UNIQUE NOT NULL,
|
||||
password VARCHAR (50) NOT NULL,
|
||||
email VARCHAR (300) UNIQUE NOT NULL
|
||||
);
|
||||
```
|
||||
And in the `.down.sql` let's delete it:
|
||||
```sql
|
||||
DROP TABLE IF EXISTS users;
|
||||
```
|
||||
By adding `IF EXISTS/IF NOT EXISTS` we are making migrations idempotent - you can read more about idempotency in [getting started](../../GETTING_STARTED.md#create-migrations)
|
||||
|
||||
## Run migrations
|
||||
```
|
||||
migrate -database ${POSTGRESQL_URL} -path db/migrations up
|
||||
```
|
||||
Let's check if the table was created properly by running `psql example -c "\d users"`.
|
||||
The output you are supposed to see:
|
||||
```
|
||||
Table "public.users"
|
||||
Column | Type | Modifiers
|
||||
----------+------------------------+---------------------------------------------------------
|
||||
user_id | integer | not null default nextval('users_user_id_seq'::regclass)
|
||||
username | character varying(50) | not null
|
||||
password | character varying(50) | not null
|
||||
email | character varying(300) | not null
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (user_id)
|
||||
"users_email_key" UNIQUE CONSTRAINT, btree (email)
|
||||
"users_username_key" UNIQUE CONSTRAINT, btree (username)
|
||||
```
|
||||
Great! Now let's check if running reverse migration also works:
|
||||
```
|
||||
migrate -database ${POSTGRESQL_URL} -path db/migrations down
|
||||
```
|
||||
Make sure to check if your database changed as expected in this case as well.
|
||||
|
||||
## Database transactions
|
||||
|
||||
To show database transactions usage, let's create another set of migrations by running:
|
||||
```
|
||||
migrate create -ext sql -dir db/migrations -seq add_mood_to_users
|
||||
```
|
||||
Again, it should create for us two migrations files:
|
||||
- 000002_add_mood_to_users.down.sql
|
||||
- 000002_add_mood_to_users.up.sql
|
||||
|
||||
In Postgres, when we want our queries to be done in a transaction, we need to wrap it with `BEGIN` and `COMMIT` commands.
|
||||
In our example, we are going to add a column to our database that can only accept enumerable values or NULL.
|
||||
Migration up:
|
||||
```sql
|
||||
BEGIN;
|
||||
|
||||
CREATE TYPE enum_mood AS ENUM (
|
||||
'happy',
|
||||
'sad',
|
||||
'neutral'
|
||||
);
|
||||
ALTER TABLE users ADD COLUMN mood enum_mood;
|
||||
|
||||
COMMIT;
|
||||
```
|
||||
Migration down:
|
||||
```sql
|
||||
BEGIN;
|
||||
|
||||
ALTER TABLE users DROP COLUMN mood;
|
||||
DROP TYPE enum_mood;
|
||||
|
||||
COMMIT;
|
||||
```
|
||||
|
||||
Now we can run our new migration and check the database:
|
||||
```
|
||||
migrate -database ${POSTGRESQL_URL} -path db/migrations up
|
||||
psql example -c "\d users"
|
||||
```
|
||||
Expected output:
|
||||
```
|
||||
Table "public.users"
|
||||
Column | Type | Modifiers
|
||||
----------+------------------------+---------------------------------------------------------
|
||||
user_id | integer | not null default nextval('users_user_id_seq'::regclass)
|
||||
username | character varying(50) | not null
|
||||
password | character varying(50) | not null
|
||||
email | character varying(300) | not null
|
||||
mood | enum_mood |
|
||||
Indexes:
|
||||
"users_pkey" PRIMARY KEY, btree (user_id)
|
||||
"users_email_key" UNIQUE CONSTRAINT, btree (email)
|
||||
"users_username_key" UNIQUE CONSTRAINT, btree (username)
|
||||
```
|
||||
|
||||
## Optional: Run migrations within your Go app
|
||||
Here is a very simple app running migrations for the above configuration:
|
||||
```go
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
func main() {
|
||||
m, err := migrate.New(
|
||||
"file://db/migrations",
|
||||
"postgres://postgres:postgres@localhost:5432/example?sslmode=disable")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := m.Up(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
```
|
||||
You can find details [here](README.md#use-in-your-go-project)
|
||||
|
||||
## Fix issue where migrations run twice
|
||||
|
||||
When the schema and role names are the same, you might run into issues if you create this schema using migrations.
|
||||
This is caused by the fact that the [default `search_path`](https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH) is `"$user", public`.
|
||||
In the first run (with an empty database) the migrate table is created in `public`.
|
||||
When the migrations create the `$user` schema, the next run will store (a new) migrate table in this schema (due to order of schemas in `search_path`) and tries to apply all migrations again (most likely failing).
|
||||
|
||||
To solve this you need to change the default `search_path` by removing the `$user` component, so the migrate table is always stored in the (available) `public` schema.
|
||||
This can be done using the [`search_path` query parameter in the URL](https://github.com/jexia/migrate/blob/fix-postgres-version-table/database/postgres/README.md#postgres).
|
||||
|
||||
For example to force the migrations table in the public schema you can use:
|
||||
```
|
||||
export POSTGRESQL_URL='postgres://postgres:password@localhost:5432/example?sslmode=disable&search_path=public'
|
||||
```
|
||||
|
||||
Note that you need to explicitly add the schema names to the table names in your migrations when you to modify the tables of the non-public schema.
|
||||
|
||||
Alternatively you can add the non-public schema manually (before applying the migrations) if that is possible in your case and let the tool store the migrations table in this schema as well.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS users;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user