8 #ifndef FAIR_MQ_SHMEM_SOCKET_H_
9 #define FAIR_MQ_SHMEM_SOCKET_H_
15 #include <FairMQSocket.h>
16 #include <FairMQMessage.h>
17 #include <FairMQLogger.h>
18 #include <fairmq/tools/Strings.h>
32 ZMsg() {
int rc __attribute__((unused)) = zmq_msg_init(&fMsg); assert(rc == 0); }
33 explicit ZMsg(
size_t size) {
int rc __attribute__((unused)) = zmq_msg_init_size(&fMsg, size); assert(rc == 0); }
34 ~
ZMsg() {
int rc __attribute__((unused)) = zmq_msg_close(&fMsg); assert(rc == 0); }
36 void* Data() {
return zmq_msg_data(&fMsg); }
37 size_t Size() {
return zmq_msg_size(&fMsg); }
38 zmq_msg_t* Msg() {
return &fMsg; }
50 , fId(
id +
"." + name +
"." + type)
59 if (type ==
"sub" || type ==
"pub") {
60 LOG(error) <<
"PUB/SUB socket type is not supported for shared memory transport";
61 throw SocketError(
"PUB/SUB socket type is not supported for shared memory transport");
64 fSocket = zmq_socket(context, GetConstant(type));
66 if (fSocket ==
nullptr) {
67 LOG(error) <<
"Failed creating socket " << fId <<
", reason: " << zmq_strerror(errno);
68 throw SocketError(tools::ToString(
"Failed creating socket ", fId,
", reason: ", zmq_strerror(errno)));
71 if (zmq_setsockopt(fSocket, ZMQ_IDENTITY, fId.c_str(), fId.length()) != 0) {
72 LOG(error) <<
"Failed setting ZMQ_IDENTITY socket option, reason: " << zmq_strerror(errno);
78 if (zmq_setsockopt(fSocket, ZMQ_LINGER, &linger,
sizeof(linger)) != 0) {
79 LOG(error) <<
"Failed setting ZMQ_LINGER socket option, reason: " << zmq_strerror(errno);
82 if (zmq_setsockopt(fSocket, ZMQ_SNDTIMEO, &fTimeout,
sizeof(fTimeout)) != 0) {
83 LOG(error) <<
"Failed setting ZMQ_SNDTIMEO socket option, reason: " << zmq_strerror(errno);
86 if (zmq_setsockopt(fSocket, ZMQ_RCVTIMEO, &fTimeout,
sizeof(fTimeout)) != 0) {
87 LOG(error) <<
"Failed setting ZMQ_RCVTIMEO socket option, reason: " << zmq_strerror(errno);
97 LOG(debug) <<
"Created socket " << GetId();
103 std::string GetId()
const override {
return fId; }
105 bool Bind(
const std::string& address)
override
108 if (zmq_bind(fSocket, address.c_str()) != 0) {
109 if (errno == EADDRINUSE) {
113 LOG(error) <<
"Failed binding socket " << fId <<
", reason: " << zmq_strerror(errno);
119 bool Connect(
const std::string& address)
override
122 if (zmq_connect(fSocket, address.c_str()) != 0) {
123 LOG(error) <<
"Failed connecting socket " << fId <<
", reason: " << zmq_strerror(errno);
129 bool ShouldRetry(
int flags,
int timeout,
int& elapsed)
const
131 if ((flags & ZMQ_DONTWAIT) == 0) {
134 if (elapsed >= timeout) {
144 int HandleErrors()
const
146 if (zmq_errno() == ETERM) {
147 LOG(debug) <<
"Terminating socket " << fId;
148 return static_cast<int>(TransferCode::error);
150 LOG(error) <<
"Failed transfer on socket " << fId <<
", reason: " << zmq_strerror(errno);
151 return static_cast<int>(TransferCode::error);
155 int64_t Send(MessagePtr& msg,
const int timeout = -1)
override
159 flags = ZMQ_DONTWAIT;
165 std::memcpy(zmqMsg.Data(), &(shmMsg->fMeta),
sizeof(
MetaHeader));
168 int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags);
170 shmMsg->fQueued =
true;
172 size_t size = msg->GetSize();
175 }
else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
176 if (fManager.Interrupted()) {
177 return static_cast<int>(TransferCode::interrupted);
178 }
else if (ShouldRetry(flags, timeout, elapsed)) {
181 return static_cast<int>(TransferCode::timeout);
184 return HandleErrors();
188 return static_cast<int>(TransferCode::error);
191 int64_t Receive(MessagePtr& msg,
const int timeout = -1)
override
195 flags = ZMQ_DONTWAIT;
203 int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags);
208 tools::ToString(
"Received message is not a valid FairMQ shared memory message. ",
209 "Possibly due to a misconfigured transport on the sender side. ",
210 "Expected size of ",
sizeof(
MetaHeader),
" bytes, received ", nbytes));
214 size_t size = hdr->fSize;
215 shmMsg->fMeta = *hdr;
220 }
else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
221 if (fManager.Interrupted()) {
222 return static_cast<int>(TransferCode::interrupted);
223 }
else if (ShouldRetry(flags, timeout, elapsed)) {
226 return static_cast<int>(TransferCode::timeout);
229 return HandleErrors();
234 int64_t Send(std::vector<MessagePtr>& msgVec,
const int timeout = -1)
override
238 flags = ZMQ_DONTWAIT;
243 const unsigned int vecSize = msgVec.size();
249 for (
auto& msg : msgVec) {
251 std::memcpy(metas++, &(shmMsg->fMeta),
sizeof(
MetaHeader));
255 int64_t totalSize = 0;
256 int nbytes = zmq_msg_send(zmqMsg.Msg(), fSocket, flags);
258 assert(
static_cast<unsigned int>(nbytes) == (vecSize *
sizeof(
MetaHeader)));
260 for (
auto& msg : msgVec) {
262 shmMsg->fQueued =
true;
263 totalSize += shmMsg->fMeta.fSize;
268 fBytesTx += totalSize;
271 }
else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
272 if (fManager.Interrupted()) {
273 return static_cast<int>(TransferCode::interrupted);
274 }
else if (ShouldRetry(flags, timeout, elapsed)) {
277 return static_cast<int>(TransferCode::timeout);
280 return HandleErrors();
284 return static_cast<int>(TransferCode::error);
287 int64_t Receive(std::vector<MessagePtr>& msgVec,
const int timeout = -1)
override
291 flags = ZMQ_DONTWAIT;
298 int64_t totalSize = 0;
299 int nbytes = zmq_msg_recv(zmqMsg.Msg(), fSocket, flags);
302 const auto hdrVecSize = zmqMsg.Size();
304 assert(hdrVecSize > 0);
307 tools::ToString(
"Received message is not a valid FairMQ shared memory message. ",
308 "Possibly due to a misconfigured transport on the sender side. ",
309 "Expected size of ",
sizeof(
MetaHeader),
" bytes, received ", nbytes));
312 const auto numMessages = hdrVecSize /
sizeof(
MetaHeader);
313 msgVec.reserve(numMessages);
315 for (
size_t m = 0; m < numMessages; m++) {
317 msgVec.emplace_back(std::make_unique<Message>(fManager, hdrVec[m], GetTransport()));
319 totalSize += shmMsg->GetSize();
324 fBytesRx += totalSize;
327 }
else if (zmq_errno() == EAGAIN || zmq_errno() == EINTR) {
328 if (fManager.Interrupted()) {
329 return static_cast<int>(TransferCode::interrupted);
330 }
else if (ShouldRetry(flags, timeout, elapsed)) {
333 return static_cast<int>(TransferCode::timeout);
336 return HandleErrors();
340 return static_cast<int>(TransferCode::error);
343 void* GetSocket()
const {
return fSocket; }
345 void Close()
override
349 if (fSocket ==
nullptr) {
353 if (zmq_close(fSocket) != 0) {
354 LOG(error) <<
"Failed closing socket " << fId <<
", reason: " << zmq_strerror(errno);
360 void SetOption(
const std::string& option,
const void* value,
size_t valueSize)
override
362 if (zmq_setsockopt(fSocket, GetConstant(option), value, valueSize) < 0) {
363 LOG(error) <<
"Failed setting socket option, reason: " << zmq_strerror(errno);
367 void GetOption(
const std::string& option,
void* value,
size_t* valueSize)
override
369 if (zmq_getsockopt(fSocket, GetConstant(option), value, valueSize) < 0) {
370 LOG(error) <<
"Failed getting socket option, reason: " << zmq_strerror(errno);
374 void SetLinger(
const int value)
override
376 if (zmq_setsockopt(fSocket, ZMQ_LINGER, &value,
sizeof(value)) < 0) {
377 throw SocketError(tools::ToString(
"failed setting ZMQ_LINGER, reason: ", zmq_strerror(errno)));
383 size_t eventsSize =
sizeof(uint32_t);
384 if (zmq_getsockopt(fSocket, ZMQ_EVENTS, events, &eventsSize) < 0) {
385 throw SocketError(tools::ToString(
"failed setting ZMQ_EVENTS, reason: ", zmq_strerror(errno)));
389 int GetLinger()
const override
392 size_t valueSize =
sizeof(value);
393 if (zmq_getsockopt(fSocket, ZMQ_LINGER, &value, &valueSize) < 0) {
394 throw SocketError(tools::ToString(
"failed getting ZMQ_LINGER, reason: ", zmq_strerror(errno)));
399 void SetSndBufSize(
const int value)
override
401 if (zmq_setsockopt(fSocket, ZMQ_SNDHWM, &value,
sizeof(value)) < 0) {
402 throw SocketError(tools::ToString(
"failed setting ZMQ_SNDHWM, reason: ", zmq_strerror(errno)));
406 int GetSndBufSize()
const override
409 size_t valueSize =
sizeof(value);
410 if (zmq_getsockopt(fSocket, ZMQ_SNDHWM, &value, &valueSize) < 0) {
411 throw SocketError(tools::ToString(
"failed getting ZMQ_SNDHWM, reason: ", zmq_strerror(errno)));
416 void SetRcvBufSize(
const int value)
override
418 if (zmq_setsockopt(fSocket, ZMQ_RCVHWM, &value,
sizeof(value)) < 0) {
419 throw SocketError(tools::ToString(
"failed setting ZMQ_RCVHWM, reason: ", zmq_strerror(errno)));
423 int GetRcvBufSize()
const override
426 size_t valueSize =
sizeof(value);
427 if (zmq_getsockopt(fSocket, ZMQ_RCVHWM, &value, &valueSize) < 0) {
428 throw SocketError(tools::ToString(
"failed getting ZMQ_RCVHWM, reason: ", zmq_strerror(errno)));
433 void SetSndKernelSize(
const int value)
override
435 if (zmq_setsockopt(fSocket, ZMQ_SNDBUF, &value,
sizeof(value)) < 0) {
436 throw SocketError(tools::ToString(
"failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno)));
440 int GetSndKernelSize()
const override
443 size_t valueSize =
sizeof(value);
444 if (zmq_getsockopt(fSocket, ZMQ_SNDBUF, &value, &valueSize) < 0) {
445 throw SocketError(tools::ToString(
"failed getting ZMQ_SNDBUF, reason: ", zmq_strerror(errno)));
450 void SetRcvKernelSize(
const int value)
override
452 if (zmq_setsockopt(fSocket, ZMQ_RCVBUF, &value,
sizeof(value)) < 0) {
453 throw SocketError(tools::ToString(
"failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno)));
457 int GetRcvKernelSize()
const override
460 size_t valueSize =
sizeof(value);
461 if (zmq_getsockopt(fSocket, ZMQ_RCVBUF, &value, &valueSize) < 0) {
462 throw SocketError(tools::ToString(
"failed getting ZMQ_RCVBUF, reason: ", zmq_strerror(errno)));
467 unsigned long GetBytesTx()
const override {
return fBytesTx; }
468 unsigned long GetBytesRx()
const override {
return fBytesRx; }
469 unsigned long GetMessagesTx()
const override {
return fMessagesTx; }
470 unsigned long GetMessagesRx()
const override {
return fMessagesRx; }
472 static int GetConstant(
const std::string& constant)
474 if (constant ==
"")
return 0;
475 if (constant ==
"sub")
return ZMQ_SUB;
476 if (constant ==
"pub")
return ZMQ_PUB;
477 if (constant ==
"xsub")
return ZMQ_XSUB;
478 if (constant ==
"xpub")
return ZMQ_XPUB;
479 if (constant ==
"push")
return ZMQ_PUSH;
480 if (constant ==
"pull")
return ZMQ_PULL;
481 if (constant ==
"req")
return ZMQ_REQ;
482 if (constant ==
"rep")
return ZMQ_REP;
483 if (constant ==
"dealer")
return ZMQ_DEALER;
484 if (constant ==
"router")
return ZMQ_ROUTER;
485 if (constant ==
"pair")
return ZMQ_PAIR;
487 if (constant ==
"snd-hwm")
return ZMQ_SNDHWM;
488 if (constant ==
"rcv-hwm")
return ZMQ_RCVHWM;
489 if (constant ==
"snd-size")
return ZMQ_SNDBUF;
490 if (constant ==
"rcv-size")
return ZMQ_RCVBUF;
491 if (constant ==
"snd-more")
return ZMQ_SNDMORE;
492 if (constant ==
"rcv-more")
return ZMQ_RCVMORE;
494 if (constant ==
"linger")
return ZMQ_LINGER;
495 if (constant ==
"no-block")
return ZMQ_DONTWAIT;
496 if (constant ==
"snd-more no-block")
return ZMQ_DONTWAIT|ZMQ_SNDMORE;
498 if (constant ==
"fd")
return ZMQ_FD;
499 if (constant ==
"events")
501 if (constant ==
"pollin")
503 if (constant ==
"pollout")
506 throw SocketError(tools::ToString(
"GetConstant called with an invalid argument: ", constant));
509 ~Socket()
override { Close(); }
515 std::atomic<unsigned long> fBytesTx;
516 std::atomic<unsigned long> fBytesRx;
517 std::atomic<unsigned long> fMessagesTx;
518 std::atomic<unsigned long> fMessagesRx;