add CAS authentication #138
@ -0,0 +1,54 @@
|
||||
package net.hostsharing.hsadminng.config;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import jakarta.servlet.http.HttpServletRequestWrapper;
|
||||
import java.util.*;
|
||||
|
||||
public class AuthenticatedHttpServletRequestWrapper extends HttpServletRequestWrapper {
|
||||
|
||||
private final Map<String, String> customHeaders = new HashMap<>();
|
||||
|
||||
public AuthenticatedHttpServletRequestWrapper(HttpServletRequest request) {
|
||||
super(request);
|
||||
}
|
||||
|
||||
public void addHeader(final String name, final String value) {
|
||||
customHeaders.put(name, value);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getHeader(final String name) {
|
||||
// Check custom headers first
|
||||
final var customHeaderValue = customHeaders.get(name);
|
||||
if (customHeaderValue != null) {
|
||||
return customHeaderValue;
|
||||
}
|
||||
// Fall back to the original headers
|
||||
return super.getHeader(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaderNames() {
|
||||
// Combine original headers and custom headers
|
||||
final var headerNames = new HashSet<>(customHeaders.keySet());
|
||||
final var originalHeaderNames = super.getHeaderNames();
|
||||
while (originalHeaderNames.hasMoreElements()) {
|
||||
headerNames.add(originalHeaderNames.nextElement());
|
||||
}
|
||||
return Collections.enumeration(headerNames);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Enumeration<String> getHeaders(final String name) {
|
||||
// Combine original headers and custom header
|
||||
final var values = new HashSet<String>();
|
||||
if (customHeaders.containsKey(name)) {
|
||||
values.add(customHeaders.get(name));
|
||||
}
|
||||
final var originalValues = super.getHeaders(name);
|
||||
while (originalValues.hasMoreElements()) {
|
||||
values.add(originalValues.nextElement());
|
||||
}
|
||||
return Collections.enumeration(values);
|
||||
}
|
||||
}
|
@ -9,13 +9,14 @@ import jakarta.servlet.http.HttpServletResponse;
|
||||
|
||||
import lombok.SneakyThrows;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.security.authentication.BadCredentialsException;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
@Component
|
||||
public class CasAuthenticationFilter implements Filter {
|
||||
|
||||
@Autowired
|
||||
private CasServiceTicketValidator ticketValidator;
|
||||
private CasAuthenticator casAuthenticator;
|
||||
|
||||
@Override
|
||||
@SneakyThrows
|
||||
@ -23,13 +24,16 @@ public class CasAuthenticationFilter implements Filter {
|
||||
final var httpRequest = (HttpServletRequest) request;
|
||||
final var httpResponse = (HttpServletResponse) response;
|
||||
|
||||
final var ticket = httpRequest.getHeader("Authorization");
|
||||
try {
|
||||
final var currentSubject = casAuthenticator.authenticate(httpRequest);
|
||||
|
||||
if (!ticketValidator.validateTicket(ticket)) {
|
||||
final var authenticatedRequest = new AuthenticatedHttpServletRequestWrapper(httpRequest);
|
||||
authenticatedRequest.addHeader("current-subject", currentSubject);
|
||||
|
||||
chain.doFilter(authenticatedRequest, response);
|
||||
} catch (final BadCredentialsException exc) {
|
||||
// TODO.impl: should not be necessary if ResponseStatusException worked
|
||||
httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
|
||||
return;
|
||||
}
|
||||
|
||||
chain.doFilter(request, response);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,54 @@
|
||||
package net.hostsharing.hsadminng.config;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.SneakyThrows;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.security.authentication.BadCredentialsException;
|
||||
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
|
||||
@Service
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class CasAuthenticator {
|
||||
|
||||
@Value("${hsadminng.cas.server-url}")
|
||||
private String casServerUrl;
|
||||
|
||||
@Value("${hsadminng.cas.service-url}")
|
||||
private String serviceUrl;
|
||||
|
||||
private final RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
@SneakyThrows
|
||||
public String authenticate(final HttpServletRequest httpRequest) {
|
||||
// FIXME: create FakeCasAuthenticator
|
||||
if (casServerUrl.equals("fake")) {
|
||||
return httpRequest.getHeader("current-subject");
|
||||
}
|
||||
|
||||
final var ticket = httpRequest.getHeader("Authorization");
|
||||
final var url = casServerUrl + "/p3/serviceValidate" +
|
||||
"?service=" + serviceUrl +
|
||||
"&ticket=" + ticket;
|
||||
|
||||
final var response = restTemplate.getForObject(url, String.class);
|
||||
|
||||
final var doc = DocumentBuilderFactory.newInstance().newDocumentBuilder()
|
||||
.parse(new java.io.ByteArrayInputStream(response.getBytes()));
|
||||
if ( doc.getElementsByTagName("cas:authenticationSuccess").getLength() == 0 ) {
|
||||
// TODO.impl: for unknown reasons, this results in a 403 FORBIDDEN
|
||||
// throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, "CAS service ticket could not be validated");
|
||||
throw new BadCredentialsException("CAS service ticket could not be validated");
|
||||
}
|
||||
final var authentication = new UsernamePasswordAuthenticationToken("test-user-from-authenticate", null, null); // TODO
|
||||
SecurityContextHolder.getContext().setAuthentication(authentication);
|
||||
return authentication.getName();
|
||||
}
|
||||
}
|
@ -1,40 +0,0 @@
|
||||
package net.hostsharing.hsadminng.config;
|
||||
|
||||
import lombok.SneakyThrows;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import javax.xml.parsers.DocumentBuilderFactory;
|
||||
import java.net.URLEncoder;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
@Service
|
||||
public class CasServiceTicketValidator {
|
||||
|
||||
@Value("${hsadminng.cas.server-url}")
|
||||
private String casServerUrl;
|
||||
|
||||
@Value("${hsadminng.cas.service-url}")
|
||||
private String serviceUrl;
|
||||
|
||||
private final RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
@SneakyThrows
|
||||
public boolean validateTicket(final String ticket) {
|
||||
if (casServerUrl.equals("fake")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
final var url = casServerUrl + "/p3/serviceValidate" +
|
||||
"?service=" + URLEncoder.encode(serviceUrl, StandardCharsets.UTF_8) +
|
||||
"&ticket=" + URLEncoder.encode(ticket, StandardCharsets.UTF_8);
|
||||
|
||||
final var response = restTemplate.getForObject(url, String.class);
|
||||
|
||||
final var doc = DocumentBuilderFactory.newInstance().newDocumentBuilder()
|
||||
.parse(new java.io.ByteArrayInputStream(response.getBytes()));
|
||||
|
||||
return doc.getElementsByTagName("cas:authenticationSuccess").getLength() > 0;
|
||||
}
|
||||
}
|
@ -1,16 +1,22 @@
|
||||
package net.hostsharing.hsadminng.ping;
|
||||
|
||||
import org.springframework.stereotype.Controller;
|
||||
import org.springframework.web.bind.annotation.RequestHeader;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMethod;
|
||||
import org.springframework.web.bind.annotation.ResponseBody;
|
||||
|
||||
import jakarta.validation.constraints.NotNull;
|
||||
|
||||
@Controller
|
||||
public class PingController {
|
||||
|
||||
@ResponseBody
|
||||
@RequestMapping(value = "/api/ping", method = RequestMethod.GET)
|
||||
public String ping() {
|
||||
return "pong\n";
|
||||
public String ping(
|
||||
@RequestHeader(name = "current-subject") @NotNull String currentSubject,
|
||||
@RequestHeader(name = "assumed-roles", required = false) String assumedRoles
|
||||
) {
|
||||
return "pong " + currentSubject + "\n";
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +1,22 @@
|
||||
package net.hostsharing.hsadminng.config;
|
||||
|
||||
import com.github.tomakehurst.wiremock.WireMockServer;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.web.client.TestRestTemplate;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.test.context.TestPropertySource;
|
||||
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import static com.github.tomakehurst.wiremock.client.WireMock.*;
|
||||
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
|
||||
@TestPropertySource(properties = {"server.port=0"})
|
||||
@TestPropertySource(properties = {"server.port=0", "hsadminng.cas.server-url=http://localhost:8088/cas"})
|
||||
// IMPORTANT: To test prod config, do not use test profile!
|
||||
class CasAuthenticationFilterIntegrationTest {
|
||||
|
||||
@ -22,10 +26,63 @@ class CasAuthenticationFilterIntegrationTest {
|
||||
@Autowired
|
||||
private TestRestTemplate restTemplate;
|
||||
|
||||
@Autowired
|
||||
private WireMockServer wireMockServer;
|
||||
|
||||
@Test
|
||||
public void shouldAcceptRequest() {
|
||||
// given
|
||||
wireMockServer.stubFor(get(urlEqualTo("/cas/p3/serviceValidate?service=http://localhost:8080/api&ticket=valid"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(200)
|
||||
.withBody("""
|
||||
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
|
||||
<cas:authenticationSuccess>
|
||||
<cas:user>test-user</cas:user>
|
||||
</cas:authenticationSuccess>
|
||||
</cas:serviceResponse>
|
||||
""")));
|
||||
|
||||
// when
|
||||
final var result = restTemplate.exchange(
|
||||
"http://localhost:" + this.serverPort + "/api/ping",
|
||||
HttpMethod.GET,
|
||||
new HttpEntity<>(null, headers("Authorization", "valid")),
|
||||
String.class
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK);
|
||||
assertThat(result.getBody()).isEqualTo("pong test-user-from-authenticate\n");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void shouldRejectRequest() {
|
||||
final var result = this.restTemplate.getForEntity(
|
||||
"http://localhost:" + this.serverPort + "/api/ping", String.class);
|
||||
// given
|
||||
wireMockServer.stubFor(get(urlEqualTo("/cas/p3/serviceValidate?service=http://localhost:8080/api&ticket=invalid"))
|
||||
.willReturn(aResponse()
|
||||
.withStatus(200)
|
||||
.withBody("""
|
||||
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
|
||||
<cas:authenticationFailure code="INVALID_REQUEST"></cas:authenticationFailure>
|
||||
</cas:serviceResponse>
|
||||
""")));
|
||||
|
||||
// when
|
||||
final var result = restTemplate.exchange(
|
||||
"http://localhost:" + this.serverPort + "/api/ping",
|
||||
HttpMethod.GET,
|
||||
new HttpEntity<>(null, headers("Authorization", "invalid")),
|
||||
String.class
|
||||
);
|
||||
|
||||
// then
|
||||
assertThat(result.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED);
|
||||
}
|
||||
|
||||
private HttpHeaders headers(final String key, final String value) {
|
||||
final var headers = new HttpHeaders();
|
||||
headers.set(key, value);
|
||||
return headers;
|
||||
}
|
||||
}
|
||||
|
@ -40,7 +40,7 @@ class WebSecurityConfigIntegrationTest {
|
||||
final var result = restTemplate.exchange(
|
||||
"http://localhost:" + this.serverPort + "/api/ping",
|
||||
HttpMethod.GET,
|
||||
new HttpEntity<Object>(null, headers),
|
||||
new HttpEntity<>(null, headers),
|
||||
String.class
|
||||
);
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
package net.hostsharing.hsadminng.test;
|
||||
|
||||
import net.hostsharing.hsadminng.config.CasAuthenticator;
|
||||
import org.springframework.boot.test.context.TestConfiguration;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
@ -16,4 +17,9 @@ public class DisableSecurityConfig {
|
||||
.csrf(AbstractHttpConfigurer::disable);
|
||||
return http.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
public CasAuthenticator casServiceTicketValidator() {
|
||||
return new CasAuthenticator("fake", null);
|
||||
}
|
||||
}
|
||||
|
@ -0,0 +1,19 @@
|
||||
package net.hostsharing.hsadminng.test;
|
||||
|
||||
import com.github.tomakehurst.wiremock.WireMockServer;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
|
||||
@Configuration
|
||||
public class WireMockConfig {
|
||||
|
||||
private static final WireMockServer wireMockServer = new WireMockServer(8088); // Use a different port to avoid conflicts
|
||||
|
||||
@Bean
|
||||
public WireMockServer wireMockServer() {
|
||||
if (!wireMockServer.isRunning()) {
|
||||
wireMockServer.start();
|
||||
}
|
||||
return wireMockServer;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user