1use crate::pb::health_server::{Health, HealthServer};
4use crate::pb::{HealthCheckRequest, HealthCheckResponse};
5use crate::ServingStatus;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::sync::{watch, RwLock};
10use tokio_stream::Stream;
11#[cfg(feature = "transport")]
12use tonic::server::NamedService;
13use tonic::{Request, Response, Status};
14
15pub fn health_reporter() -> (HealthReporter, HealthServer<impl Health>) {
24 let reporter = HealthReporter::new();
25 let service = HealthService::new(reporter.statuses.clone());
26 let server = HealthServer::new(service);
27
28 (reporter, server)
29}
30
31type StatusPair = (watch::Sender<ServingStatus>, watch::Receiver<ServingStatus>);
32
33#[derive(Clone, Debug)]
37pub struct HealthReporter {
38 statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
39}
40
41impl HealthReporter {
42 fn new() -> Self {
43 let server_status = ("".to_string(), watch::channel(ServingStatus::Serving));
45
46 let statuses = Arc::new(RwLock::new(HashMap::from([server_status])));
47
48 HealthReporter { statuses }
49 }
50
51 #[cfg(feature = "transport")]
54 pub async fn set_serving<S>(&mut self)
55 where
56 S: NamedService,
57 {
58 let service_name = <S as NamedService>::NAME;
59 self.set_service_status(service_name, ServingStatus::Serving)
60 .await;
61 }
62
63 #[cfg(feature = "transport")]
66 pub async fn set_not_serving<S>(&mut self)
67 where
68 S: NamedService,
69 {
70 let service_name = <S as NamedService>::NAME;
71 self.set_service_status(service_name, ServingStatus::NotServing)
72 .await;
73 }
74
75 pub async fn set_service_status<S>(&mut self, service_name: S, status: ServingStatus)
78 where
79 S: AsRef<str>,
80 {
81 let service_name = service_name.as_ref();
82 let mut writer = self.statuses.write().await;
83 match writer.get(service_name) {
84 Some((tx, _)) => {
85 tx.send(status).expect("channel should not be closed");
90 }
91 None => {
92 writer.insert(service_name.to_string(), watch::channel(status));
93 }
94 };
95 }
96
97 pub async fn clear_service_status(&mut self, service_name: &str) {
99 let mut writer = self.statuses.write().await;
100 let _ = writer.remove(service_name);
101 }
102}
103
104#[derive(Debug)]
106pub struct HealthService {
107 statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
108}
109
110impl HealthService {
111 fn new(services: Arc<RwLock<HashMap<String, StatusPair>>>) -> Self {
112 HealthService { statuses: services }
113 }
114
115 async fn service_health(&self, service_name: &str) -> Option<ServingStatus> {
116 let reader = self.statuses.read().await;
117 reader.get(service_name).map(|p| *p.1.borrow())
118 }
119}
120
121#[tonic::async_trait]
122impl Health for HealthService {
123 async fn check(
124 &self,
125 request: Request<HealthCheckRequest>,
126 ) -> Result<Response<HealthCheckResponse>, Status> {
127 let service_name = request.get_ref().service.as_str();
128 let status = self.service_health(service_name).await;
129
130 match status {
131 None => Err(Status::not_found("service not registered")),
132 Some(status) => Ok(Response::new(HealthCheckResponse {
133 status: crate::pb::health_check_response::ServingStatus::from(status) as i32,
134 })),
135 }
136 }
137
138 type WatchStream =
139 Pin<Box<dyn Stream<Item = Result<HealthCheckResponse, Status>> + Send + 'static>>;
140
141 async fn watch(
142 &self,
143 request: Request<HealthCheckRequest>,
144 ) -> Result<Response<Self::WatchStream>, Status> {
145 let service_name = request.get_ref().service.as_str();
146 let mut status_rx = match self.statuses.read().await.get(service_name) {
147 None => return Err(Status::not_found("service not registered")),
148 Some(pair) => pair.1.clone(),
149 };
150
151 let output = async_stream::try_stream! {
152 let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
154 yield HealthCheckResponse { status };
155
156 #[allow(clippy::redundant_pattern_matching)]
157 while let Ok(_) = status_rx.changed().await {
158 let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
159 yield HealthCheckResponse { status };
160 }
161 };
162
163 Ok(Response::new(Box::pin(output) as Self::WatchStream))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use crate::pb::health_server::Health;
170 use crate::pb::HealthCheckRequest;
171 use crate::server::{HealthReporter, HealthService};
172 use crate::ServingStatus;
173 use tokio::sync::watch;
174 use tokio_stream::StreamExt;
175 use tonic::{Code, Request, Status};
176
177 fn assert_serving_status(wire: i32, expected: ServingStatus) {
178 let expected = crate::pb::health_check_response::ServingStatus::from(expected) as i32;
179 assert_eq!(wire, expected);
180 }
181
182 fn assert_grpc_status(wire: Option<Status>, expected: Code) {
183 let wire = wire.expect("status is not None").code();
184 assert_eq!(wire, expected);
185 }
186
187 async fn make_test_service() -> (HealthReporter, HealthService) {
188 let health_reporter = HealthReporter::new();
189
190 {
192 let mut statuses = health_reporter.statuses.write().await;
193 statuses.insert(
194 "TestService".to_string(),
195 watch::channel(ServingStatus::Unknown),
196 );
197 }
198
199 let health_service = HealthService::new(health_reporter.statuses.clone());
200 (health_reporter, health_service)
201 }
202
203 #[tokio::test]
204 async fn test_service_check() {
205 let (mut reporter, service) = make_test_service().await;
206
207 let resp = service
209 .check(Request::new(HealthCheckRequest {
210 service: "".to_string(),
211 }))
212 .await;
213 assert!(resp.is_ok());
214 let resp = resp.unwrap().into_inner();
215 assert_serving_status(resp.status, ServingStatus::Serving);
216
217 let resp = service
219 .check(Request::new(HealthCheckRequest {
220 service: "Unregistered".to_string(),
221 }))
222 .await;
223 assert!(resp.is_err());
224 assert_grpc_status(resp.err(), Code::NotFound);
225
226 let resp = service
228 .check(Request::new(HealthCheckRequest {
229 service: "TestService".to_string(),
230 }))
231 .await;
232 assert!(resp.is_ok());
233 let resp = resp.unwrap().into_inner();
234 assert_serving_status(resp.status, ServingStatus::Unknown);
235
236 reporter
238 .set_service_status("TestService", ServingStatus::Serving)
239 .await;
240 let resp = service
241 .check(Request::new(HealthCheckRequest {
242 service: "TestService".to_string(),
243 }))
244 .await;
245 assert!(resp.is_ok());
246 let resp = resp.unwrap().into_inner();
247 assert_serving_status(resp.status, ServingStatus::Serving);
248 }
249
250 #[tokio::test]
251 async fn test_service_watch() {
252 let (mut reporter, service) = make_test_service().await;
253
254 let resp = service
256 .watch(Request::new(HealthCheckRequest {
257 service: "".to_string(),
258 }))
259 .await;
260 assert!(resp.is_ok());
261 let mut resp = resp.unwrap().into_inner();
262 let item = resp
263 .next()
264 .await
265 .expect("streamed response is Some")
266 .expect("response is ok");
267 assert_serving_status(item.status, ServingStatus::Serving);
268
269 let resp = service
271 .watch(Request::new(HealthCheckRequest {
272 service: "Unregistered".to_string(),
273 }))
274 .await;
275 assert!(resp.is_err());
276 assert_grpc_status(resp.err(), Code::NotFound);
277
278 let resp = service
280 .watch(Request::new(HealthCheckRequest {
281 service: "TestService".to_string(),
282 }))
283 .await;
284 assert!(resp.is_ok());
285 let mut resp = resp.unwrap().into_inner();
286
287 let item = resp
289 .next()
290 .await
291 .expect("streamed response is Some")
292 .expect("response is ok");
293 assert_serving_status(item.status, ServingStatus::Unknown);
294
295 reporter
297 .set_service_status("TestService", ServingStatus::NotServing)
298 .await;
299
300 let item = resp
301 .next()
302 .await
303 .expect("streamed response is Some")
304 .expect("response is ok");
305 assert_serving_status(item.status, ServingStatus::NotServing);
306
307 reporter
309 .set_service_status("TestService", ServingStatus::Serving)
310 .await;
311 let item = resp
312 .next()
313 .await
314 .expect("streamed response is Some")
315 .expect("response is ok");
316 assert_serving_status(item.status, ServingStatus::Serving);
317
318 reporter.clear_service_status("TestService").await;
320 let item = resp.next().await;
321 assert!(item.is_none());
322 }
323}