aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDavid Crawshaw <david@zentus.com>2019-07-07 18:00:13 -0400
committerDavid Crawshaw <david@zentus.com>2019-07-07 19:51:30 -0400
commit70a2ba7cf9789a7a1949c027cf663100a5c1b979 (patch)
tree89ff30524dafb0370efe38dbe608cec743b7abf4
parentwgcfg: new package for describing configuration (diff)
downloadwireguard-go-dc/wgcfg.tar.xz
wireguard-go-dc/wgcfg.zip
device: use key types from the wgcfg packagedc/wgcfg
-rw-r--r--device/cookie.go5
-rw-r--r--device/cookie_test.go6
-rw-r--r--device/device.go27
-rw-r--r--device/device_test.go7
-rw-r--r--device/noise-helpers.go27
-rw-r--r--device/noise-protocol.go47
-rw-r--r--device/noise-types.go82
-rw-r--r--device/noise_test.go28
-rw-r--r--device/peer.go6
-rw-r--r--device/uapi.go18
10 files changed, 72 insertions, 181 deletions
diff --git a/device/cookie.go b/device/cookie.go
index f134128..ec54f61 100644
--- a/device/cookie.go
+++ b/device/cookie.go
@@ -13,6 +13,7 @@ import (
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
+ "golang.zx2c4.com/wireguard/wgcfg"
)
type CookieChecker struct {
@@ -41,7 +42,7 @@ type CookieGenerator struct {
}
}
-func (st *CookieChecker) Init(pk NoisePublicKey) {
+func (st *CookieChecker) Init(pk wgcfg.Key) {
st.Lock()
defer st.Unlock()
@@ -171,7 +172,7 @@ func (st *CookieChecker) CreateReply(
return reply, nil
}
-func (st *CookieGenerator) Init(pk NoisePublicKey) {
+func (st *CookieGenerator) Init(pk wgcfg.Key) {
st.Lock()
defer st.Unlock()
diff --git a/device/cookie_test.go b/device/cookie_test.go
index 79a6a86..ef01d46 100644
--- a/device/cookie_test.go
+++ b/device/cookie_test.go
@@ -7,6 +7,8 @@ package device
import (
"testing"
+
+ "golang.zx2c4.com/wireguard/wgcfg"
)
func TestCookieMAC1(t *testing.T) {
@@ -18,11 +20,11 @@ func TestCookieMAC1(t *testing.T) {
checker CookieChecker
)
- sk, err := newPrivateKey()
+ sk, err := wgcfg.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
- pk := sk.publicKey()
+ pk := sk.Public()
generator.Init(pk)
checker.Init(pk)
diff --git a/device/device.go b/device/device.go
index a583fa9..3474e4d 100644
--- a/device/device.go
+++ b/device/device.go
@@ -13,6 +13,7 @@ import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
+ "golang.zx2c4.com/wireguard/wgcfg"
)
const (
@@ -46,13 +47,13 @@ type Device struct {
staticIdentity struct {
sync.RWMutex
- privateKey NoisePrivateKey
- publicKey NoisePublicKey
+ privateKey wgcfg.PrivateKey
+ publicKey wgcfg.Key
}
peers struct {
sync.RWMutex
- keyMap map[NoisePublicKey]*Peer
+ keyMap map[wgcfg.Key]*Peer
}
// unprotected / "self-synchronising resources"
@@ -96,7 +97,7 @@ type Device struct {
*
* Must hold device.peers.Mutex
*/
-func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
+func unsafeRemovePeer(device *Device, peer *Peer, key wgcfg.Key) {
// stop routing and processing of packets
@@ -200,7 +201,7 @@ func (device *Device) IsUnderLoad() bool {
return until.After(now)
}
-func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
+func (device *Device) SetPrivateKey(sk wgcfg.PrivateKey) error {
// lock required resources
@@ -217,9 +218,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// remove peers with matching public keys
- publicKey := sk.publicKey()
+ publicKey := sk.Public()
for key, peer := range device.peers.keyMap {
- if peer.handshake.remoteStatic.Equals(publicKey) {
+ if peer.handshake.remoteStatic.Equal(publicKey) {
unsafeRemovePeer(device, peer, key)
}
}
@@ -239,9 +240,9 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
handshake := &peer.handshake
if rmKey {
- handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
+ handshake.precomputedStaticStatic = [wgcfg.KeySize]byte{}
} else {
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
@@ -268,7 +269,7 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
}
device.tun.mtu = int32(mtu)
- device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+ device.peers.keyMap = make(map[wgcfg.Key]*Peer)
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
@@ -314,14 +315,14 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
return device
}
-func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
+func (device *Device) LookupPeer(pk wgcfg.Key) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
-func (device *Device) RemovePeer(key NoisePublicKey) {
+func (device *Device) RemovePeer(key wgcfg.Key) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
@@ -340,7 +341,7 @@ func (device *Device) RemoveAllPeers() {
unsafeRemovePeer(device, peer, key)
}
- device.peers.keyMap = make(map[NoisePublicKey]*Peer)
+ device.peers.keyMap = make(map[wgcfg.Key]*Peer)
}
func (device *Device) FlushPacketQueues() {
diff --git a/device/device_test.go b/device/device_test.go
index cdbd458..82bb264 100644
--- a/device/device_test.go
+++ b/device/device_test.go
@@ -12,6 +12,8 @@ package device
import (
"bytes"
"testing"
+
+ "golang.zx2c4.com/wireguard/wgcfg"
)
func TestDevice(t *testing.T) {
@@ -44,7 +46,8 @@ func TestDevice(t *testing.T) {
}
func randDevice(t *testing.T) *Device {
- sk, err := newPrivateKey()
+ t.Helper()
+ sk, err := wgcfg.NewPrivateKey()
if err != nil {
t.Fatal(err)
}
@@ -56,12 +59,14 @@ func randDevice(t *testing.T) *Device {
}
func assertNil(t *testing.T, err error) {
+ t.Helper()
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
+ t.Helper()
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
diff --git a/device/noise-helpers.go b/device/noise-helpers.go
index f5e4b4b..ae52a7d 100644
--- a/device/noise-helpers.go
+++ b/device/noise-helpers.go
@@ -7,12 +7,10 @@ package device
import (
"crypto/hmac"
- "crypto/rand"
"crypto/subtle"
"hash"
"golang.org/x/crypto/blake2s"
- "golang.org/x/crypto/curve25519"
)
/* KDF related functions.
@@ -75,28 +73,3 @@ func setZero(arr []byte) {
arr[i] = 0
}
}
-
-func (sk *NoisePrivateKey) clamp() {
- sk[0] &= 248
- sk[31] = (sk[31] & 127) | 64
-}
-
-func newPrivateKey() (sk NoisePrivateKey, err error) {
- _, err = rand.Read(sk[:])
- sk.clamp()
- return
-}
-
-func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
- apk := (*[NoisePublicKeySize]byte)(&pk)
- ask := (*[NoisePrivateKeySize]byte)(sk)
- curve25519.ScalarBaseMult(apk, ask)
- return
-}
-
-func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
- apk := (*[NoisePublicKeySize]byte)(&pk)
- ask := (*[NoisePrivateKeySize]byte)(sk)
- curve25519.ScalarMult(&ss, ask, apk)
- return ss
-}
diff --git a/device/noise-protocol.go b/device/noise-protocol.go
index dd75cc3..c6284c3 100644
--- a/device/noise-protocol.go
+++ b/device/noise-protocol.go
@@ -14,6 +14,7 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
+ "golang.zx2c4.com/wireguard/wgcfg"
)
const (
@@ -63,8 +64,8 @@ const (
type MessageInitiation struct {
Type uint32
Sender uint32
- Ephemeral NoisePublicKey
- Static [NoisePublicKeySize + poly1305.TagSize]byte
+ Ephemeral wgcfg.Key
+ Static [wgcfg.KeySize + poly1305.TagSize]byte
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
@@ -74,7 +75,7 @@ type MessageResponse struct {
Type uint32
Sender uint32
Receiver uint32
- Ephemeral NoisePublicKey
+ Ephemeral wgcfg.Key
Empty [poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
@@ -97,15 +98,15 @@ type MessageCookieReply struct {
type Handshake struct {
state int
mutex sync.RWMutex
- hash [blake2s.Size]byte // hash value
- chainKey [blake2s.Size]byte // chain key
- presharedKey NoiseSymmetricKey // psk
- localEphemeral NoisePrivateKey // ephemeral secret key
- localIndex uint32 // used to clear hash-table
- remoteIndex uint32 // index for sending
- remoteStatic NoisePublicKey // long term key
- remoteEphemeral NoisePublicKey // ephemeral public key
- precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
+ hash [blake2s.Size]byte // hash value
+ chainKey [blake2s.Size]byte // chain key
+ presharedKey wgcfg.SymmetricKey // psk
+ localEphemeral wgcfg.PrivateKey // ephemeral secret key
+ localIndex uint32 // used to clear hash-table
+ remoteIndex uint32 // index for sending
+ remoteStatic wgcfg.Key // long term key
+ remoteEphemeral wgcfg.Key // ephemeral public key
+ precomputedStaticStatic [wgcfg.KeySize]byte // precomputed shared secret
lastTimestamp tai64n.Timestamp
lastInitiationConsumption time.Time
lastSentHandshake time.Time
@@ -171,7 +172,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
- handshake.localEphemeral, err = newPrivateKey()
+ handshake.localEphemeral, err = wgcfg.NewPrivateKey()
if err != nil {
return nil, err
}
@@ -189,7 +190,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
msg := MessageInitiation{
Type: MessageInitiationType,
- Ephemeral: handshake.localEphemeral.publicKey(),
+ Ephemeral: handshake.localEphemeral.Public(),
Sender: handshake.localIndex,
}
@@ -200,7 +201,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
func() {
var key [chacha20poly1305.KeySize]byte
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ ss := handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
KDF2(
&handshake.chainKey,
&key,
@@ -252,10 +253,10 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
// decrypt static key
var err error
- var peerPK NoisePublicKey
+ var peerPK wgcfg.Key
func() {
var key [chacha20poly1305.KeySize]byte
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
@@ -352,18 +353,18 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
// create ephemeral key
- handshake.localEphemeral, err = newPrivateKey()
+ handshake.localEphemeral, err = wgcfg.NewPrivateKey()
if err != nil {
return nil, err
}
- msg.Ephemeral = handshake.localEphemeral.publicKey()
+ msg.Ephemeral = handshake.localEphemeral.Public()
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
func() {
- ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
+ ss := handshake.localEphemeral.SharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:])
- ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
+ ss = handshake.localEphemeral.SharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:])
}()
@@ -433,13 +434,13 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() {
- ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
+ ss := handshake.localEphemeral.SharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
func() {
- ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
+ ss := device.staticIdentity.privateKey.SharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
diff --git a/device/noise-types.go b/device/noise-types.go
deleted file mode 100644
index 6b1f16f..0000000
--- a/device/noise-types.go
+++ /dev/null
@@ -1,82 +0,0 @@
-/* SPDX-License-Identifier: MIT
- *
- * Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
- */
-
-package device
-
-import (
- "crypto/subtle"
- "encoding/hex"
- "errors"
-
- "golang.org/x/crypto/chacha20poly1305"
-)
-
-const (
- NoisePublicKeySize = 32
- NoisePrivateKeySize = 32
-)
-
-type (
- NoisePublicKey [NoisePublicKeySize]byte
- NoisePrivateKey [NoisePrivateKeySize]byte
- NoiseSymmetricKey [chacha20poly1305.KeySize]byte
- NoiseNonce uint64 // padded to 12-bytes
-)
-
-func loadExactHex(dst []byte, src string) error {
- slice, err := hex.DecodeString(src)
- if err != nil {
- return err
- }
- if len(slice) != len(dst) {
- return errors.New("hex string does not fit the slice")
- }
- copy(dst, slice)
- return nil
-}
-
-func (key NoisePrivateKey) IsZero() bool {
- var zero NoisePrivateKey
- return key.Equals(zero)
-}
-
-func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
- return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
-}
-
-func (key *NoisePrivateKey) FromHex(src string) (err error) {
- err = loadExactHex(key[:], src)
- key.clamp()
- return
-}
-
-func (key NoisePrivateKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
-
-func (key *NoisePublicKey) FromHex(src string) error {
- return loadExactHex(key[:], src)
-}
-
-func (key NoisePublicKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
-
-func (key NoisePublicKey) IsZero() bool {
- var zero NoisePublicKey
- return key.Equals(zero)
-}
-
-func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
- return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
-}
-
-func (key *NoiseSymmetricKey) FromHex(src string) error {
- return loadExactHex(key[:], src)
-}
-
-func (key NoiseSymmetricKey) ToHex() string {
- return hex.EncodeToString(key[:])
-}
diff --git a/device/noise_test.go b/device/noise_test.go
index 6ba3f2e..e431588 100644
--- a/device/noise_test.go
+++ b/device/noise_test.go
@@ -11,24 +11,6 @@ import (
"testing"
)
-func TestCurveWrappers(t *testing.T) {
- sk1, err := newPrivateKey()
- assertNil(t, err)
-
- sk2, err := newPrivateKey()
- assertNil(t, err)
-
- pk1 := sk1.publicKey()
- pk2 := sk2.publicKey()
-
- ss1 := sk1.sharedSecret(pk2)
- ss2 := sk2.sharedSecret(pk1)
-
- if ss1 != ss2 {
- t.Fatal("Failed to compute shared secet")
- }
-}
-
func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
@@ -36,8 +18,14 @@ func TestNoiseHandshake(t *testing.T) {
defer dev1.Close()
defer dev2.Close()
- peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
- peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
+ peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.Public())
+ if err != nil {
+ t.Fatal(err)
+ }
+ peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.Public())
+ if err != nil {
+ t.Fatal(err)
+ }
assertEqual(
t,
diff --git a/device/peer.go b/device/peer.go
index 4e7f2da..ec335a6 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -12,6 +12,8 @@ import (
"sync"
"sync/atomic"
"time"
+
+ "golang.zx2c4.com/wireguard/wgcfg"
)
const (
@@ -67,7 +69,7 @@ type Peer struct {
cookieGenerator CookieGenerator
}
-func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
+func (device *Device) NewPeer(pk wgcfg.Key) (*Peer, error) {
if device.isClosed.Get() {
return nil, errors.New("device closed")
@@ -110,7 +112,7 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
- handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
+ handshake.precomputedStaticStatic = device.staticIdentity.privateKey.SharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
diff --git a/device/uapi.go b/device/uapi.go
index 99cb421..8adca81 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -16,6 +16,7 @@ import (
"time"
"golang.zx2c4.com/wireguard/ipc"
+ "golang.zx2c4.com/wireguard/wgcfg"
)
type IPCError struct {
@@ -52,7 +53,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
- send("private_key=" + device.staticIdentity.privateKey.ToHex())
+ send("private_key=" + device.staticIdentity.privateKey.HexString())
}
if device.net.port != 0 {
@@ -69,8 +70,8 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
peer.RLock()
defer peer.RUnlock()
- send("public_key=" + peer.handshake.remoteStatic.ToHex())
- send("preshared_key=" + peer.handshake.presharedKey.ToHex())
+ send("public_key=" + peer.handshake.remoteStatic.HexString())
+ send("preshared_key=" + peer.handshake.presharedKey.HexString())
send("protocol_version=1")
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.DstToString())
@@ -136,8 +137,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "private_key":
- var sk NoisePrivateKey
- err := sk.FromHex(value)
+ sk, err := wgcfg.ParsePrivateHexKey(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
@@ -218,8 +218,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "public_key":
- var publicKey NoisePublicKey
- err := publicKey.FromHex(value)
+ publicKey, err := wgcfg.ParseHexKey(value)
if err != nil {
logError.Println("Failed to get peer by public key:", err)
return &IPCError{ipc.IpcErrorInvalid}
@@ -228,7 +227,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
// ignore peer with public key of device
device.staticIdentity.RLock()
- dummy = device.staticIdentity.publicKey.Equals(publicKey)
+ dummy = device.staticIdentity.publicKey.Equal(publicKey)
device.staticIdentity.RUnlock()
if dummy {
@@ -266,9 +265,10 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
// update PSK
logDebug.Println(peer, "- UAPI: Updating preshared key")
+ var err error
peer.handshake.mutex.Lock()
- err := peer.handshake.presharedKey.FromHex(value)
+ peer.handshake.presharedKey, err = wgcfg.ParseSymmetricHexKey(value)
peer.handshake.mutex.Unlock()
if err != nil {