mirror of
https://github.com/schollz/cowyo.git
synced 2023-08-10 21:13:00 +03:00
275 lines
6.6 KiB
Go
275 lines
6.6 KiB
Go
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||
|
// Use of this source code is governed by a BSD-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package securecookie
|
||
|
|
||
|
import (
|
||
|
"crypto/aes"
|
||
|
"crypto/hmac"
|
||
|
"crypto/sha256"
|
||
|
"encoding/base64"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"strings"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
// Asserts that cookieError and MultiError are Error implementations.
|
||
|
var _ Error = cookieError{}
|
||
|
var _ Error = MultiError{}
|
||
|
|
||
|
var testCookies = []interface{}{
|
||
|
map[string]string{"foo": "bar"},
|
||
|
map[string]string{"baz": "ding"},
|
||
|
}
|
||
|
|
||
|
var testStrings = []string{"foo", "bar", "baz"}
|
||
|
|
||
|
func TestSecureCookie(t *testing.T) {
|
||
|
// TODO test too old / too new timestamps
|
||
|
s1 := New([]byte("12345"), []byte("1234567890123456"))
|
||
|
s2 := New([]byte("54321"), []byte("6543210987654321"))
|
||
|
value := map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"baz": 128,
|
||
|
}
|
||
|
|
||
|
for i := 0; i < 50; i++ {
|
||
|
// Running this multiple times to check if any special character
|
||
|
// breaks encoding/decoding.
|
||
|
encoded, err1 := s1.Encode("sid", value)
|
||
|
if err1 != nil {
|
||
|
t.Error(err1)
|
||
|
continue
|
||
|
}
|
||
|
dst := make(map[string]interface{})
|
||
|
err2 := s1.Decode("sid", encoded, &dst)
|
||
|
if err2 != nil {
|
||
|
t.Fatalf("%v: %v", err2, encoded)
|
||
|
}
|
||
|
if !reflect.DeepEqual(dst, value) {
|
||
|
t.Fatalf("Expected %v, got %v.", value, dst)
|
||
|
}
|
||
|
dst2 := make(map[string]interface{})
|
||
|
err3 := s2.Decode("sid", encoded, &dst2)
|
||
|
if err3 == nil {
|
||
|
t.Fatalf("Expected failure decoding.")
|
||
|
}
|
||
|
err4, ok := err3.(Error)
|
||
|
if !ok {
|
||
|
t.Fatalf("Expected error to implement Error, got: %#v", err3)
|
||
|
}
|
||
|
if !err4.IsDecode() {
|
||
|
t.Fatalf("Expected DecodeError, got: %#v", err4)
|
||
|
}
|
||
|
|
||
|
// Test other error type flags.
|
||
|
if err4.IsUsage() {
|
||
|
t.Fatalf("Expected IsUsage() == false, got: %#v", err4)
|
||
|
}
|
||
|
if err4.IsInternal() {
|
||
|
t.Fatalf("Expected IsInternal() == false, got: %#v", err4)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestSecureCookieNilKey(t *testing.T) {
|
||
|
s1 := New(nil, nil)
|
||
|
value := map[string]interface{}{
|
||
|
"foo": "bar",
|
||
|
"baz": 128,
|
||
|
}
|
||
|
_, err := s1.Encode("sid", value)
|
||
|
if err != errHashKeyNotSet {
|
||
|
t.Fatal("Wrong error returned:", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestDecodeInvalid(t *testing.T) {
|
||
|
// List of invalid cookies, which must not be accepted, base64-decoded
|
||
|
// (they will be encoded before passing to Decode).
|
||
|
invalidCookies := []string{
|
||
|
"",
|
||
|
" ",
|
||
|
"\n",
|
||
|
"||",
|
||
|
"|||",
|
||
|
"cookie",
|
||
|
}
|
||
|
s := New([]byte("12345"), nil)
|
||
|
var dst string
|
||
|
for i, v := range invalidCookies {
|
||
|
for _, enc := range []*base64.Encoding{
|
||
|
base64.StdEncoding,
|
||
|
base64.URLEncoding,
|
||
|
} {
|
||
|
err := s.Decode("name", enc.EncodeToString([]byte(v)), &dst)
|
||
|
if err == nil {
|
||
|
t.Fatalf("%d: expected failure decoding", i)
|
||
|
}
|
||
|
err2, ok := err.(Error)
|
||
|
if !ok || !err2.IsDecode() {
|
||
|
t.Fatalf("%d: Expected IsDecode(), got: %#v", i, err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestAuthentication(t *testing.T) {
|
||
|
hash := hmac.New(sha256.New, []byte("secret-key"))
|
||
|
for _, value := range testStrings {
|
||
|
hash.Reset()
|
||
|
signed := createMac(hash, []byte(value))
|
||
|
hash.Reset()
|
||
|
err := verifyMac(hash, []byte(value), signed)
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEncryption(t *testing.T) {
|
||
|
block, err := aes.NewCipher([]byte("1234567890123456"))
|
||
|
if err != nil {
|
||
|
t.Fatalf("Block could not be created")
|
||
|
}
|
||
|
var encrypted, decrypted []byte
|
||
|
for _, value := range testStrings {
|
||
|
if encrypted, err = encrypt(block, []byte(value)); err != nil {
|
||
|
t.Error(err)
|
||
|
} else {
|
||
|
if decrypted, err = decrypt(block, encrypted); err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
if string(decrypted) != value {
|
||
|
t.Errorf("Expected %v, got %v.", value, string(decrypted))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestGobSerialization(t *testing.T) {
|
||
|
var (
|
||
|
sz GobEncoder
|
||
|
serialized []byte
|
||
|
deserialized map[string]string
|
||
|
err error
|
||
|
)
|
||
|
for _, value := range testCookies {
|
||
|
if serialized, err = sz.Serialize(value); err != nil {
|
||
|
t.Error(err)
|
||
|
} else {
|
||
|
deserialized = make(map[string]string)
|
||
|
if err = sz.Deserialize(serialized, &deserialized); err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) {
|
||
|
t.Errorf("Expected %v, got %v.", value, deserialized)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestJSONSerialization(t *testing.T) {
|
||
|
var (
|
||
|
sz JSONEncoder
|
||
|
serialized []byte
|
||
|
deserialized map[string]string
|
||
|
err error
|
||
|
)
|
||
|
for _, value := range testCookies {
|
||
|
if serialized, err = sz.Serialize(value); err != nil {
|
||
|
t.Error(err)
|
||
|
} else {
|
||
|
deserialized = make(map[string]string)
|
||
|
if err = sz.Deserialize(serialized, &deserialized); err != nil {
|
||
|
t.Error(err)
|
||
|
}
|
||
|
if fmt.Sprintf("%v", deserialized) != fmt.Sprintf("%v", value) {
|
||
|
t.Errorf("Expected %v, got %v.", value, deserialized)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestEncoding(t *testing.T) {
|
||
|
for _, value := range testStrings {
|
||
|
encoded := encode([]byte(value))
|
||
|
decoded, err := decode(encoded)
|
||
|
if err != nil {
|
||
|
t.Error(err)
|
||
|
} else if string(decoded) != value {
|
||
|
t.Errorf("Expected %v, got %s.", value, string(decoded))
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMultiError(t *testing.T) {
|
||
|
s1, s2 := New(nil, nil), New(nil, nil)
|
||
|
_, err := EncodeMulti("sid", "value", s1, s2)
|
||
|
if len(err.(MultiError)) != 2 {
|
||
|
t.Errorf("Expected 2 errors, got %s.", err)
|
||
|
} else {
|
||
|
if strings.Index(err.Error(), "hash key is not set") == -1 {
|
||
|
t.Errorf("Expected missing hash key error, got %s.", err.Error())
|
||
|
}
|
||
|
ourErr, ok := err.(Error)
|
||
|
if !ok || !ourErr.IsUsage() {
|
||
|
t.Fatalf("Expected error to be a usage error; got %#v", err)
|
||
|
}
|
||
|
if ourErr.IsDecode() {
|
||
|
t.Errorf("Expected error NOT to be a decode error; got %#v", ourErr)
|
||
|
}
|
||
|
if ourErr.IsInternal() {
|
||
|
t.Errorf("Expected error NOT to be an internal error; got %#v", ourErr)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMultiNoCodecs(t *testing.T) {
|
||
|
_, err := EncodeMulti("foo", "bar")
|
||
|
if err != errNoCodecs {
|
||
|
t.Errorf("EncodeMulti: bad value for error, got: %v", err)
|
||
|
}
|
||
|
|
||
|
var dst []byte
|
||
|
err = DecodeMulti("foo", "bar", &dst)
|
||
|
if err != errNoCodecs {
|
||
|
t.Errorf("DecodeMulti: bad value for error, got: %v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestMissingKey(t *testing.T) {
|
||
|
s1 := New(nil, nil)
|
||
|
|
||
|
var dst []byte
|
||
|
err := s1.Decode("sid", "value", &dst)
|
||
|
if err != errHashKeyNotSet {
|
||
|
t.Fatalf("Expected %#v, got %#v", errHashKeyNotSet, err)
|
||
|
}
|
||
|
if err2, ok := err.(Error); !ok || !err2.IsUsage() {
|
||
|
t.Errorf("Expected missing hash key to be IsUsage(); was %#v", err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ----------------------------------------------------------------------------
|
||
|
|
||
|
type FooBar struct {
|
||
|
Foo int
|
||
|
Bar string
|
||
|
}
|
||
|
|
||
|
func TestCustomType(t *testing.T) {
|
||
|
s1 := New([]byte("12345"), []byte("1234567890123456"))
|
||
|
// Type is not registered in gob. (!!!)
|
||
|
src := &FooBar{42, "bar"}
|
||
|
encoded, _ := s1.Encode("sid", src)
|
||
|
|
||
|
dst := &FooBar{}
|
||
|
_ = s1.Decode("sid", encoded, dst)
|
||
|
if dst.Foo != 42 || dst.Bar != "bar" {
|
||
|
t.Fatalf("Expected %#v, got %#v", src, dst)
|
||
|
}
|
||
|
}
|