#include "sp_firewallControl.h" #include #include #include "path.h" #include #include #include // 静态成员初始化 INetFwPolicy2* FirewallController::firewallPolicy = nullptr; bool FirewallController::comInitialized = false; // 通配符匹配实现(支持*和?) bool FirewallController::WildcardMatch(const std::wstring& pattern, const std::wstring& text) { size_t p = 0, s = 0; while (1) { if (p == pattern.size() && s == text.size()) return true; if (p == pattern.size()) return false; if (s == text.size()) return p + 1 == pattern.size() && pattern[p] == L'*'; if (pattern[p] == text[s] || pattern[p] == L'?') { p++; s++; continue; } if (pattern[p] == L'*') { if (p + 1 == pattern.size()) return true; do { if (WildcardMatch(pattern.substr(p + 1), text.substr(s))) return true; s++; } while (s != text.size()); return false; } return false; } } bool FirewallController::Initialize() { if (comInitialized) return true; HRESULT hr = CoInitializeEx(0, COINIT_APARTMENTTHREADED); if (FAILED(hr)) return false; hr = CoCreateInstance(__uuidof(NetFwPolicy2), nullptr, CLSCTX_INPROC_SERVER, __uuidof(INetFwPolicy2), (void**)&firewallPolicy); comInitialized = SUCCEEDED(hr); return comInitialized; } std::vector FirewallController::QueryRules(const std::wstring& ruleNamePattern) { std::vector matchedRules; if (!comInitialized || !firewallPolicy) return matchedRules; INetFwRules* rules = nullptr; if (FAILED(firewallPolicy->get_Rules(&rules))) return matchedRules; IEnumVARIANT* enumerator = nullptr; if (SUCCEEDED(rules->get__NewEnum((IUnknown**)&enumerator))) { VARIANT var; while (enumerator->Next(1, &var, nullptr) == S_OK) { INetFwRule* rule = nullptr; if (SUCCEEDED(V_DISPATCH(&var)->QueryInterface(__uuidof(INetFwRule), (void**)&rule))) { BSTR name; if (SUCCEEDED(rule->get_Name(&name))) { if (WildcardMatch(ruleNamePattern, name)) { FirewallRuleInfo info; info.name = name; BSTR desc, app, service; rule->get_Description(&desc); rule->get_ApplicationName(&app); rule->get_ServiceName(&service); NET_FW_RULE_DIRECTION direction; rule->get_Direction(&direction); info.direction = static_cast(direction); VARIANT_BOOL enabled; rule->get_Enabled(&enabled); info.description = desc ? desc : L""; info.applicationName = app ? app : L""; info.serviceName = service ? service : L""; info.enabled = enabled == VARIANT_TRUE; matchedRules.push_back(info); SysFreeString(desc); SysFreeString(app); SysFreeString(service); } SysFreeString(name); } rule->Release(); } VariantClear(&var); } enumerator->Release(); } rules->Release(); return matchedRules; } void FirewallController::Shutdown() { if (firewallPolicy) { firewallPolicy->Release(); firewallPolicy = nullptr; } if (comInitialized) { CoUninitialize(); comInitialized = false; } } bool FirewallController::AddFirewallRule( const std::wstring& ruleName, const std::wstring& appPath, FirewallRuleDirection direction, FirewallRuleAction action, const std::wstring& protocol, const std::wstring& localPorts, const std::wstring& remoteAddresses, const std::wstring& description) { if (!comInitialized || !firewallPolicy) return false; INetFwRule* rule = nullptr; HRESULT hr = CoCreateInstance( __uuidof(NetFwRule), nullptr, CLSCTX_INPROC_SERVER, __uuidof(INetFwRule), (void**)&rule ); if (FAILED(hr)) return false; // 设置规则属性 rule->put_Name(_bstr_t(ruleName.c_str())); if (!appPath.empty()) rule->put_ApplicationName(_bstr_t(appPath.c_str())); if (!localPorts.empty()) rule->put_LocalPorts(_bstr_t(localPorts.c_str())); LONG protocolValue = 0; if (protocol == L"TCP") protocolValue = NET_FW_IP_PROTOCOL_TCP; else if (protocol == L"UDP") protocolValue = NET_FW_IP_PROTOCOL_UDP; else protocolValue = NET_FW_IP_PROTOCOL_ANY; // 默认值 rule->put_Protocol(protocolValue); // 传入LONG类型值 rule->put_RemoteAddresses(_bstr_t(remoteAddresses.c_str())); rule->put_Direction(static_cast(direction)); rule->put_Action(static_cast(action)); rule->put_Enabled(VARIANT_TRUE); if (!description.empty()) rule->put_Description(_bstr_t(description.c_str())); // 应用到所有配置文件 rule->put_Profiles(NET_FW_PROFILE2_DOMAIN | NET_FW_PROFILE2_PRIVATE | NET_FW_PROFILE2_PUBLIC); // 添加规则 INetFwRules* rules = nullptr; hr = firewallPolicy->get_Rules(&rules); if (SUCCEEDED(hr)) { hr = rules->Add(rule); rules->Release(); } rule->Release(); if (hr == S_OK) { return true; // 明确表示成功 } else if (hr == E_ACCESSDENIED) { // 处理权限错误 return false; } else { return false; } } bool FirewallController::DeleteFirewallRule(const std::wstring& ruleName) { if (!comInitialized || !firewallPolicy) return false; INetFwRules* rules = nullptr; HRESULT hr = firewallPolicy->get_Rules(&rules); if (FAILED(hr)) return false; hr = rules->Remove(_bstr_t(ruleName.c_str())); rules->Release(); if (hr == S_OK) { return true; // 明确表示成功 } else if (hr == E_ACCESSDENIED) { // 处理权限错误 return false; } else{ return false; } } bool FirewallController::CleanupRulesExceptWhitelist( const std::wstring& ruleNamePattern, const std::vector& whitelistPaths) { if (!comInitialized || !firewallPolicy) return false; auto matchedRules = QueryRules(ruleNamePattern); if (matchedRules.empty()) return true; INetFwRules* rules = nullptr; HRESULT hr = firewallPolicy->get_Rules(&rules); if (FAILED(hr)) return false; bool allSuccess = true; for (const auto& ruleInfo : matchedRules) { bool isWhitelisted = false; for (const auto& whitelistPath : whitelistPaths) { if (WildcardMatch(whitelistPath, ruleInfo.applicationName)) { isWhitelisted = true; break; } } if (!isWhitelisted) { hr = rules->Remove(_bstr_t(ruleInfo.name.c_str())); if (FAILED(hr)) allSuccess = false; } } rules->Release(); return allSuccess; } std::wstring charToWstring(const char* szIn) { int length = MultiByteToWideChar(CP_ACP, 0, szIn, -1, NULL, 0); wchar_t* buf = new wchar_t[length]; MultiByteToWideChar(CP_ACP, 0, szIn, -1, buf, length); std::wstring result(buf); delete[] buf; return result; } std::string WStringToString(const std::wstring& wstr, unsigned int codepage = CP_ACP) { if (wstr.empty()) return ""; int len = WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL); if (len <= 0) return ""; std::string result(len, '\0'); WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), &result[0], len, NULL, NULL); return result; } bool sp_AddFirewallRule(const char *ruleName, const char *appPath) { int ret = -1; if (FirewallController::Initialize()) { std::wstring ruleNameW = charToWstring(ruleName); std::wstring appPathW = charToWstring(appPath); ret = FirewallController::AddFirewallRule(ruleNameW, appPathW, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow) ? 0 : -1; FirewallController::Shutdown(); } return ret; } std::string getFirstLevelDir(const std::string& path) { size_t start = path.find_first_of("\\/"); // 查找第一个分隔符 if (start == std::string::npos) return ""; // 无分隔符时返回空 size_t end = path.find_first_of("\\/", start + 1); // 查找第二个分隔符 if (end == std::string::npos) end = path.length(); return path.substr(start + 1, end - start - 1); // 截取Runxxx } std::string getLastFolderName(std::string path) { // 找到最后一个反斜杠的位置 size_t lastSlash = path.find_last_of("\\/"); if (lastSlash != std::string::npos) { return path.substr(lastSlash + 1); } return path; // 无斜杠时返回整个字符串 } bool sp_AddFirewallRuleByPath(const char *pszPath) { bool returnRet = true; //pszPath input path D:\\Runxxx\\version\\7.1.1.1 std::string inputPath = std::string(pszPath); while (!inputPath.empty() && (inputPath.back() == '\\' || inputPath.back() == '/')) { inputPath.pop_back(); } std::string firstLevelDir = getFirstLevelDir(inputPath); std::string lastDir = getLastFolderName(inputPath); std::string header = firstLevelDir + "_" + lastDir; std::map firewallRule; firewallRule[header + "_guardian"] = inputPath + "\\bin\\guardian.exe"; firewallRule[header + "_sphost"] = inputPath + "\\bin\\sphost.exe"; firewallRule[header + "_spshell"] = inputPath + "\\bin\\spshell.exe"; firewallRule[header + "_cefclient"] = inputPath + "\\bin\\Chromium\\cefclient.exe"; FirewallController::Initialize(); for(auto &it : firewallRule) { std::wstring ruleNameW = charToWstring(it.first.c_str()); std::wstring appPathW = charToWstring(it.second.c_str()); bool ret = FirewallController::AddFirewallRule(ruleNameW, appPathW, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow); DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM) ("Add firewall rule %s. firstLevelDir: %s, ruleName: %s, path: %s", ret ? "success" : "failed", firstLevelDir.c_str(), it.first.c_str(), it.second.c_str()); returnRet = returnRet && ret; } FirewallController::Shutdown(); return returnRet; } void listFolders(const std::string& dirPath, std::vector& folders) { folders.clear(); WIN32_FIND_DATA findData; HANDLE hFind = FindFirstFile((dirPath + "\\*").c_str(), &findData); if (hFind == INVALID_HANDLE_VALUE) { return; } do { if (findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) { if (strcmp(findData.cFileName, ".") != 0 && strcmp(findData.cFileName, "..") != 0) { folders.push_back(findData.cFileName); // 存入vector } } } while (FindNextFile(hFind, &findData)); FindClose(hFind); } using Multimap = std::multimap; // 比较两个multimap,返回多出和缺少的键值对 void compareMultimaps(const Multimap& base, const Multimap& target, std::vector>& extra, std::vector>& missing) { // 检查多出的数据:遍历target,不在base中的键值对 for (const auto& pair : target) { auto range = base.equal_range(pair.first); auto it = std::find_if(range.first, range.second, [&pair](const auto& p) { return p.second == pair.second; }); if (it == range.second) { extra.push_back(pair); } } // 检查缺少的数据:遍历base,不在target中的键值对 for (const auto& pair : base) { auto range = target.equal_range(pair.first); auto it = std::find_if(range.first, range.second, [&pair](const auto& p) { return p.second == pair.second; }); if (it == range.second) { missing.push_back(pair); } } } bool sp_CheckAllRules() { if (FirewallController::Initialize() == false) { DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("FirewallController::Initialize failed"); return false; } //get current first level dir and version path char szVersionDir[MAX_PATH] = {}; GetModuleFileNameA(NULL, szVersionDir, MAX_PATH); *strrchr(szVersionDir, SPLIT_SLASH) = 0; *strrchr(szVersionDir, SPLIT_SLASH) = 0; std::string currentVerDir = szVersionDir; *strrchr(szVersionDir, SPLIT_SLASH) = 0; std::string versionDir = szVersionDir; while (!versionDir.empty() && (versionDir.back() == '\\' || versionDir.back() == '/')) { versionDir.pop_back(); } //basic check std::string firstLevelDir = getFirstLevelDir(versionDir); std::vector WhiteListDirs; listFolders(versionDir, WhiteListDirs); if (WhiteListDirs.size() == 0) { DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("WhiteListDirs is empty. versionDir: %s", versionDir.c_str()); return false; } // firewall ruls I guess Multimap firewallRuleAll; for(auto &it : WhiteListDirs) { std::string lastDir = getLastFolderName(it); std::string header = firstLevelDir + "_" + lastDir; std::string currentVerDir = versionDir + "\\" + it; firewallRuleAll.insert({header + "_guardian", currentVerDir + "\\bin\\guardian.exe"}); firewallRuleAll.insert({header + "_sphost", currentVerDir + "\\bin\\sphost.exe"}); firewallRuleAll.insert({header + "_spshell", currentVerDir + "\\bin\\spshell.exe"}); firewallRuleAll.insert({header + "_cefclient", currentVerDir + "\\bin\\Chromium\\cefclient.exe"}); } //query all rules Multimap ruleArr; std::string pattern = firstLevelDir + "_*"; std::wstring headerW = charToWstring(pattern.c_str()); auto rules = FirewallController::QueryRules(headerW); for (const auto& ruleInfo : rules) { std::string ruleName = WStringToString(ruleInfo.name); std::string appPath = WStringToString(ruleInfo.applicationName); ruleArr.insert(std::make_pair(ruleName, appPath));// can repeated } for(auto &it : firewallRuleAll) { DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("guess firewall rule:%s, %s", it.first.c_str(), it.second.c_str()); } for(auto &it : ruleArr) { DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("current firewall rule:%s, %s", it.first.c_str(), it.second.c_str()); } std::vector> extra, missing; compareMultimaps(firewallRuleAll, ruleArr, extra, missing); for(auto &it : extra) { std::wstring ruleName = charToWstring(it.first.c_str()); bool ret = FirewallController::DeleteFirewallRule(ruleName); DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM) ("Delete firewall rule %s. ruleName: %s", ret ? "success" : "failed", it.first.c_str()); } for(auto &it : missing) { std::wstring ruleName = charToWstring(it.first.c_str()); std::wstring appPath = charToWstring(it.second.c_str()); bool ret = FirewallController::AddFirewallRule(ruleName, appPath, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow); DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM) ("Add firewall rule %s. header: %s, ruleName: %s, path: %s", ret ? "success" : "failed", firstLevelDir.c_str(), it.first.c_str(), it.second.c_str()); } FirewallController::Shutdown(); return true; }