aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJason A. Donenfeld <Jason@zx2c4.com>2021-01-25 19:00:43 +0100
committerJason A. Donenfeld <Jason@zx2c4.com>2021-01-25 20:48:28 +0100
commit18e47795e598973195887893e7d77baddec53ebb (patch)
tree3688fae4aedbad3be80d6ac97d4105f3fc0ba072
parentipc: add missing Windows errno (diff)
downloadwireguard-go-18e47795e598973195887893e7d77baddec53ebb.tar.xz
wireguard-go-18e47795e598973195887893e7d77baddec53ebb.zip
device: allow pipelining UAPI requests
The original spec ends with \n\n especially for this reason. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
-rw-r--r--device/uapi.go66
1 files changed, 36 insertions, 30 deletions
diff --git a/device/uapi.go b/device/uapi.go
index c1ddb38..31fbdc7 100644
--- a/device/uapi.go
+++ b/device/uapi.go
@@ -380,9 +380,6 @@ func (device *Device) IpcSet(uapiConf string) error {
}
func (device *Device) IpcHandle(socket net.Conn) {
-
- // create buffered read/writer
-
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
@@ -391,34 +388,43 @@ func (device *Device) IpcHandle(socket net.Conn) {
return bufio.NewReadWriter(reader, writer)
}(socket)
- defer buffered.Flush()
-
- op, err := buffered.ReadString('\n')
- if err != nil {
- return
- }
+ for {
+ op, err := buffered.ReadString('\n')
+ if err != nil {
+ return
+ }
- // handle operation
- switch op {
- case "set=1\n":
- err = device.IpcSetOperation(buffered.Reader)
- case "get=1\n":
- err = device.IpcGetOperation(buffered.Writer)
- default:
- device.log.Error.Println("invalid UAPI operation:", op)
- return
- }
+ // handle operation
+ switch op {
+ case "set=1\n":
+ err = device.IpcSetOperation(buffered.Reader)
+ case "get=1\n":
+ nextByte, err := buffered.ReadByte()
+ if err != nil {
+ return
+ }
+ if nextByte != '\n' {
+ err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %c", nextByte, err)
+ break
+ }
+ err = device.IpcGetOperation(buffered.Writer)
+ default:
+ device.log.Error.Println("invalid UAPI operation:", op)
+ return
+ }
- // write status
- var status *IPCError
- if err != nil && !errors.As(err, &status) {
- // shouldn't happen
- status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
- }
- if status != nil {
- device.log.Error.Println(status)
- fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
- } else {
- fmt.Fprintf(buffered, "errno=0\n\n")
+ // write status
+ var status *IPCError
+ if err != nil && !errors.As(err, &status) {
+ // shouldn't happen
+ status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
+ }
+ if status != nil {
+ device.log.Error.Println(status)
+ fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
+ } else {
+ fmt.Fprintf(buffered, "errno=0\n\n")
+ }
+ buffered.Flush()
}
}