sp_firewallControl.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. #include "sp_firewallControl.h"
  2. #include <comutil.h>
  3. #include <map>
  4. #include "path.h"
  5. #include <locale>
  6. #include <codecvt>
  7. #include <string>
  8. // 静态成员初始化
  9. INetFwPolicy2* FirewallController::firewallPolicy = nullptr;
  10. bool FirewallController::comInitialized = false;
  11. // 通配符匹配实现(支持*和?)
  12. bool FirewallController::WildcardMatch(const std::wstring& pattern, const std::wstring& text) {
  13. size_t p = 0, s = 0;
  14. while (1) {
  15. if (p == pattern.size() && s == text.size()) return true;
  16. if (p == pattern.size()) return false;
  17. if (s == text.size()) return p + 1 == pattern.size() && pattern[p] == L'*';
  18. if (pattern[p] == text[s] || pattern[p] == L'?') {
  19. p++; s++;
  20. continue;
  21. }
  22. if (pattern[p] == L'*') {
  23. if (p + 1 == pattern.size()) return true;
  24. do {
  25. if (WildcardMatch(pattern.substr(p + 1), text.substr(s)))
  26. return true;
  27. s++;
  28. } while (s != text.size());
  29. return false;
  30. }
  31. return false;
  32. }
  33. }
  34. bool FirewallController::Initialize() {
  35. if (comInitialized) return true;
  36. HRESULT hr = CoInitializeEx(0, COINIT_APARTMENTTHREADED);
  37. if (FAILED(hr)) return false;
  38. hr = CoCreateInstance(__uuidof(NetFwPolicy2), nullptr, CLSCTX_INPROC_SERVER,
  39. __uuidof(INetFwPolicy2), (void**)&firewallPolicy);
  40. comInitialized = SUCCEEDED(hr);
  41. return comInitialized;
  42. }
  43. std::vector<FirewallRuleInfo> FirewallController::QueryRules(const std::wstring& ruleNamePattern) {
  44. std::vector<FirewallRuleInfo> matchedRules;
  45. if (!comInitialized || !firewallPolicy) return matchedRules;
  46. INetFwRules* rules = nullptr;
  47. if (FAILED(firewallPolicy->get_Rules(&rules))) return matchedRules;
  48. IEnumVARIANT* enumerator = nullptr;
  49. if (SUCCEEDED(rules->get__NewEnum((IUnknown**)&enumerator))) {
  50. VARIANT var;
  51. while (enumerator->Next(1, &var, nullptr) == S_OK) {
  52. INetFwRule* rule = nullptr;
  53. if (SUCCEEDED(V_DISPATCH(&var)->QueryInterface(__uuidof(INetFwRule), (void**)&rule))) {
  54. BSTR name;
  55. if (SUCCEEDED(rule->get_Name(&name))) {
  56. if (WildcardMatch(ruleNamePattern, name)) {
  57. FirewallRuleInfo info;
  58. info.name = name;
  59. BSTR desc, app, service;
  60. rule->get_Description(&desc);
  61. rule->get_ApplicationName(&app);
  62. rule->get_ServiceName(&service);
  63. NET_FW_RULE_DIRECTION direction;
  64. rule->get_Direction(&direction);
  65. info.direction = static_cast<long>(direction);
  66. VARIANT_BOOL enabled;
  67. rule->get_Enabled(&enabled);
  68. info.description = desc ? desc : L"";
  69. info.applicationName = app ? app : L"";
  70. info.serviceName = service ? service : L"";
  71. info.enabled = enabled == VARIANT_TRUE;
  72. matchedRules.push_back(info);
  73. SysFreeString(desc);
  74. SysFreeString(app);
  75. SysFreeString(service);
  76. }
  77. SysFreeString(name);
  78. }
  79. rule->Release();
  80. }
  81. VariantClear(&var);
  82. }
  83. enumerator->Release();
  84. }
  85. rules->Release();
  86. return matchedRules;
  87. }
  88. void FirewallController::Shutdown() {
  89. if (firewallPolicy)
  90. {
  91. firewallPolicy->Release();
  92. firewallPolicy = nullptr;
  93. }
  94. if (comInitialized)
  95. {
  96. CoUninitialize();
  97. comInitialized = false;
  98. }
  99. }
  100. bool FirewallController::AddFirewallRule(
  101. const std::wstring& ruleName,
  102. const std::wstring& appPath,
  103. FirewallRuleDirection direction,
  104. FirewallRuleAction action,
  105. const std::wstring& protocol,
  106. const std::wstring& localPorts,
  107. const std::wstring& remoteAddresses,
  108. const std::wstring& description)
  109. {
  110. if (!comInitialized || !firewallPolicy) return false;
  111. INetFwRule* rule = nullptr;
  112. HRESULT hr = CoCreateInstance(
  113. __uuidof(NetFwRule),
  114. nullptr,
  115. CLSCTX_INPROC_SERVER,
  116. __uuidof(INetFwRule),
  117. (void**)&rule
  118. );
  119. if (FAILED(hr)) return false;
  120. // 设置规则属性
  121. rule->put_Name(_bstr_t(ruleName.c_str()));
  122. if (!appPath.empty()) rule->put_ApplicationName(_bstr_t(appPath.c_str()));
  123. if (!localPorts.empty()) rule->put_LocalPorts(_bstr_t(localPorts.c_str()));
  124. LONG protocolValue = 0;
  125. if (protocol == L"TCP") protocolValue = NET_FW_IP_PROTOCOL_TCP;
  126. else if (protocol == L"UDP") protocolValue = NET_FW_IP_PROTOCOL_UDP;
  127. else protocolValue = NET_FW_IP_PROTOCOL_ANY; // 默认值
  128. rule->put_Protocol(protocolValue); // 传入LONG类型值
  129. rule->put_RemoteAddresses(_bstr_t(remoteAddresses.c_str()));
  130. rule->put_Direction(static_cast<NET_FW_RULE_DIRECTION>(direction));
  131. rule->put_Action(static_cast<NET_FW_ACTION>(action));
  132. rule->put_Enabled(VARIANT_TRUE);
  133. if (!description.empty()) rule->put_Description(_bstr_t(description.c_str()));
  134. // 应用到所有配置文件
  135. rule->put_Profiles(NET_FW_PROFILE2_DOMAIN | NET_FW_PROFILE2_PRIVATE | NET_FW_PROFILE2_PUBLIC);
  136. // 添加规则
  137. INetFwRules* rules = nullptr;
  138. hr = firewallPolicy->get_Rules(&rules);
  139. if (SUCCEEDED(hr)) {
  140. hr = rules->Add(rule);
  141. rules->Release();
  142. }
  143. rule->Release();
  144. if (hr == S_OK) {
  145. return true; // 明确表示成功
  146. } else if (hr == E_ACCESSDENIED) {
  147. // 处理权限错误
  148. return false;
  149. }
  150. else {
  151. return false;
  152. }
  153. }
  154. bool FirewallController::DeleteFirewallRule(const std::wstring& ruleName) {
  155. if (!comInitialized || !firewallPolicy) return false;
  156. INetFwRules* rules = nullptr;
  157. HRESULT hr = firewallPolicy->get_Rules(&rules);
  158. if (FAILED(hr)) return false;
  159. hr = rules->Remove(_bstr_t(ruleName.c_str()));
  160. rules->Release();
  161. if (hr == S_OK) {
  162. return true; // 明确表示成功
  163. } else if (hr == E_ACCESSDENIED) {
  164. // 处理权限错误
  165. return false;
  166. }
  167. else{
  168. return false;
  169. }
  170. }
  171. bool FirewallController::CleanupRulesExceptWhitelist(
  172. const std::wstring& ruleNamePattern,
  173. const std::vector<std::wstring>& whitelistPaths) {
  174. if (!comInitialized || !firewallPolicy) return false;
  175. auto matchedRules = QueryRules(ruleNamePattern);
  176. if (matchedRules.empty()) return true;
  177. INetFwRules* rules = nullptr;
  178. HRESULT hr = firewallPolicy->get_Rules(&rules);
  179. if (FAILED(hr)) return false;
  180. bool allSuccess = true;
  181. for (const auto& ruleInfo : matchedRules) {
  182. bool isWhitelisted = false;
  183. for (const auto& whitelistPath : whitelistPaths) {
  184. if (WildcardMatch(whitelistPath, ruleInfo.applicationName)) {
  185. isWhitelisted = true;
  186. break;
  187. }
  188. }
  189. if (!isWhitelisted) {
  190. hr = rules->Remove(_bstr_t(ruleInfo.name.c_str()));
  191. if (FAILED(hr)) allSuccess = false;
  192. }
  193. }
  194. rules->Release();
  195. return allSuccess;
  196. }
  197. std::wstring charToWstring(const char* szIn) {
  198. int length = MultiByteToWideChar(CP_ACP, 0, szIn, -1, NULL, 0);
  199. wchar_t* buf = new wchar_t[length];
  200. MultiByteToWideChar(CP_ACP, 0, szIn, -1, buf, length);
  201. std::wstring result(buf);
  202. delete[] buf;
  203. return result;
  204. }
  205. std::string WStringToString(const std::wstring& wstr, unsigned int codepage = CP_ACP) {
  206. if (wstr.empty()) return "";
  207. int len = WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), NULL, 0, NULL, NULL);
  208. if (len <= 0) return "";
  209. std::string result(len, '\0');
  210. WideCharToMultiByte(codepage, 0, wstr.c_str(), (int)wstr.size(), &result[0], len, NULL, NULL);
  211. return result;
  212. }
  213. bool sp_AddFirewallRule(const char *ruleName, const char *appPath)
  214. {
  215. int ret = -1;
  216. if (FirewallController::Initialize()) {
  217. std::wstring ruleNameW = charToWstring(ruleName);
  218. std::wstring appPathW = charToWstring(appPath);
  219. ret = FirewallController::AddFirewallRule(ruleNameW, appPathW,
  220. FirewallRuleDirection::Inbound, FirewallRuleAction::Allow) ? 0 : -1;
  221. FirewallController::Shutdown();
  222. }
  223. return ret;
  224. }
  225. std::string getFirstLevelDir(const std::string& path) {
  226. size_t start = path.find_first_of("\\/"); // 查找第一个分隔符
  227. if (start == std::string::npos) return ""; // 无分隔符时返回空
  228. size_t end = path.find_first_of("\\/", start + 1); // 查找第二个分隔符
  229. if (end == std::string::npos) end = path.length();
  230. return path.substr(start + 1, end - start - 1); // 截取Runxxx
  231. }
  232. std::string getLastFolderName(std::string path) {
  233. // 找到最后一个反斜杠的位置
  234. size_t lastSlash = path.find_last_of("\\/");
  235. if (lastSlash != std::string::npos) {
  236. return path.substr(lastSlash + 1);
  237. }
  238. return path; // 无斜杠时返回整个字符串
  239. }
  240. bool sp_AddFirewallRuleByPath(const char *pszPath)
  241. {
  242. bool returnRet = true;
  243. //pszPath input path D:\\Runxxx\\version\\7.1.1.1
  244. std::string inputPath = std::string(pszPath);
  245. while (!inputPath.empty() && (inputPath.back() == '\\' || inputPath.back() == '/')) {
  246. inputPath.pop_back();
  247. }
  248. std::string firstLevelDir = getFirstLevelDir(inputPath);
  249. std::string lastDir = getLastFolderName(inputPath);
  250. std::string header = firstLevelDir + "_" + lastDir;
  251. std::map<std::string, std::string> firewallRule;
  252. firewallRule[header + "_guardian"] = inputPath + "\\bin\\guardian.exe";
  253. firewallRule[header + "_sphost"] = inputPath + "\\bin\\sphost.exe";
  254. firewallRule[header + "_spshell"] = inputPath + "\\bin\\spshell.exe";
  255. firewallRule[header + "_cefclient"] = inputPath + "\\bin\\Chromium\\cefclient.exe";
  256. FirewallController::Initialize();
  257. for(auto &it : firewallRule)
  258. {
  259. std::wstring ruleNameW = charToWstring(it.first.c_str());
  260. std::wstring appPathW = charToWstring(it.second.c_str());
  261. bool ret = FirewallController::AddFirewallRule(ruleNameW, appPathW, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow);
  262. DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
  263. ("Add firewall rule %s. firstLevelDir: %s, ruleName: %s, path: %s", ret ? "success" : "failed",
  264. firstLevelDir.c_str(), it.first.c_str(), it.second.c_str());
  265. returnRet = returnRet && ret;
  266. }
  267. FirewallController::Shutdown();
  268. return returnRet;
  269. }
  270. void listFolders(const std::string& dirPath, std::vector<std::string>& folders) {
  271. folders.clear();
  272. WIN32_FIND_DATA findData;
  273. HANDLE hFind = FindFirstFile((dirPath + "\\*").c_str(), &findData);
  274. if (hFind == INVALID_HANDLE_VALUE) {
  275. return;
  276. }
  277. do {
  278. if (findData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY) {
  279. if (strcmp(findData.cFileName, ".") != 0 && strcmp(findData.cFileName, "..") != 0) {
  280. folders.push_back(findData.cFileName); // 存入vector
  281. }
  282. }
  283. } while (FindNextFile(hFind, &findData));
  284. FindClose(hFind);
  285. }
  286. using Multimap = std::multimap<std::string, std::string>;
  287. // 比较两个multimap,返回多出和缺少的键值对
  288. void compareMultimaps(const Multimap& base, const Multimap& target,
  289. std::vector<std::pair<std::string, std::string>>& extra,
  290. std::vector<std::pair<std::string, std::string>>& missing) {
  291. // 检查多出的数据:遍历target,不在base中的键值对
  292. for (const auto& pair : target) {
  293. auto range = base.equal_range(pair.first);
  294. auto it = std::find_if(range.first, range.second, [&pair](const auto& p) {
  295. return p.second == pair.second;
  296. });
  297. if (it == range.second) {
  298. extra.push_back(pair);
  299. }
  300. }
  301. // 检查缺少的数据:遍历base,不在target中的键值对
  302. for (const auto& pair : base) {
  303. auto range = target.equal_range(pair.first);
  304. auto it = std::find_if(range.first, range.second, [&pair](const auto& p) {
  305. return p.second == pair.second;
  306. });
  307. if (it == range.second) {
  308. missing.push_back(pair);
  309. }
  310. }
  311. }
  312. bool sp_CheckAllRules()
  313. {
  314. if (FirewallController::Initialize() == false)
  315. {
  316. DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("FirewallController::Initialize failed");
  317. return false;
  318. }
  319. //get current first level dir and version path
  320. char szVersionDir[MAX_PATH] = {};
  321. GetModuleFileNameA(NULL, szVersionDir, MAX_PATH);
  322. *strrchr(szVersionDir, SPLIT_SLASH) = 0;
  323. *strrchr(szVersionDir, SPLIT_SLASH) = 0;
  324. std::string currentVerDir = szVersionDir;
  325. *strrchr(szVersionDir, SPLIT_SLASH) = 0;
  326. std::string versionDir = szVersionDir;
  327. while (!versionDir.empty() && (versionDir.back() == '\\' || versionDir.back() == '/')) {
  328. versionDir.pop_back();
  329. }
  330. //basic check
  331. std::string firstLevelDir = getFirstLevelDir(versionDir);
  332. std::vector<std::string> WhiteListDirs;
  333. listFolders(versionDir, WhiteListDirs);
  334. if (WhiteListDirs.size() == 0)
  335. {
  336. DbgWithLink(LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)("WhiteListDirs is empty. versionDir: %s", versionDir.c_str());
  337. return false;
  338. }
  339. // firewall ruls I guess
  340. Multimap firewallRuleAll;
  341. for(auto &it : WhiteListDirs)
  342. {
  343. std::string lastDir = getLastFolderName(it);
  344. std::string header = firstLevelDir + "_" + lastDir;
  345. std::string currentVerDir = versionDir + "\\" + it;
  346. firewallRuleAll.insert({header + "_guardian", currentVerDir + "\\bin\\guardian.exe"});
  347. firewallRuleAll.insert({header + "_sphost", currentVerDir + "\\bin\\sphost.exe"});
  348. firewallRuleAll.insert({header + "_spshell", currentVerDir + "\\bin\\spshell.exe"});
  349. firewallRuleAll.insert({header + "_cefclient", currentVerDir + "\\bin\\Chromium\\cefclient.exe"});
  350. }
  351. //query all rules
  352. Multimap ruleArr;
  353. std::string pattern = firstLevelDir + "_*";
  354. std::wstring headerW = charToWstring(pattern.c_str());
  355. auto rules = FirewallController::QueryRules(headerW);
  356. for (const auto& ruleInfo : rules) {
  357. std::string ruleName = WStringToString(ruleInfo.name);
  358. std::string appPath = WStringToString(ruleInfo.applicationName);
  359. ruleArr.insert(std::make_pair(ruleName, appPath));// can repeated
  360. }
  361. for(auto &it : firewallRuleAll)
  362. {
  363. DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("guess firewall rule:%s, %s", it.first.c_str(), it.second.c_str());
  364. }
  365. for(auto &it : ruleArr)
  366. {
  367. DbgWithLink(LOG_LEVEL_DEBUG, LOG_TYPE_SYSTEM)("current firewall rule:%s, %s", it.first.c_str(), it.second.c_str());
  368. }
  369. std::vector<std::pair<std::string, std::string>> extra, missing;
  370. compareMultimaps(firewallRuleAll, ruleArr, extra, missing);
  371. for(auto &it : extra)
  372. {
  373. std::wstring ruleName = charToWstring(it.first.c_str());
  374. bool ret = FirewallController::DeleteFirewallRule(ruleName);
  375. DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
  376. ("Delete firewall rule %s. ruleName: %s", ret ? "success" : "failed", it.first.c_str());
  377. }
  378. for(auto &it : missing)
  379. {
  380. std::wstring ruleName = charToWstring(it.first.c_str());
  381. std::wstring appPath = charToWstring(it.second.c_str());
  382. bool ret = FirewallController::AddFirewallRule(ruleName, appPath, FirewallRuleDirection::Inbound, FirewallRuleAction::Allow);
  383. DbgWithLink(ret ? LOG_LEVEL_DEBUG : LOG_LEVEL_INFO, LOG_TYPE_SYSTEM)
  384. ("Add firewall rule %s. header: %s, ruleName: %s, path: %s", ret ? "success" : "failed",
  385. firstLevelDir.c_str(), it.first.c_str(), it.second.c_str());
  386. }
  387. FirewallController::Shutdown();
  388. return true;
  389. }