aws_runtime/
endpoint_override.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6//! Endpoint override detection for business metrics tracking
7
8use aws_smithy_runtime_api::box_error::BoxError;
9use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextRef;
10use aws_smithy_runtime_api::client::interceptors::Intercept;
11use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
12use aws_smithy_runtime_api::client::runtime_plugin::RuntimePlugin;
13use aws_smithy_types::config_bag::{ConfigBag, FrozenLayer, Layer};
14
15use crate::sdk_feature::AwsSdkFeature;
16
17/// Interceptor that detects custom endpoint URLs for business metrics
18///
19/// This interceptor checks at runtime if a `StaticUriEndpointResolver` is configured,
20/// which indicates that `.endpoint_url()` was called. When detected, it stores the
21/// `AwsSdkFeature::EndpointOverride` feature flag for business metrics tracking.
22#[derive(Debug, Default)]
23#[non_exhaustive]
24pub struct EndpointOverrideInterceptor;
25
26impl EndpointOverrideInterceptor {
27    /// Creates a new EndpointOverrideInterceptor
28    pub fn new() -> Self {
29        Self
30    }
31}
32
33impl Intercept for EndpointOverrideInterceptor {
34    fn name(&self) -> &'static str {
35        "EndpointOverrideInterceptor"
36    }
37
38    fn read_after_serialization(
39        &self,
40        _context: &BeforeTransmitInterceptorContextRef<'_>,
41        runtime_components: &RuntimeComponents,
42        cfg: &mut ConfigBag,
43    ) -> Result<(), BoxError> {
44        // Check if the endpoint resolver is a StaticUriEndpointResolver
45        // This indicates that .endpoint_url() was called
46        let resolver = runtime_components.endpoint_resolver();
47
48        // Check the resolver's debug string to see if it's StaticUriEndpointResolver
49        let debug_str = format!("{:?}", resolver);
50
51        if debug_str.contains("StaticUriEndpointResolver") {
52            // Store in interceptor_state
53            cfg.interceptor_state()
54                .store_append(AwsSdkFeature::EndpointOverride);
55        }
56
57        Ok(())
58    }
59}
60
61/// Runtime plugin that detects when a custom endpoint URL has been configured
62/// and tracks it for business metrics.
63///
64/// This plugin is created by the codegen decorator when a user explicitly
65/// sets an endpoint URL via `.endpoint_url()`. It stores the
66/// `AwsSdkFeature::EndpointOverride` feature flag in the ConfigBag for
67/// business metrics tracking.
68#[derive(Debug, Default)]
69#[non_exhaustive]
70pub struct EndpointOverrideRuntimePlugin {
71    config: Option<FrozenLayer>,
72}
73
74impl EndpointOverrideRuntimePlugin {
75    /// Creates a new `EndpointOverrideRuntimePlugin` with the given config layer
76    pub fn new(config: Option<FrozenLayer>) -> Self {
77        Self { config }
78    }
79
80    /// Creates a new `EndpointOverrideRuntimePlugin` and marks that endpoint override is enabled
81    pub fn new_with_feature_flag() -> Self {
82        let mut layer = Layer::new("endpoint_override");
83        layer.store_append(AwsSdkFeature::EndpointOverride);
84        Self {
85            config: Some(layer.freeze()),
86        }
87    }
88}
89
90impl RuntimePlugin for EndpointOverrideRuntimePlugin {
91    fn config(&self) -> Option<FrozenLayer> {
92        self.config.clone()
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::sdk_feature::AwsSdkFeature;
100
101    #[test]
102    fn test_plugin_with_no_config() {
103        let plugin = EndpointOverrideRuntimePlugin::default();
104        assert!(plugin.config().is_none());
105    }
106
107    #[test]
108    fn test_plugin_with_feature_flag() {
109        let plugin = EndpointOverrideRuntimePlugin::new_with_feature_flag();
110        let config = plugin.config().expect("config should be set");
111
112        // Verify the feature flag is present in the config
113        let features: Vec<_> = config.load::<AwsSdkFeature>().cloned().collect();
114        assert_eq!(features.len(), 1);
115        assert_eq!(features[0], AwsSdkFeature::EndpointOverride);
116    }
117
118    #[test]
119    fn test_interceptor_detects_static_uri_resolver() {
120        use aws_smithy_runtime::client::orchestrator::endpoints::StaticUriEndpointResolver;
121        use aws_smithy_runtime_api::client::endpoint::SharedEndpointResolver;
122        use aws_smithy_runtime_api::client::interceptors::context::{Input, InterceptorContext};
123        use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
124        use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
125        use aws_smithy_types::config_bag::ConfigBag;
126
127        // Create a StaticUriEndpointResolver
128        let endpoint_resolver = SharedEndpointResolver::new(StaticUriEndpointResolver::uri(
129            "https://custom.example.com",
130        ));
131
132        let mut context = InterceptorContext::new(Input::doesnt_matter());
133        context.enter_serialization_phase();
134        context.set_request(HttpRequest::empty());
135        let _ = context.take_input();
136        context.enter_before_transmit_phase();
137
138        let rc = RuntimeComponentsBuilder::for_tests()
139            .with_endpoint_resolver(Some(endpoint_resolver))
140            .build()
141            .unwrap();
142        let mut cfg = ConfigBag::base();
143
144        let interceptor = EndpointOverrideInterceptor::new();
145        let ctx = Into::into(&context);
146        interceptor
147            .read_after_serialization(&ctx, &rc, &mut cfg)
148            .unwrap();
149
150        // Verify the feature flag was set
151        let features: Vec<_> = cfg
152            .interceptor_state()
153            .load::<AwsSdkFeature>()
154            .cloned()
155            .collect();
156        assert_eq!(features.len(), 1, "Expected 1 feature, got: {:?}", features);
157        assert_eq!(features[0], AwsSdkFeature::EndpointOverride);
158    }
159}