diff --git a/math.go b/math.go index a240e562..d29fb6a9 100644 --- a/math.go +++ b/math.go @@ -1,8 +1,10 @@ package lo import ( + "fmt" "github.com/samber/lo/internal/constraints" "math" + "strconv" ) // Range creates an array of numbers (positive and/or negative) with given length. @@ -110,30 +112,26 @@ func MeanBy[T any, R constraints.Float | constraints.Integer](collection []T, it func Round[T float64 | float32](f T, n ...int) T { var nn = 3 if len(n) > 0 { - nnn := n[0] - if nnn >= 0 { - nn = nnn - if nn > 15 { - nn = 15 - } + nn = n[0] + if nn < 0 || nn > 15 { + panic("Round() precision must be between 0 and 15") } } - pow10N := math.Pow10(nn) - return T(math.Round(float64(f)*pow10N) / pow10N) + r, _ := strconv.ParseFloat(fmt.Sprintf("%.*f", nn, f), 64) + return T(r) } // Truncate returns the float32/float64 of the specified precision func Truncate[T float64 | float32](f T, n ...int) T { var nn = 3 if len(n) > 0 { - nnn := n[0] - if nnn >= 0 { - nn = nnn - if nn > 15 { - nn = 15 - } + nn = n[0] + if nn < 0 || nn > 15 { + panic("Truncate() precision must be between 0 and 15") } } pow10N := math.Pow10(nn) - return T(math.Floor(float64(f)*pow10N) / pow10N) + integer, fractional := math.Modf(float64(f)) + r, _ := strconv.ParseFloat(fmt.Sprintf("%.*f", nn, integer+math.Trunc(fractional*pow10N)/pow10N), 64) + return T(r) } diff --git a/math_example_test.go b/math_example_test.go index 486ae373..5faead48 100644 --- a/math_example_test.go +++ b/math_example_test.go @@ -90,16 +90,23 @@ func ExampleRound() { result2 := Round(1.23456, 2) result3 := Round(1.23456, 3) result4 := Round(1.23456, 7) + result5 := Round(1.234999999999999, 15) + result6 := Round(1.234999999999999, 7) fmt.Printf("%v\n", result1) fmt.Printf("%v\n", result2) fmt.Printf("%v\n", result3) fmt.Printf("%v\n", result4) + fmt.Printf("%v\n", result5) + fmt.Printf("%v\n", result6) + // Output: // 1.235 // 1.23 // 1.235 // 1.23456 + // 1.235 + // 1.235 } func ExampleTruncate() { @@ -107,14 +114,20 @@ func ExampleTruncate() { result2 := Truncate(1.23456, 2) result3 := Truncate(1.23456, 4) result4 := Truncate(1.23456, 7) + result5 := Truncate(1.2349999999999999, 15) + result6 := Truncate(1.2349999999999999, 7) fmt.Printf("%v\n", result1) fmt.Printf("%v\n", result2) fmt.Printf("%v\n", result3) fmt.Printf("%v\n", result4) + fmt.Printf("%v\n", result5) + fmt.Printf("%v\n", result6) // Output: // 1.234 // 1.23 // 1.2345 // 1.23456 + // 1.234999999999999 + // 1.2349999 } diff --git a/math_test.go b/math_test.go index ea9fc8a7..7dba4fba 100644 --- a/math_test.go +++ b/math_test.go @@ -137,8 +137,8 @@ func TestRound(t *testing.T) { result3 := Round(1.23456, 2) result4 := Round(1.23456, 3) result5 := Round(1.23456, 7) - result6 := Round(1.23456, 16) - result7 := Round(1.23456, -1) + result6 := Round(1.23456, 15) + result7 := Round(1.23456789, 7) result8 := Round(1.23456, 0) result9 := Round(1.00000000001, 5) @@ -148,7 +148,7 @@ func TestRound(t *testing.T) { is.Equal(result4, 1.235) is.Equal(result5, 1.23456) is.Equal(result6, 1.23456) - is.Equal(result7, 1.235) + is.Equal(result7, 1.2345679) is.Equal(result8, 1.0) is.Equal(result9, 1.0) } @@ -162,8 +162,8 @@ func TestTruncate(t *testing.T) { result3 := Truncate(1.23456, 2) result4 := Truncate(1.23456, 3) result5 := Truncate(1.23456, 7) - result6 := Truncate(1.23456, 16) - result7 := Truncate(1.23456, -1) + result6 := Truncate(1.23456, 15) + result7 := Truncate(1.23456789, 7) result8 := Truncate(1.23456, 0) result9 := Truncate(1.00000000001, 5) @@ -173,7 +173,7 @@ func TestTruncate(t *testing.T) { is.Equal(result4, 1.234) is.Equal(result5, 1.23456) is.Equal(result6, 1.23456) - is.Equal(result7, 1.234) + is.Equal(result7, 1.2345678) is.Equal(result8, 1.0) is.Equal(result9, 1.0) }