aws_smithy_schema/schema/
protocol.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Client protocol trait for protocol-agnostic request serialization and response deserialization.
7//!
8//! A [`ClientProtocol`] uses a combination of codecs, serializers, and deserializers to
9//! serialize requests and deserialize responses for a specific Smithy protocol
10//! (e.g., AWS JSON 1.0, REST JSON, REST XML, RPCv2 CBOR).
11//!
12//! # Implementing a custom protocol
13//!
14//! Third parties can create custom protocols and use them with any client without
15//! modifying a code generator.
16//!
17//! ```ignore
18//! use aws_smithy_schema::protocol::ClientProtocol;
19//! use aws_smithy_schema::{Schema, ShapeId};
20//! use aws_smithy_schema::serde::SerializableStruct;
21//!
22//! #[derive(Debug)]
23//! struct MyProtocol {
24//!     codec: MyJsonCodec,
25//! }
26//!
27//! impl ClientProtocol for MyProtocol {
28//!     fn protocol_id(&self) -> &ShapeId { &MY_PROTOCOL_ID }
29//!
30//!     fn serialize_request(
31//!         &self,
32//!         input: &dyn SerializableStruct,
33//!         input_schema: &Schema,
34//!         endpoint: &str,
35//!         cfg: &ConfigBag,
36//!     ) -> Result<aws_smithy_runtime_api::http::Request, SerdeError> {
37//!         todo!()
38//!     }
39//!
40//!     fn deserialize_response<'a>(
41//!         &self,
42//!         response: &'a aws_smithy_runtime_api::http::Response,
43//!         output_schema: &Schema,
44//!         cfg: &ConfigBag,
45//!     ) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError> {
46//!         todo!()
47//!     }
48//! }
49//! ```
50
51use crate::serde::{SerdeError, SerializableStruct, ShapeDeserializer};
52use crate::{Schema, ShapeId};
53use aws_smithy_types::config_bag::ConfigBag;
54
55// Implementation note: We use concrete aws_smithy_runtime_api::http::{Request, Response} types here
56// rather than associated types. While the SEP allows for transport-agnostic protocols (e.g., MQTT),
57// the current SDK uses HTTP throughout. If transport abstraction is needed in the future,
58// Request/Response could become associated types while maintaining object safety of this trait.
59
60/// An object-safe client protocol for serializing requests and deserializing responses.
61///
62/// Each Smithy protocol (e.g., `aws.protocols#restJson1`, `smithy.protocols#rpcv2Cbor`)
63/// is represented by an implementation of this trait. Protocols combine one or more
64/// codecs and serializers to produce protocol-specific request messages and parse
65/// response messages.
66///
67/// # Lifecycle
68///
69/// `ClientProtocol` instances are immutable and thread-safe. They are typically created
70/// once and shared across all requests for a client. Serializers and deserializers are
71/// created per-request internally.
72pub trait ClientProtocol: Send + Sync + std::fmt::Debug {
73    /// Returns the Smithy shape ID of this protocol.
74    ///
75    /// This enables runtime protocol selection and differentiation. For example,
76    /// `aws.protocols#restJson1` or `smithy.protocols#rpcv2Cbor`.
77    fn protocol_id(&self) -> &ShapeId;
78
79    /// Serializes an operation input into an HTTP request.
80    ///
81    /// # Arguments
82    ///
83    /// * `input` - The operation input to serialize.
84    /// * `input_schema` - Schema describing the operation's input shape.
85    /// * `endpoint` - The target endpoint URI as a string.
86    /// * `cfg` - The config bag containing request-scoped configuration
87    ///   (e.g., service name, operation name for RPC protocols).
88    fn serialize_request(
89        &self,
90        input: &dyn SerializableStruct,
91        input_schema: &Schema,
92        endpoint: &str,
93        cfg: &ConfigBag,
94    ) -> Result<aws_smithy_runtime_api::http::Request, SerdeError>;
95
96    /// Deserializes an HTTP response, returning a boxed [`ShapeDeserializer`]
97    /// for the response body.
98    ///
99    /// The returned deserializer reads only body members. For outputs with
100    /// HTTP-bound members (headers, status code), generated code reads those
101    /// directly from the response before using this deserializer for body members.
102    ///
103    /// # Arguments
104    ///
105    /// * `response` - The HTTP response to deserialize.
106    /// * `output_schema` - Schema describing the operation's output shape.
107    /// * `cfg` - The config bag containing request-scoped configuration.
108    fn deserialize_response<'a>(
109        &self,
110        response: &'a aws_smithy_runtime_api::http::Response,
111        output_schema: &Schema,
112        cfg: &ConfigBag,
113    ) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError>;
114
115    /// Updates a previously serialized request with a new endpoint.
116    ///
117    /// Required by SEP requirement 7: "ClientProtocol MUST be able to update a
118    /// previously serialized request with a new endpoint." The orchestrator calls
119    /// this after endpoint resolution, which happens after `serialize_request`.
120    ///
121    /// The default implementation applies the endpoint URL (with prefix if present),
122    /// sets the request URI, and copies any endpoint headers onto the request.
123    /// This replicates the existing `apply_endpoint` logic from the orchestrator.
124    /// Custom implementations should rarely need to override this.
125    fn update_endpoint(
126        &self,
127        request: &mut aws_smithy_runtime_api::http::Request,
128        endpoint: &aws_smithy_types::endpoint::Endpoint,
129        cfg: &ConfigBag,
130    ) -> Result<(), SerdeError> {
131        use std::borrow::Cow;
132
133        let endpoint_prefix =
134            cfg.load::<aws_smithy_runtime_api::client::endpoint::EndpointPrefix>();
135        let endpoint_url = match endpoint_prefix {
136            None => Cow::Borrowed(endpoint.url()),
137            Some(prefix) => {
138                let parsed: http::Uri = endpoint
139                    .url()
140                    .parse()
141                    .map_err(|e| SerdeError::custom(format!("invalid endpoint URI: {e}")))?;
142                let scheme = parsed.scheme_str().unwrap_or_default();
143                let prefix = prefix.as_str();
144                let authority = parsed.authority().map(|a| a.as_str()).unwrap_or_default();
145                let path_and_query = parsed
146                    .path_and_query()
147                    .map(|pq| pq.as_str())
148                    .unwrap_or_default();
149                Cow::Owned(format!("{scheme}://{prefix}{authority}{path_and_query}"))
150            }
151        };
152
153        request.uri_mut().set_endpoint(&endpoint_url).map_err(|e| {
154            SerdeError::custom(format!("failed to apply endpoint `{endpoint_url}`: {e}"))
155        })?;
156
157        for (header_name, header_values) in endpoint.headers() {
158            request.headers_mut().remove(header_name);
159            for value in header_values {
160                request
161                    .headers_mut()
162                    .append(header_name.to_owned(), value.to_owned());
163            }
164        }
165
166        Ok(())
167    }
168}
169
170/// A shared, type-erased client protocol stored in a [`ConfigBag`].
171///
172/// This wraps an `Arc<dyn ClientProtocol>` so it can be stored
173/// and retrieved from the config bag for runtime protocol selection.
174#[derive(Clone, Debug)]
175pub struct SharedClientProtocol {
176    inner: std::sync::Arc<dyn ClientProtocol>,
177}
178
179impl SharedClientProtocol {
180    /// Creates a new shared protocol from any `ClientProtocol` implementation.
181    pub fn new(protocol: impl ClientProtocol + 'static) -> Self {
182        Self {
183            inner: std::sync::Arc::new(protocol),
184        }
185    }
186}
187
188impl std::ops::Deref for SharedClientProtocol {
189    type Target = dyn ClientProtocol;
190
191    fn deref(&self) -> &Self::Target {
192        &*self.inner
193    }
194}
195
196impl aws_smithy_types::config_bag::Storable for SharedClientProtocol {
197    type Storer = aws_smithy_types::config_bag::StoreReplace<Self>;
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::serde::{SerdeError, SerializableStruct, ShapeDeserializer};
204    use crate::{Schema, ShapeId};
205    use aws_smithy_runtime_api::http::{Request, Response};
206    use aws_smithy_types::body::SdkBody;
207    use aws_smithy_types::config_bag::{ConfigBag, Layer};
208    use aws_smithy_types::endpoint::Endpoint;
209
210    /// Minimal protocol impl that uses only the default `update_endpoint`.
211    #[derive(Debug)]
212    struct StubProtocol;
213
214    static STUB_ID: ShapeId = ShapeId::from_static("test#StubProtocol", "test", "StubProtocol");
215
216    impl ClientProtocol for StubProtocol {
217        fn protocol_id(&self) -> &ShapeId {
218            &STUB_ID
219        }
220        fn serialize_request(
221            &self,
222            _input: &dyn SerializableStruct,
223            _input_schema: &Schema,
224            _endpoint: &str,
225            _cfg: &ConfigBag,
226        ) -> Result<Request, SerdeError> {
227            unimplemented!()
228        }
229        fn deserialize_response<'a>(
230            &self,
231            _response: &'a Response,
232            _output_schema: &Schema,
233            _cfg: &ConfigBag,
234        ) -> Result<Box<dyn ShapeDeserializer + 'a>, SerdeError> {
235            unimplemented!()
236        }
237    }
238
239    fn request_with_uri(uri: &str) -> Request {
240        let mut req = Request::new(SdkBody::empty());
241        req.set_uri(uri).unwrap();
242        req
243    }
244
245    #[test]
246    fn basic_endpoint() {
247        let proto = StubProtocol;
248        let mut req = request_with_uri("/original/path");
249        let endpoint = Endpoint::builder()
250            .url("https://service.us-east-1.amazonaws.com")
251            .build();
252        let cfg = ConfigBag::base();
253
254        proto.update_endpoint(&mut req, &endpoint, &cfg).unwrap();
255        assert_eq!(
256            req.uri(),
257            "https://service.us-east-1.amazonaws.com/original/path"
258        );
259    }
260
261    #[test]
262    fn endpoint_with_prefix() {
263        let proto = StubProtocol;
264        let mut req = request_with_uri("/path");
265        let endpoint = Endpoint::builder()
266            .url("https://service.us-east-1.amazonaws.com")
267            .build();
268        let mut cfg = ConfigBag::base();
269        let mut layer = Layer::new("test");
270        layer.store_put(
271            aws_smithy_runtime_api::client::endpoint::EndpointPrefix::new("myprefix.").unwrap(),
272        );
273        cfg.push_shared_layer(layer.freeze());
274
275        proto.update_endpoint(&mut req, &endpoint, &cfg).unwrap();
276        assert_eq!(
277            req.uri(),
278            "https://myprefix.service.us-east-1.amazonaws.com/path"
279        );
280    }
281
282    #[test]
283    fn endpoint_with_headers() {
284        let proto = StubProtocol;
285        let mut req = request_with_uri("/path");
286        let endpoint = Endpoint::builder()
287            .url("https://example.com")
288            .header("x-custom", "value1")
289            .header("x-custom", "value2")
290            .build();
291        let cfg = ConfigBag::base();
292
293        proto.update_endpoint(&mut req, &endpoint, &cfg).unwrap();
294        assert_eq!(req.uri(), "https://example.com/path");
295        let values: Vec<&str> = req.headers().get_all("x-custom").collect();
296        assert_eq!(values, vec!["value1", "value2"]);
297    }
298
299    #[test]
300    fn endpoint_with_path() {
301        let proto = StubProtocol;
302        let mut req = request_with_uri("/operation");
303        let endpoint = Endpoint::builder().url("https://example.com/base").build();
304        let cfg = ConfigBag::base();
305
306        proto.update_endpoint(&mut req, &endpoint, &cfg).unwrap();
307        assert_eq!(req.uri(), "https://example.com/base/operation");
308    }
309}