Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 164 additions & 72 deletions src/client/body/multipart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,61 @@ use crate::{client::body::PyStream, error::Error, header::HeaderMap};

/// A multipart form for a request.
#[pyclass(subclass)]
pub struct Multipart(pub Option<multipart::Form>);
pub struct Multipart {
pub form: Option<multipart::Form>,
pub parts: Vec<Part>,
}

/// The data for a part value of a multipart form.
#[derive(FromPyObject)]
pub enum Value {
Text(PyBackedStr),
Bytes(PyBackedBytes),
File(PathBuf),
Stream(PyStream),
}

/// A part of a multipart form.
#[pyclass(subclass)]
pub struct Part {
pub name: String,
pub value: Option<Value>,
pub filename: Option<String>,
pub mime: Option<String>,
pub length: Option<u64>,
pub headers: Option<HeaderMap>,
}

// ===== impl Multipart =====

#[pymethods]
impl Multipart {
/// Creates a new multipart form.
/// Creates a new multipart.
#[new]
#[pyo3(signature = (*parts))]
pub fn new(parts: &Bound<PyTuple>) -> PyResult<Multipart> {
let mut form = multipart::Form::new();
pub fn new(py: Python, parts: &Bound<PyTuple>) -> PyResult<Multipart> {
let mut new_parts = Vec::with_capacity(parts.len());
for part in parts {
let part = part.cast::<Part>()?;
let mut part = part.borrow_mut();
form = part
.name
.take()
.zip(part.inner.take())
.map(|(name, inner)| form.part(name, inner))
.ok_or_else(|| Error::Memory)?;
new_parts.push(part.try_clone(py)?);
}

Ok(Self {
form: None,
parts: new_parts,
})
}
}

impl Multipart {
fn build_form(&mut self, py: Python) -> PyResult<multipart::Form> {
let mut form = multipart::Form::new();
for part in &mut self.parts {
let (name, inner) = part.build_form_part(py)?;
form = form.part(name, inner);
}
Ok(Multipart(Some(form)))
Ok(form)
}
}

Expand All @@ -40,31 +75,120 @@ impl FromPyObject<'_, '_> for Multipart {

fn extract(ob: Borrowed<PyAny>) -> PyResult<Self> {
let multipart = ob.cast::<Multipart>()?;
multipart
.borrow_mut()
.0
.take()
.map(Some)
.map(Self)
.ok_or_else(|| Error::Memory)
.map_err(Into::into)
let mut multipart = multipart.borrow_mut();
let form = multipart.build_form(ob.py())?;

Ok(Multipart {
form: Some(form),
parts: Vec::new(),
})
}
}

/// A part of a multipart form.
#[pyclass(subclass)]
pub struct Part {
pub name: Option<String>,
pub inner: Option<multipart::Part>,
// ===== impl Value =====

impl Value {
fn try_clone(&self, py: Python) -> Option<Self> {
match self {
Value::Text(text) => {
let text = text.clone_ref(py);
Some(Value::Text(text))
}
Value::Bytes(bytes) => {
let bytes = bytes.clone_ref(py);
Some(Value::Bytes(bytes))
}
Value::File(path) => {
let path = path.clone();
Some(Value::File(path))
}
Value::Stream(_) => None,
}
}
}

/// The data for a part value of a multipart form.
#[derive(FromPyObject)]
pub enum Value {
Text(PyBackedStr),
Bytes(PyBackedBytes),
File(PathBuf),
Stream(PyStream),
// ===== impl Part =====

impl Part {
fn with_value(&self, value: Value) -> Part {
Part {
name: self.name.clone(),
value: Some(value),
filename: self.filename.clone(),
mime: self.mime.clone(),
length: self.length,
headers: self.headers.clone(),
}
}

fn build_inner(value: Value, length: Option<u64>) -> Result<multipart::Part, Error> {
Ok(match value {
Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))),
Value::Bytes(bytes) => multipart::Part::stream(Body::from(Bytes::from_owner(bytes))),
Value::File(path) => pyo3_async_runtimes::tokio::get_runtime()
.block_on(multipart::Part::file(path))
.map_err(Error::from)?,
Comment on lines +128 to +130
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using block_on inside build_inner is risky and can lead to panics if this code is executed on a thread that is already part of a Tokio runtime (e.g., the event loop thread). Since Multipart::extract is called during argument parsing for async functions like execute_request, it is highly likely to be called from such a context.

Consider making build_inner, into_inner, and build_form asynchronous. You can then defer the actual construction of the multipart::Form to the execute_request function in src/client/req.rs, where it can be properly awaited without blocking the thread. This would also involve changing Multipart::extract to only clone the parts and leave the form field as None initially.

Value::Stream(stream) => {
let stream = Body::wrap_stream(stream);
match length {
Some(length) => multipart::Part::stream_with_length(stream, length),
None => multipart::Part::stream(stream),
}
}
})
}

fn clone_value_or_take(&mut self, py: Python) -> PyResult<Value> {
self.value
.as_ref()
.and_then(|value| value.try_clone(py))
.or_else(|| self.value.take())
.ok_or_else(|| Error::Memory.into())
}

fn build_form_part(&mut self, py: Python) -> PyResult<(String, multipart::Part)> {
let value = self.clone_value_or_take(py)?;
let name = self.name.clone();
let filename = self.filename.clone();
let mime = self.mime.clone();
let length = self.length;
let headers = self.headers.clone();

py.detach(move || {
let mut inner = Self::build_inner(value, length)?;

if let Some(filename) = filename {
inner = inner.file_name(filename);
}

if let Some(mime) = mime {
inner = inner.mime_str(&mime).map_err(Error::Library)?;
}

if let Some(headers) = headers {
inner = inner.headers(headers.0);
}

Ok((name, inner))
})
}

fn try_clone(&mut self, py: Python) -> PyResult<Part> {
if let Some(part) = self
.value
.as_ref()
.and_then(|value| value.try_clone(py))
.map(|value| self.with_value(value))
{
return Ok(part);
}

self.value
.take()
.map(|value| self.with_value(value))
.ok_or_else(|| Error::Memory)
.map_err(Into::into)
}
}

