blob: 1624235c1f4d93dea6ee279b73ca766cc404fe53 [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
////////////////////////////////////////////////////////////////////////////////
package fakeawskms
import (
"bytes"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/kms"
)
const validKeyID = "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab"
const validKeyID2 = "arn:aws:kms:us-west-2:123:key/different"
func TestEncyptDecryptWithValidKeyId(t *testing.T) {
fakeKMS, err := New([]string{validKeyID})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
plaintext := []byte("plaintext")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
encRequest := &kms.EncryptInput{
KeyId: aws.String(validKeyID),
Plaintext: plaintext,
EncryptionContext: context,
}
encResponse, err := fakeKMS.Encrypt(encRequest)
if err != nil {
t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
}
ciphertext := encResponse.CiphertextBlob
decRequest := &kms.DecryptInput{
KeyId: aws.String(validKeyID),
CiphertextBlob: ciphertext,
EncryptionContext: context,
}
decResponse, err := fakeKMS.Decrypt(decRequest)
if err != nil {
t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
}
if !bytes.Equal(decResponse.Plaintext, plaintext) {
t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
}
if strings.Compare(*decResponse.KeyId, validKeyID) != 0 {
t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID)
}
// decrypt with a different context should fail
otherContextValue := "otherContextValue"
otherContext := map[string]*string{"contextName": &otherContextValue}
otherDecRequest := &kms.DecryptInput{
KeyId: aws.String(validKeyID),
CiphertextBlob: ciphertext,
EncryptionContext: otherContext,
}
if _, err := fakeKMS.Decrypt(otherDecRequest); err == nil {
t.Fatal("fakeKMS.Decrypt(otherDecRequest) err = nil, want not nil")
}
}
func TestEncyptWithUnknownKeyID(t *testing.T) {
fakeKMS, err := New([]string{validKeyID})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
plaintext := []byte("plaintext")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
encRequestWithUnknownKeyID := &kms.EncryptInput{
KeyId: aws.String(validKeyID2),
Plaintext: plaintext,
EncryptionContext: context,
}
if _, err := fakeKMS.Encrypt(encRequestWithUnknownKeyID); err == nil {
t.Fatal("fakeKMS.Encrypt(encRequestWithvalidKeyID2) err = nil, want not nil")
}
}
func TestDecryptWithInvalidCiphertext(t *testing.T) {
fakeKMS, err := New([]string{validKeyID})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
invalidCiphertext := []byte("plaintext")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
decRequest := &kms.DecryptInput{
CiphertextBlob: invalidCiphertext,
EncryptionContext: context,
}
if _, err := fakeKMS.Decrypt(decRequest); err == nil {
t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
}
}
func TestDecryptWithUnknownKeyId(t *testing.T) {
fakeKMS, err := New([]string{validKeyID})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
ciphertext := []byte("invalidCiphertext")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
decRequest := &kms.DecryptInput{
KeyId: aws.String(validKeyID2),
CiphertextBlob: ciphertext,
EncryptionContext: context,
}
if _, err := fakeKMS.Decrypt(decRequest); err == nil {
t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
}
}
func TestDecryptWithWrongKeyId(t *testing.T) {
fakeKMS, err := New([]string{validKeyID, validKeyID2})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
plaintext := []byte("plaintext")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
encRequest := &kms.EncryptInput{
KeyId: aws.String(validKeyID),
Plaintext: plaintext,
EncryptionContext: context,
}
encResponse, err := fakeKMS.Encrypt(encRequest)
if err != nil {
t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
}
ciphertext := encResponse.CiphertextBlob
decRequest := &kms.DecryptInput{
KeyId: aws.String(validKeyID2), // wrong key id
CiphertextBlob: ciphertext,
EncryptionContext: context,
}
if _, err := fakeKMS.Decrypt(decRequest); err == nil {
t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil")
}
}
func TestDecryptWithoutKeyId(t *testing.T) {
// setting the keyId in DecryptInput is not required, see
// https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/kms#DecryptInput
fakeKMS, err := New([]string{validKeyID, validKeyID2})
if err != nil {
t.Fatalf("New() err = %s, want nil", err)
}
plaintext := []byte("plaintext")
plaintext2 := []byte("plaintext2")
contextValue := "contextValue"
context := map[string]*string{"contextName": &contextValue}
encRequest := &kms.EncryptInput{
KeyId: aws.String(validKeyID),
Plaintext: plaintext,
EncryptionContext: context,
}
encResponse, err := fakeKMS.Encrypt(encRequest)
if err != nil {
t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err)
}
if strings.Compare(*encResponse.KeyId, validKeyID) != 0 {
t.Fatalf("encResponse.KeyId = %q, want %q", *encResponse.KeyId, validKeyID)
}
encRequest2 := &kms.EncryptInput{
KeyId: aws.String(validKeyID2),
Plaintext: plaintext2,
EncryptionContext: context,
}
encResponse2, err := fakeKMS.Encrypt(encRequest2)
if err != nil {
t.Fatalf("fakeKMS.Encrypt(encRequest2) err = %s, want nil", err)
}
if strings.Compare(*encResponse2.KeyId, validKeyID2) != 0 {
t.Fatalf("encResponse2.KeyId = %q, want %q", *encResponse2.KeyId, validKeyID2)
}
decRequest := &kms.DecryptInput{
// KeyId is not set
CiphertextBlob: encResponse.CiphertextBlob,
EncryptionContext: context,
}
decResponse, err := fakeKMS.Decrypt(decRequest)
if err != nil {
t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err)
}
if !bytes.Equal(decResponse.Plaintext, plaintext) {
t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext)
}
if strings.Compare(*decResponse.KeyId, validKeyID) != 0 {
t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID)
}
decRequest2 := &kms.DecryptInput{
// KeyId is not set
CiphertextBlob: encResponse2.CiphertextBlob,
EncryptionContext: context,
}
decResponse2, err := fakeKMS.Decrypt(decRequest2)
if err != nil {
t.Fatalf("fakeKMS.Decrypt(decRequest2) err = %s, want nil", err)
}
if !bytes.Equal(decResponse2.Plaintext, plaintext2) {
t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext2)
}
if strings.Compare(*decResponse2.KeyId, validKeyID2) != 0 {
t.Fatalf("decResponse2.KeyId = %q, want %q", *decResponse2.KeyId, validKeyID2)
}
}
func TestSerializeContext(t *testing.T) {
uvw := "uvw"
xyz := "xyz"
rst := "rst"
context := map[string]*string{"def": &uvw, "abc": &xyz, "ghi": &rst}
got := string(serializeContext(context))
want := "{\"abc\":\"xyz\",\"def\":\"uvw\",\"ghi\":\"rst\"}"
if got != want {
t.Fatalf("SerializeContext(context) = %s, want %s", got, want)
}
gotEscaped := string(serializeContext(map[string]*string{"a\"b": &xyz}))
wantEscaped := "{\"a\\\"b\":\"xyz\"}"
if gotEscaped != wantEscaped {
t.Fatalf("SerializeContext(context) = %s, want %s", gotEscaped, wantEscaped)
}
gotEmpty := string(serializeContext(map[string]*string{}))
if gotEmpty != "{}" {
t.Fatalf("SerializeContext(context) = %s, want %s", gotEmpty, "{}")
}
}
OSZAR »