package db import "database/sql" import "errors" import "fmt" import "os" import "path/filepath" import "sort" import "strings" type Store struct { DB *sql.DB } func Open(driver, dsn string) (*Store, error) { var db *sql.DB var drv string var err error drv = driverName(driver) db, err = sql.Open(drv, dsn) if err != nil { return nil, err } err = db.Ping() if err != nil { return nil, err } if drv == "sqlite" { _, err = db.Exec(`PRAGMA busy_timeout = 10000`) if err != nil { return nil, err } _, _ = db.Exec(`PRAGMA journal_mode = WAL`) } return &Store{DB: db}, nil } func driverName(driver string) string { switch strings.ToLower(strings.TrimSpace(driver)) { case "sqlite", "sqlite3": return "sqlite" case "postgres", "postgresql": return "postgres" case "mysql": return "mysql" default: return driver } } func (s *Store) Close() error { return s.DB.Close() } func (s *Store) ApplyMigrations(dir string) error { var entries []os.DirEntry var err error var files []string var i int var e os.DirEntry var name string var base string var version string var applied bool var content []byte entries, err = os.ReadDir(dir) if err != nil { return err } for i = 0; i < len(entries); i++ { e = entries[i] if e.IsDir() { continue } name = e.Name() if strings.HasSuffix(name, ".sql") { files = append(files, filepath.Join(dir, name)) } } sort.Strings(files) if len(files) == 0 { return errors.New("no migration files found") } _, err = s.DB.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (version TEXT PRIMARY KEY)`) if err != nil { return err } for i = 0; i < len(files); i++ { base = filepath.Base(files[i]) version = strings.TrimSuffix(base, ".sql") applied, err = s.hasMigration(version) if err != nil { return err } if applied { continue } content, err = os.ReadFile(files[i]) if err != nil { return err } _, err = s.DB.Exec(string(content)) if err != nil { if version == "003_project_unix_updated_by" { err = s.applyProjectUnixMigrationFallback() if err == nil { _, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version) if err != nil { return err } continue } } return fmt.Errorf("apply %s: %w", base, err) } _, err = s.DB.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version) if err != nil { return err } } return nil } func (s *Store) hasMigration(version string) (bool, error) { var v string var row *sql.Row var err error row = s.DB.QueryRow(`SELECT version FROM schema_migrations WHERE version = ?`, version) err = row.Scan(&v) switch err { case nil: return true, nil case sql.ErrNoRows: return false, nil default: return false, err } } func (s *Store) applyProjectUnixMigrationFallback() error { var err error var rows *sql.Rows var cols map[string]bool var cid int var name string var ctype string var notnull int var dflt sql.NullString var pk int cols = map[string]bool{} rows, err = s.DB.Query(`PRAGMA table_info(projects)`) if err != nil { return err } defer rows.Close() for rows.Next() { err = rows.Scan(&cid, &name, &ctype, ¬null, &dflt, &pk) if err != nil { return err } cols[name] = true } if !cols["updated_by"] { _, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_by TEXT NOT NULL DEFAULT ''`) if err != nil { return err } } if !cols["created_at_unix"] { _, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN created_at_unix INTEGER NOT NULL DEFAULT 0`) if err != nil { return err } } if !cols["updated_at_unix"] { _, err = s.DB.Exec(`ALTER TABLE projects ADD COLUMN updated_at_unix INTEGER NOT NULL DEFAULT 0`) if err != nil { return err } } _, err = s.DB.Exec(` UPDATE projects SET updated_by = CASE WHEN updated_by = '' THEN created_by ELSE updated_by END, created_at_unix = CASE WHEN created_at_unix = 0 THEN COALESCE(CAST(strftime('%s', created_at) AS INTEGER), 0) ELSE created_at_unix END, updated_at_unix = CASE WHEN updated_at_unix = 0 THEN COALESCE(CAST(strftime('%s', updated_at) AS INTEGER), 0) ELSE updated_at_unix END `) if err != nil { return err } return nil }