diff --git a/main.go b/main.go index 0f4240dc0d1ab48969c9aab39a907c608344d64d..2932e4f5d32cfd9a01cb1fb109dd5124556b0937 100644 --- a/main.go +++ b/main.go @@ -3,14 +3,11 @@ import ( "context" "crypto/rand" - "embed" "encoding/base64" "encoding/json" "flag" "fmt" - "html/template" "io" - "io/fs" "log" "net/http" "os" @@ -22,15 +19,6 @@ "github.com/peterbourgon/ff/v3" "golang.org/x/oauth2" ) -var ( - //go:embed static/* - _staticFiles embed.FS - staticFiles, _ = fs.Sub(_staticFiles, "static") - //go:embed static/index.gotmpl - indexFile string - indexTmpl = template.Must(template.New("").Parse(indexFile)) -) - func randString(nByte int) (string, error) { b := make([]byte, nByte) if _, err := io.ReadFull(rand.Reader, b); err != nil { @@ -56,265 +44,127 @@ clientID string clientSecret string port int scopes string - origin string } -func (a args) redirect() string { - return strings.TrimSuffix(a.origin, "/") + "/callback" -} - -type OIDCConfig struct { - ClientProvider string `json:"client_provider"` - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` - Scopes string `json:"scopes"` -} - -type Server struct { - args args - - issuer string - scopes string - provider *oidc.Provider - verifier *oidc.IDTokenVerifier - oauth2Config *oauth2.Config -} - -type IndexContext struct { - Issuer string - ClientID string - ClientSecret string - Scopes string - RedirectURI string - - Results bool - TokenClaims string - UserInfo string - TokenResponse string -} - -func (s *Server) updateConfig(ctx context.Context, config *OIDCConfig) error { - if config.ClientProvider == "" || config.ClientID == "" { - return fmt.Errorf("client provider and client ID are required") +func main() { + var args args + fs := flag.NewFlagSet("oidc", flag.ExitOnError) + fs.StringVar(&args.clientProvider, "client-provider", "", "Client provider (e.g. https://accounts.google.com)") + fs.StringVar(&args.clientID, "client-id", "", "Client ID") + fs.StringVar(&args.clientSecret, "client-secret", "", "Client secret") + fs.IntVar(&args.port, "port", 8000, "Port to run on") + fs.StringVar(&args.scopes, "scopes", "profile,email", "Comma-delimited scopes") + fs.String("config", ".env", "Env config") + if err := ff.Parse(fs, os.Args[1:], + ff.WithEnvVarPrefix("OIDC"), + ff.WithConfigFileFlag("config"), + ff.WithAllowMissingConfigFile(true), + ff.WithConfigFileParser(ff.EnvParser), + ); err != nil { + log.Fatal(err) } + ctx := context.Background() + scopes := strings.Split(args.scopes, ",") - provider, err := oidc.NewProvider(ctx, config.ClientProvider) + provider, err := oidc.NewProvider(ctx, args.clientProvider) if err != nil { - return fmt.Errorf("failed to create OIDC provider: %v", err) + log.Fatal(err) } - oidcConfig := &oidc.Config{ - ClientID: config.ClientID, + ClientID: args.clientID, } verifier := provider.Verifier(oidcConfig) - scopes := strings.Fields(config.Scopes) - if len(scopes) == 0 { - scopes = []string{"profile", "email"} - } - - oauth2Config := &oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, + config := oauth2.Config{ + ClientID: args.clientID, + ClientSecret: args.clientSecret, Endpoint: provider.Endpoint(), - RedirectURL: s.args.redirect(), + RedirectURL: fmt.Sprintf("http://localhost:%d/callback", args.port), Scopes: append([]string{oidc.ScopeOpenID}, scopes...), } - s.issuer = config.ClientProvider - s.scopes = config.Scopes - s.provider = provider - s.verifier = verifier - s.oauth2Config = oauth2Config - - return nil -} - -func (s *Server) handleIndex(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodPost { - config := OIDCConfig{ - ClientProvider: r.FormValue("client_provider"), - ClientID: r.FormValue("client_id"), - ClientSecret: r.FormValue("client_secret"), - Scopes: r.FormValue("scopes"), - } - - if config.ClientProvider == "" || config.ClientID == "" { - http.Error(w, "Client provider and client ID are required", http.StatusBadRequest) + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + state, err := randString(16) + if err != nil { + http.Error(w, "Internal error", http.StatusInternalServerError) return } - - if err := s.updateConfig(r.Context(), &config); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + nonce, err := randString(16) + if err != nil { + http.Error(w, "Internal error", http.StatusInternalServerError) return } - - http.Redirect(w, r, "/auth", http.StatusFound) - return - } + setCallbackCookie(w, r, "state", state) + setCallbackCookie(w, r, "nonce", nonce) - indexTmpl.Execute(w, IndexContext{ - Issuer: s.args.clientProvider, - ClientID: s.args.clientID, - ClientSecret: s.args.clientSecret, - Scopes: s.args.scopes, - RedirectURI: s.args.redirect(), + http.Redirect(w, r, config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound) }) -} -func (s *Server) handleAuth(w http.ResponseWriter, r *http.Request) { - config := s.oauth2Config - - if config == nil { - http.Error(w, "OIDC not configured", http.StatusBadRequest) - return - } - - state, err := randString(16) - if err != nil { - http.Error(w, "Internal error", http.StatusInternalServerError) - return - } - nonce, err := randString(16) - if err != nil { - http.Error(w, "Internal error", http.StatusInternalServerError) - return - } - setCallbackCookie(w, r, "state", state) - setCallbackCookie(w, r, "nonce", nonce) - - http.Redirect(w, r, config.AuthCodeURL(state, oidc.Nonce(nonce)), http.StatusFound) -} - -func (s *Server) handleCallback(w http.ResponseWriter, r *http.Request) { - config := s.oauth2Config - provider := s.provider - verifier := s.verifier - - if config == nil || provider == nil || verifier == nil { - http.Error(w, "OIDC not configured", http.StatusBadRequest) - return - } - - state, err := r.Cookie("state") - if err != nil { - http.Error(w, "state not found", http.StatusBadRequest) - return - } - if r.URL.Query().Get("state") != state.Value { - http.Error(w, "state did not match", http.StatusBadRequest) - return - } - - oauth2Token, err := config.Exchange(r.Context(), r.URL.Query().Get("code")) - if err != nil { - http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) - return - } - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) - return - } - idToken, err := verifier.Verify(r.Context(), rawIDToken) - if err != nil { - http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) - return - } - - nonce, err := r.Cookie("nonce") - if err != nil { - http.Error(w, "nonce not found", http.StatusBadRequest) - return - } - if idToken.Nonce != nonce.Value { - http.Error(w, "nonce did not match", http.StatusBadRequest) - return - } - - userInfo, err := provider.UserInfo(r.Context(), oauth2.StaticTokenSource(oauth2Token)) - if err != nil { - http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError) - return - } - - var tokenClaims json.RawMessage - if err := idToken.Claims(&tokenClaims); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - tokenClaimsJSON, err := json.MarshalIndent(tokenClaims, "", "\t") - if err != nil { - http.Error(w, "Could not marshal token claims JSON: "+err.Error(), http.StatusInternalServerError) - return - } - - userInfoJSON, err := json.MarshalIndent(userInfo, "", "\t") - if err != nil { - http.Error(w, "Could not marshal userinfo JSON: "+err.Error(), http.StatusInternalServerError) - return - } + mux.HandleFunc("/callback", func(w http.ResponseWriter, r *http.Request) { + state, err := r.Cookie("state") + if err != nil { + http.Error(w, "state not found", http.StatusBadRequest) + return + } + if r.URL.Query().Get("state") != state.Value { + http.Error(w, "state did not match", http.StatusBadRequest) + return + } - tokenResponseJSON, err := json.MarshalIndent(oauth2Token, "", "\t") - if err != nil { - http.Error(w, "Could not marshal oauth2 token response JSON: "+err.Error(), http.StatusInternalServerError) - return - } + oauth2Token, err := config.Exchange(ctx, r.URL.Query().Get("code")) + if err != nil { + http.Error(w, "Failed to exchange token: "+err.Error(), http.StatusInternalServerError) + return + } + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + http.Error(w, "No id_token field in oauth2 token.", http.StatusInternalServerError) + return + } + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + http.Error(w, "Failed to verify ID Token: "+err.Error(), http.StatusInternalServerError) + return + } - indexTmpl.Execute(w, IndexContext{ - Issuer: s.issuer, - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - Scopes: s.scopes, - RedirectURI: config.RedirectURL, - Results: true, - TokenClaims: string(tokenClaimsJSON), - UserInfo: string(userInfoJSON), - TokenResponse: string(tokenResponseJSON), - }) -} + nonce, err := r.Cookie("nonce") + if err != nil { + http.Error(w, "nonce not found", http.StatusBadRequest) + return + } + if idToken.Nonce != nonce.Value { + http.Error(w, "nonce did not match", http.StatusBadRequest) + return + } -func main() { - var args args - fs := flag.NewFlagSet("oidc", flag.ExitOnError) - fs.StringVar(&args.clientProvider, "client-provider", "", "Default client provider (e.g. https://accounts.google.com)") - fs.StringVar(&args.clientID, "client-id", "", "Default client ID") - fs.StringVar(&args.clientSecret, "client-secret", "", "Default client secret") - fs.IntVar(&args.port, "port", 8000, "Port to run on") - fs.StringVar(&args.scopes, "scopes", "profile email", "Default scopes") - fs.StringVar(&args.origin, "origin", "http://localhost:8000", "Web origin") - fs.String("config", ".env", "Env config") - if err := ff.Parse(fs, os.Args[1:], - ff.WithEnvVarPrefix("OIDC"), - ff.WithConfigFileFlag("config"), - ff.WithAllowMissingConfigFile(true), - ff.WithConfigFileParser(ff.EnvParser), - ); err != nil { - log.Fatal(err) - } + userInfo, err := provider.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token)) + if err != nil { + http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError) + return + } - server := &Server{ - args: args, - } + resp := struct { + OAuth2Token *oauth2.Token `json:"oauth2_token"` + IDTokenClaims *json.RawMessage `json:"claims"` + UserInfo *oidc.UserInfo `json:"user_info"` + }{ + OAuth2Token: oauth2Token, + IDTokenClaims: new(json.RawMessage), + UserInfo: userInfo, + } - if args.clientProvider != "" && args.clientID != "" { - defaultConfig := &OIDCConfig{ - ClientProvider: args.clientProvider, - ClientID: args.clientID, - ClientSecret: args.clientSecret, - Scopes: args.scopes, + if err := idToken.Claims(&resp.IDTokenClaims); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return } - if err := server.updateConfig(context.Background(), defaultConfig); err != nil { - log.Printf("Warning: Failed to initialize default config: %v", err) + w.Header().Set("Content-Type", "application/json") + enc := json.NewEncoder(w) + enc.SetIndent("", "\t") + if err := enc.Encode(resp); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) } - } - - mux := http.NewServeMux() - mux.HandleFunc("/", server.handleIndex) - mux.Handle("/_/", http.StripPrefix("/_/", http.FileServerFS(staticFiles))) - mux.HandleFunc("/auth", server.handleAuth) - mux.HandleFunc("/callback", server.handleCallback) + }) bind := fmt.Sprintf(":%d", args.port) log.Println("listening on http://localhost" + bind) diff --git a/static/index.gotmpl b/static/index.gotmpl deleted file mode 100644 index c979f0f983701a229901dfe5f8630819b107178e..0000000000000000000000000000000000000000 --- a/static/index.gotmpl +++ /dev/null @@ -1,87 +0,0 @@ - - - -
- - -{{ .TokenClaims }}
- {{ .UserInfo }}
- {{ .TokenResponse }}
-