Index: tools/android/forwarder2/socket.cc |
diff --git a/tools/android/forwarder2/socket.cc b/tools/android/forwarder2/socket.cc |
index 898afc3d4b4dccfee8a75879a7d8bbac162668c1..35553bbee6c9ce93a9ccfc9d26d1189b5e7ea06d 100644 |
--- a/tools/android/forwarder2/socket.cc |
+++ b/tools/android/forwarder2/socket.cc |
@@ -36,6 +36,10 @@ |
namespace { |
const int kNoTimeout = -1; |
const int kConnectTimeOut = 10; // Seconds. |
+ |
+bool FamilyIsTCP(int family) { |
+ return family == AF_INET || family == AF_INET6; |
+} |
} // namespace |
namespace forwarder2 { |
@@ -129,7 +133,6 @@ bool Socket::InitUnixSocket(const std::string& path, bool abstract) { |
abstract_ = abstract; |
family_ = PF_UNIX; |
addr_.addr_un.sun_family = family_; |
- |
if (abstract) { |
// Copied from net/base/unix_domain_socket_posix.cc |
// Convert the path given into abstract socket name. It must start with |
@@ -143,14 +146,12 @@ bool Socket::InitUnixSocket(const std::string& path, bool abstract) { |
memcpy(addr_.addr_un.sun_path, path.c_str(), path.size()); |
addr_len_ = sizeof(sockaddr_un); |
} |
- |
addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); |
return InitSocketInternal(); |
} |
bool Socket::InitTcpSocket(const std::string& host, int port) { |
port_ = port; |
- |
if (host.empty()) { |
// Use localhost: INADDR_LOOPBACK |
family_ = AF_INET; |
@@ -159,8 +160,7 @@ bool Socket::InitTcpSocket(const std::string& host, int port) { |
} else if (!Resolve(host)) { |
return false; |
} |
- CHECK(family_ == AF_INET || family_ == AF_INET6) |
- << "Invalid socket family."; |
+ CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; |
if (family_ == AF_INET) { |
addr_.addr4.sin_port = htons(port_); |
addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); |
@@ -180,7 +180,7 @@ bool Socket::BindAndListen() { |
SetSocketError(); |
return false; |
} |
- if (port_ == 0) { |
+ if (port_ == 0 && FamilyIsTCP(family_)) { |
SockAddr addr; |
memset(&addr, 0, sizeof(addr)); |
socklen_t addrlen = 0; |
@@ -225,19 +225,25 @@ bool Socket::Accept(Socket* new_socket) { |
} |
bool Socket::Connect() { |
- // Set non-block because we use select. |
- fcntl(socket_, F_SETFL, fcntl(socket_, F_GETFL) | O_NONBLOCK); |
+ // Set non-block because we use select for connect. |
+ const int kFlags = fcntl(socket_, F_GETFL); |
+ DCHECK(!(kFlags & O_NONBLOCK)); |
+ fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK); |
errno = 0; |
if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && |
errno != EINPROGRESS) { |
SetSocketError(); |
+ PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); |
return false; |
} |
// Wait for connection to complete, or receive a notification. |
if (!WaitForEvent(WRITE, kConnectTimeOut)) { |
SetSocketError(); |
+ PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); |
return false; |
} |
+ // Disable non-block since our code assumes blocking semantics. |
+ fcntl(socket_, F_SETFL, kFlags); |
return true; |
} |
@@ -271,7 +277,7 @@ bool Socket::Resolve(const std::string& host) { |
} |
int Socket::GetPort() { |
- if (family_ != AF_INET && family_ != AF_INET6) { |
+ if (!FamilyIsTCP(family_)) { |
LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; |
return 0; |
} |