Commit 3add02c6 authored by Eric Cano's avatar Eric Cano
Browse files

Added a SocketPair class for IPC communication between parent and child

processes. The socket transmits and queues variable size messages.
parent 1ff9b9ab
......@@ -92,6 +92,7 @@ set (COMMON_LIB_SRC_FILES
threading/ChildProcess.cpp
threading/Daemon.cpp
threading/Mutex.cpp
threading/SocketPair.cpp
threading/System.cpp
threading/Threading.cpp
utils/utils.cpp
......@@ -124,6 +125,7 @@ set (COMMON_UNIT_TESTS_LIB_SRC_FILES
log/StringLoggerTest.cpp
remoteFS/RemotePathTest.cpp
threading/DaemonTest.cpp
threading/SocketPairTest.cpp
utils/UtilsTest.cpp
UserIdentityTest.cpp)
......
/*
* The CERN Tape Archive (CTA) project
* Copyright (C) 2015 CERN
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#include "common/threading/SocketPair.hpp"
#include "common/exception/Errnum.hpp"
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#include <poll.h>
#include <memory>
#include <list>
namespace cta { namespace server {
//------------------------------------------------------------------------------
// Constructor
//------------------------------------------------------------------------------
SocketPair::SocketPair() {
int fd[2];
cta::exception::Errnum::throwOnMinusOne(
::socketpair(AF_LOCAL, SOCK_SEQPACKET, 0, fd),
"In SocketPair::SocketPair(): failed to socketpair(): ");
m_parentFd = fd[0];
m_childFd = fd[1];
if (m_parentFd < 0 || m_childFd < 0) {
std::stringstream err;
err << "In SocketPair::SocketPair(): unexpected file descriptor: "
<< "fd[0]=" << fd[0] << " fd[1]=" << fd[1];
throw cta::exception::Exception(err.str());
}
}
//------------------------------------------------------------------------------
// Destructor
//------------------------------------------------------------------------------
SocketPair::~SocketPair() {
if (m_parentFd != -1)
::close(m_parentFd);
if (m_childFd != -1)
::close(m_childFd);
}
//------------------------------------------------------------------------------
// SocketPair::close
//------------------------------------------------------------------------------
void SocketPair::close(Side sideToClose) {
if (m_currentSide != Side::both)
throw CloseAlreadyCalled("In SocketPair::close(): one side was already closed");
switch(sideToClose) {
case Side::child:
::close(m_childFd);
m_childFd = -1;
m_currentSide = Side::parent;
break;
case Side::parent:
::close(m_parentFd);
m_parentFd = -1;
m_currentSide = Side::child;
break;
default:
throw cta::exception::Exception("In SocketPair::close(): invalid side");
}
}
//------------------------------------------------------------------------------
// SocketPair::close
//------------------------------------------------------------------------------
bool SocketPair::pollFlag() {
return m_pollFlag;
}
//------------------------------------------------------------------------------
// SocketPair::poll
//------------------------------------------------------------------------------
void SocketPair::poll(pollMap& socketPairs, time_t timeout, Side sourceToPoll) {
std::unique_ptr<struct ::pollfd[]> fds(new ::pollfd[socketPairs.size()]);
struct ::pollfd *fdsp=fds.get();
std::list<std::string> keys;
for (const auto & sp: socketPairs) {
keys.push_back(sp.first);
fdsp->fd = sp.second->getFdForAccess(sourceToPoll);
fdsp->revents = 0;
fdsp->events = POLLIN;
fdsp++;
}
int rc=::poll(fds.get(), socketPairs.size(), timeout * 1000);
if (rc > 0) {
// We have readable fds, copy the results in the provided map
fdsp=fds.get();
for (const auto & key: keys) {
socketPairs.at(key)->m_pollFlag = (fdsp++)->revents & POLLIN;
}
} else if (!rc) {
throw Timeout("In SocketPair::poll(): timeout");
} else {
throw cta::exception::Errnum("In SocketPair::poll(): failed to poll(): ");
}
}
//------------------------------------------------------------------------------
// SocketPair::getFdForAccess
//------------------------------------------------------------------------------
int SocketPair::getFdForAccess(Side sourceOrDestination) {
// First, make sure the source to poll makes sense.
// There is an inversion here. If our current side is parent, we should
// read from the child and vice versa.
Side sideForThisPair = sourceOrDestination;
switch (sideForThisPair) {
case Side::current:
switch(m_currentSide) {
case Side::child:
sideForThisPair = Side::parent;
goto done;
case Side::parent:
sideForThisPair = Side::child;
goto done;
default:
throw cta::exception::Exception("In SocketPair::getFdForPoll(): invalid side (current)");
}
case Side::child:
sideForThisPair = Side::parent;
break;
case Side::parent:
sideForThisPair = Side::child;
break;
default:
throw cta::exception::Exception("In SocketPair::poll(): invalid side (both)");
}
done:
// Now make sure the file descriptor is valid.
int fd;
switch (sideForThisPair) {
case Side::child:
fd = m_childFd;
break;
case Side::parent:
fd = m_parentFd;
break;
default:
throw cta::exception::Exception("In SocketPair::poll(): invalid sideForThisPair (internal error)");
}
if (-1 == fd)
throw cta::exception::Exception("In SocketPair::poll(): file descriptor is closed");
return fd;
}
//------------------------------------------------------------------------------
// SocketPair::receive
//------------------------------------------------------------------------------
std::string SocketPair::receive(Side source) {
int fd=getFdForAccess(source);
char buff[2048];
struct ::msghdr hdr;
struct ::iovec iov;
hdr.msg_name = nullptr;
hdr.msg_namelen = 0;
hdr.msg_iov = &iov;
hdr.msg_iovlen = 1;
hdr.msg_iov->iov_base = (void*)buff;
hdr.msg_iov->iov_len = sizeof(buff);
hdr.msg_control = nullptr;
hdr.msg_controllen = 0;
hdr.msg_flags = 0;
ssize_t size=recvmsg(fd, &hdr, MSG_DONTWAIT);
if (size > 0) {
if (hdr.msg_flags & MSG_TRUNC) {
throw Overflow("In SocketPair::receive(): message was truncated.");
}
std::string ret;
ret.append(buff, size);
return ret;
} else if (!size) {
throw PeerDisconnected("In SocketPair::receive(): connection reset by peer.");
} else {
if (errno == EAGAIN) {
throw NothingToReceive("In SocketPair::receive(): nothing to receive.");
} else {
throw cta::exception::Errnum("In SocketPair::receive(): failed to recv(): ");
}
}
}
//------------------------------------------------------------------------------
// SocketPair::send
//------------------------------------------------------------------------------
void SocketPair::send(const std::string& msg, Side destination) {
int fd=getFdForAccess(destination);
cta::exception::Errnum::throwOnMinusOne(::send(fd, msg.data(), msg.size(), 0),
"In SocketPair::send(): failed to send(): ");
}
}} // namespace cta::server
\ No newline at end of file
/*
* The CERN Tape Archive (CTA) project
* Copyright (C) 2015 CERN
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
#pragma once
#include <string>
#include <stdint.h>
#include <map>
#include "common/exception/Exception.hpp"
namespace cta { namespace server {
/**
* A class implementing a datagram communication between a parent process and
* its child. The communication is achieved of a unix domain socket of type
* datagram. It will hence allow transmission of bound binary packets.
* Higher level class will have the duty to implement semantics on top of that.
*/
class SocketPair {
public:
/// Constructor: opens the socket pair.
SocketPair();
/// Destructor: closes the remaining socketpairs
~SocketPair();
/// Enum allowing description of sides (parent, child)
enum class Side: uint8_t {
parent,
child,
current,
both
};
CTA_GENERATE_EXCEPTION_CLASS(CloseAlreadyCalled);
/// Close one side (after forking)
void close(Side sideToClose);
/// Send a buffer (optional side parameter allows use without closing,
/// useful for testing).
void send(const std::string& msg, Side destination = Side::current);
CTA_GENERATE_EXCEPTION_CLASS(NothingToReceive);
CTA_GENERATE_EXCEPTION_CLASS(PeerDisconnected);
/// Receive a buffer immediately (optional side parameter allows use without
/// closing, useful for testing).
std::string receive(Side source = Side::current);
/// A typedef used to store socketpairs to be passed to ppoll.
typedef std::map<std::string, SocketPair *> pollMap;
CTA_GENERATE_EXCEPTION_CLASS(Timeout);
CTA_GENERATE_EXCEPTION_CLASS(Overflow);
/// Poll the socketpairs listed in the map for reading (optional side
/// parameter allows use without closing, useful for testing).
static void poll(pollMap & socketPairs, time_t timeout,
Side sourceToPoll = Side::current);
/// Flag holding the result of a poll for a given socketpair.
bool pollFlag();
private:
int m_parentFd = -1; ///< The file descriptor for the
int m_childFd = -1;
Side m_currentSide = Side::both;
bool m_pollFlag = false;
/// An internal helper function getting the right file descriptor for
/// a given source or destination. With checks.
int getFdForAccess(Side sourceOrDestination);
};
}} // namespace cta::server
/******************************************************************************
*
* This file is part of the Castor project.
* See http://castor.web.cern.ch/castor
*
* Copyright (C) 2003 CERN
* This program is free software; you can redistribute it and/or
* modify it under the terms of the GNU General Public License
* as published by the Free Software Foundation; either version 2
* of the License, or (at your option) any later version.
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*
*
*
* @author Castor Dev team, castor-dev@cern.ch
*****************************************************************************/
#include <gtest/gtest.h>
#include <algorithm>
#include "common/threading/SocketPair.hpp"
#include "common/exception/Errnum.hpp"
namespace unitTests {
TEST(cta_threading_SocketPair, BasicTest) {
using cta::server::SocketPair;
cta::server::SocketPair sp0, sp1;
SocketPair::pollMap pollList;
pollList["0"] = &sp0;
pollList["1"] = &sp1;
sp0.send("C2P0", SocketPair::Side::parent);
sp0.send("P2C0", SocketPair::Side::child);
// We should have something to read
SocketPair::poll(pollList, 0, SocketPair::Side::parent);
ASSERT_TRUE(sp0.pollFlag());
ASSERT_FALSE(sp1.pollFlag());
ASSERT_EQ("P2C0", sp0.receive(SocketPair::Side::parent));
// Nothing to read (= timeout)
ASSERT_THROW(SocketPair::poll(pollList, 0, SocketPair::Side::parent), cta::server::SocketPair::Timeout);
// We should have something to read from child.
SocketPair::poll(pollList, 0, SocketPair::Side::child);
ASSERT_TRUE(sp0.pollFlag());
ASSERT_FALSE(sp1.pollFlag());
ASSERT_EQ("C2P0", sp0.receive(SocketPair::Side::child));
ASSERT_THROW(sp0.receive(SocketPair::Side::child), SocketPair::NothingToReceive);
}
TEST(cta_threading_SocketPair, Multimessages) {
using cta::server::SocketPair;
cta::server::SocketPair sp;
SocketPair::pollMap pollList;
pollList["0"] = &sp;
sp.send("C2P0", SocketPair::Side::parent);
sp.send("C2P1", SocketPair::Side::parent);
sp.send("C2P2", SocketPair::Side::parent);
// We should have something to read
SocketPair::poll(pollList, 0, SocketPair::Side::child);
ASSERT_TRUE(sp.pollFlag());
// Read 2 messages
ASSERT_EQ("C2P0", sp.receive(SocketPair::Side::child));
ASSERT_EQ("C2P1", sp.receive(SocketPair::Side::child));
// We should still something to read
SocketPair::poll(pollList, 0, SocketPair::Side::child);
ASSERT_TRUE(sp.pollFlag());
// Read 2 messages (2nd should fail)
ASSERT_EQ("C2P2", sp.receive(SocketPair::Side::child));
ASSERT_THROW(sp.receive(SocketPair::Side::child), SocketPair::NothingToReceive);
// Nothing to read (= timeout)
ASSERT_THROW(SocketPair::poll(pollList, 0, SocketPair::Side::child), cta::server::SocketPair::Timeout);
}
TEST(cta_threading_SocketPair, MaxLength) {
// We should be able to read up to 2048 bytes (this is an internal limit that
// could be raised)
// Limit to send is higher
// 1) prepare messages.
std::string smallMessage = "Hello!";
std::string maxMessage;
int i = 0;
maxMessage.resize(2048, '.');
std::for_each(maxMessage.begin(), maxMessage.end(), [&](char &c){ c='A' + (i++ % 26);});
std::string oversizeMessage;
oversizeMessage.resize(2049, '.');
// 2) send/receive them
using cta::server::SocketPair;
cta::server::SocketPair sp;
sp.send(smallMessage, SocketPair::Side::parent);
sp.send(maxMessage, SocketPair::Side::parent);
sp.send(oversizeMessage, SocketPair::Side::parent);
sp.send(smallMessage, SocketPair::Side::parent);
ASSERT_EQ(smallMessage, sp.receive(SocketPair::Side::child));
ASSERT_EQ(maxMessage, sp.receive(SocketPair::Side::child));
ASSERT_THROW(sp.receive(SocketPair::Side::child), SocketPair::Overflow);
ASSERT_EQ(smallMessage, sp.receive(SocketPair::Side::child));
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment