whatcanGOwrong

This commit is contained in:
2024-09-19 21:38:24 -04:00
commit d0ae4d841d
17908 changed files with 4096831 additions and 0 deletions
@@ -0,0 +1,43 @@
# pgx
This package is for [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). A backend for the newer [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5) is [also available](v5).
`pgx://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `x-lock-strategy` | `LockStrategy` | Strategy used for locking during migration (default: advisory) |
| `x-lock-table` | `LockTable` | Name of the table which maintains the migration lock (default: schema_lock) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
## Multi-statement mode
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
@@ -0,0 +1,5 @@
CREATE TABLE users (
user_id integer unique,
name varchar(40),
email varchar(40)
);
@@ -0,0 +1 @@
ALTER TABLE users DROP COLUMN IF EXISTS city;
@@ -0,0 +1,3 @@
ALTER TABLE users ADD COLUMN city varchar(100);
@@ -0,0 +1,3 @@
CREATE UNIQUE INDEX CONCURRENTLY users_email_index ON users (email);
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,5 @@
CREATE TABLE books (
user_id integer,
name varchar(40),
author varchar(40)
);
@@ -0,0 +1,5 @@
CREATE TABLE movies (
user_id integer,
name varchar(40),
director varchar(40)
);
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1 @@
-- Lorem ipsum dolor sit amet, consectetur adipiscing elit. Aenean sed interdum velit, tristique iaculis justo. Pellentesque ut porttitor dolor. Donec sit amet pharetra elit. Cras vel ligula ex. Phasellus posuere.
@@ -0,0 +1,623 @@
//go:build go1.9
// +build go1.9
package pgx
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
_ "github.com/jackc/pgx/v4/stdlib"
"github.com/lib/pq"
)
const (
LockStrategyAdvisory = "advisory"
LockStrategyTable = "table"
)
func init() {
db := Postgres{}
database.Register("pgx", &db)
database.Register("pgx4", &db)
}
var (
multiStmtDelimiter = []byte(";")
DefaultMigrationsTable = "schema_migrations"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
DefaultLockTable = "schema_lock"
DefaultLockStrategy = LockStrategyAdvisory
)
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
)
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
LockTable string
LockStrategy string
migrationsSchemaName string
migrationsTableName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
}
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(schemaName) == 0 {
return nil, ErrNoSchema
}
config.SchemaName = schemaName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
if len(config.LockTable) == 0 {
config.LockTable = DefaultLockTable
}
if len(config.LockStrategy) == 0 {
config.LockStrategy = DefaultLockStrategy
}
config.migrationsSchemaName = config.SchemaName
config.migrationsTableName = config.MigrationsTable
if config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
config.migrationsTableName = result[len(result)-1][1]
if len(result) == 2 {
config.migrationsSchemaName = result[0][1]
} else if len(result) > 2 {
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
}
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
px := &Postgres{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureLockTable(); err != nil {
return nil, err
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// Driver is registered as pgx, but connection string must use postgres schema
// when making actual connection
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"
db, err := sql.Open("pgx/v4", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}
statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutString)
if err != nil {
return nil, err
}
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
if multiStatementMaxSize <= 0 {
multiStatementMaxSize = DefaultMultiStatementMaxSize
}
}
multiStatementEnabled := false
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
multiStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
}
}
lockStrategy := purl.Query().Get("x-lock-strategy")
lockTable := purl.Query().Get("x-lock-table")
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
LockStrategy: lockStrategy,
LockTable: lockTable,
})
if err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
func (p *Postgres) Lock() error {
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
switch p.config.LockStrategy {
case LockStrategyAdvisory:
return p.applyAdvisoryLock()
case LockStrategyTable:
return p.applyTableLock()
default:
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
}
})
}
func (p *Postgres) Unlock() error {
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
switch p.config.LockStrategy {
case LockStrategyAdvisory:
return p.releaseAdvisoryLock()
case LockStrategyTable:
return p.releaseTableLock()
default:
return fmt.Errorf("unknown lock strategy \"%s\"", p.config.LockStrategy)
}
})
}
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) applyAdvisoryLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
}
func (p *Postgres) applyTableLock() error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
defer func() {
errRollback := tx.Rollback()
if errRollback != nil {
err = multierror.Append(err, errRollback)
}
}()
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
query := "SELECT * FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
rows, err := tx.Query(query, aid)
if err != nil {
return database.Error{OrigErr: err, Err: "failed to fetch migration lock", Query: []byte(query)}
}
defer func() {
if errClose := rows.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// If row exists at all, lock is present
locked := rows.Next()
if locked {
return database.ErrLocked
}
query = "INSERT INTO " + pq.QuoteIdentifier(p.config.LockTable) + " (lock_id) VALUES ($1)"
if _, err := tx.Exec(query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to set migration lock", Query: []byte(query)}
}
return tx.Commit()
}
func (p *Postgres) releaseAdvisoryLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (p *Postgres) releaseTableLock() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
if err != nil {
return err
}
query := "DELETE FROM " + pq.QuoteIdentifier(p.config.LockTable) + " WHERE lock_id = $1"
if _, err := p.db.Exec(query, aid); err != nil {
return database.Error{OrigErr: err, Err: "failed to release migration lock", Query: []byte(query)}
}
return nil
}
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
return p.runStatement(migr)
}
func (p *Postgres) runStatement(statement []byte) error {
ctx := context.Background()
if p.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
var col uint
var lineColOK bool
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pgconn.PgError); ok {
if e.SQLState() == pgerrcode.UndefinedTable {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
// do not drop lock table
if tableName == p.config.LockTable && p.config.LockStrategy == LockStrategyTable {
continue
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Postgres type.
func (p *Postgres) ensureVersionTable() (err error) {
if err = p.Lock(); err != nil {
return err
}
defer func() {
if e := p.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
func (p *Postgres) ensureLockTable() error {
if p.config.LockStrategy != LockStrategyTable {
return nil
}
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
if err := p.db.QueryRow(query, p.config.LockTable).Scan(&count); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE ` + pq.QuoteIdentifier(p.config.LockTable) + ` (lock_id BIGINT NOT NULL PRIMARY KEY)`
if _, err := p.db.Exec(query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
@@ -0,0 +1,789 @@
package pgx
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:9.5", Options: opts},
{ImageName: "postgres:9.6", Options: opts},
{ImageName: "postgres:10", Options: opts},
{ImageName: "postgres:11", Options: opts},
{ImageName: "postgres:12", Options: opts},
{ImageName: "postgres:13", Options: opts},
{ImageName: "postgres:14", Options: opts},
{ImageName: "postgres:15", Options: opts},
}
)
func pgConnectionString(host, port string, options ...string) string {
options = append(options, "sslmode=disable")
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMigrateLockTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-lock-strategy=table", "x-lock-table=lock_table")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://./examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMultipleStatements(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-multi-statement=true")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure created index exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-custom=foobar")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
})
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create foobar schema
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.SetVersion(1, false); err != nil {
t.Fatal(err)
}
// re-connect using that schema
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("expected NilVersion")
}
// now update version and compare
if err := d2.SetVersion(2, false); err != nil {
t.Fatal(err)
}
version, _, err = d2.Version()
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected version 2")
}
// meanwhile, the public schema still has the other version
version, _, err = d.Version()
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected version 2")
}
})
}
func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create migrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// bad unquoted x-migrations-table parameter
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// too many quoted x-migrations-table parameters
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
// make sure migrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table migrate.schema_migrations to exist")
}
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
}
})
}
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
defer func() {
if d2 == nil {
return
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
var e *database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}
func TestCheckBeforeCreateTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
// revoke privileges
mustRun(t, d, []string{
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
version, _, err := d3.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
}
defer func() {
if err := d3.Close(); err != nil {
t.Fatal(err)
}
}()
})
}
func TestParallelSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create foo and bar schemas
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// re-connect using that schemas
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dfoo.Close(); err != nil {
t.Error(err)
}
}()
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dbar.Close(); err != nil {
t.Error(err)
}
}()
if err := dfoo.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Unlock(); err != nil {
t.Fatal(err)
}
if err := dfoo.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ps := d.(*Postgres)
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestWithInstance_Concurrent(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// The number of concurrent processes running WithInstance
const concurrency = 30
// We can instantiate a single database handle because it is
// actually a connection pool, and so, each of the below go
// routines will have a high probability of using a separate
// connection, which is something we want to exercise.
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
db.SetMaxIdleConns(concurrency)
db.SetMaxOpenConns(concurrency)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wg.Done()
_, err := WithInstance(db, &Config{})
if err != nil {
t.Errorf("process %d error: %s", i, err)
}
}(i)
}
})
}
func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
wantLine uint
wantCol uint
input string
wantOk bool
}{
{
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
},
{
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
},
{
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
},
{
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
},
{
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
},
{
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
},
{
17, 2, 8, "SELECT *\nFROM foo", true, // last character
},
{
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
},
}
for i, tc := range testcases {
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
run := func(crlf bool, nonASCII bool) {
var name string
if crlf {
name = "crlf"
} else {
name = "lf"
}
if nonASCII {
name += "-nonascii"
} else {
name += "-ascii"
}
t.Run(name, func(t *testing.T) {
input := tc.input
if crlf {
input = strings.Replace(input, "\n", "\r\n", -1)
}
if nonASCII {
input = strings.Replace(input, "FROM", "FRÖM", -1)
}
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
if tc.wantOk {
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
}
if gotOK != tc.wantOk {
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
}
if gotLine != tc.wantLine {
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
}
if gotCol != tc.wantCol {
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
}
})
}
run(false, false)
run(true, false)
run(false, true)
run(true, true)
})
}
}
@@ -0,0 +1,41 @@
# pgx
This package is for [pgx/v5](https://pkg.go.dev/github.com/jackc/pgx/v5). A backend for the older [pgx/v4](https://pkg.go.dev/github.com/jackc/pgx/v4). is [also available](..).
`pgx5://user:password@host:port/dbname?query`
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
| `dbname` | `DatabaseName` | The name of the database to connect to |
| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. |
| `user` | | The user to sign in as |
| `password` | | The user's password |
| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) |
| `port` | | The port to bind to. (default is 5432) |
| `fallback_application_name` | | An application_name to fall back to if one isn't provided. |
| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. |
| `sslcert` | | Cert file location. The file must contain PEM encoded data. |
| `sslkey` | | Key file location. The file must contain PEM encoded data. |
| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. |
| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) |
## Upgrading from v1
1. Write down the current migration version from schema_migrations
1. `DROP TABLE schema_migrations`
2. Wrap your existing migrations in transactions ([BEGIN/COMMIT](https://www.postgresql.org/docs/current/static/transaction-iso.html)) if you use multiple statements within one migration.
3. Download and install the latest migrate version.
4. Force the current migration version with `migrate force <current_version>`.
## Multi-statement mode
In PostgreSQL running multiple SQL statements in one `Exec` executes them inside a transaction. Sometimes this
behavior is not desirable because some statements can be only run outside of transaction (e.g.
`CREATE INDEX CONCURRENTLY`). If you want to use `CREATE INDEX CONCURRENTLY` without activating multi-statement mode
you have to put such statements in a separate migration files.
@@ -0,0 +1,486 @@
//go:build go1.9
// +build go1.9
package pgx
import (
"context"
"database/sql"
"fmt"
"io"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
"go.uber.org/atomic"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/database/multistmt"
"github.com/hashicorp/go-multierror"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5/pgconn"
_ "github.com/jackc/pgx/v5/stdlib"
)
func init() {
db := Postgres{}
database.Register("pgx5", &db)
}
var (
multiStmtDelimiter = []byte(";")
DefaultMigrationsTable = "schema_migrations"
DefaultMultiStatementMaxSize = 10 * 1 << 20 // 10 MB
)
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoDatabaseName = fmt.Errorf("no database name")
ErrNoSchema = fmt.Errorf("no schema")
)
type Config struct {
MigrationsTable string
DatabaseName string
SchemaName string
migrationsSchemaName string
migrationsTableName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
type Postgres struct {
// Locking and unlocking need to use the same connection
conn *sql.Conn
db *sql.DB
isLocked atomic.Bool
// Open and WithInstance need to guarantee that config is never nil
config *Config
}
func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if err := instance.Ping(); err != nil {
return nil, err
}
if config.DatabaseName == "" {
query := `SELECT CURRENT_DATABASE()`
var databaseName string
if err := instance.QueryRow(query).Scan(&databaseName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(databaseName) == 0 {
return nil, ErrNoDatabaseName
}
config.DatabaseName = databaseName
}
if config.SchemaName == "" {
query := `SELECT CURRENT_SCHEMA()`
var schemaName string
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(schemaName) == 0 {
return nil, ErrNoSchema
}
config.SchemaName = schemaName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
config.migrationsSchemaName = config.SchemaName
config.migrationsTableName = config.MigrationsTable
if config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllStringSubmatch(config.MigrationsTable, -1)
config.migrationsTableName = result[len(result)-1][1]
if len(result) == 2 {
config.migrationsSchemaName = result[0][1]
} else if len(result) > 2 {
return nil, fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", config.MigrationsTable)
}
}
conn, err := instance.Conn(context.Background())
if err != nil {
return nil, err
}
px := &Postgres{
conn: conn,
db: instance,
config: config,
}
if err := px.ensureVersionTable(); err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
// Driver is registered as pgx, but connection string must use postgres schema
// when making actual connection
// i.e. pgx://user:password@host:port/db => postgres://user:password@host:port/db
purl.Scheme = "postgres"
db, err := sql.Open("pgx/v5", migrate.FilterCustomQuery(purl).String())
if err != nil {
return nil, err
}
migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}
statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
statementTimeout, err = strconv.Atoi(statementTimeoutString)
if err != nil {
return nil, err
}
}
multiStatementMaxSize := DefaultMultiStatementMaxSize
if s := purl.Query().Get("x-multi-statement-max-size"); len(s) > 0 {
multiStatementMaxSize, err = strconv.Atoi(s)
if err != nil {
return nil, err
}
if multiStatementMaxSize <= 0 {
multiStatementMaxSize = DefaultMultiStatementMaxSize
}
}
multiStatementEnabled := false
if s := purl.Query().Get("x-multi-statement"); len(s) > 0 {
multiStatementEnabled, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-multi-statement: %w", err)
}
}
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
})
if err != nil {
return nil, err
}
return px, nil
}
func (p *Postgres) Close() error {
connErr := p.conn.Close()
dbErr := p.db.Close()
if connErr != nil || dbErr != nil {
return fmt.Errorf("conn: %v, db: %v", connErr, dbErr)
}
return nil
}
// https://www.postgresql.org/docs/9.6/static/explicit-locking.html#ADVISORY-LOCKS
func (p *Postgres) Lock() error {
return database.CasRestoreOnErr(&p.isLocked, false, true, database.ErrLocked, func() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
// This will wait indefinitely until the lock can be acquired.
query := `SELECT pg_advisory_lock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)}
}
return nil
})
}
func (p *Postgres) Unlock() error {
return database.CasRestoreOnErr(&p.isLocked, true, false, database.ErrNotLocked, func() error {
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.migrationsSchemaName, p.config.migrationsTableName)
if err != nil {
return err
}
query := `SELECT pg_advisory_unlock($1)`
if _, err := p.conn.ExecContext(context.Background(), query, aid); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
})
}
func (p *Postgres) Run(migration io.Reader) error {
if p.config.MultiStatementEnabled {
var err error
if e := multistmt.Parse(migration, multiStmtDelimiter, p.config.MultiStatementMaxSize, func(m []byte) bool {
if err = p.runStatement(m); err != nil {
return false
}
return true
}); e != nil {
return e
}
return err
}
migr, err := io.ReadAll(migration)
if err != nil {
return err
}
return p.runStatement(migr)
}
func (p *Postgres) runStatement(statement []byte) error {
ctx := context.Background()
if p.config.StatementTimeout != 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, p.config.StatementTimeout)
defer cancel()
}
query := string(statement)
if strings.TrimSpace(query) == "" {
return nil
}
if _, err := p.conn.ExecContext(ctx, query); err != nil {
if pgErr, ok := err.(*pgconn.PgError); ok {
var line uint
var col uint
var lineColOK bool
line, col, lineColOK = computeLineFromPos(query, int(pgErr.Position))
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
if lineColOK {
message = fmt.Sprintf("%s (column %d)", message, col)
}
if pgErr.Detail != "" {
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
}
return database.Error{OrigErr: err, Err: message, Query: statement, Line: line}
}
return database.Error{OrigErr: err, Err: "migration failed", Query: statement}
}
return nil
}
func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) {
// replace crlf with lf
s = strings.Replace(s, "\r\n", "\n", -1)
// pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes
runes := []rune(s)
if pos > len(runes) {
return 0, 0, false
}
sel := runes[:pos]
line = uint(runesCount(sel, newLine) + 1)
col = uint(pos - 1 - runesLastIndex(sel, newLine))
return line, col, true
}
const newLine = '\n'
func runesCount(input []rune, target rune) int {
var count int
for _, r := range input {
if r == target {
count++
}
}
return count
}
func runesLastIndex(input []rune, target rune) int {
for i := len(input) - 1; i >= 0; i-- {
if input[i] == target {
return i
}
}
return -1
}
func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}
query := `TRUNCATE ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
// Also re-write the schema version for nil dirty versions to prevent
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
}
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
if err := tx.Commit(); err != nil {
return &database.Error{OrigErr: err, Err: "transaction commit failed"}
}
return nil
}
func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
return database.NilVersion, false, nil
case err != nil:
if e, ok := err.(*pgconn.PgError); ok {
if e.SQLState() == pgerrcode.UndefinedTable {
return database.NilVersion, false, nil
}
}
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
default:
return version, dirty, nil
}
}
func (p *Postgres) Drop() (err error) {
// select all tables in current schema
query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema()) AND table_type='BASE TABLE'`
tables, err := p.conn.QueryContext(context.Background(), query)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
defer func() {
if errClose := tables.Close(); errClose != nil {
err = multierror.Append(err, errClose)
}
}()
// delete one table after another
tableNames := make([]string, 0)
for tables.Next() {
var tableName string
if err := tables.Scan(&tableName); err != nil {
return err
}
if len(tableName) > 0 {
tableNames = append(tableNames, tableName)
}
}
if err := tables.Err(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
}
return nil
}
// ensureVersionTable checks if versions table exists and, if not, creates it.
// Note that this function locks the database, which deviates from the usual
// convention of "caller locks" in the Postgres type.
func (p *Postgres) ensureVersionTable() (err error) {
if err = p.Lock(); err != nil {
return err
}
defer func() {
if e := p.Unlock(); e != nil {
if err == nil {
err = e
} else {
err = multierror.Append(err, e)
}
}
}()
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_schema = $1 AND table_name = $2 LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.migrationsSchemaName, p.config.migrationsTableName)
var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if count == 1 {
return nil
}
query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.migrationsSchemaName) + `.` + quoteIdentifier(p.config.migrationsTableName) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
return nil
}
// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
@@ -0,0 +1,764 @@
package pgx
// error codes https://github.com/jackc/pgerrcode/blob/master/errcode.go
import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"
"github.com/golang-migrate/migrate/v4"
"github.com/dhui/dktest"
"github.com/golang-migrate/migrate/v4/database"
dt "github.com/golang-migrate/migrate/v4/database/testing"
"github.com/golang-migrate/migrate/v4/dktesting"
_ "github.com/golang-migrate/migrate/v4/source/file"
)
const (
pgPassword = "postgres"
)
var (
opts = dktest.Options{
Env: map[string]string{"POSTGRES_PASSWORD": pgPassword},
PortRequired: true, ReadyFunc: isReady}
// Supported versions: https://www.postgresql.org/support/versioning/
specs = []dktesting.ContainerSpec{
{ImageName: "postgres:9.5", Options: opts},
{ImageName: "postgres:9.6", Options: opts},
{ImageName: "postgres:10", Options: opts},
{ImageName: "postgres:11", Options: opts},
{ImageName: "postgres:12", Options: opts},
{ImageName: "postgres:13", Options: opts},
{ImageName: "postgres:14", Options: opts},
{ImageName: "postgres:15", Options: opts},
}
)
func pgConnectionString(host, port string, options ...string) string {
options = append(options, "sslmode=disable")
return fmt.Sprintf("postgres://postgres:%s@%s:%s/postgres?%s", pgPassword, host, port, strings.Join(options, "&"))
}
func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
ip, port, err := c.FirstPort()
if err != nil {
return false
}
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
return false
}
defer func() {
if err := db.Close(); err != nil {
log.Println("close error:", err)
}
}()
if err = db.PingContext(ctx); err != nil {
switch err {
case sqldriver.ErrBadConn, io.EOF:
return false
default:
log.Println(err)
}
return false
}
return true
}
func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}
func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
dt.Test(t, d, []byte("SELECT 1"))
})
}
func TestMigrate(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
m, err := migrate.NewWithDatabaseInstance("file://../examples/migrations", "pgx", d)
if err != nil {
t.Fatal(err)
}
dt.TestMigrate(t, m)
})
}
func TestMultipleStatements(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLE bar (bar text);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure second table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'bar' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestMultipleStatementsInMultiStatementMode(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-multi-statement=true")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE INDEX CONCURRENTLY idx_foo ON foo (foo);")); err != nil {
t.Fatalf("expected err to be nil, got %v", err)
}
// make sure created index exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM pg_indexes WHERE schemaname = (SELECT current_schema()) AND indexname = 'idx_foo')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table bar to exist")
}
})
}
func TestErrorParsing(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
`(foo text); CREATE TABLEE bar (bar text); (details: ERROR: syntax error at or near "TABLEE" (SQLSTATE 42601))`
if err := d.Run(strings.NewReader("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);")); err == nil {
t.Fatal("expected err but got nil")
} else if err.Error() != wantErr {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
})
}
func TestFilterCustomQuery(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port, "x-custom=foobar")
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
})
}
func TestWithSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create foobar schema
if err := d.Run(strings.NewReader("CREATE SCHEMA foobar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.SetVersion(1, false); err != nil {
t.Fatal(err)
}
// re-connect using that schema
d2, err := p.Open(pgConnectionString(ip, port, "search_path=foobar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
version, _, err := d2.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("expected NilVersion")
}
// now update version and compare
if err := d2.SetVersion(2, false); err != nil {
t.Fatal(err)
}
version, _, err = d2.Version()
if err != nil {
t.Fatal(err)
}
if version != 2 {
t.Fatal("expected version 2")
}
// meanwhile, the public schema still has the other version
version, _, err = d.Version()
if err != nil {
t.Fatal(err)
}
if version != 1 {
t.Fatal("expected version 2")
}
})
}
func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()
// create migrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// bad unquoted x-migrations-table parameter
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// too many quoted x-migrations-table parameters
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}
// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
// make sure migrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table migrate.schema_migrations to exist")
}
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
}
})
}
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
defer func() {
if d2 == nil {
return
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()
var e *database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}
func TestCheckBeforeCreateTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
// Check that opening the postgres connection returns NilVersion
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
})
// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
// revoke privileges
mustRun(t, d, []string{
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})
// re-connect using that schema
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
version, _, err := d3.Version()
if err != nil {
t.Fatal(err)
}
if version != database.NilVersion {
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
}
defer func() {
if err := d3.Close(); err != nil {
t.Fatal(err)
}
}()
})
}
func TestParallelSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()
// create foo and bar schemas
if err := d.Run(strings.NewReader("CREATE SCHEMA foo AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
if err := d.Run(strings.NewReader("CREATE SCHEMA bar AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}
// re-connect using that schemas
dfoo, err := p.Open(pgConnectionString(ip, port, "search_path=foo"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dfoo.Close(); err != nil {
t.Error(err)
}
}()
dbar, err := p.Open(pgConnectionString(ip, port, "search_path=bar"))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := dbar.Close(); err != nil {
t.Error(err)
}
}()
if err := dfoo.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Lock(); err != nil {
t.Fatal(err)
}
if err := dbar.Unlock(); err != nil {
t.Fatal(err)
}
if err := dfoo.Unlock(); err != nil {
t.Fatal(err)
}
})
}
func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
addr := pgConnectionString(ip, port)
p := &Postgres{}
d, err := p.Open(addr)
if err != nil {
t.Fatal(err)
}
dt.Test(t, d, []byte("SELECT 1"))
ps := d.(*Postgres)
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
err = ps.Lock()
if err != nil {
t.Fatal(err)
}
err = ps.Unlock()
if err != nil {
t.Fatal(err)
}
})
}
func TestWithInstance_Concurrent(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}
// The number of concurrent processes running WithInstance
const concurrency = 30
// We can instantiate a single database handle because it is
// actually a connection pool, and so, each of the below go
// routines will have a high probability of using a separate
// connection, which is something we want to exercise.
db, err := sql.Open("pgx", pgConnectionString(ip, port))
if err != nil {
t.Fatal(err)
}
defer func() {
if err := db.Close(); err != nil {
t.Error(err)
}
}()
db.SetMaxIdleConns(concurrency)
db.SetMaxOpenConns(concurrency)
var wg sync.WaitGroup
defer wg.Wait()
wg.Add(concurrency)
for i := 0; i < concurrency; i++ {
go func(i int) {
defer wg.Done()
_, err := WithInstance(db, &Config{})
if err != nil {
t.Errorf("process %d error: %s", i, err)
}
}(i)
}
})
}
func Test_computeLineFromPos(t *testing.T) {
testcases := []struct {
pos int
wantLine uint
wantCol uint
input string
wantOk bool
}{
{
15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists
},
{
16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line
},
{
25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error
},
{
27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines
},
{
10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo
},
{
11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line
},
{
17, 2, 8, "SELECT *\nFROM foo", true, // last character
},
{
18, 0, 0, "SELECT *\nFROM foo", false, // invalid position
},
}
for i, tc := range testcases {
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
run := func(crlf bool, nonASCII bool) {
var name string
if crlf {
name = "crlf"
} else {
name = "lf"
}
if nonASCII {
name += "-nonascii"
} else {
name += "-ascii"
}
t.Run(name, func(t *testing.T) {
input := tc.input
if crlf {
input = strings.Replace(input, "\n", "\r\n", -1)
}
if nonASCII {
input = strings.Replace(input, "FROM", "FRÖM", -1)
}
gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos)
if tc.wantOk {
t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input)
}
if gotOK != tc.wantOk {
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
}
if gotLine != tc.wantLine {
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
}
if gotCol != tc.wantCol {
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
}
})
}
run(false, false)
run(true, false)
run(false, true)
run(true, true)
})
}
}