Skip to content

Commit d807ad0

Browse files
committed
Implement kwargs in function calls
1 parent e8afd0a commit d807ad0

File tree

4 files changed

+118
-33
lines changed

4 files changed

+118
-33
lines changed

vm/src/builtins.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,7 @@ pub fn builtin_build_class_(vm: &mut VirtualMachine, mut args: PyFuncArgs) -> Py
419419
function,
420420
PyFuncArgs {
421421
args: vec![namespace.clone()],
422+
kwargs: None,
422423
},
423424
);
424425
objtype::new(metaclass, name, bases, namespace)

vm/src/obj/objtype.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ pub fn get_attribute(vm: &mut VirtualMachine, obj: PyObjectRef, name: &String) -
131131
descriptor,
132132
PyFuncArgs {
133133
args: vec![attr, obj, cls],
134+
kwargs: None,
134135
},
135136
);
136137
}

vm/src/pyobject.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,21 +512,30 @@ impl fmt::Debug for PyObject {
512512
#[derive(Debug, Default, Clone)]
513513
pub struct PyFuncArgs {
514514
pub args: Vec<PyObjectRef>,
515-
// TODO: add kwargs here
515+
pub kwargs: Option<Vec<(String, PyObjectRef)>>,
516516
}
517517

518518
impl PyFuncArgs {
519-
pub fn new() -> PyFuncArgs {
520-
PyFuncArgs { args: vec![] }
519+
pub fn new(mut args: Vec<PyObjectRef>, kwarg_names: Vec<String>) -> PyFuncArgs {
520+
let mut kwargs = vec![];
521+
for name in kwarg_names.iter().rev() {
522+
kwargs.push((name.clone(), args.pop().unwrap()));
523+
}
524+
PyFuncArgs {
525+
args: args,
526+
kwargs: Some(kwargs),
527+
}
521528
}
522529

523530
pub fn insert(&self, item: PyObjectRef) -> PyFuncArgs {
524531
let mut args = PyFuncArgs {
525532
args: self.args.clone(),
533+
kwargs: self.kwargs.clone(),
526534
};
527535
args.args.insert(0, item);
528536
return args;
529537
}
538+
530539
pub fn shift(&mut self) -> PyObjectRef {
531540
self.args.remove(0)
532541
}

vm/src/vm.rs

Lines changed: 104 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use super::objbool;
2222
use super::objobject;
2323
use super::pyobject::{
2424
AttributeProtocol, DictProtocol, IdProtocol, ParentProtocol, PyContext, PyFuncArgs, PyObject,
25-
PyObjectKind, PyObjectRef, PyResult,
25+
PyObjectKind, PyObjectRef, PyResult, ToRust,
2626
};
2727
use super::stdlib;
2828
use super::sysmodule;
@@ -61,7 +61,10 @@ impl VirtualMachine {
6161
// TODO: maybe there is a clearer way to create an instance:
6262
info!("New exception created: {}", msg);
6363
let args: Vec<PyObjectRef> = Vec::new();
64-
let args = PyFuncArgs { args: args };
64+
let args = PyFuncArgs {
65+
args: args,
66+
kwargs: None,
67+
};
6568

6669
// Call function:
6770
let exception = self.invoke(exc_type, args).unwrap();
@@ -267,6 +270,7 @@ impl VirtualMachine {
267270
self,
268271
PyFuncArgs {
269272
args: vec![traceback, pos],
273+
kwargs: None,
270274
},
271275
).unwrap();
272276
// exception.__trace
@@ -348,7 +352,10 @@ impl VirtualMachine {
348352
Ok(v) => v,
349353
Err(err) => return Err(err),
350354
};
351-
let args = PyFuncArgs { args: args };
355+
let args = PyFuncArgs {
356+
args: args,
357+
kwargs: None,
358+
};
352359
self.invoke(func, args)
353360
}
354361

@@ -496,7 +503,7 @@ impl VirtualMachine {
496503
}
497504
}
498505

499-
pub fn invoke(&mut self, func_ref: PyObjectRef, args: PyFuncArgs) -> PyResult {
506+
pub fn invoke(&mut self, func_ref: PyObjectRef, mut args: PyFuncArgs) -> PyResult {
500507
trace!("Invoke: {:?} {:?}", func_ref, args);
501508
match func_ref.borrow().kind {
502509
PyObjectKind::RustFunction { function } => function(self, args),
@@ -505,44 +512,84 @@ impl VirtualMachine {
505512
ref scope,
506513
ref defaults,
507514
} => {
508-
let mut scope = self.ctx.new_scope(Some(scope.clone()));
509515
let code_object = copy_code(code.clone());
510516
let nargs = args.args.len();
517+
518+
// Check the number of positional arguments
511519
let nexpected_args = code_object.arg_names.len();
512-
let args = if nargs > nexpected_args {
520+
if nargs > nexpected_args {
513521
return Err(self.new_type_error(format!(
514522
"Expected {} arguments (got: {})",
515523
nexpected_args, nargs
516524
)));
517-
} else if nargs < nexpected_args {
518-
// Use defaults if available
519-
let nrequired_defaults = nexpected_args - nargs;
525+
}
526+
527+
let mut scope = self.ctx.new_scope(Some(scope.clone()));
528+
529+
// Copy positional arguments into local variables
530+
for (n, arg) in args.args.iter().enumerate() {
531+
scope.set_item(&code_object.arg_names[n], arg.clone())
532+
}
533+
534+
// TODO: Pack other positional arguments in to *args
535+
536+
// Handle keyword arguments
537+
if let Some(ref mut kwargs) = args.kwargs {
538+
for (name, value) in kwargs {
539+
if !code_object.arg_names.contains(&name) {
540+
return Err(self.new_type_error(format!(
541+
"Got an unexpected keyword argument '{}'",
542+
name
543+
)));
544+
}
545+
if scope.contains_key(&name) {
546+
return Err(self.new_type_error(format!(
547+
"Got multiple values for argument '{}'",
548+
name
549+
)));
550+
}
551+
scope.set_item(&name, value.clone());
552+
}
553+
}
554+
555+
// Add missing positional arguments, if we have fewer positional arguments than the
556+
// function definition calls for
557+
if nargs < nexpected_args {
520558
let available_defaults = match defaults.borrow().kind {
521559
PyObjectKind::Tuple { ref elements } => elements.clone(),
522560
PyObjectKind::None => vec![],
523561
_ => panic!("function defaults not tuple or None"),
524562
};
525-
if nrequired_defaults > available_defaults.len() {
563+
564+
// Given the number of defaults available, check all the arguments for which we
565+
// _don't_ have defaults; if any are missing, raise an exception
566+
let required_args = nexpected_args - available_defaults.len();
567+
let mut missing = vec![];
568+
for i in 0..required_args {
569+
let variable_name = &code_object.arg_names[i];
570+
if !scope.contains_key(variable_name) {
571+
missing.push(variable_name)
572+
}
573+
}
574+
if !missing.is_empty() {
526575
return Err(self.new_type_error(format!(
527-
"Expected {}-{} arguments (got: {})",
528-
nexpected_args - available_defaults.len(),
529-
nexpected_args,
530-
nargs
576+
"Missing {} required positional arguments: {:?}",
577+
missing.len(),
578+
missing
531579
)));
532580
}
533-
let default_args = available_defaults
534-
.clone()
535-
.split_off(available_defaults.len() - nrequired_defaults);
536-
let mut new_args = args.args.clone();
537-
new_args.extend(default_args);
538-
new_args
539-
} else {
540-
// nargs == nexpected_args
541-
args.args
581+
582+
// We have sufficient defaults, so iterate over the corresponding names and use
583+
// the default if we don't already have a value
584+
let mut default_index = 0;
585+
for i in required_args..nexpected_args {
586+
let arg_name = &code_object.arg_names[i];
587+
if !scope.contains_key(arg_name) {
588+
scope.set_item(arg_name, available_defaults[default_index].clone());
589+
}
590+
default_index += 1;
591+
}
542592
};
543-
for (name, value) in code_object.arg_names.iter().zip(args) {
544-
scope.set_item(name, value);
545-
}
546593
let frame = Frame::new(code.clone(), scope);
547594
self.run_frame(frame)
548595
}
@@ -821,8 +868,10 @@ impl VirtualMachine {
821868
}
822869
bytecode::Instruction::CallFunction { count } => {
823870
let args: Vec<PyObjectRef> = self.pop_multiple(*count);
824-
// TODO: kwargs
825-
let args = PyFuncArgs { args: args };
871+
let args = PyFuncArgs {
872+
args: args,
873+
kwargs: None,
874+
};
826875
let func_ref = self.pop_value();
827876

828877
// Call function:
@@ -839,8 +888,32 @@ impl VirtualMachine {
839888
}
840889
}
841890
}
842-
bytecode::Instruction::CallFunctionKw { count: _ } => {
843-
unimplemented!("keyword arg calls not yet implemented");
891+
bytecode::Instruction::CallFunctionKw { count } => {
892+
let kwarg_names = self.pop_value();
893+
let args: Vec<PyObjectRef> = self.pop_multiple(*count);
894+
895+
let kwarg_names = kwarg_names
896+
.to_vec()
897+
.unwrap()
898+
.iter()
899+
.map(|pyobj| pyobj.to_str().unwrap())
900+
.collect();
901+
let args = PyFuncArgs::new(args, kwarg_names);
902+
let func_ref = self.pop_value();
903+
904+
// Call function:
905+
let func_result = self.invoke(func_ref, args);
906+
907+
match func_result {
908+
Ok(value) => {
909+
self.push_value(value);
910+
None
911+
}
912+
Err(value) => {
913+
// Ripple exception upwards:
914+
Some(Err(value))
915+
}
916+
}
844917
}
845918
bytecode::Instruction::Jump { target } => {
846919
self.jump(target);
@@ -924,6 +997,7 @@ impl VirtualMachine {
924997
self,
925998
PyFuncArgs {
926999
args: vec![expr.clone()],
1000+
kwargs: None,
9271001
},
9281002
).unwrap();
9291003
}

0 commit comments

Comments
 (0)