diff --git a/daemon.go b/daemon.go index 1fffeaa..78c1c17 100644 --- a/daemon.go +++ b/daemon.go @@ -2,6 +2,8 @@ package main import ( "encoding/json" + "fmt" + "net" "os" "path/filepath" "sync" @@ -139,3 +141,220 @@ func (ls *LabelStore) save() error { return os.WriteFile(ls.path, data, 0o644) } + +// Daemon is the long-running vmuxd process. It polls /proc for Claude sessions, +// maintains the registry, and serves requests over a Unix socket. +type Daemon struct { + registry *SessionRegistry + labels *LabelStore + sockPath string + procDir string + claudeDir string + workspaceResolver func(claudePID int) string // nil = no workspace resolution + pollInterval time.Duration + stopCh chan struct{} + listener net.Listener +} + +// NewDaemon creates a daemon ready to start. +func NewDaemon(sockPath, procDir, claudeDir string, labels *LabelStore) *Daemon { + return &Daemon{ + registry: NewRegistry(), + labels: labels, + sockPath: sockPath, + procDir: procDir, + claudeDir: claudeDir, + pollInterval: 5 * time.Second, + stopCh: make(chan struct{}), + } +} + +// Start runs the daemon: initial scan, then listens on the Unix socket +// and polls for sessions in the background. +func (d *Daemon) Start() error { + // Synchronous initial scan before accepting connections + d.scanOnce(time.Now()) + + if err := d.cleanStaleSocket(); err != nil { + return fmt.Errorf("clean stale socket: %w", err) + } + + ln, err := net.Listen("unix", d.sockPath) + if err != nil { + return fmt.Errorf("listen %s: %w", d.sockPath, err) + } + d.listener = ln + + // Poll loop in background + go d.pollLoop() + + // Accept loop + go d.acceptLoop() + + return nil +} + +// Stop shuts down the daemon gracefully. +func (d *Daemon) Stop() { + select { + case <-d.stopCh: + return // already stopped + default: + close(d.stopCh) + } + if d.listener != nil { + d.listener.Close() + } +} + +// Wait blocks until the daemon stops. +func (d *Daemon) Wait() { + <-d.stopCh +} + +func (d *Daemon) acceptLoop() { + for { + conn, err := d.listener.Accept() + if err != nil { + select { + case <-d.stopCh: + return + default: + continue + } + } + go d.handleConnection(conn) + } +} + +func (d *Daemon) pollLoop() { + ticker := time.NewTicker(d.pollInterval) + defer ticker.Stop() + + for { + select { + case <-d.stopCh: + return + case t := <-ticker.C: + d.scanOnce(t) + } + } +} + +func (d *Daemon) scanOnce(now time.Time) { + procs, err := FindClaudeProcesses(d.procDir) + if err != nil { + return + } + + activeIDs := make(map[string]bool) + + for _, proc := range procs { + _, messages, err := FindSessionForProcess(d.claudeDir, proc) + if err != nil { + continue + } + + state := DetectState(messages, now) + preview := ExtractPreview(messages) + + var sessionID, gitBranch string + for _, msg := range messages { + if msg.SessionID != "" { + sessionID = msg.SessionID + } + if msg.GitBranch != "" { + gitBranch = msg.GitBranch + } + } + + if sessionID == "" { + continue + } + + workspace := "" + if d.workspaceResolver != nil { + workspace = d.workspaceResolver(proc.PID) + } + + label := d.labels.Get(sessionID) + + info := SessionInfo{ + PID: proc.PID, + SessionID: sessionID, + Cwd: proc.Cwd, + GitBranch: gitBranch, + State: state.String(), + Preview: preview, + Workspace: workspace, + Label: label, + } + + d.registry.Update(info) + activeIDs[sessionID] = true + } + + d.registry.RemoveStale(activeIDs) +} + +func (d *Daemon) handleConnection(conn net.Conn) { + defer conn.Close() + + var req Request + if err := json.NewDecoder(conn).Decode(&req); err != nil { + writeResponse(conn, Response{Error: "invalid request: " + err.Error()}) + return + } + + switch req.Action { + case "list": + sessions := d.registry.List() + writeResponse(conn, Response{OK: true, Sessions: sessions}) + + case "label": + var args LabelArgs + if err := json.Unmarshal(req.Args, &args); err != nil { + writeResponse(conn, Response{Error: "invalid label args: " + err.Error()}) + return + } + if err := d.labels.Set(args.SessionID, args.Label); err != nil { + writeResponse(conn, Response{Error: "set label: " + err.Error()}) + return + } + // Update registry with new label + d.registry.mu.Lock() + if ts, ok := d.registry.sessions[args.SessionID]; ok { + ts.Info.Label = args.Label + } + d.registry.mu.Unlock() + writeResponse(conn, Response{OK: true}) + + case "stop": + writeResponse(conn, Response{OK: true}) + d.Stop() + + default: + writeResponse(conn, Response{Error: "unknown action: " + req.Action}) + } +} + +func writeResponse(conn net.Conn, resp Response) { + json.NewEncoder(conn).Encode(resp) +} + +// cleanStaleSocket removes a leftover socket file if no process is listening on it. +func (d *Daemon) cleanStaleSocket() error { + if _, err := os.Stat(d.sockPath); os.IsNotExist(err) { + return nil + } + + // Try connecting to check if another daemon is running + conn, err := net.DialTimeout("unix", d.sockPath, 500*time.Millisecond) + if err == nil { + conn.Close() + return fmt.Errorf("another daemon is already listening on %s", d.sockPath) + } + + // Stale socket, remove it + return os.Remove(d.sockPath) +} diff --git a/daemon_test.go b/daemon_test.go index 0172a3d..757823e 100644 --- a/daemon_test.go +++ b/daemon_test.go @@ -1,6 +1,8 @@ package main import ( + "encoding/json" + "net" "os" "path/filepath" "testing" @@ -173,7 +175,151 @@ func TestRegistryUpdateTimestamp(t *testing.T) { } } -// Placeholder to verify file exists -func init() { - _ = os.TempDir() +// sendRequest dials the daemon socket, sends a Request, and decodes the Response. +func sendRequest(t *testing.T, sockPath string, req Request) Response { + t.Helper() + conn, err := net.Dial("unix", sockPath) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + if err := json.NewEncoder(conn).Encode(req); err != nil { + t.Fatalf("encode: %v", err) + } + + var resp Response + if err := json.NewDecoder(conn).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + return resp +} + +// newTestDaemon creates a daemon using temp dirs (no real /proc or claude dir). +func newTestDaemon(t *testing.T) *Daemon { + t.Helper() + dir := t.TempDir() + sockPath := filepath.Join(dir, "vmux.sock") + procDir := filepath.Join(dir, "proc") // empty, no processes + claudeDir := filepath.Join(dir, "claude") + labelsPath := filepath.Join(dir, "labels.json") + + os.MkdirAll(procDir, 0o755) + os.MkdirAll(claudeDir, 0o755) + + labels, err := NewLabelStore(labelsPath) + if err != nil { + t.Fatalf("labels: %v", err) + } + + return NewDaemon(sockPath, procDir, claudeDir, labels) +} + +func TestDaemonStartStop(t *testing.T) { + d := newTestDaemon(t) + + if err := d.Start(); err != nil { + t.Fatalf("start: %v", err) + } + + // Send stop via socket + resp := sendRequest(t, d.sockPath, Request{Action: "stop"}) + if !resp.OK { + t.Errorf("stop resp.OK = false, error = %q", resp.Error) + } + + // Daemon should stop within a short time + done := make(chan struct{}) + go func() { + d.Wait() + close(done) + }() + + select { + case <-done: + // ok + case <-time.After(2 * time.Second): + t.Fatal("daemon did not stop within 2s") + } +} + +func TestDaemonListOverSocket(t *testing.T) { + d := newTestDaemon(t) + + if err := d.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer d.Stop() + + // Populate registry after Start (initial scan clears unknown sessions) + d.registry.Update(SessionInfo{ + PID: 42, + SessionID: "test-sess", + State: "Working", + Cwd: "/tmp/test", + }) + + resp := sendRequest(t, d.sockPath, Request{Action: "list"}) + if !resp.OK { + t.Fatalf("list resp.OK = false, error = %q", resp.Error) + } + if len(resp.Sessions) != 1 { + t.Fatalf("sessions len = %d, want 1", len(resp.Sessions)) + } + if resp.Sessions[0].SessionID != "test-sess" { + t.Errorf("session_id = %q, want %q", resp.Sessions[0].SessionID, "test-sess") + } +} + +func TestDaemonLabelOverSocket(t *testing.T) { + d := newTestDaemon(t) + + if err := d.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer d.Stop() + + // Populate registry after Start (initial scan clears unknown sessions) + d.registry.Update(SessionInfo{ + PID: 42, + SessionID: "test-sess", + State: "Working", + Cwd: "/tmp/test", + }) + + // Set label + args, _ := json.Marshal(LabelArgs{SessionID: "test-sess", Label: "review MR"}) + resp := sendRequest(t, d.sockPath, Request{Action: "label", Args: args}) + if !resp.OK { + t.Fatalf("label resp.OK = false, error = %q", resp.Error) + } + + // Verify label appears in list + resp = sendRequest(t, d.sockPath, Request{Action: "list"}) + if !resp.OK { + t.Fatalf("list resp.OK = false, error = %q", resp.Error) + } + if len(resp.Sessions) != 1 { + t.Fatalf("sessions len = %d, want 1", len(resp.Sessions)) + } + if resp.Sessions[0].Label != "review MR" { + t.Errorf("label = %q, want %q", resp.Sessions[0].Label, "review MR") + } +} + +func TestDaemonUnknownAction(t *testing.T) { + d := newTestDaemon(t) + + if err := d.Start(); err != nil { + t.Fatalf("start: %v", err) + } + defer d.Stop() + + resp := sendRequest(t, d.sockPath, Request{Action: "bogus"}) + if resp.OK { + t.Error("expected OK=false for unknown action") + } + if resp.Error == "" { + t.Error("expected error message for unknown action") + } }