OLD | NEW |
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. | 1 // Copyright (c) 2012 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 #include "tools/android/forwarder2/socket.h" | 5 #include "tools/android/forwarder2/socket.h" |
6 | 6 |
7 #include <arpa/inet.h> | 7 #include <arpa/inet.h> |
8 #include <fcntl.h> | 8 #include <fcntl.h> |
9 #include <netdb.h> | 9 #include <netdb.h> |
10 #include <netinet/in.h> | 10 #include <netinet/in.h> |
(...skipping 18 matching lines...) Expand all Loading... |
29 do { \ | 29 do { \ |
30 int local_errno = errno; \ | 30 int local_errno = errno; \ |
31 (void) HANDLE_EINTR(Func); \ | 31 (void) HANDLE_EINTR(Func); \ |
32 errno = local_errno; \ | 32 errno = local_errno; \ |
33 } while (false); | 33 } while (false); |
34 | 34 |
35 | 35 |
36 namespace { | 36 namespace { |
37 const int kNoTimeout = -1; | 37 const int kNoTimeout = -1; |
38 const int kConnectTimeOut = 10; // Seconds. | 38 const int kConnectTimeOut = 10; // Seconds. |
| 39 |
| 40 bool FamilyIsTCP(int family) { |
| 41 return family == AF_INET || family == AF_INET6; |
| 42 } |
39 } // namespace | 43 } // namespace |
40 | 44 |
41 namespace forwarder2 { | 45 namespace forwarder2 { |
42 | 46 |
43 bool Socket::BindUnix(const std::string& path, bool abstract) { | 47 bool Socket::BindUnix(const std::string& path, bool abstract) { |
44 errno = 0; | 48 errno = 0; |
45 if (!InitUnixSocket(path, abstract) || !BindAndListen()) { | 49 if (!InitUnixSocket(path, abstract) || !BindAndListen()) { |
46 Close(); | 50 Close(); |
47 return false; | 51 return false; |
48 } | 52 } |
(...skipping 73 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
122 // For abstract sockets we need one extra byte for the leading zero. | 126 // For abstract sockets we need one extra byte for the leading zero. |
123 if ((abstract && path.size() + 2 /* '\0' */ > kPathMax) || | 127 if ((abstract && path.size() + 2 /* '\0' */ > kPathMax) || |
124 (!abstract && path.size() + 1 /* '\0' */ > kPathMax)) { | 128 (!abstract && path.size() + 1 /* '\0' */ > kPathMax)) { |
125 LOG(ERROR) << "The provided path is too big to create a unix " | 129 LOG(ERROR) << "The provided path is too big to create a unix " |
126 << "domain socket: " << path; | 130 << "domain socket: " << path; |
127 return false; | 131 return false; |
128 } | 132 } |
129 abstract_ = abstract; | 133 abstract_ = abstract; |
130 family_ = PF_UNIX; | 134 family_ = PF_UNIX; |
131 addr_.addr_un.sun_family = family_; | 135 addr_.addr_un.sun_family = family_; |
132 | |
133 if (abstract) { | 136 if (abstract) { |
134 // Copied from net/base/unix_domain_socket_posix.cc | 137 // Copied from net/base/unix_domain_socket_posix.cc |
135 // Convert the path given into abstract socket name. It must start with | 138 // Convert the path given into abstract socket name. It must start with |
136 // the '\0' character, so we are adding it. |addr_len| must specify the | 139 // the '\0' character, so we are adding it. |addr_len| must specify the |
137 // length of the structure exactly, as potentially the socket name may | 140 // length of the structure exactly, as potentially the socket name may |
138 // have '\0' characters embedded (although we don't support this). | 141 // have '\0' characters embedded (although we don't support this). |
139 // Note that addr_.addr_un.sun_path is already zero initialized. | 142 // Note that addr_.addr_un.sun_path is already zero initialized. |
140 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); | 143 memcpy(addr_.addr_un.sun_path + 1, path.c_str(), path.size()); |
141 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; | 144 addr_len_ = path.size() + offsetof(struct sockaddr_un, sun_path) + 1; |
142 } else { | 145 } else { |
143 memcpy(addr_.addr_un.sun_path, path.c_str(), path.size()); | 146 memcpy(addr_.addr_un.sun_path, path.c_str(), path.size()); |
144 addr_len_ = sizeof(sockaddr_un); | 147 addr_len_ = sizeof(sockaddr_un); |
145 } | 148 } |
146 | |
147 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); | 149 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr_un); |
148 return InitSocketInternal(); | 150 return InitSocketInternal(); |
149 } | 151 } |
150 | 152 |
151 bool Socket::InitTcpSocket(const std::string& host, int port) { | 153 bool Socket::InitTcpSocket(const std::string& host, int port) { |
152 port_ = port; | 154 port_ = port; |
153 | |
154 if (host.empty()) { | 155 if (host.empty()) { |
155 // Use localhost: INADDR_LOOPBACK | 156 // Use localhost: INADDR_LOOPBACK |
156 family_ = AF_INET; | 157 family_ = AF_INET; |
157 addr_.addr4.sin_family = family_; | 158 addr_.addr4.sin_family = family_; |
158 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); | 159 addr_.addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); |
159 } else if (!Resolve(host)) { | 160 } else if (!Resolve(host)) { |
160 return false; | 161 return false; |
161 } | 162 } |
162 CHECK(family_ == AF_INET || family_ == AF_INET6) | 163 CHECK(FamilyIsTCP(family_)) << "Invalid socket family."; |
163 << "Invalid socket family."; | |
164 if (family_ == AF_INET) { | 164 if (family_ == AF_INET) { |
165 addr_.addr4.sin_port = htons(port_); | 165 addr_.addr4.sin_port = htons(port_); |
166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); | 166 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr4); |
167 addr_len_ = sizeof(addr_.addr4); | 167 addr_len_ = sizeof(addr_.addr4); |
168 } else if (family_ == AF_INET6) { | 168 } else if (family_ == AF_INET6) { |
169 addr_.addr6.sin6_port = htons(port_); | 169 addr_.addr6.sin6_port = htons(port_); |
170 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); | 170 addr_ptr_ = reinterpret_cast<sockaddr*>(&addr_.addr6); |
171 addr_len_ = sizeof(addr_.addr6); | 171 addr_len_ = sizeof(addr_.addr6); |
172 } | 172 } |
173 return InitSocketInternal(); | 173 return InitSocketInternal(); |
174 } | 174 } |
175 | 175 |
176 bool Socket::BindAndListen() { | 176 bool Socket::BindAndListen() { |
177 errno = 0; | 177 errno = 0; |
178 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || | 178 if (HANDLE_EINTR(bind(socket_, addr_ptr_, addr_len_)) < 0 || |
179 HANDLE_EINTR(listen(socket_, 5)) < 0) { | 179 HANDLE_EINTR(listen(socket_, 5)) < 0) { |
180 SetSocketError(); | 180 SetSocketError(); |
181 return false; | 181 return false; |
182 } | 182 } |
183 if (port_ == 0) { | 183 if (port_ == 0 && FamilyIsTCP(family_)) { |
184 SockAddr addr; | 184 SockAddr addr; |
185 memset(&addr, 0, sizeof(addr)); | 185 memset(&addr, 0, sizeof(addr)); |
186 socklen_t addrlen = 0; | 186 socklen_t addrlen = 0; |
187 sockaddr* addr_ptr = NULL; | 187 sockaddr* addr_ptr = NULL; |
188 uint16* port_ptr = NULL; | 188 uint16* port_ptr = NULL; |
189 if (family_ == AF_INET) { | 189 if (family_ == AF_INET) { |
190 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); | 190 addr_ptr = reinterpret_cast<sockaddr*>(&addr.addr4); |
191 port_ptr = &addr.addr4.sin_port; | 191 port_ptr = &addr.addr4.sin_port; |
192 addrlen = sizeof(addr.addr4); | 192 addrlen = sizeof(addr.addr4); |
193 } else if (family_ == AF_INET6) { | 193 } else if (family_ == AF_INET6) { |
(...skipping 24 matching lines...) Expand all Loading... |
218 SetSocketError(); | 218 SetSocketError(); |
219 return false; | 219 return false; |
220 } | 220 } |
221 | 221 |
222 tools::DisableNagle(new_socket_fd); | 222 tools::DisableNagle(new_socket_fd); |
223 new_socket->socket_ = new_socket_fd; | 223 new_socket->socket_ = new_socket_fd; |
224 return true; | 224 return true; |
225 } | 225 } |
226 | 226 |
227 bool Socket::Connect() { | 227 bool Socket::Connect() { |
228 // Set non-block because we use select. | 228 // Set non-block because we use select for connect. |
229 fcntl(socket_, F_SETFL, fcntl(socket_, F_GETFL) | O_NONBLOCK); | 229 const int kFlags = fcntl(socket_, F_GETFL); |
| 230 DCHECK(!(kFlags & O_NONBLOCK)); |
| 231 fcntl(socket_, F_SETFL, kFlags | O_NONBLOCK); |
230 errno = 0; | 232 errno = 0; |
231 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && | 233 if (HANDLE_EINTR(connect(socket_, addr_ptr_, addr_len_)) < 0 && |
232 errno != EINPROGRESS) { | 234 errno != EINPROGRESS) { |
233 SetSocketError(); | 235 SetSocketError(); |
| 236 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); |
234 return false; | 237 return false; |
235 } | 238 } |
236 // Wait for connection to complete, or receive a notification. | 239 // Wait for connection to complete, or receive a notification. |
237 if (!WaitForEvent(WRITE, kConnectTimeOut)) { | 240 if (!WaitForEvent(WRITE, kConnectTimeOut)) { |
238 SetSocketError(); | 241 SetSocketError(); |
| 242 PRESERVE_ERRNO_HANDLE_EINTR(fcntl(socket_, F_SETFL, kFlags)); |
239 return false; | 243 return false; |
240 } | 244 } |
| 245 // Disable non-block since our code assumes blocking semantics. |
| 246 fcntl(socket_, F_SETFL, kFlags); |
241 return true; | 247 return true; |
242 } | 248 } |
243 | 249 |
244 bool Socket::Resolve(const std::string& host) { | 250 bool Socket::Resolve(const std::string& host) { |
245 struct addrinfo hints; | 251 struct addrinfo hints; |
246 struct addrinfo* res; | 252 struct addrinfo* res; |
247 memset(&hints, 0, sizeof(hints)); | 253 memset(&hints, 0, sizeof(hints)); |
248 hints.ai_family = AF_UNSPEC; | 254 hints.ai_family = AF_UNSPEC; |
249 hints.ai_socktype = SOCK_STREAM; | 255 hints.ai_socktype = SOCK_STREAM; |
250 hints.ai_flags |= AI_CANONNAME; | 256 hints.ai_flags |= AI_CANONNAME; |
(...skipping 13 matching lines...) Expand all Loading... |
264 case AF_INET6: | 270 case AF_INET6: |
265 memcpy(&addr_.addr6, | 271 memcpy(&addr_.addr6, |
266 reinterpret_cast<sockaddr_in6*>(res->ai_addr), | 272 reinterpret_cast<sockaddr_in6*>(res->ai_addr), |
267 sizeof(sockaddr_in6)); | 273 sizeof(sockaddr_in6)); |
268 break; | 274 break; |
269 } | 275 } |
270 return true; | 276 return true; |
271 } | 277 } |
272 | 278 |
273 int Socket::GetPort() { | 279 int Socket::GetPort() { |
274 if (family_ != AF_INET && family_ != AF_INET6) { | 280 if (!FamilyIsTCP(family_)) { |
275 LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; | 281 LOG(ERROR) << "Can't call GetPort() on an unix domain socket."; |
276 return 0; | 282 return 0; |
277 } | 283 } |
278 return port_; | 284 return port_; |
279 } | 285 } |
280 | 286 |
281 bool Socket::IsFdInSet(const fd_set& fds) const { | 287 bool Socket::IsFdInSet(const fd_set& fds) const { |
282 if (IsClosed()) | 288 if (IsClosed()) |
283 return false; | 289 return false; |
284 return FD_ISSET(socket_, &fds); | 290 return FD_ISSET(socket_, &fds); |
(...skipping 83 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
368 return !FD_ISSET(exit_notifier_fd_, &read_fds); | 374 return !FD_ISSET(exit_notifier_fd_, &read_fds); |
369 } | 375 } |
370 | 376 |
371 // static | 377 // static |
372 int Socket::GetHighestFileDescriptor( | 378 int Socket::GetHighestFileDescriptor( |
373 const Socket& s1, const Socket& s2) { | 379 const Socket& s1, const Socket& s2) { |
374 return std::max(s1.socket_, s2.socket_); | 380 return std::max(s1.socket_, s2.socket_); |
375 } | 381 } |
376 | 382 |
377 } // namespace forwarder | 383 } // namespace forwarder |
OLD | NEW |