#[pymethods]
Expand All @@ -80,52 +204,20 @@ impl Part {
headers = None
))]
pub fn new(
py: Python,
name: String,
value: Value,
filename: Option<String>,
mime: Option<&str>,
length: Option<u64>,
headers: Option<HeaderMap>,
) -> PyResult<Part> {
py.detach(|| {
// Create the inner part
let mut inner = match value {
Value::Text(text) => multipart::Part::stream(Body::from(Bytes::from_owner(text))),
Value::Bytes(bytes) => {
multipart::Part::stream(Body::from(Bytes::from_owner(bytes)))
}
Value::File(path) => pyo3_async_runtimes::tokio::get_runtime()
.block_on(multipart::Part::file(path))
.map_err(Error::from)?,
Value::Stream(stream) => {
let stream = Body::wrap_stream(stream);
match length {
Some(length) => multipart::Part::stream_with_length(stream, length),
None => multipart::Part::stream(stream),
}
}
};

// Set the filename and MIME type if provided
if let Some(filename) = filename {
inner = inner.file_name(filename);
}

// Set the MIME type if provided
if let Some(mime) = mime {
inner = inner.mime_str(mime).map_err(Error::Library)?;
}

// Set the headers if provided
if let Some(headers) = headers {
inner = inner.headers(headers.0);
}

Ok(Part {
name: Some(name),
inner: Some(inner),
})
})
) -> Part {
Part {
name,
value: Some(value),
filename,
mime: mime.map(ToOwned::to_owned),
length,
headers,
}
}
}
2 changes: 1 addition & 1 deletion src/client/req.rs
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ where
apply_option!(
set_if_some,
builder,
request.multipart.and_then(|form| form.0),
request.multipart.and_then(|form| form.form),
multipart
);
apply_option!(
Expand Down
90 changes: 90 additions & 0 deletions tests/multipart_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from pathlib import Path

import pytest
import wreq
from wreq import Multipart, Part

client = wreq.Client(tls_info=True)


def assert_form_value(data, key, expected):
value = data["form"][key]
if isinstance(value, list):
assert expected in value
else:
assert value == expected


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_reuse_multipart_with_clonable_parts():
form = Multipart(
Part(name="a", value="1"),
Part(name="b", value=b"2"),
Part(name="c", value=Path("./README.md"), filename="README.md", mime="text/plain"),
)

for _ in range(3):
resp = await client.post("https://httpbin.io/post", multipart=form)
async with resp:
assert resp.status.is_success()
data = await resp.json()
assert_form_value(data, "a", "1")
assert_form_value(data, "b", "2")
assert "c" in data["files"]


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_stream_part_is_one_shot_when_reusing_multipart():
def file_stream(path):
with open(path, "rb") as f:
while chunk := f.read(1024):
yield chunk

form = Multipart(
Part(
name="stream",
value=file_stream("./README.md"),
filename="README.md",
mime="text/plain",
),
)

resp = await client.post("https://httpbin.io/post", multipart=form)
async with resp:
assert resp.status.is_success()

with pytest.raises(RuntimeError):
resp = await client.post("https://httpbin.io/post", multipart=form)
async with resp:
pass


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_reuse_same_part_without_copy_for_clonable_value():
part = Part(name="a", value="1")

form1 = Multipart(part)
form2 = Multipart(part)

for form in (form1, form2):
resp = await client.post("https://httpbin.io/post", multipart=form)
async with resp:
assert resp.status.is_success()
data = await resp.json()
assert_form_value(data, "a", "1")


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3, reruns_delay=2)
async def test_reuse_same_part_without_copy_fails_for_stream_value():
def bytes_stream():
yield b"hello"

part = Part(name="stream", value=bytes_stream())
Multipart(part)

with pytest.raises(RuntimeError):
Multipart(part)
Loading