sp_rpc.c 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. #include "precompile.h"
  2. #include "sp_def.h"
  3. #include "sp_svc.h"
  4. #include "sp_rpc.h"
  5. #include "sp_dbg_export.h"
  6. #include "list.h"
  7. #include "memutil.h"
  8. #include "spinlock.h"
  9. #include "refcnt.h"
  10. #include "sp_logwithlinkforc.h"
  11. #include <winpr/synch.h>
  12. #define TAG SPBASE_TAG("sp_rpc")
  13. #define BUCKET_SIZE 127
  14. /*
  15. Create +------------+ SENT +------------+ ANS +------------+ Destroy +------------+
  16. ------> | INIT | ----> | SENT | ----->| CALLED | -------->| TERM |
  17. +------------+ +------------+ +------------+ +------------+
  18. */
  19. #define STATE_INIT 0
  20. #define STATE_SENT 1
  21. #define STATE_CALLED 2
  22. #define STATE_TERM 3
  23. #define STATE_ERROR 4
  24. #define RPC_CMD_INFO 0
  25. #define RPC_CMD_REQ 1
  26. #define RPC_CMD_ANS 2
  27. struct sp_rpc_server_t
  28. {
  29. int stop;
  30. sp_rpc_server_callback cb;
  31. sp_svc_t *svc;
  32. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  33. };
  34. DECLARE_REF_COUNT_STATIC(sp_rpc_server, sp_rpc_server_t)
  35. static void __threadpool_server_on_pkt(threadpool_t *threadpool, void *arg, param_size_t param1, param_size_t param2)
  36. {
  37. sp_rpc_server_t *server = (sp_rpc_server_t *)arg;
  38. iobuffer_t *pkt = (iobuffer_t*)param1;
  39. int epid;
  40. int svc_id;
  41. int pkt_type;
  42. int pkt_id;
  43. int cmd_type;
  44. iobuffer_read(pkt, IOBUF_T_I4, &epid, 0);
  45. iobuffer_read(pkt, IOBUF_T_I4, &svc_id, 0);
  46. iobuffer_read(pkt, IOBUF_T_I4, &pkt_type, 0);
  47. iobuffer_read(pkt, IOBUF_T_I4, &pkt_id, 0);
  48. cmd_type = SP_GET_TYPE(pkt_type);
  49. if (cmd_type == RPC_CMD_INFO) {
  50. server->cb.on_info(server, epid, svc_id, pkt_id, &pkt, server->cb.user_data);
  51. } else if (cmd_type == RPC_CMD_REQ) {
  52. int call_type;
  53. iobuffer_read(pkt, IOBUF_T_I4, &call_type, NULL);
  54. server->cb.on_req(server, epid, svc_id, pkt_id, call_type, &pkt, server->cb.user_data);
  55. } else {
  56. DbgWithLinkForC(LOG_LEVEL_WARN, LOG_TYPE_SYSTEM, "RPC CMD unknown types!");
  57. }
  58. sp_rpc_server_dec_ref(server); // @
  59. if (pkt)
  60. iobuffer_dec_ref(pkt);
  61. }
  62. static int server_on_pkt(sp_svc_t *svc, int epid, int svc_id, int pkt_type, int pkt_id, iobuffer_t **p_pkt, void *user_data)
  63. {
  64. sp_rpc_server_t *server = (sp_rpc_server_t*)user_data;
  65. int rc;
  66. iobuffer_t *pkt;
  67. pkt = *p_pkt;
  68. *p_pkt = NULL;
  69. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_id, 0);
  70. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_type, 0);
  71. iobuffer_write_head(pkt, IOBUF_T_I4, &svc_id, 0);
  72. iobuffer_write_head(pkt, IOBUF_T_I4, &epid, 0);
  73. sp_rpc_server_inc_ref(server); // @
  74. rc = threadpool_queue_workitem2(sp_svc_get_threadpool(svc), NULL, &__threadpool_server_on_pkt, server, (param_size_t)pkt, 0);
  75. if (rc != 0) {
  76. sp_rpc_server_dec_ref(server); // @
  77. iobuffer_dec_ref(pkt);
  78. }
  79. return FALSE;
  80. }
  81. int sp_rpc_server_create(sp_svc_t *svc, sp_rpc_server_callback *cb, sp_rpc_server_t **p_server)
  82. {
  83. sp_rpc_server_t *server = MALLOC_T(sp_rpc_server_t);
  84. server->stop = 0;
  85. memcpy(&server->cb, cb, sizeof(sp_rpc_server_callback));
  86. server->svc = svc;
  87. REF_COUNT_INIT(&server->ref_cnt);
  88. *p_server = server;
  89. return 0;
  90. }
  91. void sp_rpc_server_destroy(sp_rpc_server_t *server)
  92. {
  93. sp_rpc_server_dec_ref(server);
  94. }
  95. int sp_rpc_server_start(sp_rpc_server_t *server)
  96. {
  97. server->stop = 0;
  98. return sp_svc_add_pkt_handler(server->svc, (int)server, SP_PKT_RPC, &server_on_pkt, server);
  99. }
  100. int sp_rpc_server_stop(sp_rpc_server_t *server)
  101. {
  102. // BugFix [4/5/2020 11:55 Gifur]
  103. if (/*!*/server->stop)
  104. return Error_Bug;
  105. server->stop = 1;
  106. return sp_svc_remove_pkt_handler(server->svc, (int)server, SP_PKT_RPC);
  107. }
  108. sp_svc_t *sp_rpc_server_get_svc(sp_rpc_server_t *server)
  109. {
  110. return server->svc;
  111. }
  112. int sp_rpc_server_send_answer(sp_rpc_server_t *server, int epid, int svc_id, int rpc_id, iobuffer_t **ans_pkt)
  113. {
  114. return sp_svc_post(server->svc, epid, svc_id, SP_PKT_RPC | RPC_CMD_ANS, rpc_id, ans_pkt);
  115. }
  116. static void __sp_rpc_destroy(sp_rpc_server_t *server)
  117. {
  118. if (server->cb.on_destroy) {
  119. (*server->cb.on_destroy)(server, server->cb.user_data);
  120. }
  121. free(server);
  122. }
  123. IMPLEMENT_REF_COUNT_MT(sp_rpc_server, sp_rpc_server_t, ref_cnt, __sp_rpc_destroy)
  124. struct sp_rpc_client_t
  125. {
  126. struct hlist_node hentry; // element of sp_rpc_client_mgr_t->rpc_buckets[index]
  127. int state;
  128. int remote_epid;
  129. int remote_svc_id;
  130. unsigned int rpc_id;
  131. int call_type;
  132. spinlock_t lock;
  133. sp_rpc_client_callback cb;
  134. sp_rpc_client_mgr_t *mgr;
  135. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  136. };
  137. DECLARE_REF_COUNT_STATIC(sp_rpc_client, sp_rpc_client_t)
  138. struct sp_rpc_client_mgr_t
  139. {
  140. struct hlist_head rpc_buckets[BUCKET_SIZE]; // list of sp_rpc_client_t
  141. sp_svc_t *svc;
  142. int rpc_cnt;
  143. int stop;
  144. int local_seq;
  145. sp_rpc_client_mgr_callback cb;
  146. CRITICAL_SECTION lock;
  147. DECLARE_REF_COUNT_MEMBER(ref_cnt);
  148. };
  149. DECLARE_REF_COUNT_STATIC(sp_rpc_client_mgr, sp_rpc_client_mgr_t)
  150. static __inline void mgr_lock(sp_rpc_client_mgr_t *mgr)
  151. {
  152. EnterCriticalSection(&mgr->lock);
  153. }
  154. static __inline void mgr_unlock(sp_rpc_client_mgr_t *mgr)
  155. {
  156. LeaveCriticalSection(&mgr->lock);
  157. }
  158. static __inline void client_lock(sp_rpc_client_t *client)
  159. {
  160. spinlock_enter(&client->lock, -1);
  161. }
  162. static __inline void client_unlock(sp_rpc_client_t *client)
  163. {
  164. spinlock_leave(&client->lock);
  165. }
  166. static void client_set_error(sp_rpc_client_t *client, int error);
  167. static void client_process_ans(sp_rpc_client_t *client, iobuffer_t **ans_pkt);
  168. static void __threadpool_mgr_on_req(threadpool_t *threadpool, void *arg, param_size_t param1, param_size_t param2)
  169. {
  170. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t *)arg;
  171. iobuffer_t *pkt = (iobuffer_t*)param1;
  172. int epid;
  173. int svc_id;
  174. int pkt_type;
  175. int pkt_id;
  176. int cmd_type;
  177. iobuffer_read(pkt, IOBUF_T_I4, &epid, 0);
  178. iobuffer_read(pkt, IOBUF_T_I4, &svc_id, 0);
  179. iobuffer_read(pkt, IOBUF_T_I4, &pkt_type, 0);
  180. iobuffer_read(pkt, IOBUF_T_I4, &pkt_id, 0);
  181. cmd_type = SP_GET_TYPE(pkt_type);
  182. if (cmd_type == RPC_CMD_REQ && mgr->cb.on_req)
  183. {
  184. int call_type;
  185. iobuffer_read(pkt, IOBUF_T_I4, &call_type, NULL);
  186. mgr->cb.on_req(mgr, epid, svc_id, pkt_id, call_type, &pkt, mgr->cb.user_data);
  187. }
  188. else
  189. {
  190. DbgWithLinkForC(LOG_LEVEL_WARN, LOG_TYPE_SYSTEM, "RPC CMD unknown types!");
  191. }
  192. sp_rpc_client_mgr_dec_ref(mgr); // @
  193. if (pkt)
  194. iobuffer_dec_ref(pkt);
  195. }
  196. static int mgr_on_pkt(sp_svc_t *svc,int epid, int svc_id, int pkt_type, int pkt_id, iobuffer_t **p_pkt, void *user_data)
  197. {
  198. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t*)user_data;
  199. WLog_DBG(TAG, "sp_rpc::mgr_on_pkt: epid:%d, svc_id: %d, pkt_type:0x%08X, pkt_id: %d, rpc:%d", epid, svc_id, pkt_type, pkt_id, SP_GET_TYPE(pkt_type));
  200. if (SP_GET_TYPE(pkt_type) == RPC_CMD_ANS) {
  201. int rpc_id = pkt_id;
  202. int slot = ((unsigned int)rpc_id) % BUCKET_SIZE;
  203. sp_rpc_client_t *tpos;
  204. struct hlist_node *pos, *n;
  205. mgr_lock(mgr);
  206. hlist_for_each_entry_safe(tpos, pos, n, &mgr->rpc_buckets[slot], sp_rpc_client_t, hentry) {
  207. if (tpos->rpc_id == rpc_id) {
  208. client_process_ans(tpos, p_pkt);
  209. break;
  210. }
  211. }
  212. mgr_unlock(mgr);
  213. return FALSE;
  214. }
  215. else if (SP_GET_TYPE(pkt_type) == RPC_CMD_REQ)
  216. {
  217. int rc;
  218. iobuffer_t *pkt = *p_pkt;
  219. *p_pkt = NULL;
  220. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_id, 0);
  221. iobuffer_write_head(pkt, IOBUF_T_I4, &pkt_type, 0);
  222. iobuffer_write_head(pkt, IOBUF_T_I4, &svc_id, 0);
  223. iobuffer_write_head(pkt, IOBUF_T_I4, &epid, 0);
  224. sp_rpc_client_mgr_inc_ref(mgr);
  225. rc = threadpool_queue_workitem2(sp_svc_get_threadpool(svc), NULL, &__threadpool_mgr_on_req, mgr, (param_size_t)pkt, 0);
  226. if (rc != 0) {
  227. sp_rpc_client_mgr_dec_ref(mgr); // @
  228. iobuffer_dec_ref(pkt);
  229. }
  230. }
  231. return TRUE;
  232. }
  233. static void mgr_on_sys(sp_svc_t *svc,int epid, int state, void *user_data)
  234. {
  235. sp_rpc_client_mgr_t *mgr = (sp_rpc_client_mgr_t*)user_data;
  236. if (state == BUS_STATE_OFF) {
  237. int i;
  238. sp_rpc_client_t *tpos;
  239. struct hlist_node *pos, *n;
  240. mgr_lock(mgr);
  241. for (i = 0; i < BUCKET_SIZE; ++i) {
  242. hlist_for_each_entry_safe(tpos, pos, n, &mgr->rpc_buckets[i], sp_rpc_client_t, hentry) {
  243. if (tpos->remote_epid == epid) {
  244. client_set_error(tpos, Error_NetBroken);
  245. }
  246. }
  247. }
  248. mgr_unlock(mgr);
  249. }
  250. }
  251. int sp_rpc_client_mgr_create(sp_svc_t *svc, sp_rpc_client_mgr_callback *cb, sp_rpc_client_mgr_t **p_mgr)
  252. {
  253. int i;
  254. sp_rpc_client_mgr_t *mgr = MALLOC_T(sp_rpc_client_mgr_t);
  255. mgr->local_seq = 0;
  256. mgr->rpc_cnt = 0;
  257. mgr->stop = 0;
  258. mgr->svc = svc;
  259. memcpy(&mgr->cb, cb, sizeof(sp_rpc_client_mgr_callback));
  260. for (i = 0;i < BUCKET_SIZE; ++i) {
  261. INIT_HLIST_HEAD(&mgr->rpc_buckets[i]);
  262. }
  263. InitializeCriticalSection(&mgr->lock);
  264. REF_COUNT_INIT(&mgr->ref_cnt);
  265. *p_mgr = mgr;
  266. return 0;
  267. }
  268. // {bug} not delete rpc_buckets arrary
  269. void sp_rpc_client_mgr_destroy(sp_rpc_client_mgr_t *mgr)
  270. {
  271. sp_rpc_client_mgr_dec_ref(mgr);
  272. }
  273. int sp_rpc_client_mgr_start(sp_rpc_client_mgr_t *mgr)
  274. {
  275. mgr->stop = 0;
  276. sp_svc_add_pkt_handler(mgr->svc, (int)mgr, SP_PKT_RPC, &mgr_on_pkt, mgr);
  277. sp_svc_add_sys_handler(mgr->svc, (int)mgr, &mgr_on_sys, mgr);
  278. return 0;
  279. }
  280. int sp_rpc_client_mgr_stop(sp_rpc_client_mgr_t *mgr)
  281. {
  282. sp_svc_remove_pkt_handler(mgr->svc, (int)mgr, SP_PKT_RPC);
  283. sp_svc_remove_sys_handler(mgr->svc, (int)mgr);
  284. return 0;
  285. }
  286. sp_svc_t *sp_rpc_client_mgr_get_svc(sp_rpc_client_mgr_t *mgr)
  287. {
  288. return mgr->svc;
  289. }
  290. int sp_rpc_client_mgr_cancel_all(sp_rpc_client_mgr_t *mgr)
  291. {
  292. int i;
  293. mgr_lock(mgr);
  294. for (i = 0; i < BUCKET_SIZE; ++i) {
  295. sp_rpc_client_t *tpos;
  296. struct hlist_node *pos;
  297. hlist_for_each_entry(tpos, pos, &mgr->rpc_buckets[i], sp_rpc_client_t, hentry) {
  298. client_set_error(tpos, Error_Cancel);
  299. }
  300. }
  301. mgr_unlock(mgr);
  302. return 0;
  303. }
  304. int sp_rpc_client_mgr_get_client_cnt(sp_rpc_client_mgr_t *mgr)
  305. {
  306. return mgr->rpc_cnt;
  307. }
  308. int sp_rpc_client_mgr_one_way_call(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int call_type, iobuffer_t **info_pkt)
  309. {
  310. return sp_svc_post(mgr->svc, epid, svc_id, SP_PKT_RPC| RPC_CMD_INFO, call_type, info_pkt);
  311. }
  312. int sp_rpc_client_mgr_send_answer(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int rpc_id, iobuffer_t **ans_pkt)
  313. {
  314. return sp_svc_post(mgr->svc, epid, svc_id, SP_PKT_RPC | RPC_CMD_ANS, rpc_id, ans_pkt);
  315. }
  316. static void __sp_rpc_client_mgr_destroy(sp_rpc_client_mgr_t *mgr)
  317. {
  318. if (mgr->cb.on_destroy)
  319. mgr->cb.on_destroy(mgr, mgr->cb.user_data);
  320. DeleteCriticalSection(&mgr->lock);
  321. free(mgr);
  322. }
  323. IMPLEMENT_REF_COUNT_MT_STATIC(sp_rpc_client_mgr, sp_rpc_client_mgr_t, ref_cnt, __sp_rpc_client_mgr_destroy)
  324. int sp_rpc_client_create(sp_rpc_client_mgr_t *mgr, int epid, int svc_id, int call_type, sp_rpc_client_callback *cb, sp_rpc_client_t **p_client)
  325. {
  326. sp_rpc_client_t *client = MALLOC_T(sp_rpc_client_t);
  327. client->mgr = mgr;
  328. client->remote_epid = epid;
  329. client->remote_svc_id = svc_id;
  330. client->call_type = call_type;
  331. memcpy(&client->cb, cb, sizeof(sp_rpc_client_callback));
  332. client->rpc_id = (int)InterlockedIncrement((LONG*)&mgr->local_seq);
  333. spinlock_init(&client->lock);
  334. client->state = STATE_INIT;
  335. REF_COUNT_INIT(&client->ref_cnt);
  336. sp_rpc_client_mgr_inc_ref(mgr);
  337. sp_rpc_client_inc_ref(client);
  338. mgr_lock(mgr);
  339. hlist_add_head(&client->hentry, &mgr->rpc_buckets[client->rpc_id % BUCKET_SIZE]);
  340. client->mgr->rpc_cnt++;
  341. mgr_unlock(mgr);
  342. *p_client = client;
  343. return 0;
  344. }
  345. int sp_rpc_client_close(sp_rpc_client_t *client)
  346. {
  347. int rc;
  348. client_lock(client);
  349. if (client->state != STATE_TERM && client->state != STATE_ERROR) {
  350. client->state = STATE_ERROR;
  351. rc = 0;
  352. } else {
  353. rc = Error_Duplication;
  354. }
  355. client_unlock(client);
  356. return rc;
  357. }
  358. void sp_rpc_client_destroy(sp_rpc_client_t *client)
  359. {
  360. mgr_lock(client->mgr);
  361. client->mgr->rpc_cnt --;
  362. hlist_del(&client->hentry);
  363. mgr_unlock(client->mgr);
  364. sp_rpc_client_dec_ref(client);
  365. client_lock(client);
  366. client->state = STATE_TERM;
  367. client_unlock(client);
  368. sp_rpc_client_dec_ref(client);
  369. }
  370. int sp_rpc_client_async_call(sp_rpc_client_t *client, iobuffer_t **req_pkt)
  371. {
  372. sp_rpc_client_mgr_t *mgr = client->mgr;
  373. int rc = 0;
  374. if (client->state != STATE_INIT)
  375. return Error_Bug;
  376. client_lock(client);
  377. if (client->state == STATE_INIT) {
  378. client->state = STATE_SENT;
  379. sp_rpc_client_inc_ref(client); // @
  380. iobuffer_write_head(*req_pkt, IOBUF_T_I4, &client->call_type, 0);
  381. rc = sp_svc_post(mgr->svc, client->remote_epid, client->remote_svc_id, SP_PKT_RPC|RPC_CMD_REQ, client->rpc_id, req_pkt);
  382. if (rc != 0) {
  383. sp_rpc_client_dec_ref(client); // @
  384. client->state = STATE_ERROR;
  385. }
  386. } else {
  387. rc = Error_NetBroken;
  388. }
  389. client_unlock(client);
  390. return rc;
  391. }
  392. int sp_rpc_client_get_rpc_id(sp_rpc_client_t *client)
  393. {
  394. return client->rpc_id;
  395. }
  396. int sp_rpc_client_get_remote_epid(sp_rpc_client_t *client)
  397. {
  398. return client->remote_epid;
  399. }
  400. int sp_rpc_client_get_remote_svc_id(sp_rpc_client_t *client)
  401. {
  402. return client->remote_svc_id;
  403. }
  404. static void client_set_error(sp_rpc_client_t *client, int error)
  405. {
  406. if (client->state != STATE_ERROR && client->state != STATE_TERM) {
  407. client_lock(client);
  408. if (client->state != STATE_ERROR && client->state != STATE_TERM) {
  409. if (client->state == STATE_SENT) {
  410. if (client->cb.on_ans) {
  411. client->cb.on_ans(client, error, NULL, client->cb.user_data);
  412. }
  413. } else {
  414. client->state = STATE_ERROR;
  415. }
  416. }
  417. client_unlock(client);
  418. }
  419. sp_rpc_client_dec_ref(client); // @
  420. }
  421. static void client_process_ans(sp_rpc_client_t *client, iobuffer_t **ans_pkt)
  422. {
  423. if (client->state == STATE_SENT) {
  424. client_lock(client);
  425. if (client->state == STATE_SENT) {
  426. client->state = STATE_CALLED;
  427. if (client->cb.on_ans) {
  428. client->cb.on_ans(client, 0, ans_pkt, client->cb.user_data);
  429. }
  430. }
  431. client_unlock(client);
  432. }
  433. sp_rpc_client_dec_ref(client); // @
  434. }
  435. static void __client_destroy(sp_rpc_client_t *client)
  436. {
  437. if (client->cb.on_destroy)
  438. client->cb.on_destroy(client, client->cb.user_data);
  439. sp_rpc_client_mgr_dec_ref(client->mgr);
  440. free(client);
  441. }
  442. IMPLEMENT_REF_COUNT_MT_STATIC(sp_rpc_client, sp_rpc_client_t, ref_cnt, __client_destroy)