Published on

Verifying a JSON Web Token from Amazon Cognito in Go and Gin

Authors

NOTE: This post is intentionally structured similarly to Verifying a JSON Web Token from Amazon Cognito in Python and FastAPI, but showcasing the same methodology using Go related technologies.

Intro

One popular option for integrating Amazon Cognito authentication/authorization with a backend, requires the decoding and verifying of JSON Web Tokens (JWT) for server-side processing. AWS details the steps required to validate an incoming JWT produced by Amazon Cognito in its official documentation, and offers an example using a dedicated JavaScript library.

We can implement the above steps in Go and Gin web framework, using a recommended combination of two modern and well-maintained Go libraries, taking advantage of the excellent interoperability between them.

  • keyfunc for consuming a JWKS and parsing it in an easily readable structure in Go jwt.Keyfunc.

  • golang-jwt for parsing and verifying JSON web tokens

Retrieving a public JWKS

The JWKS URI contains public information about the private key that signed your user's token. As soon as a Cognito User Pool is created, it will publish its JWKS in a public URI. It can be composed as followed https://cognito-idp.<region>.amazonaws.com/<userPoolId>/.well-known/jwks.json where region is the AWS region of your User Pool and userPoolId the ID of the User Pool.

The keyfunc Go library, has a convenient method keyfunc.Get for retrieving that key using an HTTP request and parsing it, all in a single statement:

import (
    "fmt"

    "github.com/MicahParks/keyfunc/v2"
)


func GetJWKS(awsRegion string, cognitoUserPoolId string) (*keyfunc.JWKS, error) {

    jwksURL := fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", awsRegion, cognitoUserPoolId)

    jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{})
    if err != nil {
        return nil, err
    }
    return jwks, nil
}

Verifying an incoming JWT in a Gin middleware

1. Defining a Gin authentication middleware

We can define a middleware in a file middlewares/cognito.go, with the following function signature, passing the token use (requiredTokenUse which can be access or id depending on the type of Cognito token we accept), AWS Region (awsDefaultRegion), the Cognito User Pool ID (cognitoUserPoolId), the Cognito App Client ID that we've created (cognitoAppClientId) and the JWKS we retrieved earlier (jwks).

func CognitoAuthMiddleware(requiredTokenUse string,
    awsDefaultRegion string,
    cognitoUserPoolId string,
    cognitoAppClientId string,
    jwks *keyfunc.JWKS) gin.HandlerFunc {
    return func(c *gin.Context) {

        }
    }

2. Retrieving the JWT from an incoming request

By convention, an incoming request contains a JWT in its Authorization header using a Bearer token.

The Authorization header including the Bearer token has the format:

Authorization: Bearer abcdefg

Using the Gin context passed to our middleware (variable c), we can retrieve the header, and get the raw value of the JWT by splitting the string containing it.

// Retrieve JWT from the "Authorization" header
authHeader := c.GetHeader("Authorization")
splitToken := strings.Split(authHeader, "Bearer ")

