Skip to content
Open
47 changes: 39 additions & 8 deletions src/windows/common/wslutil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,45 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
Filename = Url.substr(lastSlash + 1);
}

const auto downloadFolder =
winrt::Windows::Storage::StorageFolder::GetFolderFromPathAsync(std::filesystem::temp_directory_path().wstring()).get();

const auto file =
downloadFolder.CreateFileAsync(Filename, winrt::Windows::Storage::CreationCollisionOption::GenerateUniqueName).get();
auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] { file.DeleteAsync().get(); });
// GetFolderFromPathAsync won't work if the folder is hidden or system.
auto downloadFolderPath = std::filesystem::temp_directory_path();
auto filenameStem = std::filesystem::path(Filename).stem().wstring();
auto filenameExtension = std::filesystem::path(Filename).extension().wstring();
Comment thread
chemwolf6922 marked this conversation as resolved.
std::wstring filePath{};
winrt::Windows::Storage::Streams::IRandomAccessStream outputStream{};
for (int suffix = 1; outputStream == nullptr; suffix++)
{
if (suffix == 1)
{
filePath = (downloadFolderPath / Filename).wstring();
}
else
{
filePath = (downloadFolderPath / std::format(L"{} ({}){}", filenameStem, suffix, filenameExtension)).wstring();
}
try
{
outputStream = winrt::Windows::Storage::Streams::FileRandomAccessStream::OpenAsync(
filePath,
winrt::Windows::Storage::FileAccessMode::ReadWrite,
winrt::Windows::Storage::StorageOpenOptions::None,
winrt::Windows::Storage::Streams::FileOpenDisposition::CreateNew)
.get();
}
catch (...)
{
if (wil::ResultFromCaughtException() != HRESULT_FROM_WIN32(ERROR_ALREADY_EXISTS))
{
throw;
}
}
}

const auto outputStream = file.OpenAsync(winrt::Windows::Storage::FileAccessMode::ReadWrite).get().GetOutputStreamAt(0);
auto deleteFileOnFailure = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
outputStream.Close();
std::error_code ec;
std::filesystem::remove(filePath, ec);
});

// By default downloaded files are cached in %appdata%/local/packages/{package-family}/AC/InetCache .
// Disable caching since there's no reason to keep local copies of .msixbundle files.
Expand Down Expand Up @@ -410,7 +441,7 @@ std::wstring wsl::windows::common::wslutil::DownloadFileImpl(
download.get();
deleteFileOnFailure.release();

return file.Path().c_str();
return filePath;
}

[[nodiscard]] HANDLE wsl::windows::common::wslutil::DuplicateHandle(_In_ HANDLE Handle, _In_ std::optional<DWORD> DesiredAccess, _In_ BOOL InheritHandle)
Expand Down
65 changes: 65 additions & 0 deletions test/windows/UnitTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7448,5 +7448,70 @@ Error code: Wsl/InstallDistro/WSL_E_INVALID_JSON\r\n",
}
}

TEST_METHOD(DownloadToHiddenSystemTempFolder)
{
// Avoid contaminating the real temp folder.
const auto testTempFolder = std::filesystem::temp_directory_path() / L"wsl-download-test";
std::filesystem::create_directories(testTempFolder);
auto cleanupTempFolder = wil::scope_exit_log(WI_DIAGNOSTICS_INFO, [&] {
std::error_code error;
std::filesystem::remove_all(testTempFolder, error);
});

const auto originalAttributes = GetFileAttributesW(testTempFolder.c_str());
VERIFY_IS_TRUE(originalAttributes != INVALID_FILE_ATTRIBUTES);
VERIFY_IS_TRUE(SetFileAttributesW(testTempFolder.c_str(), originalAttributes | FILE_ATTRIBUTE_HIDDEN | FILE_ATTRIBUTE_SYSTEM));
Comment thread
chemwolf6922 marked this conversation as resolved.

ScopedEnvVariable temp(L"TEMP", testTempFolder.wstring());
ScopedEnvVariable tmp(L"TMP", testTempFolder.wstring());

VERIFY_IS_TRUE(std::filesystem::equivalent(std::filesystem::temp_directory_path(), testTempFolder));

constexpr USHORT port = 6666;
const auto endpoint = std::format(L"http://127.0.0.1:{}/", port);
constexpr auto fileName = L"downloaded-file.bin";
constexpr auto fileContent = L"wsl download test content";
UniqueWebServer server(endpoint.c_str(), fileContent);

const auto url = endpoint + fileName;
const auto noProgress = [](uint64_t, uint64_t) {};

wsl::shared::retry::RetryWithTimeout<void>(
[&]() {
wil::unique_socket probe{socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)};
THROW_LAST_ERROR_IF(!probe);

sockaddr_in address{};
address.sin_family = AF_INET;
address.sin_port = htons(port);
address.sin_addr.s_addr = htonl(INADDR_LOOPBACK);

THROW_LAST_ERROR_IF(connect(probe.get(), reinterpret_cast<const sockaddr*>(&address), sizeof(address)) == SOCKET_ERROR);
},
std::chrono::milliseconds(500),
std::chrono::seconds(5));

const auto firstPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress);

auto readFile = [](const std::filesystem::path& Path) {
std::ifstream file(Path, std::ios::binary);
VERIFY_IS_TRUE(file.good());
return std::string{std::istreambuf_iterator<char>(file), {}};
};

VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).parent_path(), testTempFolder);
VERIFY_ARE_EQUAL(std::filesystem::path(firstPath).filename().wstring(), std::wstring(fileName));
VERIFY_IS_TRUE(std::filesystem::exists(firstPath));
VERIFY_ARE_EQUAL(readFile(firstPath), wsl::shared::string::WideToMultiByte(fileContent));

const auto secondPath = wsl::windows::common::wslutil::DownloadFileImpl(url, L"", noProgress);

VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).parent_path(), testTempFolder);
VERIFY_ARE_EQUAL(std::filesystem::path(secondPath).filename().wstring(), std::wstring(L"downloaded-file (2).bin"));
VERIFY_IS_TRUE(std::filesystem::exists(firstPath));
VERIFY_IS_TRUE(std::filesystem::exists(secondPath));
VERIFY_ARE_EQUAL(readFile(secondPath), wsl::shared::string::WideToMultiByte(fileContent));
}

}; // namespace UnitTests
} // namespace UnitTests