whatcanGOwrong

This commit is contained in:
2024-09-19 21:38:24 -04:00
commit d0ae4d841d
17908 changed files with 4096831 additions and 0 deletions
@@ -0,0 +1,8 @@
//go:build aws_s3
// +build aws_s3
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/aws_s3"
)
@@ -0,0 +1,8 @@
//go:build bitbucket
// +build bitbucket
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/bitbucket"
)
@@ -0,0 +1,8 @@
//go:build cassandra
// +build cassandra
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/cassandra"
)
@@ -0,0 +1,9 @@
//go:build clickhouse
// +build clickhouse
package cli
import (
_ "github.com/ClickHouse/clickhouse-go"
_ "github.com/golang-migrate/migrate/v4/database/clickhouse"
)
@@ -0,0 +1,8 @@
//go:build cockroachdb
// +build cockroachdb
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/cockroachdb"
)
@@ -0,0 +1,8 @@
//go:build firebird
// +build firebird
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/firebird"
)
@@ -0,0 +1,8 @@
//go:build github
// +build github
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/github"
)
@@ -0,0 +1,8 @@
//go:build github
// +build github
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/github_ee"
)
@@ -0,0 +1,8 @@
//go:build gitlab
// +build gitlab
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/gitlab"
)
@@ -0,0 +1,8 @@
//go:build go_bindata
// +build go_bindata
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/go_bindata"
)
@@ -0,0 +1,8 @@
//go:build godoc_vfs
// +build godoc_vfs
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/godoc_vfs"
)
@@ -0,0 +1,8 @@
//go:build google_cloud_storage
// +build google_cloud_storage
package cli
import (
_ "github.com/golang-migrate/migrate/v4/source/google_cloud_storage"
)
@@ -0,0 +1,8 @@
//go:build mongodb
// +build mongodb
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/mongodb"
)
@@ -0,0 +1,8 @@
//go:build mysql
// +build mysql
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/mysql"
)
@@ -0,0 +1,8 @@
//go:build neo4j
// +build neo4j
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/neo4j"
)
@@ -0,0 +1,8 @@
//go:build pgx
// +build pgx
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/pgx"
)
@@ -0,0 +1,8 @@
//go:build pgx5
// +build pgx5
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/pgx/v5"
)
@@ -0,0 +1,8 @@
//go:build postgres
// +build postgres
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/postgres"
)
@@ -0,0 +1,8 @@
//go:build ql
// +build ql
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/ql"
)
@@ -0,0 +1,8 @@
//go:build redshift
// +build redshift
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/redshift"
)
@@ -0,0 +1,8 @@
//go:build rqlite
// +build rqlite
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/rqlite"
)
@@ -0,0 +1,8 @@
//go:build snowflake
// +build snowflake
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/snowflake"
)
@@ -0,0 +1,8 @@
//go:build spanner
// +build spanner
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/spanner"
)
@@ -0,0 +1,8 @@
//go:build sqlcipher
// +build sqlcipher
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/sqlcipher"
)
@@ -0,0 +1,8 @@
//go:build sqlite
// +build sqlite
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/sqlite"
)
@@ -0,0 +1,8 @@
//go:build sqlite3
// +build sqlite3
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/sqlite3"
)
@@ -0,0 +1,8 @@
//go:build sqlserver
// +build sqlserver
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/sqlserver"
)
@@ -0,0 +1,8 @@
//go:build yugabytedb
// +build yugabytedb
package cli
import (
_ "github.com/golang-migrate/migrate/v4/database/yugabytedb"
)
@@ -0,0 +1,248 @@
package cli
import (
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/stub" // TODO remove again
_ "github.com/golang-migrate/migrate/v4/source/file"
)
var (
errInvalidSequenceWidth = errors.New("Digits must be positive")
errIncompatibleSeqAndFormat = errors.New("The seq and format options are mutually exclusive")
errInvalidTimeFormat = errors.New("Time format may not be empty")
)
func nextSeqVersion(matches []string, seqDigits int) (string, error) {
if seqDigits <= 0 {
return "", errInvalidSequenceWidth
}
nextSeq := uint64(1)
if len(matches) > 0 {
filename := matches[len(matches)-1]
matchSeqStr := filepath.Base(filename)
idx := strings.Index(matchSeqStr, "_")
if idx < 1 { // Using 1 instead of 0 since there should be at least 1 digit
return "", fmt.Errorf("Malformed migration filename: %s", filename)
}
var err error
matchSeqStr = matchSeqStr[0:idx]
nextSeq, err = strconv.ParseUint(matchSeqStr, 10, 64)
if err != nil {
return "", err
}
nextSeq++
}
version := fmt.Sprintf("%0[2]*[1]d", nextSeq, seqDigits)
if len(version) > seqDigits {
return "", fmt.Errorf("Next sequence number %s too large. At most %d digits are allowed", version, seqDigits)
}
return version, nil
}
func timeVersion(startTime time.Time, format string) (version string, err error) {
switch format {
case "":
err = errInvalidTimeFormat
case "unix":
version = strconv.FormatInt(startTime.Unix(), 10)
case "unixNano":
version = strconv.FormatInt(startTime.UnixNano(), 10)
default:
version = startTime.Format(format)
}
return
}
// createCmd (meant to be called via a CLI command) creates a new migration
func createCmd(dir string, startTime time.Time, format string, name string, ext string, seq bool, seqDigits int, print bool) error {
if seq && format != defaultTimeFormat {
return errIncompatibleSeqAndFormat
}
var version string
var err error
dir = filepath.Clean(dir)
ext = "." + strings.TrimPrefix(ext, ".")
if seq {
matches, err := filepath.Glob(filepath.Join(dir, "*"+ext))
if err != nil {
return err
}
version, err = nextSeqVersion(matches, seqDigits)
if err != nil {
return err
}
} else {
version, err = timeVersion(startTime, format)
if err != nil {
return err
}
}
versionGlob := filepath.Join(dir, version+"_*"+ext)
matches, err := filepath.Glob(versionGlob)
if err != nil {
return err
}
if len(matches) > 0 {
return fmt.Errorf("duplicate migration version: %s", version)
}
if err = os.MkdirAll(dir, os.ModePerm); err != nil {
return err
}
for _, direction := range []string{"up", "down"} {
basename := fmt.Sprintf("%s_%s.%s%s", version, name, direction, ext)
filename := filepath.Join(dir, basename)
if err = createFile(filename); err != nil {
return err
}
if print {
absPath, _ := filepath.Abs(filename)
log.Println(absPath)
}
}
return nil
}
func createFile(filename string) error {
// create exclusive (fails if file already exists)
// os.Create() specifies 0666 as the FileMode, so we're doing the same
f, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0666)
if err != nil {
return err
}
return f.Close()
}
func gotoCmd(m *migrate.Migrate, v uint) error {
if err := m.Migrate(v); err != nil {
if err != migrate.ErrNoChange {
return err
}
log.Println(err)
}
return nil
}
func upCmd(m *migrate.Migrate, limit int) error {
if limit >= 0 {
if err := m.Steps(limit); err != nil {
if err != migrate.ErrNoChange {
return err
}
log.Println(err)
}
} else {
if err := m.Up(); err != nil {
if err != migrate.ErrNoChange {
return err
}
log.Println(err)
}
}
return nil
}
func downCmd(m *migrate.Migrate, limit int) error {
if limit >= 0 {
if err := m.Steps(-limit); err != nil {
if err != migrate.ErrNoChange {
return err
}
log.Println(err)
}
} else {
if err := m.Down(); err != nil {
if err != migrate.ErrNoChange {
return err
}
log.Println(err)
}
}
return nil
}
func dropCmd(m *migrate.Migrate) error {
if err := m.Drop(); err != nil {
return err
}
return nil
}
func forceCmd(m *migrate.Migrate, v int) error {
if err := m.Force(v); err != nil {
return err
}
return nil
}
func versionCmd(m *migrate.Migrate) error {
v, dirty, err := m.Version()
if err != nil {
return err
}
if dirty {
log.Printf("%v (dirty)\n", v)
} else {
log.Println(v)
}
return nil
}
// numDownMigrationsFromArgs returns an int for number of migrations to apply
// and a bool indicating if we need a confirm before applying
func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) {
if applyAll {
if len(args) > 0 {
return 0, false, errors.New("-all cannot be used with other arguments")
}
return -1, false, nil
}
switch len(args) {
case 0:
return -1, true, nil
case 1:
downValue := args[0]
n, err := strconv.ParseUint(downValue, 10, 64)
if err != nil {
return 0, false, errors.New("can't read limit argument N")
}
return int(n), false, nil
default:
return 0, false, errors.New("too many arguments")
}
}
@@ -0,0 +1,292 @@
package cli
import (
"errors"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type CreateCmdSuite struct {
suite.Suite
}
func TestCreateCmdSuite(t *testing.T) {
suite.Run(t, &CreateCmdSuite{})
}
func (s *CreateCmdSuite) mustCreateTempDir() string {
tmpDir, err := os.MkdirTemp("", "migrate_")
if err != nil {
s.FailNow(err.Error())
}
return tmpDir
}
func (s *CreateCmdSuite) mustCreateDir(dir string) {
if err := os.MkdirAll(dir, 0755); err != nil {
s.FailNow(err.Error())
}
}
func (s *CreateCmdSuite) mustRemoveDir(dir string) {
if err := os.RemoveAll(dir); err != nil {
s.FailNow(err.Error())
}
}
func (s *CreateCmdSuite) mustWriteFile(dir, file, body string) {
if err := os.WriteFile(filepath.Join(dir, file), []byte(body), 0644); err != nil {
s.FailNow(err.Error())
}
}
func (s *CreateCmdSuite) mustGetwd() string {
cwd, err := os.Getwd()
if err != nil {
s.FailNow(err.Error())
}
return cwd
}
func (s *CreateCmdSuite) mustChdir(dir string) {
if err := os.Chdir(dir); err != nil {
s.FailNow(err.Error())
}
}
func (s *CreateCmdSuite) assertEmptyDir(dir string) bool {
fis, err := os.ReadDir(dir)
if err != nil {
return s.Fail(err.Error())
}
return s.Empty(fis)
}
func (s *CreateCmdSuite) TestNextSeqVersion() {
cases := []struct {
tid string
matches []string
seqDigits int
expected string
expectedErr error
}{
{"Bad digits", []string{}, 0, "", errInvalidSequenceWidth},
{"Single digit initialize", []string{}, 1, "1", nil},
{"Single digit malformed", []string{"bad"}, 1, "", errors.New("Malformed migration filename: bad")},
{"Single digit no int", []string{"bad_bad"}, 1, "", errors.New(`strconv.ParseUint: parsing "bad": invalid syntax`)},
{"Single digit negative seq", []string{"-5_test"}, 1, "", errors.New(`strconv.ParseUint: parsing "-5": invalid syntax`)},
{"Single digit increment", []string{"3_test", "4_test"}, 1, "5", nil},
{"Single digit overflow", []string{"9_test"}, 1, "", errors.New("Next sequence number 10 too large. At most 1 digits are allowed")},
{"Zero-pad initialize", []string{}, 6, "000001", nil},
{"Zero-pad malformed", []string{"bad"}, 6, "", errors.New("Malformed migration filename: bad")},
{"Zero-pad no int", []string{"bad_bad"}, 6, "", errors.New(`strconv.ParseUint: parsing "bad": invalid syntax`)},
{"Zero-pad negative seq", []string{"-000005_test"}, 6, "", errors.New(`strconv.ParseUint: parsing "-000005": invalid syntax`)},
{"Zero-pad increment", []string{"000003_test", "000004_test"}, 6, "000005", nil},
{"Zero-pad overflow", []string{"999999_test"}, 6, "", errors.New("Next sequence number 1000000 too large. At most 6 digits are allowed")},
{"dir absolute path", []string{"/migrationDir/000001_test"}, 6, "000002", nil},
{"dir relative path", []string{"migrationDir/000001_test"}, 6, "000002", nil},
{"dir dot prefix", []string{"./migrationDir/000001_test"}, 6, "000002", nil},
{"dir parent prefix", []string{"../migrationDir/000001_test"}, 6, "000002", nil},
{"dir no prefix", []string{"000001_test"}, 6, "000002", nil},
}
for _, c := range cases {
s.Run(c.tid, func() {
v, err := nextSeqVersion(c.matches, c.seqDigits)
if c.expectedErr != nil {
s.EqualError(err, c.expectedErr.Error())
} else {
s.NoError(err)
s.Equal(c.expected, v)
}
})
}
}
func (s *CreateCmdSuite) TestTimeVersion() {
ts := time.Date(2000, 12, 25, 00, 01, 02, 3456789, time.UTC)
tsUnixStr := strconv.FormatInt(ts.Unix(), 10)
tsUnixNanoStr := strconv.FormatInt(ts.UnixNano(), 10)
cases := []struct {
tid string
time time.Time
format string
expected string
expectedErr error
}{
{"Bad format", ts, "", "", errInvalidTimeFormat},
{"unix", ts, "unix", tsUnixStr, nil},
{"unixNano", ts, "unixNano", tsUnixNanoStr, nil},
{"custom ymthms", ts, "20060102150405", "20001225000102", nil},
}
for _, c := range cases {
s.Run(c.tid, func() {
v, err := timeVersion(c.time, c.format)
if c.expectedErr != nil {
s.EqualError(err, c.expectedErr.Error())
} else {
s.NoError(err)
s.Equal(c.expected, v)
}
})
}
}
// TestCreateCmd tests function createCmd.
//
// For each test case, it creates a temp dir as "sandbox" (called `baseDir`) and
// all path manipulations are relative to `baseDir`.
func (s *CreateCmdSuite) TestCreateCmd() {
ts := time.Date(2000, 12, 25, 00, 01, 02, 3456789, time.UTC)
tsUnixStr := strconv.FormatInt(ts.Unix(), 10)
tsUnixNanoStr := strconv.FormatInt(ts.UnixNano(), 10)
testCwd := s.mustGetwd()
cases := []struct {
tid string
existingDirs []string // directory paths to create before test. relative to baseDir.
cwd string // path to chdir to before test. relative to baseDir.
existingFiles []string // file paths created before test. relative to baseDir.
expectedFiles []string // file paths expected to exist after test. paths relative to baseDir.
expectedErr error
dir string // `dir` parameter. if absolute path, will be converted to baseDir/dir.
startTime time.Time
format string
seq bool
seqDigits int
ext string
name string
}{
{"seq and format", nil, "", nil, nil, errIncompatibleSeqAndFormat, ".", ts, "unix", true, 4, "sql", "name"},
{"seq init dir dot", nil, "", nil, []string{"0001_name.up.sql", "0001_name.down.sql"}, nil, ".", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir dot trailing slash", nil, "", nil, []string{"0001_name.up.sql", "0001_name.down.sql"}, nil, "./", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir double dot", []string{"subdir"}, "subdir", nil, []string{"0001_name.up.sql", "0001_name.down.sql"}, nil, "..", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir double dot trailing slash", []string{"subdir"}, "subdir", nil, []string{"0001_name.up.sql", "0001_name.down.sql"}, nil, "../", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir absolute", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "/subdir", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir absolute trailing slash", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "/subdir/", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir relative", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "subdir", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir relative trailing slash", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "subdir/", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir dot relative", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "./subdir", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir dot relative trailing slash", []string{"subdir"}, "", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "./subdir/", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir double dot relative", []string{"subdir"}, "subdir", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "../subdir", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir double dot relative trailing slash", []string{"subdir"}, "subdir", nil, []string{"subdir/0001_name.up.sql", "subdir/0001_name.down.sql"}, nil, "../subdir/", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq init dir maze", []string{"subdir"}, "subdir", nil, []string{"0001_name.up.sql", "0001_name.down.sql"}, nil, "..//subdir/./.././/subdir/..", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq width invalid", nil, "", nil, nil, errInvalidSequenceWidth, ".", ts, defaultTimeFormat, true, 0, "sql", "name"},
{"seq malformed", nil, "", []string{"bad.sql"}, []string{"bad.sql"}, errors.New("Malformed migration filename: bad.sql"), ".", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq not int", nil, "", []string{"bad_bad.sql"}, []string{"bad_bad.sql"}, errors.New(`strconv.ParseUint: parsing "bad": invalid syntax`), ".", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq negative", nil, "", []string{"-5_negative.sql"}, []string{"-5_negative.sql"}, errors.New(`strconv.ParseUint: parsing "-5": invalid syntax`), ".", ts, defaultTimeFormat, true, 4, "sql", "name"},
{"seq increment", nil, "", []string{"3_three.sql", "4_four.sql"}, []string{"3_three.sql", "4_four.sql", "0005_five.up.sql", "0005_five.down.sql"}, nil, ".", ts, defaultTimeFormat, true, 4, "sql", "five"},
{"seq overflow", nil, "", []string{"9_nine.sql"}, []string{"9_nine.sql"}, errors.New(`Next sequence number 10 too large. At most 1 digits are allowed`), ".", ts, defaultTimeFormat, true, 1, "sql", "ten"},
{"time empty format", nil, "", nil, nil, errInvalidTimeFormat, ".", ts, "", false, 0, "sql", "name"},
{"time unix", nil, "", nil, []string{tsUnixStr + "_name.up.sql", tsUnixStr + "_name.down.sql"}, nil, ".", ts, "unix", false, 0, "sql", "name"},
{"time unixNano", nil, "", nil, []string{tsUnixNanoStr + "_name.up.sql", tsUnixNanoStr + "_name.down.sql"}, nil, ".", ts, "unixNano", false, 0, "sql", "name"},
{"time custom format", nil, "", nil, []string{"20001225000102_name.up.sql", "20001225000102_name.down.sql"}, nil, ".", ts, "20060102150405", false, 0, "sql", "name"},
{"time version collision", nil, "", []string{"20001225_name.up.sql", "20001225_name.down.sql"}, []string{"20001225_name.up.sql", "20001225_name.down.sql"}, errors.New("duplicate migration version: 20001225"), ".", ts, "20060102", false, 0, "sql", "name"},
{"dir invalid", nil, "", []string{"file"}, []string{"file"}, errors.New("mkdir 'test: this is invalid dir name'\x00: invalid argument"), "'test: this is invalid dir name'\000", ts, "unix", false, 0, "sql", "name"},
}
for _, c := range cases {
s.Run(c.tid, func() {
baseDir := s.mustCreateTempDir()
for _, d := range c.existingDirs {
s.mustCreateDir(filepath.Join(baseDir, d))
}
cwd := baseDir
if c.cwd != "" {
cwd = filepath.Join(baseDir, c.cwd)
}
s.mustChdir(cwd)
for _, f := range c.existingFiles {
s.mustWriteFile(baseDir, f, "")
}
dir := c.dir
dir = filepath.ToSlash(dir)
volName := filepath.VolumeName(baseDir)
// Windows specific, can not recognize \subdir as abs path
isWindowsAbsPathNoLetter := strings.HasPrefix(dir, "/") && volName != ""
isRealAbsPath := filepath.IsAbs(dir)
if isWindowsAbsPathNoLetter || isRealAbsPath {
dir = filepath.Join(baseDir, dir)
}
err := createCmd(dir, c.startTime, c.format, c.name, c.ext, c.seq, c.seqDigits, false)
if c.expectedErr != nil {
s.EqualError(err, c.expectedErr.Error())
} else {
s.NoError(err)
}
if len(c.expectedFiles) == 0 {
s.assertEmptyDir(baseDir)
} else {
for _, f := range c.expectedFiles {
s.FileExists(filepath.Join(baseDir, f))
}
}
s.mustChdir(testCwd)
s.mustRemoveDir(baseDir)
})
}
}
func TestNumDownFromArgs(t *testing.T) {
cases := []struct {
name string
args []string
applyAll bool
expectedNeedConfirm bool
expectedNum int
expectedErrStr string
}{
{"no args", []string{}, false, true, -1, ""},
{"down all", []string{}, true, false, -1, ""},
{"down 5", []string{"5"}, false, false, 5, ""},
{"down N", []string{"N"}, false, false, 0, "can't read limit argument N"},
{"extra arg after -all", []string{"5"}, true, false, 0, "-all cannot be used with other arguments"},
{"extra arg before -all", []string{"5", "-all"}, false, false, 0, "too many arguments"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
num, needsConfirm, err := numDownMigrationsFromArgs(c.applyAll, c.args)
if needsConfirm != c.expectedNeedConfirm {
t.Errorf("Incorrect needsConfirm was: %v wanted %v", needsConfirm, c.expectedNeedConfirm)
}
if num != c.expectedNum {
t.Errorf("Incorrect num was: %v wanted %v", num, c.expectedNum)
}
if err != nil {
if err.Error() != c.expectedErrStr {
t.Error("Incorrect error: " + err.Error() + " != " + c.expectedErrStr)
}
} else if c.expectedErrStr != "" {
t.Error("Expected error: " + c.expectedErrStr + " but got nil instead")
}
})
}
}
@@ -0,0 +1,44 @@
package cli
import (
"fmt"
logpkg "log"
"os"
)
// Log represents the logger
type Log struct {
verbose bool
}
// Printf prints out formatted string into a log
func (l *Log) Printf(format string, v ...interface{}) {
if l.verbose {
logpkg.Printf(format, v...)
} else {
fmt.Fprintf(os.Stderr, format, v...)
}
}
// Println prints out args into a log
func (l *Log) Println(args ...interface{}) {
if l.verbose {
logpkg.Println(args...)
} else {
fmt.Fprintln(os.Stderr, args...)
}
}
// Verbose shows if verbose print enabled
func (l *Log) Verbose() bool {
return l.verbose
}
func (l *Log) fatal(args ...interface{}) {
l.Println(args...)
os.Exit(1)
}
func (l *Log) fatalErr(err error) {
l.fatal("error:", err)
}
@@ -0,0 +1,377 @@
package cli
import (
"flag"
"fmt"
"os"
"os/signal"
"strconv"
"strings"
"syscall"
"time"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database"
"github.com/golang-migrate/migrate/v4/source"
)
const (
defaultTimeFormat = "20060102150405"
defaultTimezone = "UTC"
createUsage = `create [-ext E] [-dir D] [-seq] [-digits N] [-format] [-tz] NAME
Create a set of timestamped up/down migrations titled NAME, in directory D with extension E.
Use -seq option to generate sequential up/down migrations with N digits.
Use -format option to specify a Go time format string. Note: migrations with the same time cause "duplicate migration version" error.
Use -tz option to specify the timezone that will be used when generating non-sequential migrations (defaults: UTC).
`
gotoUsage = `goto V Migrate to version V`
upUsage = `up [N] Apply all or N up migrations`
downUsage = `down [N] [-all] Apply all or N down migrations
Use -all to apply all down migrations`
dropUsage = `drop [-f] Drop everything inside database
Use -f to bypass confirmation`
forceUsage = `force V Set version V but don't run migration (ignores dirty state)`
)
func handleSubCmdHelp(help bool, usage string, flagSet *flag.FlagSet) {
if help {
fmt.Fprintln(os.Stderr, usage)
flagSet.PrintDefaults()
os.Exit(0)
}
}
func newFlagSetWithHelp(name string) (*flag.FlagSet, *bool) {
flagSet := flag.NewFlagSet(name, flag.ExitOnError)
helpPtr := flagSet.Bool("help", false, "Print help information")
return flagSet, helpPtr
}
// set main log
var log = &Log{}
func printUsageAndExit() {
flag.Usage()
// If a command is not found we exit with a status 2 to match the behavior
// of flag.Parse() with flag.ExitOnError when parsing an invalid flag.
os.Exit(2)
}
// Main function of a cli application. It is public for backwards compatibility with `cli` package
func Main(version string) {
helpPtr := flag.Bool("help", false, "")
versionPtr := flag.Bool("version", false, "")
verbosePtr := flag.Bool("verbose", false, "")
prefetchPtr := flag.Uint("prefetch", 10, "")
lockTimeoutPtr := flag.Uint("lock-timeout", 15, "")
pathPtr := flag.String("path", "", "")
databasePtr := flag.String("database", "", "")
sourcePtr := flag.String("source", "", "")
flag.Usage = func() {
fmt.Fprintf(os.Stderr,
`Usage: migrate OPTIONS COMMAND [arg...]
migrate [ -version | -help ]
Options:
-source Location of the migrations (driver://url)
-path Shorthand for -source=file://path
-database Run migrations against this database (driver://url)
-prefetch N Number of migrations to load in advance before executing (default 10)
-lock-timeout N Allow N seconds to acquire database lock (default 15)
-verbose Print verbose logging
-version Print version
-help Print usage
Commands:
%s
%s
%s
%s
%s
%s
version Print current migration version
Source drivers: `+strings.Join(source.List(), ", ")+`
Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoUsage, upUsage, downUsage, dropUsage, forceUsage)
}
flag.Parse()
// initialize logger
log.verbose = *verbosePtr
// show cli version
if *versionPtr {
fmt.Fprintln(os.Stderr, version)
os.Exit(0)
}
// show help
if *helpPtr {
flag.Usage()
os.Exit(0)
}
// translate -path into -source if given
if *sourcePtr == "" && *pathPtr != "" {
*sourcePtr = fmt.Sprintf("file://%v", *pathPtr)
}
// initialize migrate
// don't catch migraterErr here and let each command decide
// how it wants to handle the error
migrater, migraterErr := migrate.New(*sourcePtr, *databasePtr)
defer func() {
if migraterErr == nil {
if _, err := migrater.Close(); err != nil {
log.Println(err)
}
}
}()
if migraterErr == nil {
migrater.Log = log
migrater.PrefetchMigrations = *prefetchPtr
migrater.LockTimeout = time.Duration(int64(*lockTimeoutPtr)) * time.Second
// handle Ctrl+c
signals := make(chan os.Signal, 1)
signal.Notify(signals, syscall.SIGINT)
go func() {
for range signals {
log.Println("Stopping after this running migration ...")
migrater.GracefulStop <- true
return
}
}()
}
startTime := time.Now()
if len(flag.Args()) < 1 {
printUsageAndExit()
}
args := flag.Args()[1:]
switch flag.Arg(0) {
case "create":
seq := false
seqDigits := 6
createFlagSet, help := newFlagSetWithHelp("create")
extPtr := createFlagSet.String("ext", "", "File extension")
dirPtr := createFlagSet.String("dir", "", "Directory to place file in (default: current working directory)")
formatPtr := createFlagSet.String("format", defaultTimeFormat, `The Go time format string to use. If the string "unix" or "unixNano" is specified, then the seconds or nanoseconds since January 1, 1970 UTC respectively will be used. Caution, due to the behavior of time.Time.Format(), invalid format strings will not error`)
timezoneName := createFlagSet.String("tz", defaultTimezone, `The timezone that will be used for generating timestamps (default: utc)`)
createFlagSet.BoolVar(&seq, "seq", seq, "Use sequential numbers instead of timestamps (default: false)")
createFlagSet.IntVar(&seqDigits, "digits", seqDigits, "The number of digits to use in sequences (default: 6)")
if err := createFlagSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*help, createUsage, createFlagSet)
if createFlagSet.NArg() == 0 {
log.fatal("error: please specify name")
}
name := createFlagSet.Arg(0)
if *extPtr == "" {
log.fatal("error: -ext flag must be specified")
}
timezone, err := time.LoadLocation(*timezoneName)
if err != nil {
log.fatal(err)
}
if err := createCmd(*dirPtr, startTime.In(timezone), *formatPtr, name, *extPtr, seq, seqDigits, true); err != nil {
log.fatalErr(err)
}
case "goto":
gotoSet, helpPtr := newFlagSetWithHelp("goto")
if err := gotoSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*helpPtr, gotoUsage, gotoSet)
if migraterErr != nil {
log.fatalErr(migraterErr)
}
if gotoSet.NArg() == 0 {
log.fatal("error: please specify version argument V")
}
v, err := strconv.ParseUint(gotoSet.Arg(0), 10, 64)
if err != nil {
log.fatal("error: can't read version argument V")
}
if err := gotoCmd(migrater, uint(v)); err != nil {
log.fatalErr(err)
}
if log.verbose {
log.Println("Finished after", time.Since(startTime))
}
case "up":
upSet, helpPtr := newFlagSetWithHelp("up")
if err := upSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*helpPtr, upUsage, upSet)
if migraterErr != nil {
log.fatalErr(migraterErr)
}
limit := -1
if upSet.NArg() > 0 {
n, err := strconv.ParseUint(upSet.Arg(0), 10, 64)
if err != nil {
log.fatal("error: can't read limit argument N")
}
limit = int(n)
}
if err := upCmd(migrater, limit); err != nil {
log.fatalErr(err)
}
if log.verbose {
log.Println("Finished after", time.Since(startTime))
}
case "down":
downFlagSet, helpPtr := newFlagSetWithHelp("down")
applyAll := downFlagSet.Bool("all", false, "Apply all down migrations")
if err := downFlagSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*helpPtr, downUsage, downFlagSet)
if migraterErr != nil {
log.fatalErr(migraterErr)
}
downArgs := downFlagSet.Args()
num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs)
if err != nil {
log.fatalErr(err)
}
if needsConfirm {
log.Println("Are you sure you want to apply all down migrations? [y/N]")
var response string
fmt.Scanln(&response)
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" {
log.Println("Applying all down migrations")
} else {
log.fatal("Not applying all down migrations")
}
}
if err := downCmd(migrater, num); err != nil {
log.fatalErr(err)
}
if log.verbose {
log.Println("Finished after", time.Since(startTime))
}
case "drop":
dropFlagSet, help := newFlagSetWithHelp("drop")
forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt")
if err := dropFlagSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*help, dropUsage, dropFlagSet)
if !*forceDrop {
log.Println("Are you sure you want to drop the entire database schema? [y/N]")
var response string
fmt.Scanln(&response)
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" {
log.Println("Dropping the entire database schema")
} else {
log.fatal("Aborted dropping the entire database schema")
}
}
if migraterErr != nil {
log.fatalErr(migraterErr)
}
if err := dropCmd(migrater); err != nil {
log.fatalErr(err)
}
if log.verbose {
log.Println("Finished after", time.Since(startTime))
}
case "force":
forceSet, helpPtr := newFlagSetWithHelp("force")
if err := forceSet.Parse(args); err != nil {
log.fatalErr(err)
}
handleSubCmdHelp(*helpPtr, forceUsage, forceSet)
if migraterErr != nil {
log.fatalErr(migraterErr)
}
if forceSet.NArg() == 0 {
log.fatal("error: please specify version argument V")
}
v, err := strconv.ParseInt(forceSet.Arg(0), 10, 64)
if err != nil {
log.fatal("error: can't read version argument V")
}
if v < -1 {
log.fatal("error: argument V must be >= -1")
}
if err := forceCmd(migrater, int(v)); err != nil {
log.fatalErr(err)
}
if log.verbose {
log.Println("Finished after", time.Since(startTime))
}
case "version":
if migraterErr != nil {
log.fatalErr(migraterErr)
}
if err := versionCmd(migrater); err != nil {
log.fatalErr(err)
}
default:
printUsageAndExit()
}
}
@@ -0,0 +1,25 @@
package url
import (
"errors"
"strings"
)
var errNoScheme = errors.New("no scheme")
var errEmptyURL = errors.New("URL cannot be empty")
// schemeFromURL returns the scheme from a URL string
func SchemeFromURL(url string) (string, error) {
if url == "" {
return "", errEmptyURL
}
i := strings.Index(url, ":")
// No : or : is the first character.
if i < 1 {
return "", errNoScheme
}
return url[0:i], nil
}
@@ -0,0 +1,48 @@
package url
import (
"testing"
)
func TestSchemeFromUrl(t *testing.T) {
cases := []struct {
name string
urlStr string
expected string
expectErr error
}{
{
name: "Simple",
urlStr: "protocol://path",
expected: "protocol",
},
{
// See issue #264
name: "MySQLWithPort",
urlStr: "mysql://user:pass@tcp(host:1337)/db",
expected: "mysql",
},
{
name: "Empty",
urlStr: "",
expectErr: errEmptyURL,
},
{
name: "NoScheme",
urlStr: "hello",
expectErr: errNoScheme,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s, err := SchemeFromURL(tc.urlStr)
if err != tc.expectErr {
t.Fatalf("expected %q, but received %q", tc.expectErr, err)
}
if s != tc.expected {
t.Fatalf("expected %q, but received %q", tc.expected, s)
}
})
}
}