553 lines
15 KiB
Go
553 lines
15 KiB
Go
package test
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/kms/api-key-service/internal/auth"
|
|
)
|
|
|
|
func TestOAuth2Provider_GetDiscoveryDocument(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
providerURL string
|
|
mockResponse string
|
|
mockStatusCode int
|
|
expectError bool
|
|
expectedIssuer string
|
|
}{
|
|
{
|
|
name: "successful discovery",
|
|
providerURL: "https://example.com",
|
|
mockResponse: `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo",
|
|
"jwks_uri": "https://example.com/jwks"
|
|
}`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: false,
|
|
expectedIssuer: "https://example.com",
|
|
},
|
|
{
|
|
name: "missing provider URL",
|
|
providerURL: "",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid response status",
|
|
providerURL: "https://example.com",
|
|
mockResponse: `{"error": "not found"}`,
|
|
mockStatusCode: http.StatusNotFound,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid JSON response",
|
|
providerURL: "https://example.com",
|
|
mockResponse: `invalid json`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock server if needed
|
|
var server *httptest.Server
|
|
if tt.providerURL != "" && !tt.expectError {
|
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/.well-known/openid_configuration", r.URL.Path)
|
|
w.WriteHeader(tt.mockStatusCode)
|
|
w.Write([]byte(tt.mockResponse))
|
|
}))
|
|
defer server.Close()
|
|
tt.providerURL = server.URL
|
|
}
|
|
|
|
// Create config mock
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = tt.providerURL
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
discovery, err := provider.GetDiscoveryDocument(ctx)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, discovery)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, discovery)
|
|
assert.Equal(t, tt.expectedIssuer, discovery.Issuer)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuth2Provider_GenerateAuthURL(t *testing.T) {
|
|
// Create mock discovery server
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer server.Close()
|
|
|
|
tests := []struct {
|
|
name string
|
|
clientID string
|
|
state string
|
|
redirectURI string
|
|
expectError bool
|
|
}{
|
|
{
|
|
name: "successful URL generation",
|
|
clientID: "test-client-id",
|
|
state: "test-state",
|
|
redirectURI: "https://app.example.com/callback",
|
|
expectError: false,
|
|
},
|
|
{
|
|
name: "missing client ID",
|
|
clientID: "",
|
|
state: "test-state",
|
|
redirectURI: "https://app.example.com/callback",
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
|
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
authURL, err := provider.GenerateAuthURL(ctx, tt.state, tt.redirectURI)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Empty(t, authURL)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotEmpty(t, authURL)
|
|
assert.Contains(t, authURL, "https://example.com/auth")
|
|
assert.Contains(t, authURL, "client_id="+tt.clientID)
|
|
assert.Contains(t, authURL, "state="+tt.state)
|
|
assert.Contains(t, authURL, "redirect_uri=")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuth2Provider_ExchangeCodeForToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
code string
|
|
redirectURI string
|
|
codeVerifier string
|
|
clientID string
|
|
clientSecret string
|
|
mockResponse string
|
|
mockStatusCode int
|
|
expectError bool
|
|
expectedToken string
|
|
}{
|
|
{
|
|
name: "successful token exchange",
|
|
code: "test-code",
|
|
redirectURI: "https://app.example.com/callback",
|
|
codeVerifier: "test-verifier",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
mockResponse: `{
|
|
"access_token": "test-access-token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
"refresh_token": "test-refresh-token"
|
|
}`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: false,
|
|
expectedToken: "test-access-token",
|
|
},
|
|
{
|
|
name: "missing client ID",
|
|
code: "test-code",
|
|
redirectURI: "https://app.example.com/callback",
|
|
codeVerifier: "test-verifier",
|
|
clientID: "",
|
|
clientSecret: "test-client-secret",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "token endpoint error",
|
|
code: "test-code",
|
|
redirectURI: "https://app.example.com/callback",
|
|
codeVerifier: "test-verifier",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
mockResponse: `{"error": "invalid_grant"}`,
|
|
mockStatusCode: http.StatusBadRequest,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock servers
|
|
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer discoveryServer.Close()
|
|
|
|
var tokenServer *httptest.Server
|
|
if !tt.expectError {
|
|
tokenServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "POST", r.Method)
|
|
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
|
|
|
w.WriteHeader(tt.mockStatusCode)
|
|
w.Write([]byte(tt.mockResponse))
|
|
}))
|
|
defer tokenServer.Close()
|
|
|
|
// Update discovery server to return the token server URL
|
|
discoveryServer.Close()
|
|
discoveryServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "` + tokenServer.URL + `",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
}
|
|
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
|
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
|
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
tokenResp, err := provider.ExchangeCodeForToken(ctx, tt.code, tt.redirectURI, tt.codeVerifier)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, tokenResp)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, tokenResp)
|
|
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
|
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuth2Provider_GetUserInfo(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
accessToken string
|
|
mockResponse string
|
|
mockStatusCode int
|
|
expectError bool
|
|
expectedSub string
|
|
expectedEmail string
|
|
}{
|
|
{
|
|
name: "successful user info retrieval",
|
|
accessToken: "test-access-token",
|
|
mockResponse: `{
|
|
"sub": "user123",
|
|
"email": "user@example.com",
|
|
"name": "Test User",
|
|
"email_verified": true
|
|
}`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: false,
|
|
expectedSub: "user123",
|
|
expectedEmail: "user@example.com",
|
|
},
|
|
{
|
|
name: "unauthorized access token",
|
|
accessToken: "invalid-token",
|
|
mockResponse: `{"error": "invalid_token"}`,
|
|
mockStatusCode: http.StatusUnauthorized,
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "invalid JSON response",
|
|
accessToken: "test-access-token",
|
|
mockResponse: `invalid json`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock servers
|
|
userInfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "GET", r.Method)
|
|
assert.Equal(t, "Bearer "+tt.accessToken, r.Header.Get("Authorization"))
|
|
|
|
w.WriteHeader(tt.mockStatusCode)
|
|
w.Write([]byte(tt.mockResponse))
|
|
}))
|
|
defer userInfoServer.Close()
|
|
|
|
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "` + userInfoServer.URL + `"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer discoveryServer.Close()
|
|
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
userInfo, err := provider.GetUserInfo(ctx, tt.accessToken)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, userInfo)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, userInfo)
|
|
assert.Equal(t, tt.expectedSub, userInfo.Sub)
|
|
assert.Equal(t, tt.expectedEmail, userInfo.Email)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuth2Provider_ValidateIDToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
idToken string
|
|
expectError bool
|
|
expectedSub string
|
|
}{
|
|
{
|
|
name: "valid ID token",
|
|
// This is a mock JWT token with payload: {"sub": "user123", "email": "user@example.com", "name": "Test User"}
|
|
idToken: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJ1c2VyMTIzIiwiZW1haWwiOiJ1c2VyQGV4YW1wbGUuY29tIiwibmFtZSI6IlRlc3QgVXNlciJ9.invalid-signature",
|
|
expectError: false,
|
|
expectedSub: "user123",
|
|
},
|
|
{
|
|
name: "invalid token format",
|
|
idToken: "invalid.token",
|
|
expectError: true,
|
|
},
|
|
{
|
|
name: "empty token",
|
|
idToken: "",
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
configMock := NewMockConfig()
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
authContext, err := provider.ValidateIDToken(ctx, tt.idToken)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, authContext)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, authContext)
|
|
assert.Equal(t, tt.expectedSub, authContext.UserID)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOAuth2Provider_RefreshAccessToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
refreshToken string
|
|
clientID string
|
|
clientSecret string
|
|
mockResponse string
|
|
mockStatusCode int
|
|
expectError bool
|
|
expectedToken string
|
|
}{
|
|
{
|
|
name: "successful token refresh",
|
|
refreshToken: "test-refresh-token",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
mockResponse: `{
|
|
"access_token": "new-access-token",
|
|
"token_type": "Bearer",
|
|
"expires_in": 3600,
|
|
"refresh_token": "new-refresh-token"
|
|
}`,
|
|
mockStatusCode: http.StatusOK,
|
|
expectError: false,
|
|
expectedToken: "new-access-token",
|
|
},
|
|
{
|
|
name: "invalid refresh token",
|
|
refreshToken: "invalid-refresh-token",
|
|
clientID: "test-client-id",
|
|
clientSecret: "test-client-secret",
|
|
mockResponse: `{"error": "invalid_grant"}`,
|
|
mockStatusCode: http.StatusBadRequest,
|
|
expectError: true,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
// Create mock servers
|
|
tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "POST", r.Method)
|
|
assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
|
|
|
w.WriteHeader(tt.mockStatusCode)
|
|
w.Write([]byte(tt.mockResponse))
|
|
}))
|
|
defer tokenServer.Close()
|
|
|
|
discoveryServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "` + tokenServer.URL + `",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer discoveryServer.Close()
|
|
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = discoveryServer.URL
|
|
configMock.values["SSO_CLIENT_ID"] = tt.clientID
|
|
configMock.values["SSO_CLIENT_SECRET"] = tt.clientSecret
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
|
|
ctx := context.Background()
|
|
tokenResp, err := provider.RefreshAccessToken(ctx, tt.refreshToken)
|
|
|
|
if tt.expectError {
|
|
assert.Error(t, err)
|
|
assert.Nil(t, tokenResp)
|
|
} else {
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, tokenResp)
|
|
assert.Equal(t, tt.expectedToken, tokenResp.AccessToken)
|
|
assert.Equal(t, "Bearer", tokenResp.TokenType)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// Benchmark tests for OAuth2 operations
|
|
func BenchmarkOAuth2Provider_GetDiscoveryDocument(b *testing.B) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer server.Close()
|
|
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
ctx := context.Background()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := provider.GetDiscoveryDocument(ctx)
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func BenchmarkOAuth2Provider_GenerateAuthURL(b *testing.B) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
response := `{
|
|
"issuer": "https://example.com",
|
|
"authorization_endpoint": "https://example.com/auth",
|
|
"token_endpoint": "https://example.com/token",
|
|
"userinfo_endpoint": "https://example.com/userinfo"
|
|
}`
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Write([]byte(response))
|
|
}))
|
|
defer server.Close()
|
|
|
|
configMock := NewMockConfig()
|
|
configMock.values["SSO_PROVIDER_URL"] = server.URL
|
|
configMock.values["SSO_CLIENT_ID"] = "test-client-id"
|
|
|
|
logger := zap.NewNop()
|
|
provider := auth.NewOAuth2Provider(configMock, logger)
|
|
ctx := context.Background()
|
|
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
_, err := provider.GenerateAuthURL(ctx, "test-state", "https://app.example.com/callback")
|
|
if err != nil {
|
|
b.Fatal(err)
|
|
}
|
|
}
|
|
}
|