From fd70bdb8219210b3ebb95d00844b3c86899e03fc Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Sun, 30 Jun 2024 07:13:17 +0930 Subject: [PATCH 1/6] cmd/worklog/pgstore: new package for postgres support --- .github/workflows/ci.yml | 47 ++ cmd/worklog/pgstore/db.go | 991 +++++++++++++++++++++++++++++++++ cmd/worklog/pgstore/db_test.go | 751 +++++++++++++++++++++++++ go.mod | 4 + go.sum | 17 +- 5 files changed, 1808 insertions(+), 2 deletions(-) create mode 100644 cmd/worklog/pgstore/db.go create mode 100644 cmd/worklog/pgstore/db_test.go diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 877f377..62df736 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -75,3 +75,50 @@ jobs: sudo apt-get remove -qq libxss-dev libxres-dev libx11-dev sudo apt-get autoremove -qq go build -tags no_xorg ./cmd/watcher + + postgres-test: + runs-on: ubuntu-latest + services: + postgres: + image: postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: password + POSTGRES_DB: postgres + options: >- + --name postgres + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + + env: + PGHOST: localhost + PGPORT: 5432 + PGUSER: test_user + PGPASSWORD: password + + steps: + - name: install Go + uses: actions/setup-go@v4 + with: + go-version: ${{ matrix.go-version }} + + - name: checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: unit tests postgres + run: | + psql --host $PGHOST \ + --username="postgres" \ + --dbname="postgres" \ + --command="CREATE USER $PGUSER PASSWORD '$PGPASSWORD'" \ + --command="ALTER USER $PGUSER CREATEDB" \ + --command="CREATE USER ${PGUSER}_ro PASSWORD '$PGPASSWORD'" \ + --command="\du" + + go test ./cmd/worklog/pgstore diff --git a/cmd/worklog/pgstore/db.go b/cmd/worklog/pgstore/db.go new file mode 100644 index 0000000..94964ea --- /dev/null +++ b/cmd/worklog/pgstore/db.go @@ -0,0 +1,991 @@ +// Copyright ©2024 Dan Kortschak. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pgstore provides a worklog data storage layer using PostgreSQL. +package pgstore + +import ( + "bufio" + "bytes" + "compress/gzip" + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "golang.org/x/sys/execabs" + + worklog "github.com/kortschak/dex/cmd/worklog/api" +) + +// DB is a persistent store. +type DB struct { + name string + host string + store *pgx.Conn + roStore *pgx.Conn + + roWarn error +} + +type execer interface { + Exec(ctx context.Context, query string, args ...any) (pgconn.CommandTag, error) +} + +type result struct { + pgconn.CommandTag +} + +func (r result) RowsAffected() (int64, error) { + return r.CommandTag.RowsAffected(), nil +} + +func (r result) LastInsertId() (int64, error) { + return 0, nil +} + +type querier interface { + Query(ctx context.Context, query string, args ...any) (pgx.Rows, error) +} + +func txDone(ctx context.Context, tx pgx.Tx, err *error) { + if *err == nil { + *err = tx.Commit(ctx) + } else { + *err = errors.Join(*err, tx.Rollback(ctx)) + } +} + +// Open opens a PostgresSQL DB. See [pgx.Connect] for name handling details. +// Two connections to the database are created, one using the username within +// the name parameter and one with the same username, but "_ro" appended. If +// the second user does not have SELECT role_table_grants for the buckets' and +// 'events' the database, or has any non-SELECT role_table_grants, the second +// connection will be closed and a [Warning] error will be returned. If the +// second connection is closed, the [DB.Select] method will return a non-nil +// error and perform no DB operation. Open attempts to get the CNAME for the +// host, which may wait indefinitely, so a timeout context can be provided to +// fall back to the kernel-provided hostname. +func Open(ctx context.Context, name, host string) (*DB, error) { + u, err := url.Parse(name) + if err != nil { + return nil, err + } + if u.User == nil { + return nil, fmt.Errorf("missing user info: %s", u) + } + + if host == "" { + host, err = hostname(ctx) + if err != nil { + return nil, err + } + } + + pgHost, pgPort, err := net.SplitHostPort(u.Host) + if err != nil { + return nil, err + } + if _, ok := u.User.Password(); !ok { + userInfo, err := pgUserinfo(u.User.Username(), pgHost, pgPort, os.Getenv("PGPASSWORD")) + if err != nil { + return nil, err + } + u.User = userInfo + } + + db, err := pgx.Connect(ctx, name) + if err != nil { + return nil, err + } + _, err = db.Exec(ctx, Schema) + if err != nil { + return nil, errors.Join(err, db.Close(ctx)) + } + + var dbRO *pgx.Conn + roUser := u.User.Username() + "_ro" + userInfo, warn := pgUserinfo(roUser, pgHost, pgPort, "") + if warn != nil { + warn = Warning{error: fmt.Errorf("ro user failed to get password: %w", warn)} + } else { + u.User = userInfo + + // Check that the ro user can read the stores tables + // and has no other grants. If either of these checks + // fail, return a Warning error. If the user has other + // grants, deny its use. + const ( + nonSelects = `select + count(*) = 0 + from + information_schema.role_table_grants + where + not privilege_type ilike 'SELECT' + and grantee = $1` + otherSelects = `select + count(distinct table_name) = 0 + from + information_schema.role_table_grants + where + privilege_type ilike 'SELECT' and not table_name in ('buckets', 'events') + and grantee = $1` + selects = `select + count(distinct table_name) = 2 + from + information_schema.role_table_grants + where + privilege_type ilike 'SELECT' and table_name in ('buckets', 'events') + and grantee = $1` + ) + for _, check := range []struct { + statement string + warn Warning + }{ + { + statement: nonSelects, + warn: Warning{ + error: errors.New("ro user failed capability restriction checks"), + allow: false, + }, + }, + { + statement: otherSelects, + warn: Warning{ + error: errors.New("ro user failed table read capability restriction checks"), + allow: false, + }, + }, + { + statement: selects, + warn: Warning{ + error: errors.New("ro user failed read capability checks"), + allow: true, + }, + }, + } { + var ok bool + err = db.QueryRow(ctx, check.statement, roUser).Scan(&ok) + if err != nil { + db.Close(ctx) + return nil, err + } + if !ok { + warn = check.warn + break + } + } + if w, ok := warn.(Warning); err == nil || (ok || w.allow) { + dbRO, err = pgx.Connect(ctx, u.String()) + if err != nil { + warn = Warning{error: err} + dbRO = nil + } + } + } + u.User = nil + + return &DB{name: u.String(), host: host, store: db, roStore: dbRO, roWarn: warn}, warn +} + +func pgUserinfo(pgUser, pgHost, pgPort, pgPassword string) (*url.Userinfo, error) { + if pgPassword != "" { + return url.UserPassword(pgUser, pgPassword), nil + } + + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("could not get home directory: %w", err) + } + pgpass, err := os.Open(filepath.Join(home, ".pgpass")) + if err != nil { + return nil, fmt.Errorf("could not open .pgpass: %w", err) + } + defer pgpass.Close() + fi, err := pgpass.Stat() + if err != nil { + return nil, fmt.Errorf("could not stat .pgpass: %v", err) + } + if fi.Mode()&0o077 != 0o000 { + return nil, fmt.Errorf(".pgpass permissions too relaxed: %s", fi.Mode()) + } + sc := bufio.NewScanner(pgpass) + found := false + var e pgPassEntry + for sc.Scan() { + line := strings.TrimSpace(sc.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + e, err = parsePgPassLine(line) + if err != nil { + return nil, fmt.Errorf("could not parse .pgpass: %w", err) + } + if e.match(pgUser, pgHost, pgPort, "*") { + found = true + break + } + } + if sc.Err() != nil { + return nil, fmt.Errorf("unexpected error reading .pgpass: %w", err) + } + if !found { + return nil, errors.New("must have postgres password in $PGPASSWORD or .pgpass") + } + + return url.UserPassword(pgUser, e.password), nil +} + +type pgPassEntry struct { + host string + port string + database string + user string + password string +} + +func (e pgPassEntry) match(user, host, port, database string) bool { + return user == e.user && + (host == e.host || e.host == "*") && + (port == e.port || e.port == "*") && + (database == e.database || e.database == "*") +} + +func parsePgPassLine(text string) (pgPassEntry, error) { + var ( + entry pgPassEntry + field int + last int + escape bool + ) + for i, r := range text { + switch r { + case '\\': + escape = !escape + continue + case ':': + if escape { + break + } + switch field { + case 0: + entry.host = text[last:i] + case 1: + entry.port = text[last:i] + case 2: + entry.database = text[last:i] + case 3: + entry.user = text[last:i] + default: + return entry, errors.New("too many fields") + } + last = i + 1 + field++ + } + escape = false + } + entry.password = text[last:] + return entry, nil +} + +// Warning is a warning-only error. +type Warning struct { + error + allow bool +} + +// hostname returns the FQDN of the local host, falling back to the hostname +// reported by the kernel if CNAME lookup fails. +func hostname(ctx context.Context) (string, error) { + host, err := os.Hostname() + if err != nil { + return "", err + } + cname, err := net.DefaultResolver.LookupCNAME(ctx, host) + if err != nil { + return host, nil + } + return strings.TrimSuffix(cname, "."), nil +} + +// Name returns the name of the database as provided to Open. +func (db *DB) Name() string { + if db == nil { + return "" + } + return db.name +} + +// Backup creates a backup of the DB using the pg_dump command into the provided +// directory as a gzip file. It returns the path of the backup. +func (db *DB) Backup(ctx context.Context, dir string) (string, error) { + u, err := url.Parse(db.name) + if err != nil { + return "", err + } + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + return "", err + } + dbname := path.Base(u.Path) + + dst := filepath.Join(dir, dbname+"_"+time.Now().In(time.UTC).Format("20060102150405")+".gz") + cmd := execabs.Command("pg_dump", "-h", host, "-p", port, dbname) + f, err := os.OpenFile(dst, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0o600) + if err != nil { + return "", err + } + w := gzip.NewWriter(f) + cmd.Stdout = w + var buf bytes.Buffer + cmd.Stderr = &buf + err = cmd.Run() + if err != nil { + return "", fmt.Errorf("%s: %w", bytes.TrimSpace(buf.Bytes()), err) + } + err = errors.Join(w.Close(), f.Sync(), f.Close()) + return dst, err +} + +// Close closes the database. +func (db *DB) Close(ctx context.Context) error { + var roErr error + if db.roStore != nil { + roErr = db.roStore.Close(ctx) + } + return errors.Join(db.store.Close(ctx), roErr) +} + +// Schema is the DB schema. +const Schema = ` +create table if not exists buckets ( + rowid SERIAL PRIMARY KEY, + id TEXT UNIQUE NOT NULL, + name TEXT NOT NULL, + type TEXT NOT NULL, + client TEXT NOT NULL, + hostname TEXT NOT NULL, + created TIMESTAMP WITH TIME ZONE NOT NULL, + timezone TEXT NOT NULL, -- tz of created, not mutated after first write + datastr JSONB NOT NULL +); +create table if not exists events ( + id SERIAL PRIMARY KEY, + bucketrow INTEGER NOT NULL, + starttime TIMESTAMP WITH TIME ZONE NOT NULL, + endtime TIMESTAMP WITH TIME ZONE NOT NULL, + timezone TEXT NOT NULL, -- tz of starttime, not mutated after first write + datastr JSONB NOT NULL, + FOREIGN KEY (bucketrow) REFERENCES buckets(rowid) +); +create index if not exists event_index_id ON events(id); +create index if not exists event_index_starttime ON events(bucketrow, starttime); +create index if not exists event_index_endtime ON events(bucketrow, endtime); +` + +const tzFormat = "-07:00" + +// BucketID returns the internal bucket ID for the provided bucket uid. +func (db *DB) BucketID(uid string) string { + return fmt.Sprintf("%s_%s", uid, db.host) +} + +const CreateBucket = `insert into buckets(id, name, type, client, hostname, created, timezone, datastr) values ($1, $2, $3, $4, $5, $6, $7, $8) +on conflict (id) do nothing;` + +// CreateBucket creates a new entry in the bucket table. +// The SQL command run is [CreateBucket]. +func (db *DB) CreateBucket(ctx context.Context, uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) { + bid := db.BucketID(uid) + tx, err := db.store.Begin(ctx) + if err != nil { + return nil, err + } + defer txDone(ctx, tx, &err) + return createBucket(ctx, tx, bid, name, typ, client, db.host, created, data) +} + +func createBucket(ctx context.Context, tx pgx.Tx, bid, name, typ, client, host string, created time.Time, data map[string]any) (*worklog.BucketMetadata, error) { + if data == nil { + // datastr has a NOT NULL constraint. + data = make(map[string]any) + } + _, err := tx.Exec(ctx, CreateBucket, bid, name, typ, client, host, created, created.Format(tzFormat), data) + if err != nil { + return nil, err + } + m, err := bucketMetadata(ctx, tx, bid) + if err != nil { + return nil, err + } + return m, nil +} + +const BucketMetadata = `select id, name, type, client, hostname, created, timezone, datastr from buckets where id = $1` + +// BucketMetadata returns the metadata for the bucket with the provided internal +// bucket ID. +// The SQL command run is [BucketMetadata]. +func (db *DB) BucketMetadata(ctx context.Context, bid string) (*worklog.BucketMetadata, error) { + return bucketMetadata(ctx, db.store, bid) +} + +func bucketMetadata(ctx context.Context, db querier, bid string) (*worklog.BucketMetadata, error) { + rows, err := db.Query(ctx, BucketMetadata, bid) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, io.EOF + } + var ( + m worklog.BucketMetadata + tz string + ) + err = rows.Scan(&m.ID, &m.Name, &m.Type, &m.Client, &m.Hostname, &m.Created, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return &m, fmt.Errorf("invalid timezone for %s bucket %s: %w", m.Name, m.ID, err) + } + m.Created = m.Created.In(timezone.Location()) + if rows.Next() { + return &m, errors.New("unexpected item") + } + return &m, nil +} + +const InsertEvent = `insert into events(bucketrow, starttime, endtime, timezone, datastr) values ((select rowid from buckets where id = $1), $2, $3, $4, $5)` + +// InsertEvent inserts a new event into the events table. +// The SQL command run is [InsertEvent]. +func (db *DB) InsertEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { + bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) + return insertEvent(ctx, db.store, bid, e) +} + +func insertEvent(ctx context.Context, db execer, bid string, e *worklog.Event) (sql.Result, error) { + res, err := db.Exec(ctx, InsertEvent, bid, e.Start, e.End, e.Start.Format(tzFormat), e.Data) + return result{res}, err +} + +const UpdateEvent = `update events set starttime = $1, endtime = $2, datastr = $3 where id = $4 and bucketrow = ( + select rowid from buckets where id = $5 +)` + +// UpdateEvent updates the event in the store corresponding to the provided +// event. +// The SQL command run is [UpdateEvent]. +func (db *DB) UpdateEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { + bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) + res, err := db.store.Exec(ctx, UpdateEvent, e.Start, e.End, e.Data, e.ID, bid) + return result{res}, err +} + +const LastEvent = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime = ( + select max(endtime) from events where bucketrow = ( + select rowid from buckets where id = $1 + ) limit 1 +) limit 1` + +// LastEvent returns the last event in the named bucket. +// The SQL command run is [LastEvent]. +func (db *DB) LastEvent(ctx context.Context, uid string) (*worklog.Event, error) { + bid := db.BucketID(uid) + rows, err := db.store.Query(ctx, LastEvent, bid) + if err != nil { + return nil, err + } + defer rows.Close() + if !rows.Next() { + return nil, io.EOF + } + var ( + e worklog.Event + tz string + ) + err = rows.Scan(&e.ID, &e.Start, &e.End, &tz, &e.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return &e, fmt.Errorf("invalid timezone for event %d: %w", e.ID, err) + } + loc := timezone.Location() + e.Start = e.Start.In(loc) + e.End = e.End.In(loc) + if rows.Next() { + return &e, errors.New("unexpected item") + } + return &e, nil +} + +// Dump dumps the complete database into a slice of [worklog.BucketMetadata]. +func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { + m, err := db.buckets(ctx) + if err != nil { + return nil, err + } + for i, b := range m { + bucket, ok := strings.CutSuffix(b.ID, "_"+b.Hostname) + if !ok { + return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) + } + e, err := db.events(ctx, b.ID) + if err != nil { + return m, err + } + for j := range e { + e[j].Bucket = bucket + } + m[i].Events = e + } + return m, nil +} + +// DumpRange dumps the database spanning the specified time range into a slice +// of [worklog.BucketMetadata]. +func (db *DB) DumpRange(ctx context.Context, start, end time.Time) ([]worklog.BucketMetadata, error) { + m, err := db.buckets(ctx) + if err != nil { + return nil, err + } + for i, b := range m { + bucket, ok := strings.CutSuffix(b.ID, "_"+b.Hostname) + if !ok { + return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) + } + e, err := db.dumpEventsRange(ctx, b.ID, start, end, nil) + if err != nil { + return m, err + } + for j := range e { + e[j].Bucket = bucket + } + m[i].Events = e + } + return m, nil +} + +const ( + dumpEventsRange = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 and starttime <= $3 limit $4` + + dumpEventsRangeUntil = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and starttime <= $2 limit $3` + + dumpEventsRangeFrom = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 limit $3` + + dumpEventsLimit = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) limit $2` +) + +func (db *DB) dumpEventsRange(ctx context.Context, bid string, start, end time.Time, limit *int) ([]worklog.Event, error) { + var e []worklog.Event + err := db.eventsRangeFunc(ctx, bid, start, end, limit, func(m worklog.Event) error { + e = append(e, m) + return nil + }, false) + return e, err +} + +// Load loads a complete database from a slice of [worklog.BucketMetadata]. +// Event IDs will be regenerated by the backing database and so will not +// match the input data. If replace is true and a bucket already exists matching +// the bucket in the provided buckets slice, the existing events will be +// deleted and replaced. If replace is false, the new events will be added to +// the existing events in the store. +func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replace bool) (err error) { + tx, err := db.store.Begin(ctx) + if err != nil { + return err + } + defer txDone(ctx, tx, &err) + for _, m := range buckets { + var b *worklog.BucketMetadata + b, err = createBucket(ctx, tx, m.ID, m.Name, m.Type, m.Client, m.Hostname, m.Created, m.Data) + if !sameBucket(&m, b) { + return fmt.Errorf("mismatched bucket: %s != %s", bucketString(&m), bucketString(b)) + } + if replace { + _, err = tx.Exec(ctx, DeleteBucketEvents, m.ID) + if err != nil { + return err + } + } + for i, e := range m.Events { + bid := fmt.Sprintf("%s_%s", e.Bucket, m.Hostname) + _, err = insertEvent(ctx, tx, bid, &m.Events[i]) + if err != nil { + return err + } + } + } + return nil +} + +func sameBucket(a, b *worklog.BucketMetadata) bool { + return a.ID == b.ID && + a.Name == b.Name && + a.Type == b.Type && + a.Client == b.Client && + a.Hostname == b.Hostname +} + +func bucketString(b *worklog.BucketMetadata) string { + return fmt.Sprintf("{id=%s name=%s type=%s client=%s hostname=%s}", b.ID, + b.Name, + b.Type, + b.Client, + b.Hostname) +} + +const Buckets = `select id, name, type, client, hostname, created, timezone, datastr from buckets` + +// Buckets returns the full set of bucket metadata. +// The SQL command run is [Buckets]. +func (db *DB) Buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { + return db.buckets(ctx) +} + +func (db *DB) buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { + rows, err := db.store.Query(ctx, Buckets) + if err != nil { + return nil, err + } + defer rows.Close() + var b []worklog.BucketMetadata + for rows.Next() { + var ( + m worklog.BucketMetadata + tz string + ) + err = rows.Scan(&m.ID, &m.Name, &m.Type, &m.Client, &m.Hostname, &m.Created, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return nil, fmt.Errorf("invalid timezone for %s bucket %s: %w", m.Name, m.ID, err) + } + m.Created = m.Created.In(timezone.Location()) + b = append(b, m) + } + return b, nil +} + +const Event = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and id = $2 limit 1` + +const Events = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +)` + +// Buckets returns the full set of events in the bucket with the provided +// internal bucket ID. +// The SQL command run is [Events]. +func (db *DB) Events(ctx context.Context, bid string) ([]worklog.Event, error) { + return db.events(ctx, bid) +} + +func (db *DB) events(ctx context.Context, bid string) ([]worklog.Event, error) { + rows, err := db.store.Query(ctx, Events, bid) + if err != nil { + return nil, err + } + defer rows.Close() + var e []worklog.Event + for rows.Next() { + var ( + m worklog.Event + tz string + ) + err = rows.Scan(&m.ID, &m.Start, &m.End, &tz, &m.Data) + if err != nil { + return nil, err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return nil, fmt.Errorf("invalid timezone for event %d: %w", m.ID, err) + } + loc := timezone.Location() + m.Start = m.Start.In(loc) + m.End = m.End.In(loc) + e = append(e, m) + } + return e, nil +} + +const ( + EventsRange = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 and starttime <= $3 order by endtime desc limit $4` + + EventsRangeUntil = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and starttime <= $2 order by endtime desc limit $3` + + EventsRangeFrom = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) and endtime >= $2 order by endtime desc limit $3` + + EventsLimit = `select id, starttime, endtime, timezone, datastr from events where bucketrow = ( + select rowid from buckets where id = $1 +) order by endtime desc limit $2` +) + +// EventsRange returns the events in the bucket with the provided bucket ID +// within the specified time range, sorted descending by end time. +// The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] +// or [EventsLimit] depending on whether start and end are zero. +func (db *DB) EventsRange(ctx context.Context, bid string, start, end time.Time, limit int) ([]worklog.Event, error) { + var lim *int + if limit >= 0 { + lim = &limit + } + var e []worklog.Event + err := db.eventsRangeFunc(ctx, bid, start, end, lim, func(m worklog.Event) error { + e = append(e, m) + return nil + }, true) + return e, err +} + +// EventsRange calls fn on all the events in the bucket with the provided +// bucket ID within the specified time range, sorted descending by end time. +// The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] +// or [EventsLimit] depending on whether start and end are zero. +func (db *DB) EventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error { + var lim *int + if limit >= 0 { + lim = &limit + } + return db.eventsRangeFunc(ctx, bid, start, end, lim, fn, true) +} + +func (db *DB) eventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit *int, fn func(worklog.Event) error, order bool) error { + var ( + query string + rows pgx.Rows + err error + ) + switch { + case !start.IsZero() && !end.IsZero(): + query = EventsRange + if !order { + query = dumpEventsRange + } + rows, err = db.store.Query(ctx, query, bid, start, end, limit) + case !start.IsZero(): + query = EventsRangeFrom + if !order { + query = dumpEventsRangeFrom + } + rows, err = db.store.Query(ctx, query, bid, start, limit) + case !end.IsZero(): + query = EventsRangeUntil + if !order { + query = dumpEventsRangeUntil + } + rows, err = db.store.Query(ctx, query, bid, end, limit) + default: + query = EventsLimit + if !order { + query = dumpEventsLimit + } + rows, err = db.store.Query(ctx, query, bid, limit) + } + if err != nil { + return err + } + defer rows.Close() + for rows.Next() { + var ( + m worklog.Event + tz string + ) + err = rows.Scan(&m.ID, &m.Start, &m.End, &tz, &m.Data) + if err != nil { + return err + } + timezone, err := time.Parse(tzFormat, tz) + if err != nil { + return fmt.Errorf("invalid timezone for event %d: %w", m.ID, err) + } + loc := timezone.Location() + m.Start = m.Start.In(loc) + m.End = m.End.In(loc) + err = fn(m) + if err != nil { + return err + } + } + return nil +} + +// Select allows running an PostgreSQL SELECT query. The query is run on the +// store's read-only connection to the database. +func (db *DB) Select(ctx context.Context, query string) ([]map[string]any, error) { + if db.roStore == nil { + return nil, db.wrapROWarn(errors.New("no read-only connection")) + } + + rows, err := db.roStore.Query(ctx, query) + if err != nil { + return nil, db.wrapROWarn(err) + } + defer rows.Close() + + descs := rows.FieldDescriptions() + var e []map[string]any + for rows.Next() { + cols, err := rows.Values() + if err != nil { + return nil, db.wrapROWarn(err) + } + result := make(map[string]any) + for i, v := range cols { + result[descs[i].Name] = v + } + e = append(e, result) + } + return e, db.wrapROWarn(rows.Err()) +} + +func (db *DB) wrapROWarn(err error) error { + if err == nil || db.roWarn == nil { + return err + } + return fmt.Errorf("%w: %w", err, db.roWarn) +} + +const AmendEventsPrepare = `update events set datastr = jsonb_set(datastr, '{amend}', '[]') + where + starttime < $3 and + endtime > $2 and + not datastr::jsonb ? 'amend' and + bucketrow = ( + select rowid from buckets where id = $1 + );` +const AmendEventsUpdate = `update events set datastr = jsonb_set( + datastr, + '{amend}', + datastr->'amend' || jsonb_build_object( + 'time', $2::text, + 'msg', $3::text, + 'replace', ( + with replace as ( + select jsonb($6::text) replacements + ) + select + jsonb_agg(new order by idx) trimmed_replacements + from + replace, lateral ( + select idx, jsonb_object_agg(key, + case + when key = 'start' + then to_jsonb(greatest(old::text::timestamptz, starttime)) + when key = 'end' + then to_jsonb(least(old::text::timestamptz, endtime)) + else old + end + ) + from + jsonb_array_elements(replacements) + with ordinality rs(r, idx), + jsonb_each(r) each(key, old) + where + (r->>'start')::timestamptz < endtime and + (r->>'end')::timestamptz > starttime + group BY idx + ) news(idx, new) + ) + ) + ) + where + starttime < $5 and + endtime > $4 and + bucketrow = ( + select rowid from buckets where id = $1 + );` + +// AmendEvents adds amendment notes to the data for events in the store +// overlapping the note. On return the note.Replace slice will be sorted. +// +// The SQL commands run are [AmendEventsPrepare] and [AmendEventsUpdate] +// in a transaction. +func (db *DB) AmendEvents(ctx context.Context, ts time.Time, note *worklog.Amendment) (sql.Result, error) { + if len(note.Replace) == 0 { + return driver.RowsAffected(0), nil + } + sort.Slice(note.Replace, func(i, j int) bool { + return note.Replace[i].Start.Before(note.Replace[j].Start) + }) + start := note.Replace[0].Start + end := note.Replace[0].End + for i, r := range note.Replace[1:] { + if note.Replace[i].End.After(r.Start) { + return nil, fmt.Errorf("overlapping replacements: [%d].end (%s) is after [%d].start (%s)", + i, note.Replace[i].End.Format(time.RFC3339), i+1, r.Start.Format(time.RFC3339)) + } + if r.End.After(end) { + end = r.End + } + } + replace, err := json.Marshal(note.Replace) + if err != nil { + return nil, err + } + var res pgconn.CommandTag + err = pgx.BeginFunc(ctx, db.store, func(tx pgx.Tx) error { + _, err = tx.Exec(ctx, AmendEventsPrepare, db.BucketID(note.Bucket), start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano)) + if err != nil { + return fmt.Errorf("prepare amendment list: %w", err) + } + res, err = tx.Exec(ctx, AmendEventsUpdate, db.BucketID(note.Bucket), ts.Format(time.RFC3339Nano), note.Message, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), replace) + if err != nil { + return fmt.Errorf("add amendments: %w", err) + } + return nil + }) + if err != nil { + return nil, err + } + return result{res}, err +} + +const DeleteEvent = `delete from events where bucketrow = ( + select rowid from buckets where id = $1 +) and id = $2` + +const DeleteBucketEvents = `delete from events where bucketrow in ( + select rowid from buckets where id = $1 +)` + +const DeleteBucket = `delete from buckets where id = $1` diff --git a/cmd/worklog/pgstore/db_test.go b/cmd/worklog/pgstore/db_test.go new file mode 100644 index 0000000..4af9145 --- /dev/null +++ b/cmd/worklog/pgstore/db_test.go @@ -0,0 +1,751 @@ +// Copyright ©2024 Dan Kortschak. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pgstore + +import ( + "compress/gzip" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "io/fs" + "net/url" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/jackc/pgx/v5" + "golang.org/x/exp/slices" + + worklog "github.com/kortschak/dex/cmd/worklog/api" +) + +var ( + verbose = flag.Bool("verbose_log", false, "print full logging") + lines = flag.Bool("show_lines", false, "log source code position") + keep = flag.Bool("keep", false, "keep test database after tests") +) + +const testDir = "testdata" + +func TestDB(t *testing.T) { + if !*keep { + t.Cleanup(func() { + os.RemoveAll(testDir) + }) + } + const dbBaseName = "test_worklog_database" + + pgHost := os.Getenv("PGHOST") + if pgHost == "" { + t.Fatal("must have postgres host in $PGHOST") + } + pgPort := os.Getenv("PGPORT") + if pgPort == "" { + t.Fatal("must have postgres port number in $PGPORT") + } + pgHostPort := pgHost + ":" + pgPort + + pgUser := os.Getenv("PGUSER") + if pgUser == "" { + t.Fatal("must have postgres user in $PGUSER") + } + pgUserinfo, err := pgUserinfo(pgUser, pgHost, pgPort, os.Getenv("PGPASSWORD")) + if err != nil { + t.Fatal(err) + } + pgPassword, ok := pgUserinfo.Password() + if !ok { + t.Fatal("did not get password") + } + + ctx := context.Background() + for _, interval := range []struct { + name string + duration time.Duration + }{ + {name: "second", duration: time.Second}, + {name: "subsecond", duration: time.Second / 10}, + } { + dbName := dbBaseName + "_" + interval.name + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + t.Run(interval.name, func(t *testing.T) { + t.Run("db", func(t *testing.T) { + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword(pgUser, pgPassword), + Host: pgHostPort, + Path: dbName, + } + dbURL := u.String() + + db, err := Open(ctx, dbURL, "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close db: %v", err) + } + + now := time.Now().Round(time.Millisecond) + bucket := "test_bucket" + data := []worklog.BucketMetadata{ + { + ID: "test_bucket_test_host", + Name: "test_bucket_name", + Type: "test_bucket_type", + Client: "testing", + Hostname: "test_host", + Created: now, + Data: map[string]any{"key0": "value0"}, + Events: []worklog.Event{ + { + Bucket: bucket, + Start: now.Add(1 * interval.duration), + End: now.Add(2 * interval.duration), + Data: map[string]any{"key1": "value1"}, + }, + { + Bucket: bucket, + Start: now.Add(3 * interval.duration), + End: now.Add(4 * interval.duration), + Data: map[string]any{"key2": "value2"}, + }, + { + Bucket: bucket, + Start: now.Add(5 * interval.duration), + End: now.Add(6 * interval.duration), + Data: map[string]any{"key3": "value3"}, + }, + }, + }, + } + + // worklog.Event.ID is the only int64 field, so + // this is easier than filtering on field name. + ignoreID := cmp.FilterValues( + func(_, _ int64) bool { return true }, + cmp.Ignore(), + ) + + db, err = Open(ctx, dbURL, "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer func() { + err = db.Close(ctx) + if err != nil { + t.Errorf("failed to close db: %v", err) + } + }() + + t.Run("load_dump", func(t *testing.T) { + err = db.Load(ctx, data, false) + if err != nil { + t.Errorf("failed to load data no replace: %v", err) + } + err = db.Load(ctx, data, true) + if err != nil { + t.Errorf("failed to load data replace: %v", err) + } + + got, err := db.Dump(ctx) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + want := data + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected dump result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + + gotRange, err := db.DumpRange(ctx, now.Add(3*interval.duration), now.Add(4*interval.duration)) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRange := slices.Clone(data) + wantRange[0].Events = wantRange[0].Events[1:2] + if !cmp.Equal(wantRange, gotRange, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRange, gotRange, ignoreID)) + } + + gotRangeFrom, err := db.DumpRange(ctx, now.Add(3*interval.duration), time.Time{}) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeFrom := slices.Clone(data) + wantRangeFrom[0].Events = wantRangeFrom[0].Events[1:] + if !cmp.Equal(wantRangeFrom, gotRangeFrom, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeFrom, gotRangeFrom, ignoreID)) + } + + gotRangeUntil, err := db.DumpRange(ctx, time.Time{}, now.Add(4*interval.duration)) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeUntil := slices.Clone(data) + wantRangeUntil[0].Events = wantRangeUntil[0].Events[:2] + if !cmp.Equal(wantRangeUntil, gotRangeUntil, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeUntil, gotRangeUntil, ignoreID)) + } + + gotRangeAll, err := db.DumpRange(ctx, time.Time{}, time.Time{}) + if err != nil { + t.Errorf("failed to dump data: %v", err) + } + + wantRangeAll := slices.Clone(data) + if !cmp.Equal(wantRangeAll, gotRangeAll, ignoreID) { + t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeAll, gotRangeAll, ignoreID)) + } + + }) + + t.Run("backup", func(t *testing.T) { + workDir := filepath.Join(testDir, interval.name) + err := os.MkdirAll(workDir, 0o755) + if err != nil && !errors.Is(err, fs.ErrExist) { + t.Fatalf("failed to make dir: %v", err) + } + + path, err := db.Backup(ctx, workDir) + if err != nil { + if strings.Contains(err.Error(), "aborting because of server version mismatch") { + // On CI, we may have a different local version + // to the version on the service container. + // Skip if that's the case. + t.Skip(err.Error()) + } + t.Fatalf("failed to backup database: %v", err) + } + f, err := os.Open(path) + if errors.Is(err, fs.ErrNotExist) { + t.Fatal("did not create backup") + } else if err != nil { + t.Fatalf("unexpected error opening backup file: %v", err) + } + r, err := gzip.NewReader(f) + if err != nil { + t.Fatalf("unexpected error opening gzip backup file: %v", err) + } + _, err = io.Copy(io.Discard, r) + if err != nil { + t.Errorf("unexpected error gunzipping backup file: %v", err) + } + }) + + t.Run("last_event", func(t *testing.T) { + got, err := db.LastEvent(ctx, bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + got.Bucket = bucket + + want := &data[0].Events[len(data[0].Events)-1] + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + + t.Run("update_last_event", func(t *testing.T) { + e := data[0].Events[len(data[0].Events)-1] + e.End = e.End.Add(interval.duration) + last, err := db.LastEvent(ctx, e.Bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + e.ID = last.ID + _, err = db.UpdateEvent(ctx, &e) + if err != nil { + t.Fatalf("failed to update event: %v", err) + } + if err != nil { + t.Errorf("failed to update event: %v", err) + } + got, err := db.LastEvent(ctx, bucket) + if err != nil { + t.Errorf("failed to get last event: %v", err) + } + got.Bucket = bucket + + want := &e + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + + t.Run("events_range", func(t *testing.T) { + bid := db.BucketID(bucket) + for _, loc := range []*time.Location{time.Local, time.UTC} { + t.Run(loc.String(), func(t *testing.T) { + got, err := db.EventsRange(ctx, bid, now.Add(3*interval.duration).In(loc), now.Add(4*interval.duration).In(loc), -1) + if err != nil { + t.Errorf("failed to load data: %v", err) + } + for i := range got { + got[i].Bucket = bucket + } + + want := data[0].Events[1:2] + if !cmp.Equal(want, got, ignoreID) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) + } + }) + } + }) + + t.Run("update_last_event_coequal", func(t *testing.T) { + dbName := dbName + "_coequal" + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + u := url.URL{ + Scheme: "postgres", + User: url.UserPassword(pgUser, pgPassword), + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + buckets := []string{ + `{"id":"window","name":"window-watcher","type":"currentwindow","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38.305691865+09:30"}`, + `{"id":"afk","name":"afk-watcher","type":"afkstatus","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38.310302464+09:30"}`, + } + for _, msg := range buckets { + var b worklog.BucketMetadata + err := json.Unmarshal([]byte(msg), &b) + if err != nil { + t.Fatalf("failed to unmarshal bucket message: %v", err) + } + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + if err != nil { + t.Fatalf("failed to create bucket: %v", err) + } + } + + events := []string{ + `{"bucket":"window","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:39.248859996+09:30","data":{"app":"Gnome-terminal","title":"Terminal"},"continue":false}`, + `{"bucket":"afk","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:39.248859996+09:30","data":{"afk":false,"locked":false},"continue":false}`, + `{"bucket":"window","start":"2023-06-12T19:54:40.247357339+09:30","end":"2023-06-12T19:54:40.247357339+09:30","data":{"app":"Gnome-terminal","title":"Terminal"},"continue":false}`, + `{"bucket":"afk","start":"2023-06-12T19:54:39.248859996+09:30","end":"2023-06-12T19:54:40.247357339+09:30","data":{"afk":false,"locked":false},"continue":true}`, + } + for i, msg := range events { + var note *worklog.Event + err := json.Unmarshal([]byte(msg), ¬e) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + if note.Continue != nil && *note.Continue { + last, err := db.LastEvent(ctx, note.Bucket) + if err != nil { + t.Fatalf("failed to get last event: %v", err) + } + note.ID = last.ID + _, err = db.UpdateEvent(ctx, note) + if err != nil { + t.Fatalf("failed to update event: %v", err) + } + } else { + _, err = db.InsertEvent(ctx, note) + if err != nil { + t.Fatalf("failed to insert event: %v", err) + } + } + + dump, err := db.Dump(ctx) + if err != nil { + t.Fatalf("failed to dump db after step %d: %v", i, err) + } + t.Logf("note: %#v\ndump: %#v", note, dump) + + for _, b := range dump { + for _, e := range b.Events { + if e.Bucket == "window" { + if _, ok := e.Data["afk"]; ok { + t.Errorf("unexpectedly found afk data in window bucket: %v", e) + } + } + } + } + } + }) + + t.Run("amend", func(t *testing.T) { + // t.Skip("not yet working") + + dbName := dbName + "_amend" + dropTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + drop := createTestDB(t, ctx, pgUserinfo, pgHostPort, dbName) + if !*keep { + defer drop() + } + + u := url.URL{ + Scheme: "postgres", + User: pgUserinfo, + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil && !errors.As(err, &Warning{}) { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + buckets := []string{ + `{"id":"window","name":"window-watcher","type":"currentwindow","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38Z"}`, + `{"id":"afk","name":"afk-watcher","type":"afkstatus","client":"worklog","hostname":"test_host","created":"2023-06-12T19:54:38Z"}`, + } + for _, msg := range buckets { + var b worklog.BucketMetadata + err := json.Unmarshal([]byte(msg), &b) + if err != nil { + t.Fatalf("failed to unmarshal bucket message: %v", err) + } + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + if err != nil { + t.Fatalf("failed to create bucket: %v", err) + } + } + + events := []string{ + `{"bucket":"window","start":"2023-06-12T19:54:40Z","end":"2023-06-12T19:54:45Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:40Z","end":"2023-06-12T19:54:45Z","data":{"afk":false,"locked":false}}`, + `{"bucket":"window","start":"2023-06-12T19:54:45Z","end":"2023-06-12T19:54:50Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:45Z","end":"2023-06-12T19:54:50Z","data":{"afk":true,"locked":true}}`, + `{"bucket":"window","start":"2023-06-12T19:54:50Z","end":"2023-06-12T19:54:55Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:50Z","end":"2023-06-12T19:54:55Z","data":{"afk":false,"locked":false}}`, + `{"bucket":"window","start":"2023-06-12T19:54:55Z","end":"2023-06-12T19:54:59Z","data":{"app":"Gnome-terminal","title":"Terminal"}}`, + `{"bucket":"afk","start":"2023-06-12T19:54:55Z","end":"2023-06-12T19:54:59Z","data":{"afk":true,"locked":true}}`, + } + for _, msg := range events { + var note *worklog.Event + err := json.Unmarshal([]byte(msg), ¬e) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + _, err = db.InsertEvent(ctx, note) + if err != nil { + t.Fatalf("failed to insert event: %v", err) + } + } + msg := `{"bucket":"afk","msg":"testing","replace":[{"start":"2023-06-12T19:54:39Z","end":"2023-06-12T19:54:51Z","data":{"afk":true,"locked":true}}]}` + var amendment *worklog.Amendment + err = json.Unmarshal([]byte(msg), &amendment) + if err != nil { + t.Fatalf("failed to unmarshal event message: %v", err) + } + _, err = db.AmendEvents(ctx, time.Time{}, amendment) + if err != nil { + t.Errorf("unexpected error amending events: %v", err) + } + dump, err := db.Dump(ctx) + if err != nil { + t.Fatalf("failed to dump db: %v", err) + } + for _, bucket := range dump { + for i, event := range bucket.Events { + switch event.Bucket { + case "window": + _, ok := event.Data["amend"] + if ok { + t.Errorf("unexpected amendment in window event %d: %v", i, event.Data) + } + case "afk": + a, ok := event.Data["amend"] + if !ok { + for _, r := range amendment.Replace { + if overlaps(event.Start, event.End, r.Start, r.End) { + t.Errorf("expected amendment for event %d of afk", i) + break + } + } + break + } + var n []worklog.Amendment + err = remarshalJSON(&n, a) + if err != nil { + t.Errorf("unexpected error remarshalling []AmendEvents: %v", err) + } + if len(n) == 0 { + t.Fatal("unexpected zero-length []AmendEvents") + } + for _, r := range n[len(n)-1].Replace { + if r.Start.Before(event.Start) { + t.Errorf("replacement start extends before start of event: %s < %s", + r.Start.Format(time.RFC3339), event.Start.Format(time.RFC3339)) + } + if noted, ok := findOverlap(r, amendment.Replace); ok && !r.Start.Equal(event.Start) && !r.Start.Equal(noted.Start) { + t.Errorf("non-truncated replacement start was altered: %s != %s", + r.Start.Format(time.RFC3339), noted.Start.Format(time.RFC3339)) + } + if r.End.After(event.End) { + t.Errorf("replacement end extends beyond end of event: %s > %s", + r.End.Format(time.RFC3339), event.End.Format(time.RFC3339)) + } + if noted, ok := findOverlap(r, amendment.Replace); ok && !r.End.Equal(event.End) && !r.End.Equal(noted.End) { + t.Errorf("non-truncated replacement end was altered: %s != %s", + r.End.Format(time.RFC3339), noted.End.Format(time.RFC3339)) + } + } + default: + t.Errorf("unexpected event bucket name in event %d of %s: %s", i, bucket.ID, event.Bucket) + } + } + } + }) + + t.Run("dynamic_query", func(t *testing.T) { + grantReadAccess(t, ctx, pgUserinfo, pgHost, dbName, pgUser+"_ro") + u := url.URL{ + Scheme: "postgres", + User: pgUserinfo, + Host: pgHostPort, + Path: dbName, + } + db, err := Open(ctx, u.String(), "test_host") + if err != nil { + t.Fatalf("failed to open db: %v", err) + } + defer db.Close(ctx) + + dynamicTests := []struct { + name string + sql string + wantErr []error + }{ + { + name: "kitchen_or", + sql: `select datastr ->> 'title', starttime, datastr ->> 'afk' from events + where + not (datastr ->> 'afk')::boolean or datastr ->> 'title' = 'Terminal' + limit 2`, + }, + { + name: "kitchen_and", + sql: `select datastr ->> 'title', starttime, datastr ->> 'afk' from events + where + not (datastr ->> 'afk')::boolean and datastr ->> 'title' = 'Terminal' + limit 2`, + }, + { + name: "count", + sql: `select count(*) from events`, + }, + { + name: "all", + sql: `select * from events`, + }, + { + name: "non_null_afk", + sql: `select * from events where datastr ? 'app'`, + }, + { + name: "drop_table", + sql: `drop table events`, + wantErr: []error{ + errors.New("ERROR: must be owner of relation events (SQLSTATE 42501)"), + errors.New("ERROR: must be owner of table events (SQLSTATE 42501)"), + }, + }, + { + name: "sneaky_create_table", + sql: "select count(*) from events; create table if not exists t(i)", + wantErr: []error{ + errors.New("ERROR: syntax error at end of input (SQLSTATE 42601)"), + }, + }, + { + name: "sneaky_drop_table", + sql: "select count(*) from events; drop table events", + wantErr: []error{ + errors.New("ERROR: cannot insert multiple commands into a prepared statement (SQLSTATE 42601)"), + }, + }, + } + + for _, test := range dynamicTests { + t.Run(test.name, func(t *testing.T) { + got, err := db.Select(ctx, test.sql) + if !sameErrorIn(err, test.wantErr) { + t.Errorf("unexpected error: got:%v want:%v", err, test.wantErr) + return + } + if err != nil { + return + } + + rows, err := db.store.Query(ctx, test.sql) + if err != nil { + t.Fatalf("unexpected error for query: %v", err) + } + descs := rows.FieldDescriptions() + var want []map[string]any + for rows.Next() { + args := make([]any, len(descs)) + for i := range args { + var a any + args[i] = &a + } + err = rows.Scan(args...) + if err != nil { + t.Fatal(err) + } + row := make(map[string]any) + for i, a := range args { + row[descs[i].Name] = *(a.(*any)) + } + want = append(want, row) + } + rows.Close() + + if !cmp.Equal(want, got) { + t.Errorf("unexpected result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got)) + } + }) + } + }) + }) + }) + } +} + +func createTestDB(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname string) func() { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: "template1", + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open admin database: %v", err) + } + _, err = db.Exec(ctx, "create database "+dbname) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close admin connection: %v", err) + } + + return func() { + dropTestDB(t, ctx, user, host, dbname) + } +} + +func dropTestDB(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname string) { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: "template1", + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open admin database: %v", err) + } + _, err = db.Exec(ctx, "drop database if exists "+dbname) + if err != nil { + t.Fatalf("failed to drop test database: %v", err) + } + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close admin connection: %v", err) + } +} + +func grantReadAccess(t *testing.T, ctx context.Context, user *url.Userinfo, host, dbname, target string) { + t.Helper() + + u := url.URL{ + Scheme: "postgres", + User: user, + Host: host, + Path: dbname, + } + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + t.Fatalf("failed to open database: %v", err) + } + + statements := []string{ + fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO %s", dbname, target), + fmt.Sprintf("GRANT USAGE ON SCHEMA public TO %s", target), + fmt.Sprintf("GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s", target), + } + for _, s := range statements { + _, err = db.Exec(ctx, s) + if err != nil { + t.Fatalf("failed to execute grant: %v", err) + } + } + + err = db.Close(ctx) + if err != nil { + t.Fatalf("failed to close connection: %v", err) + } +} + +func ptr[T any](v T) *T { return &v } + +func findOverlap(n worklog.Replacement, h []worklog.Replacement) (worklog.Replacement, bool) { + for _, c := range h { + if overlaps(n.Start, n.End, c.Start, c.End) { + return c, true + } + } + return worklog.Replacement{}, false +} + +func overlaps(as, ae, bs, be time.Time) bool { + return ae.After(bs) && as.Before(be) +} +func remarshalJSON(dst, src any) error { + b, err := json.Marshal(src) + if err != nil { + return err + } + return json.Unmarshal(b, dst) +} + +func sameErrorIn(err error, list []error) bool { + switch { + case err != nil && list != nil: + return slices.ContainsFunc(list, func(e error) bool { + return err.Error() == e.Error() + }) + case err == nil && list == nil: + return true + default: + return false + } +} diff --git a/go.mod b/go.mod index 79c639b..6cd7825 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,7 @@ require ( github.com/gofrs/flock v0.8.1 github.com/google/cel-go v0.20.1 github.com/google/go-cmp v0.6.0 + github.com/jackc/pgx/v5 v5.6.0 github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4 github.com/kortschak/goroutine v1.1.2 github.com/kortschak/jsonrpc2 v0.0.0-20240214190357-0539ebd6a045 @@ -41,6 +42,8 @@ require ( github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/golang-lru/v2 v2.0.7 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect @@ -49,6 +52,7 @@ require ( github.com/tdewolff/font v0.0.0-20240417221047-e5855237f87b // indirect github.com/tdewolff/minify/v2 v2.20.5 // indirect github.com/tdewolff/parse/v2 v2.7.3 // indirect + golang.org/x/crypto v0.23.0 // indirect golang.org/x/exp/shiny v0.0.0-20231006140011-7918f672742d // indirect golang.org/x/text v0.16.0 // indirect google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37 // indirect diff --git a/go.sum b/go.sum index 50f147a..82d9dca 100644 --- a/go.sum +++ b/go.sum @@ -76,6 +76,14 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.6.0 h1:SWJzexBzPL5jb0GEsrPMLIsi/3jOo7RHlzTjcAeDrPY= +github.com/jackc/pgx/v5 v5.6.0/go.mod h1:DNZ/vlrUnhWCoFGxHAG8U2ljioxukquj7utPDgtQdTw= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4 h1:LBA2QNMoqB/6H1QG9oSzcVQxqLQJ8vX/Vr4SiyGlLNI= github.com/kortschak/ardilla v0.0.0-20240121074954-8297d203ffa4/go.mod h1:CzpsZoRc37ZYYkb/iJHFRW9m+pWuyiS3WAqvuBVFiCo= github.com/kortschak/goroutine v1.1.2 h1:lhllcCuERxMIK5cYr8yohZZScL1na+JM5JYPRclWjck= @@ -111,7 +119,9 @@ github.com/sstallion/go-hid v0.14.1/go.mod h1:fPKp4rqx0xuoTV94gwKojsPG++KNKhxuU8 github.com/stoewer/go-strcase v1.2.0 h1:Z2iHWqGXH00XYgqDmNgQbIBxf3wrNq0F3feEy0ainaU= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/tdewolff/canvas v0.0.0-20240512164826-1cb71758b3b2 h1:n5JfO/T/+VqXuHym08mEAzmDz2QDwMA71+/zl1SG+dE= @@ -125,6 +135,8 @@ github.com/tdewolff/parse/v2 v2.7.3/go.mod h1:9p2qMIHpjRSTr1qnFxQr+igogyTUTlwvf9 github.com/tdewolff/test v1.0.10/go.mod h1:6DAvZliBAAnD7rhVgwaM7DE5/d9NMOAJ09SqYqeK4QE= github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739 h1:IkjBCtQOOjIn03u/dMQK9g+Iw9ewps4mCl1nB8Sscbo= github.com/tdewolff/test v1.0.11-0.20240106005702-7de5f7df4739/go.mod h1:XPuWBzvdUzhCuxWO1ojpXsyzsA5bFoS3tO/Q3kFuTG8= +golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f h1:99ci1mjWVBWwJiEKYY6jWa4d2nTQVIEhZIptnrVb1XY= golang.org/x/exp v0.0.0-20240416160154-fe59bbe5cc7f/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI= golang.org/x/exp/shiny v0.0.0-20231006140011-7918f672742d h1:grE48C8cjIY0aiHVmFyYgYxxSARQWBABLXKZfQPrBhY= @@ -164,9 +176,10 @@ google.golang.org/genproto v0.0.0-20221207170731-23e4bf6bdc37/go.mod h1:RGgjbofJ google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= modernc.org/cc/v4 v4.21.2 h1:dycHFB/jDc3IyacKipCNSDrjIC0Lm1hyoWOZTRR20Lk= From ee950f81c415546ca504076f464642d246004167 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Sun, 22 Sep 2024 07:15:05 +0930 Subject: [PATCH 2/6] cmd/worklog: add initial generalised database wiring Deprecate the database_dir configuration, to be removed in a month. --- cmd/worklog/README.md | 2 +- cmd/worklog/api/api.go | 19 ++-- cmd/worklog/main.go | 46 ++++++++- cmd/worklog/main_test.go | 88 ++++++++++++++++- testdata/worklog_load.txt | 4 +- testdata/worklog_load_deprecated.txt | 143 +++++++++++++++++++++++++++ 6 files changed, 289 insertions(+), 13 deletions(-) create mode 100644 testdata/worklog_load_deprecated.txt diff --git a/cmd/worklog/README.md b/cmd/worklog/README.md index 33094d1..afe9046 100644 --- a/cmd/worklog/README.md +++ b/cmd/worklog/README.md @@ -37,7 +37,7 @@ log_mode = "log" log_level = "info" [module.worklog.options] -database_dir = "worklog" +database = "sqlite:worklog" [module.worklog.options.rules.afk] name = "afk-watcher" diff --git a/cmd/worklog/api/api.go b/cmd/worklog/api/api.go index 32e54cb..7281921 100644 --- a/cmd/worklog/api/api.go +++ b/cmd/worklog/api/api.go @@ -23,12 +23,19 @@ type Config struct { LogLevel *slog.Level `json:"log_level,omitempty"` AddSource *bool `json:"log_add_source,omitempty"` Options struct { - DynamicLocation *bool `json:"dynamic_location,omitempty"` - Web *Web `json:"web,omitempty"` - DatabaseDir string `json:"database_dir,omitempty"` // Relative to XDG_STATE_HOME. - Hostname string `json:"hostname,omitempty"` - Heartbeat *rpc.Duration `json:"heartbeat,omitempty"` - Rules map[string]Rule `json:"rules,omitempty"` + DynamicLocation *bool `json:"dynamic_location,omitempty"` + Web *Web `json:"web,omitempty"` + // Database is the URL location of the worklog + // database. When the scheme is sqlite, the location + // is a directory relative to XDG_STATE_HOME as + // URL opaque data. + Database string `json:"database,omitempty"` + Hostname string `json:"hostname,omitempty"` + Heartbeat *rpc.Duration `json:"heartbeat,omitempty"` + Rules map[string]Rule `json:"rules,omitempty"` + + // Deprecated: Use Database with sqlite scheme. + DatabaseDir string `json:"database_dir,omitempty"` // Relative to XDG_STATE_HOME. } `json:"options,omitempty"` } diff --git a/cmd/worklog/main.go b/cmd/worklog/main.go index 3fb9b1b..743e5a3 100644 --- a/cmd/worklog/main.go +++ b/cmd/worklog/main.go @@ -319,8 +319,20 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) } } - if m.Body.Options.DatabaseDir != "" { - dir, err := xdg.State(m.Body.Options.DatabaseDir) + databaseDir, err := dbDir(m.Body) + if err != nil { + d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.Any("error", err)) + return nil, rpc.NewError(rpc.ErrCodeInvalidMessage, + err.Error(), + map[string]any{ + "type": rpc.ErrCodeParameters, + "database": m.Body.Options.Database, + "database_dir": m.Body.Options.DatabaseDir, + }, + ) + } + if databaseDir != "" { + dir, err := xdg.State(databaseDir) switch err { case nil: case syscall.ENOENT: @@ -330,7 +342,7 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.String("error", "no XDG_STATE_HOME")) return nil, err } - dir = filepath.Join(dir, m.Body.Options.DatabaseDir) + dir = filepath.Join(dir, databaseDir) err = os.Mkdir(dir, 0o750) if err != nil { err := err.(*os.PathError) // See godoc for os.Mkdir for why this is safe. @@ -392,6 +404,34 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) } } +func dbDir(cfg worklog.Config) (string, error) { + opt := cfg.Options + if opt.Database == "" { + return opt.DatabaseDir, nil + } + u, err := url.Parse(opt.Database) + if err != nil { + return "", err + } + switch u.Scheme { + case "": + return "", errors.New("missing scheme in database configuration") + case "sqlite": + if opt.DatabaseDir != "" && u.Opaque != opt.DatabaseDir { + return "", fmt.Errorf("inconsistent database directory configuration: (%s:)%s != %s", u.Scheme, u.Opaque, opt.DatabaseDir) + } + if u.Opaque == "" { + return "", fmt.Errorf("sqlite configuration missing opaque data: %s", opt.Database) + } + return u.Opaque, nil + default: + if opt.DatabaseDir != "" { + return "", fmt.Errorf("inconsistent database configuration: both %s database and sqlite directory configured", u.Scheme) + } + return "", nil + } +} + func (d *daemon) replaceTimezone(ctx context.Context, dynamic *bool) { if dynamic == nil { return diff --git a/cmd/worklog/main_test.go b/cmd/worklog/main_test.go index f31357e..951e84f 100644 --- a/cmd/worklog/main_test.go +++ b/cmd/worklog/main_test.go @@ -143,10 +143,11 @@ func TestDaemon(t *testing.T) { type options struct { DynamicLocation *bool `json:"dynamic_location,omitempty"` Web *worklog.Web `json:"web,omitempty"` - DatabaseDir string `json:"database_dir,omitempty"` // Relative to XDG_STATE_HOME. + Database string `json:"database,omitempty"` Hostname string `json:"hostname,omitempty"` Heartbeat *rpc.Duration `json:"heartbeat,omitempty"` Rules map[string]worklog.Rule `json:"rules,omitempty"` + DatabaseDir string `json:"database_dir,omitempty"` // Relative to XDG_STATE_HOME. } err := conn.Call(ctx, "configure", rpc.NewMessage(uid, worklog.Config{ Options: options{ @@ -1715,3 +1716,88 @@ func TestMergeReplacements(t *testing.T) { }) } } + +var dbDirTests = []struct { + name string + config worklog.Config + want string + wantErr error +}{ + { + name: "none", + }, + { + name: "deprecated", + config: mkDBDirOptions("database_directory", ""), + want: "database_directory", + }, + { + name: "url_only_sqlite", + config: mkDBDirOptions("", "sqlite:database_directory"), + want: "database_directory", + }, + { + name: "url_only_postgres", + config: mkDBDirOptions("", "postgres://username:password@localhost:5432/database_name"), + want: "", + }, + { + name: "both_consistent", + config: mkDBDirOptions("database_directory", "sqlite:database_directory"), + want: "database_directory", + }, + { + name: "missing_scheme", + config: mkDBDirOptions("database_dir", "database_directory"), + want: "", + wantErr: errors.New("missing scheme in database configuration"), + }, + { + name: "both_inconsistent_sqlite", + config: mkDBDirOptions("database_dir", "sqlite:database_directory"), + want: "", + wantErr: errors.New("inconsistent database directory configuration: (sqlite:)database_directory != database_dir"), + }, + { + name: "invalid_sqlite_url", + config: mkDBDirOptions("", "sqlite:/database_directory"), + want: "", + wantErr: errors.New("sqlite configuration missing opaque data: sqlite:/database_directory"), + }, + { + name: "both_inconsistent_postgres", + config: mkDBDirOptions("database_dir", "postgres://username:password@localhost:5432/database_name"), + want: "", + wantErr: errors.New("inconsistent database configuration: both postgres database and sqlite directory configured"), + }, +} + +func mkDBDirOptions(dir, url string) worklog.Config { + var cfg worklog.Config + cfg.Options.DatabaseDir = dir + cfg.Options.Database = url + return cfg +} + +func TestDBDir(t *testing.T) { + for _, test := range dbDirTests { + t.Run(test.name, func(t *testing.T) { + got, err := dbDir(test.config) + if !sameError(err, test.wantErr) { + t.Errorf("unexpected error calling dbDir: got:%v want:%v", err, test.wantErr) + } + if got != test.want { + t.Errorf("unexpected result: got:%q want:%q", got, test.want) + } + }) + } +} + +func sameError(a, b error) bool { + switch { + case a != nil && b != nil: + return a.Error() == b.Error() + default: + return a == b + } +} diff --git a/testdata/worklog_load.txt b/testdata/worklog_load.txt index 7c0befb..1c57cbd 100644 --- a/testdata/worklog_load.txt +++ b/testdata/worklog_load.txt @@ -1,6 +1,6 @@ # The build of worklog takes a fair while due to the size of the dependency # tree, the size of some of the individual dependencies and the absence of -# cachine when building within a test script. +# caching when building within a test script. [short] stop 'Skipping long test.' env HOME=${WORK} @@ -43,7 +43,7 @@ log_level = "debug" log_add_source = true [module.worklog.options] -database_dir = "worklog" +database = "sqlite:worklog" hostname = "localhost" [module.worklog.options.web] addr = "localhost:7979" diff --git a/testdata/worklog_load_deprecated.txt b/testdata/worklog_load_deprecated.txt new file mode 100644 index 0000000..5224d46 --- /dev/null +++ b/testdata/worklog_load_deprecated.txt @@ -0,0 +1,143 @@ +# The build of worklog takes a fair while due to the size of the dependency +# tree, the size of some of the individual dependencies and the absence of +# caching when building within a test script. +[short] stop 'Skipping long test.' + +env HOME=${WORK} + +[linux] env XDG_CONFIG_HOME=${HOME}/.config +[linux] env XDG_RUNTIME_DIR=${HOME}/runtime +[linux] mkdir ${XDG_CONFIG_HOME}/dex +[linux] mkdir ${XDG_RUNTIME_DIR} +[linux] mv config.toml ${HOME}/.config/dex/config.toml + +[darwin] mkdir ${HOME}'/Library/Application Support/dex' +[darwin] mv config.toml ${HOME}'/Library/Application Support/dex/config.toml' + +env GOBIN=${WORK}/bin +env PATH=${GOBIN}:${PATH} +cd ${PKG_ROOT} +go install ./cmd/worklog +cd ${WORK} + +dex -log debug -lines &dex& +sleep 1s + +POST dump.json http://localhost:7978/load/?replace=true + +GET -json http://localhost:7978/dump/ +cmp stdout want.json + +GET http://localhost:7978/query?sql=select+count(*)+as+event_count+from+events +cmp stdout want_event_count.json + +-- config.toml -- +[kernel] +device = [] +network = "tcp" + +[module.worklog] +path = "worklog" +log_mode = "log" +log_level = "debug" +log_add_source = true + +[module.worklog.options] +database_dir = "worklog" +hostname = "localhost" +[module.worklog.options.web] +addr = "localhost:7978" +can_modify = true + +-- dump.json -- +{ + "buckets": [ + { + "id": "afk_localhost", + "name": "afk-watcher", + "type": "afkstatus", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.424207553+10:30", + "events": [ + { + "bucket": "afk", + "id": 2, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "afk": false, + "locked": false + } + } + ] + }, + { + "id": "window_localhost", + "name": "window-watcher", + "type": "currentwindow", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.428793055+10:30", + "events": [ + { + "bucket": "window", + "id": 1, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "app": "Gnome-terminal", + "title": "Terminal" + } + } + ] + } + ] +} +-- want.json -- +{ + "buckets": [ + { + "id": "afk_localhost", + "name": "afk-watcher", + "type": "afkstatus", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.424207553+10:30", + "events": [ + { + "bucket": "afk", + "id": 1, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "afk": false, + "locked": false + } + } + ] + }, + { + "id": "window_localhost", + "name": "window-watcher", + "type": "currentwindow", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.428793055+10:30", + "events": [ + { + "bucket": "window", + "id": 2, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "app": "Gnome-terminal", + "title": "Terminal" + } + } + ] + } + ] +} +-- want_event_count.json -- +[{"event_count":2}] From 1ff7d37b684d4f86a572d94c2a1a870b43fc6b43 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Sun, 22 Sep 2024 09:18:48 +0930 Subject: [PATCH 3/6] cmd/worklog: factor storage into an interface --- cmd/worklog/dashboard.go | 11 ++++----- cmd/worklog/main.go | 51 ++++++++++++++++++++++++++++++++++------ cmd/worklog/main_test.go | 3 +-- 3 files changed, 50 insertions(+), 15 deletions(-) diff --git a/cmd/worklog/dashboard.go b/cmd/worklog/dashboard.go index 0cdacdc..2cd93d8 100644 --- a/cmd/worklog/dashboard.go +++ b/cmd/worklog/dashboard.go @@ -22,7 +22,6 @@ import ( "time" worklog "github.com/kortschak/dex/cmd/worklog/api" - "github.com/kortschak/dex/cmd/worklog/store" ) func (d *daemon) dashboardData(ctx context.Context) http.HandlerFunc { @@ -91,7 +90,7 @@ func dateQuery(u *url.URL, loc *time.Location) (time.Time, error) { return time.ParseInLocation(time.DateOnly, d, loc) } -func (d *daemon) eventData(ctx context.Context, db *store.DB, rules map[string]map[string]ruleDetail, date time.Time, raw bool) (map[string]any, error) { +func (d *daemon) eventData(ctx context.Context, db storage, rules map[string]map[string]ruleDetail, date time.Time, raw bool) (map[string]any, error) { if raw { return d.rawEventData(ctx, db, rules, date) } @@ -189,7 +188,7 @@ func (d *daemon) eventData(ctx context.Context, db *store.DB, rules map[string]m return events, nil } -func (d *daemon) rawEventData(ctx context.Context, db *store.DB, rules map[string]map[string]ruleDetail, date time.Time) (map[string]any, error) { +func (d *daemon) rawEventData(ctx context.Context, db storage, rules map[string]map[string]ruleDetail, date time.Time) (map[string]any, error) { start, end := day(date) events := map[string]any{ "date": zoneTranslatedTime(start, date.Location()), @@ -220,7 +219,7 @@ func (d *daemon) rawEventData(ctx context.Context, db *store.DB, rules map[strin return events, nil } -func (d *daemon) dayData(ctx context.Context, db *store.DB, rules map[string]map[string]ruleDetail, start, end time.Time) (atKeyboard []worklog.Event, dayEvents, windowEvents map[string][]worklog.Event, transitions graph, err error) { +func (d *daemon) dayData(ctx context.Context, db storage, rules map[string]map[string]ruleDetail, start, end time.Time) (atKeyboard []worklog.Event, dayEvents, windowEvents map[string][]worklog.Event, transitions graph, err error) { dayEvents = make(map[string][]worklog.Event) windowEvents = make(map[string][]worklog.Event) transitions = newGraph(rng{min: 5, max: 30}, rng{min: 1, max: 5}) @@ -411,7 +410,7 @@ type summary struct { Warnings []string `json:"warn,omitempty"` } -func (d *daemon) rangeSummary(ctx context.Context, db *store.DB, rules map[string]map[string]ruleDetail, start, end time.Time, raw bool, req *url.URL) (summary, error) { +func (d *daemon) rangeSummary(ctx context.Context, db storage, rules map[string]map[string]ruleDetail, start, end time.Time, raw bool, req *url.URL) (summary, error) { events := summary{ Start: start, End: end, @@ -590,7 +589,7 @@ func mergeSummaries(summaries []summary, cooldown time.Duration) (summary, error return sum, nil } -func (d *daemon) atKeyboard(ctx context.Context, db *store.DB, rules map[string]map[string]ruleDetail, start, end time.Time) ([]worklog.Event, error) { +func (d *daemon) atKeyboard(ctx context.Context, db storage, rules map[string]map[string]ruleDetail, start, end time.Time) ([]worklog.Event, error) { var atKeyboard []worklog.Event for srcBucket, ruleSet := range rules { for dstBucket, rule := range ruleSet { diff --git a/cmd/worklog/main.go b/cmd/worklog/main.go index 743e5a3..01cc713 100644 --- a/cmd/worklog/main.go +++ b/cmd/worklog/main.go @@ -8,6 +8,7 @@ package main import ( "bytes" "context" + "database/sql" "embed" "encoding/json" "errors" @@ -185,7 +186,7 @@ type daemon struct { rMu sync.Mutex lastEvents map[string]*worklog.Event - db atomic.Pointer[store.DB] + db atomicIfaceValue[storage] lastReport map[rpc.UID]worklog.Report @@ -221,13 +222,49 @@ type atomicIfaceValue[T any] struct { } func (v *atomicIfaceValue[T]) Store(val T) { - v.val.Store(&val) + // We need to be able to store a nil T, so work + // around T any not being comparable to nil. + switch any(val).(type) { + case nil: + v.val.Store((*T)(nil)) + default: + v.val.Store(&val) + } } func (v *atomicIfaceValue[T]) Load() T { - return *v.val.Load() + p := v.val.Load() + if p == nil { + // Get our nil T. + var zero T + return zero + } + return *p } +type storage interface { + AmendEvents(ts time.Time, note *worklog.Amendment) (sql.Result, error) + Backup(ctx context.Context, n int, sleep time.Duration) (string, error) + BucketID(uid string) string + BucketMetadata(bid string) (*worklog.BucketMetadata, error) + Buckets() ([]worklog.BucketMetadata, error) + Close() error + CreateBucket(uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) + Dump() ([]worklog.BucketMetadata, error) + DumpRange(start, end time.Time) ([]worklog.BucketMetadata, error) + Events(bid string) ([]worklog.Event, error) + EventsRange(bid string, start, end time.Time, limit int) ([]worklog.Event, error) + EventsRangeFunc(bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error + InsertEvent(e *worklog.Event) (sql.Result, error) + LastEvent(uid string) (*worklog.Event, error) + Load(buckets []worklog.BucketMetadata, replace bool) (err error) + Name() string + Select(query string) ([]map[string]any, error) + UpdateEvent(e *worklog.Event) (sql.Result, error) +} + +var _ storage = (*store.DB)(nil) + type current interface { Location() (*time.Location, error) } @@ -539,10 +576,10 @@ func (d *daemon) configureRules(ctx context.Context, rules map[string]worklog.Ru d.rules.Store(ruleDetails) } -func (d *daemon) openDB(ctx context.Context, db *store.DB, path, hostname string) error { +func (d *daemon) openDB(ctx context.Context, db storage, path, hostname string) error { if db != nil { d.log.LogAttrs(ctx, slog.LevelInfo, "close database", slog.String("path", db.Name())) - d.db.Store((*store.DB)(nil)) + d.db.Store((storage)(nil)) db.Close() } // store.Open may need to get the hostname, which may @@ -562,7 +599,7 @@ func (d *daemon) openDB(ctx context.Context, db *store.DB, path, hostname string return nil } -func (d *daemon) configureDB(ctx context.Context, db *store.DB) { +func (d *daemon) configureDB(ctx context.Context, db storage) { rules := d.rules.Load() for bucket, rule := range rules { d.log.LogAttrs(ctx, slog.LevelDebug, "create bucket", slog.Any("bucket", bucket)) @@ -1278,7 +1315,7 @@ func queryError(dec *json.Decoder, body []byte, err error) any { } type dbLib struct { - db *store.DB + db storage log *slog.Logger } diff --git a/cmd/worklog/main_test.go b/cmd/worklog/main_test.go index 951e84f..7051a81 100644 --- a/cmd/worklog/main_test.go +++ b/cmd/worklog/main_test.go @@ -29,7 +29,6 @@ import ( "golang.org/x/sys/execabs" worklog "github.com/kortschak/dex/cmd/worklog/api" - "github.com/kortschak/dex/cmd/worklog/store" "github.com/kortschak/dex/internal/slogext" "github.com/kortschak/dex/rpc" ) @@ -522,7 +521,7 @@ func mergeSummaryData() int { return 0 } -func newTestDaemon(ctx context.Context, cancel context.CancelFunc, verbose bool, dbName string, replace bool, data []byte, ruleBytes []byte) (*daemon, *store.DB, map[string]map[string]ruleDetail, int) { +func newTestDaemon(ctx context.Context, cancel context.CancelFunc, verbose bool, dbName string, replace bool, data []byte, ruleBytes []byte) (*daemon, storage, map[string]map[string]ruleDetail, int) { var ( level slog.LevelVar addSource = slogext.NewAtomicBool(*lines) From 3a2e929ebbd6d7a46c39e342578a3ad6c9a08b99 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Sun, 22 Sep 2024 09:52:59 +0930 Subject: [PATCH 4/6] cmd/worklog{,/store}: generalise storage interface --- cmd/worklog/README.md | 2 +- cmd/worklog/dashboard.go | 4 +- cmd/worklog/main.go | 116 ++++++++++++++++++++--------------- cmd/worklog/main_test.go | 10 +-- cmd/worklog/store/db.go | 30 ++++----- cmd/worklog/store/db_test.go | 73 +++++++++++----------- 6 files changed, 127 insertions(+), 108 deletions(-) diff --git a/cmd/worklog/README.md b/cmd/worklog/README.md index afe9046..3da6f25 100644 --- a/cmd/worklog/README.md +++ b/cmd/worklog/README.md @@ -160,7 +160,7 @@ In addition to the dashboard endpoint provided by the `worklog` server, there ar - `GET` `/data/`: accepts `date` and `tz` query parameters for the day of the data to collect, and a `raw` parameter to return un-summarised data. - `GET` `/summary/`: accepts `start`, `end` and `tz` query parameters for time ranges, a `cooldown` parameter to ignore brief AFK periods, an `other` parameter which is a list of other worklog instance URLs to collate into the result, and a `raw` parameter to return un-summarised data. - `GET` `/dump/`: accepts `start` and `end` query parameters. -- `GET` `/backup/`: accepts `pages_per_step` and `sleep` query parameters corresponding to the [SQLite backup API](https://www.sqlite.org/backup.html)'s [`sqlite3_backup_step` `N` parameter](https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep) and the time between successive `sqlite3_backup_step` calls. The backup endpoint is only available when the server address is a loop-back address. +- `GET` `/backup/`: when using SQLite for data storage, accepts `pages_per_step` and `sleep` query parameters corresponding to the [SQLite backup API](https://www.sqlite.org/backup.html)'s [`sqlite3_backup_step` `N` parameter](https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupstep) and the time between successive `sqlite3_backup_step` calls; when using PostgreSQL for storage, accepts `directory` indicating the destination directory to write the backup to. The backup endpoint is only available when the server address is a loop-back address. - `GET`/`POST` `/query`: takes an SQLite SELECT statement (content-type:application/sql or URL parameter, sql) or a CEL program (content-type:application/cel) that may use a built-in `query()` function. The query endpoint is only available when the server address is a loop-back address. A potentially useful configuration for debugging rules is diff --git a/cmd/worklog/dashboard.go b/cmd/worklog/dashboard.go index 2cd93d8..c38bbf8 100644 --- a/cmd/worklog/dashboard.go +++ b/cmd/worklog/dashboard.go @@ -226,7 +226,7 @@ func (d *daemon) dayData(ctx context.Context, db storage, rules map[string]map[s for srcBucket, ruleSet := range rules { for dstBucket, rule := range ruleSet { var nextApp worklog.Event // EventsRange is sorted descending. - err := db.EventsRangeFunc(db.BucketID(srcBucket), start, end, -1, func(m worklog.Event) error { + err := db.EventsRangeFunc(ctx, db.BucketID(srcBucket), start, end, -1, func(m worklog.Event) error { // canonicalise to the time zone that the event was // recorded in for the purposes of the dashboard. // See comment in atKeyboard. @@ -593,7 +593,7 @@ func (d *daemon) atKeyboard(ctx context.Context, db storage, rules map[string]ma var atKeyboard []worklog.Event for srcBucket, ruleSet := range rules { for dstBucket, rule := range ruleSet { - err := db.EventsRangeFunc(db.BucketID(srcBucket), start, end, -1, func(m worklog.Event) error { + err := db.EventsRangeFunc(ctx, db.BucketID(srcBucket), start, end, -1, func(m worklog.Event) error { // atKeyboard is used for week and year intervals which // may involve work spanning multiple time zones. We // canonicalise to the time zone that the event was diff --git a/cmd/worklog/main.go b/cmd/worklog/main.go index 01cc713..b53a647 100644 --- a/cmd/worklog/main.go +++ b/cmd/worklog/main.go @@ -44,6 +44,7 @@ import ( sqlite3 "modernc.org/sqlite/lib" worklog "github.com/kortschak/dex/cmd/worklog/api" + "github.com/kortschak/dex/cmd/worklog/pgstore" "github.com/kortschak/dex/cmd/worklog/store" "github.com/kortschak/dex/internal/celext" "github.com/kortschak/dex/internal/localtime" @@ -243,27 +244,29 @@ func (v *atomicIfaceValue[T]) Load() T { } type storage interface { - AmendEvents(ts time.Time, note *worklog.Amendment) (sql.Result, error) - Backup(ctx context.Context, n int, sleep time.Duration) (string, error) + AmendEvents(ctx context.Context, ts time.Time, note *worklog.Amendment) (sql.Result, error) BucketID(uid string) string - BucketMetadata(bid string) (*worklog.BucketMetadata, error) - Buckets() ([]worklog.BucketMetadata, error) - Close() error - CreateBucket(uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) - Dump() ([]worklog.BucketMetadata, error) - DumpRange(start, end time.Time) ([]worklog.BucketMetadata, error) - Events(bid string) ([]worklog.Event, error) - EventsRange(bid string, start, end time.Time, limit int) ([]worklog.Event, error) - EventsRangeFunc(bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error - InsertEvent(e *worklog.Event) (sql.Result, error) - LastEvent(uid string) (*worklog.Event, error) - Load(buckets []worklog.BucketMetadata, replace bool) (err error) + BucketMetadata(ctx context.Context, bid string) (*worklog.BucketMetadata, error) + Buckets(ctx context.Context) ([]worklog.BucketMetadata, error) + Close(ctx context.Context) error + CreateBucket(ctx context.Context, uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) + Dump(ctx context.Context) ([]worklog.BucketMetadata, error) + DumpRange(ctx context.Context, start, end time.Time) ([]worklog.BucketMetadata, error) + Events(ctx context.Context, bid string) ([]worklog.Event, error) + EventsRange(ctx context.Context, bid string, start, end time.Time, limit int) ([]worklog.Event, error) + EventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error + InsertEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) + LastEvent(ctx context.Context, uid string) (*worklog.Event, error) + Load(ctx context.Context, buckets []worklog.BucketMetadata, replace bool) (err error) Name() string - Select(query string) ([]map[string]any, error) - UpdateEvent(e *worklog.Event) (sql.Result, error) + Select(ctx context.Context, query string) ([]map[string]any, error) + UpdateEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) } -var _ storage = (*store.DB)(nil) +var ( + _ storage = (*store.DB)(nil) + _ storage = (*pgstore.DB)(nil) +) type current interface { Location() (*time.Location, error) @@ -580,7 +583,7 @@ func (d *daemon) openDB(ctx context.Context, db storage, path, hostname string) if db != nil { d.log.LogAttrs(ctx, slog.LevelInfo, "close database", slog.String("path", db.Name())) d.db.Store((storage)(nil)) - db.Close() + db.Close(ctx) } // store.Open may need to get the hostname, which may // wait indefinitely due to network unavailability. @@ -603,7 +606,7 @@ func (d *daemon) configureDB(ctx context.Context, db storage) { rules := d.rules.Load() for bucket, rule := range rules { d.log.LogAttrs(ctx, slog.LevelDebug, "create bucket", slog.Any("bucket", bucket)) - m, err := db.CreateBucket(bucket, rule.name, rule.typ, d.uid, time.Now(), nil) + m, err := db.CreateBucket(ctx, bucket, rule.name, rule.typ, d.uid, time.Now(), nil) var sqlErr *sqlite.Error switch { case err == nil: @@ -650,7 +653,7 @@ func (d *daemon) record(ctx context.Context, src rpc.UID, curr, last worklog.Rep d.log.LogAttrs(ctx, slog.LevelDebug, "last event in cache", slog.String("bucket", bucket), slog.Any("last", lastEvent)) } else { var err error - lastEvent, err = db.LastEvent(bucket) + lastEvent, err = db.LastEvent(ctx, bucket) if err == nil { d.log.LogAttrs(ctx, slog.LevelDebug, "last event from store", slog.String("bucket", bucket), slog.Any("last", lastEvent)) } else { @@ -697,13 +700,13 @@ func (d *daemon) record(ctx context.Context, src rpc.UID, curr, last worklog.Rep var isNew bool if lastEvent.ID != 0 && note.Continue != nil && *note.Continue { note.ID = lastEvent.ID - _, err = db.UpdateEvent(note) + _, err = db.UpdateEvent(ctx, note) if err != nil { d.log.LogAttrs(ctx, slog.LevelError, "failed update event", slog.Any("error", err), slog.Any("note", note)) continue } } else { - res, err := db.InsertEvent(note) + res, err := db.InsertEvent(ctx, note) if err != nil { d.log.LogAttrs(ctx, slog.LevelError, "failed insert event", slog.Any("error", err), slog.Any("note", note)) continue @@ -971,7 +974,7 @@ func (d *daemon) amend(ctx context.Context) http.HandlerFunc { http.ServeContent(w, req, "amended.json", time.Now(), strings.NewReader("[]")) return } - _, err = db.AmendEvents(now, ¬e) + _, err = db.AmendEvents(ctx, now, ¬e) if err != nil { d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) w.WriteHeader(http.StatusBadRequest) @@ -979,7 +982,7 @@ func (d *daemon) amend(ctx context.Context) http.HandlerFunc { } var amended []worklog.Event for _, r := range mergeReplacement(note.Replace) { - e, err := db.EventsRange(db.BucketID(note.Bucket), r.start, r.end, -1) + e, err := db.EventsRange(ctx, db.BucketID(note.Bucket), r.start, r.end, -1) amended = append(amended, e...) if err != nil { d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) @@ -1073,9 +1076,9 @@ func (d *daemon) dump(ctx context.Context) http.HandlerFunc { return } } - dump, err = db.DumpRange(start, end) + dump, err = db.DumpRange(ctx, start, end) } else { - dump, err = db.Dump() + dump, err = db.Dump(ctx) } if err != nil { d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) @@ -1107,30 +1110,44 @@ func (d *daemon) backup(ctx context.Context) http.HandlerFunc { w.WriteHeader(http.StatusInternalServerError) return } - var n int - if req.URL.Query().Has("pages_per_step") { - var err error - n, err = strconv.Atoi(req.URL.Query().Get("pages_per_step")) - if err != nil { - d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, `{"error":%q}`, err) - return + + var ( + path string + err error + ) + switch db := db.(type) { + case *store.DB: + var n int + if req.URL.Query().Has("pages_per_step") { + n, err = strconv.Atoi(req.URL.Query().Get("pages_per_step")) + if err != nil { + d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `{"error":%q}`, err) + return + } } - } - var sleep time.Duration - if req.URL.Query().Has("sleep") { - var err error - n, err = strconv.Atoi(req.URL.Query().Get("sleep")) - if err != nil { - d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) + var sleep time.Duration + if req.URL.Query().Has("sleep") { + n, err = strconv.Atoi(req.URL.Query().Get("sleep")) + if err != nil { + d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, `{"error":%q}`, err) + return + } + } + path, err = db.Backup(ctx, n, sleep) + case *pgstore.DB: + dir := req.URL.Query().Get("directory") + if dir == "" { + d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.String("error", "missing directory parameter"), slog.String("url", req.RequestURI)) w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, `{"error":%q}`, err) + fmt.Fprintf(w, `{"error":"missing directory parameter"}`) return } + path, err = db.Backup(ctx, dir) } - - path, err := db.Backup(ctx, n, sleep) if err != nil { d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) w.WriteHeader(http.StatusInternalServerError) @@ -1197,7 +1214,7 @@ func (d *daemon) load(ctx context.Context) http.HandlerFunc { w.WriteHeader(http.StatusBadRequest) return } - err = db.Load(dump.Buckets, replace) + err = db.Load(ctx, dump.Buckets, replace) if err != nil { d.log.LogAttrs(ctx, slog.LevelWarn, "web server", slog.Any("error", err), slog.String("url", req.RequestURI)) w.WriteHeader(http.StatusBadRequest) @@ -1241,7 +1258,7 @@ func (d *daemon) query(ctx context.Context) http.HandlerFunc { body.WriteString(req.URL.Query().Get("sql")) fallthrough case "application/sql": - resp, err := db.Select(body.String()) + resp, err := db.Select(ctx, body.String()) if err != nil { w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(map[string]any{"err": err.Error()}) @@ -1265,7 +1282,7 @@ func (d *daemon) query(ctx context.Context) http.HandlerFunc { prg, err := compile(body.String(), []cel.EnvOption{ cel.OptionalTypes(cel.OptionalTypesVersion(1)), celext.Lib(d.log), - cel.Lib(dbLib{db: db, log: d.log}), + cel.Lib(dbLib{ctx: ctx, db: db, log: d.log}), }) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -1315,6 +1332,7 @@ func queryError(dec *json.Decoder, body []byte, err error) any { } type dbLib struct { + ctx context.Context db storage log *slog.Logger } @@ -1341,7 +1359,7 @@ func (l dbLib) query(arg ref.Val) ref.Val { if !ok { return types.ValOrErr(sql, "no such overload") } - resp, err := l.db.Select(string(sql)) + resp, err := l.db.Select(l.ctx, string(sql)) if err != nil { return types.NewErr(err.Error()) } diff --git a/cmd/worklog/main_test.go b/cmd/worklog/main_test.go index 7051a81..b074f06 100644 --- a/cmd/worklog/main_test.go +++ b/cmd/worklog/main_test.go @@ -280,7 +280,7 @@ func mergeAfk() int { }, }) db := d.db.Load() - defer db.Close() + defer db.Close(ctx) d.configureDB(ctx, db) dec := json.NewDecoder(bytes.NewReader(data)) @@ -298,7 +298,7 @@ func mergeAfk() int { last = curr } - dump, err := db.Dump() + dump, err := db.Dump(ctx) if err != nil { fmt.Fprintf(os.Stderr, "failed to dump db: %v\n", err) return 1 @@ -390,7 +390,7 @@ func dashboardData() int { if status != 0 { return status } - defer db.Close() + defer db.Close(ctx) events, err := d.eventData(ctx, db, rules, date, *raw) if err != nil { @@ -460,7 +460,7 @@ func summaryData() int { if status != 0 { return status } - defer db.Close() + defer db.Close(ctx) events, err := d.rangeSummary(ctx, db, rules, start, end, *raw, nil) if err != nil { @@ -550,7 +550,7 @@ func newTestDaemon(ctx context.Context, cancel context.CancelFunc, verbose bool, fmt.Fprintf(os.Stderr, "failed to unmarshal db dump: %v\n", err) return nil, nil, nil, 1 } - err = db.Load(buckets, replace) + err = db.Load(ctx, buckets, replace) if err != nil { fmt.Fprintf(os.Stderr, "failed to load db dump: %v\n", err) return nil, nil, nil, 1 diff --git a/cmd/worklog/store/db.go b/cmd/worklog/store/db.go index b97412c..ce81bf3 100644 --- a/cmd/worklog/store/db.go +++ b/cmd/worklog/store/db.go @@ -192,7 +192,7 @@ func (db *DB) Backup(ctx context.Context, n int, sleep time.Duration) (string, e } // Close closes the database. -func (db *DB) Close() error { +func (db *DB) Close(ctx context.Context) error { db.mu.Lock() defer db.mu.Unlock() return errors.Join(db.store.Close(), db.roStore.Close()) @@ -234,7 +234,7 @@ const CreateBucket = `insert into buckets(id, name, type, client, hostname, crea // CreateBucket creates a new entry in the bucket table. If the entry already // exists it will return an sqlite.Error with the code sqlite3.SQLITE_CONSTRAINT_UNIQUE. // The SQL command run is [CreateBucket]. -func (db *DB) CreateBucket(uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) { +func (db *DB) CreateBucket(ctx context.Context, uid, name, typ, client string, created time.Time, data map[string]any) (m *worklog.BucketMetadata, err error) { bid := db.BucketID(uid) db.mu.Lock() defer db.mu.Unlock() @@ -277,7 +277,7 @@ const BucketMetadata = `select id, name, type, client, hostname, created, datast // BucketMetadata returns the metadata for the bucket with the provided internal // bucket ID. // The SQL command run is [BucketMetadata]. -func (db *DB) BucketMetadata(bid string) (*worklog.BucketMetadata, error) { +func (db *DB) BucketMetadata(ctx context.Context, bid string) (*worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() return bucketMetadata(db.store, bid) @@ -321,7 +321,7 @@ const InsertEvent = `insert into events(bucketrow, starttime, endtime, datastr) // InsertEvent inserts a new event into the events table. // The SQL command run is [InsertEvent]. -func (db *DB) InsertEvent(e *worklog.Event) (sql.Result, error) { +func (db *DB) InsertEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) db.mu.Lock() defer db.mu.Unlock() @@ -343,7 +343,7 @@ const UpdateEvent = `update events set starttime = ?, endtime = ?, datastr = ? w // UpdateEvent updates the event in the store corresponding to the provided // event. // The SQL command run is [UpdateEvent]. -func (db *DB) UpdateEvent(e *worklog.Event) (sql.Result, error) { +func (db *DB) UpdateEvent(ctx context.Context, e *worklog.Event) (sql.Result, error) { msg, err := json.Marshal(e.Data) if err != nil { return nil, err @@ -364,7 +364,7 @@ const LastEvent = `select id, starttime, endtime, datastr from events where buck // LastEvent returns the last event in the named bucket. // The SQL command run is [LastEvent]. -func (db *DB) LastEvent(uid string) (*worklog.Event, error) { +func (db *DB) LastEvent(ctx context.Context, uid string) (*worklog.Event, error) { bid := db.BucketID(uid) db.mu.Lock() rows, err := db.store.Query(LastEvent, bid) @@ -407,7 +407,7 @@ func (db *DB) LastEvent(uid string) (*worklog.Event, error) { } // Dump dumps the complete database into a slice of [worklog.BucketMetadata]. -func (db *DB) Dump() ([]worklog.BucketMetadata, error) { +func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() m, err := db.buckets() @@ -433,7 +433,7 @@ func (db *DB) Dump() ([]worklog.BucketMetadata, error) { // DumpRange dumps the database spanning the specified time range into a slice // of [worklog.BucketMetadata]. -func (db *DB) DumpRange(start, end time.Time) ([]worklog.BucketMetadata, error) { +func (db *DB) DumpRange(ctx context.Context, start, end time.Time) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() m, err := db.buckets() @@ -490,7 +490,7 @@ func (db *DB) dumpEventsRange(bid string, start, end time.Time, limit int) ([]wo // the bucket in the provided buckets slice, the existing events will be // deleted and replaced. If replace is false, the new events will be added to // the existing events in the store. -func (db *DB) Load(buckets []worklog.BucketMetadata, replace bool) (err error) { +func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replace bool) (err error) { db.mu.Lock() defer db.mu.Unlock() tx, err := db.store.Begin() @@ -539,7 +539,7 @@ const Buckets = `select id, name, type, client, hostname, created, datastr from // Buckets returns the full set of bucket metadata. // The SQL command run is [Buckets]. -func (db *DB) Buckets() ([]worklog.BucketMetadata, error) { +func (db *DB) Buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() return db.buckets() @@ -588,7 +588,7 @@ const Events = `select id, starttime, endtime, datastr from events where bucketr // Buckets returns the full set of events in the bucket with the provided // internal bucket ID. // The SQL command run is [Events]. -func (db *DB) Events(bid string) ([]worklog.Event, error) { +func (db *DB) Events(ctx context.Context, bid string) ([]worklog.Event, error) { db.mu.Lock() defer db.mu.Unlock() return db.events(bid) @@ -653,7 +653,7 @@ const ( // within the specified time range, sorted descending by end time. // The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] // or [EventsLimit] depending on whether start and end are zero. -func (db *DB) EventsRange(bid string, start, end time.Time, limit int) ([]worklog.Event, error) { +func (db *DB) EventsRange(ctx context.Context, bid string, start, end time.Time, limit int) ([]worklog.Event, error) { db.mu.Lock() defer db.mu.Unlock() var e []worklog.Event @@ -668,7 +668,7 @@ func (db *DB) EventsRange(bid string, start, end time.Time, limit int) ([]worklo // bucket ID within the specified time range, sorted descending by end time. // The SQL command run is [EventsRange], [EventsRangeUntil], [EventsRangeFrom] // or [EventsLimit] depending on whether start and end are zero. -func (db *DB) EventsRangeFunc(bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error { +func (db *DB) EventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error { db.mu.Lock() defer db.mu.Unlock() return db.eventsRangeFunc(bid, start, end, limit, fn, true) @@ -745,7 +745,7 @@ func (db *DB) eventsRangeFunc(bid string, start, end time.Time, limit int, fn fu // Select allows running an SQLite SELECT query. The query is run on a read-only // connection to the database. -func (db *DB) Select(query string) ([]map[string]any, error) { +func (db *DB) Select(ctx context.Context, query string) ([]map[string]any, error) { db.mu.Lock() defer db.mu.Unlock() rows, err := db.roStore.Query(query) @@ -831,7 +831,7 @@ commit;` // overlapping the note. On return the note.Replace slice will be sorted. // // The SQL command run is [AmendEvents]. -func (db *DB) AmendEvents(ts time.Time, note *worklog.Amendment) (sql.Result, error) { +func (db *DB) AmendEvents(ctx context.Context, ts time.Time, note *worklog.Amendment) (sql.Result, error) { if len(note.Replace) == 0 { return driver.RowsAffected(0), nil } diff --git a/cmd/worklog/store/db_test.go b/cmd/worklog/store/db_test.go index 736de40..f7a857d 100644 --- a/cmd/worklog/store/db_test.go +++ b/cmd/worklog/store/db_test.go @@ -36,6 +36,7 @@ func TestDB(t *testing.T) { }) } + ctx := context.Background() for _, interval := range []struct { name string duration time.Duration @@ -63,7 +64,7 @@ func TestDB(t *testing.T) { if err != nil { t.Fatalf("failed to create db: %v", err) } - err = db.Close() + err = db.Close(ctx) if err != nil { t.Fatalf("failed to close db: %v", err) } @@ -115,22 +116,22 @@ func TestDB(t *testing.T) { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } }() - err = db.Load(data, false) + err = db.Load(ctx, data, false) if err != nil { t.Errorf("failed to load data: %v", err) } - err = db.Load(data, true) + err = db.Load(ctx, data, true) if err != nil { t.Errorf("failed to load data: %v", err) } - got, err := db.Dump() + got, err := db.Dump(ctx) if err != nil { t.Errorf("failed to dump data: %v", err) } @@ -140,7 +141,7 @@ func TestDB(t *testing.T) { t.Errorf("unexpected dump result:\n--- want:\n+++ got:\n%s", cmp.Diff(want, got, ignoreID)) } - gotRange, err := db.DumpRange(now.Add(3*interval.duration), now.Add(4*interval.duration)) + gotRange, err := db.DumpRange(ctx, now.Add(3*interval.duration), now.Add(4*interval.duration)) if err != nil { t.Errorf("failed to dump data: %v", err) } @@ -151,7 +152,7 @@ func TestDB(t *testing.T) { t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRange, gotRange, ignoreID)) } - gotRangeFrom, err := db.DumpRange(now.Add(3*interval.duration), time.Time{}) + gotRangeFrom, err := db.DumpRange(ctx, now.Add(3*interval.duration), time.Time{}) if err != nil { t.Errorf("failed to dump data: %v", err) } @@ -162,7 +163,7 @@ func TestDB(t *testing.T) { t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeFrom, gotRangeFrom, ignoreID)) } - gotRangeUntil, err := db.DumpRange(time.Time{}, now.Add(4*interval.duration)) + gotRangeUntil, err := db.DumpRange(ctx, time.Time{}, now.Add(4*interval.duration)) if err != nil { t.Errorf("failed to dump data: %v", err) } @@ -173,7 +174,7 @@ func TestDB(t *testing.T) { t.Errorf("unexpected dump range result:\n--- want:\n+++ got:\n%s", cmp.Diff(wantRangeUntil, gotRangeUntil, ignoreID)) } - gotRangeAll, err := db.DumpRange(time.Time{}, time.Time{}) + gotRangeAll, err := db.DumpRange(ctx, time.Time{}, time.Time{}) if err != nil { t.Errorf("failed to dump data: %v", err) } @@ -185,18 +186,18 @@ func TestDB(t *testing.T) { }) t.Run("last_event", func(t *testing.T) { - db, err := Open(context.Background(), path, "test_host") + db, err := Open(ctx, path, "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } }() - got, err := db.LastEvent(bucket) + got, err := db.LastEvent(ctx, bucket) if err != nil { t.Errorf("failed to get last event: %v", err) } @@ -209,12 +210,12 @@ func TestDB(t *testing.T) { }) t.Run("update_last_event", func(t *testing.T) { - db, err := Open(context.Background(), path, "test_host") + db, err := Open(ctx, path, "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } @@ -222,19 +223,19 @@ func TestDB(t *testing.T) { e := data[0].Events[len(data[0].Events)-1] e.End = e.End.Add(interval.duration) - last, err := db.LastEvent(e.Bucket) + last, err := db.LastEvent(ctx, e.Bucket) if err != nil { t.Fatalf("failed to get last event: %v", err) } e.ID = last.ID - _, err = db.UpdateEvent(&e) + _, err = db.UpdateEvent(ctx, &e) if err != nil { t.Fatalf("failed to update event: %v", err) } if err != nil { t.Errorf("failed to update event: %v", err) } - got, err := db.LastEvent(bucket) + got, err := db.LastEvent(ctx, bucket) if err != nil { t.Errorf("failed to get last event: %v", err) } @@ -247,12 +248,12 @@ func TestDB(t *testing.T) { }) t.Run("events_range", func(t *testing.T) { - db, err := Open(context.Background(), path, "test_host") + db, err := Open(ctx, path, "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } @@ -261,7 +262,7 @@ func TestDB(t *testing.T) { bid := db.BucketID(bucket) for _, loc := range []*time.Location{time.Local, time.UTC} { t.Run(loc.String(), func(t *testing.T) { - got, err := db.EventsRange(bid, now.Add(3*interval.duration).In(loc), now.Add(4*interval.duration).In(loc), -1) + got, err := db.EventsRange(ctx, bid, now.Add(3*interval.duration).In(loc), now.Add(4*interval.duration).In(loc), -1) if err != nil { t.Errorf("failed to load data: %v", err) } @@ -278,12 +279,12 @@ func TestDB(t *testing.T) { }) t.Run("update_last_event_coequal", func(t *testing.T) { - db, err := Open(context.Background(), filepath.Join(workDir, "coequal.db"), "test_host") + db, err := Open(ctx, filepath.Join(workDir, "coequal.db"), "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } @@ -299,7 +300,7 @@ func TestDB(t *testing.T) { if err != nil { t.Fatalf("failed to unmarshal bucket message: %v", err) } - _, err = db.CreateBucket(b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) if err != nil { t.Fatalf("failed to create bucket: %v", err) } @@ -318,23 +319,23 @@ func TestDB(t *testing.T) { t.Fatalf("failed to unmarshal event message: %v", err) } if note.Continue != nil && *note.Continue { - last, err := db.LastEvent(note.Bucket) + last, err := db.LastEvent(ctx, note.Bucket) if err != nil { t.Fatalf("failed to get last event: %v", err) } note.ID = last.ID - _, err = db.UpdateEvent(note) + _, err = db.UpdateEvent(ctx, note) if err != nil { t.Fatalf("failed to update event: %v", err) } } else { - _, err = db.InsertEvent(note) + _, err = db.InsertEvent(ctx, note) if err != nil { t.Fatalf("failed to insert event: %v", err) } } - dump, err := db.Dump() + dump, err := db.Dump(ctx) if err != nil { t.Fatalf("failed to dump db after step %d: %v", i, err) } @@ -353,12 +354,12 @@ func TestDB(t *testing.T) { }) t.Run("amend", func(t *testing.T) { - db, err := Open(context.Background(), filepath.Join(workDir, "amend.db"), "test_host") + db, err := Open(ctx, filepath.Join(workDir, "amend.db"), "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } @@ -374,7 +375,7 @@ func TestDB(t *testing.T) { if err != nil { t.Fatalf("failed to unmarshal bucket message: %v", err) } - _, err = db.CreateBucket(b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) + _, err = db.CreateBucket(ctx, b.ID, b.Name, b.Type, b.Client, b.Created, b.Data) if err != nil { t.Fatalf("failed to create bucket: %v", err) } @@ -396,7 +397,7 @@ func TestDB(t *testing.T) { if err != nil { t.Fatalf("failed to unmarshal event message: %v", err) } - _, err = db.InsertEvent(note) + _, err = db.InsertEvent(ctx, note) if err != nil { t.Fatalf("failed to insert event: %v", err) } @@ -407,11 +408,11 @@ func TestDB(t *testing.T) { if err != nil { t.Fatalf("failed to unmarshal event message: %v", err) } - _, err = db.AmendEvents(time.Time{}, amendment) + _, err = db.AmendEvents(ctx, time.Time{}, amendment) if err != nil { t.Errorf("unexpected error amending events: %v", err) } - dump, err := db.Dump() + dump, err := db.Dump(ctx) if err != nil { t.Fatalf("failed to dump db: %v", err) } @@ -468,12 +469,12 @@ func TestDB(t *testing.T) { }) t.Run("dynamic_query", func(t *testing.T) { - db, err := Open(context.Background(), path, "test_host") + db, err := Open(ctx, path, "test_host") if err != nil { t.Fatalf("failed to create db: %v", err) } defer func() { - err = db.Close() + err = db.Close(ctx) if err != nil { t.Errorf("failed to close db: %v", err) } @@ -529,7 +530,7 @@ func TestDB(t *testing.T) { for _, test := range dynamicTests { t.Run(test.name, func(t *testing.T) { - got, err := db.Select(test.sql) + got, err := db.Select(ctx, test.sql) if !sameError(err, test.wantErr) { t.Errorf("unexpected error: got:%v want:%v", err, test.wantErr) return From 2c7c3eac456e57606adf14e863417a162aca14e9 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Sun, 22 Sep 2024 10:59:37 +0930 Subject: [PATCH 5/6] cmd/worklog/store: use contexts --- cmd/worklog/store/db.go | 84 ++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/cmd/worklog/store/db.go b/cmd/worklog/store/db.go index ce81bf3..8690f31 100644 --- a/cmd/worklog/store/db.go +++ b/cmd/worklog/store/db.go @@ -41,11 +41,11 @@ type DB struct { } type execer interface { - Exec(query string, args ...any) (sql.Result, error) + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } type querier interface { - Query(query string, args ...any) (*sql.Rows, error) + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) } func txDone(tx *sql.Tx, err *error) { @@ -88,7 +88,7 @@ func Open(ctx context.Context, name, host string) (*DB, error) { if err != nil { return nil, err } - _, err = db.Exec(Schema) + _, err = db.ExecContext(ctx, Schema) if err != nil { return nil, err } @@ -192,7 +192,7 @@ func (db *DB) Backup(ctx context.Context, n int, sleep time.Duration) (string, e } // Close closes the database. -func (db *DB) Close(ctx context.Context) error { +func (db *DB) Close(_ context.Context) error { db.mu.Lock() defer db.mu.Unlock() return errors.Join(db.store.Close(), db.roStore.Close()) @@ -238,15 +238,15 @@ func (db *DB) CreateBucket(ctx context.Context, uid, name, typ, client string, c bid := db.BucketID(uid) db.mu.Lock() defer db.mu.Unlock() - tx, err := db.store.Begin() + tx, err := db.store.BeginTx(ctx, nil) if err != nil { return nil, err } defer txDone(tx, &err) - return createBucket(tx, bid, name, typ, client, db.host, created, data) + return createBucket(ctx, tx, bid, name, typ, client, db.host, created, data) } -func createBucket(tx *sql.Tx, bid, name, typ, client, host string, created time.Time, data map[string]any) (*worklog.BucketMetadata, error) { +func createBucket(ctx context.Context, tx *sql.Tx, bid, name, typ, client, host string, created time.Time, data map[string]any) (*worklog.BucketMetadata, error) { var ( msg = []byte{} // datastr has a NOT NULL constraint. err error @@ -257,12 +257,12 @@ func createBucket(tx *sql.Tx, bid, name, typ, client, host string, created time. return nil, err } } - _, err = tx.Exec(CreateBucket, bid, name, typ, client, host, created.Format(time.RFC3339Nano), msg) + _, err = tx.ExecContext(ctx, CreateBucket, bid, name, typ, client, host, created.Format(time.RFC3339Nano), msg) var sqlErr *sqlite.Error if errors.As(err, &sqlErr) && sqlErr.Code() != sqlite3.SQLITE_CONSTRAINT_UNIQUE { return nil, err } - m, err := bucketMetadata(tx, bid) + m, err := bucketMetadata(ctx, tx, bid) if err != nil { return nil, err } @@ -280,11 +280,11 @@ const BucketMetadata = `select id, name, type, client, hostname, created, datast func (db *DB) BucketMetadata(ctx context.Context, bid string) (*worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() - return bucketMetadata(db.store, bid) + return bucketMetadata(ctx, db.store, bid) } -func bucketMetadata(db querier, bid string) (*worklog.BucketMetadata, error) { - rows, err := db.Query(BucketMetadata, bid) +func bucketMetadata(ctx context.Context, db querier, bid string) (*worklog.BucketMetadata, error) { + rows, err := db.QueryContext(ctx, BucketMetadata, bid) if err != nil { return nil, err } @@ -325,15 +325,15 @@ func (db *DB) InsertEvent(ctx context.Context, e *worklog.Event) (sql.Result, er bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) db.mu.Lock() defer db.mu.Unlock() - return insertEvent(db.store, bid, e) + return insertEvent(ctx, db.store, bid, e) } -func insertEvent(db execer, bid string, e *worklog.Event) (sql.Result, error) { +func insertEvent(ctx context.Context, db execer, bid string, e *worklog.Event) (sql.Result, error) { msg, err := json.Marshal(e.Data) if err != nil { return nil, err } - return db.Exec(InsertEvent, bid, e.Start.Format(time.RFC3339Nano), e.End.Format(time.RFC3339Nano), msg) + return db.ExecContext(ctx, InsertEvent, bid, e.Start.Format(time.RFC3339Nano), e.End.Format(time.RFC3339Nano), msg) } const UpdateEvent = `update events set starttime = ?, endtime = ?, datastr = ? where id = ? and bucketrow = ( @@ -351,7 +351,7 @@ func (db *DB) UpdateEvent(ctx context.Context, e *worklog.Event) (sql.Result, er bid := fmt.Sprintf("%s_%s", e.Bucket, db.host) db.mu.Lock() defer db.mu.Unlock() - return db.store.Exec(UpdateEvent, e.Start.Format(time.RFC3339Nano), e.End.Format(time.RFC3339Nano), msg, e.ID, bid) + return db.store.ExecContext(ctx, UpdateEvent, e.Start.Format(time.RFC3339Nano), e.End.Format(time.RFC3339Nano), msg, e.ID, bid) } const LastEvent = `select id, starttime, endtime, datastr from events where bucketrow = ( @@ -367,7 +367,7 @@ const LastEvent = `select id, starttime, endtime, datastr from events where buck func (db *DB) LastEvent(ctx context.Context, uid string) (*worklog.Event, error) { bid := db.BucketID(uid) db.mu.Lock() - rows, err := db.store.Query(LastEvent, bid) + rows, err := db.store.QueryContext(ctx, LastEvent, bid) db.mu.Unlock() if err != nil { return nil, err @@ -410,7 +410,7 @@ func (db *DB) LastEvent(ctx context.Context, uid string) (*worklog.Event, error) func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() - m, err := db.buckets() + m, err := db.buckets(ctx) if err != nil { return nil, err } @@ -419,7 +419,7 @@ func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { if !ok { return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) } - e, err := db.events(b.ID) + e, err := db.events(ctx, b.ID) if err != nil { return m, err } @@ -436,7 +436,7 @@ func (db *DB) Dump(ctx context.Context) ([]worklog.BucketMetadata, error) { func (db *DB) DumpRange(ctx context.Context, start, end time.Time) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() - m, err := db.buckets() + m, err := db.buckets(ctx) if err != nil { return nil, err } @@ -445,7 +445,7 @@ func (db *DB) DumpRange(ctx context.Context, start, end time.Time) ([]worklog.Bu if !ok { return m, fmt.Errorf("invalid bucket ID at %d: %s", i, b.ID) } - e, err := db.dumpEventsRange(b.ID, start, end, -1) + e, err := db.dumpEventsRange(ctx, b.ID, start, end, -1) if err != nil { return m, err } @@ -475,9 +475,9 @@ const ( ) limit ?` ) -func (db *DB) dumpEventsRange(bid string, start, end time.Time, limit int) ([]worklog.Event, error) { +func (db *DB) dumpEventsRange(ctx context.Context, bid string, start, end time.Time, limit int) ([]worklog.Event, error) { var e []worklog.Event - err := db.eventsRangeFunc(bid, start, end, limit, func(m worklog.Event) error { + err := db.eventsRangeFunc(ctx, bid, start, end, limit, func(m worklog.Event) error { e = append(e, m) return nil }, false) @@ -493,14 +493,14 @@ func (db *DB) dumpEventsRange(bid string, start, end time.Time, limit int) ([]wo func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replace bool) (err error) { db.mu.Lock() defer db.mu.Unlock() - tx, err := db.store.Begin() + tx, err := db.store.BeginTx(ctx, nil) if err != nil { return err } defer txDone(tx, &err) for _, m := range buckets { var b *worklog.BucketMetadata - b, err = createBucket(tx, m.ID, m.Name, m.Type, m.Client, m.Hostname, m.Created, m.Data) + b, err = createBucket(ctx, tx, m.ID, m.Name, m.Type, m.Client, m.Hostname, m.Created, m.Data) if err != nil { var sqlErr *sqlite.Error if errors.As(err, &sqlErr) && sqlErr.Code() != sqlite3.SQLITE_CONSTRAINT_UNIQUE { @@ -510,7 +510,7 @@ func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replac return err } if replace { - _, err = tx.Exec(DeleteBucketEvents, m.ID) + _, err = tx.ExecContext(ctx, DeleteBucketEvents, m.ID) if err != nil { return err } @@ -518,7 +518,7 @@ func (db *DB) Load(ctx context.Context, buckets []worklog.BucketMetadata, replac } for i, e := range m.Events { bid := fmt.Sprintf("%s_%s", e.Bucket, m.Hostname) - _, err = insertEvent(tx, bid, &m.Events[i]) + _, err = insertEvent(ctx, tx, bid, &m.Events[i]) if err != nil { return err } @@ -542,11 +542,11 @@ const Buckets = `select id, name, type, client, hostname, created, datastr from func (db *DB) Buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { db.mu.Lock() defer db.mu.Unlock() - return db.buckets() + return db.buckets(ctx) } -func (db *DB) buckets() ([]worklog.BucketMetadata, error) { - rows, err := db.store.Query(Buckets) +func (db *DB) buckets(ctx context.Context) ([]worklog.BucketMetadata, error) { + rows, err := db.store.QueryContext(ctx, Buckets) if err != nil { return nil, err } @@ -591,11 +591,11 @@ const Events = `select id, starttime, endtime, datastr from events where bucketr func (db *DB) Events(ctx context.Context, bid string) ([]worklog.Event, error) { db.mu.Lock() defer db.mu.Unlock() - return db.events(bid) + return db.events(ctx, bid) } -func (db *DB) events(bid string) ([]worklog.Event, error) { - rows, err := db.store.Query(Events, bid) +func (db *DB) events(ctx context.Context, bid string) ([]worklog.Event, error) { + rows, err := db.store.QueryContext(ctx, Events, bid) if err != nil { return nil, err } @@ -657,7 +657,7 @@ func (db *DB) EventsRange(ctx context.Context, bid string, start, end time.Time, db.mu.Lock() defer db.mu.Unlock() var e []worklog.Event - err := db.eventsRangeFunc(bid, start, end, limit, func(m worklog.Event) error { + err := db.eventsRangeFunc(ctx, bid, start, end, limit, func(m worklog.Event) error { e = append(e, m) return nil }, true) @@ -671,10 +671,10 @@ func (db *DB) EventsRange(ctx context.Context, bid string, start, end time.Time, func (db *DB) EventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error) error { db.mu.Lock() defer db.mu.Unlock() - return db.eventsRangeFunc(bid, start, end, limit, fn, true) + return db.eventsRangeFunc(ctx, bid, start, end, limit, fn, true) } -func (db *DB) eventsRangeFunc(bid string, start, end time.Time, limit int, fn func(worklog.Event) error, order bool) error { +func (db *DB) eventsRangeFunc(ctx context.Context, bid string, start, end time.Time, limit int, fn func(worklog.Event) error, order bool) error { var ( query string rows *sql.Rows @@ -686,25 +686,25 @@ func (db *DB) eventsRangeFunc(bid string, start, end time.Time, limit int, fn fu if !order { query = dumpEventsRange } - rows, err = db.store.Query(query, bid, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), limit) + rows, err = db.store.QueryContext(ctx, query, bid, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), limit) case !start.IsZero(): query = EventsRangeFrom if !order { query = dumpEventsRangeFrom } - rows, err = db.store.Query(query, bid, start.Format(time.RFC3339Nano), limit) + rows, err = db.store.QueryContext(ctx, query, bid, start.Format(time.RFC3339Nano), limit) case !end.IsZero(): query = EventsRangeUntil if !order { query = dumpEventsRangeUntil } - rows, err = db.store.Query(query, bid, end.Format(time.RFC3339Nano), limit) + rows, err = db.store.QueryContext(ctx, query, bid, end.Format(time.RFC3339Nano), limit) default: query = EventsLimit if !order { query = dumpEventsLimit } - rows, err = db.store.Query(query, bid, limit) + rows, err = db.store.QueryContext(ctx, query, bid, limit) } if err != nil { return err @@ -748,7 +748,7 @@ func (db *DB) eventsRangeFunc(bid string, start, end time.Time, limit int, fn fu func (db *DB) Select(ctx context.Context, query string) ([]map[string]any, error) { db.mu.Lock() defer db.mu.Unlock() - rows, err := db.roStore.Query(query) + rows, err := db.roStore.QueryContext(ctx, query) if err != nil { return nil, err } @@ -855,7 +855,7 @@ func (db *DB) AmendEvents(ctx context.Context, ts time.Time, note *worklog.Amend } db.mu.Lock() defer db.mu.Unlock() - return db.store.Exec(AmendEvents, db.BucketID(note.Bucket), ts.Format(time.RFC3339Nano), note.Message, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), replace) + return db.store.ExecContext(ctx, AmendEvents, db.BucketID(note.Bucket), ts.Format(time.RFC3339Nano), note.Message, start.Format(time.RFC3339Nano), end.Format(time.RFC3339Nano), replace) } const DeleteEvent = `delete from events where bucketrow = ( From bf380e9038dcf788099c9f2e69c221b99deac246 Mon Sep 17 00:00:00 2001 From: Dan Kortschak Date: Thu, 26 Sep 2024 20:16:01 +0930 Subject: [PATCH 6/6] cmd/worklog: wire pgstore in to program --- .github/workflows/ci.yml | 23 +++- cmd/worklog/README.md | 8 +- cmd/worklog/main.go | 140 +++++++++++++++-------- cmd/worklog/main_test.go | 91 ++++++++------- main_postgres_test.go | 9 ++ main_test.go | 128 +++++++++++++++++++++ testdata/worklog_load_postgres.txt | 172 +++++++++++++++++++++++++++++ 7 files changed, 484 insertions(+), 87 deletions(-) create mode 100644 main_postgres_test.go create mode 100644 testdata/worklog_load_postgres.txt diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62df736..61f02c7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -99,6 +99,7 @@ jobs: PGPORT: 5432 PGUSER: test_user PGPASSWORD: password + POSTGRES_DB: postgres steps: - name: install Go @@ -111,14 +112,32 @@ jobs: with: fetch-depth: 1 - - name: unit tests postgres + - name: non-Go linux dependencies + run: | + sudo apt-get update + sudo apt-get install -qq libudev-dev + + - name: set up postgres users run: | psql --host $PGHOST \ --username="postgres" \ --dbname="postgres" \ --command="CREATE USER $PGUSER PASSWORD '$PGPASSWORD'" \ --command="ALTER USER $PGUSER CREATEDB" \ - --command="CREATE USER ${PGUSER}_ro PASSWORD '$PGPASSWORD'" \ + --command="CREATE USER ${PGUSER}_ro PASSWORD '${PGPASSWORD}_ro'" \ --command="\du" + echo ${PGHOST}:${PGPORT}:*:${PGUSER}:${PGPASSWORD} >> ~/.pgpass + echo ${PGHOST}:${PGPORT}:*:${PGUSER}_ro:${PGPASSWORD}_ro >> ~/.pgpass + chmod 600 ~/.pgpass + - name: unit tests postgres + run: | go test ./cmd/worklog/pgstore + + - name: integration tests postgres + uses: nick-fields/retry@v3 + with: + timeout_minutes: 10 + max_attempts: 3 + command: | + go test -tags postgres -run TestScripts/worklog_load_postgres diff --git a/cmd/worklog/README.md b/cmd/worklog/README.md index 3da6f25..052c4fb 100644 --- a/cmd/worklog/README.md +++ b/cmd/worklog/README.md @@ -1,6 +1,6 @@ # `worklog` -`worklog` is a module that records screen activity, screen-saver lock state and AFK status. It takes messages from the `watcher` module and records them in an SQLite database and serves a small dashboard page that shows work activity. +`worklog` is a module that records screen activity, screen-saver lock state and AFK status. It takes messages from the `watcher` module and records them in an SQLite or PostgreSQL database and serves a small dashboard page that shows work activity. Example configuration fragment (requires a kernel configuration fragment): ``` @@ -187,3 +187,9 @@ The CEL environment enables the CEL [optional types library](https://pkg.go.dev/ ## CEL extensions The CEL environment provides the [`Lib`](https://pkg.go.dev/github.com/kortschak/dex/internal/celext#Lib) and [`StateLib`](https://pkg.go.dev/github.com/kortschak/dex/internal/celext#StateLib) extensions from the celext package. `StateLib` is only available in `module.*.options.rules.*.src`. + +## PostgreSQL store + +When using PostgreSQL as a store, the `~/.pgpass` file MAY be used for password look-up for the primary connection to the database and MUST be used for the read-only connection. + +The read-only connection is made on start-up. Before connection, the read-only user, which is `${PGUSER}_ro` where `${PGUSER}` is the user for the primary connection, is checked for its ability to read the tables used by the store and for the ability to do any non-SELECT operations. If the user cannot read the tables, a warning is emitted, but the connection is made. If non-SELECT operations are allowed for the user, or the user can read other tables, no connection is made. Since this check is only made at start-up, there is a TOCTOU concern here, but exploiting this would require having user ALTER and GRANT grants at which point you have already lost the game. diff --git a/cmd/worklog/main.go b/cmd/worklog/main.go index b53a647..3690c54 100644 --- a/cmd/worklog/main.go +++ b/cmd/worklog/main.go @@ -187,6 +187,7 @@ type daemon struct { rMu sync.Mutex lastEvents map[string]*worklog.Event + dMu sync.Mutex // tMu is only used for protecting configuration of db. db atomicIfaceValue[storage] lastReport map[rpc.UID]worklog.Report @@ -359,7 +360,7 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) } } - databaseDir, err := dbDir(m.Body) + scheme, databaseDir, err := dbDir(m.Body) if err != nil { d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.Any("error", err)) return nil, rpc.NewError(rpc.ErrCodeInvalidMessage, @@ -371,50 +372,86 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) }, ) } - if databaseDir != "" { - dir, err := xdg.State(databaseDir) - switch err { - case nil: - case syscall.ENOENT: - var ok bool - dir, ok = xdg.StateHome() - if !ok { - d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.String("error", "no XDG_STATE_HOME")) - return nil, err + d.dMu.Lock() + defer d.dMu.Unlock() + switch { + default: + d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.String("error", "unknown scheme"), slog.String("url", m.Body.Options.Database)) + return nil, rpc.NewError(rpc.ErrCodeInvalidMessage, + err.Error(), + map[string]any{ + "type": rpc.ErrCodeParameters, + "database": m.Body.Options.Database, + }, + ) + case scheme == "": + // Do nothing. + case scheme == "postgres", scheme == "sqlite": + var ( + open opener + addr string + ) + switch { + case scheme == "postgres": + addr = m.Body.Options.Database + open = func(ctx context.Context, addr, host string) (storage, error) { + db, err := pgstore.Open(ctx, addr, host) + if _, ok := err.(pgstore.Warning); ok { + d.log.LogAttrs(ctx, slog.LevelWarn, "configure database", slog.Any("error", err)) + err = nil + } + return db, err } - dir = filepath.Join(dir, databaseDir) - err = os.Mkdir(dir, 0o750) - if err != nil { - err := err.(*os.PathError) // See godoc for os.Mkdir for why this is safe. - d.log.LogAttrs(ctx, slog.LevelError, "create database dir", slog.Any("error", err)) - return nil, rpc.NewError(rpc.ErrCodeInternal, + + case scheme == "sqlite": + dir, err := xdg.State(databaseDir) + switch err { + case nil: + case syscall.ENOENT: + var ok bool + dir, ok = xdg.StateHome() + if !ok { + d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.String("error", "no XDG_STATE_HOME")) + return nil, err + } + dir = filepath.Join(dir, databaseDir) + err = os.Mkdir(dir, 0o750) + if err != nil { + err := err.(*os.PathError) // See godoc for os.Mkdir for why this is safe. + d.log.LogAttrs(ctx, slog.LevelError, "create database dir", slog.Any("error", err)) + return nil, rpc.NewError(rpc.ErrCodeInternal, + err.Error(), + map[string]any{ + "type": rpc.ErrCodePath, + "op": err.Op, + "path": err.Path, + "err": fmt.Sprint(err.Err), + }, + ) + } + default: + d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.Any("error", err)) + return nil, jsonrpc2.NewError( + rpc.ErrCodeInternal, err.Error(), - map[string]any{ - "type": rpc.ErrCodePath, - "op": err.Op, - "path": err.Path, - "err": fmt.Sprint(err.Err), - }, ) } - default: - d.log.LogAttrs(ctx, slog.LevelError, "configure database", slog.Any("error", err)) - return nil, jsonrpc2.NewError( - rpc.ErrCodeInternal, - err.Error(), - ) + + addr = filepath.Join(dir, "db.sqlite") + open = func(ctx context.Context, addr, host string) (storage, error) { + return store.Open(ctx, addr, host) + } } - path := filepath.Join(dir, "db.sqlite") - if db := d.db.Load(); db == nil || path != db.Name() { - err = d.openDB(ctx, db, path, m.Body.Options.Hostname) + if db := d.db.Load(); !sameDB(db, addr) { + err = d.openDB(ctx, db, open, addr, m.Body.Options.Hostname) if err != nil { return nil, rpc.NewError(rpc.ErrCodeInternal, err.Error(), map[string]any{ "type": rpc.ErrCodeStoreErr, "op": "open", - "path": path, + "name": addr, }, ) } @@ -444,34 +481,41 @@ func (d *daemon) Handle(ctx context.Context, req *jsonrpc2.Request) (any, error) } } -func dbDir(cfg worklog.Config) (string, error) { +func dbDir(cfg worklog.Config) (scheme, dir string, err error) { opt := cfg.Options if opt.Database == "" { - return opt.DatabaseDir, nil + if opt.DatabaseDir != "" { + scheme = "sqlite" + } + return scheme, opt.DatabaseDir, nil } u, err := url.Parse(opt.Database) if err != nil { - return "", err + return "", "", err } switch u.Scheme { case "": - return "", errors.New("missing scheme in database configuration") + return "", "", errors.New("missing scheme in database configuration") case "sqlite": if opt.DatabaseDir != "" && u.Opaque != opt.DatabaseDir { - return "", fmt.Errorf("inconsistent database directory configuration: (%s:)%s != %s", u.Scheme, u.Opaque, opt.DatabaseDir) + return "", "", fmt.Errorf("inconsistent database directory configuration: (%s:)%s != %s", u.Scheme, u.Opaque, opt.DatabaseDir) } if u.Opaque == "" { - return "", fmt.Errorf("sqlite configuration missing opaque data: %s", opt.Database) + return "", "", fmt.Errorf("sqlite configuration missing opaque data: %s", opt.Database) } - return u.Opaque, nil + return u.Scheme, u.Opaque, nil default: if opt.DatabaseDir != "" { - return "", fmt.Errorf("inconsistent database configuration: both %s database and sqlite directory configured", u.Scheme) + return "", "", fmt.Errorf("inconsistent database configuration: both %s database and sqlite directory configured", u.Scheme) } - return "", nil + return u.Scheme, "", nil } } +func sameDB(db storage, name string) bool { + return db != nil && name == db.Name() +} + func (d *daemon) replaceTimezone(ctx context.Context, dynamic *bool) { if dynamic == nil { return @@ -579,26 +623,28 @@ func (d *daemon) configureRules(ctx context.Context, rules map[string]worklog.Ru d.rules.Store(ruleDetails) } -func (d *daemon) openDB(ctx context.Context, db storage, path, hostname string) error { +type opener = func(ctx context.Context, addr, hostname string) (storage, error) + +func (d *daemon) openDB(ctx context.Context, db storage, open opener, addr, hostname string) error { if db != nil { - d.log.LogAttrs(ctx, slog.LevelInfo, "close database", slog.String("path", db.Name())) + d.log.LogAttrs(ctx, slog.LevelInfo, "close database", slog.String("name", db.Name())) d.db.Store((storage)(nil)) db.Close(ctx) } - // store.Open may need to get the hostname, which may + // An opener may need to get the hostname, which may // wait indefinitely due to network unavailability. // So make a timeout and allow the fallback to the // kernel-provided hostname. This fallback is // implemented by store.Open. ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - db, err := store.Open(ctx, path, hostname) + db, err := open(ctx, addr, hostname) if err != nil { d.log.LogAttrs(ctx, slog.LevelError, "open database", slog.Any("error", err)) return err } d.db.Store(db) - d.log.LogAttrs(ctx, slog.LevelInfo, "open database", slog.String("path", path)) + d.log.LogAttrs(ctx, slog.LevelInfo, "open database", slog.String("name", addr)) return nil } diff --git a/cmd/worklog/main_test.go b/cmd/worklog/main_test.go index b074f06..a9eb2ac 100644 --- a/cmd/worklog/main_test.go +++ b/cmd/worklog/main_test.go @@ -29,6 +29,7 @@ import ( "golang.org/x/sys/execabs" worklog "github.com/kortschak/dex/cmd/worklog/api" + "github.com/kortschak/dex/cmd/worklog/store" "github.com/kortschak/dex/internal/slogext" "github.com/kortschak/dex/rpc" ) @@ -267,7 +268,7 @@ func mergeAfk() int { ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) defer cancel() d := newDaemon("worklog", log, &level, addSource, ctx, cancel) - err = d.openDB(ctx, nil, "db.sqlite3", "localhost") + err = d.openDB(ctx, nil, open, "db.sqlite3", "localhost") if err != nil { fmt.Fprintf(os.Stderr, "failed to create db: %v\n", err) return 1 @@ -536,7 +537,7 @@ func newTestDaemon(ctx context.Context, cancel context.CancelFunc, verbose bool, })) d := newDaemon("worklog", log, &level, addSource, ctx, cancel) - err := d.openDB(ctx, nil, dbName, "localhost") + err := d.openDB(ctx, nil, open, dbName, "localhost") if err != nil { fmt.Fprintf(os.Stderr, "failed to create db: %v\n", err) return nil, nil, nil, 1 @@ -568,6 +569,10 @@ func newTestDaemon(ctx context.Context, cancel context.CancelFunc, verbose bool, return d, db, webRules, 0 } +func open(ctx context.Context, addr, host string) (storage, error) { + return store.Open(ctx, addr, host) +} + func generateData(ts *testscript.TestScript, neg bool, args []string) { if neg { ts.Fatalf("unsupported: ! gen_data") @@ -1717,57 +1722,66 @@ func TestMergeReplacements(t *testing.T) { } var dbDirTests = []struct { - name string - config worklog.Config - want string - wantErr error + name string + config worklog.Config + wantScheme string + wantDir string + wantErr error }{ { name: "none", }, { - name: "deprecated", - config: mkDBDirOptions("database_directory", ""), - want: "database_directory", + name: "deprecated", + config: mkDBDirOptions("database_directory", ""), + wantScheme: "sqlite", + wantDir: "database_directory", }, { - name: "url_only_sqlite", - config: mkDBDirOptions("", "sqlite:database_directory"), - want: "database_directory", + name: "url_only_sqlite", + config: mkDBDirOptions("", "sqlite:database_directory"), + wantScheme: "sqlite", + wantDir: "database_directory", }, { - name: "url_only_postgres", - config: mkDBDirOptions("", "postgres://username:password@localhost:5432/database_name"), - want: "", + name: "url_only_postgres", + config: mkDBDirOptions("", "postgres://username:password@localhost:5432/database_name"), + wantScheme: "postgres", + wantDir: "", }, { - name: "both_consistent", - config: mkDBDirOptions("database_directory", "sqlite:database_directory"), - want: "database_directory", + name: "both_consistent", + config: mkDBDirOptions("database_directory", "sqlite:database_directory"), + wantScheme: "sqlite", + wantDir: "database_directory", }, { - name: "missing_scheme", - config: mkDBDirOptions("database_dir", "database_directory"), - want: "", - wantErr: errors.New("missing scheme in database configuration"), + name: "missing_scheme", + config: mkDBDirOptions("database_dir", "database_directory"), + wantScheme: "", + wantDir: "", + wantErr: errors.New("missing scheme in database configuration"), }, { - name: "both_inconsistent_sqlite", - config: mkDBDirOptions("database_dir", "sqlite:database_directory"), - want: "", - wantErr: errors.New("inconsistent database directory configuration: (sqlite:)database_directory != database_dir"), + name: "both_inconsistent_sqlite", + config: mkDBDirOptions("database_dir", "sqlite:database_directory"), + wantScheme: "", + wantDir: "", + wantErr: errors.New("inconsistent database directory configuration: (sqlite:)database_directory != database_dir"), }, { - name: "invalid_sqlite_url", - config: mkDBDirOptions("", "sqlite:/database_directory"), - want: "", - wantErr: errors.New("sqlite configuration missing opaque data: sqlite:/database_directory"), + name: "invalid_sqlite_url", + config: mkDBDirOptions("", "sqlite:/database_directory"), + wantScheme: "", + wantDir: "", + wantErr: errors.New("sqlite configuration missing opaque data: sqlite:/database_directory"), }, { - name: "both_inconsistent_postgres", - config: mkDBDirOptions("database_dir", "postgres://username:password@localhost:5432/database_name"), - want: "", - wantErr: errors.New("inconsistent database configuration: both postgres database and sqlite directory configured"), + name: "both_inconsistent_postgres", + config: mkDBDirOptions("database_dir", "postgres://username:password@localhost:5432/database_name"), + wantScheme: "", + wantDir: "", + wantErr: errors.New("inconsistent database configuration: both postgres database and sqlite directory configured"), }, } @@ -1781,12 +1795,15 @@ func mkDBDirOptions(dir, url string) worklog.Config { func TestDBDir(t *testing.T) { for _, test := range dbDirTests { t.Run(test.name, func(t *testing.T) { - got, err := dbDir(test.config) + gotScheme, gotDir, err := dbDir(test.config) if !sameError(err, test.wantErr) { t.Errorf("unexpected error calling dbDir: got:%v want:%v", err, test.wantErr) } - if got != test.want { - t.Errorf("unexpected result: got:%q want:%q", got, test.want) + if gotScheme != test.wantScheme { + t.Errorf("unexpected scheme result: got:%q want:%q", gotScheme, test.wantScheme) + } + if gotDir != test.wantDir { + t.Errorf("unexpected dir result: got:%q want:%q", gotDir, test.wantDir) } }) } diff --git a/main_postgres_test.go b/main_postgres_test.go new file mode 100644 index 0000000..7509294 --- /dev/null +++ b/main_postgres_test.go @@ -0,0 +1,9 @@ +// Copyright ©2024 Dan Kortschak. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build postgres + +package main + +func init() { postgres = true } diff --git a/main_test.go b/main_test.go index 579d75d..5abf17f 100644 --- a/main_test.go +++ b/main_test.go @@ -6,18 +6,22 @@ package main import ( "bytes" + "context" "crypto/tls" "encoding/json" "flag" "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "regexp" + "strings" "testing" "time" + "github.com/jackc/pgx/v5" "github.com/rogpeppe/go-internal/gotooltest" "github.com/rogpeppe/go-internal/testscript" ) @@ -25,6 +29,9 @@ import ( var ( update = flag.Bool("update", false, "update tests") keep = flag.Bool("keep", false, "keep $WORK directory after tests") + + // postgres indicates tests were invoked with -tags postgres. + postgres bool ) func TestMain(m *testing.M) { @@ -45,6 +52,9 @@ func TestScripts(t *testing.T) { Cmds: map[string]func(ts *testscript.TestScript, neg bool, args []string){ "sleep": sleep, "grep_from_file": grep, + "expand": expand, + "createdb": createDB, + "grant_read": grantReadAccess, }, Setup: func(e *testscript.Env) error { pwd, err := os.Getwd() @@ -52,8 +62,25 @@ func TestScripts(t *testing.T) { return err } e.Setenv("PKG_ROOT", pwd) + for _, k := range []string{ + "PGUSER", "PGPASSWORD", + "PGHOST", "PGPORT", + "POSTGRES_DB", + } { + if v, ok := os.LookupEnv(k); ok { + e.Setenv(k, v) + } + } return nil }, + Condition: func(cond string) (bool, error) { + switch cond { + case "postgres": + return postgres, nil + default: + return false, fmt.Errorf("unknown condition: %s", cond) + } + }, } if err := gotooltest.Setup(&p); err != nil { t.Fatal(err) @@ -97,6 +124,107 @@ func grep(ts *testscript.TestScript, neg bool, args []string) { } } +func expand(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("unsupported: ! expand") + } + if len(args) != 2 { + ts.Fatalf("usage: expand src dst") + } + src, err := os.ReadFile(ts.MkAbs(args[0])) + ts.Check(err) + src = []byte(os.Expand(string(src), func(key string) string { + return ts.Getenv(key) + })) + err = os.WriteFile(ts.MkAbs(args[1]), src, 0o644) + ts.Check(err) +} + +func createDB(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("unsupported: ! createdb") + } + if len(args) != 2 { + ts.Fatalf("usage: createdb postgres://user:password@host:port/server new_db") + } + u, err := url.Parse(args[0]) + ts.Check(err) + createTestDB(ts, u, args[1]) +} + +func createTestDB(ts *testscript.TestScript, u *url.URL, dbname string) { + ctx := context.Background() + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + ts.Fatalf("failed to open admin database: %v", err) + } + _, err = db.Exec(ctx, "create database "+dbname) + if err != nil { + ts.Fatalf("failed to create test database: %v", err) + } + err = db.Close(ctx) + if err != nil { + ts.Fatalf("failed to close admin connection: %v", err) + } + + ts.Defer(func() { + dropTestDB(ts, ctx, u, dbname) + }) +} + +func dropTestDB(ts *testscript.TestScript, ctx context.Context, u *url.URL, dbname string) { + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + ts.Logf("failed to open admin database: %v", err) + return + } + _, err = db.Exec(ctx, "drop database if exists "+dbname) + if err != nil { + ts.Logf("failed to drop test database: %v", err) + return + } + err = db.Close(ctx) + if err != nil { + ts.Logf("failed to close admin connection: %v", err) + } +} + +func grantReadAccess(ts *testscript.TestScript, neg bool, args []string) { + if neg { + ts.Fatalf("unsupported: ! grant_read") + } + if len(args) != 2 { + ts.Fatalf("usage: grant_read postgres://user:password@host:port/dbname target_user") + } + + u, err := url.Parse(args[0]) + ts.Check(err) + ctx := context.Background() + db, err := pgx.Connect(ctx, u.String()) + if err != nil { + ts.Fatalf("failed to open database: %v", err) + } + + target := args[1] + statements := []string{ + fmt.Sprintf("GRANT CONNECT ON DATABASE %s TO %s", strings.TrimLeft(u.Path, "/"), target), + fmt.Sprintf("GRANT USAGE ON SCHEMA public TO %s", target), + fmt.Sprintf("GRANT SELECT ON ALL TABLES IN SCHEMA public TO %s", target), + } + for _, s := range statements { + _, err = db.Exec(ctx, s) + if err != nil { + ts.Logf("failed to execute grant: %q %v", s, err) + break + } + } + + err = db.Close(ctx) + if err != nil { + ts.Fatalf("failed to close connection: %v", err) + } +} + func get() int { jsonData := flag.Bool("json", false, "data from GET is JSON") flag.Parse() diff --git a/testdata/worklog_load_postgres.txt b/testdata/worklog_load_postgres.txt new file mode 100644 index 0000000..02b9204 --- /dev/null +++ b/testdata/worklog_load_postgres.txt @@ -0,0 +1,172 @@ +# Only run in a postgres-available environment. +[!postgres] stop 'Skipping postgres test.' + +# The build of worklog takes a fair while due to the size of the dependency +# tree, the size of some of the individual dependencies and the absence of +# caching when building within a test script. +[short] stop 'Skipping long test.' + +env HOME=${WORK} + +[linux] env XDG_CONFIG_HOME=${HOME}/.config +[linux] env XDG_RUNTIME_DIR=${HOME}/runtime +[linux] mkdir ${XDG_CONFIG_HOME}/dex +[linux] mkdir ${XDG_RUNTIME_DIR} +[linux] expand config.toml ${HOME}/.config/dex/config.toml + +[darwin] mkdir ${HOME}'/Library/Application Support/dex' +[darwin] expand config.toml ${HOME}'/Library/Application Support/dex/config.toml' + +env GOBIN=${WORK}/bin +env PATH=${GOBIN}:${PATH} +cd ${PKG_ROOT} +go install ./cmd/worklog +cd ${WORK} + +# Create the database... +createdb postgres://${PGUSER}:${PGPASSWORD}@${PGHOST}:${PGPORT}/${POSTGRES_DB} test_database +# and set up user details. +expand pgpass ${HOME}/.pgpass +chmod 600 ${HOME}/.pgpass + +# Start dex to load the data. +dex -log debug -lines &dex& +sleep 1s + +# Load the data... +POST dump.json http://localhost:9797/load/?replace=true +# and confirm. +GET -json http://localhost:9797/dump/ +cmp stdout want.json + +# Show that the non-granted user cannot read. +GET http://localhost:9797/query?sql=select+count(*)+as+event_count+from+events +cp stdout got_event_count_no_grant.json +grep_from_file want_event_count_no_grant.pattern got_event_count_no_grant.json + +# Grant the ro user read access... +grant_read postgres://${PGUSER}:${PGPASSWORD}@${PGHOST}:${PGPORT}/test_database ${PGUSER}_ro + +# and confirm that they can read. +GET http://localhost:9797/query?sql=select+count(*)+as+event_count+from+events +cmp stdout want_event_count_grant.json + +# Terminate dex to allow the test database to be dropped. +kill -INT dex +wait dex + +-- config.toml -- +[kernel] +device = [] +network = "tcp" + +[module.worklog] +path = "worklog" +log_mode = "log" +log_level = "debug" +log_add_source = true + +[module.worklog.options] +database = "postgres://${PGUSER}@${PGHOST}:${PGPORT}/test_database" +hostname = "localhost" +[module.worklog.options.web] +addr = "localhost:9797" +can_modify = true + +-- pgpass -- +${PGHOST}:${PGPORT}:*:${PGUSER}:${PGPASSWORD} +${PGHOST}:${PGPORT}:*:${PGUSER}_ro:${PGPASSWORD}_ro +-- dump.json -- +{ + "buckets": [ + { + "id": "afk_localhost", + "name": "afk-watcher", + "type": "afkstatus", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.424207553+10:30", + "events": [ + { + "bucket": "afk", + "id": 2, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "afk": false, + "locked": false + } + } + ] + }, + { + "id": "window_localhost", + "name": "window-watcher", + "type": "currentwindow", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.428793055+10:30", + "events": [ + { + "bucket": "window", + "id": 1, + "start": "2023-12-04T17:21:28.270750821+10:30", + "end": "2023-12-04T17:21:28.270750821+10:30", + "data": { + "app": "Gnome-terminal", + "title": "Terminal" + } + } + ] + } + ] +} +-- want.json -- +{ + "buckets": [ + { + "id": "afk_localhost", + "name": "afk-watcher", + "type": "afkstatus", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.424207+10:30", + "events": [ + { + "bucket": "afk", + "id": 1, + "start": "2023-12-04T17:21:28.27075+10:30", + "end": "2023-12-04T17:21:28.27075+10:30", + "data": { + "afk": false, + "locked": false + } + } + ] + }, + { + "id": "window_localhost", + "name": "window-watcher", + "type": "currentwindow", + "client": "worklog", + "hostname": "localhost", + "created": "2023-12-04T17:21:27.428793+10:30", + "events": [ + { + "bucket": "window", + "id": 2, + "start": "2023-12-04T17:21:28.27075+10:30", + "end": "2023-12-04T17:21:28.27075+10:30", + "data": { + "app": "Gnome-terminal", + "title": "Terminal" + } + } + ] + } + ] +} +-- want_event_count_no_grant.pattern -- +{"err":"ERROR: permission denied for (?:relation|table) events \(SQLSTATE 42501\): ro user failed read capability checks"} +-- want_event_count_grant.json -- +[{"event_count":2}]