add CAS authentication #138

Merged
hsh-michaelhoennig merged 24 commits from feature/add-cas-authentication into master 2024-12-23 12:49:46 +01:00
9 changed files with 213 additions and 53 deletions
Showing only changes of commit ee001b520c - Show all commits

View File

@ -0,0 +1,54 @@
package net.hostsharing.hsadminng.config;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import java.util.*;
public class AuthenticatedHttpServletRequestWrapper extends HttpServletRequestWrapper {
private final Map<String, String> customHeaders = new HashMap<>();
public AuthenticatedHttpServletRequestWrapper(HttpServletRequest request) {
super(request);
}
public void addHeader(final String name, final String value) {
customHeaders.put(name, value);
}
@Override
public String getHeader(final String name) {
// Check custom headers first
final var customHeaderValue = customHeaders.get(name);
if (customHeaderValue != null) {
return customHeaderValue;
}
// Fall back to the original headers
return super.getHeader(name);
}
@Override
public Enumeration<String> getHeaderNames() {
// Combine original headers and custom headers
final var headerNames = new HashSet<>(customHeaders.keySet());
final var originalHeaderNames = super.getHeaderNames();
while (originalHeaderNames.hasMoreElements()) {
headerNames.add(originalHeaderNames.nextElement());
}
return Collections.enumeration(headerNames);
}
@Override
public Enumeration<String> getHeaders(final String name) {
// Combine original headers and custom header
final var values = new HashSet<String>();
if (customHeaders.containsKey(name)) {
values.add(customHeaders.get(name));
}
final var originalValues = super.getHeaders(name);
while (originalValues.hasMoreElements()) {
values.add(originalValues.nextElement());
}
return Collections.enumeration(values);
}
}

View File

@ -9,13 +9,14 @@ import jakarta.servlet.http.HttpServletResponse;
import lombok.SneakyThrows;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.stereotype.Component;
@Component
public class CasAuthenticationFilter implements Filter {
@Autowired
private CasServiceTicketValidator ticketValidator;
private CasAuthenticator casAuthenticator;
@Override
@SneakyThrows
@ -23,13 +24,16 @@ public class CasAuthenticationFilter implements Filter {
final var httpRequest = (HttpServletRequest) request;
final var httpResponse = (HttpServletResponse) response;
final var ticket = httpRequest.getHeader("Authorization");
try {
final var currentSubject = casAuthenticator.authenticate(httpRequest);
if (!ticketValidator.validateTicket(ticket)) {
final var authenticatedRequest = new AuthenticatedHttpServletRequestWrapper(httpRequest);
authenticatedRequest.addHeader("current-subject", currentSubject);
chain.doFilter(authenticatedRequest, response);
} catch (final BadCredentialsException exc) {
// TODO.impl: should not be necessary if ResponseStatusException worked
httpResponse.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
return;
}
chain.doFilter(request, response);
}
}
}

View File

@ -0,0 +1,54 @@
package net.hostsharing.hsadminng.config;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.SneakyThrows;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import jakarta.servlet.http.HttpServletRequest;
import javax.xml.parsers.DocumentBuilderFactory;
@Service
@NoArgsConstructor
@AllArgsConstructor
public class CasAuthenticator {
@Value("${hsadminng.cas.server-url}")
private String casServerUrl;
@Value("${hsadminng.cas.service-url}")
private String serviceUrl;
private final RestTemplate restTemplate = new RestTemplate();
@SneakyThrows
public String authenticate(final HttpServletRequest httpRequest) {
// FIXME: create FakeCasAuthenticator
if (casServerUrl.equals("fake")) {
return httpRequest.getHeader("current-subject");
}
final var ticket = httpRequest.getHeader("Authorization");
final var url = casServerUrl + "/p3/serviceValidate" +
"?service=" + serviceUrl +
"&ticket=" + ticket;
final var response = restTemplate.getForObject(url, String.class);
final var doc = DocumentBuilderFactory.newInstance().newDocumentBuilder()
.parse(new java.io.ByteArrayInputStream(response.getBytes()));
if ( doc.getElementsByTagName("cas:authenticationSuccess").getLength() == 0 ) {
// TODO.impl: for unknown reasons, this results in a 403 FORBIDDEN
// throw new ResponseStatusException(HttpStatus.UNAUTHORIZED, "CAS service ticket could not be validated");
throw new BadCredentialsException("CAS service ticket could not be validated");
}
final var authentication = new UsernamePasswordAuthenticationToken("test-user-from-authenticate", null, null); // TODO
SecurityContextHolder.getContext().setAuthentication(authentication);
return authentication.getName();
}
}

View File

