whatcanGOwrong
This commit is contained in:
@@ -0,0 +1,43 @@
|
||||
# pgx
|
||||
|
||||
This package is for [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). A backend for the newer [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5) is [also available](v5).
|
||||
|
||||
`pgx://user:password@host:port/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
|
||||
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
|
||||
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
|
||||
| `x-lock-strategy` | `LockStrategy` | Strategy used for locking during migration (default: advisory) |
|
||||
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock (default: schema_lock) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
|
||||
## Multi-statement mode
|
||||
|
||||
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
|
||||
behavior is not desirable because some statements can be only run outside of transaction (e.g.
|
||||
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
|
||||
you have to put such statements in a separate migration files.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS users;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE users (
|
||||
user_id integer unique,
|
||||
name varchar(40),
|
||||
email varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
ALTER TABLE users DROP COLUMN IF EXISTS city;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
ALTER TABLE users ADD COLUMN city varchar(100);
|
||||
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP INDEX IF EXISTS users_email_index;
|
||||
+3
@@ -0,0 +1,3 @@
|
||||
CREATE UNIQUE INDEX CONCURRENTLY users_email_index ON users (email);
|
||||
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS books;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE books (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
author varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS movies;
|
||||
+5
@@ -0,0 +1,5 @@
|
||||
CREATE TABLE movies (
|
||||
user_id integer,
|
||||
name varchar(40),
|
||||
director varchar(40)
|
||||
);
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
+1
@@ -0,0 +1 @@
|
||||
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
|
||||
@@ -0,0 +1,623 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgerrcode"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
LockStrategyAdvisory = "advisory"
|
||||
LockStrategyTable = "table"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Postgres{}
|
||||
database.Register("pgx", &db)
|
||||
database.Register("pgx4", &db)
|
||||
}
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMigrationsTable = "schema_migrations"
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
DefaultLockTable = "schema_lock"
|
||||
DefaultLockStrategy = LockStrategyAdvisory
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrNoSchema = fmt.Errorf("no schema")
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
LockTable string
|
||||
LockStrategy string
|
||||
migrationsSchemaName string
|
||||
migrationsTableName string
|
||||
StatementTimeout time.Duration
|
||||
MigrationsTableQuoted bool
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT CURRENT_DATABASE()`
|
||||
var databaseName string
|
||||
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName
|
||||
}
|
||||
|
||||
if config.SchemaName == "" {
|
||||
query := `SELECT CURRENT_SCHEMA()`
|
||||
var schemaName string
|
||||
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(schemaName) == 0 {
|
||||
return nil, ErrNoSchema
|
||||
}
|
||||
|
||||
config.SchemaName = schemaName
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
if len(config.LockTable) == 0 {
|
||||
config.LockTable = DefaultLockTable
|
||||
}
|
||||
|
||||
if len(config.LockStrategy) == 0 {
|
||||
config.LockStrategy = DefaultLockStrategy
|
||||
}
|
||||
|
||||
config.migrationsSchemaName = config.SchemaName
|
||||
config.migrationsTableName = config.MigrationsTable
|
||||
if config.MigrationsTableQuoted {
|
||||
re := regexp.MustCompile(`"(.*?)"`)
|
||||
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
|
||||
config.migrationsTableName = result[len(result)-1][1]
|
||||
if len(result) == 2 {
|
||||
config.migrationsSchemaName = result[0][1]
|
||||
} else if len(result) > 2 {
|
||||
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
px := &Postgres{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := px.ensureLockTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Driver is registered as pgx, but connection string must use postgres schema
|
||||
// when making actual connection
|
||||
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
|
||||
purl.Scheme = "postgres"
|
||||
|
||||
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
migrationsTableQuoted := false
|
||||
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
|
||||
migrationsTableQuoted, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
|
||||
}
|
||||
}
|
||||
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
|
||||
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
|
||||
}
|
||||
|
||||
statementTimeoutString := purl.Query().Get("x-statement-timeout")
|
||||
statementTimeout := 0
|
||||
if statementTimeoutString != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if multiStatementMaxSize <= 0 {
|
||||
multiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementEnabled := false
|
||||
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
|
||||
multiStatementEnabled, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
lockStrategy := purl.Query().Get("x-lock-strategy")
|
||||
lockTable := purl.Query().Get("x-lock-table")
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
MigrationsTableQuoted: migrationsTableQuoted,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
MultiStatementEnabled: multiStatementEnabled,
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
LockStrategy: lockStrategy,
|
||||
LockTable: lockTable,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Close() error {
|
||||
connErr := p.conn.Close()
|
||||
dbErr := p.db.Close()
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Lock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
|
||||
switch p.config.LockStrategy {
|
||||
case LockStrategyAdvisory:
|
||||
return p.applyAdvisoryLock()
|
||||
case LockStrategyTable:
|
||||
return p.applyTableLock()
|
||||
default:
|
||||
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Unlock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
switch p.config.LockStrategy {
|
||||
case LockStrategyAdvisory:
|
||||
return p.releaseAdvisoryLock()
|
||||
case LockStrategyTable:
|
||||
return p.releaseTableLock()
|
||||
default:
|
||||
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) applyAdvisoryLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This will wait indefinitely until the lock can be acquired.
|
||||
query := `SELECT pg_advisory_lock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) applyTableLock() error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
defer func() {
|
||||
errRollback := tx.Rollback()
|
||||
if errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
}()
|
||||
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
|
||||
rows, err := tx.Query(query, aid)
|
||||
if err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if errClose := rows.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// If row exists at all, lock is present
|
||||
locked := rows.Next()
|
||||
if locked {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
||||
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
|
||||
if _, err := tx.Exec(query, aid); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (p *Postgres) releaseAdvisoryLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT pg_advisory_unlock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) releaseTableLock() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
|
||||
if _, err := p.db.Exec(query, aid); err != nil {
|
||||
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Run(migration io.Reader) error {
|
||||
if p.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
if err = p.runStatement(m); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.runStatement(migr)
|
||||
}
|
||||
|
||||
func (p *Postgres) runStatement(statement []byte) error {
|
||||
ctx := context.Background()
|
||||
if p.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
query := string(statement)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
|
||||
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
||||
// replace crlf with lf
|
||||
s = strings.Replace(s, "\r\n", "\n", -1)
|
||||
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
||||
runes := []rune(s)
|
||||
if pos > len(runes) {
|
||||
return 0, 0, false
|
||||
}
|
||||
sel := runes[:pos]
|
||||
line = uint(runesCount(sel, newLine) + 1)
|
||||
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
||||
return line, col, true
|
||||
}
|
||||
|
||||
const newLine = '\n'
|
||||
|
||||
func runesCount(input []rune, target rune) int {
|
||||
var count int
|
||||
for _, r := range input {
|
||||
if r == target {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runesLastIndex(input []rune, target rune) int {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
if input[i] == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *Postgres) SetVersion(version int, dirty bool) error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
|
||||
if _, err := tx.Exec(query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
|
||||
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pgconn.PgError); ok {
|
||||
if e.SQLState() == pgerrcode.UndefinedTable {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Postgres) Drop() (err error) {
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
||||
tables, err := p.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// do not drop lock table
|
||||
if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Postgres type.
|
||||
func (p *Postgres) ensureVersionTable() (err error) {
|
||||
if err = p.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := p.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
|
||||
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
|
||||
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
|
||||
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
|
||||
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
|
||||
var count int
|
||||
err = row.Scan(&count)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
|
||||
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) ensureLockTable() error {
|
||||
if p.config.LockStrategy != LockStrategyTable {
|
||||
return nil
|
||||
}
|
||||
|
||||
var count int
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
|
||||
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
|
||||
if _, err := p.db.Exec(query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
|
||||
func quoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
||||
}
|
||||
@@ -0,0 +1,789 @@
|
||||
package pgx
|
||||
|
||||
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const (
|
||||
pgPassword = "postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
|
||||
PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: https://www.postgresql.org/support/versioning/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "postgres:9.5", Options: opts},
|
||||
{ImageName: "postgres:9.6", Options: opts},
|
||||
{ImageName: "postgres:10", Options: opts},
|
||||
{ImageName: "postgres:11", Options: opts},
|
||||
{ImageName: "postgres:12", Options: opts},
|
||||
{ImageName: "postgres:13", Options: opts},
|
||||
{ImageName: "postgres:14", Options: opts},
|
||||
{ImageName: "postgres:15", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func pgConnectionString(host, port string, options ...string) string {
|
||||
options = append(options, "sslmode=disable")
|
||||
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func mustRun(t *testing.T, d database.Driver, statements []string) {
|
||||
for _, statement := range statements {
|
||||
if err := d.Run(strings.NewReader(statement)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrateLockTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatements(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-multi-statement=true")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure created index exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
|
||||
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-custom=foobar")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foobar schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.SetVersion(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
version, _, err := d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("expected NilVersion")
|
||||
}
|
||||
|
||||
// now update version and compare
|
||||
if err := d2.SetVersion(2, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _, err = d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
|
||||
// meanwhile, the public schema still has the other version
|
||||
version, _, err = d.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationTableOption(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, _ := p.Open(addr)
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create migrate schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// bad unquoted x-migrations-table parameter
|
||||
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// too many quoted x-migrations-table parameters
|
||||
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// good quoted x-migrations-table parameter
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure migrate.schema_migrations table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table migrate.schema_migrations to exist")
|
||||
}
|
||||
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
defer func() {
|
||||
if d2 == nil {
|
||||
return
|
||||
}
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var e *database.Error
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
|
||||
// re-connect using that x-migrations-table and x-migrations-table-quoted
|
||||
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckBeforeCreateTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// revoke privileges
|
||||
mustRun(t, d, []string{
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _, err := d3.Version()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d3.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestParallelSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foo and bar schemas
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schemas
|
||||
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dfoo.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dbar.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := dfoo.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dfoo.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgres_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ps := d.(*Postgres)
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance_Concurrent(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The number of concurrent processes running WithInstance
|
||||
const concurrency = 30
|
||||
|
||||
// We can instantiate a single database handle because it is
|
||||
// actually a connection pool, and so, each of the below go
|
||||
// routines will have a high probability of using a separate
|
||||
// connection, which is something we want to exercise.
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetMaxIdleConns(concurrency)
|
||||
db.SetMaxOpenConns(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_, err := WithInstance(db, &Config{})
|
||||
if err != nil {
|
||||
t.Errorf("process %d error: %s", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos int
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
|
||||
},
|
||||
{
|
||||
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
|
||||
},
|
||||
{
|
||||
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
|
||||
},
|
||||
{
|
||||
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
|
||||
},
|
||||
{
|
||||
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
|
||||
},
|
||||
{
|
||||
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
|
||||
},
|
||||
{
|
||||
17, 2, 8, "SELECT *\nFROM foo", true, // last character
|
||||
},
|
||||
{
|
||||
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
run := func(crlf bool, nonASCII bool) {
|
||||
var name string
|
||||
if crlf {
|
||||
name = "crlf"
|
||||
} else {
|
||||
name = "lf"
|
||||
}
|
||||
if nonASCII {
|
||||
name += "-nonascii"
|
||||
} else {
|
||||
name += "-ascii"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
input := tc.input
|
||||
if crlf {
|
||||
input = strings.Replace(input, "\n", "\r\n", -1)
|
||||
}
|
||||
if nonASCII {
|
||||
input = strings.Replace(input, "FROM", "FRÖM", -1)
|
||||
}
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
|
||||
|
||||
if tc.wantOk {
|
||||
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
|
||||
}
|
||||
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
run(false, false)
|
||||
run(true, false)
|
||||
run(false, true)
|
||||
run(true, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
# pgx
|
||||
|
||||
This package is for [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5). A backend for the older [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). is [also available](..).
|
||||
|
||||
`pgx5://user:password@host:port/dbname?query`
|
||||
|
||||
| URL Query | WithInstance Config | Description |
|
||||
|------------|---------------------|-------------|
|
||||
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
|
||||
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
|
||||
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
|
||||
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
|
||||
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
|
||||
| `dbname` | `DatabaseName` | The name of the database to connect to |
|
||||
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
|
||||
| `user` | | The user to sign in as |
|
||||
| `password` | | The user's password |
|
||||
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
|
||||
| `port` | | The port to bind to. (default is 5432) |
|
||||
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
|
||||
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
|
||||
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
|
||||
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
|
||||
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
|
||||
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
|
||||
|
||||
|
||||
## Upgrading from v1
|
||||
|
||||
1. Write down the current migration version from schema_migrations
|
||||
1. `DROP TABLE schema_migrations`
|
||||
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
|
||||
3. Download and install the latest migrate version.
|
||||
4. Force the current migration version with `migrate force <current_version>`.
|
||||
|
||||
## Multi-statement mode
|
||||
|
||||
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
|
||||
behavior is not desirable because some statements can be only run outside of transaction (e.g.
|
||||
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
|
||||
you have to put such statements in a separate migration files.
|
||||
@@ -0,0 +1,486 @@
|
||||
//go:build go1.9
|
||||
// +build go1.9
|
||||
|
||||
package pgx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
nurl "net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
"github.com/golang-migrate/migrate/v4/database/multistmt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/jackc/pgerrcode"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
func init() {
|
||||
db := Postgres{}
|
||||
database.Register("pgx5", &db)
|
||||
}
|
||||
|
||||
var (
|
||||
multiStmtDelimiter = []byte(";")
|
||||
|
||||
DefaultMigrationsTable = "schema_migrations"
|
||||
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||
ErrNoSchema = fmt.Errorf("no schema")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
migrationsSchemaName string
|
||||
migrationsTableName string
|
||||
StatementTimeout time.Duration
|
||||
MigrationsTableQuoted bool
|
||||
MultiStatementEnabled bool
|
||||
MultiStatementMaxSize int
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
// Locking and unlocking need to use the same connection
|
||||
conn *sql.Conn
|
||||
db *sql.DB
|
||||
isLocked atomic.Bool
|
||||
|
||||
// Open and WithInstance need to guarantee that config is never nil
|
||||
config *Config
|
||||
}
|
||||
|
||||
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
}
|
||||
|
||||
if err := instance.Ping(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if config.DatabaseName == "" {
|
||||
query := `SELECT CURRENT_DATABASE()`
|
||||
var databaseName string
|
||||
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(databaseName) == 0 {
|
||||
return nil, ErrNoDatabaseName
|
||||
}
|
||||
|
||||
config.DatabaseName = databaseName
|
||||
}
|
||||
|
||||
if config.SchemaName == "" {
|
||||
query := `SELECT CURRENT_SCHEMA()`
|
||||
var schemaName string
|
||||
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(schemaName) == 0 {
|
||||
return nil, ErrNoSchema
|
||||
}
|
||||
|
||||
config.SchemaName = schemaName
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
config.migrationsSchemaName = config.SchemaName
|
||||
config.migrationsTableName = config.MigrationsTable
|
||||
if config.MigrationsTableQuoted {
|
||||
re := regexp.MustCompile(`"(.*?)"`)
|
||||
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
|
||||
config.migrationsTableName = result[len(result)-1][1]
|
||||
if len(result) == 2 {
|
||||
config.migrationsSchemaName = result[0][1]
|
||||
} else if len(result) > 2 {
|
||||
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := instance.Conn(context.Background())
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
px := &Postgres{
|
||||
conn: conn,
|
||||
db: instance,
|
||||
config: config,
|
||||
}
|
||||
|
||||
if err := px.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Open(url string) (database.Driver, error) {
|
||||
purl, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Driver is registered as pgx, but connection string must use postgres schema
|
||||
// when making actual connection
|
||||
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
|
||||
purl.Scheme = "postgres"
|
||||
|
||||
db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
migrationsTable := purl.Query().Get("x-migrations-table")
|
||||
migrationsTableQuoted := false
|
||||
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
|
||||
migrationsTableQuoted, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
|
||||
}
|
||||
}
|
||||
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
|
||||
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
|
||||
}
|
||||
|
||||
statementTimeoutString := purl.Query().Get("x-statement-timeout")
|
||||
statementTimeout := 0
|
||||
if statementTimeoutString != "" {
|
||||
statementTimeout, err = strconv.Atoi(statementTimeoutString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementMaxSize := DefaultMultiStatementMaxSize
|
||||
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
|
||||
multiStatementMaxSize, err = strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if multiStatementMaxSize <= 0 {
|
||||
multiStatementMaxSize = DefaultMultiStatementMaxSize
|
||||
}
|
||||
}
|
||||
|
||||
multiStatementEnabled := false
|
||||
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
|
||||
multiStatementEnabled, err = strconv.ParseBool(s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
px, err := WithInstance(db, &Config{
|
||||
DatabaseName: purl.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
MigrationsTableQuoted: migrationsTableQuoted,
|
||||
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
|
||||
MultiStatementEnabled: multiStatementEnabled,
|
||||
MultiStatementMaxSize: multiStatementMaxSize,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return px, nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Close() error {
|
||||
connErr := p.conn.Close()
|
||||
dbErr := p.db.Close()
|
||||
if connErr != nil || dbErr != nil {
|
||||
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
|
||||
func (p *Postgres) Lock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This will wait indefinitely until the lock can be acquired.
|
||||
query := `SELECT pg_advisory_lock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Unlock() error {
|
||||
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query := `SELECT pg_advisory_unlock($1)`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (p *Postgres) Run(migration io.Reader) error {
|
||||
if p.config.MultiStatementEnabled {
|
||||
var err error
|
||||
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
|
||||
if err = p.runStatement(m); err != nil {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
migr, err := io.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return p.runStatement(migr)
|
||||
}
|
||||
|
||||
func (p *Postgres) runStatement(statement []byte) error {
|
||||
ctx := context.Background()
|
||||
if p.config.StatementTimeout != 0 {
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
|
||||
defer cancel()
|
||||
}
|
||||
query := string(statement)
|
||||
if strings.TrimSpace(query) == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := p.conn.ExecContext(ctx, query); err != nil {
|
||||
|
||||
if pgErr, ok := err.(*pgconn.PgError); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
|
||||
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
|
||||
// replace crlf with lf
|
||||
s = strings.Replace(s, "\r\n", "\n", -1)
|
||||
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
|
||||
runes := []rune(s)
|
||||
if pos > len(runes) {
|
||||
return 0, 0, false
|
||||
}
|
||||
sel := runes[:pos]
|
||||
line = uint(runesCount(sel, newLine) + 1)
|
||||
col = uint(pos - 1 - runesLastIndex(sel, newLine))
|
||||
return line, col, true
|
||||
}
|
||||
|
||||
const newLine = '\n'
|
||||
|
||||
func runesCount(input []rune, target rune) int {
|
||||
var count int
|
||||
for _, r := range input {
|
||||
if r == target {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func runesLastIndex(input []rune, target rune) int {
|
||||
for i := len(input) - 1; i >= 0; i-- {
|
||||
if input[i] == target {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (p *Postgres) SetVersion(version int, dirty bool) error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction start failed"}
|
||||
}
|
||||
|
||||
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
// Also re-write the schema version for nil dirty versions to prevent
|
||||
// empty schema version for failed down migration on the first migration
|
||||
// See: https://github.com/golang-migrate/migrate/issues/330
|
||||
if version >= 0 || (version == database.NilVersion && dirty) {
|
||||
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
|
||||
if _, err := tx.Exec(query, version, dirty); err != nil {
|
||||
if errRollback := tx.Rollback(); errRollback != nil {
|
||||
err = multierror.Append(err, errRollback)
|
||||
}
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Postgres) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
|
||||
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == sql.ErrNoRows:
|
||||
return database.NilVersion, false, nil
|
||||
|
||||
case err != nil:
|
||||
if e, ok := err.(*pgconn.PgError); ok {
|
||||
if e.SQLState() == pgerrcode.UndefinedTable {
|
||||
return database.NilVersion, false, nil
|
||||
}
|
||||
}
|
||||
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
|
||||
default:
|
||||
return version, dirty, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Postgres) Drop() (err error) {
|
||||
// select all tables in current schema
|
||||
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
|
||||
tables, err := p.conn.QueryContext(context.Background(), query)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
defer func() {
|
||||
if errClose := tables.Close(); errClose != nil {
|
||||
err = multierror.Append(err, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// delete one table after another
|
||||
tableNames := make([]string, 0)
|
||||
for tables.Next() {
|
||||
var tableName string
|
||||
if err := tables.Scan(&tableName); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(tableName) > 0 {
|
||||
tableNames = append(tableNames, tableName)
|
||||
}
|
||||
}
|
||||
if err := tables.Err(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(tableNames) > 0 {
|
||||
// delete one by one ...
|
||||
for _, t := range tableNames {
|
||||
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
|
||||
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ensureVersionTable checks if versions table exists and, if not, creates it.
|
||||
// Note that this function locks the database, which deviates from the usual
|
||||
// convention of "caller locks" in the Postgres type.
|
||||
func (p *Postgres) ensureVersionTable() (err error) {
|
||||
if err = p.Lock(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if e := p.Unlock(); e != nil {
|
||||
if err == nil {
|
||||
err = e
|
||||
} else {
|
||||
err = multierror.Append(err, e)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
|
||||
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
|
||||
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
|
||||
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
|
||||
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
|
||||
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
|
||||
|
||||
var count int
|
||||
err = row.Scan(&count)
|
||||
if err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if count == 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
|
||||
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
|
||||
func quoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
|
||||
}
|
||||
@@ -0,0 +1,764 @@
|
||||
package pgx
|
||||
|
||||
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
|
||||
"github.com/dhui/dktest"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4/database"
|
||||
dt "github.com/golang-migrate/migrate/v4/database/testing"
|
||||
"github.com/golang-migrate/migrate/v4/dktesting"
|
||||
_ "github.com/golang-migrate/migrate/v4/source/file"
|
||||
)
|
||||
|
||||
const (
|
||||
pgPassword = "postgres"
|
||||
)
|
||||
|
||||
var (
|
||||
opts = dktest.Options{
|
||||
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
|
||||
PortRequired: true, ReadyFunc: isReady}
|
||||
// Supported versions: https://www.postgresql.org/support/versioning/
|
||||
specs = []dktesting.ContainerSpec{
|
||||
{ImageName: "postgres:9.5", Options: opts},
|
||||
{ImageName: "postgres:9.6", Options: opts},
|
||||
{ImageName: "postgres:10", Options: opts},
|
||||
{ImageName: "postgres:11", Options: opts},
|
||||
{ImageName: "postgres:12", Options: opts},
|
||||
{ImageName: "postgres:13", Options: opts},
|
||||
{ImageName: "postgres:14", Options: opts},
|
||||
{ImageName: "postgres:15", Options: opts},
|
||||
}
|
||||
)
|
||||
|
||||
func pgConnectionString(host, port string, options ...string) string {
|
||||
options = append(options, "sslmode=disable")
|
||||
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
|
||||
}
|
||||
|
||||
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
log.Println("close error:", err)
|
||||
}
|
||||
}()
|
||||
if err = db.PingContext(ctx); err != nil {
|
||||
switch err {
|
||||
case sqldriver.ErrBadConn, io.EOF:
|
||||
return false
|
||||
default:
|
||||
log.Println(err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func mustRun(t *testing.T, d database.Driver, statements []string) {
|
||||
for _, statement := range statements {
|
||||
if err := d.Run(strings.NewReader(statement)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Test(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
dt.TestMigrate(t, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatements(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure second table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-multi-statement=true")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
|
||||
t.Fatalf("expected err to be nil, got %v", err)
|
||||
}
|
||||
|
||||
// make sure created index exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table bar to exist")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
|
||||
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
|
||||
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port, "x-custom=foobar")
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foobar schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.SetVersion(1, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
version, _, err := d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("expected NilVersion")
|
||||
}
|
||||
|
||||
// now update version and compare
|
||||
if err := d2.SetVersion(2, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
version, _, err = d2.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 2 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
|
||||
// meanwhile, the public schema still has the other version
|
||||
version, _, err = d.Version()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if version != 1 {
|
||||
t.Fatal("expected version 2")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMigrationTableOption(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, _ := p.Open(addr)
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create migrate schema
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// bad unquoted x-migrations-table parameter
|
||||
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// too many quoted x-migrations-table parameters
|
||||
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if (err != nil) && (err.Error() != wantErr) {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
|
||||
// good quoted x-migrations-table parameter
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make sure migrate.schema_migrations table exists
|
||||
var exists bool
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table migrate.schema_migrations to exist")
|
||||
}
|
||||
|
||||
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
|
||||
pgPassword, ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
defer func() {
|
||||
if d2 == nil {
|
||||
return
|
||||
}
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
var e *database.Error
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
|
||||
// re-connect using that x-migrations-table and x-migrations-table-quoted
|
||||
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if !errors.As(err, &e) || err == nil {
|
||||
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
|
||||
t.Fatal(e)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestCheckBeforeCreateTable(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
|
||||
// Check that opening the postgres connection returns NilVersion
|
||||
p := &Postgres{}
|
||||
|
||||
d, err := p.Open(addr)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
|
||||
// since this is a test environment and we're not expecting to the pgPassword to be malicious
|
||||
mustRun(t, d, []string{
|
||||
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
|
||||
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
|
||||
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
|
||||
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := d2.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// revoke privileges
|
||||
mustRun(t, d, []string{
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
|
||||
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
|
||||
})
|
||||
|
||||
// re-connect using that schema
|
||||
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
|
||||
pgPassword, ip, port))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
version, _, err := d3.Version()
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if version != database.NilVersion {
|
||||
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := d3.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
func TestParallelSchema(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := d.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// create foo and bar schemas
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schemas
|
||||
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dfoo.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := dbar.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := dfoo.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dfoo.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPostgres_Lock(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr := pgConnectionString(ip, port)
|
||||
p := &Postgres{}
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
dt.Test(t, d, []byte("SELECT 1"))
|
||||
|
||||
ps := d.(*Postgres)
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Lock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = ps.Unlock()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance_Concurrent(t *testing.T) {
|
||||
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
|
||||
ip, port, err := c.FirstPort()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// The number of concurrent processes running WithInstance
|
||||
const concurrency = 30
|
||||
|
||||
// We can instantiate a single database handle because it is
|
||||
// actually a connection pool, and so, each of the below go
|
||||
// routines will have a high probability of using a separate
|
||||
// connection, which is something we want to exercise.
|
||||
db, err := sql.Open("pgx", pgConnectionString(ip, port))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer func() {
|
||||
if err := db.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetMaxIdleConns(concurrency)
|
||||
db.SetMaxOpenConns(concurrency)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
defer wg.Wait()
|
||||
|
||||
wg.Add(concurrency)
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
_, err := WithInstance(db, &Config{})
|
||||
if err != nil {
|
||||
t.Errorf("process %d error: %s", i, err)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
})
|
||||
}
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos int
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
|
||||
},
|
||||
{
|
||||
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
|
||||
},
|
||||
{
|
||||
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
|
||||
},
|
||||
{
|
||||
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
|
||||
},
|
||||
{
|
||||
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
|
||||
},
|
||||
{
|
||||
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
|
||||
},
|
||||
{
|
||||
17, 2, 8, "SELECT *\nFROM foo", true, // last character
|
||||
},
|
||||
{
|
||||
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
run := func(crlf bool, nonASCII bool) {
|
||||
var name string
|
||||
if crlf {
|
||||
name = "crlf"
|
||||
} else {
|
||||
name = "lf"
|
||||
}
|
||||
if nonASCII {
|
||||
name += "-nonascii"
|
||||
} else {
|
||||
name += "-ascii"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
input := tc.input
|
||||
if crlf {
|
||||
input = strings.Replace(input, "\n", "\r\n", -1)
|
||||
}
|
||||
if nonASCII {
|
||||
input = strings.Replace(input, "FROM", "FRÖM", -1)
|
||||
}
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
|
||||
|
||||
if tc.wantOk {
|
||||
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
|
||||
}
|
||||
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
run(false, false)
|
||||
run(true, false)
|
||||
run(false, true)
|
||||
run(true, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user