Expand Up @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" "io" "net" "sync" "time" Expand Down Expand Up @@ -164,16 +165,36 @@ func (q *msgQueue) dropped() { q.cond.Broadcast() } // pqListener is an interface that represents a *pq.Listener for testing type pqListener interface { io.Closer Listen(string) error Unlisten(string) error NotifyChan() <-chan *pq.Notification } type pqListenerShim struct { *pq.Listener } func (l pqListenerShim) NotifyChan() <-chan *pq.Notification { return l.Notify } // PGPubsub is a pubsub implementation using PostgreSQL. type PGPubsub struct { ctx context.Context cancel context.CancelFunc logger slog.Logger listenDone chan struct{} pgListener *pq.Listener db *sql.DB mut sync.Mutex queues map[string]map[uuid.UUID]*msgQueue logger slog.Logger listenDone chan struct{} pgListener pqListener db *sql.DB qMu sync.Mutex queues map[string]map[uuid.UUID]*msgQueue // making the close state its own mutex domain simplifies closing logic so // that we don't have to hold the qMu --- which could block processing // notifications while the pqListener is closing. closeMu sync.Mutex closedListener bool closeListenerErr error Expand All @@ -192,16 +213,14 @@ const BufferSize = 2048 // Subscribe calls the listener when an event matching the name is received. func (p *PGPubsub) Subscribe(event string, listener Listener) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx , listener, nil)) return p.subscribeQueue(event, newMsgQueue(context.Background() , listener, nil)) } func (p *PGPubsub) SubscribeWithErr(event string, listener ListenerWithErr) (cancel func(), err error) { return p.subscribeQueue(event, newMsgQueue(p.ctx , nil, listener)) return p.subscribeQueue(event, newMsgQueue(context.Background() , nil, listener)) } func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), err error) { p.mut.Lock() defer p.mut.Unlock() defer func() { if err != nil { // if we hit an error, we need to close the queue so we don't Expand All @@ -213,9 +232,13 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), } }() // The pgListener waits for the response to `LISTEN` on a mainloop that also dispatches // notifies. We need to avoid holding the mutex while this happens, since holding the mutex // blocks reading notifications and can deadlock the pgListener. // c.f. https://github.com/coder/coder/issues/11950 err = p.pgListener.Listen(event) if err == nil { p.logger.Debug(p.ctx , "started listening to event channel", slog.F("event", event)) p.logger.Debug(context.Background() , "started listening to event channel", slog.F("event", event)) } if errors.Is(err, pq.ErrChannelAlreadyOpen) { // It's ok if it's already open! Expand All @@ -224,6 +247,8 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), if err != nil { return nil, xerrors.Errorf("listen: %w", err) } p.qMu.Lock() defer p.qMu.Unlock() var eventQs map[uuid.UUID]*msgQueue var ok bool Expand All @@ -234,30 +259,36 @@ func (p *PGPubsub) subscribeQueue(event string, newQ *msgQueue) (cancel func(), id := uuid.New() eventQs[id] = newQ return func() { p.mut.Lock() defer p.mut.Unlock() p.qMu.Lock() listeners := p.queues[event] q := listeners[id] q.close() delete(listeners, id) if len(listeners) == 0 { delete(p.queues, event) } p.qMu.Unlock() // as above, we must not hold the lock while calling into pgListener if len(listeners) == 0 { uErr := p.pgListener.Unlisten(event) p.closeMu.Lock() defer p.closeMu.Unlock() if uErr != nil && !p.closedListener { p.logger.Warn(p.ctx , "failed to unlisten", slog.Error(uErr), slog.F("event", event)) p.logger.Warn(context.Background() , "failed to unlisten", slog.Error(uErr), slog.F("event", event)) } else { p.logger.Debug(p.ctx , "stopped listening to event channel", slog.F("event", event)) p.logger.Debug(context.Background() , "stopped listening to event channel", slog.F("event", event)) } } }, nil } func (p *PGPubsub) Publish(event string, message []byte) error { p.logger.Debug(p.ctx , "publish", slog.F("event", event), slog.F("message_len", len(message))) p.logger.Debug(context.Background() , "publish", slog.F("event", event), slog.F("message_len", len(message))) // This is safe because we are calling pq.QuoteLiteral. pg_notify doesn't // support the first parameter being a prepared statement. //nolint:gosec _, err := p.db.ExecContext(p.ctx , `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) _, err := p.db.ExecContext(context.Background() , `select pg_notify(`+pq.QuoteLiteral(event)+`, $1)`, message) if err != nil { p.publishesTotal.WithLabelValues("false").Inc() return xerrors.Errorf("exec pg_notify: %w", err) Expand All @@ -269,53 +300,38 @@ func (p *PGPubsub) Publish(event string, message []byte) error { // Close closes the pubsub instance. func (p *PGPubsub) Close() error { p.logger.Info(p.ctx, "pubsub is closing") p.cancel() p.logger.Info(context.Background(), "pubsub is closing") err := p.closeListener() <-p.listenDone p.logger.Debug(p.ctx , "pubsub closed") p.logger.Debug(context.Background() , "pubsub closed") return err } // closeListener closes the pgListener, unless it has already been closed. func (p *PGPubsub) closeListener() error { p.mut .Lock() defer p.mut .Unlock() p.closeMu .Lock() defer p.closeMu .Unlock() if p.closedListener { return p.closeListenerErr } p.closeListenerErr = p.pgListener.Close() p.closedListener = true p.closeListenerErr = p.pgListener.Close() return p.closeListenerErr } // listen begins receiving messages on the pq listener. func (p *PGPubsub) listen() { defer func() { p.logger.Info(p.ctx, "pubsub listen stopped receiving notify") cErr := p.closeListener() if cErr != nil { p.logger.Error(p.ctx, "failed to close listener") } p.logger.Info(context.Background(), "pubsub listen stopped receiving notify") close(p.listenDone) }() var ( notif *pq.Notification ok bool ) for { select { case <-p.ctx.Done(): return case notif, ok = <-p.pgListener.Notify: if !ok { return } } notify := p.pgListener.NotifyChan() for notif := range notify { // A nil notification can be dispatched on reconnect. if notif == nil { p.logger.Debug(p.ctx , "notifying subscribers of a reconnection") p.logger.Debug(context.Background() , "notifying subscribers of a reconnection") p.recordReconnect() continue } Expand All @@ -331,8 +347,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { p.messagesTotal.WithLabelValues(sizeLabel).Inc() p.receivedBytesTotal.Add(float64(len(notif.Extra))) p.mut .Lock() defer p.mut .Unlock() p.qMu .Lock() defer p.qMu .Unlock() queues, ok := p.queues[notif.Channel] if !ok { return Expand All @@ -344,8 +360,8 @@ func (p *PGPubsub) listenReceive(notif *pq.Notification) { } func (p *PGPubsub) recordReconnect() { p.mut .Lock() defer p.mut .Unlock() p.qMu .Lock() defer p.qMu .Unlock() for _, listeners := range p.queues { for _, q := range listeners { q.dropped() Expand Down Expand Up @@ -409,30 +425,32 @@ func (p *PGPubsub) startListener(ctx context.Context, connectURL string) error { d: net.Dialer{}, } ) p.pgListener = pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { switch t { case pq.ListenerEventConnected: p.logger.Info(ctx, "pubsub connected to postgres") p.connected.Set(1.0) case pq.ListenerEventDisconnected: p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) p.connected.Set(0) case pq.ListenerEventReconnected: p.logger.Info(ctx, "pubsub reconnected to postgres") p.connected.Set(1) case pq.ListenerEventConnectionAttemptFailed: p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) } // This callback gets events whenever the connection state changes. // Don't send if the errChannel has already been closed. select { case <-errCh: return default: errCh <- err close(errCh) } }) p.pgListener = pqListenerShim{ Listener: pq.NewDialListener(dialer, connectURL, time.Second, time.Minute, func(t pq.ListenerEventType, err error) { switch t { case pq.ListenerEventConnected: p.logger.Info(ctx, "pubsub connected to postgres") p.connected.Set(1.0) case pq.ListenerEventDisconnected: p.logger.Error(ctx, "pubsub disconnected from postgres", slog.Error(err)) p.connected.Set(0) case pq.ListenerEventReconnected: p.logger.Info(ctx, "pubsub reconnected to postgres") p.connected.Set(1) case pq.ListenerEventConnectionAttemptFailed: p.logger.Error(ctx, "pubsub failed to connect to postgres", slog.Error(err)) } // This callback gets events whenever the connection state changes. // Don't send if the errChannel has already been closed. select { case <-errCh: return default: errCh <- err close(errCh) } }), } select { case err := <-errCh: if err != nil { Expand Down Expand Up @@ -501,24 +519,31 @@ func (p *PGPubsub) Collect(metrics chan<- prometheus.Metric) { p.connected.Collect(metrics) // implicit metrics p.mut .Lock() p.qMu .Lock() events := len(p.queues) subs := 0 for _, subscriberMap := range p.queues { subs += len(subscriberMap) } p.mut .Unlock() p.qMu .Unlock() metrics <- prometheus.MustNewConstMetric(currentSubscribersDesc, prometheus.GaugeValue, float64(subs)) metrics <- prometheus.MustNewConstMetric(currentEventsDesc, prometheus.GaugeValue, float64(events)) } // New creates a new Pubsub implementation using a PostgreSQL connection. func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connectURL string) (*PGPubsub, error) { // Start a new context that will be canceled when the pubsub is closed. ctx, cancel := context.WithCancel(context.Background()) p := &PGPubsub{ ctx: ctx, cancel: cancel, p := newWithoutListener(logger, database) if err := p.startListener(startCtx, connectURL); err != nil { return nil, err } go p.listen() logger.Info(startCtx, "pubsub has started") return p, nil } // newWithoutListener creates a new PGPubsub without creating the pqListener. func newWithoutListener(logger slog.Logger, database *sql.DB) *PGPubsub { return &PGPubsub{ logger: logger, listenDone: make(chan struct{}), db: database, Expand Down Expand Up @@ -567,10 +592,4 @@ func New(startCtx context.Context, logger slog.Logger, database *sql.DB, connect Help: "Whether we are connected (1) or not connected (0) to postgres", }), } if err := p.startListener(startCtx, connectURL); err != nil { return nil, err } go p.listen() logger.Info(ctx, "pubsub has started") return p, nil }