From ab56e18f1bf729223566160a9a5d0794e2fc96e1 Mon Sep 17 00:00:00 2001 From: "Jason A. Donenfeld" Date: Sat, 4 May 2019 22:40:19 +0200 Subject: ui: syntax: implement trafic blocking semantics This is our "auto kill switch". --- ui/editdialog.go | 281 ++++++++++++++++++++---------------------------- ui/filesave.go | 45 ++++++++ ui/syntax/syntaxedit.c | 64 +++++++++++ ui/syntax/syntaxedit.go | 17 ++- ui/syntax/syntaxedit.h | 7 ++ ui/util.go | 125 --------------------- 6 files changed, 248 insertions(+), 291 deletions(-) create mode 100644 ui/filesave.go delete mode 100644 ui/util.go (limited to 'ui') diff --git a/ui/editdialog.go b/ui/editdialog.go index 3cd5b606..77a044a0 100644 --- a/ui/editdialog.go +++ b/ui/editdialog.go @@ -15,44 +15,17 @@ import ( "golang.zx2c4.com/wireguard/windows/ui/syntax" ) -const ( - configKeyDNS = "DNS" - configKeyAllowedIPs = "AllowedIPs" -) - -var ( - ipv4Wildcard = orderedStringSetFromSlice([]string{"0.0.0.0/0"}) - ipv4PublicNetworks = orderedStringSetFromSlice([]string{ - "0.0.0.0/5", "8.0.0.0/7", "11.0.0.0/8", "12.0.0.0/6", "16.0.0.0/4", "32.0.0.0/3", - "64.0.0.0/2", "128.0.0.0/3", "160.0.0.0/5", "168.0.0.0/6", "172.0.0.0/12", - "172.32.0.0/11", "172.64.0.0/10", "172.128.0.0/9", "173.0.0.0/8", "174.0.0.0/7", - "176.0.0.0/4", "192.0.0.0/9", "192.128.0.0/11", "192.160.0.0/13", "192.169.0.0/16", - "192.170.0.0/15", "192.172.0.0/14", "192.176.0.0/12", "192.192.0.0/10", - "193.0.0.0/8", "194.0.0.0/7", "196.0.0.0/6", "200.0.0.0/5", "208.0.0.0/4", - }) -) - -type allowedIPsState int - -const ( - allowedIPsStateInvalid allowedIPsState = iota - allowedIPsStateContainsIPV4Wildcard - allowedIPsStateContainsIPV4PublicNetworks - allowedIPsStateOther -) - type EditDialog struct { *walk.Dialog - nameEdit *walk.LineEdit - pubkeyEdit *walk.LineEdit - syntaxEdit *syntax.SyntaxEdit - excludePrivateIPsCB *walk.CheckBox - saveButton *walk.PushButton - tunnel *service.Tunnel - config conf.Config - allowedIPsState allowedIPsState - lastPrivateKey string - inCheckedChanged bool + nameEdit *walk.LineEdit + pubkeyEdit *walk.LineEdit + syntaxEdit *syntax.SyntaxEdit + blockUntunneledTrafficCB *walk.CheckBox + saveButton *walk.PushButton + tunnel *service.Tunnel + config conf.Config + lastPrivateKey string + blockUntunneledTraficCheckGuard bool } func runTunnelEditDialog(owner walk.Form, tunnel *service.Tunnel) *conf.Config { @@ -106,19 +79,17 @@ func runTunnelEditDialog(owner walk.Form, tunnel *service.Tunnel) *conf.Config { dlg.syntaxEdit, _ = syntax.NewSyntaxEdit(dlg) layout.SetRange(dlg.syntaxEdit, walk.Rectangle{0, 2, 2, 1}) - dlg.syntaxEdit.PrivateKeyChanged().Attach(dlg.onSyntaxEditPrivateKeyChanged) - dlg.syntaxEdit.SetText(dlg.config.ToWgQuick()) - dlg.syntaxEdit.TextChanged().Attach(dlg.updateExcludePrivateIPsCBVisible) buttonsContainer, _ := walk.NewComposite(dlg) layout.SetRange(buttonsContainer, walk.Rectangle{0, 3, 2, 1}) buttonsContainer.SetLayout(walk.NewHBoxLayout()) buttonsContainer.Layout().SetMargins(walk.Margins{}) - dlg.excludePrivateIPsCB, _ = walk.NewCheckBox(buttonsContainer) - dlg.excludePrivateIPsCB.SetText("Exclude private IPs") - dlg.excludePrivateIPsCB.CheckedChanged().Attach(dlg.onExcludePrivateIPsCBCheckedChanged) - dlg.updateExcludePrivateIPsCBVisible() + dlg.blockUntunneledTrafficCB, _ = walk.NewCheckBox(buttonsContainer) + dlg.blockUntunneledTrafficCB.SetText("Block untunneled traffic (kill-switch)") + dlg.blockUntunneledTrafficCB.SetToolTipText("When a configuration has exactly one peer, and that peer has an allowed IPs containing at least one of 0.0.0.0/0 or ::/0, then the tunnel service engages a firewall ruleset to block all traffic that is neither to nor from the tunnel interface, with special exceptions for DHCP and NDP.") + dlg.blockUntunneledTrafficCB.SetVisible(false) + dlg.blockUntunneledTrafficCB.CheckedChanged().Attach(dlg.onBlockUntunneledTrafficCBCheckedChanged) walk.NewHSpacer(buttonsContainer) @@ -133,142 +104,126 @@ func runTunnelEditDialog(owner walk.Form, tunnel *service.Tunnel) *conf.Config { dlg.SetCancelButton(cancelButton) dlg.SetDefaultButton(dlg.saveButton) - dlg.updateAllowedIPsState() + dlg.syntaxEdit.PrivateKeyChanged().Attach(dlg.onSyntaxEditPrivateKeyChanged) + dlg.syntaxEdit.BlockUntunneledTrafficStateChanged().Attach(dlg.onBlockUntunneledTrafficStateChanged) + dlg.syntaxEdit.SetText(dlg.config.ToWgQuick()) if dlg.Run() == walk.DlgCmdOK { - // Save return &dlg.config } return nil } -func (dlg *EditDialog) updateAllowedIPsState() { - var newState allowedIPsState - if len(dlg.config.Peers) == 1 { - if allowedIPs := dlg.allowedIPsSet(); allowedIPs.IsSupersetOf(ipv4Wildcard) { - newState = allowedIPsStateContainsIPV4Wildcard - } else if allowedIPs.IsSupersetOf(ipv4PublicNetworks) { - newState = allowedIPsStateContainsIPV4PublicNetworks - } else { - newState = allowedIPsStateOther - } - } else { - newState = allowedIPsStateInvalid +func (dlg *EditDialog) onBlockUntunneledTrafficCBCheckedChanged() { + if dlg.blockUntunneledTraficCheckGuard { + return } + var ( + v40 = [4]byte{} + v60 = [16]byte{} + v48 = [4]byte{0x80} + v68 = [16]byte{0x80} + ) - if newState != dlg.allowedIPsState { - dlg.allowedIPsState = newState + block := dlg.blockUntunneledTrafficCB.Checked() + cfg, err := conf.FromWgQuick(dlg.syntaxEdit.Text(), "temporary") + var newAllowedIPs []conf.IPCidr - dlg.excludePrivateIPsCB.SetVisible(dlg.canExcludePrivateIPs()) - dlg.excludePrivateIPsCB.SetChecked(dlg.privateIPsExcluded()) + if err != nil { + goto err } -} - -func (dlg *EditDialog) canExcludePrivateIPs() bool { - return dlg.allowedIPsState == allowedIPsStateContainsIPV4PublicNetworks || - dlg.allowedIPsState == allowedIPsStateContainsIPV4Wildcard -} - -func (dlg *EditDialog) privateIPsExcluded() bool { - return dlg.allowedIPsState == allowedIPsStateContainsIPV4PublicNetworks -} - -func (dlg *EditDialog) setPrivateIPsExcluded(excluded bool) { - if !dlg.canExcludePrivateIPs() || dlg.privateIPsExcluded() == excluded { - return + if len(cfg.Peers) != 1 { + goto err } - var oldNetworks, newNetworks *orderedStringSet - if excluded { - oldNetworks, newNetworks = ipv4Wildcard, ipv4PublicNetworks - } else { - oldNetworks, newNetworks = ipv4PublicNetworks, ipv4Wildcard - } - input := dlg.allowedIPs() - output := newOrderedStringSet() - var replaced bool - - // Replace the first instance of the wildcard with the public network list, or vice versa. - for _, network := range input { - if oldNetworks.Contains(network) { - if !replaced { - output.UniteWith(newNetworks) - replaced = true + newAllowedIPs = make([]conf.IPCidr, 0, len(cfg.Peers[0].AllowedIPs)) + if block { + var ( + foundV401 bool + foundV41281 bool + foundV600001 bool + foundV680001 bool + ) + for _, allowedip := range cfg.Peers[0].AllowedIPs { + if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { + foundV600001 = true + } else if allowedip.Cidr == 1 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v68[:]) { + foundV680001 = true + } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { + foundV401 = true + } else if allowedip.Cidr == 1 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v48[:]) { + foundV41281 = true + } else { + newAllowedIPs = append(newAllowedIPs, allowedip) } - } else { - output.Add(network) } - } - - // DNS servers only need to be handled specially when we're excluding private IPs. - for _, route := range dlg.dnsRoutes() { - if excluded { - output.Add(route) - } else { - output.Remove(route) - output.Remove(route + "/32") + if !((foundV401 && foundV41281) || (foundV600001 && foundV680001)) { + goto err } - } - - if excluded { - dlg.allowedIPsState = allowedIPsStateContainsIPV4PublicNetworks + if foundV401 && foundV41281 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 0}) + } else if foundV401 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1}) + } else if foundV41281 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1}) + } + if foundV600001 && foundV680001 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 0}) + } else if foundV600001 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1}) + } else if foundV680001 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1}) + } + cfg.Peers[0].AllowedIPs = newAllowedIPs } else { - dlg.allowedIPsState = allowedIPsStateContainsIPV4Wildcard - } - - dlg.replaceLine(configKeyAllowedIPs, strings.Join(output.ToSlice(), ", ")) -} - -func (dlg *EditDialog) replaceLine(key, value string) { - text := dlg.syntaxEdit.Text() - - start := strings.Index(text, key) - end := start + strings.Index(text[start:], "\n") - oldLine := text[start:end] - newLine := fmt.Sprintf("%s = %s", key, value) - - dlg.syntaxEdit.SetText(strings.ReplaceAll(text, oldLine, newLine)) -} - -func (dlg *EditDialog) updateExcludePrivateIPsCBVisible() { - dlg.updateAllowedIPsState() - - dlg.excludePrivateIPsCB.SetVisible(dlg.canExcludePrivateIPs()) -} - -func (dlg *EditDialog) dnsRoutes() []string { - return dlg.routes(configKeyDNS) -} - -func (dlg *EditDialog) allowedIPs() []string { - return dlg.routes(configKeyAllowedIPs) -} - -func (dlg *EditDialog) allowedIPsSet() *orderedStringSet { - return orderedStringSetFromSlice(dlg.allowedIPs()) -} - -func (dlg *EditDialog) routes(key string) []string { - var routes []string - - lines := strings.Split(dlg.syntaxEdit.Text(), "\n") - for _, line := range lines { - if strings.HasPrefix(strings.TrimSpace(line), key) { - routesMaybeWithSpace := strings.Split(strings.TrimSpace(line[strings.IndexByte(line, '=')+1:]), ",") - routes = make([]string, len(routesMaybeWithSpace)) - for i, route := range routesMaybeWithSpace { - routes[i] = strings.TrimSpace(route) + var ( + foundV400 bool + foundV600 bool + ) + for _, allowedip := range cfg.Peers[0].AllowedIPs { + if allowedip.Cidr == 0 && len(allowedip.IP) == 16 && allowedip.IP.Equal(v60[:]) { + foundV600 = true + } else if allowedip.Cidr == 0 && len(allowedip.IP) == 4 && allowedip.IP.Equal(v40[:]) { + foundV400 = true + } else { + newAllowedIPs = append(newAllowedIPs, allowedip) } - break } + if !(foundV400 || foundV600) { + goto err + } + if foundV400 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v40[:], 1}) + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v48[:], 1}) + } + if foundV600 { + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v60[:], 1}) + newAllowedIPs = append(newAllowedIPs, conf.IPCidr{v68[:], 1}) + } + cfg.Peers[0].AllowedIPs = newAllowedIPs } + dlg.syntaxEdit.SetText(cfg.ToWgQuick()) + return - return routes +err: + walk.MsgBox(dlg, "Invalid configuration", "Unable to toggle untunneled traffic blocking state.", walk.MsgBoxIconWarning) + dlg.blockUntunneledTrafficCB.SetVisible(false) } -func (dlg *EditDialog) onExcludePrivateIPsCBCheckedChanged() { - dlg.setPrivateIPsExcluded(dlg.excludePrivateIPsCB.Checked()) +func (dlg *EditDialog) onBlockUntunneledTrafficStateChanged(state int) { + dlg.blockUntunneledTraficCheckGuard = true + switch state { + case syntax.InevaluableBlockingUntunneledTraffic: + dlg.blockUntunneledTrafficCB.SetVisible(false) + case syntax.BlockingUntunneledTraffic: + dlg.blockUntunneledTrafficCB.SetVisible(true) + dlg.blockUntunneledTrafficCB.SetChecked(true) + case syntax.NotBlockingUntunneledTraffic: + dlg.blockUntunneledTrafficCB.SetVisible(true) + dlg.blockUntunneledTrafficCB.SetChecked(false) + } + dlg.blockUntunneledTraficCheckGuard = false } func (dlg *EditDialog) onSyntaxEditPrivateKeyChanged(privateKey string) { @@ -287,14 +242,18 @@ func (dlg *EditDialog) onSyntaxEditPrivateKeyChanged(privateKey string) { func (dlg *EditDialog) onSaveButtonClicked() { newName := dlg.nameEdit.Text() if newName == "" { - walk.MsgBox(dlg, "Invalid configuration", "Name is required", walk.MsgBoxIconWarning) + walk.MsgBox(dlg, "Invalid name", "A name is required.", walk.MsgBoxIconWarning) + return + } + if !conf.TunnelNameIsValid(newName) { + walk.MsgBox(dlg, "Invalid name", fmt.Sprintf("Tunnel name ā€˜%sā€™ is invalid.", newName), walk.MsgBoxIconWarning) return } if dlg.tunnel != nil && dlg.tunnel.Name != newName { names, err := conf.ListConfigNames() if err != nil { - walk.MsgBox(dlg, "Error", err.Error(), walk.MsgBoxIconError) + walk.MsgBox(dlg, "Unable to list existing tunnels", err.Error(), walk.MsgBoxIconError) return } @@ -306,18 +265,12 @@ func (dlg *EditDialog) onSaveButtonClicked() { } } - if !conf.TunnelNameIsValid(newName) { - walk.MsgBox(dlg, "Invalid configuration", fmt.Sprintf("Tunnel name ā€˜%sā€™ is invalid.", newName), walk.MsgBoxIconWarning) - return - } - cfg, err := conf.FromWgQuick(dlg.syntaxEdit.Text(), newName) if err != nil { - walk.MsgBox(dlg, "Error", err.Error(), walk.MsgBoxIconError) + walk.MsgBox(dlg, "Unable to create new configuration", err.Error(), walk.MsgBoxIconError) return } dlg.config = *cfg - dlg.Accept() } diff --git a/ui/filesave.go b/ui/filesave.go new file mode 100644 index 00000000..b17f106c --- /dev/null +++ b/ui/filesave.go @@ -0,0 +1,45 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. + */ + +package ui + +import ( + "fmt" + "os" + + "github.com/lxn/walk" +) + +func writeFileWithOverwriteHandling(owner walk.Form, filePath string, write func(file *os.File) error) bool { + showError := func(err error) bool { + if err == nil { + return false + } + + walk.MsgBox(owner, "Writing file failed", err.Error(), walk.MsgBoxIconError) + + return true + } + + file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) + if err != nil { + if os.IsExist(err) { + if walk.DlgCmdNo == walk.MsgBox(owner, "Writing file failed", fmt.Sprintf(`File "%s" already exists. + +Do you want to overwrite it?`, filePath), walk.MsgBoxYesNo|walk.MsgBoxDefButton2|walk.MsgBoxIconWarning) { + return false + } + + if file, err = os.Create(filePath); err != nil { + return !showError(err) + } + } else { + return !showError(err) + } + } + defer file.Close() + + return !showError(write(file)) +} diff --git a/ui/syntax/syntaxedit.c b/ui/syntax/syntaxedit.c index 1b674623..cf75b0e4 100644 --- a/ui/syntax/syntaxedit.c +++ b/ui/syntax/syntaxedit.c @@ -21,6 +21,7 @@ const GUID CDECL IID_ITextDocument = { 0x8CC497C0, 0xA1DF, 0x11CE, { 0x80, 0x98, struct syntaxedit_data { IRichEditOle *irich; ITextDocument *idoc; + enum block_state last_block_state; bool highlight_guard; }; @@ -54,6 +55,67 @@ static const struct span_style stylemap[] = { [HighlightError] = { .color = RGB(0xC4, 0x1A, 0x16), .effects = CFE_UNDERLINE } }; +static void evaluate_untunneled_blocking(struct syntaxedit_data *this, HWND hWnd, const char *msg, struct highlight_span *spans) +{ + enum block_state state = InevaluableBlockingUntunneledTraffic; + bool on_allowedips = false; + bool seen_peer = false; + bool seen_v6_00 = false, seen_v4_00 = false; + bool seen_v6_01 = false, seen_v6_80001 = false, seen_v4_01 = false, seen_v4_1281 = false; + + for (struct highlight_span *span = spans; span->type != HighlightEnd; ++span) { + switch (span->type) { + case HighlightError: + goto done; + case HighlightSection: + if (span->len != 6 || strncasecmp(&msg[span->start], "[peer]", 6)) + break; + if (!seen_peer) + seen_peer = true; + else + goto done; + break; + case HighlightField: + on_allowedips = span->len == 10 && !strncasecmp(&msg[span->start], "allowedips", 10); + break; + case HighlightIP: + if (!on_allowedips || !seen_peer) + break; + if ((span + 1)->type != HighlightDelimiter || (span + 2)->type != HighlightCidr) + break; + if ((span + 2)->len != 1) + break; + if (msg[(span + 2)->start] == '0') { + if (span->len == 7 && !strncmp(&msg[span->start], "0.0.0.0", 7)) + seen_v4_00 = true; + else if (span->len == 2 && !strncmp(&msg[span->start], "::", 2)) + seen_v6_00 = true; + } else if (msg[(span + 2)->start] == '1') { + if (span->len == 7 && !strncmp(&msg[span->start], "0.0.0.0", 7)) + seen_v4_01 = true; + else if (span->len == 9 && !strncmp(&msg[span->start], "128.0.0.0", 9)) + seen_v4_1281 = true; + else if (span->len == 2 && !strncmp(&msg[span->start], "::", 2)) + seen_v6_01 = true; + else if (span->len == 6 && !strncmp(&msg[span->start], "8000::", 6)) + seen_v6_80001 = true; + } + break; + } + } + + if (seen_v4_00 || seen_v6_00) + state = BlockingUntunneledTraffic; + else if ((seen_v4_01 && seen_v4_1281) || (seen_v6_01 && seen_v6_80001)) + state = NotBlockingUntunneledTraffic; + +done: + if (state != this->last_block_state) { + SendMessage(hWnd, SE_TRAFFIC_BLOCK, 0, state); + this->last_block_state = state; + } +} + static void highlight_text(HWND hWnd) { GETTEXTLENGTHEX gettextlengthex = { @@ -104,6 +166,8 @@ static void highlight_text(HWND hWnd) if (!spans) goto out; + evaluate_untunneled_blocking(this, hWnd, msg, spans); + this->idoc->lpVtbl->Undo(this->idoc, tomSuspend, NULL); SendMessage(hWnd, WM_SETREDRAW, FALSE, 0); SendMessage(hWnd, EM_EXGETSEL, 0, (LPARAM)&orig_selection); diff --git a/ui/syntax/syntaxedit.go b/ui/syntax/syntaxedit.go index 5598d7a8..67e132c4 100644 --- a/ui/syntax/syntaxedit.go +++ b/ui/syntax/syntaxedit.go @@ -20,10 +20,17 @@ import "C" type SyntaxEdit struct { walk.WidgetBase - textChangedPublisher walk.EventPublisher - privateKeyPublisher walk.StringEventPublisher + textChangedPublisher walk.EventPublisher + privateKeyPublisher walk.StringEventPublisher + blockUntunneledTrafficPublisher walk.IntEventPublisher } +const ( + InevaluableBlockingUntunneledTraffic = C.InevaluableBlockingUntunneledTraffic + BlockingUntunneledTraffic = C.BlockingUntunneledTraffic + NotBlockingUntunneledTraffic = C.NotBlockingUntunneledTraffic +) + func (se *SyntaxEdit) LayoutFlags() walk.LayoutFlags { return walk.GrowableHorz | walk.GrowableVert | walk.GreedyHorz | walk.GreedyVert } @@ -63,6 +70,10 @@ func (se *SyntaxEdit) PrivateKeyChanged() *walk.StringEvent { return se.privateKeyPublisher.Event() } +func (se *SyntaxEdit) BlockUntunneledTrafficStateChanged() *walk.IntEvent { + return se.blockUntunneledTrafficPublisher.Event() +} + func (se *SyntaxEdit) WndProc(hwnd win.HWND, msg uint32, wParam, lParam uintptr) uintptr { switch msg { case win.WM_NOTIFY, win.WM_COMMAND: @@ -78,6 +89,8 @@ func (se *SyntaxEdit) WndProc(hwnd win.HWND, msg uint32, wParam, lParam uintptr) } else { se.privateKeyPublisher.Publish(C.GoString((*C.char)(unsafe.Pointer(lParam)))) } + case C.SE_TRAFFIC_BLOCK: + se.blockUntunneledTrafficPublisher.Publish(int(lParam)) } return se.WidgetBase.WndProc(hwnd, msg, wParam, lParam) } diff --git a/ui/syntax/syntaxedit.h b/ui/syntax/syntaxedit.h index 4013f328..7d158b29 100644 --- a/ui/syntax/syntaxedit.h +++ b/ui/syntax/syntaxedit.h @@ -17,6 +17,13 @@ #define WM_REFLECT (WM_USER + 0x1C00) #define SE_PRIVATE_KEY (WM_USER + 0x3100) +#define SE_TRAFFIC_BLOCK (WM_USER + 0x3101) + +enum block_state { + InevaluableBlockingUntunneledTraffic, + BlockingUntunneledTraffic, + NotBlockingUntunneledTraffic +}; extern bool register_syntax_edit(void); diff --git a/ui/util.go b/ui/util.go deleted file mode 100644 index 0cca0909..00000000 --- a/ui/util.go +++ /dev/null @@ -1,125 +0,0 @@ -/* SPDX-License-Identifier: MIT - * - * Copyright (C) 2019 WireGuard LLC. All Rights Reserved. - */ - -package ui - -import ( - "fmt" - "os" - - "github.com/lxn/walk" -) - -type orderedStringSet struct { - items []string - item2index map[string]int -} - -func orderedStringSetFromSlice(items []string) *orderedStringSet { - oss := newOrderedStringSet() - oss.AddMany(items) - return oss -} - -func newOrderedStringSet() *orderedStringSet { - return &orderedStringSet{item2index: make(map[string]int)} -} - -func (oss *orderedStringSet) Add(item string) bool { - if _, ok := oss.item2index[item]; ok { - return false - } - - oss.item2index[item] = len(oss.items) - oss.items = append(oss.items, item) - return true -} - -func (oss *orderedStringSet) AddMany(items []string) { - for _, item := range items { - oss.Add(item) - } -} - -func (oss *orderedStringSet) UniteWith(other *orderedStringSet) { - if other == oss { - return - } - - oss.AddMany(other.items) -} - -func (oss *orderedStringSet) Remove(item string) bool { - if i, ok := oss.item2index[item]; ok { - oss.items = append(oss.items[:i], oss.items[i+1:]...) - delete(oss.item2index, item) - return true - } - - return false -} - -func (oss *orderedStringSet) Len() int { - return len(oss.items) -} - -func (oss *orderedStringSet) ToSlice() []string { - return append(([]string)(nil), oss.items...) -} - -func (oss *orderedStringSet) Contains(item string) bool { - _, ok := oss.item2index[item] - return ok -} - -func (oss *orderedStringSet) IsSupersetOf(other *orderedStringSet) bool { - if oss.Len() < other.Len() { - return false - } - - for _, item := range other.items { - if !oss.Contains(item) { - return false - } - } - - return true -} - -func (oss *orderedStringSet) String() string { - return fmt.Sprintf("%v", oss.items) -} - -func writeFileWithOverwriteHandling(owner walk.Form, filePath string, write func(file *os.File) error) bool { - showError := func(err error) bool { - if err == nil { - return false - } - - walk.MsgBox(owner, "Writing file failed", err.Error(), walk.MsgBoxIconError) - - return true - } - - file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_EXCL, 0600) - if err != nil { - if os.IsExist(err) { - if walk.DlgCmdNo == walk.MsgBox(owner, "Writing file failed", fmt.Sprintf(`File "%s" already exists. - -Do you want to overwrite it?`, filePath), walk.MsgBoxYesNo|walk.MsgBoxDefButton2|walk.MsgBoxIconWarning) { - return false - } - - if file, err = os.Create(filePath); err != nil { - return !showError(err) - } - } else { - return !showError(err) - } - } - defer file.Close() - - return !showError(write(file)) -} -- cgit v1.2.3-59-g8ed1b