Generalize keeping backreferences.

This commit is contained in:
Nuno Cruces 2022-03-23 13:24:02 +00:00
parent e03e9df189
commit da48e44e0c

View file

@ -191,23 +191,32 @@ func newDialogHook(ctx context.Context, initDialog func(wnd uintptr)) (*dialogHo
go hook.wait() go hook.wait()
} }
saveDialogHook(&hook) saveBackRef(tid, unsafe.Pointer(&hook))
return &hook, nil return &hook, nil
} }
func initDialogHook(wnd uintptr) { func dialogHookProc(code int32, wparam uintptr, lparam *_CWPRETSTRUCT) uintptr {
if lparam.Message == 0x0110 { // WM_INITDIALOG
var name [8]uint16
getClassName.Call(lparam.Wnd, uintptr(unsafe.Pointer(&name)), uintptr(len(name)))
if syscall.UTF16ToString(name[:]) == "#32770" { // The class for a dialog box
tid, _, _ := getCurrentThreadId.Call() tid, _, _ := getCurrentThreadId.Call()
hook := loadDialogHook(tid) hook := (*dialogHook)(loadBackRef(tid))
atomic.StoreUintptr(&hook.wnd, wnd) atomic.StoreUintptr(&hook.wnd, lparam.Wnd)
if hook.ctx != nil && hook.ctx.Err() != nil { if hook.ctx != nil && hook.ctx.Err() != nil {
sendMessage.Call(wnd, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0) sendMessage.Call(lparam.Wnd, 0x0112 /* WM_SYSCOMMAND */, 0xf060 /* SC_CLOSE */, 0)
} else if hook.init != nil { } else if hook.init != nil {
hook.init(wnd) hook.init(lparam.Wnd)
} }
} }
}
next, _, _ := callNextHookEx.Call(
0, uintptr(code), wparam, uintptr(unsafe.Pointer(lparam)))
return next
}
func (h *dialogHook) unhook() { func (h *dialogHook) unhook() {
deleteDialogHook(h.tid) deleteBackRef(h.tid)
if h.done != nil { if h.done != nil {
close(h.done) close(h.done)
} }
@ -224,45 +233,6 @@ func (h *dialogHook) wait() {
} }
} }
func dialogHookProc(code int32, wparam uintptr, lparam *_CWPRETSTRUCT) uintptr {
if lparam.Message == 0x0110 { // WM_INITDIALOG
var name [8]uint16
getClassName.Call(lparam.Wnd, uintptr(unsafe.Pointer(&name)), uintptr(len(name)))
if syscall.UTF16ToString(name[:]) == "#32770" { // The class for a dialog box
initDialogHook(lparam.Wnd)
}
}
next, _, _ := callNextHookEx.Call(
0, uintptr(code), wparam, uintptr(unsafe.Pointer(lparam)))
return next
}
var dialogHooks struct {
sync.Mutex
m map[uintptr]*dialogHook
}
func saveDialogHook(h *dialogHook) {
dialogHooks.Lock()
defer dialogHooks.Unlock()
if dialogHooks.m == nil {
dialogHooks.m = map[uintptr]*dialogHook{}
}
dialogHooks.m[h.tid] = h
}
func loadDialogHook(tid uintptr) *dialogHook {
dialogHooks.Lock()
defer dialogHooks.Unlock()
return dialogHooks.m[tid]
}
func deleteDialogHook(tid uintptr) {
dialogHooks.Lock()
defer dialogHooks.Unlock()
delete(dialogHooks.m, tid)
}
func hookDialogTitle(ctx context.Context, title *string) (unhook context.CancelFunc, err error) { func hookDialogTitle(ctx context.Context, title *string) (unhook context.CancelFunc, err error) {
var init func(wnd uintptr) var init func(wnd uintptr)
if title != nil { if title != nil {
@ -273,6 +243,32 @@ func hookDialogTitle(ctx context.Context, title *string) (unhook context.CancelF
return hookDialog(ctx, init) return hookDialog(ctx, init)
} }
var backRefs struct {
sync.Mutex
m map[uintptr]unsafe.Pointer
}
func saveBackRef(id uintptr, ptr unsafe.Pointer) {
backRefs.Lock()
defer backRefs.Unlock()
if backRefs.m == nil {
backRefs.m = map[uintptr]unsafe.Pointer{}
}
backRefs.m[id] = ptr
}
func loadBackRef(id uintptr) unsafe.Pointer {
backRefs.Lock()
defer backRefs.Unlock()
return backRefs.m[id]
}
func deleteBackRef(id uintptr) {
backRefs.Lock()
defer backRefs.Unlock()
delete(backRefs.m, id)
}
type dpi uintptr type dpi uintptr
func getDPI(wnd uintptr) dpi { func getDPI(wnd uintptr) dpi {