| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| package fakeawskms |
| |
| import ( |
| "bytes" |
| "strings" |
| "testing" |
| |
| "github.com/aws/aws-sdk-go/aws" |
| "github.com/aws/aws-sdk-go/service/kms" |
| ) |
| |
| const validKeyID = "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab" |
| const validKeyID2 = "arn:aws:kms:us-west-2:123:key/different" |
| |
| func TestEncyptDecryptWithValidKeyId(t *testing.T) { |
| fakeKMS, err := New([]string{validKeyID}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| plaintext := []byte("plaintext") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| encRequest := &kms.EncryptInput{ |
| KeyId: aws.String(validKeyID), |
| Plaintext: plaintext, |
| EncryptionContext: context, |
| } |
| |
| encResponse, err := fakeKMS.Encrypt(encRequest) |
| if err != nil { |
| t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) |
| } |
| |
| ciphertext := encResponse.CiphertextBlob |
| |
| decRequest := &kms.DecryptInput{ |
| KeyId: aws.String(validKeyID), |
| CiphertextBlob: ciphertext, |
| EncryptionContext: context, |
| } |
| decResponse, err := fakeKMS.Decrypt(decRequest) |
| if err != nil { |
| t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) |
| } |
| if !bytes.Equal(decResponse.Plaintext, plaintext) { |
| t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) |
| } |
| if strings.Compare(*decResponse.KeyId, validKeyID) != 0 { |
| t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID) |
| } |
| |
| // decrypt with a different context should fail |
| otherContextValue := "otherContextValue" |
| otherContext := map[string]*string{"contextName": &otherContextValue} |
| otherDecRequest := &kms.DecryptInput{ |
| KeyId: aws.String(validKeyID), |
| CiphertextBlob: ciphertext, |
| EncryptionContext: otherContext, |
| } |
| if _, err := fakeKMS.Decrypt(otherDecRequest); err == nil { |
| t.Fatal("fakeKMS.Decrypt(otherDecRequest) err = nil, want not nil") |
| } |
| } |
| |
| func TestEncyptWithUnknownKeyID(t *testing.T) { |
| fakeKMS, err := New([]string{validKeyID}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| plaintext := []byte("plaintext") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| encRequestWithUnknownKeyID := &kms.EncryptInput{ |
| KeyId: aws.String(validKeyID2), |
| Plaintext: plaintext, |
| EncryptionContext: context, |
| } |
| |
| if _, err := fakeKMS.Encrypt(encRequestWithUnknownKeyID); err == nil { |
| t.Fatal("fakeKMS.Encrypt(encRequestWithvalidKeyID2) err = nil, want not nil") |
| } |
| } |
| |
| func TestDecryptWithInvalidCiphertext(t *testing.T) { |
| fakeKMS, err := New([]string{validKeyID}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| invalidCiphertext := []byte("plaintext") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| decRequest := &kms.DecryptInput{ |
| CiphertextBlob: invalidCiphertext, |
| EncryptionContext: context, |
| } |
| |
| if _, err := fakeKMS.Decrypt(decRequest); err == nil { |
| t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") |
| } |
| } |
| |
| func TestDecryptWithUnknownKeyId(t *testing.T) { |
| fakeKMS, err := New([]string{validKeyID}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| ciphertext := []byte("invalidCiphertext") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| decRequest := &kms.DecryptInput{ |
| KeyId: aws.String(validKeyID2), |
| CiphertextBlob: ciphertext, |
| EncryptionContext: context, |
| } |
| |
| if _, err := fakeKMS.Decrypt(decRequest); err == nil { |
| t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") |
| } |
| } |
| |
| func TestDecryptWithWrongKeyId(t *testing.T) { |
| fakeKMS, err := New([]string{validKeyID, validKeyID2}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| plaintext := []byte("plaintext") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| encRequest := &kms.EncryptInput{ |
| KeyId: aws.String(validKeyID), |
| Plaintext: plaintext, |
| EncryptionContext: context, |
| } |
| |
| encResponse, err := fakeKMS.Encrypt(encRequest) |
| if err != nil { |
| t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) |
| } |
| |
| ciphertext := encResponse.CiphertextBlob |
| |
| decRequest := &kms.DecryptInput{ |
| KeyId: aws.String(validKeyID2), // wrong key id |
| CiphertextBlob: ciphertext, |
| EncryptionContext: context, |
| } |
| if _, err := fakeKMS.Decrypt(decRequest); err == nil { |
| t.Fatal("fakeKMS.Decrypt(decRequest) err = nil, want not nil") |
| } |
| } |
| |
| func TestDecryptWithoutKeyId(t *testing.T) { |
| // setting the keyId in DecryptInput is not required, see |
| // https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/kms#DecryptInput |
| |
| fakeKMS, err := New([]string{validKeyID, validKeyID2}) |
| if err != nil { |
| t.Fatalf("New() err = %s, want nil", err) |
| } |
| |
| plaintext := []byte("plaintext") |
| plaintext2 := []byte("plaintext2") |
| contextValue := "contextValue" |
| context := map[string]*string{"contextName": &contextValue} |
| |
| encRequest := &kms.EncryptInput{ |
| KeyId: aws.String(validKeyID), |
| Plaintext: plaintext, |
| EncryptionContext: context, |
| } |
| encResponse, err := fakeKMS.Encrypt(encRequest) |
| if err != nil { |
| t.Fatalf("fakeKMS.Encrypt(encRequest) err = %s, want nil", err) |
| } |
| if strings.Compare(*encResponse.KeyId, validKeyID) != 0 { |
| t.Fatalf("encResponse.KeyId = %q, want %q", *encResponse.KeyId, validKeyID) |
| } |
| |
| encRequest2 := &kms.EncryptInput{ |
| KeyId: aws.String(validKeyID2), |
| Plaintext: plaintext2, |
| EncryptionContext: context, |
| } |
| encResponse2, err := fakeKMS.Encrypt(encRequest2) |
| if err != nil { |
| t.Fatalf("fakeKMS.Encrypt(encRequest2) err = %s, want nil", err) |
| } |
| if strings.Compare(*encResponse2.KeyId, validKeyID2) != 0 { |
| t.Fatalf("encResponse2.KeyId = %q, want %q", *encResponse2.KeyId, validKeyID2) |
| } |
| |
| decRequest := &kms.DecryptInput{ |
| // KeyId is not set |
| CiphertextBlob: encResponse.CiphertextBlob, |
| EncryptionContext: context, |
| } |
| decResponse, err := fakeKMS.Decrypt(decRequest) |
| if err != nil { |
| t.Fatalf("fakeKMS.Decrypt(decRequest) err = %s, want nil", err) |
| } |
| if !bytes.Equal(decResponse.Plaintext, plaintext) { |
| t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext) |
| } |
| if strings.Compare(*decResponse.KeyId, validKeyID) != 0 { |
| t.Fatalf("decResponse.KeyId = %q, want %q", *decResponse.KeyId, validKeyID) |
| } |
| |
| decRequest2 := &kms.DecryptInput{ |
| // KeyId is not set |
| CiphertextBlob: encResponse2.CiphertextBlob, |
| EncryptionContext: context, |
| } |
| decResponse2, err := fakeKMS.Decrypt(decRequest2) |
| if err != nil { |
| t.Fatalf("fakeKMS.Decrypt(decRequest2) err = %s, want nil", err) |
| } |
| if !bytes.Equal(decResponse2.Plaintext, plaintext2) { |
| t.Fatalf("decResponse.Plaintext = %q, want %q", decResponse.Plaintext, plaintext2) |
| } |
| if strings.Compare(*decResponse2.KeyId, validKeyID2) != 0 { |
| t.Fatalf("decResponse2.KeyId = %q, want %q", *decResponse2.KeyId, validKeyID2) |
| } |
| } |
| |
| func TestSerializeContext(t *testing.T) { |
| uvw := "uvw" |
| xyz := "xyz" |
| rst := "rst" |
| context := map[string]*string{"def": &uvw, "abc": &xyz, "ghi": &rst} |
| |
| got := string(serializeContext(context)) |
| want := "{\"abc\":\"xyz\",\"def\":\"uvw\",\"ghi\":\"rst\"}" |
| if got != want { |
| t.Fatalf("SerializeContext(context) = %s, want %s", got, want) |
| } |
| |
| gotEscaped := string(serializeContext(map[string]*string{"a\"b": &xyz})) |
| wantEscaped := "{\"a\\\"b\":\"xyz\"}" |
| if gotEscaped != wantEscaped { |
| t.Fatalf("SerializeContext(context) = %s, want %s", gotEscaped, wantEscaped) |
| } |
| |
| gotEmpty := string(serializeContext(map[string]*string{})) |
| if gotEmpty != "{}" { |
| t.Fatalf("SerializeContext(context) = %s, want %s", gotEmpty, "{}") |
| } |
| } |