193 lines
4.1 KiB
Go
193 lines
4.1 KiB
Go
package db
|
|
|
|
import "database/sql"
|
|
import "errors"
|
|
import "fmt"
|
|
import "os"
|
|
import "path/filepath"
|
|
import "sort"
|
|
import "strings"
|
|
|
|
type Store struct {
|
|
DB *sql.DB
|
|
}
|
|
|
|
func Open(driver, dsn string) (*Store, error) {
|
|
var db *sql.DB
|
|
var drv string
|
|
var err error
|
|
drv = driverName(driver)
|
|
db, err = sql.Open(drv, dsn)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
err = db.Ping()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if drv == "sqlite" {
|
|
_, err = db.Exec(`PRAGMA busy_timeout = 10000`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
_, _ = db.Exec(`PRAGMA journal_mode = WAL`)
|
|
}
|
|
return &Store{DB: db}, nil
|
|
}
|
|
|
|
func driverName(driver string) string {
|
|
switch strings.ToLower(strings.TrimSpace(driver)) {
|
|
case "sqlite", "sqlite3":
|
|
return "sqlite"
|
|
case "postgres", "postgresql":
|
|
return "postgres"
|
|
case "mysql":
|
|
return "mysql"
|
|
default:
|
|
return driver
|
|
}
|
|
}
|
|
|
|
func (s *Store) Close() error {
|
|
return s.DB.Close()
|
|
}
|
|
|
|
func (s *Store) ApplyMigrations(dir string) error {
|
|
var entries []os.DirEntry
|
|
var err error
|
|
var files []string
|
|
var i int
|
|
var e os.DirEntry
|
|
var name string
|
|
var base string
|
|
var version string
|
|
var applied bool
|
|
var content []byte
|
|
entries, err = os.ReadDir(dir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i = 0; i < len(entries); i++ {
|
|
e = entries[i]
|
|
if e.IsDir() {
|
|
continue
|
|
}
|
|
name = e.Name()
|
|
if strings.HasSuffix(name, ".sql") {
|
|
files = append(files, filepath.Join(dir, name))
|
|
}
|
|
}
|
|
sort.Strings(files)
|
|
if len(files) == 0 {
|
|
return errors.New("no migration files found")
|
|
}
|
|
_, err = s.DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (version TEXT PRIMARY KEY)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for i = 0; i < len(files); i++ {
|
|
base = filepath.Base(files[i])
|
|
version = strings.TrimSuffix(base, ".sql")
|
|
applied, err = s.hasMigration(version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if applied {
|
|
continue
|
|
}
|
|
content, err = os.ReadFile(files[i])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.DB.Exec(string(content))
|
|
if err != nil {
|
|
if version == "003_project_unix_updated_by" {
|
|
err = s.applyProjectUnixMigrationFallback()
|
|
if err == nil {
|
|
_, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
continue
|
|
}
|
|
}
|
|
return fmt.Errorf("apply %s: %w", base, err)
|
|
}
|
|
_, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) hasMigration(version string) (bool, error) {
|
|
var v string
|
|
var row *sql.Row
|
|
var err error
|
|
row = s.DB.QueryRow(`SELECT version FROM schema_migrations WHERE version = ?`, version)
|
|
err = row.Scan(&v)
|
|
switch err {
|
|
case nil:
|
|
return true, nil
|
|
case sql.ErrNoRows:
|
|
return false, nil
|
|
default:
|
|
return false, err
|
|
}
|
|
}
|
|
|
|
func (s *Store) applyProjectUnixMigrationFallback() error {
|
|
var err error
|
|
var rows *sql.Rows
|
|
var cols map[string]bool
|
|
var cid int
|
|
var name string
|
|
var ctype string
|
|
var notnull int
|
|
var dflt sql.NullString
|
|
var pk int
|
|
cols = map[string]bool{}
|
|
rows, err = s.DB.Query(`PRAGMA table_info(projects)`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
err = rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
cols[name] = true
|
|
}
|
|
if !cols["updated_by"] {
|
|
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_by TEXT NOT NULL DEFAULT ''`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if !cols["created_at_unix"] {
|
|
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN created_at_unix INTEGER NOT NULL DEFAULT 0`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
if !cols["updated_at_unix"] {
|
|
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_at_unix INTEGER NOT NULL DEFAULT 0`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
_, err = s.DB.Exec(`
|
|
UPDATE projects
|
|
SET
|
|
updated_by = CASE WHEN updated_by = '' THEN created_by ELSE updated_by END,
|
|
created_at_unix = CASE WHEN created_at_unix = 0 THEN COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) ELSE created_at_unix END,
|
|
updated_at_unix = CASE WHEN updated_at_unix = 0 THEN COALESCE(CAST(strftime('%s', updated_at) AS INTEGER), 0) ELSE updated_at_unix END
|
|
`)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|