From a899c69072e5224e7b9fc297288cdabc1cd122c6 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 14 Sep 2021 14:02:04 +0100 Subject: [PATCH] Avoid hook closures in syscall.NewCallback, #20. --- util_windows.go | 148 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 101 insertions(+), 47 deletions(-) diff --git a/util_windows.go b/util_windows.go index a406ad7..95b78f3 100644 --- a/util_windows.go +++ b/util_windows.go @@ -7,6 +7,7 @@ import ( "reflect" "runtime" "strconv" + "sync" "sync/atomic" "syscall" "unsafe" @@ -155,58 +156,111 @@ func hookDialog(ctx context.Context, initDialog func(wnd uintptr)) (unhook conte if ctx != nil && ctx.Err() != nil { return nil, ctx.Err() } - - var hook, wnd uintptr - callNextHookEx := callNextHookEx.Addr() - tid, _, _ := getCurrentThreadId.Call() - hook, _, err = setWindowsHookEx.Call(12, // WH_CALLWNDPROCRET - syscall.NewCallback(func(code int32, wparam uintptr, lparam *_CWPRETSTRUCT) uintptr { - if lparam.Message == 0x0110 { // WM_INITDIALOG - var name [8]uint16 - getClassName.Call(lparam.Wnd, uintptr(unsafe.Pointer(&name)), uintptr(len(name))) - if syscall.UTF16ToString(name[:]) == "#32770" { // The class for a dialog box - var close bool - - if ctx != nil && ctx.Err() != nil { - close = true - } else { - atomic.StoreUintptr(&wnd, lparam.Wnd) - } - - if close { - sendMessage.Call(lparam.Wnd, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0) - } else if initDialog != nil { - initDialog(lparam.Wnd) - } - } - } - next, _, _ := syscall.Syscall6(callNextHookEx, 4, - hook, uintptr(code), wparam, uintptr(unsafe.Pointer(lparam)), - 0, 0) - return next - }), 0, tid) - - if hook == 0 { + hook, err := newDialogHook(ctx, initDialog) + if err != nil { return nil, err } - if ctx == nil { - return func() { unhookWindowsHookEx.Call(hook) }, nil + return hook.unhook, nil +} + +type dialogHook struct { + ctx context.Context + tid uintptr + wnd uintptr + hook uintptr + done chan struct{} + init func(wnd uintptr) +} + +func newDialogHook(ctx context.Context, initDialog func(wnd uintptr)) (*dialogHook, error) { + tid, _, _ := getCurrentThreadId.Call() + hk, _, err := setWindowsHookEx.Call(12, // WH_CALLWNDPROCRET + syscall.NewCallback(dialogHookProc), 0, tid) + if hk == 0 { + return nil, err } - wait := make(chan struct{}) - go func() { - select { - case <-ctx.Done(): - if w := atomic.LoadUintptr(&wnd); w != 0 { - sendMessage.Call(w, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0) - } - case <-wait: + hook := dialogHook{ + ctx: ctx, + tid: tid, + hook: hk, + init: initDialog, + } + if ctx != nil { + hook.done = make(chan struct{}) + go hook.wait() + } + + saveDialogHook(&hook) + return &hook, nil +} + +func initDialogHook(wnd uintptr) { + tid, _, _ := getCurrentThreadId.Call() + hook := loadDialogHook(tid) + atomic.StoreUintptr(&hook.wnd, wnd) + if hook.ctx != nil && hook.ctx.Err() != nil { + sendMessage.Call(wnd, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0) + } else if hook.init != nil { + hook.init(wnd) + } +} + +func (h *dialogHook) unhook() { + deleteDialogHook(h.tid) + if h.done != nil { + close(h.done) + } + unhookWindowsHookEx.Call(h.hook) +} + +func (h *dialogHook) wait() { + select { + case <-h.ctx.Done(): + if wnd := atomic.LoadUintptr(&h.wnd); wnd != 0 { + sendMessage.Call(wnd, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0) } - }() - return func() { - unhookWindowsHookEx.Call(hook) - close(wait) - }, nil + case <-h.done: + } +} + +func dialogHookProc(code int32, wparam uintptr, lparam *_CWPRETSTRUCT) uintptr { + if lparam.Message == 0x0110 { // WM_INITDIALOG + var name [8]uint16 + getClassName.Call(lparam.Wnd, uintptr(unsafe.Pointer(&name)), uintptr(len(name))) + if syscall.UTF16ToString(name[:]) == "#32770" { // The class for a dialog box + initDialogHook(lparam.Wnd) + } + } + next, _, _ := callNextHookEx.Call( + 0, uintptr(code), wparam, uintptr(unsafe.Pointer(lparam))) + return next +} + +var dialogHooks struct { + sync.Mutex + m map[uintptr]*dialogHook +} + +func saveDialogHook(h *dialogHook) { + dialogHooks.Lock() + defer dialogHooks.Unlock() + if dialogHooks.m == nil { + dialogHooks.m = map[uintptr]*dialogHook{} + } + dialogHooks.m[h.tid] = h +} + +func loadDialogHook(tid uintptr) *dialogHook { + dialogHooks.Lock() + defer dialogHooks.Unlock() + return dialogHooks.m[tid] +} + +func deleteDialogHook(tid uintptr) { + dialogHooks.Lock() + defer dialogHooks.Unlock() + delete(dialogHooks.m, tid) } func hookDialogTitle(ctx context.Context, title *string) (unhook context.CancelFunc, err error) {