package test import ( "context" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "go.uber.org/zap" "github.com/kms/api-key-service/internal/auth" ) func TestOAuth2Provider_GetDiscoveryDocument(t *testing.T) { tests := []struct { name string providerURL string mockResponse string mockStatusCode int expectError bool expectedIssuer string }{ { name: "successful discovery", providerURL: "https://example.com", mockResponse: `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo", "jwks_uri": "https://example.com/jwks" }`, mockStatusCode: http.StatusOK, expectError: false, expectedIssuer: "https://example.com", }, { name: "missing provider URL", providerURL: "", expectError: true, }, { name: "invalid response status", providerURL: "https://example.com", mockResponse: `{"error": "not found"}`, mockStatusCode: http.StatusNotFound, expectError: true, }, { name: "invalid JSON response", providerURL: "https://example.com", mockResponse: `invalid json`, mockStatusCode: http.StatusOK, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock server if needed var server *httptest.Server if tt.providerURL != "" && !tt.expectError { server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/.well-known/openid_configuration", r.URL.Path) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponse)) })) defer server.Close() tt.providerURL = server.URL } // Create config mock configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = tt.providerURL logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() discovery, err := provider.GetDiscoveryDocument(ctx) if tt.expectError { assert.Error(t, err) assert.Nil(t, discovery) } else { assert.NoError(t, err) assert.NotNil(t, discovery) assert.Equal(t, tt.expectedIssuer, discovery.Issuer) } }) } } func TestOAuth2Provider_GenerateAuthURL(t *testing.T) { // Create mock discovery server server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer server.Close() tests := []struct { name string clientID string state string redirectURI string expectError bool }{ { name: "successful URL generation", clientID: "test-client-id", state: "test-state", redirectURI: "https://app.example.com/callback", expectError: false, }, { name: "missing client ID", clientID: "", state: "test-state", redirectURI: "https://app.example.com/callback", expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = server.URL configMock.values["SSO_CLIENT_ID"] = tt.clientID logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() authURL, err := provider.GenerateAuthURL(ctx, tt.state, tt.redirectURI) if tt.expectError { assert.Error(t, err) assert.Empty(t, authURL) } else { assert.NoError(t, err) assert.NotEmpty(t, authURL) assert.Contains(t, authURL, "https://example.com/auth") assert.Contains(t, authURL, "client_id="+tt.clientID) assert.Contains(t, authURL, "state="+tt.state) assert.Contains(t, authURL, "redirect_uri=") } }) } } func TestOAuth2Provider_ExchangeCodeForToken(t *testing.T) { tests := []struct { name string code string redirectURI string codeVerifier string clientID string clientSecret string mockResponse string mockStatusCode int expectError bool expectedToken string }{ { name: "successful token exchange", code: "test-code", redirectURI: "https://app.example.com/callback", codeVerifier: "test-verifier", clientID: "test-client-id", clientSecret: "test-client-secret", mockResponse: `{ "access_token": "test-access-token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "test-refresh-token" }`, mockStatusCode: http.StatusOK, expectError: false, expectedToken: "test-access-token", }, { name: "missing client ID", code: "test-code", redirectURI: "https://app.example.com/callback", codeVerifier: "test-verifier", clientID: "", clientSecret: "test-client-secret", expectError: true, }, { name: "token endpoint error", code: "test-code", redirectURI: "https://app.example.com/callback", codeVerifier: "test-verifier", clientID: "test-client-id", clientSecret: "test-client-secret", mockResponse: `{"error": "invalid_grant"}`, mockStatusCode: http.StatusBadRequest, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock servers discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer discoveryServer.Close() var tokenServer *httptest.Server if !tt.expectError { tokenServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponse)) })) defer tokenServer.Close() // Update discovery server to return the token server URL discoveryServer.Close() discoveryServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "` + tokenServer.URL + `", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) } configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL configMock.values["SSO_CLIENT_ID"] = tt.clientID configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() tokenResp, err := provider.ExchangeCodeForToken(ctx, tt.code, tt.redirectURI, tt.codeVerifier) if tt.expectError { assert.Error(t, err) assert.Nil(t, tokenResp) } else { assert.NoError(t, err) assert.NotNil(t, tokenResp) assert.Equal(t, tt.expectedToken, tokenResp.AccessToken) assert.Equal(t, "Bearer", tokenResp.TokenType) } }) } } func TestOAuth2Provider_GetUserInfo(t *testing.T) { tests := []struct { name string accessToken string mockResponse string mockStatusCode int expectError bool expectedSub string expectedEmail string }{ { name: "successful user info retrieval", accessToken: "test-access-token", mockResponse: `{ "sub": "user123", "email": "user@example.com", "name": "Test User", "email_verified": true }`, mockStatusCode: http.StatusOK, expectError: false, expectedSub: "user123", expectedEmail: "user@example.com", }, { name: "unauthorized access token", accessToken: "invalid-token", mockResponse: `{"error": "invalid_token"}`, mockStatusCode: http.StatusUnauthorized, expectError: true, }, { name: "invalid JSON response", accessToken: "test-access-token", mockResponse: `invalid json`, mockStatusCode: http.StatusOK, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock servers userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "GET", r.Method) assert.Equal(t, "Bearer "+tt.accessToken, r.Header.Get("Authorization")) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponse)) })) defer userInfoServer.Close() discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "` + userInfoServer.URL + `" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer discoveryServer.Close() configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() userInfo, err := provider.GetUserInfo(ctx, tt.accessToken) if tt.expectError { assert.Error(t, err) assert.Nil(t, userInfo) } else { assert.NoError(t, err) assert.NotNil(t, userInfo) assert.Equal(t, tt.expectedSub, userInfo.Sub) assert.Equal(t, tt.expectedEmail, userInfo.Email) } }) } } func TestOAuth2Provider_ValidateIDToken(t *testing.T) { tests := []struct { name string idToken string expectError bool expectedSub string }{ { name: "valid ID token", // This is a mock JWT token with payload: {"sub": "user123", "email": "user@example.com", "name": "Test User"} idToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwibmFtZSI6IlRlc3QgVXNlciJ9.invalid-signature", expectError: false, expectedSub: "user123", }, { name: "invalid token format", idToken: "invalid.token", expectError: true, }, { name: "empty token", idToken: "", expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { configMock := NewMockConfig() logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() authContext, err := provider.ValidateIDToken(ctx, tt.idToken) if tt.expectError { assert.Error(t, err) assert.Nil(t, authContext) } else { assert.NoError(t, err) assert.NotNil(t, authContext) assert.Equal(t, tt.expectedSub, authContext.UserID) } }) } } func TestOAuth2Provider_RefreshAccessToken(t *testing.T) { tests := []struct { name string refreshToken string clientID string clientSecret string mockResponse string mockStatusCode int expectError bool expectedToken string }{ { name: "successful token refresh", refreshToken: "test-refresh-token", clientID: "test-client-id", clientSecret: "test-client-secret", mockResponse: `{ "access_token": "new-access-token", "token_type": "Bearer", "expires_in": 3600, "refresh_token": "new-refresh-token" }`, mockStatusCode: http.StatusOK, expectError: false, expectedToken: "new-access-token", }, { name: "invalid refresh token", refreshToken: "invalid-refresh-token", clientID: "test-client-id", clientSecret: "test-client-secret", mockResponse: `{"error": "invalid_grant"}`, mockStatusCode: http.StatusBadRequest, expectError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Create mock servers tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "POST", r.Method) assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) w.WriteHeader(tt.mockStatusCode) w.Write([]byte(tt.mockResponse)) })) defer tokenServer.Close() discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "` + tokenServer.URL + `", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer discoveryServer.Close() configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL configMock.values["SSO_CLIENT_ID"] = tt.clientID configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() tokenResp, err := provider.RefreshAccessToken(ctx, tt.refreshToken) if tt.expectError { assert.Error(t, err) assert.Nil(t, tokenResp) } else { assert.NoError(t, err) assert.NotNil(t, tokenResp) assert.Equal(t, tt.expectedToken, tokenResp.AccessToken) assert.Equal(t, "Bearer", tokenResp.TokenType) } }) } } // Benchmark tests for OAuth2 operations func BenchmarkOAuth2Provider_GetDiscoveryDocument(b *testing.B) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer server.Close() configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = server.URL logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := provider.GetDiscoveryDocument(ctx) if err != nil { b.Fatal(err) } } } func BenchmarkOAuth2Provider_GenerateAuthURL(b *testing.B) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := `{ "issuer": "https://example.com", "authorization_endpoint": "https://example.com/auth", "token_endpoint": "https://example.com/token", "userinfo_endpoint": "https://example.com/userinfo" }` w.Header().Set("Content-Type", "application/json") w.Write([]byte(response)) })) defer server.Close() configMock := NewMockConfig() configMock.values["SSO_PROVIDER_URL"] = server.URL configMock.values["SSO_CLIENT_ID"] = "test-client-id" logger := zap.NewNop() provider := auth.NewOAuth2Provider(configMock, logger) ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { _, err := provider.GenerateAuthURL(ctx, "test-state", "https://app.example.com/callback") if err != nil { b.Fatal(err) } } }