if len(splitToken) != 2 {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

tokenString := splitToken[1]

if tokenString == "" {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

3. Verify the JWT signature, signing algorithm, issuer (iss) and existence of expiry time (exp)

Now, using golang-jwt we can perform some first rudimentary checks against the JWT using some convenient methods offered by the library.

  • The signature of the JWT is verified using the jwt.Parse function by providing the JWKS we retrieved earlier (jwks.Keyfunc), along with some extra checks we define below.

  • One of the most important verifications is defining the specific valid algorithm methods that the parser will use, and confirming that the incoming token is using those exclusively. In the case of Amazon Cognito the asymmetric algorithm RS256 is used to sign the key. This can be enforced using jwt.WithValidMethods. Leaving this option without a value leaves us open to algorithm confusion attacks.

  • The issuer claim (iss) should match our User Pool. For example, a User Pool created in the us-east-1 region will have the following iss value: https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>. We can verify this using the jwt.WithIssuer function.

  • We can also preemptively check that the exp claim defining the expiration time exists in the token (using jwt.WithExpirationRequired). This simply checks that it is present in the token and not its value.

The above checks can be succinctly defined in one statement, facilitated by golang-jwt, as followed:

token, err := jwt.Parse(tokenString,
    jwks.Keyfunc,
    jwt.WithValidMethods([]string{"RS256"}),
    jwt.WithExpirationRequired(),
    jwt.WithIssuer(fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s", awsDefaultRegion, cognitoUserPoolId)))
if err != nil || !token.Valid {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

4. Parse JWT claims

If the above checks succeeded, we can attempt to parse the token's claims by casting them in a jwt.MapClaims struct.

claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

5. Compare the exp claim to the current time

Now we can start making additional checks against the claims' values found in the token, starting with its expiry time (exp claim) and using jwt.GetExpirationTime.

expClaim, err := claims.GetExpirationTime()
if err != nil {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}
if expClaim.Unix() < time.Now().Unix() {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

6. Check the token_use claim

Cognito can send ID or access tokens, each with a different set of attributes. Depending on the nature of the endpoint we want to protect we can choose to accept specific types.

ID tokens contain claims about the identity of the authenticated user, such as name, email, and phone_number.

Access tokens contain claims about the authorization of the authenticated user such as a list of the user's groups, and a list of scopes.

tokenUseClaim, ok := claims["token_use"].(string)
if !ok {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}
if tokenUseClaim != requiredTokenUse {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

7. Verify the existence of the subject (sub) claim

The subject (sub) claim is mandatory and contains the ID of the Cognito user. We can also set its value in the Gin context, for further processing in the request.

subClaim, err := claims.GetSubject()
if err != nil {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}
c.Set("username", subClaim)

8. Verify the audience (aud)/client ID (client_id) claim

Depending on the type of token (access or ID), we can check respectively the aud or the client_id claims and that they should match the Cognito app client ID created in the Cognito User Pool.

The aud claim can be retrieved using jwt.GetAudience, the client_id custom claim can be retrieved by manually checking its existence in the JWT claims.

var appClientIdClaim string
if tokenUseClaim == "id" {
    audienceClaims, err := claims.GetAudience()
    if err != nil {
        c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
        c.Abort()
        return
    }
    appClientIdClaim = audienceClaims[0]

} else if tokenUseClaim == "access" {
    clientIdClaim, ok := claims["client_id"].(string)
    if !ok {
        c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
        c.Abort()
        return
    }
    appClientIdClaim = clientIdClaim
} else {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}
if appClientIdClaim != cognitoAppClientId {
    c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
    c.Abort()
    return
}

9. Optionally retrieve any Cognito user groups that the user belongs to

We can optionally parse any Cognito groups that the user belongs to, and set them in the Gin context for further usage within the request. These exist in a cognito:groups claim in the JWT.

userGroupsAttribute, ok := claims["cognito:groups"]
userGroupsClaims := make([]string, 0)
if ok {
    switch x := userGroupsAttribute.(type) {
    case []interface{}:
        for _, e := range x {
            userGroupsClaims = append(userGroupsClaims, e.(string))
        }
    default:
        c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
        c.Abort()
        return
    }
}

c.Set("groups", userGroupsClaims)

10. Finally, proceed to the actual request

c.Next()

Complete middleware snippet

The complete snippet using all the above statements is as followed:

package middlewares

import (
	"fmt"
	"net/http"
	"strings"
	"time"

	"github.com/MicahParks/keyfunc/v2"
	"github.com/gin-gonic/gin"
	"github.com/golang-jwt/jwt/v5"
)

// Gin middleware for verifying an incoming Cognito JWT, embedded in the request headers
// https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-tokens-verifying-a-jwt.html
func CognitoAuthMiddleware(requiredTokenUse string,
	awsDefaultRegion string,
	cognitoUserPoolId string,
	cognitoAppClientId string,
	jwks *keyfunc.JWKS) gin.HandlerFunc {
	return func(c *gin.Context) {

		// Retrieve JWT from the "Authorization" header
		authHeader := c.GetHeader("Authorization")
		splitToken := strings.Split(authHeader, "Bearer ")

		if len(splitToken) != 2 {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		tokenString := splitToken[1]

		if tokenString == "" {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// * Verify the signature of the JWT
		// * Verify that the algorithm used is RS256
		// * Verify that the 'exp' claim exists in the token
		// * Verification of audience 'aud' is taken care later when we examine if the
		//   token is 'id' or 'access'
		// * The issuer (iss) claim should match your user pool. For example, a user
		//   pool created in the us-east-1 region
		//   will have the following iss value: https://cognito-idp.us-east-1.amazonaws.com/<userpoolID>.
		token, err := jwt.Parse(tokenString,
			jwks.Keyfunc,
			jwt.WithValidMethods([]string{"RS256"}),
			jwt.WithExpirationRequired(),
			jwt.WithIssuer(fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s", awsDefaultRegion, cognitoUserPoolId)))
		if err != nil || !token.Valid {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// Attempt to parse the JWT claims
		claims, ok := token.Claims.(jwt.MapClaims)
		if !ok {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// Compare the "exp" claim to the current time
		expClaim, err := claims.GetExpirationTime()
		if err != nil {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}
		if expClaim.Unix() < time.Now().Unix() {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// Check the token_use claim.
		// If you are only accepting the access token in your web API operations, its value must be access.
		// If you are only using the ID token, its value must be id.
		// If you are using both ID and access tokens, the token_use claim must be either id or access.
		tokenUseClaim, ok := claims["token_use"].(string)
		if !ok {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}
		if tokenUseClaim != requiredTokenUse {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// "sub" claim exists in both ID and Access tokens
		subClaim, err := claims.GetSubject()
		if err != nil {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}
		c.Set("username", subClaim)

		// The "aud" claim in an ID token and the "client_id" claim in an access token should match the app
		// client ID that was created in the Amazon Cognito user pool.
		var appClientIdClaim string
		if tokenUseClaim == "id" {
			audienceClaims, err := claims.GetAudience()
			if err != nil {
				c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
				c.Abort()
				return
			}
			appClientIdClaim = audienceClaims[0]

		} else if tokenUseClaim == "access" {
			clientIdClaim, ok := claims["client_id"].(string)
			if !ok {
				c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
				c.Abort()
				return
			}
			appClientIdClaim = clientIdClaim
		} else {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}
		if appClientIdClaim != cognitoAppClientId {
			c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
			c.Abort()
			return
		}

		// Retrieve any Cognito user groups that the user belongs to
		userGroupsAttribute, ok := claims["cognito:groups"]
		userGroupsClaims := make([]string, 0)
		if ok {
			switch x := userGroupsAttribute.(type) {
			case []interface{}:
				for _, e := range x {
					userGroupsClaims = append(userGroupsClaims, e.(string))
				}
			default:
				c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
				c.Abort()
				return
			}
		}

		c.Set("groups", userGroupsClaims)

		c.Next()
	}
}

Using the middleware in a Gin endpoint

Putting it all together, we can create Gin endpoints that are protected using the above middleware.

The below code shows how to define a Go main function containing two protected Gin endpoints, one requiring a Cognito ID token (/protected-with-id-token) and one requiring a Cognito Access Token (/protected-with-access-token).

package main

import (
	"fmt"
	"log"
	"net/http"
	"os"

	"github.com/MicahParks/keyfunc/v2"
	"github.com/angelospanag/go-gin-cognito-jwt-verification/middleware"
	"github.com/gin-gonic/gin"
	"github.com/joho/godotenv"
)

func GetJWKS(awsRegion string, cognitoUserPoolId string) (*keyfunc.JWKS, error) {

	jwksURL := fmt.Sprintf("https://cognito-idp.%s.amazonaws.com/%s/.well-known/jwks.json", awsRegion, cognitoUserPoolId)

	jwks, err := keyfunc.Get(jwksURL, keyfunc.Options{})
	if err != nil {
		return nil, err
	}
	return jwks, nil
}

func main() {
	err := godotenv.Load()
	if err != nil {
		log.Fatal("Error loading .env file")
	}
	awsDefaultRegion := os.Getenv("AWS_DEFAULT_REGION")
	cognitoUserPoolId := os.Getenv("COGNITO_USER_POOL_ID")
	cognitoAppClientId := os.Getenv("COGNITO_APP_CLIENT_ID")

	jwks, err := GetJWKS(awsDefaultRegion, cognitoUserPoolId)

	if err != nil {
		log.Fatalf("Failed to retrieve Cognito JWKS\nError: %s", err)
	}

	router := gin.Default()

	router.GET("/healthcheck", func(c *gin.Context) {
		c.String(http.StatusOK, "ok")
	})

	router.GET("/protected-with-id-token", middlewares.CognitoAuthMiddleware(
		"id",
		awsDefaultRegion,
		cognitoUserPoolId,
		cognitoAppClientId,
		jwks), func(c *gin.Context) {
		username, _ := c.Get("username")
		c.JSON(http.StatusOK, gin.H{"username": username})
	})

	router.GET("/protected-with-access-token", middlewares.CognitoAuthMiddleware(
		"access",
		awsDefaultRegion,
		cognitoUserPoolId,
		cognitoAppClientId,
		jwks), func(c *gin.Context) {
		username, _ := c.Get("username")
		c.JSON(http.StatusOK, gin.H{"username": username})
	})
	router.Run()
}

Full implementation

You can find the fully working implementation in a GitHub repository.