OLD | NEW |
1 // Copyright 2016 The Chromium Authors. All rights reserved. | 1 // Copyright 2016 The Chromium Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style license that can be | 2 // Use of this source code is governed by a BSD-style license that can be |
3 // found in the LICENSE file. | 3 // found in the LICENSE file. |
4 | 4 |
5 package prpc | 5 package prpc |
6 | 6 |
7 import ( | 7 import ( |
| 8 "fmt" |
8 "net/http" | 9 "net/http" |
9 "sort" | 10 "sort" |
10 "sync" | 11 "sync" |
11 | 12 |
12 "github.com/julienschmidt/httprouter" | 13 "github.com/julienschmidt/httprouter" |
13 "golang.org/x/net/context" | 14 "golang.org/x/net/context" |
14 "google.golang.org/grpc" | 15 "google.golang.org/grpc" |
| 16 "google.golang.org/grpc/codes" |
15 | 17 |
| 18 "github.com/luci/luci-go/common/logging" |
16 "github.com/luci/luci-go/server/auth" | 19 "github.com/luci/luci-go/server/auth" |
17 "github.com/luci/luci-go/server/middleware" | 20 "github.com/luci/luci-go/server/middleware" |
18 ) | 21 ) |
19 | 22 |
20 // Server is a pRPC server to serve RPC requests. | 23 // Server is a pRPC server to serve RPC requests. |
21 // Zero value is valid. | 24 // Zero value is valid. |
22 type Server struct { | 25 type Server struct { |
23 // CustomAuthenticator, if true, disables the forced authentication set
by | 26 // CustomAuthenticator, if true, disables the forced authentication set
by |
24 // RegisterDefaultAuth. | 27 // RegisterDefaultAuth. |
25 CustomAuthenticator bool | 28 CustomAuthenticator bool |
(...skipping 22 matching lines...) Expand all Loading... |
48 desc: grpcDesc, | 51 desc: grpcDesc, |
49 } | 52 } |
50 } | 53 } |
51 | 54 |
52 s.mu.Lock() | 55 s.mu.Lock() |
53 defer s.mu.Unlock() | 56 defer s.mu.Unlock() |
54 | 57 |
55 if s.services == nil { | 58 if s.services == nil { |
56 s.services = map[string]*service{} | 59 s.services = map[string]*service{} |
57 } else if _, ok := s.services[desc.ServiceName]; ok { | 60 } else if _, ok := s.services[desc.ServiceName]; ok { |
58 » » panicf("service %q is already registered", desc.ServiceName) | 61 » » panic(fmt.Errorf("service %q is already registered", desc.Servic
eName)) |
59 } | 62 } |
60 | 63 |
61 s.services[desc.ServiceName] = serv | 64 s.services[desc.ServiceName] = serv |
62 } | 65 } |
63 | 66 |
64 // authenticate forces authentication set by RegisterDefaultAuth. | 67 // authenticate forces authentication set by RegisterDefaultAuth. |
65 func (s *Server) authenticate(base middleware.Base) middleware.Base { | 68 func (s *Server) authenticate(base middleware.Base) middleware.Base { |
66 a := GetDefaultAuth() | 69 a := GetDefaultAuth() |
67 if a == nil { | 70 if a == nil { |
68 » » panicf("prpc: CustomAuthenticator is false, but default authenti
cator was not registered. " + | 71 » » panic("prpc: CustomAuthenticator is false, but default authentic
ator was not registered. " + |
69 "Forgot to import appengine/gaeauth/server package?") | 72 "Forgot to import appengine/gaeauth/server package?") |
70 } | 73 } |
71 | 74 |
72 return func(h middleware.Handler) httprouter.Handle { | 75 return func(h middleware.Handler) httprouter.Handle { |
73 return base(func(c context.Context, w http.ResponseWriter, r *ht
tp.Request, p httprouter.Params) { | 76 return base(func(c context.Context, w http.ResponseWriter, r *ht
tp.Request, p httprouter.Params) { |
74 c = auth.SetAuthenticator(c, a) | 77 c = auth.SetAuthenticator(c, a) |
75 c, err := a.Authenticate(c, r) | 78 c, err := a.Authenticate(c, r) |
76 if err != nil { | 79 if err != nil { |
77 » » » » writeError(c, w, withStatus(err, http.StatusUnau
thorized)) | 80 » » » » res := errResponse(codes.Unauthenticated, http.S
tatusUnauthorized, err.Error()) |
| 81 » » » » res.write(c, w) |
78 return | 82 return |
79 } | 83 } |
80 h(c, w, r, p) | 84 h(c, w, r, p) |
81 }) | 85 }) |
82 } | 86 } |
83 } | 87 } |
84 | 88 |
85 // InstallHandlers installs HTTP POST handlers at | 89 // InstallHandlers installs HTTP handlers at /prpc/:service/:method. |
86 // /prpc/{service_name}/{method_name} for all registered services. | 90 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 91 // for pRPC protocol. |
87 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { | 92 func (s *Server) InstallHandlers(r *httprouter.Router, base middleware.Base) { |
88 s.mu.Lock() | 93 s.mu.Lock() |
89 defer s.mu.Unlock() | 94 defer s.mu.Unlock() |
90 | 95 |
91 if !s.CustomAuthenticator { | 96 if !s.CustomAuthenticator { |
92 base = s.authenticate(base) | 97 base = s.authenticate(base) |
93 } | 98 } |
94 | 99 |
95 » for _, service := range s.services { | 100 » r.POST("/prpc/:service/:method", base(s.handle)) |
96 » » for _, m := range service.methods { | 101 } |
97 » » » m.InstallHandlers(r, base) | 102 |
98 » » } | 103 // handle handles RPCs. |
| 104 // See https://godoc.org/github.com/luci/luci-go/common/prpc#hdr-Protocol |
| 105 // for pRPC protocol. |
| 106 func (s *Server) handle(c context.Context, w http.ResponseWriter, r *http.Reques
t, p httprouter.Params) { |
| 107 » serviceName := p.ByName("service") |
| 108 » methodName := p.ByName("method") |
| 109 » res := s.respond(c, w, r, serviceName, methodName) |
| 110 |
| 111 » c = logging.SetFields(c, logging.Fields{ |
| 112 » » "service": serviceName, |
| 113 » » "method": methodName, |
| 114 » }) |
| 115 » res.write(c, w) |
| 116 } |
| 117 |
| 118 func (s *Server) respond(c context.Context, w http.ResponseWriter, r *http.Reque
st, serviceName, methodName string) *response { |
| 119 » service := s.services[serviceName] |
| 120 » if service == nil { |
| 121 » » return errResponse( |
| 122 » » » codes.Unimplemented, |
| 123 » » » http.StatusNotImplemented, |
| 124 » » » fmt.Sprintf("service %q is not implemented", serviceName
)) |
99 } | 125 } |
| 126 |
| 127 method := service.methods[methodName] |
| 128 if method == nil { |
| 129 return errResponse( |
| 130 codes.Unimplemented, |
| 131 http.StatusNotImplemented, |
| 132 fmt.Sprintf("method %q in service %q is not implemented"
, methodName, serviceName)) |
| 133 } |
| 134 |
| 135 return method.handle(c, w, r) |
100 } | 136 } |
101 | 137 |
102 // ServiceNames returns a sorted list of full names of all registered services. | 138 // ServiceNames returns a sorted list of full names of all registered services. |
103 func (s *Server) ServiceNames() []string { | 139 func (s *Server) ServiceNames() []string { |
104 s.mu.Lock() | 140 s.mu.Lock() |
105 defer s.mu.Unlock() | 141 defer s.mu.Unlock() |
106 | 142 |
107 names := make([]string, 0, len(s.services)) | 143 names := make([]string, 0, len(s.services)) |
108 for name := range s.services { | 144 for name := range s.services { |
109 names = append(names, name) | 145 names = append(names, name) |
110 } | 146 } |
111 sort.Strings(names) | 147 sort.Strings(names) |
112 return names | 148 return names |
113 } | 149 } |
OLD | NEW |