diff --git a/ssh/server.go b/ssh/server.go index c0d1c29e6f..5b5ccd96f4 100644 --- a/ssh/server.go +++ b/ssh/server.go @@ -149,7 +149,7 @@ func (s *ServerConfig) AddHostKey(key Signer) { } // cachedPubKey contains the results of querying whether a public key is -// acceptable for a user. +// acceptable for a user. This is a FIFO cache. type cachedPubKey struct { user string pubKeyData []byte @@ -157,7 +157,13 @@ type cachedPubKey struct { perms *Permissions } -const maxCachedPubKeys = 16 +// maxCachedPubKeys is the number of cache entries we store. +// +// Due to consistent misuse of the PublicKeyCallback API, we have reduced this +// to 1, such that the only key in the cache is the most recently seen one. This +// forces the behavior that the last call to PublicKeyCallback will always be +// with the key that is used for authentication. +const maxCachedPubKeys = 1 // pubKeyCache caches tests for public keys. Since SSH clients // will query whether a public key is acceptable before attempting to @@ -179,9 +185,10 @@ func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { // add adds the given tuple to the cache. func (c *pubKeyCache) add(candidate cachedPubKey) { - if len(c.keys) < maxCachedPubKeys { - c.keys = append(c.keys, candidate) + if len(c.keys) >= maxCachedPubKeys { + c.keys = c.keys[1:] } + c.keys = append(c.keys, candidate) } // ServerConn is an authenticated SSH connection, as seen from the diff --git a/ssh/server_test.go b/ssh/server_test.go index b6d8ab3333..ba1bd10e82 100644 --- a/ssh/server_test.go +++ b/ssh/server_test.go @@ -5,6 +5,7 @@ package ssh import ( + "bytes" "errors" "fmt" "io" @@ -299,6 +300,54 @@ func TestBannerError(t *testing.T) { } } +func TestPublicKeyCallbackLastSeen(t *testing.T) { + var lastSeenKey PublicKey + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + lastSeenKey = key + fmt.Printf("seen %#v\n", key) + if _, ok := key.(*dsaPublicKey); !ok { + return nil, errors.New("nope") + } + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + <-done + + expectedPublicKey := testSigners["dsa"].PublicKey().Marshal() + lastSeenMarshalled := lastSeenKey.Marshal() + if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) { + t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey()) + } +} + type markerConn struct { closed uint32 used uint32