@ -1,40 +0,0 @@
package net.hostsharing.hsadminng.config;
import lombok.SneakyThrows;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import javax.xml.parsers.DocumentBuilderFactory;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
@Service
public class CasServiceTicketValidator {
@Value("${hsadminng.cas.server-url}")
private String casServerUrl;
@Value("${hsadminng.cas.service-url}")
private String serviceUrl;
private final RestTemplate restTemplate = new RestTemplate();
@SneakyThrows
public boolean validateTicket(final String ticket) {
if (casServerUrl.equals("fake")) {
return true;
}
final var url = casServerUrl + "/p3/serviceValidate" +
"?service=" + URLEncoder.encode(serviceUrl, StandardCharsets.UTF_8) +
"&ticket=" + URLEncoder.encode(ticket, StandardCharsets.UTF_8);
final var response = restTemplate.getForObject(url, String.class);
final var doc = DocumentBuilderFactory.newInstance().newDocumentBuilder()
.parse(new java.io.ByteArrayInputStream(response.getBytes()));
return doc.getElementsByTagName("cas:authenticationSuccess").getLength() > 0;
}
}

View File

@ -1,16 +1,22 @@
package net.hostsharing.hsadminng.ping;
import org.springframework.stereotype.Controller;
import org.springframework.web.bind.annotation.RequestHeader;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.ResponseBody;
import jakarta.validation.constraints.NotNull;
@Controller
public class PingController {
@ResponseBody
@RequestMapping(value = "/api/ping", method = RequestMethod.GET)
public String ping() {
return "pong\n";
public String ping(
@RequestHeader(name = "current-subject") @NotNull String currentSubject,
@RequestHeader(name = "assumed-roles", required = false) String assumedRoles
) {
return "pong " + currentSubject + "\n";
}
}

View File

@ -1,18 +1,22 @@
package net.hostsharing.hsadminng.config;
import com.github.tomakehurst.wiremock.WireMockServer;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.client.TestRestTemplate;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.test.context.TestPropertySource;
import static org.assertj.core.api.Assertions.assertThat;
import static com.github.tomakehurst.wiremock.client.WireMock.*;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@TestPropertySource(properties = {"server.port=0"})
@TestPropertySource(properties = {"server.port=0", "hsadminng.cas.server-url=http://localhost:8088/cas"})
// IMPORTANT: To test prod config, do not use test profile!
class CasAuthenticationFilterIntegrationTest {
@ -22,10 +26,63 @@ class CasAuthenticationFilterIntegrationTest {
@Autowired
private TestRestTemplate restTemplate;
@Autowired
private WireMockServer wireMockServer;
@Test
public void shouldAcceptRequest() {
// given
wireMockServer.stubFor(get(urlEqualTo("/cas/p3/serviceValidate?service=http://localhost:8080/api&ticket=valid"))
.willReturn(aResponse()
.withStatus(200)
.withBody("""
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationSuccess>
<cas:user>test-user</cas:user>
</cas:authenticationSuccess>
</cas:serviceResponse>
""")));
// when
final var result = restTemplate.exchange(
"http://localhost:" + this.serverPort + "/api/ping",
HttpMethod.GET,
new HttpEntity<>(null, headers("Authorization", "valid")),
String.class
);
// then
assertThat(result.getStatusCode()).isEqualTo(HttpStatus.OK);
assertThat(result.getBody()).isEqualTo("pong test-user-from-authenticate\n");
}
@Test
public void shouldRejectRequest() {
final var result = this.restTemplate.getForEntity(
"http://localhost:" + this.serverPort + "/api/ping", String.class);
// given
wireMockServer.stubFor(get(urlEqualTo("/cas/p3/serviceValidate?service=http://localhost:8080/api&ticket=invalid"))
.willReturn(aResponse()
.withStatus(200)
.withBody("""
<cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
<cas:authenticationFailure code="INVALID_REQUEST"></cas:authenticationFailure>
</cas:serviceResponse>
""")));
// when
final var result = restTemplate.exchange(
"http://localhost:" + this.serverPort + "/api/ping",
HttpMethod.GET,
new HttpEntity<>(null, headers("Authorization", "invalid")),
String.class
);
// then
assertThat(result.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED);
}
private HttpHeaders headers(final String key, final String value) {
final var headers = new HttpHeaders();
headers.set(key, value);
return headers;
}
}

View File

@ -40,7 +40,7 @@ class WebSecurityConfigIntegrationTest {
final var result = restTemplate.exchange(
"http://localhost:" + this.serverPort + "/api/ping",
HttpMethod.GET,
new HttpEntity<Object>(null, headers),
new HttpEntity<>(null, headers),
String.class
);

View File

@ -1,5 +1,6 @@
package net.hostsharing.hsadminng.test;
import net.hostsharing.hsadminng.config.CasAuthenticator;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
@ -16,4 +17,9 @@ public class DisableSecurityConfig {
.csrf(AbstractHttpConfigurer::disable);
return http.build();
}
@Bean
public CasAuthenticator casServiceTicketValidator() {
return new CasAuthenticator("fake", null);
}
}

View File

@ -0,0 +1,19 @@
package net.hostsharing.hsadminng.test;
import com.github.tomakehurst.wiremock.WireMockServer;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class WireMockConfig {
private static final WireMockServer wireMockServer = new WireMockServer(8088); // Use a different port to avoid conflicts
@Bean
public WireMockServer wireMockServer() {
if (!wireMockServer.isRunning()) {
wireMockServer.start();
}
return wireMockServer;
}
}