Files

193 lines
4.1 KiB
Go

package db
import "database/sql"
import "errors"
import "fmt"
import "os"
import "path/filepath"
import "sort"
import "strings"
type Store struct {
DB *sql.DB
}
func Open(driver, dsn string) (*Store, error) {
var db *sql.DB
var drv string
var err error
drv = driverName(driver)
db, err = sql.Open(drv, dsn)
if err != nil {
return nil, err
}
err = db.Ping()
if err != nil {
return nil, err
}
if drv == "sqlite" {
_, err = db.Exec(`PRAGMA busy_timeout = 10000`)
if err != nil {
return nil, err
}
_, _ = db.Exec(`PRAGMA journal_mode = WAL`)
}
return &Store{DB: db}, nil
}
func driverName(driver string) string {
switch strings.ToLower(strings.TrimSpace(driver)) {
case "sqlite", "sqlite3":
return "sqlite"
case "postgres", "postgresql":
return "postgres"
case "mysql":
return "mysql"
default:
return driver
}
}
func (s *Store) Close() error {
return s.DB.Close()
}
func (s *Store) ApplyMigrations(dir string) error {
var entries []os.DirEntry
var err error
var files []string
var i int
var e os.DirEntry
var name string
var base string
var version string
var applied bool
var content []byte
entries, err = os.ReadDir(dir)
if err != nil {
return err
}
for i = 0; i < len(entries); i++ {
e = entries[i]
if e.IsDir() {
continue
}
name = e.Name()
if strings.HasSuffix(name, ".sql") {
files = append(files, filepath.Join(dir, name))
}
}
sort.Strings(files)
if len(files) == 0 {
return errors.New("no migration files found")
}
_, err = s.DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (version TEXT PRIMARY KEY)`)
if err != nil {
return err
}
for i = 0; i < len(files); i++ {
base = filepath.Base(files[i])
version = strings.TrimSuffix(base, ".sql")
applied, err = s.hasMigration(version)
if err != nil {
return err
}
if applied {
continue
}
content, err = os.ReadFile(files[i])
if err != nil {
return err
}
_, err = s.DB.Exec(string(content))
if err != nil {
if version == "003_project_unix_updated_by" {
err = s.applyProjectUnixMigrationFallback()
if err == nil {
_, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version)
if err != nil {
return err
}
continue
}
}
return fmt.Errorf("apply %s: %w", base, err)
}
_, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version)
if err != nil {
return err
}
}
return nil
}
func (s *Store) hasMigration(version string) (bool, error) {
var v string
var row *sql.Row
var err error
row = s.DB.QueryRow(`SELECT version FROM schema_migrations WHERE version = ?`, version)
err = row.Scan(&v)
switch err {
case nil:
return true, nil
case sql.ErrNoRows:
return false, nil
default:
return false, err
}
}
func (s *Store) applyProjectUnixMigrationFallback() error {
var err error
var rows *sql.Rows
var cols map[string]bool
var cid int
var name string
var ctype string
var notnull int
var dflt sql.NullString
var pk int
cols = map[string]bool{}
rows, err = s.DB.Query(`PRAGMA table_info(projects)`)
if err != nil {
return err
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&cid, &name, &ctype, &notnull, &dflt, &pk)
if err != nil {
return err
}
cols[name] = true
}
if !cols["updated_by"] {
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_by TEXT NOT NULL DEFAULT ''`)
if err != nil {
return err
}
}
if !cols["created_at_unix"] {
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN created_at_unix INTEGER NOT NULL DEFAULT 0`)
if err != nil {
return err
}
}
if !cols["updated_at_unix"] {
_, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_at_unix INTEGER NOT NULL DEFAULT 0`)
if err != nil {
return err
}
}
_, err = s.DB.Exec(`
UPDATE projects
SET
updated_by = CASE WHEN updated_by = '' THEN created_by ELSE updated_by END,
created_at_unix = CASE WHEN created_at_unix = 0 THEN COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) ELSE created_at_unix END,
updated_at_unix = CASE WHEN updated_at_unix = 0 THEN COALESCE(CAST(strftime('%s', updated_at) AS INTEGER), 0) ELSE updated_at_unix END
`)
if err != nil {
return err
}
return nil
}