123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461 |
- #include "sp_firewallControl.h"
- #include <comutil.h>
- #include <map>
- #include "path.h"
- #include <locale>
- #include <codecvt>
- #include <string>
- // 静态成员初始化
- INetFwPolicy2* FirewallController::firewallPolicy = nullptr;
- bool FirewallController::comInitialized = false;
- // 通配符匹配实现(支持*和?)
- bool FirewallController::WildcardMatch(const std::wstring& pattern, const std::wstring& text) {
- size_t p = 0, s = 0;
- while (1) {
- if (p == pattern.size() && s == text.size()) return true;
- if (p == pattern.size()) return false;
- if (s == text.size()) return p + 1 == pattern.size() && pattern[p] == L'*';
- if (pattern[p] == text[s] || pattern[p] == L'?') {
- p++; s++;
- continue;
- }
- if (pattern[p] == L'*') {
- if (p + 1 == pattern.size()) return true;
- do {
- if (WildcardMatch(pattern.substr(p + 1), text.substr(s)))
- return true;
- s++;
- } while (s != text.size());
- return false;
- }
- return false;
- }
- }
- bool FirewallController::Initialize() {
- if (comInitialized) return true;
- HRESULT hr = CoInitializeEx(0, COINIT_APARTMENTTHREADED);
- if (FAILED(hr)) return false;
- hr = CoCreateInstance(__uuidof(NetFwPolicy2), nullptr, CLSCTX_INPROC_SERVER,
- __uuidof(INetFwPolicy2), (void**)&firewallPolicy);
- comInitialized = SUCCEEDED(hr);
- return comInitialized;
- }
- std::vector<FirewallRuleInfo> FirewallController::QueryRules(const std::wstring& ruleNamePattern) {
- std::vector<FirewallRuleInfo> matchedRules;
- if (!comInitialized || !firewallPolicy) return matchedRules;
- INetFwRules* rules = nullptr;
- if (FAILED(firewallPolicy->get_Rules(&rules))) return matchedRules;
- IEnumVARIANT* enumerator = nullptr;
- if (SUCCEEDED(rules->get__NewEnum((IUnknown**)&enumerator))) {
- VARIANT var;
- while (enumerator->Next(1, &var, nullptr) == S_OK) {
- INetFwRule* rule = nullptr;
- if (SUCCEEDED(V_DISPATCH(&var)->QueryInterface(__uuidof(INetFwRule), (void**)&rule))) {
- BSTR name;
- if (SUCCEEDED(rule->get_Name(&name))) {
- if (WildcardMatch(ruleNamePattern, name)) {
- FirewallRuleInfo info;
- info.name = name;
- BSTR desc, app, service;
- rule->get_Description(&desc);
- rule->get_ApplicationName(&app);
- rule->get_ServiceName(&service);
- NET_FW_RULE_DIRECTION direction;
- rule->get_Direction(&direction);
- info.direction = static_cast<long>(direction);
- VARIANT_BOOL enabled;
- rule->get_Enabled(&enabled);
- info.description = desc ? desc : L"";
- info.applicationName = app ? app : L"";
- info.serviceName = service ? service : L"";
- info.enabled = enabled == VARIANT_TRUE;
- matchedRules.push_back(info);
- SysFreeString(desc);
- SysFreeString(app);
- SysFreeString(service);
- }
- SysFreeString(name);
- }
- rule->Release();
- }
- VariantClear(&var);
- }
- enumerator->Release();
- }
- rules->Release();
- return matchedRules;
- }
- void FirewallController::Shutdown() {
- if (firewallPolicy)
- {
- firewallPolicy->Release();
- firewallPolicy = nullptr;
- }
- if (comInitialized)
- {
- CoUninitialize();
- comInitialized = false;
- }
- }
- bool FirewallController::AddFirewallRule(
- const std::wstring& ruleName,
- const std::wstring& appPath,
- FirewallRuleDirection direction,
- FirewallRuleAction action,
- const std::wstring& protocol,
- const std::wstring& localPorts,
- const std::wstring& remoteAddresses,
- const std::wstring& description)
- {
- if (!comInitialized || !firewallPolicy) return false;
- INetFwRule* rule = nullptr;
- HRESULT hr = CoCreateInstance(
- __uuidof(NetFwRule),
- nullptr,
- CLSCTX_INPROC_SERVER,
- __uuidof(INetFwRule),
- (void**)&rule
- );
- if (FAILED(hr)) return false;
- // 设置规则属性
- rule->put_Name(_bstr_t(ruleName.c_str()));
- if (!appPath.empty()) rule->put_ApplicationName(_bstr_t(appPath.c_str()));
- if (!localPorts.empty()) rule->put_LocalPorts(_bstr_t(localPorts.c_str()));
-
- LONG protocolValue = 0;
- if (protocol == L"TCP") protocolValue = NET_FW_IP_PROTOCOL_TCP;
- else if (protocol == L"UDP") protocolValue = NET_FW_IP_PROTOCOL_UDP;
- else protocolValue = NET_FW_IP_PROTOCOL_ANY; // 默认值
- rule->put_Protocol(protocolValue); // 传入LONG类型值
- rule->put_RemoteAddresses(_bstr_t(remoteAddresses.c_str()));
- rule->put_Direction(static_cast<NET_FW_RULE_DIRECTION>(direction));
- rule->put_Action(static_cast<NET_FW_ACTION>(action));
- rule->put_Enabled(VARIANT_TRUE);
- if (!description.empty()) rule->put_Description(_bstr_t(description.c_str()));
- // 应用到所有配置文件
- rule->put_Profiles(NET_FW_PROFILE2_DOMAIN | NET_FW_PROFILE2_PRIVATE | NET_FW_PROFILE2_PUBLIC);
- // 添加规则
- INetFwRules* rules = nullptr;
- hr = firewallPolicy->get_Rules(&rules);
- if (SUCCEEDED(hr)) {
- hr = rules->Add(rule);
- rules->Release();
- }
- rule->Release();
- if (hr == S_OK) {
- return true; // 明确表示成功
- } else if (hr == E_ACCESSDENIED) {
- // 处理权限错误
- return false;
- }
- else {
- return false;
- }
- }
- bool FirewallController::DeleteFirewallRule(const std::wstring& ruleName) {
- if (!comInitialized || !firewallPolicy) return false;
- INetFwRules* rules = nullptr;
- HRESULT hr = firewallPolicy->get_Rules(&rules);
- if (FAILED(hr)) return false;
- hr = rules->Remove(_bstr_t(ruleName.c_str()));
- rules->Release();
- if (hr == S_OK) {
- return true; // 明确表示成功
- } else if (hr == E_ACCESSDENIED) {
- // 处理权限错误
- return false;
- }
- else{
- return false;
- }
- }
- bool FirewallController::CleanupRulesExceptWhitelist(
- const std::wstring& ruleNamePattern,
- const std::vector<std::wstring>& whitelistPaths) {
- if (!comInitialized || !firewallPolicy) return false;
- auto matchedRules = QueryRules(ruleNamePattern);
- if (matchedRules.empty()) return true;
- INetFwRules* rules = nullptr;
- HRESULT hr = firewallPolicy->get_Rules(&rules);
- if (FAILED(hr)) return false;
- bool allSuccess = true;
- for (const auto& ruleInfo : matchedRules) {
- bool isWhitelisted = false;
- for (const auto& whitelistPath : whitelistPaths) {
- if (WildcardMatch(whitelistPath, ruleInfo.applicationName)) {
- isWhitelisted = true;
- break;
- }
- }
- if (!isWhitelisted) {
- hr = rules->Remove(_bstr_t(ruleInfo.name.c_str()));
- if (FAILED(hr)) allSuccess = false;
- }
- }
- rules->Release();
- return allSuccess;
- }
- std::wstring charToWstring(const char* szIn) {
- int length = MultiByteToWideChar(CP_ACP, 0, szIn, -1, NULL, 0);
- wchar_t* buf = new wchar_t[length];
- MultiByteToWideChar(CP_ACP, 0, szIn, -1, buf, length);
- std::wstring result(buf);
- delete[] buf;
- return result;
- }
- std::string WStringToString(const std::wstring& wstr, unsigned int codepage = CP_ACP) {
- if (wstr.empty()) return "";
- int len = WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL);
- if (len <= 0) return "";
- std::string result(len, '\0');
- WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), &result[0], len, NULL, NULL);
- return result;
- }
- bool sp_AddFirewallRule(const char *ruleName, const char *appPath)
- {
- int ret = -1;
- if (FirewallController::Initialize()) {
- std::wstring ruleNameW = charToWstring(ruleName);
- std::wstring appPathW = charToWstring(appPath);
- ret = FirewallController::AddFirewallRule(ruleNameW, appPathW,
- FirewallRuleDirection::Inbound, FirewallRuleAction::Allow) ? 0 : -1;
- FirewallController::Shutdown();
- }
- return ret;
- }
- std::string getFirstLevelDir(const std::string& path) {
- size_t start = path.find_first_of("\\/"); // 查找第一个分隔符
- if (start == std::string::npos) return ""; // 无分隔符时返回空
-
- size_t end = path.find_first_of("\\/", start + 1); // 查找第二个分隔符
- if (end == std::string::npos) end = path.length();
-
- return path.substr(start + 1, end - start - 1); // 截取Runxxx
- }
- std::string getLastFolderName(std::string path) {
- // 找到最后一个反斜杠的位置
- size_t lastSlash = path.find_last_of("\\/");
- if (lastSlash != std::string::npos) {
- return path.substr(lastSlash + 1);
- }
- return path; // 无斜杠时返回整个字符串
- }
- bool sp_AddFirewallRuleByPath(const char *pszPath)
- {
- bool returnRet = true;
- //pszPath input path D:\\Runxxx\\version\\7.1.1.1
- std::string inputPath = std::string(pszPath);
- while (!inputPath.empty() && (inputPath.back() == '\\' || inputPath.back() == '/')) {
- inputPath.pop_back();
- }
- std::string firstLevelDir = getFirstLevelDir(inputPath);
- std::string lastDir = getLastFolderName(inputPath);
- std::string header = firstLevelDir + "_" + lastDir;
- std::map<std::string, std::string> firewallRule;
- firewallRule[header + "_guardian"] = inputPath + "\\bin\\guardian.exe";
- firewallRule[header + "_sphost"] = inputPath + "\\bin\\sphost.exe";
- firewallRule[header + "_spshell"] = inputPath + "\\bin\\spshell.exe";
- firewallRule[header + "_cefclient"] = inputPath + "\\bin\\Chromium\\cefclient.exe";
- FirewallController::Initialize();
- for(auto &it : firewallRule)
- {
- std::wstring ruleNameW = charToWstring(it.first.c_str());
- std::wstring appPathW = charToWstring(it.second.c_str());
- bool ret = FirewallController::AddFirewallRule(ruleNameW, appPathW, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow);
- DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
- ("Add firewall rule %s. firstLevelDir: %s, ruleName: %s, path: %s", ret ? "success" : "failed",
- firstLevelDir.c_str(), it.first.c_str(), it.second.c_str());
- returnRet = returnRet && ret;
- }
- FirewallController::Shutdown();
- return returnRet;
- }
- void listFolders(const std::string& dirPath, std::vector<std::string>& folders) {
- folders.clear();
- WIN32_FIND_DATA findData;
- HANDLE hFind = FindFirstFile((dirPath + "\\*").c_str(), &findData);
- if (hFind == INVALID_HANDLE_VALUE) {
- return;
- }
- do {
- if (findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) {
- if (strcmp(findData.cFileName, ".") != 0 && strcmp(findData.cFileName, "..") != 0) {
- folders.push_back(findData.cFileName); // 存入vector
- }
- }
- } while (FindNextFile(hFind, &findData));
- FindClose(hFind);
- }
- using Multimap = std::multimap<std::string, std::string>;
- // 比较两个multimap,返回多出和缺少的键值对
- void compareMultimaps(const Multimap& base, const Multimap& target,
- std::vector<std::pair<std::string, std::string>>& extra,
- std::vector<std::pair<std::string, std::string>>& missing) {
- // 检查多出的数据:遍历target,不在base中的键值对
- for (const auto& pair : target) {
- auto range = base.equal_range(pair.first);
- auto it = std::find_if(range.first, range.second, [&pair](const auto& p) {
- return p.second == pair.second;
- });
- if (it == range.second) {
- extra.push_back(pair);
- }
- }
- // 检查缺少的数据:遍历base,不在target中的键值对
- for (const auto& pair : base) {
- auto range = target.equal_range(pair.first);
- auto it = std::find_if(range.first, range.second, [&pair](const auto& p) {
- return p.second == pair.second;
- });
- if (it == range.second) {
- missing.push_back(pair);
- }
- }
- }
- bool sp_CheckAllRules()
- {
- if (FirewallController::Initialize() == false)
- {
- DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("FirewallController::Initialize failed");
- return false;
- }
- //get current first level dir and version path
- char szVersionDir[MAX_PATH] = {};
- GetModuleFileNameA(NULL, szVersionDir, MAX_PATH);
- *strrchr(szVersionDir, SPLIT_SLASH) = 0;
- *strrchr(szVersionDir, SPLIT_SLASH) = 0;
- std::string currentVerDir = szVersionDir;
- *strrchr(szVersionDir, SPLIT_SLASH) = 0;
- std::string versionDir = szVersionDir;
- while (!versionDir.empty() && (versionDir.back() == '\\' || versionDir.back() == '/')) {
- versionDir.pop_back();
- }
- //basic check
- std::string firstLevelDir = getFirstLevelDir(versionDir);
-
- std::vector<std::string> WhiteListDirs;
- listFolders(versionDir, WhiteListDirs);
- if (WhiteListDirs.size() == 0)
- {
- DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("WhiteListDirs is empty. versionDir: %s", versionDir.c_str());
- return false;
- }
-
- // firewall ruls I guess
- Multimap firewallRuleAll;
- for(auto &it : WhiteListDirs)
- {
- std::string lastDir = getLastFolderName(it);
- std::string header = firstLevelDir + "_" + lastDir;
- std::string currentVerDir = versionDir + "\\" + it;
- firewallRuleAll.insert({header + "_guardian", currentVerDir + "\\bin\\guardian.exe"});
- firewallRuleAll.insert({header + "_sphost", currentVerDir + "\\bin\\sphost.exe"});
- firewallRuleAll.insert({header + "_spshell", currentVerDir + "\\bin\\spshell.exe"});
- firewallRuleAll.insert({header + "_cefclient", currentVerDir + "\\bin\\Chromium\\cefclient.exe"});
- }
-
- //query all rules
- Multimap ruleArr;
- std::string pattern = firstLevelDir + "_*";
- std::wstring headerW = charToWstring(pattern.c_str());
- auto rules = FirewallController::QueryRules(headerW);
- for (const auto& ruleInfo : rules) {
- std::string ruleName = WStringToString(ruleInfo.name);
- std::string appPath = WStringToString(ruleInfo.applicationName);
- ruleArr.insert(std::make_pair(ruleName, appPath));// can repeated
- }
- for(auto &it : firewallRuleAll)
- {
- DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("guess firewall rule:%s, %s", it.first.c_str(), it.second.c_str());
- }
- for(auto &it : ruleArr)
- {
- DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("current firewall rule:%s, %s", it.first.c_str(), it.second.c_str());
- }
- std::vector<std::pair<std::string, std::string>> extra, missing;
- compareMultimaps(firewallRuleAll, ruleArr, extra, missing);
- for(auto &it : extra)
- {
- std::wstring ruleName = charToWstring(it.first.c_str());
- bool ret = FirewallController::DeleteFirewallRule(ruleName);
- DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
- ("Delete firewall rule %s. ruleName: %s", ret ? "success" : "failed", it.first.c_str());
- }
- for(auto &it : missing)
- {
- std::wstring ruleName = charToWstring(it.first.c_str());
- std::wstring appPath = charToWstring(it.second.c_str());
- bool ret = FirewallController::AddFirewallRule(ruleName, appPath, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow);
- DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
- ("Add firewall rule %s. header: %s, ruleName: %s, path: %s", ret ? "success" : "failed",
- firstLevelDir.c_str(), it.first.c_str(), it.second.c_str());
- }
-
- FirewallController::Shutdown();
- return true;
- }
|