From 99f1a258bcfc7eaa8b3a90332e51ae06f8768454 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 23 Feb 2024 13:06:35 +0000 Subject: [PATCH] Don't CoUninitialize, close MTA dialog. --- file_windows.go | 43 ++++++++++++++++++++++++++++++++++------- internal/win/ole32.go | 1 + internal/win/shell32.go | 13 +++++++++++-- 3 files changed, 48 insertions(+), 9 deletions(-) diff --git a/file_windows.go b/file_windows.go index bff5d6c..a117009 100644 --- a/file_windows.go +++ b/file_windows.go @@ -246,6 +246,18 @@ func fileOpenDialog(opts options, multi bool) (string, []string, bool, error) { } defer unhook() + if opts.ctx != nil && opts.ctx.Done() != nil { + wait := make(chan struct{}) + defer close(wait) + go func() { + select { + case <-opts.ctx.Done(): + dialog.Close(win.E_TIMEOUT) + case <-wait: + } + }() + } + err = dialog.Show(owner) if opts.ctx != nil && opts.ctx.Err() != nil { return "", nil, true, opts.ctx.Err() @@ -352,6 +364,18 @@ func fileSaveDialog(opts options) (string, bool, error) { } defer unhook() + if opts.ctx != nil && opts.ctx.Done() != nil { + wait := make(chan struct{}) + defer close(wait) + go func() { + select { + case <-opts.ctx.Done(): + dialog.Close(win.E_TIMEOUT) + case <-wait: + } + }() + } + err = dialog.Show(owner) if opts.ctx != nil && opts.ctx.Err() != nil { return "", true, opts.ctx.Err() @@ -428,14 +452,19 @@ func browseForFolderCallback(wnd win.HWND, msg uint32, lparam, data uintptr) uin func coInitialize() (context.CancelFunc, error) { runtime.LockOSThread() - err := win.CoInitializeEx(0, win.COINIT_APARTMENTTHREADED|win.COINIT_DISABLE_OLE1DDE) - if err == nil || err == win.S_FALSE { - return func() { - win.CoUninitialize() - runtime.UnlockOSThread() - }, nil + // .NET uses MTA for all background threads, so do the same. + // If someone needs STA because they're doing UI, + // they should initialize COM themselves before. + err := win.CoInitializeEx(0, win.COINIT_MULTITHREADED|win.COINIT_DISABLE_OLE1DDE) + if err == win.S_FALSE { + // COM was already initialized, we simply increased the ref count. + // Make this a no-op by decreasing our ref count. + win.CoUninitialize() + return runtime.UnlockOSThread, nil } - if err == win.RPC_E_CHANGED_MODE { + // Don't uninitialize COM; this is against the docs, but it's what .NET does. + // Eventually all threads will have COM initialized. + if err == nil || err == win.RPC_E_CHANGED_MODE { return runtime.UnlockOSThread, nil } runtime.UnlockOSThread() diff --git a/internal/win/ole32.go b/internal/win/ole32.go index 9b13e36..9383a93 100644 --- a/internal/win/ole32.go +++ b/internal/win/ole32.go @@ -23,6 +23,7 @@ const ( CLSCTX_ALL = windows.CLSCTX_INPROC_SERVER | windows.CLSCTX_INPROC_HANDLER | windows.CLSCTX_LOCAL_SERVER | windows.CLSCTX_REMOTE_SERVER E_CANCELED = windows.ERROR_CANCELLED | windows.FACILITY_WIN32<<16 | 0x80000000 + E_TIMEOUT = windows.ERROR_TIMEOUT | windows.FACILITY_WIN32<<16 | 0x80000000 RPC_E_CHANGED_MODE = syscall.Errno(windows.RPC_E_CHANGED_MODE) S_FALSE = syscall.Errno(windows.S_FALSE) ) diff --git a/internal/win/shell32.go b/internal/win/shell32.go index 5b107e4..61cb0ab 100644 --- a/internal/win/shell32.go +++ b/internal/win/shell32.go @@ -251,6 +251,15 @@ func (u *IFileDialog) SetTitle(title *uint16) (err error) { return } +func (u *IFileDialog) GetResult() (item *IShellItem, err error) { + vtbl := *(**iFileDialogVtbl)(unsafe.Pointer(u)) + hr, _, _ := u.call(vtbl.GetResult, uintptr(unsafe.Pointer(&item))) + if hr != 0 { + err = syscall.Errno(hr) + } + return +} + func (u *IFileDialog) SetDefaultExtension(extension *uint16) (err error) { vtbl := *(**iFileDialogVtbl)(unsafe.Pointer(u)) hr, _, _ := u.call(vtbl.SetDefaultExtension, uintptr(unsafe.Pointer(extension))) @@ -260,9 +269,9 @@ func (u *IFileDialog) SetDefaultExtension(extension *uint16) (err error) { return } -func (u *IFileDialog) GetResult() (item *IShellItem, err error) { +func (u *IFileDialog) Close(res syscall.Errno) (err error) { vtbl := *(**iFileDialogVtbl)(unsafe.Pointer(u)) - hr, _, _ := u.call(vtbl.GetResult, uintptr(unsafe.Pointer(&item))) + hr, _, _ := u.call(vtbl.Close, uintptr(res)) if hr != 0 { err = syscall.Errno(hr) }