tonic_health/
server.rs

1//! Contains all healthcheck based server utilities.
2
3use crate::pb::health_server::{Health, HealthServer};
4use crate::pb::{HealthCheckRequest, HealthCheckResponse};
5use crate::ServingStatus;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::sync::Arc;
9use tokio::sync::{watch, RwLock};
10use tokio_stream::Stream;
11#[cfg(feature = "transport")]
12use tonic::server::NamedService;
13use tonic::{Request, Response, Status};
14
15/// Creates a `HealthReporter` and a linked `HealthServer` pair. Together,
16/// these types can be used to serve the gRPC Health Checking service.
17///
18/// A `HealthReporter` is used to update the state of gRPC services.
19///
20/// A `HealthServer` is a Tonic gRPC server for the `grpc.health.v1.Health`,
21/// which can be added to a Tonic runtime using `add_service` on the runtime
22/// builder.
23pub fn health_reporter() -> (HealthReporter, HealthServer<impl Health>) {
24    let reporter = HealthReporter::new();
25    let service = HealthService::new(reporter.statuses.clone());
26    let server = HealthServer::new(service);
27
28    (reporter, server)
29}
30
31type StatusPair = (watch::Sender<ServingStatus>, watch::Receiver<ServingStatus>);
32
33/// A handle providing methods to update the health status of gRPC services. A
34/// `HealthReporter` is connected to a `HealthServer` which serves the statuses
35/// over the `grpc.health.v1.Health` service.
36#[derive(Clone, Debug)]
37pub struct HealthReporter {
38    statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
39}
40
41impl HealthReporter {
42    fn new() -> Self {
43        // According to the gRPC Health Check specification, the empty service "" corresponds to the overall server health
44        let server_status = ("".to_string(), watch::channel(ServingStatus::Serving));
45
46        let statuses = Arc::new(RwLock::new(HashMap::from([server_status])));
47
48        HealthReporter { statuses }
49    }
50
51    /// Sets the status of the service implemented by `S` to `Serving`. This notifies any watchers
52    /// if there is a change in status.
53    #[cfg(feature = "transport")]
54    pub async fn set_serving<S>(&mut self)
55    where
56        S: NamedService,
57    {
58        let service_name = <S as NamedService>::NAME;
59        self.set_service_status(service_name, ServingStatus::Serving)
60            .await;
61    }
62
63    /// Sets the status of the service implemented by `S` to `NotServing`. This notifies any watchers
64    /// if there is a change in status.
65    #[cfg(feature = "transport")]
66    pub async fn set_not_serving<S>(&mut self)
67    where
68        S: NamedService,
69    {
70        let service_name = <S as NamedService>::NAME;
71        self.set_service_status(service_name, ServingStatus::NotServing)
72            .await;
73    }
74
75    /// Sets the status of the service with `service_name` to `status`. This notifies any watchers
76    /// if there is a change in status.
77    pub async fn set_service_status<S>(&mut self, service_name: S, status: ServingStatus)
78    where
79        S: AsRef<str>,
80    {
81        let service_name = service_name.as_ref();
82        let mut writer = self.statuses.write().await;
83        match writer.get(service_name) {
84            Some((tx, _)) => {
85                // We only ever hand out clones of the receiver, so the originally-created
86                // receiver should always be present, only being dropped when clearing the
87                // service status. Consequently, `tx.send` should not fail, making use
88                // of `expect` here safe.
89                tx.send(status).expect("channel should not be closed");
90            }
91            None => {
92                writer.insert(service_name.to_string(), watch::channel(status));
93            }
94        };
95    }
96
97    /// Clear the status of the given service.
98    pub async fn clear_service_status(&mut self, service_name: &str) {
99        let mut writer = self.statuses.write().await;
100        let _ = writer.remove(service_name);
101    }
102}
103
104/// A service providing implementations of gRPC health checking protocol.
105#[derive(Debug)]
106pub struct HealthService {
107    statuses: Arc<RwLock<HashMap<String, StatusPair>>>,
108}
109
110impl HealthService {
111    fn new(services: Arc<RwLock<HashMap<String, StatusPair>>>) -> Self {
112        HealthService { statuses: services }
113    }
114
115    async fn service_health(&self, service_name: &str) -> Option<ServingStatus> {
116        let reader = self.statuses.read().await;
117        reader.get(service_name).map(|p| *p.1.borrow())
118    }
119}
120
121#[tonic::async_trait]
122impl Health for HealthService {
123    async fn check(
124        &self,
125        request: Request<HealthCheckRequest>,
126    ) -> Result<Response<HealthCheckResponse>, Status> {
127        let service_name = request.get_ref().service.as_str();
128        let status = self.service_health(service_name).await;
129
130        match status {
131            None => Err(Status::not_found("service not registered")),
132            Some(status) => Ok(Response::new(HealthCheckResponse {
133                status: crate::pb::health_check_response::ServingStatus::from(status) as i32,
134            })),
135        }
136    }
137
138    type WatchStream =
139        Pin<Box<dyn Stream<Item = Result<HealthCheckResponse, Status>> + Send + 'static>>;
140
141    async fn watch(
142        &self,
143        request: Request<HealthCheckRequest>,
144    ) -> Result<Response<Self::WatchStream>, Status> {
145        let service_name = request.get_ref().service.as_str();
146        let mut status_rx = match self.statuses.read().await.get(service_name) {
147            None => return Err(Status::not_found("service not registered")),
148            Some(pair) => pair.1.clone(),
149        };
150
151        let output = async_stream::try_stream! {
152            // yield the current value
153            let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
154            yield HealthCheckResponse { status };
155
156            #[allow(clippy::redundant_pattern_matching)]
157            while let Ok(_) = status_rx.changed().await {
158                let status = crate::pb::health_check_response::ServingStatus::from(*status_rx.borrow()) as i32;
159                yield HealthCheckResponse { status };
160            }
161        };
162
163        Ok(Response::new(Box::pin(output) as Self::WatchStream))
164    }
165}
166
167#[cfg(test)]
168mod tests {
169    use crate::pb::health_server::Health;
170    use crate::pb::HealthCheckRequest;
171    use crate::server::{HealthReporter, HealthService};
172    use crate::ServingStatus;
173    use tokio::sync::watch;
174    use tokio_stream::StreamExt;
175    use tonic::{Code, Request, Status};
176
177    fn assert_serving_status(wire: i32, expected: ServingStatus) {
178        let expected = crate::pb::health_check_response::ServingStatus::from(expected) as i32;
179        assert_eq!(wire, expected);
180    }
181
182    fn assert_grpc_status(wire: Option<Status>, expected: Code) {
183        let wire = wire.expect("status is not None").code();
184        assert_eq!(wire, expected);
185    }
186
187    async fn make_test_service() -> (HealthReporter, HealthService) {
188        let health_reporter = HealthReporter::new();
189
190        // insert test value
191        {
192            let mut statuses = health_reporter.statuses.write().await;
193            statuses.insert(
194                "TestService".to_string(),
195                watch::channel(ServingStatus::Unknown),
196            );
197        }
198
199        let health_service = HealthService::new(health_reporter.statuses.clone());
200        (health_reporter, health_service)
201    }
202
203    #[tokio::test]
204    async fn test_service_check() {
205        let (mut reporter, service) = make_test_service().await;
206
207        // Overall server health
208        let resp = service
209            .check(Request::new(HealthCheckRequest {
210                service: "".to_string(),
211            }))
212            .await;
213        assert!(resp.is_ok());
214        let resp = resp.unwrap().into_inner();
215        assert_serving_status(resp.status, ServingStatus::Serving);
216
217        // Unregistered service
218        let resp = service
219            .check(Request::new(HealthCheckRequest {
220                service: "Unregistered".to_string(),
221            }))
222            .await;
223        assert!(resp.is_err());
224        assert_grpc_status(resp.err(), Code::NotFound);
225
226        // Registered service - initial state
227        let resp = service
228            .check(Request::new(HealthCheckRequest {
229                service: "TestService".to_string(),
230            }))
231            .await;
232        assert!(resp.is_ok());
233        let resp = resp.unwrap().into_inner();
234        assert_serving_status(resp.status, ServingStatus::Unknown);
235
236        // Registered service - updated state
237        reporter
238            .set_service_status("TestService", ServingStatus::Serving)
239            .await;
240        let resp = service
241            .check(Request::new(HealthCheckRequest {
242                service: "TestService".to_string(),
243            }))
244            .await;
245        assert!(resp.is_ok());
246        let resp = resp.unwrap().into_inner();
247        assert_serving_status(resp.status, ServingStatus::Serving);
248    }
249
250    #[tokio::test]
251    async fn test_service_watch() {
252        let (mut reporter, service) = make_test_service().await;
253
254        // Overall server health
255        let resp = service
256            .watch(Request::new(HealthCheckRequest {
257                service: "".to_string(),
258            }))
259            .await;
260        assert!(resp.is_ok());
261        let mut resp = resp.unwrap().into_inner();
262        let item = resp
263            .next()
264            .await
265            .expect("streamed response is Some")
266            .expect("response is ok");
267        assert_serving_status(item.status, ServingStatus::Serving);
268
269        // Unregistered service
270        let resp = service
271            .watch(Request::new(HealthCheckRequest {
272                service: "Unregistered".to_string(),
273            }))
274            .await;
275        assert!(resp.is_err());
276        assert_grpc_status(resp.err(), Code::NotFound);
277
278        // Registered service
279        let resp = service
280            .watch(Request::new(HealthCheckRequest {
281                service: "TestService".to_string(),
282            }))
283            .await;
284        assert!(resp.is_ok());
285        let mut resp = resp.unwrap().into_inner();
286
287        // Registered service - initial state
288        let item = resp
289            .next()
290            .await
291            .expect("streamed response is Some")
292            .expect("response is ok");
293        assert_serving_status(item.status, ServingStatus::Unknown);
294
295        // Registered service - updated state
296        reporter
297            .set_service_status("TestService", ServingStatus::NotServing)
298            .await;
299
300        let item = resp
301            .next()
302            .await
303            .expect("streamed response is Some")
304            .expect("response is ok");
305        assert_serving_status(item.status, ServingStatus::NotServing);
306
307        // Registered service - updated state
308        reporter
309            .set_service_status("TestService", ServingStatus::Serving)
310            .await;
311        let item = resp
312            .next()
313            .await
314            .expect("streamed response is Some")
315            .expect("response is ok");
316        assert_serving_status(item.status, ServingStatus::Serving);
317
318        // De-registered service
319        reporter.clear_service_status("TestService").await;
320        let item = resp.next().await;
321        assert!(item.is_none());
322    }
323}