package middleware import ( "context" "log/slog" "net/http" "slices" "strings" "github.com/go-chi/chi/v5" "shlf.space/internal/atproto" "shlf.space/internal/server/oauth" notfound "shlf.space/internal/views/not-found" ) type CtxKey string const UnreadNotificationCountCtxKey CtxKey = "unreadNotificationCount" type Middleware struct { oauth *oauth.OAuth idResolver *atproto.Resolver } func New(oauth *oauth.OAuth, idResolver *atproto.Resolver) Middleware { return Middleware{ oauth: oauth, idResolver: idResolver, } } type middlewareFunc func(http.Handler) http.Handler func (mw Middleware) ResolveIdent() middlewareFunc { excluded := []string{"favicon.ico", "favicon.svg"} return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { didOrHandle := chi.URLParam(r, "user") didOrHandle = strings.TrimPrefix(didOrHandle, "@") if slices.Contains(excluded, didOrHandle) { next.ServeHTTP(w, r) return } id, err := mw.idResolver.ResolveIdent(r.Context(), didOrHandle) if err != nil { slog.Error("failed to resolve did/handle", "err", err) w.WriteHeader(http.StatusNotFound) notfound.NotFoundPage(notfound.NotFoundParams{}).Render(r.Context(), w) return } ctx := context.WithValue(r.Context(), "resolvedId", *id) next.ServeHTTP(w, r.WithContext(ctx)) }) } }