1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
/*
 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 * SPDX-License-Identifier: Apache-2.0
 */

use aws_smithy_runtime_api::client::endpoint::{
    error::ResolveEndpointError, EndpointFuture, EndpointResolverParams, ResolveEndpoint,
};
use aws_smithy_runtime_api::client::interceptors::context::InterceptorContext;
use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents;
use aws_smithy_runtime_api::{box_error::BoxError, client::endpoint::EndpointPrefix};
use aws_smithy_types::config_bag::ConfigBag;
use aws_smithy_types::endpoint::Endpoint;
use http_02x::header::HeaderName;
use http_02x::uri::PathAndQuery;
use http_02x::{HeaderValue, Uri};
use std::borrow::Cow;
use std::fmt::Debug;
use std::str::FromStr;
use tracing::trace;

/// An endpoint resolver that uses a static URI.
#[derive(Clone, Debug)]
pub struct StaticUriEndpointResolver {
    endpoint: String,
}

impl StaticUriEndpointResolver {
    /// Create a resolver that resolves to `http://localhost:{port}`.
    pub fn http_localhost(port: u16) -> Self {
        Self {
            endpoint: format!("http://localhost:{port}"),
        }
    }

    /// Create a resolver that resolves to the given URI.
    pub fn uri(endpoint: impl Into<String>) -> Self {
        Self {
            endpoint: endpoint.into(),
        }
    }
}

impl ResolveEndpoint for StaticUriEndpointResolver {
    fn resolve_endpoint<'a>(&'a self, _params: &'a EndpointResolverParams) -> EndpointFuture<'a> {
        EndpointFuture::ready(Ok(Endpoint::builder()
            .url(self.endpoint.to_string())
            .build()))
    }
}

/// Empty params to be used with [`StaticUriEndpointResolver`].
#[derive(Debug, Default)]
pub struct StaticUriEndpointResolverParams;

impl StaticUriEndpointResolverParams {
    /// Creates a new `StaticUriEndpointResolverParams`.
    pub fn new() -> Self {
        Self
    }
}

impl From<StaticUriEndpointResolverParams> for EndpointResolverParams {
    fn from(params: StaticUriEndpointResolverParams) -> Self {
        EndpointResolverParams::new(params)
    }
}

pub(super) async fn orchestrate_endpoint(
    ctx: &mut InterceptorContext,
    runtime_components: &RuntimeComponents,
    cfg: &mut ConfigBag,
) -> Result<(), BoxError> {
    trace!("orchestrating endpoint resolution");

    let params = cfg
        .load::<EndpointResolverParams>()
        .expect("endpoint resolver params must be set");
    let endpoint_prefix = cfg.load::<EndpointPrefix>();
    tracing::debug!(endpoint_params = ?params, endpoint_prefix = ?endpoint_prefix, "resolving endpoint");
    let request = ctx.request_mut().expect("set during serialization");

    let endpoint = runtime_components
        .endpoint_resolver()
        .resolve_endpoint(params)
        .await?;
    tracing::debug!("will use endpoint {:?}", endpoint);
    apply_endpoint(request, &endpoint, endpoint_prefix)?;

    // Make the endpoint config available to interceptors
    cfg.interceptor_state().store_put(endpoint);
    Ok(())
}

fn apply_endpoint(
    request: &mut HttpRequest,
    endpoint: &Endpoint,
    endpoint_prefix: Option<&EndpointPrefix>,
) -> Result<(), BoxError> {
    let endpoint_url = match endpoint_prefix {
        None => Cow::Borrowed(endpoint.url()),
        Some(prefix) => {
            let parsed = endpoint.url().parse::<Uri>()?;
            let scheme = parsed.scheme_str().unwrap_or_default();
            let prefix = prefix.as_str();
            let authority = parsed
                .authority()
                .map(|auth| auth.as_str())
                .unwrap_or_default();
            let path_and_query = parsed
                .path_and_query()
                .map(PathAndQuery::as_str)
                .unwrap_or_default();
            Cow::Owned(format!("{scheme}://{prefix}{authority}{path_and_query}"))
        }
    };

    request
        .uri_mut()
        .set_endpoint(&endpoint_url)
        .map_err(|err| {
            ResolveEndpointError::message(format!(
                "failed to apply endpoint `{}` to request `{:?}`",
                endpoint_url, request,
            ))
            .with_source(Some(err.into()))
        })?;

    for (header_name, header_values) in endpoint.headers() {
        request.headers_mut().remove(header_name);
        for value in header_values {
            request.headers_mut().append(
                HeaderName::from_str(header_name).map_err(|err| {
                    ResolveEndpointError::message("invalid header name")
                        .with_source(Some(err.into()))
                })?,
                HeaderValue::from_str(value).map_err(|err| {
                    ResolveEndpointError::message("invalid header value")
                        .with_source(Some(err.into()))
                })?,
            );
        }
    }
    Ok(())
}

#[cfg(test)]
mod test {
    use aws_smithy_runtime_api::client::endpoint::EndpointPrefix;
    use aws_smithy_runtime_api::client::orchestrator::HttpRequest;
    use aws_smithy_types::endpoint::Endpoint;

    #[test]
    fn test_apply_endpoint() {
        let mut req = HttpRequest::empty();
        req.set_uri("/foo?bar=1").unwrap();
        let endpoint = Endpoint::builder().url("https://s3.amazon.com").build();
        let prefix = EndpointPrefix::new("prefix.subdomain.").unwrap();
        super::apply_endpoint(&mut req, &endpoint, Some(&prefix)).expect("should succeed");
        assert_eq!(
            req.uri(),
            "https://prefix.subdomain.s3.amazon.com/foo?bar=1"
        );
    }
}