winamp/Src/external_dependencies/openmpt-trunk/mptrack/HTTP.cpp

520 lines
12 KiB
C++
Raw Permalink Normal View History

2024-09-24 13:54:57 +01:00
/*
* HTTP.cpp
* --------
* Purpose: Simple HTTP client interface.
* Notes : (currently none)
* Authors: OpenMPT Devs
* The OpenMPT source code is released under the BSD license. Read LICENSE for more details.
*/
#include "stdafx.h"
#include "HTTP.h"
#include "mpt/system_error/system_error.hpp"
#include <WinInet.h>
#include "mpt/io/io.hpp"
#include "mpt/io/io_stdstream.hpp"
OPENMPT_NAMESPACE_BEGIN
URI ParseURI(mpt::ustring str)
{
URI uri;
std::size_t scheme_delim_pos = str.find(':');
if(scheme_delim_pos == mpt::ustring::npos)
{
throw bad_uri("no scheme delimiter");
}
if(scheme_delim_pos == 0)
{
throw bad_uri("no scheme");
}
uri.scheme = str.substr(0, scheme_delim_pos);
str = str.substr(scheme_delim_pos + 1);
if(str.substr(0, 2) == U_("//"))
{
str = str.substr(2);
std::size_t authority_delim_pos = str.find_first_of(U_("/?#"));
mpt::ustring authority = str.substr(0, authority_delim_pos);
std::size_t userinfo_delim_pos = authority.find(U_("@"));
if(userinfo_delim_pos != mpt::ustring::npos)
{
mpt::ustring userinfo = authority.substr(0, userinfo_delim_pos);
authority = authority.substr(userinfo_delim_pos + 1);
std::size_t username_delim_pos = userinfo.find(U_(":"));
uri.username = userinfo.substr(0, username_delim_pos);
if(username_delim_pos != mpt::ustring::npos)
{
uri.password = userinfo.substr(username_delim_pos + 1);
}
}
std::size_t beg_bracket_pos = authority.find(U_("["));
std::size_t end_bracket_pos = authority.find(U_("]"));
std::size_t port_delim_pos = authority.find_last_of(U_(":"));
if(beg_bracket_pos != mpt::ustring::npos && end_bracket_pos != mpt::ustring::npos)
{
if(port_delim_pos != mpt::ustring::npos && port_delim_pos > end_bracket_pos)
{
uri.host = authority.substr(0, port_delim_pos);
uri.port = authority.substr(port_delim_pos + 1);
} else
{
uri.host = authority;
}
} else
{
uri.host = authority.substr(0, port_delim_pos);
if(port_delim_pos != mpt::ustring::npos)
{
uri.port = authority.substr(port_delim_pos + 1);
}
}
if(authority_delim_pos != mpt::ustring::npos)
{
str = str.substr(authority_delim_pos);
} else
{
str = U_("");
}
}
std::size_t path_delim_pos = str.find_first_of(U_("?#"));
uri.path = str.substr(0, path_delim_pos);
if(path_delim_pos != mpt::ustring::npos)
{
str = str.substr(path_delim_pos);
std::size_t query_delim_pos = str.find(U_("#"));
if(query_delim_pos != mpt::ustring::npos)
{
if(query_delim_pos > 0)
{
uri.query = str.substr(1, query_delim_pos - 1);
uri.fragment = str.substr(query_delim_pos + 1);
} else
{
uri.fragment = str.substr(query_delim_pos + 1);
}
} else
{
uri.query = str.substr(1);
}
}
return uri;
}
namespace HTTP
{
exception::exception(const mpt::ustring &m)
: std::runtime_error(std::string("HTTP error: ") + mpt::ToCharset(mpt::CharsetException, m))
{
message = m;
}
mpt::ustring exception::GetMessage() const
{
return message;
}
class LastErrorException
: public exception
{
public:
LastErrorException()
: exception(mpt::windows::GetErrorMessage(GetLastError(), GetModuleHandle(TEXT("wininet.dll"))))
{
}
};
struct NativeHandle
{
HINTERNET native_handle;
NativeHandle(HINTERNET h)
: native_handle(h)
{
}
operator HINTERNET() const
{
return native_handle;
}
};
Handle::Handle()
: handle(std::make_unique<NativeHandle>(HINTERNET(NULL)))
{
}
Handle::operator bool() const
{
return handle->native_handle != HINTERNET(NULL);
}
bool Handle::operator!() const
{
return handle->native_handle == HINTERNET(NULL);
}
Handle::Handle(NativeHandle h)
: handle(std::make_unique<NativeHandle>(HINTERNET(NULL)))
{
handle->native_handle = h.native_handle;
}
Handle & Handle::operator=(NativeHandle h)
{
if(handle->native_handle)
{
InternetCloseHandle(handle->native_handle);
handle->native_handle = HINTERNET(NULL);
}
handle->native_handle = h.native_handle;
return *this;
}
Handle::operator NativeHandle ()
{
return *handle;
}
Handle::~Handle()
{
if(handle->native_handle)
{
InternetCloseHandle(handle->native_handle);
handle->native_handle = HINTERNET(NULL);
}
}
InternetSession::InternetSession(mpt::ustring userAgent)
{
internet = NativeHandle(InternetOpen(mpt::ToWin(userAgent).c_str(), INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0));
if(!internet)
{
throw HTTP::LastErrorException();
}
}
InternetSession::operator NativeHandle ()
{
return internet;
}
static mpt::winstring Verb(Method method)
{
mpt::winstring result;
switch(method)
{
case Method::Get:
result = _T("GET");
break;
case Method::Head:
result = _T("HEAD");
break;
case Method::Post:
result = _T("POST");
break;
case Method::Put:
result = _T("PUT");
break;
case Method::Delete:
result = _T("DELETE");
break;
case Method::Trace:
result = _T("TRACE");
break;
case Method::Options:
result = _T("OPTIONS");
break;
case Method::Connect:
result = _T("CONNECT");
break;
case Method::Patch:
result = _T("PATCH");
break;
}
return result;
}
static bool IsCachable(Method method)
{
return method == Method::Get || method == Method::Head;
}
namespace
{
class AcceptMimeTypesWrapper
{
private:
std::vector<mpt::winstring> strings;
std::vector<LPCTSTR> array;
public:
AcceptMimeTypesWrapper(AcceptMimeTypes acceptMimeTypes)
{
for(const auto &mimeType : acceptMimeTypes)
{
strings.push_back(mpt::ToWin(mpt::Charset::ASCII, mimeType));
}
array.resize(strings.size() + 1);
for(std::size_t i = 0; i < strings.size(); ++i)
{
array[i] = strings[i].c_str();
}
array[strings.size()] = NULL;
}
operator LPCTSTR*()
{
return strings.empty() ? NULL : array.data();
}
};
}
void Request::progress(Progress progress, uint64 transferred, std::optional<uint64> expectedSize) const
{
if(progressCallback)
{
progressCallback(progress, transferred, expectedSize);
}
}
Result Request::operator()(InternetSession &internet) const
{
progress(Progress::Start, 0, std::nullopt);
Port actualPort = port;
if(actualPort == Port::Default)
{
actualPort = (protocol != Protocol::HTTP) ? Port::HTTPS : Port::HTTP;
}
Handle connection = NativeHandle(InternetConnect(
NativeHandle(internet),
mpt::ToWin(host).c_str(),
static_cast<uint16>(actualPort),
!username.empty() ? mpt::ToWin(username).c_str() : NULL,
!password.empty() ? mpt::ToWin(password).c_str() : NULL,
INTERNET_SERVICE_HTTP,
0,
0));
if(!connection)
{
throw HTTP::LastErrorException();
}
progress(Progress::ConnectionEstablished, 0, std::nullopt);
mpt::ustring queryPath = path;
if(!query.empty())
{
std::vector<mpt::ustring> arguments;
for(const auto &[key, value] : query)
{
if(!value.empty())
{
arguments.push_back(MPT_UFORMAT("{}={}")(key, value));
} else
{
arguments.push_back(MPT_UFORMAT("{}")(key));
}
}
queryPath += U_("?") + mpt::String::Combine(arguments, U_("&"));
}
Handle request = NativeHandle(HttpOpenRequest(
NativeHandle(connection),
Verb(method).c_str(),
mpt::ToWin(path).c_str(),
NULL,
!referrer.empty() ? mpt::ToWin(referrer).c_str() : NULL,
AcceptMimeTypesWrapper(acceptMimeTypes),
0
| ((protocol != Protocol::HTTP) ? INTERNET_FLAG_SECURE : 0)
| (IsCachable(method) ? 0 : INTERNET_FLAG_NO_CACHE_WRITE)
| ((flags & NoCache) ? (INTERNET_FLAG_PRAGMA_NOCACHE | INTERNET_FLAG_RELOAD | INTERNET_FLAG_NO_CACHE_WRITE) : 0)
| ((flags & AutoRedirect) ? 0 : INTERNET_FLAG_NO_AUTO_REDIRECT)
,
NULL));
if(!request)
{
throw HTTP::LastErrorException();
}
progress(Progress::RequestOpened, 0, std::nullopt);
{
std::string headersString;
if(!dataMimeType.empty())
{
headersString += MPT_AFORMAT("Content-type: {}\r\n")(dataMimeType);
}
if(!headers.empty())
{
for(const auto &[key, value] : headers)
{
headersString += MPT_AFORMAT("{}: {}\r\n")(key, value);
}
}
if(HttpSendRequest(
NativeHandle(request),
!headersString.empty() ? mpt::ToWin(mpt::Charset::ASCII, headersString).c_str() : NULL,
!headersString.empty() ? mpt::saturate_cast<DWORD>(mpt::ToWin(mpt::Charset::ASCII, headersString).length()) : 0,
!data.empty() ? (LPVOID)data.data() : NULL,
!data.empty() ? mpt::saturate_cast<DWORD>(data.size()) : 0)
== FALSE)
{
throw HTTP::LastErrorException();
}
}
progress(Progress::RequestSent, 0, std::nullopt);
Result result;
{
DWORD statusCode = 0;
DWORD length = sizeof(statusCode);
if(HttpQueryInfo(NativeHandle(request), HTTP_QUERY_STATUS_CODE | HTTP_QUERY_FLAG_NUMBER, &statusCode, &length, NULL) == FALSE)
{
throw HTTP::LastErrorException();
}
result.Status = statusCode;
}
progress(Progress::ResponseReceived, 0, std::nullopt);
DWORD contentLength = static_cast<DWORD>(-1);
{
DWORD length = sizeof(contentLength);
if(HttpQueryInfo(NativeHandle(request), HTTP_QUERY_CONTENT_LENGTH | HTTP_QUERY_FLAG_NUMBER, &contentLength, &length, NULL) == FALSE)
{
contentLength = static_cast<DWORD>(-1);
}
}
uint64 transferred = 0;
if(contentLength != static_cast<DWORD>(-1))
{
result.ContentLength = contentLength;
}
std::optional<uint64> expectedSize = result.ContentLength;
progress(Progress::TransferBegin, transferred, expectedSize);
{
std::vector<std::byte> resultBuffer;
DWORD bytesRead = 0;
do
{
std::array<std::byte, mpt::IO::BUFFERSIZE_TINY> downloadBuffer;
DWORD availableSize = 0;
if(InternetQueryDataAvailable(NativeHandle(request), &availableSize, 0, NULL) == FALSE)
{
throw HTTP::LastErrorException();
}
availableSize = std::clamp(availableSize, DWORD(0), mpt::saturate_cast<DWORD>(mpt::IO::BUFFERSIZE_TINY));
if(InternetReadFile(NativeHandle(request), downloadBuffer.data(), availableSize, &bytesRead) == FALSE)
{
throw HTTP::LastErrorException();
}
if(outputStream)
{
if(!mpt::IO::WriteRaw(*outputStream, mpt::as_span(downloadBuffer).first(bytesRead)))
{
throw HTTP::exception(U_("Writing output file failed."));
}
} else
{
mpt::append(resultBuffer, downloadBuffer.data(), downloadBuffer.data() + bytesRead);
}
transferred += bytesRead;
progress(Progress::TransferRunning, transferred, expectedSize);
} while(bytesRead != 0);
result.Data = std::move(resultBuffer);
}
progress(Progress::TransferDone, transferred, expectedSize);
return result;
}
Request &Request::SetURI(const URI &uri)
{
if(uri.scheme == U_(""))
{
throw bad_uri("no scheme");
} else if(uri.scheme == U_("http"))
{
protocol = HTTP::Protocol::HTTP;
} else if(uri.scheme == U_("https"))
{
protocol = HTTP::Protocol::HTTPS;
} else
{
throw bad_uri("wrong scheme");
}
host = uri.host;
if(!uri.port.empty())
{
port = HTTP::Port(ConvertStrTo<uint16>(uri.port));
} else
{
port = HTTP::Port::Default;
}
username = uri.username;
password = uri.password;
if(uri.path.empty())
{
path = U_("/");
} else
{
path = uri.path;
}
query.clear();
auto keyvals = mpt::String::Split<mpt::ustring>(uri.query, U_("&"));
for(const auto &keyval : keyvals)
{
std::size_t delim_pos = keyval.find(U_("="));
mpt::ustring key = keyval.substr(0, delim_pos);
mpt::ustring val;
if(delim_pos != mpt::ustring::npos)
{
val = keyval.substr(delim_pos + 1);
}
query.push_back(std::make_pair(key, val));
}
// ignore fragment
return *this;
}
#if defined(MPT_BUILD_RETRO)
Request &Request::InsecureTLSDowngradeWindowsXP()
{
if(mpt::OS::Windows::IsOriginal() && mpt::OS::Windows::Version::Current().IsBefore(mpt::OS::Windows::Version::WinVista))
{
// TLS 1.0 is not enabled by default until IE7. Since WinInet won't let us override this setting, we cannot assume that HTTPS
// is going to work on older systems. Besides... Windows XP is already enough of a security risk by itself. :P
if(protocol == Protocol::HTTPS)
{
protocol = Protocol::HTTP;
}
if(port == Port::HTTPS)
{
port = Port::HTTP;
}
}
return *this;
}
#endif // MPT_BUILD_RETRO
Result SimpleGet(InternetSession &internet, Protocol protocol, const mpt::ustring &host, const mpt::ustring &path)
{
HTTP::Request request;
request.protocol = protocol;
request.host = host;
request.method = HTTP::Method::Get;
request.path = path;
return internet(request);
}
} // namespace HTTP
OPENMPT_NAMESPACE_END